From 26208313da4c28110bddaf9d19180d64d2295fc0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Apr 2024 15:14:21 +0000 Subject: [PATCH] build(deps): bump github.com/jackc/pgproto3/v2 from 2.0.6 to 2.3.3 Bumps [github.com/jackc/pgproto3/v2](https://github.com/jackc/pgproto3) from 2.0.6 to 2.3.3. - [Commits](https://github.com/jackc/pgproto3/compare/v2.0.6...v2.3.3) --- updated-dependencies: - dependency-name: github.com/jackc/pgproto3/v2 dependency-type: indirect ... Signed-off-by: dependabot[bot] --- go.mod | 2 +- go.sum | 3 +- vendor/github.com/jackc/pgproto3/v2/README.md | 6 + .../v2/authentication_cleartext_password.go | 20 +- .../jackc/pgproto3/v2/authentication_gss.go | 58 ++++ .../v2/authentication_gss_continue.go | 67 +++++ .../v2/authentication_md5_password.go | 41 ++- .../jackc/pgproto3/v2/authentication_ok.go | 20 +- .../jackc/pgproto3/v2/authentication_sasl.go | 23 +- .../v2/authentication_sasl_continue.go | 41 ++- .../pgproto3/v2/authentication_sasl_final.go | 41 ++- .../github.com/jackc/pgproto3/v2/backend.go | 114 +++++-- .../jackc/pgproto3/v2/backend_key_data.go | 7 +- vendor/github.com/jackc/pgproto3/v2/bind.go | 54 +++- .../jackc/pgproto3/v2/bind_complete.go | 4 +- .../jackc/pgproto3/v2/cancel_request.go | 4 +- vendor/github.com/jackc/pgproto3/v2/close.go | 39 ++- .../jackc/pgproto3/v2/close_complete.go | 4 +- .../jackc/pgproto3/v2/command_complete.go | 32 +- .../jackc/pgproto3/v2/copy_both_response.go | 41 ++- .../github.com/jackc/pgproto3/v2/copy_data.go | 27 +- .../github.com/jackc/pgproto3/v2/copy_done.go | 4 +- .../github.com/jackc/pgproto3/v2/copy_fail.go | 14 +- .../jackc/pgproto3/v2/copy_in_response.go | 39 ++- .../jackc/pgproto3/v2/copy_out_response.go | 39 ++- .../github.com/jackc/pgproto3/v2/data_row.go | 40 ++- .../github.com/jackc/pgproto3/v2/describe.go | 38 ++- .../jackc/pgproto3/v2/empty_query_response.go | 4 +- .../jackc/pgproto3/v2/error_response.go | 278 ++++++++++++------ .../github.com/jackc/pgproto3/v2/execute.go | 13 +- vendor/github.com/jackc/pgproto3/v2/flush.go | 4 +- .../github.com/jackc/pgproto3/v2/frontend.go | 38 ++- .../jackc/pgproto3/v2/function_call.go | 102 +++++++ .../pgproto3/v2/function_call_response.go | 28 +- .../jackc/pgproto3/v2/gss_enc_request.go | 4 +- .../jackc/pgproto3/v2/gss_response.go | 46 +++ .../github.com/jackc/pgproto3/v2/no_data.go | 4 +- .../jackc/pgproto3/v2/notice_response.go | 6 +- .../pgproto3/v2/notification_response.go | 12 +- .../pgproto3/v2/parameter_description.go | 15 +- .../jackc/pgproto3/v2/parameter_status.go | 14 +- vendor/github.com/jackc/pgproto3/v2/parse.go | 15 +- .../jackc/pgproto3/v2/parse_complete.go | 4 +- .../jackc/pgproto3/v2/password_message.go | 14 +- .../github.com/jackc/pgproto3/v2/pgproto3.go | 53 +++- .../jackc/pgproto3/v2/portal_suspended.go | 4 +- vendor/github.com/jackc/pgproto3/v2/query.go | 11 +- .../jackc/pgproto3/v2/ready_for_query.go | 25 +- .../jackc/pgproto3/v2/row_description.go | 46 ++- .../pgproto3/v2/sasl_initial_response.go | 32 +- .../jackc/pgproto3/v2/sasl_response.go | 26 +- .../jackc/pgproto3/v2/ssl_request.go | 4 +- .../jackc/pgproto3/v2/startup_message.go | 6 +- vendor/github.com/jackc/pgproto3/v2/sync.go | 4 +- .../github.com/jackc/pgproto3/v2/terminate.go | 4 +- vendor/modules.txt | 2 +- 56 files changed, 1277 insertions(+), 363 deletions(-) create mode 100644 vendor/github.com/jackc/pgproto3/v2/authentication_gss.go create mode 100644 vendor/github.com/jackc/pgproto3/v2/authentication_gss_continue.go create mode 100644 vendor/github.com/jackc/pgproto3/v2/function_call.go create mode 100644 vendor/github.com/jackc/pgproto3/v2/gss_response.go diff --git a/go.mod b/go.mod index 961ce6678..86a626c98 100644 --- a/go.mod +++ b/go.mod @@ -68,7 +68,7 @@ require ( github.com/jackc/pgconn v1.8.1 // indirect github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgproto3/v2 v2.0.6 // indirect + github.com/jackc/pgproto3/v2 v2.3.3 // indirect github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect github.com/jackc/pgtype v1.7.0 // indirect github.com/jackc/pgx/v4 v4.11.0 // indirect diff --git a/go.sum b/go.sum index e1fcee307..0569320a2 100644 --- a/go.sum +++ b/go.sum @@ -368,8 +368,9 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.0.6 h1:b1105ZGEMFe7aCvrT1Cca3VoVb4ZFMaFJLJcg/3zD+8= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.3.3 h1:1HLSx5H+tXR9pW3in3zaztoEwQYRC9SQaYUHjTSUOag= +github.com/jackc/pgproto3/v2 v2.3.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= diff --git a/vendor/github.com/jackc/pgproto3/v2/README.md b/vendor/github.com/jackc/pgproto3/v2/README.md index 565b3efd5..77a31700a 100644 --- a/vendor/github.com/jackc/pgproto3/v2/README.md +++ b/vendor/github.com/jackc/pgproto3/v2/README.md @@ -1,6 +1,12 @@ [![](https://godoc.org/github.com/jackc/pgproto3?status.svg)](https://godoc.org/github.com/jackc/pgproto3) [![Build Status](https://travis-ci.org/jackc/pgproto3.svg)](https://travis-ci.org/jackc/pgproto3) +--- + +This version is used with pgx `v4`. In pgx `v5` it is part of the https://github.com/jackc/pgx repository. + +--- + # pgproto3 Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_cleartext_password.go b/vendor/github.com/jackc/pgproto3/v2/authentication_cleartext_password.go index dd82c7a77..1ec219bc3 100644 --- a/vendor/github.com/jackc/pgproto3/v2/authentication_cleartext_password.go +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_cleartext_password.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -14,6 +15,9 @@ type AuthenticationCleartextPassword struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationCleartextPassword) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationCleartextPassword) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { @@ -31,9 +35,17 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 8) +func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) - return dst + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "AuthenticationCleartextPassword", + }) } diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_gss.go b/vendor/github.com/jackc/pgproto3/v2/authentication_gss.go new file mode 100644 index 000000000..425be6efa --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_gss.go @@ -0,0 +1,58 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +type AuthenticationGSS struct{} + +func (a *AuthenticationGSS) Backend() {} + +func (a *AuthenticationGSS) AuthenticationResponse() {} + +func (a *AuthenticationGSS) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeGSS { + return errors.New("bad auth type") + } + return nil +} + +func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') + dst = pgio.AppendUint32(dst, AuthTypeGSS) + return finishMessage(dst, sp) +} + +func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "AuthenticationGSS", + }) +} + +func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_gss_continue.go b/vendor/github.com/jackc/pgproto3/v2/authentication_gss_continue.go new file mode 100644 index 000000000..42a70daf9 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_gss_continue.go @@ -0,0 +1,67 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +type AuthenticationGSSContinue struct { + Data []byte +} + +func (a *AuthenticationGSSContinue) Backend() {} + +func (a *AuthenticationGSSContinue) AuthenticationResponse() {} + +func (a *AuthenticationGSSContinue) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeGSSCont { + return errors.New("bad auth type") + } + + a.Data = src[4:] + return nil +} + +func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') + dst = pgio.AppendUint32(dst, AuthTypeGSSCont) + dst = append(dst, a.Data...) + return finishMessage(dst, sp) +} + +func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "AuthenticationGSSContinue", + Data: a.Data, + }) +} + +func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Data []byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + a.Data = msg.Data + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_md5_password.go b/vendor/github.com/jackc/pgproto3/v2/authentication_md5_password.go index d505d2649..9c0f5ee08 100644 --- a/vendor/github.com/jackc/pgproto3/v2/authentication_md5_password.go +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_md5_password.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -15,6 +16,9 @@ type AuthenticationMD5Password struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationMD5Password) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationMD5Password) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationMD5Password) Decode(src []byte) error { @@ -34,10 +38,39 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 12) +func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeMD5Password) dst = append(dst, src.Salt[:]...) - return dst + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Salt [4]byte + }{ + Type: "AuthenticationMD5Password", + Salt: src.Salt, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Salt [4]byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Salt = msg.Salt + return nil } diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_ok.go b/vendor/github.com/jackc/pgproto3/v2/authentication_ok.go index 7b13c6e01..021f820fe 100644 --- a/vendor/github.com/jackc/pgproto3/v2/authentication_ok.go +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_ok.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -14,6 +15,9 @@ type AuthenticationOk struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationOk) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationOk) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationOk) Decode(src []byte) error { @@ -31,9 +35,17 @@ func (dst *AuthenticationOk) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationOk) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 8) +func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeOk) - return dst + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationOk) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "AuthenticationOK", + }) } diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_sasl.go b/vendor/github.com/jackc/pgproto3/v2/authentication_sasl.go index c57ae32de..b56461cd3 100644 --- a/vendor/github.com/jackc/pgproto3/v2/authentication_sasl.go +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_sasl.go @@ -3,6 +3,7 @@ package pgproto3 import ( "bytes" "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -16,6 +17,9 @@ type AuthenticationSASL struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationSASL) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationSASL) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationSASL) Decode(src []byte) error { @@ -42,10 +46,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASL) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASL) for _, s := range src.AuthMechanisms { @@ -54,7 +56,16 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte { } dst = append(dst, 0) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + return finishMessage(dst, sp) +} - return dst +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationSASL) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + AuthMechanisms []string + }{ + Type: "AuthenticationSASL", + AuthMechanisms: src.AuthMechanisms, + }) } diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_continue.go b/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_continue.go index 1b918a6ef..d405b1293 100644 --- a/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_continue.go +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_continue.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -15,6 +16,9 @@ type AuthenticationSASLContinue struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationSASLContinue) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLContinue) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationSASLContinue) Decode(src []byte) error { @@ -34,15 +38,38 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) - dst = append(dst, src.Data...) + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "AuthenticationSASLContinue", + Data: string(src.Data), + }) +} - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } - return dst + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil } diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_final.go b/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_final.go index 11d356600..c34ac4e6b 100644 --- a/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_final.go +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_final.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -15,6 +16,9 @@ type AuthenticationSASLFinal struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationSASLFinal) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLFinal) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationSASLFinal) Decode(src []byte) error { @@ -34,15 +38,38 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) - dst = append(dst, src.Data...) + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Unmarshaler. +func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "AuthenticationSASLFinal", + Data: string(src.Data), + }) +} - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } - return dst + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil } diff --git a/vendor/github.com/jackc/pgproto3/v2/backend.go b/vendor/github.com/jackc/pgproto3/v2/backend.go index 1f854c693..6eabcd85f 100644 --- a/vendor/github.com/jackc/pgproto3/v2/backend.go +++ b/vendor/github.com/jackc/pgproto3/v2/backend.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "errors" "fmt" "io" ) @@ -12,27 +13,35 @@ type Backend struct { w io.Writer // Frontend message flyweights - bind Bind - cancelRequest CancelRequest - _close Close - copyFail CopyFail - describe Describe - execute Execute - flush Flush - gssEncRequest GSSEncRequest - parse Parse - passwordMessage PasswordMessage - query Query - sslRequest SSLRequest - startupMessage StartupMessage - sync Sync - terminate Terminate + bind Bind + cancelRequest CancelRequest + _close Close + copyFail CopyFail + copyData CopyData + copyDone CopyDone + describe Describe + execute Execute + flush Flush + functionCall FunctionCall + gssEncRequest GSSEncRequest + parse Parse + query Query + sslRequest SSLRequest + startupMessage StartupMessage + sync Sync + terminate Terminate bodyLen int msgType byte partialMsg bool + authType uint32 } +const ( + minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. + maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. +) + // NewBackend creates a new Backend. func NewBackend(cr ChunkReader, w io.Writer) *Backend { return &Backend{cr: cr, w: w} @@ -40,7 +49,12 @@ func NewBackend(cr ChunkReader, w io.Writer) *Backend { // Send sends a message to the frontend. func (b *Backend) Send(msg BackendMessage) error { - _, err := b.w.Write(msg.Encode(nil)) + buf, err := msg.Encode(nil) + if err != nil { + return err + } + + _, err = b.w.Write(buf) return err } @@ -54,9 +68,13 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { } msgSize := int(binary.BigEndian.Uint32(buf) - 4) + if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen { + return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize) + } + buf, err = b.cr.Next(msgSize) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } code := binary.BigEndian.Uint32(buf) @@ -96,12 +114,15 @@ func (b *Backend) Receive() (FrontendMessage, error) { if !b.partialMsg { header, err := b.cr.Next(5) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } b.msgType = header[0] b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 b.partialMsg = true + if b.bodyLen < 0 { + return nil, errors.New("invalid message with negative body length received") + } } var msg FrontendMessage @@ -114,14 +135,34 @@ func (b *Backend) Receive() (FrontendMessage, error) { msg = &b.describe case 'E': msg = &b.execute + case 'F': + msg = &b.functionCall case 'f': msg = &b.copyFail + case 'd': + msg = &b.copyData + case 'c': + msg = &b.copyDone case 'H': msg = &b.flush case 'P': msg = &b.parse case 'p': - msg = &b.passwordMessage + switch b.authType { + case AuthTypeSASL: + msg = &SASLInitialResponse{} + case AuthTypeSASLContinue: + msg = &SASLResponse{} + case AuthTypeSASLFinal: + msg = &SASLResponse{} + case AuthTypeGSS, AuthTypeGSSCont: + msg = &GSSResponse{} + case AuthTypeCleartextPassword, AuthTypeMD5Password: + fallthrough + default: + // to maintain backwards compatability + msg = &PasswordMessage{} + } case 'Q': msg = &b.query case 'S': @@ -134,7 +175,7 @@ func (b *Backend) Receive() (FrontendMessage, error) { msgBody, err := b.cr.Next(b.bodyLen) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } b.partialMsg = false @@ -142,3 +183,36 @@ func (b *Backend) Receive() (FrontendMessage, error) { err = msg.Decode(msgBody) return msg, err } + +// SetAuthType sets the authentication type in the backend. +// Since multiple message types can start with 'p', SetAuthType allows +// contextual identification of FrontendMessages. For example, in the +// PG message flow documentation for PasswordMessage: +// +// Byte1('p') +// +// Identifies the message as a password response. Note that this is also used for +// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from +// the context. +// +// Since the Frontend does not know about the state of a backend, it is important +// to call SetAuthType() after an authentication request is received by the Frontend. +func (b *Backend) SetAuthType(authType uint32) error { + switch authType { + case AuthTypeOk, + AuthTypeCleartextPassword, + AuthTypeMD5Password, + AuthTypeSCMCreds, + AuthTypeGSS, + AuthTypeGSSCont, + AuthTypeSSPI, + AuthTypeSASL, + AuthTypeSASLContinue, + AuthTypeSASLFinal: + b.authType = authType + default: + return fmt.Errorf("authType not recognized: %d", authType) + } + + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/backend_key_data.go b/vendor/github.com/jackc/pgproto3/v2/backend_key_data.go index ca20dd259..0a3d5e55f 100644 --- a/vendor/github.com/jackc/pgproto3/v2/backend_key_data.go +++ b/vendor/github.com/jackc/pgproto3/v2/backend_key_data.go @@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *BackendKeyData) Encode(dst []byte) []byte { - dst = append(dst, 'K') - dst = pgio.AppendUint32(dst, 12) +func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'K') dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.SecretKey) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/bind.go b/vendor/github.com/jackc/pgproto3/v2/bind.go index 52372095d..dd5503b11 100644 --- a/vendor/github.com/jackc/pgproto3/v2/bind.go +++ b/vendor/github.com/jackc/pgproto3/v2/bind.go @@ -5,6 +5,9 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "errors" + "fmt" + "math" "github.com/jackc/pgio" ) @@ -107,21 +110,25 @@ func (dst *Bind) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Bind) Encode(dst []byte) []byte { - dst = append(dst, 'B') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *Bind) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'B') dst = append(dst, src.DestinationPortal...) dst = append(dst, 0) dst = append(dst, src.PreparedStatement...) dst = append(dst, 0) + if len(src.ParameterFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many parameter format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) for _, fc := range src.ParameterFormatCodes { dst = pgio.AppendInt16(dst, fc) } + if len(src.Parameters) > math.MaxUint16 { + return nil, errors.New("too many parameters") + } dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) for _, p := range src.Parameters { if p == nil { @@ -133,14 +140,15 @@ func (src *Bind) Encode(dst []byte) []byte { dst = append(dst, p...) } + if len(src.ResultFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many result format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) for _, fc := range src.ResultFormatCodes { dst = pgio.AppendInt16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. @@ -181,3 +189,35 @@ func (src Bind) MarshalJSON() ([]byte, error) { ResultFormatCodes: src.ResultFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Bind) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters []map[string]string + ResultFormatCodes []int16 + } + err := json.Unmarshal(data, &msg) + if err != nil { + return err + } + dst.DestinationPortal = msg.DestinationPortal + dst.PreparedStatement = msg.PreparedStatement + dst.ParameterFormatCodes = msg.ParameterFormatCodes + dst.Parameters = make([][]byte, len(msg.Parameters)) + dst.ResultFormatCodes = msg.ResultFormatCodes + for n, parameter := range msg.Parameters { + dst.Parameters[n], err = getValueFromJSON(parameter) + if err != nil { + return fmt.Errorf("cannot get param %d: %w", n, err) + } + } + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/bind_complete.go b/vendor/github.com/jackc/pgproto3/v2/bind_complete.go index 3be256c89..bacf30d88 100644 --- a/vendor/github.com/jackc/pgproto3/v2/bind_complete.go +++ b/vendor/github.com/jackc/pgproto3/v2/bind_complete.go @@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *BindComplete) Encode(dst []byte) []byte { - return append(dst, '2', 0, 0, 0, 4) +func (src *BindComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '2', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/cancel_request.go b/vendor/github.com/jackc/pgproto3/v2/cancel_request.go index 942e404be..76acb3fc9 100644 --- a/vendor/github.com/jackc/pgproto3/v2/cancel_request.go +++ b/vendor/github.com/jackc/pgproto3/v2/cancel_request.go @@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *CancelRequest) Encode(dst []byte) []byte { +func (src *CancelRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 16) dst = pgio.AppendInt32(dst, cancelRequestCode) dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.SecretKey) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/close.go b/vendor/github.com/jackc/pgproto3/v2/close.go index 382969093..0b50f27cb 100644 --- a/vendor/github.com/jackc/pgproto3/v2/close.go +++ b/vendor/github.com/jackc/pgproto3/v2/close.go @@ -3,8 +3,7 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgio" + "errors" ) type Close struct { @@ -36,18 +35,12 @@ func (dst *Close) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Close) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Close) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'C') dst = append(dst, src.ObjectType) dst = append(dst, src.Name...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. @@ -62,3 +55,27 @@ func (src Close) MarshalJSON() ([]byte, error) { Name: src.Name, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Close) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + ObjectType string + Name string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.ObjectType) != 1 { + return errors.New("invalid length for Close.ObjectType") + } + + dst.ObjectType = byte(msg.ObjectType[0]) + dst.Name = msg.Name + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/close_complete.go b/vendor/github.com/jackc/pgproto3/v2/close_complete.go index 1d7b8f085..833f7a12c 100644 --- a/vendor/github.com/jackc/pgproto3/v2/close_complete.go +++ b/vendor/github.com/jackc/pgproto3/v2/close_complete.go @@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CloseComplete) Encode(dst []byte) []byte { - return append(dst, '3', 0, 0, 0, 4) +func (src *CloseComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '3', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/command_complete.go b/vendor/github.com/jackc/pgproto3/v2/command_complete.go index b5106fdaf..9d822064d 100644 --- a/vendor/github.com/jackc/pgproto3/v2/command_complete.go +++ b/vendor/github.com/jackc/pgproto3/v2/command_complete.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgio" ) type CommandComplete struct { @@ -28,17 +26,11 @@ func (dst *CommandComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CommandComplete) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *CommandComplete) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'C') dst = append(dst, src.CommandTag...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. @@ -51,3 +43,21 @@ func (src CommandComplete) MarshalJSON() ([]byte, error) { CommandTag: string(src.CommandTag), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CommandComplete) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + CommandTag string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.CommandTag = []byte(msg.CommandTag) + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/copy_both_response.go b/vendor/github.com/jackc/pgproto3/v2/copy_both_response.go index 2d58f820e..4bf3ef325 100644 --- a/vendor/github.com/jackc/pgproto3/v2/copy_both_response.go +++ b/vendor/github.com/jackc/pgproto3/v2/copy_both_response.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgio" ) @@ -43,19 +45,18 @@ func (dst *CopyBothResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyBothResponse) Encode(dst []byte) []byte { - dst = append(dst, 'W') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'W') + dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. @@ -68,3 +69,27 @@ func (src CopyBothResponse) MarshalJSON() ([]byte, error) { ColumnFormatCodes: src.ColumnFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyBothResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/copy_data.go b/vendor/github.com/jackc/pgproto3/v2/copy_data.go index 7d6002fe0..89ecdd4dd 100644 --- a/vendor/github.com/jackc/pgproto3/v2/copy_data.go +++ b/vendor/github.com/jackc/pgproto3/v2/copy_data.go @@ -3,8 +3,6 @@ package pgproto3 import ( "encoding/hex" "encoding/json" - - "github.com/jackc/pgio" ) type CopyData struct { @@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyData) Encode(dst []byte) []byte { - dst = append(dst, 'd') - dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) +func (src *CopyData) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'd') dst = append(dst, src.Data...) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. @@ -42,3 +39,21 @@ func (src CopyData) MarshalJSON() ([]byte, error) { Data: hex.EncodeToString(src.Data), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyData) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/copy_done.go b/vendor/github.com/jackc/pgproto3/v2/copy_done.go index 0e13282bf..040814dbd 100644 --- a/vendor/github.com/jackc/pgproto3/v2/copy_done.go +++ b/vendor/github.com/jackc/pgproto3/v2/copy_done.go @@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyDone) Encode(dst []byte) []byte { - return append(dst, 'c', 0, 0, 0, 4) +func (src *CopyDone) Encode(dst []byte) ([]byte, error) { + return append(dst, 'c', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/copy_fail.go b/vendor/github.com/jackc/pgproto3/v2/copy_fail.go index 78ff0b30b..72a85fd09 100644 --- a/vendor/github.com/jackc/pgproto3/v2/copy_fail.go +++ b/vendor/github.com/jackc/pgproto3/v2/copy_fail.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgio" ) type CopyFail struct { @@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyFail) Encode(dst []byte) []byte { - dst = append(dst, 'f') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *CopyFail) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'f') dst = append(dst, src.Message...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/copy_in_response.go b/vendor/github.com/jackc/pgproto3/v2/copy_in_response.go index 5f2595b87..bfc3ee073 100644 --- a/vendor/github.com/jackc/pgproto3/v2/copy_in_response.go +++ b/vendor/github.com/jackc/pgproto3/v2/copy_in_response.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgio" ) @@ -43,20 +45,19 @@ func (dst *CopyInResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyInResponse) Encode(dst []byte) []byte { - dst = append(dst, 'G') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'G') dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. @@ -69,3 +70,27 @@ func (src CopyInResponse) MarshalJSON() ([]byte, error) { ColumnFormatCodes: src.ColumnFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyInResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyInResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/copy_out_response.go b/vendor/github.com/jackc/pgproto3/v2/copy_out_response.go index 8538dfc7d..265e35f93 100644 --- a/vendor/github.com/jackc/pgproto3/v2/copy_out_response.go +++ b/vendor/github.com/jackc/pgproto3/v2/copy_out_response.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgio" ) @@ -42,21 +44,20 @@ func (dst *CopyOutResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyOutResponse) Encode(dst []byte) []byte { - dst = append(dst, 'H') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'H') dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. @@ -69,3 +70,27 @@ func (src CopyOutResponse) MarshalJSON() ([]byte, error) { ColumnFormatCodes: src.ColumnFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyOutResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/data_row.go b/vendor/github.com/jackc/pgproto3/v2/data_row.go index 5fa3c5d8c..d755515cc 100644 --- a/vendor/github.com/jackc/pgproto3/v2/data_row.go +++ b/vendor/github.com/jackc/pgproto3/v2/data_row.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "errors" + "math" "github.com/jackc/pgio" ) @@ -63,11 +65,12 @@ func (dst *DataRow) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *DataRow) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *DataRow) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'D') + if len(src.Values) > math.MaxUint16 { + return nil, errors.New("too many values") + } dst = pgio.AppendUint16(dst, uint16(len(src.Values))) for _, v := range src.Values { if v == nil { @@ -79,9 +82,7 @@ func (src *DataRow) Encode(dst []byte) []byte { dst = append(dst, v...) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. @@ -115,3 +116,28 @@ func (src DataRow) MarshalJSON() ([]byte, error) { Values: formattedValues, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *DataRow) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Values []map[string]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Values = make([][]byte, len(msg.Values)) + for n, parameter := range msg.Values { + var err error + dst.Values[n], err = getValueFromJSON(parameter) + if err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/describe.go b/vendor/github.com/jackc/pgproto3/v2/describe.go index 308f582e7..89feff215 100644 --- a/vendor/github.com/jackc/pgproto3/v2/describe.go +++ b/vendor/github.com/jackc/pgproto3/v2/describe.go @@ -3,8 +3,7 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgio" + "errors" ) type Describe struct { @@ -36,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Describe) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Describe) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'D') dst = append(dst, src.ObjectType) dst = append(dst, src.Name...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. @@ -62,3 +55,26 @@ func (src Describe) MarshalJSON() ([]byte, error) { Name: src.Name, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Describe) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + ObjectType string + Name string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.ObjectType) != 1 { + return errors.New("invalid length for Describe.ObjectType") + } + + dst.ObjectType = byte(msg.ObjectType[0]) + dst.Name = msg.Name + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/empty_query_response.go b/vendor/github.com/jackc/pgproto3/v2/empty_query_response.go index 2b85e744b..cb6cca073 100644 --- a/vendor/github.com/jackc/pgproto3/v2/empty_query_response.go +++ b/vendor/github.com/jackc/pgproto3/v2/empty_query_response.go @@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *EmptyQueryResponse) Encode(dst []byte) []byte { - return append(dst, 'I', 0, 0, 0, 4) +func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) { + return append(dst, 'I', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/error_response.go b/vendor/github.com/jackc/pgproto3/v2/error_response.go index d444798bc..6ef9bd061 100644 --- a/vendor/github.com/jackc/pgproto3/v2/error_response.go +++ b/vendor/github.com/jackc/pgproto3/v2/error_response.go @@ -2,28 +2,29 @@ package pgproto3 import ( "bytes" - "encoding/binary" + "encoding/json" "strconv" ) type ErrorResponse struct { - Severity string - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string UnknownFields map[byte]string } @@ -56,6 +57,8 @@ func (dst *ErrorResponse) Decode(src []byte) error { switch k { case 'S': dst.Severity = v + case 'V': + dst.SeverityUnlocalized = v case 'C': dst.Code = v case 'M': @@ -107,112 +110,217 @@ func (dst *ErrorResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ErrorResponse) Encode(dst []byte) []byte { - return append(dst, src.marshalBinary('E')...) +func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'E') + dst = src.appendFields(dst) + return finishMessage(dst, sp) } -func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte(typeByte) - buf.Write(bigEndian.Uint32(0)) - +func (src *ErrorResponse) appendFields(dst []byte) []byte { if src.Severity != "" { - buf.WriteByte('S') - buf.WriteString(src.Severity) - buf.WriteByte(0) + dst = append(dst, 'S') + dst = append(dst, src.Severity...) + dst = append(dst, 0) + } + if src.SeverityUnlocalized != "" { + dst = append(dst, 'V') + dst = append(dst, src.SeverityUnlocalized...) + dst = append(dst, 0) } if src.Code != "" { - buf.WriteByte('C') - buf.WriteString(src.Code) - buf.WriteByte(0) + dst = append(dst, 'C') + dst = append(dst, src.Code...) + dst = append(dst, 0) } if src.Message != "" { - buf.WriteByte('M') - buf.WriteString(src.Message) - buf.WriteByte(0) + dst = append(dst, 'M') + dst = append(dst, src.Message...) + dst = append(dst, 0) } if src.Detail != "" { - buf.WriteByte('D') - buf.WriteString(src.Detail) - buf.WriteByte(0) + dst = append(dst, 'D') + dst = append(dst, src.Detail...) + dst = append(dst, 0) } if src.Hint != "" { - buf.WriteByte('H') - buf.WriteString(src.Hint) - buf.WriteByte(0) + dst = append(dst, 'H') + dst = append(dst, src.Hint...) + dst = append(dst, 0) } if src.Position != 0 { - buf.WriteByte('P') - buf.WriteString(strconv.Itoa(int(src.Position))) - buf.WriteByte(0) + dst = append(dst, 'P') + dst = append(dst, strconv.Itoa(int(src.Position))...) + dst = append(dst, 0) } if src.InternalPosition != 0 { - buf.WriteByte('p') - buf.WriteString(strconv.Itoa(int(src.InternalPosition))) - buf.WriteByte(0) + dst = append(dst, 'p') + dst = append(dst, strconv.Itoa(int(src.InternalPosition))...) + dst = append(dst, 0) } if src.InternalQuery != "" { - buf.WriteByte('q') - buf.WriteString(src.InternalQuery) - buf.WriteByte(0) + dst = append(dst, 'q') + dst = append(dst, src.InternalQuery...) + dst = append(dst, 0) } if src.Where != "" { - buf.WriteByte('W') - buf.WriteString(src.Where) - buf.WriteByte(0) + dst = append(dst, 'W') + dst = append(dst, src.Where...) + dst = append(dst, 0) } if src.SchemaName != "" { - buf.WriteByte('s') - buf.WriteString(src.SchemaName) - buf.WriteByte(0) + dst = append(dst, 's') + dst = append(dst, src.SchemaName...) + dst = append(dst, 0) } if src.TableName != "" { - buf.WriteByte('t') - buf.WriteString(src.TableName) - buf.WriteByte(0) + dst = append(dst, 't') + dst = append(dst, src.TableName...) + dst = append(dst, 0) } if src.ColumnName != "" { - buf.WriteByte('c') - buf.WriteString(src.ColumnName) - buf.WriteByte(0) + dst = append(dst, 'c') + dst = append(dst, src.ColumnName...) + dst = append(dst, 0) } if src.DataTypeName != "" { - buf.WriteByte('d') - buf.WriteString(src.DataTypeName) - buf.WriteByte(0) + dst = append(dst, 'd') + dst = append(dst, src.DataTypeName...) + dst = append(dst, 0) } if src.ConstraintName != "" { - buf.WriteByte('n') - buf.WriteString(src.ConstraintName) - buf.WriteByte(0) + dst = append(dst, 'n') + dst = append(dst, src.ConstraintName...) + dst = append(dst, 0) } if src.File != "" { - buf.WriteByte('F') - buf.WriteString(src.File) - buf.WriteByte(0) + dst = append(dst, 'F') + dst = append(dst, src.File...) + dst = append(dst, 0) } if src.Line != 0 { - buf.WriteByte('L') - buf.WriteString(strconv.Itoa(int(src.Line))) - buf.WriteByte(0) + dst = append(dst, 'L') + dst = append(dst, strconv.Itoa(int(src.Line))...) + dst = append(dst, 0) } if src.Routine != "" { - buf.WriteByte('R') - buf.WriteString(src.Routine) - buf.WriteByte(0) + dst = append(dst, 'R') + dst = append(dst, src.Routine...) + dst = append(dst, 0) } for k, v := range src.UnknownFields { - buf.WriteByte(k) - buf.WriteByte(0) - buf.WriteString(v) - buf.WriteByte(0) + dst = append(dst, k) + dst = append(dst, v...) + dst = append(dst, 0) } - buf.WriteByte(0) - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + dst = append(dst, 0) + + return dst +} - return buf.Bytes() +// MarshalJSON implements encoding/json.Marshaler. +func (src ErrorResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string + }{ + Type: "ErrorResponse", + Severity: src.Severity, + SeverityUnlocalized: src.SeverityUnlocalized, + Code: src.Code, + Message: src.Message, + Detail: src.Detail, + Hint: src.Hint, + Position: src.Position, + InternalPosition: src.InternalPosition, + InternalQuery: src.InternalQuery, + Where: src.Where, + SchemaName: src.SchemaName, + TableName: src.TableName, + ColumnName: src.ColumnName, + DataTypeName: src.DataTypeName, + ConstraintName: src.ConstraintName, + File: src.File, + Line: src.Line, + Routine: src.Routine, + UnknownFields: src.UnknownFields, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ErrorResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Severity = msg.Severity + dst.SeverityUnlocalized = msg.SeverityUnlocalized + dst.Code = msg.Code + dst.Message = msg.Message + dst.Detail = msg.Detail + dst.Hint = msg.Hint + dst.Position = msg.Position + dst.InternalPosition = msg.InternalPosition + dst.InternalQuery = msg.InternalQuery + dst.Where = msg.Where + dst.SchemaName = msg.SchemaName + dst.TableName = msg.TableName + dst.ColumnName = msg.ColumnName + dst.DataTypeName = msg.DataTypeName + dst.ConstraintName = msg.ConstraintName + dst.File = msg.File + dst.Line = msg.Line + dst.Routine = msg.Routine + + dst.UnknownFields = msg.UnknownFields + + return nil } diff --git a/vendor/github.com/jackc/pgproto3/v2/execute.go b/vendor/github.com/jackc/pgproto3/v2/execute.go index 8bae61332..efb9e1e21 100644 --- a/vendor/github.com/jackc/pgproto3/v2/execute.go +++ b/vendor/github.com/jackc/pgproto3/v2/execute.go @@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Execute) Encode(dst []byte) []byte { - dst = append(dst, 'E') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Execute) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'E') dst = append(dst, src.Portal...) dst = append(dst, 0) - dst = pgio.AppendUint32(dst, src.MaxRows) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/flush.go b/vendor/github.com/jackc/pgproto3/v2/flush.go index 2725f6894..e5dc1fbbd 100644 --- a/vendor/github.com/jackc/pgproto3/v2/flush.go +++ b/vendor/github.com/jackc/pgproto3/v2/flush.go @@ -20,8 +20,8 @@ func (dst *Flush) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Flush) Encode(dst []byte) []byte { - return append(dst, 'H', 0, 0, 0, 4) +func (src *Flush) Encode(dst []byte) ([]byte, error) { + return append(dst, 'H', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/frontend.go b/vendor/github.com/jackc/pgproto3/v2/frontend.go index b8f545ca8..623b0a98e 100644 --- a/vendor/github.com/jackc/pgproto3/v2/frontend.go +++ b/vendor/github.com/jackc/pgproto3/v2/frontend.go @@ -16,6 +16,8 @@ type Frontend struct { authenticationOk AuthenticationOk authenticationCleartextPassword AuthenticationCleartextPassword authenticationMD5Password AuthenticationMD5Password + authenticationGSS AuthenticationGSS + authenticationGSSContinue AuthenticationGSSContinue authenticationSASL AuthenticationSASL authenticationSASLContinue AuthenticationSASLContinue authenticationSASLFinal AuthenticationSASLFinal @@ -45,6 +47,7 @@ type Frontend struct { bodyLen int msgType byte partialMsg bool + authType uint32 } // NewFrontend creates a new Frontend. @@ -54,7 +57,11 @@ func NewFrontend(cr ChunkReader, w io.Writer) *Frontend { // Send sends a message to the backend. func (f *Frontend) Send(msg FrontendMessage) error { - _, err := f.w.Write(msg.Encode(nil)) + buf, err := msg.Encode(nil) + if err != nil { + return err + } + _, err = f.w.Write(buf) return err } @@ -76,6 +83,9 @@ func (f *Frontend) Receive() (BackendMessage, error) { f.msgType = header[0] f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 f.partialMsg = true + if f.bodyLen < 0 { + return nil, errors.New("invalid message with negative body length received") + } } msgBody, err := f.cr.Next(f.bodyLen) @@ -146,10 +156,16 @@ func (f *Frontend) Receive() (BackendMessage, error) { } // Authentication message type constants. +// See src/include/libpq/pqcomm.h for all +// constants. const ( AuthTypeOk = 0 AuthTypeCleartextPassword = 3 AuthTypeMD5Password = 5 + AuthTypeSCMCreds = 6 + AuthTypeGSS = 7 + AuthTypeGSSCont = 8 + AuthTypeSSPI = 9 AuthTypeSASL = 10 AuthTypeSASLContinue = 11 AuthTypeSASLFinal = 12 @@ -159,15 +175,23 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er if len(src) < 4 { return nil, errors.New("authentication message too short") } - authType := binary.BigEndian.Uint32(src[:4]) + f.authType = binary.BigEndian.Uint32(src[:4]) - switch authType { + switch f.authType { case AuthTypeOk: return &f.authenticationOk, nil case AuthTypeCleartextPassword: return &f.authenticationCleartextPassword, nil case AuthTypeMD5Password: return &f.authenticationMD5Password, nil + case AuthTypeSCMCreds: + return nil, errors.New("AuthTypeSCMCreds is unimplemented") + case AuthTypeGSS: + return &f.authenticationGSS, nil + case AuthTypeGSSCont: + return &f.authenticationGSSContinue, nil + case AuthTypeSSPI: + return nil, errors.New("AuthTypeSSPI is unimplemented") case AuthTypeSASL: return &f.authenticationSASL, nil case AuthTypeSASLContinue: @@ -175,6 +199,12 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er case AuthTypeSASLFinal: return &f.authenticationSASLFinal, nil default: - return nil, fmt.Errorf("unknown authentication type: %d", authType) + return nil, fmt.Errorf("unknown authentication type: %d", f.authType) } } + +// GetAuthType returns the authType used in the current state of the frontend. +// See SetAuthType for more information. +func (f *Frontend) GetAuthType() uint32 { + return f.authType +} diff --git a/vendor/github.com/jackc/pgproto3/v2/function_call.go b/vendor/github.com/jackc/pgproto3/v2/function_call.go new file mode 100644 index 000000000..5d799c4d7 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/function_call.go @@ -0,0 +1,102 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + "math" + + "github.com/jackc/pgio" +) + +type FunctionCall struct { + Function uint32 + ArgFormatCodes []uint16 + Arguments [][]byte + ResultFormatCode uint16 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*FunctionCall) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *FunctionCall) Decode(src []byte) error { + *dst = FunctionCall{} + rp := 0 + // Specifies the object ID of the function to call. + dst.Function = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + // The number of argument format codes that follow (denoted C below). + // This can be zero to indicate that there are no arguments or that the arguments all use the default format (text); + // or one, in which case the specified format code is applied to all arguments; + // or it can equal the actual number of arguments. + nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + argumentCodes := make([]uint16, nArgumentCodes) + for i := 0; i < nArgumentCodes; i++ { + // The argument format codes. Each must presently be zero (text) or one (binary). + ac := binary.BigEndian.Uint16(src[rp:]) + if ac != 0 && ac != 1 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + argumentCodes[i] = ac + rp += 2 + } + dst.ArgFormatCodes = argumentCodes + + // Specifies the number of arguments being supplied to the function. + nArguments := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + arguments := make([][]byte, nArguments) + for i := 0; i < nArguments; i++ { + // The length of the argument value, in bytes (this count does not include itself). Can be zero. + // As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case. + argumentLength := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + if argumentLength == -1 { + arguments[i] = nil + } else { + // The value of the argument, in the format indicated by the associated format code. n is the above length. + argumentValue := src[rp : rp+argumentLength] + rp += argumentLength + arguments[i] = argumentValue + } + } + dst.Arguments = arguments + // The format code for the function result. Must presently be zero (text) or one (binary). + resultFormatCode := binary.BigEndian.Uint16(src[rp:]) + if resultFormatCode != 0 && resultFormatCode != 1 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + dst.ResultFormatCode = resultFormatCode + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *FunctionCall) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'F') + dst = pgio.AppendUint32(dst, src.Function) + + if len(src.ArgFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many arg format codes") + } + dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) + for _, argFormatCode := range src.ArgFormatCodes { + dst = pgio.AppendUint16(dst, argFormatCode) + } + + if len(src.Arguments) > math.MaxUint16 { + return nil, errors.New("too many arguments") + } + dst = pgio.AppendUint16(dst, uint16(len(src.Arguments))) + for _, argument := range src.Arguments { + if argument == nil { + dst = pgio.AppendInt32(dst, -1) + } else { + dst = pgio.AppendInt32(dst, int32(len(argument))) + dst = append(dst, argument...) + } + } + dst = pgio.AppendUint16(dst, src.ResultFormatCode) + return finishMessage(dst, sp) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/function_call_response.go b/vendor/github.com/jackc/pgproto3/v2/function_call_response.go index 5cc2d4d29..abc14f0d1 100644 --- a/vendor/github.com/jackc/pgproto3/v2/function_call_response.go +++ b/vendor/github.com/jackc/pgproto3/v2/function_call_response.go @@ -39,10 +39,8 @@ func (dst *FunctionCallResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *FunctionCallResponse) Encode(dst []byte) []byte { - dst = append(dst, 'V') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'V') if src.Result == nil { dst = pgio.AppendInt32(dst, -1) @@ -51,9 +49,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte { dst = append(dst, src.Result...) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. @@ -81,3 +77,21 @@ func (src FunctionCallResponse) MarshalJSON() ([]byte, error) { Result: formattedValue, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Result map[string]string + } + err := json.Unmarshal(data, &msg) + if err != nil { + return err + } + dst.Result, err = getValueFromJSON(msg.Result) + return err +} diff --git a/vendor/github.com/jackc/pgproto3/v2/gss_enc_request.go b/vendor/github.com/jackc/pgproto3/v2/gss_enc_request.go index cf405a3e0..f6e4f6627 100644 --- a/vendor/github.com/jackc/pgproto3/v2/gss_enc_request.go +++ b/vendor/github.com/jackc/pgproto3/v2/gss_enc_request.go @@ -31,10 +31,10 @@ func (dst *GSSEncRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *GSSEncRequest) Encode(dst []byte) []byte { +func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 8) dst = pgio.AppendInt32(dst, gssEncReqNumber) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/gss_response.go b/vendor/github.com/jackc/pgproto3/v2/gss_response.go new file mode 100644 index 000000000..10d937759 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/gss_response.go @@ -0,0 +1,46 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type GSSResponse struct { + Data []byte +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (g *GSSResponse) Frontend() {} + +func (g *GSSResponse) Decode(data []byte) error { + g.Data = data + return nil +} + +func (g *GSSResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') + dst = append(dst, g.Data...) + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (g *GSSResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "GSSResponse", + Data: g.Data, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (g *GSSResponse) UnmarshalJSON(data []byte) error { + var msg struct { + Data []byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + g.Data = msg.Data + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/no_data.go b/vendor/github.com/jackc/pgproto3/v2/no_data.go index d8f85d38a..cbcaad40c 100644 --- a/vendor/github.com/jackc/pgproto3/v2/no_data.go +++ b/vendor/github.com/jackc/pgproto3/v2/no_data.go @@ -20,8 +20,8 @@ func (dst *NoData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NoData) Encode(dst []byte) []byte { - return append(dst, 'n', 0, 0, 0, 4) +func (src *NoData) Encode(dst []byte) ([]byte, error) { + return append(dst, 'n', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/notice_response.go b/vendor/github.com/jackc/pgproto3/v2/notice_response.go index 4ac28a791..497aba6dd 100644 --- a/vendor/github.com/jackc/pgproto3/v2/notice_response.go +++ b/vendor/github.com/jackc/pgproto3/v2/notice_response.go @@ -12,6 +12,8 @@ func (dst *NoticeResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NoticeResponse) Encode(dst []byte) []byte { - return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) +func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'N') + dst = (*ErrorResponse)(src).appendFields(dst) + return finishMessage(dst, sp) } diff --git a/vendor/github.com/jackc/pgproto3/v2/notification_response.go b/vendor/github.com/jackc/pgproto3/v2/notification_response.go index e762eb967..5be3edd3c 100644 --- a/vendor/github.com/jackc/pgproto3/v2/notification_response.go +++ b/vendor/github.com/jackc/pgproto3/v2/notification_response.go @@ -41,20 +41,14 @@ func (dst *NotificationResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NotificationResponse) Encode(dst []byte) []byte { - dst = append(dst, 'A') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'A') dst = pgio.AppendUint32(dst, src.PID) dst = append(dst, src.Channel...) dst = append(dst, 0) dst = append(dst, src.Payload...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/parameter_description.go b/vendor/github.com/jackc/pgproto3/v2/parameter_description.go index e28965c8a..fec0fce85 100644 --- a/vendor/github.com/jackc/pgproto3/v2/parameter_description.go +++ b/vendor/github.com/jackc/pgproto3/v2/parameter_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgio" ) @@ -39,19 +41,18 @@ func (dst *ParameterDescription) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParameterDescription) Encode(dst []byte) []byte { - dst = append(dst, 't') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 't') + if len(src.ParameterOIDs) > math.MaxUint16 { + return nil, errors.New("too many parameter oids") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/parameter_status.go b/vendor/github.com/jackc/pgproto3/v2/parameter_status.go index c4021d92f..9ee0720b5 100644 --- a/vendor/github.com/jackc/pgproto3/v2/parameter_status.go +++ b/vendor/github.com/jackc/pgproto3/v2/parameter_status.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgio" ) type ParameterStatus struct { @@ -37,19 +35,13 @@ func (dst *ParameterStatus) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParameterStatus) Encode(dst []byte) []byte { - dst = append(dst, 'S') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'S') dst = append(dst, src.Name...) dst = append(dst, 0) dst = append(dst, src.Value...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/parse.go b/vendor/github.com/jackc/pgproto3/v2/parse.go index 723885d41..7dd06990e 100644 --- a/vendor/github.com/jackc/pgproto3/v2/parse.go +++ b/vendor/github.com/jackc/pgproto3/v2/parse.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgio" ) @@ -52,24 +54,23 @@ func (dst *Parse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Parse) Encode(dst []byte) []byte { - dst = append(dst, 'P') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *Parse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'P') dst = append(dst, src.Name...) dst = append(dst, 0) dst = append(dst, src.Query...) dst = append(dst, 0) + if len(src.ParameterOIDs) > math.MaxUint16 { + return nil, errors.New("too many parameter oids") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/parse_complete.go b/vendor/github.com/jackc/pgproto3/v2/parse_complete.go index 92c9498b6..cff9e27d0 100644 --- a/vendor/github.com/jackc/pgproto3/v2/parse_complete.go +++ b/vendor/github.com/jackc/pgproto3/v2/parse_complete.go @@ -20,8 +20,8 @@ func (dst *ParseComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParseComplete) Encode(dst []byte) []byte { - return append(dst, '1', 0, 0, 0, 4) +func (src *ParseComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '1', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/password_message.go b/vendor/github.com/jackc/pgproto3/v2/password_message.go index 4b68b31a8..d820d3275 100644 --- a/vendor/github.com/jackc/pgproto3/v2/password_message.go +++ b/vendor/github.com/jackc/pgproto3/v2/password_message.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgio" ) type PasswordMessage struct { @@ -14,6 +12,9 @@ type PasswordMessage struct { // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*PasswordMessage) Frontend() {} +// Frontend identifies this message as an authentication response. +func (*PasswordMessage) InitialResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *PasswordMessage) Decode(src []byte) error { @@ -29,14 +30,11 @@ func (dst *PasswordMessage) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *PasswordMessage) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) - +func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, src.Password...) dst = append(dst, 0) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/pgproto3.go b/vendor/github.com/jackc/pgproto3/v2/pgproto3.go index fe7b085bc..aa4167c45 100644 --- a/vendor/github.com/jackc/pgproto3/v2/pgproto3.go +++ b/vendor/github.com/jackc/pgproto3/v2/pgproto3.go @@ -1,6 +1,16 @@ package pgproto3 -import "fmt" +import ( + "encoding/hex" + "errors" + "fmt" + + "github.com/jackc/pgio" +) + +// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL +// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff. +const maxMessageBodyLen = (0x3fffffff - 1) // Message is the interface implemented by an object that can decode and encode // a particular PostgreSQL message. @@ -10,7 +20,7 @@ type Message interface { Decode(data []byte) error // Encode appends itself to dst and returns the new buffer. - Encode(dst []byte) []byte + Encode(dst []byte) ([]byte, error) } type FrontendMessage interface { @@ -23,6 +33,11 @@ type BackendMessage interface { Backend() // no-op method to distinguish frontend from backend methods } +type AuthenticationResponseMessage interface { + BackendMessage + AuthenticationResponse() // no-op method to distinguish authentication responses +} + type invalidMessageLenErr struct { messageType string expectedLen int @@ -40,3 +55,37 @@ type invalidMessageFormatErr struct { func (e *invalidMessageFormatErr) Error() string { return fmt.Sprintf("%s body is invalid", e.messageType) } + +// getValueFromJSON gets the value from a protocol message representation in JSON. +func getValueFromJSON(v map[string]string) ([]byte, error) { + if v == nil { + return nil, nil + } + if text, ok := v["text"]; ok { + return []byte(text), nil + } + if binary, ok := v["binary"]; ok { + return hex.DecodeString(binary) + } + return nil, errors.New("unknown protocol representation") +} + +// beginMessage begines a new message of type t. It appends the message type and a placeholder for the message length to +// dst. It returns the new buffer and the position of the message length placeholder. +func beginMessage(dst []byte, t byte) ([]byte, int) { + dst = append(dst, t) + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + return dst, sp +} + +// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to +// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer. +func finishMessage(dst []byte, sp int) ([]byte, error) { + messageBodyLen := len(dst[sp:]) + if messageBodyLen > maxMessageBodyLen { + return nil, errors.New("message body too large") + } + pgio.SetInt32(dst[sp:], int32(messageBodyLen)) + return dst, nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/portal_suspended.go b/vendor/github.com/jackc/pgproto3/v2/portal_suspended.go index 1a9e7bfb1..9e2f8cbc4 100644 --- a/vendor/github.com/jackc/pgproto3/v2/portal_suspended.go +++ b/vendor/github.com/jackc/pgproto3/v2/portal_suspended.go @@ -20,8 +20,8 @@ func (dst *PortalSuspended) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *PortalSuspended) Encode(dst []byte) []byte { - return append(dst, 's', 0, 0, 0, 4) +func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) { + return append(dst, 's', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/query.go b/vendor/github.com/jackc/pgproto3/v2/query.go index 41c93b4a8..aebdfde89 100644 --- a/vendor/github.com/jackc/pgproto3/v2/query.go +++ b/vendor/github.com/jackc/pgproto3/v2/query.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgio" ) type Query struct { @@ -28,14 +26,11 @@ func (dst *Query) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Query) Encode(dst []byte) []byte { - dst = append(dst, 'Q') - dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) - +func (src *Query) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'Q') dst = append(dst, src.String...) dst = append(dst, 0) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/ready_for_query.go b/vendor/github.com/jackc/pgproto3/v2/ready_for_query.go index 879afe390..a56af9fb2 100644 --- a/vendor/github.com/jackc/pgproto3/v2/ready_for_query.go +++ b/vendor/github.com/jackc/pgproto3/v2/ready_for_query.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/json" + "errors" ) type ReadyForQuery struct { @@ -24,8 +25,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ReadyForQuery) Encode(dst []byte) []byte { - return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) +func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) { + return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil } // MarshalJSON implements encoding/json.Marshaler. @@ -38,3 +39,23 @@ func (src ReadyForQuery) MarshalJSON() ([]byte, error) { TxStatus: string(src.TxStatus), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ReadyForQuery) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + TxStatus string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.TxStatus) != 1 { + return errors.New("invalid length for ReadyForQuery.TxStatus") + } + dst.TxStatus = msg.TxStatus[0] + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/row_description.go b/vendor/github.com/jackc/pgproto3/v2/row_description.go index d9b8c7c98..3f6b2c649 100644 --- a/vendor/github.com/jackc/pgproto3/v2/row_description.go +++ b/vendor/github.com/jackc/pgproto3/v2/row_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgio" ) @@ -99,11 +101,12 @@ func (dst *RowDescription) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *RowDescription) Encode(dst []byte) []byte { - dst = append(dst, 'T') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *RowDescription) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'T') + if len(src.Fields) > math.MaxUint16 { + return nil, errors.New("too many fields") + } dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) for _, fd := range src.Fields { dst = append(dst, fd.Name...) @@ -117,9 +120,7 @@ func (src *RowDescription) Encode(dst []byte) []byte { dst = pgio.AppendInt16(dst, fd.Format) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. @@ -132,3 +133,34 @@ func (src RowDescription) MarshalJSON() ([]byte, error) { Fields: src.Fields, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *RowDescription) UnmarshalJSON(data []byte) error { + var msg struct { + Fields []struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 + } + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + dst.Fields = make([]FieldDescription, len(msg.Fields)) + for n, field := range msg.Fields { + dst.Fields[n] = FieldDescription{ + Name: []byte(field.Name), + TableOID: field.TableOID, + TableAttributeNumber: field.TableAttributeNumber, + DataTypeOID: field.DataTypeOID, + DataTypeSize: field.DataTypeSize, + TypeModifier: field.TypeModifier, + Format: field.Format, + } + } + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/sasl_initial_response.go b/vendor/github.com/jackc/pgproto3/v2/sasl_initial_response.go index 0bf8a9e56..1938f6582 100644 --- a/vendor/github.com/jackc/pgproto3/v2/sasl_initial_response.go +++ b/vendor/github.com/jackc/pgproto3/v2/sasl_initial_response.go @@ -2,7 +2,6 @@ package pgproto3 import ( "bytes" - "encoding/hex" "encoding/json" "errors" @@ -39,10 +38,8 @@ func (dst *SASLInitialResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *SASLInitialResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, []byte(src.AuthMechanism)...) dst = append(dst, 0) @@ -50,9 +47,7 @@ func (src *SASLInitialResponse) Encode(dst []byte) []byte { dst = pgio.AppendInt32(dst, int32(len(src.Data))) dst = append(dst, src.Data...) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. @@ -64,6 +59,25 @@ func (src SASLInitialResponse) MarshalJSON() ([]byte, error) { }{ Type: "SASLInitialResponse", AuthMechanism: src.AuthMechanism, - Data: hex.EncodeToString(src.Data), + Data: string(src.Data), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + AuthMechanism string + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + dst.AuthMechanism = msg.AuthMechanism + dst.Data = []byte(msg.Data) + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/sasl_response.go b/vendor/github.com/jackc/pgproto3/v2/sasl_response.go index 21be6d755..f4a131858 100644 --- a/vendor/github.com/jackc/pgproto3/v2/sasl_response.go +++ b/vendor/github.com/jackc/pgproto3/v2/sasl_response.go @@ -1,10 +1,7 @@ package pgproto3 import ( - "encoding/hex" "encoding/json" - - "github.com/jackc/pgio" ) type SASLResponse struct { @@ -22,13 +19,10 @@ func (dst *SASLResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *SASLResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) - +func (src *SASLResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, src.Data...) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. @@ -38,6 +32,18 @@ func (src SASLResponse) MarshalJSON() ([]byte, error) { Data string }{ Type: "SASLResponse", - Data: hex.EncodeToString(src.Data), + Data: string(src.Data), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *SASLResponse) UnmarshalJSON(data []byte) error { + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + dst.Data = []byte(msg.Data) + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/ssl_request.go b/vendor/github.com/jackc/pgproto3/v2/ssl_request.go index 96ce489e5..8feff1a28 100644 --- a/vendor/github.com/jackc/pgproto3/v2/ssl_request.go +++ b/vendor/github.com/jackc/pgproto3/v2/ssl_request.go @@ -31,10 +31,10 @@ func (dst *SSLRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *SSLRequest) Encode(dst []byte) []byte { +func (src *SSLRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 8) dst = pgio.AppendInt32(dst, sslRequestNumber) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/startup_message.go b/vendor/github.com/jackc/pgproto3/v2/startup_message.go index 5f1cd24f7..255ea22d6 100644 --- a/vendor/github.com/jackc/pgproto3/v2/startup_message.go +++ b/vendor/github.com/jackc/pgproto3/v2/startup_message.go @@ -64,7 +64,7 @@ func (dst *StartupMessage) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *StartupMessage) Encode(dst []byte) []byte { +func (src *StartupMessage) Encode(dst []byte) ([]byte, error) { sp := len(dst) dst = pgio.AppendInt32(dst, -1) @@ -77,9 +77,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte { } dst = append(dst, 0) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/sync.go b/vendor/github.com/jackc/pgproto3/v2/sync.go index 5db8e07ac..ea4fc9594 100644 --- a/vendor/github.com/jackc/pgproto3/v2/sync.go +++ b/vendor/github.com/jackc/pgproto3/v2/sync.go @@ -20,8 +20,8 @@ func (dst *Sync) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Sync) Encode(dst []byte) []byte { - return append(dst, 'S', 0, 0, 0, 4) +func (src *Sync) Encode(dst []byte) ([]byte, error) { + return append(dst, 'S', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/github.com/jackc/pgproto3/v2/terminate.go b/vendor/github.com/jackc/pgproto3/v2/terminate.go index 135191eae..35a9dc837 100644 --- a/vendor/github.com/jackc/pgproto3/v2/terminate.go +++ b/vendor/github.com/jackc/pgproto3/v2/terminate.go @@ -20,8 +20,8 @@ func (dst *Terminate) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Terminate) Encode(dst []byte) []byte { - return append(dst, 'X', 0, 0, 0, 4) +func (src *Terminate) Encode(dst []byte) ([]byte, error) { + return append(dst, 'X', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/vendor/modules.txt b/vendor/modules.txt index 818d314bc..88cd958e2 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -187,7 +187,7 @@ github.com/jackc/pgio # github.com/jackc/pgpassfile v1.0.0 ## explicit; go 1.12 github.com/jackc/pgpassfile -# github.com/jackc/pgproto3/v2 v2.0.6 +# github.com/jackc/pgproto3/v2 v2.3.3 ## explicit; go 1.12 github.com/jackc/pgproto3/v2 # github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b