From 41b4c6735ed435110155798811db5c4dcee379e4 Mon Sep 17 00:00:00 2001
From: "Stephen R. van den Berg" <srb@cuci.nl>
Date: Fri, 10 Nov 2017 10:43:16 +0100
Subject: [PATCH] pgsql: Simplify and bolster the code by using
 Thread.ResourceCount.

---
 lib/modules/Sql.pmod/pgsql.pike      |  19 ++--
 lib/modules/Sql.pmod/pgsql_util.pmod | 128 +++++++++------------------
 2 files changed, 50 insertions(+), 97 deletions(-)

diff --git a/lib/modules/Sql.pmod/pgsql.pike b/lib/modules/Sql.pmod/pgsql.pike
index 8358616b36..c5f63a4341 100644
--- a/lib/modules/Sql.pmod/pgsql.pike
+++ b/lib/modules/Sql.pmod/pgsql.pike
@@ -65,8 +65,8 @@ final int _fetchlimit=FETCHLIMIT;
 final Thread.Mutex _unnamedportalmux;
 private Thread.Mutex unnamedstatement;
 private Thread.MutexKey termlock;
-final int _portalsinflight;
-final int _statementsinflight;
+private Thread.ResourceCountKey backendreg;
+final Thread.ResourceCount  _portalsinflight, _statementsinflight;
 final int _wasparallelisable;
 final int _intransaction;
 
@@ -108,10 +108,9 @@ private string database, user, pass;
 private string cnonce;
 private Thread.Condition waitforauthready;
 final Thread.Mutex _shortmux;
-final Thread.Condition _readyforcommit;
-final int _waittocommit, _readyforquerycount;
+final int _readyforquerycount;
 
-private string _sprintf(int type, void|mapping flags) {
+protected string _sprintf(int type) {
   string res=UNDEFINED;
   switch(type) {
     case 'O':
@@ -213,7 +212,7 @@ protected void create(void|string host, void|string database,
 
   if(!_port)
     _port = PGSQL_DEFAULT_PORT;
-  .pgsql_util.register_backend();
+  backendreg = .pgsql_util.register_backend();
   _shortmux=Thread.Mutex();
   PD("Connect\n");
   waitforauthready = Thread.Condition();
@@ -226,8 +225,8 @@ protected void create(void|string host, void|string database,
   _unnamedportalmux = Thread.Mutex();
   unnamedstatement = Thread.Mutex();
   readyforquery_cb = connect_cb;
-  _portalsinflight = 0;
-  _statementsinflight = 0;
+  _portalsinflight = Thread.ResourceCount();
+  _statementsinflight = Thread.ResourceCount();
   _wasparallelisable = 0;
 }
 
@@ -1143,7 +1142,7 @@ private void procmessage() {
 #ifdef PG_DEBUGMORE
           showportalstack("ERRORRESPONSE");
 #endif
-          if (!_portalsinflight && !_readyforquerycount)
+          if (_portalsinflight->drained() && !_readyforquerycount)
             sendsync();
           PD("%O ErrorResponse %O\n",
            objectp(portal)&&(portal._portalname||portal._preparedname),
@@ -1345,7 +1344,7 @@ private void procmessage() {
 protected void destroy() {
   string errstring;
   mixed err = catch(close());
-  .pgsql_util.unregister_backend();
+  backendreg = 0;
   /*
    * Flush out any asynchronously reported errors to stderr; because we are
    * inside a destructor, throwing an error will not work anymore.
diff --git a/lib/modules/Sql.pmod/pgsql_util.pmod b/lib/modules/Sql.pmod/pgsql_util.pmod
index aaa05fafc0..c4c8b3d3e6 100644
--- a/lib/modules/Sql.pmod/pgsql_util.pmod
+++ b/lib/modules/Sql.pmod/pgsql_util.pmod
@@ -22,7 +22,7 @@
 final Pike.Backend local_backend = Pike.SmallBackend();
 
 private Thread.Mutex backendmux = Thread.Mutex();
-private int clientsregistered;
+private Thread.ResourceCount clientsregistered = Thread.ResourceCount();
 
 constant emptyarray = ({});
 constant describenodata
@@ -127,7 +127,7 @@ private void run_local_backend() {
     looponce=0;
     if(lock=backendmux->trylock()) {
       PD("Starting local backend\n");
-      while (clientsregistered			// Autoterminate when not needed
+      while (!clientsregistered->drained()	// Autoterminate when not needed
        || sizeof(local_backend->call_out_info())) {
         mixed err;
         if (err = catch(local_backend(4096.0)))
@@ -135,22 +135,19 @@ private void run_local_backend() {
       }
       PD("Terminating local backend\n");
       lock=0;
-      looponce=clientsregistered;
+      looponce = !clientsregistered->drained();
     }
   } while(looponce);
 }
 
 //! Registers yourself as a user of this backend.  If the backend
 //! has not been started yet, it will be spawned automatically.
-final void register_backend() {
-  if(!clientsregistered++)
+final Thread.ResourceCountKey register_backend() {
+  int startbackend = clientsregistered->drained();
+  Thread.ResourceCountKey key = clientsregistered->acquire();
+  if (startbackend)
     Thread.Thread(run_local_backend);
-}
-
-//! Unregisters yourself as a user of this backend.  If there are
-//! no longer any registered users, the backend will be terminated.
-final void unregister_backend() {
-  --clientsregistered;
+  return key;
 }
 
 final void throwdelayederror(object parent) {
@@ -192,7 +189,7 @@ private inline mixed callout(function(mixed ...:void) f,
 
 class bufcon {
   inherit Stdio.Buffer;
-  private int dirty;
+  private Thread.ResourceCountKey dirty;
 
 #ifdef PG_DEBUGRACE
   final bufcon `chain() {
@@ -202,18 +199,16 @@ class bufcon {
 
   private conxion realbuffer;
 
-  protected void create(conxion _realbuffer) {
+  private void create(conxion _realbuffer) {
     realbuffer=_realbuffer;
   }
 
-  final int `stashcount() {
+  final Thread.ResourceCount `stashcount() {
     return realbuffer->stashcount;
   }
 
   final bufcon start(void|int waitforreal) {
-    dirty = 1;
-    realbuffer->stashcount++;
-    dirty = 2;
+    dirty = realbuffer->stashcount->acquire();
 #ifdef PG_DEBUG
     if(waitforreal)
       error("pgsql.bufcon not allowed here\n");
@@ -235,12 +230,7 @@ class bufcon {
      realbuffer->socket->query_fd(), mode, realbuffer->stashflushmode);
     if (mode > realbuffer->stashflushmode)
       realbuffer->stashflushmode = mode;
-    dirty = 1;
-    if(!--realbuffer->stashcount)
-      dirty = 0, realbuffer->stashavail.signal();
-    else
-      dirty = 0;
-    lock=0;
+    dirty = 0;
     this->clear();
     if(lock=realbuffer->nostash->trylock(1)) {
 #ifdef PG_DEBUGRACE
@@ -255,21 +245,6 @@ class bufcon {
 #endif
     }
   }
-
-  protected void _destruct() {
-    switch (dirty) {
-      case 1:
-        werror("FIXME: Race condition detected %s\n",
-         describe_backtrace(({"", backtrace()[..<1]})));
-        if (!realbuffer->stashcount)
-          break;
-      case 2:
-        Thread.MutexKey lock = realbuffer->shortmux->lock(2);
-        if (!--realbuffer->stashcount)
-          realbuffer->stashavail.signal();
-        lock = 0;
-    }
-  }
 };
 
 class conxiin {
@@ -280,7 +255,7 @@ class conxiin {
   final int procmsg;
   private int didreadcb;
 
-  protected bool range_error(int howmuch) {
+  protected final bool range_error(int howmuch) {
 #ifdef PG_DEBUG
     if(howmuch<=0)
       error("Out of range %d\n",howmuch);
@@ -311,7 +286,7 @@ class conxiin {
     return 0;
   }
 
-  protected void create() {
+  private void create() {
     i::create();
     fillreadmux=Thread.Mutex();
     fillread=Thread.Condition();
@@ -320,7 +295,7 @@ class conxiin {
 
 class sfile {
   inherit Stdio.File;
-  int query_fd() {
+  final int query_fd() {
     return is_open() ? ::query_fd() : -1;
   }
 };
@@ -343,7 +318,7 @@ class conxion {
   final Thread.Condition stashavail;
   final Stdio.Buffer stash;
   final int stashflushmode;
-  final int stashcount;
+  final Thread.ResourceCount stashcount;
   final int synctransact;
 #ifdef PG_DEBUGRACE
   final mixed nostrack;
@@ -368,8 +343,7 @@ class conxion {
 #endif
       started = lock;
       lock=shortmux->lock();
-      if(stashcount)
-        PT(stashavail.wait(lock));
+      stashcount->wait_till_drained(lock);
       mode = getstash(KEEP);
       lock=0;
       if (mode > KEEP)
@@ -476,7 +450,7 @@ outer:
       return -1;
   }
 
-  protected void destroy() {
+  private void destroy() {
     PD("%d>Close conxion %d\n", socket ? socket->query_fd() : -1, !!nostash);
     int|.pgsql_util.sql_result portal;
     if (qportals)			// CancelRequest does not use qportals
@@ -555,16 +529,16 @@ outer:
           catch(fd=socket->query_fd());
         res=predef::sprintf("conxion  fd: %d input queue: %d/%d "
                     "queued portals: %d  output queue: %d/%d\n"
-                    "started: %d  stashcount: %d\n",
+                    "started: %d\n",
                     fd,sizeof(i),i->_size_object(),
                     qportals && qportals->size(), sizeof(this), _size_object(),
-                    !!started, stashcount);
+                    !!started);
         break;
     }
     return res;
   }
 
-  protected void create(object pgsqlsess,Thread.Queue _qportals,int nossl) {
+  private void create(object pgsqlsess,Thread.Queue _qportals,int nossl) {
     o::create();
     qportals = _qportals;
     synctransact = 1;
@@ -576,6 +550,7 @@ outer:
     stashavail=Thread.Condition();
     stashqueue=Thread.Queue();
     stash=Stdio.Buffer();
+    stashcount = Thread.ResourceCount();
     Thread.Thread(connectloop,pgsqlsess,nossl);
   }
 };
@@ -584,7 +559,7 @@ outer:
 class conxsess {
   final conxion chain;
 
-  void create(conxion parent) {
+  private void create(conxion parent) {
     if (parent->started)
       werror("Overwriting conxsess %s %s\n",
         describe_backtrace(({"new ", backtrace()[..<1]})),
@@ -598,7 +573,7 @@ class conxsess {
     chain = 0;
   }
 
-  void destroy() {
+  private void destroy() {
     if (chain)
       werror("Untransmitted conxsess %s\n",
        describe_backtrace(({"", backtrace()[..<1]})));
@@ -632,6 +607,7 @@ class sql_result {
   private int portalbuffersize;
   private Thread.Mutex closemux;
   private Thread.Queue datarows;
+  private Thread.ResourceCountKey stmtifkey, portalsifkey;
   private array(mapping(string:mixed)) datarowdesc;
   private array(int) datarowtypes;	// types from datarowdesc
   private string statuscmdcomplete;
@@ -647,7 +623,7 @@ class sql_result {
   private function(:void) gottimeout;
   private int timeout;
 
-  private string _sprintf(int type, void|mapping flags) {
+  protected string _sprintf(int type) {
     string res=UNDEFINED;
     switch(type) {
       case 'O':
@@ -1037,17 +1013,10 @@ class sql_result {
         if(sizeof(datarowtypes))
           plugbuffer->add_ints(map(datarowtypes,oidformat),2);
         else if (syncparse < 0 && !pgsqlsess->_wasparallelisable
-         && pgsqlsess->_statementsinflight > 1) {
-          lock=pgsqlsess->_shortmux->lock();
-          // Decrement temporarily to account for ourselves
-          if(--pgsqlsess->_statementsinflight) {
-            pgsqlsess->_waittocommit++;
-            PD("Commit waiting for statements to finish\n");
-            catch(PT(pgsqlsess->_readyforcommit->wait(lock)));
-            pgsqlsess->_waittocommit--;
-          }
-          // Increment again to account for ourselves
-          pgsqlsess->_statementsinflight++;
+         && !pgsqlsess->_statementsinflight->drained(1)) {
+          lock = pgsqlsess->_shortmux->lock();
+          PD("Commit waiting for statements to finish\n");
+          catch(PT(pgsqlsess->_statementsinflight->wait_till_drained(lock, 1)));
         }
         lock=0;
         PD("Bind portal %O statement %O\n",_portalname,_preparedname);
@@ -1082,17 +1051,10 @@ class sql_result {
     _state=PARSING;
     Thread.MutexKey lockc = pgsqlsess->_shortmux->lock();
     if (syncparse || syncparse < 0 && pgsqlsess->_wasparallelisable) {
-      if(pgsqlsess->_statementsinflight) {
-        pgsqlsess->_waittocommit++;
-        PD("Commit waiting for statements to finish\n");
-        // Do NOT put this in a function, it would require passing a lock
-        // variable on the argumentstack to the function, which will cause
-        // unpredictable lock release issues due to the extra copy on the stack
-        catch(PT(pgsqlsess->_readyforcommit->wait(lockc)));
-        pgsqlsess->_waittocommit--;
-      }
+      PD("Commit waiting for statements to finish\n");
+      catch(PT(pgsqlsess->_statementsinflight->wait_till_drained(lockc)));
     }
-    pgsqlsess->_statementsinflight++;
+    stmtifkey = pgsqlsess->_statementsinflight->acquire();
     lockc = 0;
     lock=0;
     statuscmdcomplete=UNDEFINED;
@@ -1105,11 +1067,7 @@ class sql_result {
       lock = closemux->lock();
     if (_state <= BOUND) {
       _state = COMMITTED;
-      lock = pgsqlsess->_shortmux->lock();
-      if (!--pgsqlsess->_statementsinflight && pgsqlsess->_waittocommit) {
-        PD("Signal no statements in flight\n");
-        catch(pgsqlsess->_readyforcommit->signal());
-      }
+      stmtifkey = 0;
     }
     lock = 0;
   }
@@ -1117,9 +1075,7 @@ class sql_result {
   final void _bindportal() {
     Thread.MutexKey lock = closemux->lock();
     _state=BOUND;
-    Thread.MutexKey lockc = pgsqlsess->_shortmux->lock();
-    pgsqlsess->_portalsinflight++;
-    lockc = 0;
+    portalsifkey = pgsqlsess->_portalsinflight->acquire();
     lock=0;
   }
 
@@ -1132,12 +1088,12 @@ class sql_result {
       case COPYINPROGRESS:
       case COMMITTED:
       case BOUND:
-        --pgsqlsess->_portalsinflight;
+        portalsifkey = 0;
     }
     switch(_state) {
       case BOUND:
       case PARSING:
-        --pgsqlsess->_statementsinflight;
+        stmtifkey = 0;
     }
     _state = PURGED;
     lock=0;
@@ -1174,10 +1130,9 @@ class sql_result {
           retval=FLUSHSEND;
         } else
           _unnamedportalkey=0;
-        Thread.MutexKey lockc=pgsqlsess->_shortmux->lock();
-        if(!--pgsqlsess->_portalsinflight) {
-          if(!pgsqlsess->_waittocommit && !plugbuffer->stashcount
-           && transtype != TRANSBEGIN)
+        portalsifkey = 0;
+        if (pgsqlsess->_portalsinflight->drained()) {
+          if (plugbuffer->stashcount->drained() && transtype != TRANSBEGIN)
            /*
             * stashcount will be non-zero if a parse request has been queued
             * before the close was initiated.
@@ -1186,7 +1141,6 @@ class sql_result {
             pgsqlsess->_readyforquerycount++, retval=SYNCSEND;
           pgsqlsess->_pportalcount=0;
         }
-        lockc=0;
     }
     lock=0;
     return retval;
-- 
GitLab