From b018a340eb5b00f2442710cc3fc93337dfc39145 Mon Sep 17 00:00:00 2001
From: "Stephen R. van den Berg" <srb@cuci.nl>
Date: Sat, 11 Nov 2017 03:08:55 +0100
Subject: [PATCH] pgsql: Postgresql 10 scram-SHA256 authentication support.

---
 lib/modules/Sql.pmod/pgsql.pike      | 133 +++++++++++++++++++++++----
 lib/modules/Sql.pmod/pgsql_util.pmod |  14 +++
 2 files changed, 128 insertions(+), 19 deletions(-)

diff --git a/lib/modules/Sql.pmod/pgsql.pike b/lib/modules/Sql.pmod/pgsql.pike
index 3ed2399b46..e2062c5d52 100644
--- a/lib/modules/Sql.pmod/pgsql.pike
+++ b/lib/modules/Sql.pmod/pgsql.pike
@@ -107,6 +107,7 @@ private function (:void) readyforquery_cb;
 final string _host;
 final int _port;
 private string database, user, pass;
+private string cnonce, ServerSignature;
 private Thread.Condition waitforauthready;
 final Thread.Mutex _shortmux;
 final Thread.Condition _readyforcommit;
@@ -204,12 +205,12 @@ private string _sprintf(int type, void|mapping flags) {
 protected void create(void|string host, void|string database,
                       void|string user, void|string pass,
                       void|mapping(string:mixed) options) {
-  this::pass = pass;
+  this::pass = Standards.IDNA.to_ascii(pass);
   if(pass) {
     String.secure(pass);
     pass = "CENSORED";
   }
-  this::user = user;
+  this::user = Standards.IDNA.to_ascii(user, 1);
   this::database = database;
   _options = options || ([]);
 
@@ -718,16 +719,25 @@ private void procmessage() {
           return msgresponse;
         };
         case 'R': {
+          void authresponse(array msg) {
+            object cs = ci->start();
+            CHAIN(cs)->add_int8('p')->add_hstring(msg, 4, 4);
+            cs->sendcmd(SENDOUT); // No flushing, PostgreSQL 9.4 disapproves
+          }
           PD("Authentication ");
-          string sendpass;
           msglen-=4+4;
-          int authtype=cr->read_int32();
-          switch(authtype) {
+          int authtype, k;
+          switch(authtype = cr->read_int32()) {
             case 0:
               PD("Ok\n");
-              .pgsql_util.local_backend->remove_call_out(reconnect);
-              reconnectdelay=0;
-              cancelsecret="";
+              if (cnonce) {
+                PD("Authentication validation still in progress\n");
+                errtype = PROTOCOLUNSUPPORTED;
+              } else {
+                .pgsql_util.local_backend->remove_call_out(reconnect);
+                reconnectdelay=0;
+                cancelsecret="";
+              }
               break;
             case 2:
               PD("KerberosV5\n");
@@ -735,7 +745,7 @@ private void procmessage() {
               break;
             case 3:
               PD("ClearTextPassword\n");
-              sendpass=pass;
+              authresponse(({pass, 0}));
               break;
             case 4:
               PD("CryptPassword\n");
@@ -748,8 +758,8 @@ private void procmessage() {
                 errtype=PROTOCOLERROR;
 #endif
 #define md5hex(x) String.string2hex(Crypto.MD5.hash(x))
-              sendpass=md5hex(pass+user);
-              sendpass="md5"+md5hex(sendpass+cr->read(msglen));
+              authresponse(({"md5",
+               md5hex(md5hex(pass + user) + cr->read(msglen)), 0}));
 #ifdef PG_DEBUG
               msglen=0;
 #endif
@@ -769,29 +779,114 @@ private void procmessage() {
             case 8:
               PD("GSSContinue\n");
               errtype=PROTOCOLUNSUPPORTED;
-              cancelsecret=cr->read(msglen);		// Actually SSauthdata
+              cr->read(msglen);				// SSauthdata
 #ifdef PG_DEBUG
               if(msglen<1)
                 errtype=PROTOCOLERROR;
               msglen=0;
 #endif
               break;
+            case 10: {
+              string word;
+              PD("AuthenticationSASL\n");
+              k = 0;
+              while (sizeof(word = cr->read_cstring())) {
+                switch (word) {
+                  case "SCRAM-SHA-256":
+                    k = 1;
+                }
+#ifdef PG_DEBUG
+                msglen -= sizeof(word) + 1;
+                if (msglen < 1)
+                  break;
+#endif
+              }
+              if (k) {
+                cnonce = MIME.encode_base64(Random.System().random_string(18));
+                word = "n,,n=,r=" + cnonce;
+                authresponse(({
+                  "SCRAM-SHA-256", 0, sprintf("%4c", sizeof(word)), word
+                 }));
+              } else
+                errtype = PROTOCOLUNSUPPORTED;
+#ifdef PG_DEBUG
+              if(msglen != 1)
+                errtype=PROTOCOLERROR;
+              msglen=0;
+#endif
+              break;
+            }
+            case 11: {
+              string r, salt;
+              int iters;
+#define HMAC256(key, data)	(Crypto.SHA256.HMAC(key)(data))
+#define HASH256(data)		(Crypto.SHA256.hash(data))
+              PD("AuthenticationSASLContinue\n");
+              { Stdio.Buffer tb = cr->read_buffer(msglen);
+                [r, salt, iters] = tb->sscanf("r=%s,s=%s,i=%d");
+#ifdef PG_DEBUG
+                if (sizeof(tb))
+                  errtype = PROTOCOLERROR;
+                msglen = 0;
+#endif
+              }
+              if (!has_prefix(r, cnonce) || iters <= 0)
+                errtype = PROTOCOLERROR;
+              else {
+                string SaltedPassword;
+                if (!(SaltedPassword
+                  = .pgsql_util.get_salted_password(pass, salt, iters))) {
+                  SaltedPassword = ServerSignature =
+                   HMAC256(pass, MIME.decode_base64(salt) + "\0\0\0\1");
+                  int i = iters;
+                  while (--i)
+                    SaltedPassword ^= ServerSignature
+                     = HMAC256(pass, ServerSignature);
+                  .pgsql_util.set_salted_password(pass, salt, iters,
+                   SaltedPassword);
+                }
+                salt = sprintf("n=,r=%s,r=%s,s=%s,i=%d,c=biws,r=%s",
+                               cnonce, r, salt, iters, r);
+                ServerSignature = HMAC256(SaltedPassword, "Client Key");
+                authresponse(({
+                  sprintf("c=biws,r=%s,p=%s",
+                   r, MIME.encode_base64(
+                        ServerSignature
+                        ^ HMAC256(HASH256(ServerSignature), salt)
+                       ))
+                 }));
+                ServerSignature = HMAC256(HMAC256(SaltedPassword, "Server Key"),
+                  salt);
+              }
+              break;
+            }
+            case 12: {
+              string v;
+              PD("AuthenticationSASLFinal\n");
+              Stdio.Buffer tb = cr->read_buffer(msglen);
+              [v] = tb->sscanf("v=%s");
+              v = MIME.decode_base64(v);
+              if (v != ServerSignature)
+                errtype = PROTOCOLERROR;
+              else
+                cnonce = 0;		// Clears cnonce and approves server
+#ifdef PG_DEBUG
+              if (sizeof(tb))
+                errtype=PROTOCOLERROR;
+              msglen=0;
+#endif
+              break;
+            }
             default:
               PD("Unknown Authentication Method %c\n",authtype);
               errtype=PROTOCOLUNSUPPORTED;
               break;
           }
           switch(errtype) {
-            case NOERROR:
-              if(cancelsecret!="") {
-                object cs = ci->start();
-                CHAIN(cs)->add_int8('p')->add_hstring(({sendpass, 0}), 4, 4);
-                cs->sendcmd(SENDOUT);
-              }
-              break;	// No flushing here, PostgreSQL 9.4 disapproves
             default:
             case PROTOCOLUNSUPPORTED:
               ERROR("Unsupported authenticationmethod %c\n",authtype);
+            case NOERROR:
               break;
           }
           break;
diff --git a/lib/modules/Sql.pmod/pgsql_util.pmod b/lib/modules/Sql.pmod/pgsql_util.pmod
index de8032dbe5..0b9ca3433a 100644
--- a/lib/modules/Sql.pmod/pgsql_util.pmod
+++ b/lib/modules/Sql.pmod/pgsql_util.pmod
@@ -93,6 +93,20 @@ private Regexp iregexp(string expr) {
   return Regexp(ret->read());
 }
 
+private string cached_pass, cached_salt, cached_SaltedPassword;
+private int cached_iters;
+
+final string get_salted_password(string pass, string salt, int iters) {
+  return cached_pass == pass && cached_salt == salt && cached_iters == iters
+   && cached_SaltedPassword;
+}
+
+final void set_salted_password(string pass, string salt, int iters,
+ string SaltedPassword) {
+  cached_pass = pass; cached_salt = salt; cached_iters = iters;
+  cached_SaltedPassword = SaltedPassword;
+}
+
 final void closestatement(bufcon|conxsess plugbuffer,string oldprep) {
   if(oldprep) {
     PD("Close statement %s\n",oldprep);
-- 
GitLab