From eded558fc894c1444a17f345e9e5c5395459afde Mon Sep 17 00:00:00 2001 From: Roland Groen Date: Tue, 5 Nov 2024 18:02:01 +0100 Subject: [PATCH] Support multiple SAN OtherName values in certificates Updated functions to handle multiple SAN OtherName values in certificates and modified tests accordingly. Refactored validation logic to accommodate lists of SAN OtherName values. --- vdr/didx509/resolver_test.go | 53 +++++++++++++++++++-- vdr/didx509/validation.go | 4 +- vdr/didx509/x509_utils.go | 20 ++++---- vdr/didx509/x509_utils_test.go | 86 +++++++++++++++++++--------------- 4 files changed, 109 insertions(+), 54 deletions(-) diff --git a/vdr/didx509/resolver_test.go b/vdr/didx509/resolver_test.go index 0fb808aad..b83ddc228 100644 --- a/vdr/didx509/resolver_test.go +++ b/vdr/didx509/resolver_test.go @@ -42,7 +42,8 @@ func TestManager_Resolve_OtherName(t *testing.T) { metadata := resolver2.ResolveMetadata{} otherNameValue := "A_BIG_STRING" - _, certChain, rootCertificate, _, signingCert, err := BuildCertChain(otherNameValue) + otherNameValueSecondary := "A_SECOND_STRING" + _, certChain, rootCertificate, _, signingCert, err := BuildCertChain([]string{otherNameValue, otherNameValueSecondary}) require.NoError(t, err) metadata.JwtProtectedHeaders = make(map[string]interface{}) metadata.JwtProtectedHeaders[X509CertChainHeader] = certChain @@ -82,7 +83,7 @@ func TestManager_Resolve_OtherName(t *testing.T) { didUrl, err := did.ParseDIDURL(rootDID.String() + "#0") assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl)) }) - t.Run("happy flow, policy depth of 1", func(t *testing.T) { + t.Run("happy flow, policy depth of 1 and primary value", func(t *testing.T) { validator.EXPECT().ValidateStrict(gomock.Any()) resolve, documentMetadata, err := resolver.Resolve(rootDID, &metadata) @@ -94,7 +95,21 @@ func TestManager_Resolve_OtherName(t *testing.T) { didUrl, err := did.ParseDIDURL(rootDID.String() + "#0") assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl)) }) - t.Run("happy flow, policy depth of 2", func(t *testing.T) { + t.Run("happy flow, policy depth of 1 and secondary value", func(t *testing.T) { + rootDID := did.MustParseDID(fmt.Sprintf("did:x509:0:%s:%s::san:otherName:%s", "sha256", sha256Sum(rootCertificate.Raw), otherNameValueSecondary)) + + validator.EXPECT().ValidateStrict(gomock.Any()) + resolve, documentMetadata, err := resolver.Resolve(rootDID, &metadata) + + require.NoError(t, err) + assert.NotNil(t, resolve) + require.NoError(t, err) + assert.NotNil(t, documentMetadata) + // Check that the DID url is did#0 + didUrl, err := did.ParseDIDURL(rootDID.String() + "#0") + assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl)) + }) + t.Run("happy flow, policy depth of 2 of type OU", func(t *testing.T) { rootDID := did.MustParseDID(fmt.Sprintf("did:x509:0:%s:%s::san:otherName:%s::subject:OU:%s", "sha256", sha256Sum(rootCertificate.Raw), otherNameValue, "The%20A-Team")) validator.EXPECT().ValidateStrict(gomock.Any()) @@ -108,6 +123,34 @@ func TestManager_Resolve_OtherName(t *testing.T) { didUrl, err := did.ParseDIDURL(rootDID.String() + "#0") assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl)) }) + t.Run("happy flow, policy depth of 2, primary and secondary", func(t *testing.T) { + rootDID := did.MustParseDID(fmt.Sprintf("did:x509:0:%s:%s::san:otherName:%s::san:otherName:%s", "sha256", sha256Sum(rootCertificate.Raw), otherNameValue, otherNameValueSecondary)) + + validator.EXPECT().ValidateStrict(gomock.Any()) + resolve, documentMetadata, err := resolver.Resolve(rootDID, &metadata) + + require.NoError(t, err) + assert.NotNil(t, resolve) + require.NoError(t, err) + assert.NotNil(t, documentMetadata) + // Check that the DID url is did#0 + didUrl, err := did.ParseDIDURL(rootDID.String() + "#0") + assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl)) + }) + t.Run("happy flow, policy depth of 2, secondary and primary", func(t *testing.T) { + rootDID := did.MustParseDID(fmt.Sprintf("did:x509:0:%s:%s::san:otherName:%s::san:otherName:%s", "sha256", sha256Sum(rootCertificate.Raw), otherNameValue, otherNameValueSecondary)) + + validator.EXPECT().ValidateStrict(gomock.Any()) + resolve, documentMetadata, err := resolver.Resolve(rootDID, &metadata) + + require.NoError(t, err) + assert.NotNil(t, resolve) + require.NoError(t, err) + assert.NotNil(t, documentMetadata) + // Check that the DID url is did#0 + didUrl, err := did.ParseDIDURL(rootDID.String() + "#0") + assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl)) + }) t.Run("happy flow with only x5t header", func(t *testing.T) { delete(metadata.JwtProtectedHeaders, X509CertThumbprintS256Header) validator.EXPECT().ValidateStrict(gomock.Any()) @@ -236,7 +279,7 @@ func TestManager_Resolve_San_Generic(t *testing.T) { resolver := NewResolver(validator) metadata := resolver2.ResolveMetadata{} - _, certChain, rootCertificate, _, signingCert, err := BuildCertChain("") + _, certChain, rootCertificate, _, signingCert, err := BuildCertChain([]string{}) require.NoError(t, err) metadata.JwtProtectedHeaders = make(map[string]interface{}) metadata.JwtProtectedHeaders[X509CertChainHeader] = certChain @@ -316,7 +359,7 @@ func TestManager_Resolve_Subject(t *testing.T) { metadata := resolver2.ResolveMetadata{} otherNameValue := "A_BIG_STRING" - _, certChain, rootCertificate, _, signingCert, err := BuildCertChain(otherNameValue) + _, certChain, rootCertificate, _, signingCert, err := BuildCertChain([]string{otherNameValue}) require.NoError(t, err) metadata.JwtProtectedHeaders = make(map[string]interface{}) metadata.JwtProtectedHeaders[X509CertChainHeader] = certChain diff --git a/vdr/didx509/validation.go b/vdr/didx509/validation.go index f70f1193c..22a3153da 100644 --- a/vdr/didx509/validation.go +++ b/vdr/didx509/validation.go @@ -144,11 +144,11 @@ type validationFunction func(cert *x509.Certificate, key string, value string) e // validatorMap maps PolicyKey to their corresponding validation functions for certificate attributes. var validatorMap = map[PolicyKey]validationFunction{ SanPolicyOtherName: func(cert *x509.Certificate, key string, value string) error { - nameValue, err := findOtherNameValue(cert) + nameValues, err := findOtherNameValues(cert) if err != nil { return err } - if nameValue != value { + if !slices.Contains(nameValues, value) { return fmt.Errorf("the SAN attribute %s does not match the query", key) } return nil diff --git a/vdr/didx509/x509_utils.go b/vdr/didx509/x509_utils.go index c6111e99f..98fe2f817 100644 --- a/vdr/didx509/x509_utils.go +++ b/vdr/didx509/x509_utils.go @@ -91,19 +91,19 @@ var ( OtherNameType = asn1.ObjectIdentifier{2, 5, 5, 5} ) -// findOtherNameValue extracts the value of a specified OtherName type from the certificate -func findOtherNameValue(cert *x509.Certificate) (string, error) { +// findOtherNameValues extracts the value of a specified OtherName types from the certificate +func findOtherNameValues(cert *x509.Certificate) ([]string, error) { for _, extension := range cert.Extensions { if extension.Id.Equal(SubjectAlternativeNameType) { - return findSanValue(extension) + return findSanValues(extension) } } - return "", nil + return make([]string, 0), nil } -// findSanValue extracts the SAN value from a given pkix.Extension, returning the resulting value or an error. -func findSanValue(extension pkix.Extension) (string, error) { - value := "" +// findSanValues extracts the SAN values from a given pkix.Extension, returning the resulting values or an error. +func findSanValues(extension pkix.Extension) ([]string, error) { + var values []string err := forEachSan(extension.Value, func(data []byte) error { var other OtherName _, err := asn1.UnmarshalWithParams(data, &other, "tag:0") @@ -111,17 +111,19 @@ func findSanValue(extension pkix.Extension) (string, error) { return err } if other.TypeID.Equal(OtherNameType) { + var value string _, err = asn1.Unmarshal(other.Value.Bytes, &value) if err != nil { return err } + values = append(values, value) } return nil }) if err != nil { - return "", err + return make([]string, 0), err } - return value, err + return values, err } // forEachSan processes each SAN extension in the certificate diff --git a/vdr/didx509/x509_utils_test.go b/vdr/didx509/x509_utils_test.go index 7c2144217..fc4de62f3 100644 --- a/vdr/didx509/x509_utils_test.go +++ b/vdr/didx509/x509_utils_test.go @@ -32,13 +32,14 @@ import ( "github.com/lestrrat-go/jwx/v2/cert" "math/big" "net" + "slices" "strings" "testing" "time" ) // BuildCertChain generates a certificate chain, including root, intermediate, and signing certificates. -func BuildCertChain(identifier string) (chainCerts [4]*x509.Certificate, chain *cert.Chain, rootCertificate *x509.Certificate, signingKey *rsa.PrivateKey, signingCert *x509.Certificate, err error) { +func BuildCertChain(identifiers []string) (chainCerts [4]*x509.Certificate, chain *cert.Chain, rootCertificate *x509.Certificate, signingKey *rsa.PrivateKey, signingCert *x509.Certificate, err error) { chainCerts = [4]*x509.Certificate{} chain = &cert.Chain{} rootKey, rootCert, rootPem, err := buildRootCert() @@ -68,7 +69,7 @@ func BuildCertChain(identifier string) (chainCerts [4]*x509.Certificate, chain * return chainCerts, nil, nil, nil, nil, err } - signingKey, signingCert, signingPEM, err := buildSigningCert(identifier, intermediateL2Cert, intermediateL2Key, "32121323") + signingKey, signingCert, signingPEM, err := buildSigningCert(identifiers, intermediateL2Cert, intermediateL2Key, "32121323") if err != nil { return chainCerts, nil, nil, nil, nil, err } @@ -80,12 +81,12 @@ func BuildCertChain(identifier string) (chainCerts [4]*x509.Certificate, chain * return chainCerts, chain, rootCert, signingKey, signingCert, nil } -func buildSigningCert(identifier string, intermediateL2Cert *x509.Certificate, intermediateL2Key *rsa.PrivateKey, serialNumber string) (*rsa.PrivateKey, *x509.Certificate, []byte, error) { +func buildSigningCert(identifiers []string, intermediateL2Cert *x509.Certificate, intermediateL2Key *rsa.PrivateKey, serialNumber string) (*rsa.PrivateKey, *x509.Certificate, []byte, error) { signingKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, nil, nil, err } - signingTmpl, err := SigningCertTemplate(nil, identifier) + signingTmpl, err := SigningCertTemplate(nil, identifiers) if err != nil { return nil, nil, nil, err } @@ -152,7 +153,7 @@ func CertTemplate(serialNumber *big.Int) (*x509.Certificate, error) { } // SigningCertTemplate creates a x509.Certificate template for a signing certificate with an optional serial number. -func SigningCertTemplate(serialNumber *big.Int, identifier string) (*x509.Certificate, error) { +func SigningCertTemplate(serialNumber *big.Int, identifiers []string) (*x509.Certificate, error) { // generate a random serial number (a real cert authority would have some logic behind this) if serialNumber == nil { serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 8) @@ -179,8 +180,8 @@ func SigningCertTemplate(serialNumber *big.Int, identifier string) (*x509.Certif tmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} // Either the ExtraExtensions SubjectAlternativeNameType is set, or the Subject Alternate Name values are set, // both don't mix - if identifier != "" { - err := setSanAlternativeName(&tmpl, identifier) + if len(identifiers) > 0 { + err := setSanAlternativeName(&tmpl, identifiers) if err != nil { return nil, err } @@ -192,27 +193,30 @@ func SigningCertTemplate(serialNumber *big.Int, identifier string) (*x509.Certif return &tmpl, nil } -func setSanAlternativeName(tmpl *x509.Certificate, identifier string) error { - raw, err := toRawValue(identifier, "ia5") - if err != nil { - return err - } - otherName := OtherName{ - TypeID: OtherNameType, - Value: asn1.RawValue{ - Class: 2, - Tag: 0, - IsCompound: true, - Bytes: raw.FullBytes, - }, - } +func setSanAlternativeName(tmpl *x509.Certificate, identifiers []string) error { + var list []asn1.RawValue - raw, err = toRawValue(otherName, "tag:0") - if err != nil { - return err + for _, identifier := range identifiers { + raw, err := toRawValue(identifier, "ia5") + if err != nil { + return err + } + otherName := OtherName{ + TypeID: OtherNameType, + Value: asn1.RawValue{ + Class: 2, + Tag: 0, + IsCompound: true, + Bytes: raw.FullBytes, + }, + } + + raw, err = toRawValue(otherName, "tag:0") + if err != nil { + return err + } + list = append(list, *raw) } - var list []asn1.RawValue - list = append(list, *raw) marshal, err := asn1.Marshal(list) if err != nil { return err @@ -261,7 +265,7 @@ func CreateCert(template, parent *x509.Certificate, pub interface{}, parentPriv func TestFindOtherNameValue(t *testing.T) { t.Parallel() key, certificate, _, err := buildRootCert() - _, signingCert, _, err := buildSigningCert("123", certificate, key, "4567") + _, signingCert, _, err := buildSigningCert([]string{"123", "321"}, certificate, key, "4567") if err != nil { t.Fatalf("failed to build root certificate: %v", err) } @@ -279,22 +283,28 @@ func TestFindOtherNameValue(t *testing.T) { wantErr: false, }, { - name: "with extensions", + name: "with extensions first", cert: signingCert, want: "123", wantErr: false, }, + { + name: "with extensions second", + cert: signingCert, + want: "321", + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotName, err := findOtherNameValue(tt.cert) + gotName, err := findOtherNameValues(tt.cert) if (err != nil) != tt.wantErr { - t.Errorf("findOtherNameValue() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("findOtherNameValues() error = %v, wantErr %v", err, tt.wantErr) return } - if gotName != tt.want { - t.Errorf("findOtherNameValue() = %v, want %v", gotName, tt.want) + if tt.want != "" && !slices.Contains(gotName, tt.want) { + t.Errorf("findOtherNameValues() = %v, want %v", gotName, tt.want) } }) } @@ -308,7 +318,7 @@ func TestFindCertificateByHash(t *testing.T) { } return base64.RawURLEncoding.EncodeToString(h) } - chainCerts, _, _, _, _, err := BuildCertChain("123") + chainCerts, _, _, _, _, err := BuildCertChain([]string{"123"}) if err != nil { t.Error(err) } @@ -409,7 +419,7 @@ func TestParseChain(t *testing.T) { } return &chain } - certs, chain, _, _, _, _ := BuildCertChain("123") + certs, chain, _, _, _, _ := BuildCertChain([]string{"123"}) invalidPEM := `-----BEGIN CERTIFICATE----- Y29ycnVwdCBjZXJ0aWZpY2F0ZQo= @@ -704,7 +714,7 @@ func TestFindSanValue(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - val, foundErr := findSanValue(pkix.Extension{ + val, foundErr := findSanValues(pkix.Extension{ Value: tt.rest, }) if foundErr != nil { @@ -719,15 +729,15 @@ func TestFindSanValue(t *testing.T) { t.Errorf("forEachSan() error = %v", foundErr) } if foundErr.Error() != tt.wantError.Error() { - t.Errorf("findSanValue() error = %v, want: %v", foundErr, tt.wantError) + t.Errorf("findSanValues() error = %v, want: %v", foundErr, tt.wantError) return } } } } - if val != tt.expectedValue { - t.Errorf("findSanValue() = %v, want: %v", val, tt.expectedValue) + if tt.expectedValue != "" && !slices.Contains(val, tt.expectedValue) { + t.Errorf("findSanValues() = %v, want: %v", val, tt.expectedValue) } })