From 62ca280bc81f107ea9b00704081bee1050feed2d Mon Sep 17 00:00:00 2001
From: "Stephen R. van den Berg" <srb@cuci.nl>
Date: Sun, 12 Nov 2017 19:27:44 +0100
Subject: [PATCH] Crypto.SCRAM: New module.

---
 lib/modules/Crypto.pmod/SCRAM.pike   | 206 +++++++++++++++++++++++++++
 lib/modules/Crypto.pmod/module.pmod  |  18 +++
 lib/modules/Crypto.pmod/testsuite.in |  24 ++++
 lib/modules/Sql.pmod/pgsql.pike      |  65 +++------
 lib/modules/Sql.pmod/pgsql_util.pmod |  20 ---
 5 files changed, 264 insertions(+), 69 deletions(-)
 create mode 100644 lib/modules/Crypto.pmod/SCRAM.pike

diff --git a/lib/modules/Crypto.pmod/SCRAM.pike b/lib/modules/Crypto.pmod/SCRAM.pike
new file mode 100644
index 0000000000..61956cc4ea
--- /dev/null
+++ b/lib/modules/Crypto.pmod/SCRAM.pike
@@ -0,0 +1,206 @@
+
+//! SCRAM, defined by @rfc{5802@}.
+//!
+//! This implements both the client- and the serverside.
+//! You normally run either the server or the client, but if you would
+//! run both (use a separate client and a separate server object!),
+//! the sequence would be:
+//!
+//! @[client_1] -> @[server_1] -> @[server_2] -> @[client_2] ->
+//! @[server_3] -> @[client_3]
+//!
+//! @note
+//!   This implementation does not pretend to support the full protocol.
+//!   Most notably optional extension arguments are not supported (yet).
+
+#pike __REAL_VERSION__
+#pragma strict_types
+#require constant(Crypto.Hash)
+
+private .Hash H;  // hash object
+
+private string(8bit) first, nonce;
+
+private string(7bit) encode64(string(8bit) raw) {
+  return MIME.encode_base64(raw, 1);
+}
+
+private string(7bit) randomstring() {
+  return encode64(random_string(18));
+}
+
+private .MAC.State HMAC(string(8bit) key) {
+  return H->HMAC(key);
+}
+
+private string(7bit) clientproof(string(8bit) salted_password) {
+  .MAC.State hmacsaltedpw = HMAC(salted_password);
+  salted_password = hmacsaltedpw("Client Key");
+  // Returns ServerSignature through nonce
+  nonce = encode64(HMAC(hmacsaltedpw("Server Key"))(first));
+  return encode64([string(8bit)]
+                  (salted_password ^ HMAC(H->hash(salted_password))(first)));
+}
+
+//! Step 0 in the SCRAM handshake, prior to creating the object,
+//! you need to have agreed with your peer on the hashfunction to be used.
+//!
+//! @param h
+//!   The hash object on which the SCRAM object should base its
+//!   operations. Typical input is @[Crypto.SHA256].
+//!
+//! @note
+//! If you are a client, you must use the @ref{client_*@} methods; if you are
+//! a server, you must use the @ref{server_*@} methods.
+//! You cannot mix both client and server methods in a single object.
+//!
+//! @seealso
+//!   @[client_1], @[server_1]
+protected void create(.Hash h) {
+  H = h;
+}
+
+//! Client-side step 1 in the SCRAM handshake.
+//!
+//! @param username
+//!   The username to feed to the server.  Some servers already received
+//!   the username through an alternate channel (usually during
+//!   the hash-function selection handshake), in which case it
+//!   should be omitted here.
+//!
+//! @returns
+//!   The client-first request to send to the server.
+//!
+//! @seealso
+//!   @[client_2]
+string(7bit) client_1(void|string username) {
+  nonce = randomstring();
+  return [string(7bit)](first = [string(8bit)]sprintf("n,,n=%s,r=%s",
+    username && username != "" ? Standards.IDNA.to_ascii(username, 1) : "",
+    nonce));
+}
+
+//! Server-side step 1 in the SCRAM handshake.
+//!
+//! @param line
+//!   The received client-first request from the client.
+//!
+//! @returns
+//!   The username specified by the client.  Returns null
+//!   if the response could not be parsed.
+//!
+//! @seealso
+//!   @[server_2]
+string server_1(string(8bit) line) {
+  constant format = "n,,n=%s,r=%s";
+  string username, r;
+  catch {
+    first = [string(8bit)]line[3..];
+    [username, r] = array_sscanf(line, format);
+    nonce = [string(8bit)]r;
+    r = Standards.IDNA.to_unicode(username);
+  };
+  return r;
+}
+
+//! Server-side step 2 in the SCRAM handshake.
+//!
+//! @param salt
+//!   The salt corresponding to the username that has been specified earlier.
+//!
+//! @param iters
+//!   The number of iterations the hashing algorithm should perform
+//!   to compute the authentication hash.
+//!
+//! @returns
+//!   The server-first challenge to send to the client.
+//!
+//! @seealso
+//!   @[server_3]
+string(7bit) server_2(string(8bit) salt, int iters) {
+  string response = sprintf("r=%s,s=%s,i=%d",
+    nonce += randomstring(), encode64(salt), iters);
+  first += "," + response + ",";
+  return [string(7bit)]response;
+}
+
+//! Client-side step 2 in the SCRAM handshake.
+//!
+//! @param line
+//!   The received server-first challenge from the server.
+//!
+//! @param pass
+//!   The password to feed to the server.
+//!
+//! @returns
+//!   The client-final response to send to the server.  If the response is
+//!   null, the server sent something unacceptable or unparseable.
+//!
+//! @seealso
+//!   @[client_3]
+string(7bit) client_2(string(8bit) line, string pass) {
+  constant format = "r=%s,s=%s,i=%d";
+  string r, salt;
+  int iters;
+  if (!catch([r, salt, iters] = array_sscanf(line, format))
+      && iters > 0
+      && has_prefix(r, nonce)) {
+    line = [string(8bit)]sprintf("c=biws,r=%s", r);
+    first = [string(8bit)]sprintf("%s,r=%s,s=%s,i=%d,%s",
+                                  first[3..], r, salt, iters, line);
+    if (pass != "")
+      pass = Standards.IDNA.to_ascii(pass);
+    salt = MIME.decode_base64(salt);
+    nonce = [string(8bit)]sprintf("%s,%s,%d", pass, salt, iters);
+    if (!(r = .SCRAM_get_salted_password(H, nonce))) {
+      r = [string(8bit)]H->pbkdf2(pass, salt, iters, H->digest_size());
+      .SCRAM_set_salted_password([string(8bit)]r, H, nonce);
+    }
+    salt = sprintf("%s,p=%s", line, clientproof([string(8bit)]r));
+    first = 0;                         // Free memory
+  } else
+    salt = 0;
+  return [string(7bit)]salt;
+}
+
+//! Final server-side step in the SCRAM handshake.
+//!
+//! @param line
+//!   The received client-final challenge and response from the client.
+//!
+//! @param salted_password
+//!   The salted (using the salt provided earlier) password belonging
+//!   to the specified username.
+//!
+//! @returns
+//!   The server-final response to send to the client.  If the response
+//!   is null, the client did not supply the correct credentials or
+//!   the response was unparseable.
+string(7bit) server_3(string(8bit) line,
+ string(8bit) salted_password) {
+  constant format = "c=biws,r=%s,p=%s";
+  string r, p;
+  if (!catch([r, p] = array_sscanf(line, format))
+      && r == nonce) {
+    first += sprintf("c=biws,r=%s", r);
+    p = p == clientproof(salted_password) && sprintf("v=%s", nonce);
+  }
+  return [string(7bit)]p;
+}
+
+//! Final client-side step in the SCRAM handshake.  If we get this far, the
+//! server has already verified that we supplied the correct credentials.
+//! If this step fails, it means the server does not have our
+//! credentials at all and is an imposter.
+//!
+//! @param line
+//!   The received server-final verification response.
+//!
+//! @returns
+//!   True if the server is valid, false if the server is invalid.
+int(0..1) client_3(string(8bit) line) {
+  constant format = "v=%s";
+  string v;
+  return !catch([v] = array_sscanf(line, format))
+         && v == nonce;
+}
diff --git a/lib/modules/Crypto.pmod/module.pmod b/lib/modules/Crypto.pmod/module.pmod
index 3a80b6a933..31f6bce086 100644
--- a/lib/modules/Crypto.pmod/module.pmod
+++ b/lib/modules/Crypto.pmod/module.pmod
@@ -62,6 +62,24 @@ class Sign {
   protected State `()() { return State(); }
 }
 
+// Salted password cache for SCRAM
+
+private mapping (Hash:mapping(string:string(8bit)))
+ SCRAM_salted_password_cache = ([]);
+
+final string(8bit) SCRAM_get_salted_password(Hash h, string key) {
+  mapping(string:string(8bit)) m = SCRAM_salted_password_cache[h];
+  return m && m[key];
+}
+
+final void SCRAM_set_salted_password(string(8bit) SaltedPassword,
+ Hash h, string key) {
+  mapping(string:string(8bit)) m = SCRAM_salted_password_cache[h];
+  if (!m || sizeof(m) > 16)
+    SCRAM_salted_password_cache[h] = m = ([]);
+  m[key] = SaltedPassword;
+}
+
 //! Hashes a @[password] together with a @[salt] with the
 //! crypt_md5 algorithm and returns the result.
 //!
diff --git a/lib/modules/Crypto.pmod/testsuite.in b/lib/modules/Crypto.pmod/testsuite.in
index 77b2891c9c..3f33683c46 100644
--- a/lib/modules/Crypto.pmod/testsuite.in
+++ b/lib/modules/Crypto.pmod/testsuite.in
@@ -554,6 +554,30 @@ test_eq( RSA->rsa_unpad(Gmp.mpz(865312297808908525818041211045691652263021981338
 test_equal(Web.decode_jwk(Web.encode_jwk(RSA, 1)), RSA)
 test_do( add_constant("RSA") )
 
+define(test_SCRAM, [[
+  cond_resolv($1, [[
+    test_any([[
+      Crypto.SCRAM client = Crypto.SCRAM($1);
+      Crypto.SCRAM server = Crypto.SCRAM($1);
+      return $2 == server->server_1(client->client_1($2)) &&
+        client->client_3(server->server_3(
+          client->client_2(server->server_2(MIME.decode_base64($5), $6), $4),
+          MIME.decode_base64($3))
+         || "v=")
+    ]], $7)
+  ]])
+]])
+
+test_SCRAM(Crypto.SHA256, "anyuser",
+     "5EP5GZspS3GlBRRj+FOIwYuLeVMtdcbOeU1n8Sy1uQk=",
+    "scramhard",   "4y5+72rGdDvI77fWm1AHqQ==", 4096, 1)
+test_SCRAM(Crypto.SHA256, "",
+     "5EP5GZspS3GlBRRj+FOIwYuLeVMtdcbOeU1n8Sy1uQk=",
+    "scramhard",   "4y5+72rGdDvI77fWm1AHqQ==", 4096, 1)
+test_SCRAM(Crypto.SHA256, "anyuser",
+     "5EP5GZspS3GlBRRj+FOIwYuLeVMtdcbOeU1n8Sy1uQk=",
+    "scramharder", "4y5+72rGdDvI77fWm1AHqQ==", 4096, 0)
+
 // Hash, JWA
 define(test_jwk_hmac, [[
   cond_resolv($1, [[
diff --git a/lib/modules/Sql.pmod/pgsql.pike b/lib/modules/Sql.pmod/pgsql.pike
index c5f63a4341..d480457fc0 100644
--- a/lib/modules/Sql.pmod/pgsql.pike
+++ b/lib/modules/Sql.pmod/pgsql.pike
@@ -105,7 +105,7 @@ private function (:void) readyforquery_cb;
 final string _host;
 final int _port;
 private string database, user, pass;
-private string cnonce;
+private Crypto.SCRAM SASLcontext;
 private Thread.Condition waitforauthready;
 final Thread.Mutex _shortmux;
 final int _readyforquerycount;
@@ -707,7 +707,7 @@ private void procmessage() {
           return msgresponse;
         };
         case 'R': {
-          void authresponse(array msg) {
+          void authresponse(string|array msg) {
             object cs = ci->start();
             CHAIN(cs)->add_int8('p')->add_hstring(msg, 4, 4);
             cs->sendcmd(SENDOUT); // No flushing, PostgreSQL 9.4 disapproves
@@ -718,7 +718,7 @@ private void procmessage() {
           switch(authtype = cr->read_int32()) {
             case 0:
               PD("Ok\n");
-              if (cnonce) {
+              if (SASLcontext) {
                 PD("Authentication validation still in progress\n");
                 errtype = PROTOCOLUNSUPPORTED;
               } else
@@ -787,8 +787,8 @@ private void procmessage() {
 #endif
               }
               if (k) {
-                cnonce = MIME.encode_base64(random_string(18));
-                word = "n,,n=,r=" + cnonce;
+                SASLcontext = Crypto.SCRAM(Crypto.SHA256);
+                word = SASLcontext.client_1();
                 authresponse(({
                   "SCRAM-SHA-256", 0, sprintf("%4c", sizeof(word)), word
                  }));
@@ -802,61 +802,28 @@ private void procmessage() {
               break;
             }
             case 11: {
-              string r, salt;
-              int iters;
-#define HMAC256(key, data)	(Crypto.SHA256.HMAC(key)(data))
-#define HASH256(data)		(Crypto.SHA256.hash(data))
               PD("AuthenticationSASLContinue\n");
-              { Stdio.Buffer tb = cr->read_buffer(msglen);
-                [r, salt, iters] = tb->sscanf("r=%s,s=%s,i=%d");
+              string response;
+              if (response
+               = SASLcontext.client_2(cr->read(msglen), pass))
+                authresponse(response);
+              else
+                errtype = PROTOCOLERROR;
 #ifdef PG_DEBUG
-                if (sizeof(tb))
-                  errtype = PROTOCOLERROR;
-                msglen = 0;
+              msglen = 0;
 #endif
-              }
-              if (!has_prefix(r, cnonce) || iters <= 0)
-                errtype = PROTOCOLERROR;
-              else {
-                string SaltedPassword;
-                string biws = sprintf("c=biws,r=%s", r);
-                r = sprintf("n=,r=%s,r=%s,s=%s,i=%d,%s",
-                            cnonce, r, salt, iters, biws);
-                if (!(SaltedPassword
-                  = .pgsql_util.get_salted_password(pass, salt, iters))) {
-                  SaltedPassword = cnonce =
-                   HMAC256(pass, MIME.decode_base64(salt) + "\0\0\0\1");
-                  int i = iters;
-                  while (--i)
-                    SaltedPassword ^= cnonce = HMAC256(pass, cnonce);
-                  .pgsql_util.set_salted_password(pass, salt, iters,
-                   SaltedPassword);
-                }
-                salt = HMAC256(SaltedPassword, "Client Key");
-                authresponse(({
-                  sprintf("%s,p=%s", biws,
-                   MIME.encode_base64(salt ^ HMAC256(HASH256(salt), r)))
-                 }));
-                cnonce = HMAC256(HMAC256(SaltedPassword, "Server Key"), r);
-              }
               break;
             }
-            case 12: {
-              string v;
+            case 12:
               PD("AuthenticationSASLFinal\n");
-              Stdio.Buffer tb = cr->read_buffer(msglen);
-              [v] = tb->sscanf("v=%s");
-              if (MIME.decode_base64(v) != cnonce)
-                errtype = PROTOCOLERROR;
+              if (SASLcontext.client_3(cr->read(msglen)))
+                SASLcontext = 0;	// Clears context and approves server
               else
-                cnonce = 0;		// Clears cnonce and approves server
+                errtype = PROTOCOLERROR;
 #ifdef PG_DEBUG
-              if (sizeof(tb))
-                errtype=PROTOCOLERROR;
               msglen=0;
 #endif
               break;
-            }
             default:
               PD("Unknown Authentication Method %c\n",authtype);
               errtype=PROTOCOLUNSUPPORTED;
diff --git a/lib/modules/Sql.pmod/pgsql_util.pmod b/lib/modules/Sql.pmod/pgsql_util.pmod
index c4c8b3d3e6..e3b1a2f3d3 100644
--- a/lib/modules/Sql.pmod/pgsql_util.pmod
+++ b/lib/modules/Sql.pmod/pgsql_util.pmod
@@ -93,26 +93,6 @@ private Regexp iregexp(string expr) {
   return Regexp(ret->read());
 }
 
-private Thread.Mutex cachemutex = Thread.Mutex();
-private string cached_pass, cached_salt, cached_SaltedPassword;
-private int cached_iters;
-
-final string get_salted_password(string pass, string salt, int iters) {
-  Thread.MutexKey lock = cachemutex->lock();
-  string ret = cached_pass == pass && cached_salt == salt
-   && cached_iters == iters && cached_SaltedPassword;
-  lock = 0;
-  return ret;
-}
-
-final void set_salted_password(string pass, string salt, int iters,
- string SaltedPassword) {
-  Thread.MutexKey lock = cachemutex->lock();
-  cached_pass = pass; cached_salt = salt; cached_iters = iters;
-  cached_SaltedPassword = SaltedPassword;
-  lock = 0;
-}
-
 final void closestatement(bufcon|conxsess plugbuffer,string oldprep) {
   if(oldprep) {
     PD("Close statement %s\n",oldprep);
-- 
GitLab