Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support tls-server-end-point channel binding for SCRAM-SHA-256 #1181

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions buf.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ func (b *readBuf) string() string {
return string(s)
}

func (b *readBuf) strings() []string {
ss := []string{}
for (*b)[0] != 0 {
i := bytes.IndexByte(*b, 0)
if i < 0 {
errorf("invalid message format; expected string terminator")
}
s := (*b)[:i]
*b = (*b)[i+1:]
ss = append(ss, string(s))
}
return ss
}

func (b *readBuf) next(n int) (v []byte) {
v = (*b)[:n]
*b = (*b)[n:]
Expand Down
73 changes: 71 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ type conn struct {

// GSSAPI context
gss GSS

// channel binding data used for SCRAM-SHA-256-PLUS
tlsServerEndPoint []byte
}

type syncErr struct {
Expand Down Expand Up @@ -1129,7 +1132,20 @@ func (cn *conn) ssl(o values) error {
return ErrSSLNotSupported
}

cn.c, err = upgrade(cn.c)
conn, err := upgrade(cn.c)
if err != nil {
return err
}

if o["channel_binding"] != "disable" {
cb, err := tlsServerEndPoint(conn)
if err != nil {
return err
}
cn.tlsServerEndPoint = cb
}

cn.c = conn
return err
}

Expand Down Expand Up @@ -1207,8 +1223,16 @@ func (cn *conn) startup(o values) {
func (cn *conn) auth(r *readBuf, o values) {
switch code := r.int32(); code {
case 0:
if o["channel_binding"] == "required" {
errorf("SCRAM-SHA-256 protocol error: channel binding required")
}

// OK
case 3:
if o["channel_binding"] == "required" {
errorf("SCRAM-SHA-256 protocol error: channel binding required")
}

w := cn.writeBuf('p')
w.string(o["password"])
cn.send(w)
Expand All @@ -1222,6 +1246,10 @@ func (cn *conn) auth(r *readBuf, o values) {
errorf("unexpected authentication response: %q", t)
}
case 5:
if o["channel_binding"] == "required" {
errorf("SCRAM-SHA-256 protocol error: channel binding required")
}

s := string(r.next(4))
w := cn.writeBuf('p')
w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
Expand All @@ -1236,6 +1264,10 @@ func (cn *conn) auth(r *readBuf, o values) {
errorf("unexpected authentication response: %q", t)
}
case 7: // GSSAPI, startup
if o["channel_binding"] == "required" {
errorf("SCRAM-SHA-256 protocol error: channel binding required")
}

if newGss == nil {
errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)")
}
Expand Down Expand Up @@ -1271,6 +1303,9 @@ func (cn *conn) auth(r *readBuf, o values) {
cn.gss = cli

case 8: // GSSAPI continue
if o["channel_binding"] == "required" {
errorf("SCRAM-SHA-256 protocol error: channel binding required")
}

if cn.gss == nil {
errorf("GSSAPI protocol error")
Expand All @@ -1289,15 +1324,49 @@ func (cn *conn) auth(r *readBuf, o values) {
// from the server..

case 10:
supported := r.strings()

scramSha256 := false
scramSha256Plus := false
for _, s := range supported {
switch s {
case "SCRAM-SHA-256":
scramSha256 = true
case "SCRAM-SHA-256-PLUS":
scramSha256Plus = true
}
}

sc := scram.NewClient(sha256.New, o["user"], o["password"])

// channel binding is supported by the client
if cn.tlsServerEndPoint != nil {
sc.WithTlsServerEndPoint(cn.tlsServerEndPoint)
}

var selected string
// SCRAM-SHA-256-PLUS always takes preference.
if cn.tlsServerEndPoint != nil && scramSha256Plus {
sc.UseChannelBinding()
selected = "SCRAM-SHA-256-PLUS"
} else if scramSha256 {
selected = "SCRAM-SHA-256"
} else {
errorf("SCRAM-SHA-256 protocol error")
}

if o["channel_binding"] == "required" && selected != "SCRAM-SHA-256-PLUS" {
errorf("SCRAM-SHA-256 protocol error: channel binding required")
}

sc.Step(nil)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
}
scOut := sc.Out()

w := cn.writeBuf('p')
w.string("SCRAM-SHA-256")
w.string(selected)
w.int32(len(scOut))
w.bytes(scOut)
cn.send(w)
Expand Down
80 changes: 62 additions & 18 deletions scram/scram.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
// Package scram implements a SCRAM-{SHA-1,etc} client per RFC5802.
//
// http://tools.ietf.org/html/rfc5802
//
package scram

import (
Expand All @@ -43,17 +42,16 @@ import (
//
// A Client may be used within a SASL conversation with logic resembling:
//
// var in []byte
// var client = scram.NewClient(sha1.New, user, pass)
// for client.Step(in) {
// out := client.Out()
// // send out to server
// in := serverOut
// }
// if client.Err() != nil {
// // auth failed
// }
//
// var in []byte
// var client = scram.NewClient(sha1.New, user, pass)
// for client.Step(in) {
// out := client.Out()
// // send out to server
// in := serverOut
// }
// if client.Err() != nil {
// // auth failed
// }
type Client struct {
newHash func() hash.Hash

Expand All @@ -67,14 +65,17 @@ type Client struct {
serverNonce []byte
saltedPass []byte
authMsg bytes.Buffer

// channel binding data used for SCRAM-SHA-256-PLUS
tlsServerEndPoint []byte
channelBinding bool
}

// NewClient returns a new SCRAM-* client with the provided hash algorithm.
//
// For SCRAM-SHA-256, for example, use:
//
// client := scram.NewClient(sha256.New, user, pass)
//
// client := scram.NewClient(sha256.New, user, pass)
func NewClient(newHash func() hash.Hash, user, pass string) *Client {
c := &Client{
newHash: newHash,
Expand All @@ -86,6 +87,16 @@ func NewClient(newHash func() hash.Hash, user, pass string) *Client {
return c
}

// Out returns the data to be sent to the server in the current step.
func (c *Client) WithTlsServerEndPoint(cb []byte) {
c.tlsServerEndPoint = cb
}

// Out returns the data to be sent to the server in the current step.
func (c *Client) UseChannelBinding() {
c.channelBinding = true
}

// Out returns the data to be sent to the server in the current step.
func (c *Client) Out() []byte {
if c.out.Len() == 0 {
Expand Down Expand Up @@ -143,7 +154,17 @@ func (c *Client) step1(in []byte) error {
c.authMsg.WriteString(",r=")
c.authMsg.Write(c.clientNonce)

c.out.WriteString("n,,")
if c.tlsServerEndPoint != nil && c.channelBinding {
// we support channel binding, and so does the server
c.out.WriteString("p=tls-server-end-point,,")
} else if c.tlsServerEndPoint != nil {
// we support channel binding, but the server doesn't.
c.out.WriteString("y,,")
} else {
// we do not support channel binding.
c.out.WriteString("n,,")
}

c.out.Write(c.authMsg.Bytes())
return nil
}
Expand Down Expand Up @@ -185,11 +206,34 @@ func (c *Client) step2(in []byte) error {
}
c.saltPassword(salt, iterCount)

c.authMsg.WriteString(",c=biws,r=")
c.authMsg.Write(c.serverNonce)
// channel binding:
c.authMsg.WriteString(",c=")
c.out.WriteString("c=")

var mode string
if c.tlsServerEndPoint != nil && c.channelBinding {
// we support channel binding, and so does the server
data := []byte("p=tls-server-end-point,,")
data = append(data, c.tlsServerEndPoint...)

mode = base64.StdEncoding.EncodeToString(data)
} else if c.tlsServerEndPoint != nil {
// we support channel binding, but the server doesn't.
mode = "eSws"
} else {
// we do not support channel binding.
mode = "biws"
}
c.authMsg.WriteString(mode)
c.out.WriteString(mode)

// server nonce
c.authMsg.WriteString(",r=")
c.out.WriteString(",r=")

c.out.WriteString("c=biws,r=")
c.authMsg.Write(c.serverNonce)
c.out.Write(c.serverNonce)

c.out.WriteString(",p=")
c.out.Write(c.clientProof())
return nil
Expand Down
30 changes: 28 additions & 2 deletions ssl.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pq

import (
"crypto"
"crypto/tls"
"crypto/x509"
"io/ioutil"
Expand All @@ -13,7 +14,7 @@ import (

// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
// related settings. The function is nil when no upgrade should take place.
func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
func ssl(o values) (func(net.Conn) (*tls.Conn, error), error) {
verifyCaOnly := false
tlsConf := tls.Config{}
switch mode := o["sslmode"]; mode {
Expand Down Expand Up @@ -77,7 +78,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
// also initiates renegotiations and cannot be reconfigured.
tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient

return func(conn net.Conn) (net.Conn, error) {
return func(conn net.Conn) (*tls.Conn, error) {
client := tls.Client(conn, &tlsConf)
if verifyCaOnly {
err := sslVerifyCertificateAuthority(client, &tlsConf)
Expand Down Expand Up @@ -202,3 +203,28 @@ func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error
_, err = certs[0].Verify(opts)
return err
}

func tlsServerEndPoint(conn *tls.Conn) ([]byte, error) {
err := conn.Handshake()
if err != nil {
return nil, err
}

cert := conn.ConnectionState().PeerCertificates[0]

// choose the channel binding hash type
// Use the same hash type used for the certificate signature, except for MD5 and SHA-1 which
// use SHA256
hashType := crypto.SHA256
switch cert.SignatureAlgorithm {
case x509.SHA384WithRSA, x509.ECDSAWithSHA384, x509.SHA384WithRSAPSS:
hashType = crypto.SHA384
case x509.SHA512WithRSA, x509.ECDSAWithSHA512, x509.SHA512WithRSAPSS:
hashType = crypto.SHA512
}

hasher := hashType.New()
_, _ = hasher.Write(cert.Raw)
data := hasher.Sum(nil)
return data, nil
}