Skip to content

Commit

Permalink
Support multiple SAN OtherName values in certificates
Browse files Browse the repository at this point in the history
Updated functions to handle multiple SAN OtherName values in certificates and modified tests accordingly. Refactored validation logic to accommodate lists of SAN OtherName values.
  • Loading branch information
rolandgroen committed Nov 5, 2024
1 parent 4ada7aa commit eded558
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 54 deletions.
53 changes: 48 additions & 5 deletions vdr/didx509/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ func TestManager_Resolve_OtherName(t *testing.T) {
metadata := resolver2.ResolveMetadata{}

otherNameValue := "A_BIG_STRING"
_, certChain, rootCertificate, _, signingCert, err := BuildCertChain(otherNameValue)
otherNameValueSecondary := "A_SECOND_STRING"
_, certChain, rootCertificate, _, signingCert, err := BuildCertChain([]string{otherNameValue, otherNameValueSecondary})
require.NoError(t, err)
metadata.JwtProtectedHeaders = make(map[string]interface{})
metadata.JwtProtectedHeaders[X509CertChainHeader] = certChain
Expand Down Expand Up @@ -82,7 +83,7 @@ func TestManager_Resolve_OtherName(t *testing.T) {
didUrl, err := did.ParseDIDURL(rootDID.String() + "#0")
assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl))
})
t.Run("happy flow, policy depth of 1", func(t *testing.T) {
t.Run("happy flow, policy depth of 1 and primary value", func(t *testing.T) {
validator.EXPECT().ValidateStrict(gomock.Any())
resolve, documentMetadata, err := resolver.Resolve(rootDID, &metadata)

Expand All @@ -94,7 +95,21 @@ func TestManager_Resolve_OtherName(t *testing.T) {
didUrl, err := did.ParseDIDURL(rootDID.String() + "#0")
assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl))
})
t.Run("happy flow, policy depth of 2", func(t *testing.T) {
t.Run("happy flow, policy depth of 1 and secondary value", func(t *testing.T) {
rootDID := did.MustParseDID(fmt.Sprintf("did:x509:0:%s:%s::san:otherName:%s", "sha256", sha256Sum(rootCertificate.Raw), otherNameValueSecondary))

validator.EXPECT().ValidateStrict(gomock.Any())
resolve, documentMetadata, err := resolver.Resolve(rootDID, &metadata)

require.NoError(t, err)
assert.NotNil(t, resolve)
require.NoError(t, err)
assert.NotNil(t, documentMetadata)
// Check that the DID url is did#0
didUrl, err := did.ParseDIDURL(rootDID.String() + "#0")
assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl))
})
t.Run("happy flow, policy depth of 2 of type OU", func(t *testing.T) {
rootDID := did.MustParseDID(fmt.Sprintf("did:x509:0:%s:%s::san:otherName:%s::subject:OU:%s", "sha256", sha256Sum(rootCertificate.Raw), otherNameValue, "The%20A-Team"))

validator.EXPECT().ValidateStrict(gomock.Any())
Expand All @@ -108,6 +123,34 @@ func TestManager_Resolve_OtherName(t *testing.T) {
didUrl, err := did.ParseDIDURL(rootDID.String() + "#0")
assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl))
})
t.Run("happy flow, policy depth of 2, primary and secondary", func(t *testing.T) {
rootDID := did.MustParseDID(fmt.Sprintf("did:x509:0:%s:%s::san:otherName:%s::san:otherName:%s", "sha256", sha256Sum(rootCertificate.Raw), otherNameValue, otherNameValueSecondary))

validator.EXPECT().ValidateStrict(gomock.Any())
resolve, documentMetadata, err := resolver.Resolve(rootDID, &metadata)

require.NoError(t, err)
assert.NotNil(t, resolve)
require.NoError(t, err)
assert.NotNil(t, documentMetadata)
// Check that the DID url is did#0
didUrl, err := did.ParseDIDURL(rootDID.String() + "#0")
assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl))
})
t.Run("happy flow, policy depth of 2, secondary and primary", func(t *testing.T) {
rootDID := did.MustParseDID(fmt.Sprintf("did:x509:0:%s:%s::san:otherName:%s::san:otherName:%s", "sha256", sha256Sum(rootCertificate.Raw), otherNameValue, otherNameValueSecondary))

validator.EXPECT().ValidateStrict(gomock.Any())
resolve, documentMetadata, err := resolver.Resolve(rootDID, &metadata)

require.NoError(t, err)
assert.NotNil(t, resolve)
require.NoError(t, err)
assert.NotNil(t, documentMetadata)
// Check that the DID url is did#0
didUrl, err := did.ParseDIDURL(rootDID.String() + "#0")
assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl))
})
t.Run("happy flow with only x5t header", func(t *testing.T) {
delete(metadata.JwtProtectedHeaders, X509CertThumbprintS256Header)
validator.EXPECT().ValidateStrict(gomock.Any())
Expand Down Expand Up @@ -236,7 +279,7 @@ func TestManager_Resolve_San_Generic(t *testing.T) {
resolver := NewResolver(validator)
metadata := resolver2.ResolveMetadata{}

_, certChain, rootCertificate, _, signingCert, err := BuildCertChain("")
_, certChain, rootCertificate, _, signingCert, err := BuildCertChain([]string{})
require.NoError(t, err)
metadata.JwtProtectedHeaders = make(map[string]interface{})
metadata.JwtProtectedHeaders[X509CertChainHeader] = certChain
Expand Down Expand Up @@ -316,7 +359,7 @@ func TestManager_Resolve_Subject(t *testing.T) {
metadata := resolver2.ResolveMetadata{}

otherNameValue := "A_BIG_STRING"
_, certChain, rootCertificate, _, signingCert, err := BuildCertChain(otherNameValue)
_, certChain, rootCertificate, _, signingCert, err := BuildCertChain([]string{otherNameValue})
require.NoError(t, err)
metadata.JwtProtectedHeaders = make(map[string]interface{})
metadata.JwtProtectedHeaders[X509CertChainHeader] = certChain
Expand Down
4 changes: 2 additions & 2 deletions vdr/didx509/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ type validationFunction func(cert *x509.Certificate, key string, value string) e
// validatorMap maps PolicyKey to their corresponding validation functions for certificate attributes.
var validatorMap = map[PolicyKey]validationFunction{
SanPolicyOtherName: func(cert *x509.Certificate, key string, value string) error {
nameValue, err := findOtherNameValue(cert)
nameValues, err := findOtherNameValues(cert)
if err != nil {
return err
}
if nameValue != value {
if !slices.Contains(nameValues, value) {
return fmt.Errorf("the SAN attribute %s does not match the query", key)
}
return nil
Expand Down
20 changes: 11 additions & 9 deletions vdr/didx509/x509_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,37 +91,39 @@ var (
OtherNameType = asn1.ObjectIdentifier{2, 5, 5, 5}
)

// findOtherNameValue extracts the value of a specified OtherName type from the certificate
func findOtherNameValue(cert *x509.Certificate) (string, error) {
// findOtherNameValues extracts the value of a specified OtherName types from the certificate
func findOtherNameValues(cert *x509.Certificate) ([]string, error) {
for _, extension := range cert.Extensions {
if extension.Id.Equal(SubjectAlternativeNameType) {
return findSanValue(extension)
return findSanValues(extension)
}
}
return "", nil
return make([]string, 0), nil
}

// findSanValue extracts the SAN value from a given pkix.Extension, returning the resulting value or an error.
func findSanValue(extension pkix.Extension) (string, error) {
value := ""
// findSanValues extracts the SAN values from a given pkix.Extension, returning the resulting values or an error.
func findSanValues(extension pkix.Extension) ([]string, error) {
var values []string
err := forEachSan(extension.Value, func(data []byte) error {
var other OtherName
_, err := asn1.UnmarshalWithParams(data, &other, "tag:0")
if err != nil {
return err
}
if other.TypeID.Equal(OtherNameType) {
var value string
_, err = asn1.Unmarshal(other.Value.Bytes, &value)
if err != nil {
return err
}
values = append(values, value)
}
return nil
})
if err != nil {
return "", err
return make([]string, 0), err
}
return value, err
return values, err
}

// forEachSan processes each SAN extension in the certificate
Expand Down
86 changes: 48 additions & 38 deletions vdr/didx509/x509_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ import (
"github.com/lestrrat-go/jwx/v2/cert"
"math/big"
"net"
"slices"
"strings"
"testing"
"time"
)

// BuildCertChain generates a certificate chain, including root, intermediate, and signing certificates.
func BuildCertChain(identifier string) (chainCerts [4]*x509.Certificate, chain *cert.Chain, rootCertificate *x509.Certificate, signingKey *rsa.PrivateKey, signingCert *x509.Certificate, err error) {
func BuildCertChain(identifiers []string) (chainCerts [4]*x509.Certificate, chain *cert.Chain, rootCertificate *x509.Certificate, signingKey *rsa.PrivateKey, signingCert *x509.Certificate, err error) {
chainCerts = [4]*x509.Certificate{}
chain = &cert.Chain{}
rootKey, rootCert, rootPem, err := buildRootCert()
Expand Down Expand Up @@ -68,7 +69,7 @@ func BuildCertChain(identifier string) (chainCerts [4]*x509.Certificate, chain *
return chainCerts, nil, nil, nil, nil, err
}

signingKey, signingCert, signingPEM, err := buildSigningCert(identifier, intermediateL2Cert, intermediateL2Key, "32121323")
signingKey, signingCert, signingPEM, err := buildSigningCert(identifiers, intermediateL2Cert, intermediateL2Key, "32121323")
if err != nil {
return chainCerts, nil, nil, nil, nil, err
}
Expand All @@ -80,12 +81,12 @@ func BuildCertChain(identifier string) (chainCerts [4]*x509.Certificate, chain *
return chainCerts, chain, rootCert, signingKey, signingCert, nil
}

func buildSigningCert(identifier string, intermediateL2Cert *x509.Certificate, intermediateL2Key *rsa.PrivateKey, serialNumber string) (*rsa.PrivateKey, *x509.Certificate, []byte, error) {
func buildSigningCert(identifiers []string, intermediateL2Cert *x509.Certificate, intermediateL2Key *rsa.PrivateKey, serialNumber string) (*rsa.PrivateKey, *x509.Certificate, []byte, error) {
signingKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, nil, err
}
signingTmpl, err := SigningCertTemplate(nil, identifier)
signingTmpl, err := SigningCertTemplate(nil, identifiers)
if err != nil {
return nil, nil, nil, err
}
Expand Down Expand Up @@ -152,7 +153,7 @@ func CertTemplate(serialNumber *big.Int) (*x509.Certificate, error) {
}

// SigningCertTemplate creates a x509.Certificate template for a signing certificate with an optional serial number.
func SigningCertTemplate(serialNumber *big.Int, identifier string) (*x509.Certificate, error) {
func SigningCertTemplate(serialNumber *big.Int, identifiers []string) (*x509.Certificate, error) {
// generate a random serial number (a real cert authority would have some logic behind this)
if serialNumber == nil {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 8)
Expand All @@ -179,8 +180,8 @@ func SigningCertTemplate(serialNumber *big.Int, identifier string) (*x509.Certif
tmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
// Either the ExtraExtensions SubjectAlternativeNameType is set, or the Subject Alternate Name values are set,
// both don't mix
if identifier != "" {
err := setSanAlternativeName(&tmpl, identifier)
if len(identifiers) > 0 {
err := setSanAlternativeName(&tmpl, identifiers)
if err != nil {
return nil, err
}
Expand All @@ -192,27 +193,30 @@ func SigningCertTemplate(serialNumber *big.Int, identifier string) (*x509.Certif
return &tmpl, nil
}

func setSanAlternativeName(tmpl *x509.Certificate, identifier string) error {
raw, err := toRawValue(identifier, "ia5")
if err != nil {
return err
}
otherName := OtherName{
TypeID: OtherNameType,
Value: asn1.RawValue{
Class: 2,
Tag: 0,
IsCompound: true,
Bytes: raw.FullBytes,
},
}
func setSanAlternativeName(tmpl *x509.Certificate, identifiers []string) error {
var list []asn1.RawValue

raw, err = toRawValue(otherName, "tag:0")
if err != nil {
return err
for _, identifier := range identifiers {
raw, err := toRawValue(identifier, "ia5")
if err != nil {
return err
}
otherName := OtherName{
TypeID: OtherNameType,
Value: asn1.RawValue{
Class: 2,
Tag: 0,
IsCompound: true,
Bytes: raw.FullBytes,
},
}

raw, err = toRawValue(otherName, "tag:0")
if err != nil {
return err
}
list = append(list, *raw)
}
var list []asn1.RawValue
list = append(list, *raw)
marshal, err := asn1.Marshal(list)
if err != nil {
return err
Expand Down Expand Up @@ -261,7 +265,7 @@ func CreateCert(template, parent *x509.Certificate, pub interface{}, parentPriv
func TestFindOtherNameValue(t *testing.T) {
t.Parallel()
key, certificate, _, err := buildRootCert()
_, signingCert, _, err := buildSigningCert("123", certificate, key, "4567")
_, signingCert, _, err := buildSigningCert([]string{"123", "321"}, certificate, key, "4567")
if err != nil {
t.Fatalf("failed to build root certificate: %v", err)
}
Expand All @@ -279,22 +283,28 @@ func TestFindOtherNameValue(t *testing.T) {
wantErr: false,
},
{
name: "with extensions",
name: "with extensions first",
cert: signingCert,
want: "123",
wantErr: false,
},
{
name: "with extensions second",
cert: signingCert,
want: "321",
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotName, err := findOtherNameValue(tt.cert)
gotName, err := findOtherNameValues(tt.cert)
if (err != nil) != tt.wantErr {
t.Errorf("findOtherNameValue() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("findOtherNameValues() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotName != tt.want {
t.Errorf("findOtherNameValue() = %v, want %v", gotName, tt.want)
if tt.want != "" && !slices.Contains(gotName, tt.want) {
t.Errorf("findOtherNameValues() = %v, want %v", gotName, tt.want)
}
})
}
Expand All @@ -308,7 +318,7 @@ func TestFindCertificateByHash(t *testing.T) {
}
return base64.RawURLEncoding.EncodeToString(h)
}
chainCerts, _, _, _, _, err := BuildCertChain("123")
chainCerts, _, _, _, _, err := BuildCertChain([]string{"123"})
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -409,7 +419,7 @@ func TestParseChain(t *testing.T) {
}
return &chain
}
certs, chain, _, _, _, _ := BuildCertChain("123")
certs, chain, _, _, _, _ := BuildCertChain([]string{"123"})

invalidPEM := `-----BEGIN CERTIFICATE-----
Y29ycnVwdCBjZXJ0aWZpY2F0ZQo=
Expand Down Expand Up @@ -704,7 +714,7 @@ func TestFindSanValue(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, foundErr := findSanValue(pkix.Extension{
val, foundErr := findSanValues(pkix.Extension{
Value: tt.rest,
})
if foundErr != nil {
Expand All @@ -719,15 +729,15 @@ func TestFindSanValue(t *testing.T) {
t.Errorf("forEachSan() error = %v", foundErr)
}
if foundErr.Error() != tt.wantError.Error() {
t.Errorf("findSanValue() error = %v, want: %v", foundErr, tt.wantError)
t.Errorf("findSanValues() error = %v, want: %v", foundErr, tt.wantError)
return
}
}
}
}

if val != tt.expectedValue {
t.Errorf("findSanValue() = %v, want: %v", val, tt.expectedValue)
if tt.expectedValue != "" && !slices.Contains(val, tt.expectedValue) {
t.Errorf("findSanValues() = %v, want: %v", val, tt.expectedValue)
}

})
Expand Down

0 comments on commit eded558

Please sign in to comment.