From 20e95672eb76c6b5398f298ebcd1cea94751b9d8 Mon Sep 17 00:00:00 2001
From: Marcus Comstedt <marcus@mc.pp.se>
Date: Wed, 31 Mar 1999 21:30:48 +0200
Subject: [PATCH] rsql now supports almost all DBAPI calls.

Rev: bin/rsqld.pike:1.2
Rev: lib/modules/Sql.pmod/rsql.pike:1.2
---
 bin/rsqld.pike                 | 127 ++++++++++++++++++++++++----
 lib/modules/Sql.pmod/rsql.pike | 147 +++++++++++++++++++++++++++------
 2 files changed, 234 insertions(+), 40 deletions(-)

diff --git a/bin/rsqld.pike b/bin/rsqld.pike
index a63129f0fd..0a0e602391 100644
--- a/bin/rsqld.pike
+++ b/bin/rsqld.pike
@@ -14,8 +14,9 @@ class Connection
   static int expected;
   static mapping(int:function) commandset = ([]);
   static object sqlobj;
-  static mapping(int:object) queries = ([]);
-  static int qid;
+  static mapping(string:object) queries = ([]);
+  static int qbase, qid;
+  static private mapping(string:string) users;
 
   static void timeout()
   {
@@ -109,8 +110,17 @@ class Connection
 
   static int cmd_login(array(string) userpw)
   {
-    // No security right now...
-    int authorized = 1;
+    int authorized = 0;
+
+    if(!users)
+      authorized = 1;  // No security enabled
+    else if(sizeof(userpw) && userpw[0] && !zero_type(users[userpw[0]])) {
+      if(!stringp(users[userpw[0]]))
+	authorized = 1;  // Password-less user
+      else if(sizeof(userpw)>1 && userpw[1] &&
+	      crypt(userpw[1], users[userpw[0]]))
+	authorized = 1;  // Correct password for user
+    }
 
     if(authorized)
       commandset_1();
@@ -145,27 +155,104 @@ class Connection
     return sqlobj->server_info();
   }
 
-  static int cmd_bigquery(string q)
+  static string cmd_hostinfo()
+  {
+    return sqlobj->host_info();
+  }
+
+  static void cmd_shutdown()
+  {
+    return sqlobj->shutdown();
+  }
+
+  static void cmd_reload()
+  {
+    return sqlobj->reload();
+  }
+
+  static array(string) cmd_listdbs(string wild)
+  {
+    return sqlobj->list_dbs(wild);
+  }
+
+  static array(string) cmd_listtables(string wild)
+  {
+    return sqlobj->list_tables(wild);
+  }
+
+  static array(mapping(string:mixed)) cmd_listflds(array(string) args)
   {
-    if(!qid)
+    return sqlobj->list_fields(@args);
+  }
+
+  static private string make_id()
+  {
+    if(!qid) {
       qid=1;
+      qbase++;
+    }
+    return sprintf("%4c%4c", qbase, qid++);
+  }
+
+  static string cmd_bigquery(string q)
+  {
     object res = sqlobj->big_query(q);
-    return res && ((queries[qid]=res), qid++);
+    if(!res)
+      return 0;
+    string qid = make_id();
+    queries[qid] = res;
+    return qid;
   }
 
-  static void cmd_zapquery(int qid)
+  static void cmd_zapquery(string qid)
   {
     m_delete(queries, qid);
   }
 
-  static int|array(string|int) cmd_fetchrow(int qid)
+  static object get_query(string qid)
+  {
+    return queries[qid] || 
+      (throw (({"Query ID has expired.\n", backtrace()})),0);
+  }
+
+  static array(mapping(string:mixed)) cmd_query(array args)
+  {
+    return sqlobj->query(@args);
+  }
+
+  static int|array(string|int) cmd_fetchrow(string qid)
+  {
+    return get_query(qid)->fetch_row();
+  }
+
+  static array(mapping(string:mixed)) cmd_fetchfields(string qid)
+  {
+    return get_query(qid)->fetch_fields();
+  }
+
+  static int cmd_numrows(string qid)
+  {
+    return get_query(qid)->num_rows();
+  }
+
+  static int cmd_numfields(string qid)
+  {
+    return get_query(qid)->num_fields();
+  }
+
+  static int cmd_eof(string qid)
+  {
+    return get_query(qid)->eof();
+  }
+
+  static void cmd_seek(array(string|int) args)
   {
-    return queries[qid]->fetch_row();
+    get_query(args[0])->seek(args[1]);
   }
 
-  static array(mapping(string:mixed)) cmd_fetchfields(int qid)
+  static string cmd_quote(string s)
   {
-    return queries[qid]->fetch_fields();
+    return sqlobj->quote(s);
   }
 
   static void commandset_0()
@@ -178,8 +265,12 @@ class Connection
     commandset_0();
     commandset |= ([ 'D': cmd_selectdb, 'E': cmd_error,
 		     'C': cmd_create, 'X': cmd_drop, 'I': cmd_srvinfo,
+		     'i': cmd_hostinfo, 's': cmd_shutdown, 'r': cmd_reload, 
+		     'l': cmd_listdbs, 't': cmd_listtables, 'f': cmd_listflds,
 		     'Q': cmd_bigquery, 'Z': cmd_zapquery,
-		     'R': cmd_fetchrow, 'F': cmd_fetchfields]);
+		     'R': cmd_fetchrow, 'F': cmd_fetchfields,
+		     'N': cmd_numrows, 'n': cmd_numfields, 'e': cmd_eof,
+		     'S': cmd_seek, '@': cmd_query, 'q': cmd_quote]);
   }
 
   static void client_ident(string s)
@@ -198,10 +289,12 @@ class Connection
     }
   }
 
-  void create(object s)
+  void create(object s, mapping(string:string) u)
   {
+    users = u;
     clntsock = s;
     got = to_write = "";
+    qbase = time();
     expect(8, client_ident);
     clntsock->set_nonblocking(read_callback, write_callback, close_callback);
   }
@@ -210,16 +303,18 @@ class Connection
 class Server
 {
   static object servsock;
+  static private mapping(string:string) users;
 
   static void accept_callback()
   {
     object s = servsock->accept();
     if(s)
-      Connection(s);
+      Connection(s, users);
   }
 
-  void create(int|void port, string|void ip)
+  void create(int|void port, string|void ip, mapping(string:string)|void usrs)
   {
+    users = usrs;
     if(!port)
       port = RSQL_PORT;
     servsock = Stdio.Port();
diff --git a/lib/modules/Sql.pmod/rsql.pike b/lib/modules/Sql.pmod/rsql.pike
index 2e129dbb6d..799c5308f9 100644
--- a/lib/modules/Sql.pmod/rsql.pike
+++ b/lib/modules/Sql.pmod/rsql.pike
@@ -3,13 +3,66 @@
 #define RSQL_PORT 3994
 #define RSQL_VERSION 1
 
+
+#if constant(thread_create)
+#define LOCK object key=mutex->lock()
+#define UNLOCK destruct(key)
+static private object(Thread.Mutex) mutex = Thread.Mutex();
+#else
+#define LOCK
+#define UNLOCK
+#endif
+
+
 static object(Stdio.File) sock;
 static int seqno = 0;
 
-static mixed do_request(int cmd, mixed|void arg)
+static private string host, user, pw;
+static private int port;
+
+static void low_reconnect()
 {
+  object losock = Stdio.File();
+  if(sock)
+    destruct(sock);
+  if(!losock->connect(host, port|RSQL_PORT))
+    throw(({"Can't connect to "+host+(port? ":"+port:"")+": "+
+	    strerror(losock->errno())+"\n", backtrace()}));
+  if(8!=losock->write(sprintf("RSQL%4c", RSQL_VERSION)) ||
+     losock->read(4) != "SQL!") {
+    destruct(losock);
+    throw(({"Initial handshake error on "+host+(port? ":"+port:"")+"\n",
+	    backtrace()}));    
+  }
+  sock = losock;
+  if(!do_request('L', ({user,pw}), 1)) {
+    sock = 0;
+    if(losock)
+      destruct(losock);
+    throw(({"Login refused on "+host+(port? ":"+port:"")+"\n",
+	    backtrace()}));        
+  }
+}
+
+static void low_connect(string the_host, int the_port, string the_user,
+			string the_pw)
+{
+  host = the_host;
+  port = the_port;
+  user = the_user;
+  pw = the_pw;
+  low_reconnect();
+}
+
+static mixed do_request(int cmd, mixed|void arg, int|void noreconnect)
+{
+  LOCK;
   if(!sock)
-    throw(({"Not connected.\n", backtrace()}));
+    if(noreconnect) {
+      UNLOCK;
+      throw(({"No connection\n", backtrace()}));
+    } else
+      low_reconnect();
   arg = (arg? encode_value(arg) : "");
   sock->write(sprintf("?<%c>%4c%4c%s", cmd, ++seqno, sizeof(arg), arg));
   string res;
@@ -20,8 +73,12 @@ static mixed do_request(int cmd, mixed|void arg)
     mixed rdat = (rlen? sock->read(rlen) : "");
     if((!rdat) || sizeof(rdat)!=rlen) {
       destruct(sock);
-      throw(({"RSQL Phase error, disconnected\n", backtrace()}));
+      UNLOCK;
+      if(noreconnect)
+	throw(({"RSQL Phase error, disconnected\n", backtrace()}));
+      else return do_request(cmd, arg, 1);
     }
+    UNLOCK;
     rdat = (sizeof(rdat)? decode_value(rdat):0);
     switch(res[0]) {
     case '.': return rdat;
@@ -30,7 +87,10 @@ static mixed do_request(int cmd, mixed|void arg)
     throw(({"Internal error\n", backtrace()}));    
   } else {
     destruct(sock);
-    throw(({"RSQL Phase error, disconnected\n", backtrace()}));
+    UNLOCK;
+    if(noreconnect)
+      throw(({"RSQL Phase error, disconnected\n", backtrace()}));
+    else return do_request(cmd, arg, 1);
   }
 }
 
@@ -54,9 +114,44 @@ void drop_db(string db)
   do_request('X', db);
 }
 
-void server_info()
+string server_info()
 {
-  do_request('I');
+  return do_request('I');
+}
+
+string host_info()
+{
+  return do_request('i');
+}
+
+void shutdown()
+{
+  do_request('s');
+}
+
+void reload()
+{
+  do_request('r');
+}
+
+array(string) list_dbs(string|void wild)
+{
+  return do_request('l', wild);
+}
+
+array(string) list_tables(string|void wild)
+{
+  return do_request('t', wild);
+}
+
+array(mapping(string:mixed)) list_fields(string ... args)
+{
+  return do_request('f', args);
+}
+
+string quote(string s)
+{
+  return do_request('q');
 }
 
 object big_query(string q)
@@ -82,6 +177,26 @@ object big_query(string q)
       return do_request('F', qid);
     }
 
+    int num_rows()
+    {
+      return do_request('N', qid);
+    }
+
+    int num_fields()
+    {
+      return do_request('n', qid);
+    }
+
+    int eof()
+    {
+      return do_request('e', qid);
+    }
+
+    void seek(int skip)
+    {
+      return do_request('S', ({qid,skip}));
+    }
+
     void create(function(int,mixed:mixed) d_r, mixed i)
     {
       do_request = d_r;
@@ -91,25 +206,9 @@ object big_query(string q)
   }(do_request, qid);
 }
 
-static void low_connect(string host, int port, string user, string pw)
+array(mapping(string:mixed)) query(mixed ... args)
 {
-  object losock = Stdio.File();
-  if(!losock->connect(host, port|RSQL_PORT))
-    throw(({"Can't connect to "+host+(port? ":"+port:"")+": "+
-	    strerror(losock->errno())+"\n", backtrace()}));
-  if(8!=losock->write(sprintf("RSQL%4c", RSQL_VERSION)) ||
-     losock->read(4) != "SQL!") {
-    destruct(losock);
-    throw(({"Initial handshake error on "+host+(port? ":"+port:"")+"\n",
-	    backtrace()}));    
-  }
-  sock = losock;
-  if(!do_request('L', ({user,pw}))) {
-    sock = 0;
-    destruct(losock);
-    throw(({"Login refused on "+host+(port? ":"+port:"")+"\n",
-	    backtrace()}));        
-  }
+  return do_request('@', args);
 }
 
 void create(string|void host, string|void db, string|void user, string|void pw)
-- 
GitLab