From 0eebf7a795f7f76fc77e2a97b0e823815f5838cc Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Sun, 24 Nov 2024 12:16:37 +0000 Subject: [PATCH 1/3] feat: support tls-server-end-point channel binding for SCRAM-SHA-256 --- buf.go | 14 +++++++++ conn.go | 48 ++++++++++++++++++++++++++++-- scram/scram.go | 80 ++++++++++++++++++++++++++++++++++++++------------ ssl.go | 30 +++++++++++++++++-- 4 files changed, 150 insertions(+), 22 deletions(-) diff --git a/buf.go b/buf.go index 4b0a0a8f7..0256fcd99 100644 --- a/buf.go +++ b/buf.go @@ -38,6 +38,20 @@ func (b *readBuf) string() string { return string(s) } +func (b *readBuf) strings() []string { + ss := []string{} + for (*b)[0] != 0 { + i := bytes.IndexByte(*b, 0) + if i < 0 { + errorf("invalid message format; expected string terminator") + } + s := (*b)[:i] + *b = (*b)[i+1:] + ss = append(ss, string(s)) + } + return ss +} + func (b *readBuf) next(n int) (v []byte) { v = (*b)[:n] *b = (*b)[n:] diff --git a/conn.go b/conn.go index bc0983608..74c169063 100644 --- a/conn.go +++ b/conn.go @@ -170,6 +170,9 @@ type conn struct { // GSSAPI context gss GSS + + // channel binding data used for SCRAM-SHA-256-PLUS + tlsServerEndPoint []byte } type syncErr struct { @@ -1129,7 +1132,18 @@ func (cn *conn) ssl(o values) error { return ErrSSLNotSupported } - cn.c, err = upgrade(cn.c) + conn, err := upgrade(cn.c) + if err != nil { + return err + } + + cb, err := tlsServerEndPoint(conn) + if err != nil { + return err + } + + cn.c = conn + cn.tlsServerEndPoint = cb return err } @@ -1289,7 +1303,37 @@ func (cn *conn) auth(r *readBuf, o values) { // from the server.. case 10: + supported := r.strings() + + scramSha256 := false + scramSha256Plus := false + for _, s := range supported { + switch s { + case "SCRAM-SHA-256": + scramSha256 = true + case "SCRAM-SHA-256-PLUS": + scramSha256Plus = true + } + } + sc := scram.NewClient(sha256.New, o["user"], o["password"]) + + // channel binding is supported by the client + if cn.tlsServerEndPoint != nil { + sc.WithTlsServerEndPoint(cn.tlsServerEndPoint) + } + + var selected string + // SCRAM-SHA-256-PLUS always takes preference. + if cn.tlsServerEndPoint != nil && scramSha256Plus { + sc.UseChannelBinding() + selected = "SCRAM-SHA-256-PLUS" + } else if scramSha256 { + selected = "SCRAM-SHA-256" + } else { + errorf("SCRAM-SHA-256 protocol error") + } + sc.Step(nil) if sc.Err() != nil { errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) @@ -1297,7 +1341,7 @@ func (cn *conn) auth(r *readBuf, o values) { scOut := sc.Out() w := cn.writeBuf('p') - w.string("SCRAM-SHA-256") + w.string(selected) w.int32(len(scOut)) w.bytes(scOut) cn.send(w) diff --git a/scram/scram.go b/scram/scram.go index 477216b60..26e4e4d75 100644 --- a/scram/scram.go +++ b/scram/scram.go @@ -25,7 +25,6 @@ // Package scram implements a SCRAM-{SHA-1,etc} client per RFC5802. // // http://tools.ietf.org/html/rfc5802 -// package scram import ( @@ -43,17 +42,16 @@ import ( // // A Client may be used within a SASL conversation with logic resembling: // -// var in []byte -// var client = scram.NewClient(sha1.New, user, pass) -// for client.Step(in) { -// out := client.Out() -// // send out to server -// in := serverOut -// } -// if client.Err() != nil { -// // auth failed -// } -// +// var in []byte +// var client = scram.NewClient(sha1.New, user, pass) +// for client.Step(in) { +// out := client.Out() +// // send out to server +// in := serverOut +// } +// if client.Err() != nil { +// // auth failed +// } type Client struct { newHash func() hash.Hash @@ -67,14 +65,17 @@ type Client struct { serverNonce []byte saltedPass []byte authMsg bytes.Buffer + + // channel binding data used for SCRAM-SHA-256-PLUS + tlsServerEndPoint []byte + channelBinding bool } // NewClient returns a new SCRAM-* client with the provided hash algorithm. // // For SCRAM-SHA-256, for example, use: // -// client := scram.NewClient(sha256.New, user, pass) -// +// client := scram.NewClient(sha256.New, user, pass) func NewClient(newHash func() hash.Hash, user, pass string) *Client { c := &Client{ newHash: newHash, @@ -86,6 +87,16 @@ func NewClient(newHash func() hash.Hash, user, pass string) *Client { return c } +// Out returns the data to be sent to the server in the current step. +func (c *Client) WithTlsServerEndPoint(cb []byte) { + c.tlsServerEndPoint = cb +} + +// Out returns the data to be sent to the server in the current step. +func (c *Client) UseChannelBinding() { + c.channelBinding = true +} + // Out returns the data to be sent to the server in the current step. func (c *Client) Out() []byte { if c.out.Len() == 0 { @@ -143,7 +154,17 @@ func (c *Client) step1(in []byte) error { c.authMsg.WriteString(",r=") c.authMsg.Write(c.clientNonce) - c.out.WriteString("n,,") + if c.tlsServerEndPoint != nil && c.channelBinding { + // we support channel binding, and so does the server + c.out.WriteString("p=tls-server-end-point,,") + } else if c.tlsServerEndPoint != nil { + // we support channel binding, but the server doesn't. + c.out.WriteString("y,,") + } else { + // we do not support channel binding. + c.out.WriteString("n,,") + } + c.out.Write(c.authMsg.Bytes()) return nil } @@ -185,11 +206,34 @@ func (c *Client) step2(in []byte) error { } c.saltPassword(salt, iterCount) - c.authMsg.WriteString(",c=biws,r=") - c.authMsg.Write(c.serverNonce) + // channel binding: + c.authMsg.WriteString(",c=") + c.out.WriteString("c=") + + var mode string + if c.tlsServerEndPoint != nil && c.channelBinding { + // we support channel binding, and so does the server + data := []byte("p=tls-server-end-point,,") + data = append(data, c.tlsServerEndPoint...) + + mode = base64.StdEncoding.EncodeToString(data) + } else if c.tlsServerEndPoint != nil { + // we support channel binding, but the server doesn't. + mode = "eSws" + } else { + // we do not support channel binding. + mode = "biws" + } + c.authMsg.WriteString(mode) + c.out.WriteString(mode) + + // server nonce + c.authMsg.WriteString(",r=") + c.out.WriteString(",r=") - c.out.WriteString("c=biws,r=") + c.authMsg.Write(c.serverNonce) c.out.Write(c.serverNonce) + c.out.WriteString(",p=") c.out.Write(c.clientProof()) return nil diff --git a/ssl.go b/ssl.go index 36b61ba45..731e620b2 100644 --- a/ssl.go +++ b/ssl.go @@ -1,6 +1,7 @@ package pq import ( + "crypto" "crypto/tls" "crypto/x509" "io/ioutil" @@ -13,7 +14,7 @@ import ( // ssl generates a function to upgrade a net.Conn based on the "sslmode" and // related settings. The function is nil when no upgrade should take place. -func ssl(o values) (func(net.Conn) (net.Conn, error), error) { +func ssl(o values) (func(net.Conn) (*tls.Conn, error), error) { verifyCaOnly := false tlsConf := tls.Config{} switch mode := o["sslmode"]; mode { @@ -77,7 +78,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { // also initiates renegotiations and cannot be reconfigured. tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient - return func(conn net.Conn) (net.Conn, error) { + return func(conn net.Conn) (*tls.Conn, error) { client := tls.Client(conn, &tlsConf) if verifyCaOnly { err := sslVerifyCertificateAuthority(client, &tlsConf) @@ -202,3 +203,28 @@ func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error _, err = certs[0].Verify(opts) return err } + +func tlsServerEndPoint(conn *tls.Conn) ([]byte, error) { + err := conn.Handshake() + if err != nil { + return nil, err + } + + cert := conn.ConnectionState().PeerCertificates[0] + + // choose the channel binding hash type + // Use the same hash type used for the certificate signature, except for MD5 and SHA-1 which + // use SHA256 + hashType := crypto.SHA256 + switch cert.SignatureAlgorithm { + case x509.SHA384WithRSA, x509.ECDSAWithSHA384, x509.SHA384WithRSAPSS: + hashType = crypto.SHA384 + case x509.SHA512WithRSA, x509.ECDSAWithSHA512, x509.SHA512WithRSAPSS: + hashType = crypto.SHA512 + } + + hasher := hashType.New() + _, _ = hasher.Write(cert.Raw) + data := hasher.Sum(nil) + return data, nil +} From 0d8ed4c5e4b552c6430c658445e07804d7973614 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Sun, 24 Nov 2024 12:27:08 +0000 Subject: [PATCH 2/3] add channel_binding require --- conn.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 74c169063..8f82890b7 100644 --- a/conn.go +++ b/conn.go @@ -1137,13 +1137,15 @@ func (cn *conn) ssl(o values) error { return err } - cb, err := tlsServerEndPoint(conn) - if err != nil { - return err + if o["channel_binding"] != "disable" { + cb, err := tlsServerEndPoint(conn) + if err != nil { + return err + } + cn.tlsServerEndPoint = cb } cn.c = conn - cn.tlsServerEndPoint = cb return err } @@ -1334,6 +1336,10 @@ func (cn *conn) auth(r *readBuf, o values) { errorf("SCRAM-SHA-256 protocol error") } + if o["channel_binding"] == "required" && selected != "SCRAM-SHA-256-PLUS" { + errorf("SCRAM-SHA-256 protocol error: channel binding required") + } + sc.Step(nil) if sc.Err() != nil { errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) From cff584e6d4d36d71ea6ee2a47c165bd41db45297 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Sun, 24 Nov 2024 12:28:31 +0000 Subject: [PATCH 3/3] add channel binding requirement checks to all other auth branches --- conn.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/conn.go b/conn.go index 8f82890b7..5b111d2b5 100644 --- a/conn.go +++ b/conn.go @@ -1223,8 +1223,16 @@ func (cn *conn) startup(o values) { func (cn *conn) auth(r *readBuf, o values) { switch code := r.int32(); code { case 0: + if o["channel_binding"] == "required" { + errorf("SCRAM-SHA-256 protocol error: channel binding required") + } + // OK case 3: + if o["channel_binding"] == "required" { + errorf("SCRAM-SHA-256 protocol error: channel binding required") + } + w := cn.writeBuf('p') w.string(o["password"]) cn.send(w) @@ -1238,6 +1246,10 @@ func (cn *conn) auth(r *readBuf, o values) { errorf("unexpected authentication response: %q", t) } case 5: + if o["channel_binding"] == "required" { + errorf("SCRAM-SHA-256 protocol error: channel binding required") + } + s := string(r.next(4)) w := cn.writeBuf('p') w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) @@ -1252,6 +1264,10 @@ func (cn *conn) auth(r *readBuf, o values) { errorf("unexpected authentication response: %q", t) } case 7: // GSSAPI, startup + if o["channel_binding"] == "required" { + errorf("SCRAM-SHA-256 protocol error: channel binding required") + } + if newGss == nil { errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)") } @@ -1287,6 +1303,9 @@ func (cn *conn) auth(r *readBuf, o values) { cn.gss = cli case 8: // GSSAPI continue + if o["channel_binding"] == "required" { + errorf("SCRAM-SHA-256 protocol error: channel binding required") + } if cn.gss == nil { errorf("GSSAPI protocol error")