diff --git a/crypto.go b/crypto.go index 13a5cd55..403d8793 100644 --- a/crypto.go +++ b/crypto.go @@ -27,6 +27,7 @@ import ( "io" "math/big" "net" + "net/url" "time" "strings" @@ -341,7 +342,7 @@ func generateCertificateAuthorityWithKeyInternal( ) (certificate, error) { ca := certificate{} - template, err := getBaseCertTemplate(cn, nil, nil, daysValid) + template, err := getBaseCertTemplate(cn, nil, nil, nil, daysValid) if err != nil { return ca, err } @@ -360,19 +361,21 @@ func generateSelfSignedCertificate( cn string, ips []interface{}, alternateDNS []interface{}, + alternateURIs []interface{}, daysValid int, ) (certificate, error) { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return certificate{}, fmt.Errorf("error generating rsa key: %s", err) } - return generateSelfSignedCertificateWithKeyInternal(cn, ips, alternateDNS, daysValid, priv) + return generateSelfSignedCertificateWithKeyInternal(cn, ips, alternateDNS, alternateURIs, daysValid, priv) } func generateSelfSignedCertificateWithPEMKey( cn string, ips []interface{}, alternateDNS []interface{}, + alternateURIs []interface{}, daysValid int, privPEM string, ) (certificate, error) { @@ -380,19 +383,20 @@ func generateSelfSignedCertificateWithPEMKey( if err != nil { return certificate{}, fmt.Errorf("parsing private key: %s", err) } - return generateSelfSignedCertificateWithKeyInternal(cn, ips, alternateDNS, daysValid, priv) + return generateSelfSignedCertificateWithKeyInternal(cn, ips, alternateDNS, alternateURIs, daysValid, priv) } func generateSelfSignedCertificateWithKeyInternal( cn string, ips []interface{}, alternateDNS []interface{}, + alternateURIs []interface{}, daysValid int, priv crypto.PrivateKey, ) (certificate, error) { cert := certificate{} - template, err := getBaseCertTemplate(cn, ips, alternateDNS, daysValid) + template, err := getBaseCertTemplate(cn, ips, alternateDNS, alternateURIs, daysValid) if err != nil { return cert, err } @@ -406,6 +410,7 @@ func generateSignedCertificate( cn string, ips []interface{}, alternateDNS []interface{}, + alternateURIs []interface{}, daysValid int, ca certificate, ) (certificate, error) { @@ -413,13 +418,14 @@ func generateSignedCertificate( if err != nil { return certificate{}, fmt.Errorf("error generating rsa key: %s", err) } - return generateSignedCertificateWithKeyInternal(cn, ips, alternateDNS, daysValid, ca, priv) + return generateSignedCertificateWithKeyInternal(cn, ips, alternateDNS, alternateURIs, daysValid, ca, priv) } func generateSignedCertificateWithPEMKey( cn string, ips []interface{}, alternateDNS []interface{}, + alternateURIs []interface{}, daysValid int, ca certificate, privPEM string, @@ -428,13 +434,14 @@ func generateSignedCertificateWithPEMKey( if err != nil { return certificate{}, fmt.Errorf("parsing private key: %s", err) } - return generateSignedCertificateWithKeyInternal(cn, ips, alternateDNS, daysValid, ca, priv) + return generateSignedCertificateWithKeyInternal(cn, ips, alternateDNS, alternateURIs, daysValid, ca, priv) } func generateSignedCertificateWithKeyInternal( cn string, ips []interface{}, alternateDNS []interface{}, + alternateURIs []interface{}, daysValid int, ca certificate, priv crypto.PrivateKey, @@ -460,7 +467,7 @@ func generateSignedCertificateWithKeyInternal( ) } - template, err := getBaseCertTemplate(cn, ips, alternateDNS, daysValid) + template, err := getBaseCertTemplate(cn, ips, alternateDNS, alternateURIs, daysValid) if err != nil { return cert, err } @@ -519,6 +526,7 @@ func getBaseCertTemplate( cn string, ips []interface{}, alternateDNS []interface{}, + alternateURIs []interface{}, daysValid int, ) (*x509.Certificate, error) { ipAddresses, err := getNetIPs(ips) @@ -529,6 +537,10 @@ func getBaseCertTemplate( if err != nil { return nil, err } + uris, err := getURIs(alternateURIs) + if err != nil { + return nil, err + } serialNumberUpperBound := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberUpperBound) if err != nil { @@ -541,6 +553,7 @@ func getBaseCertTemplate( }, IPAddresses: ipAddresses, DNSNames: dnsNames, + URIs: uris, NotBefore: time.Now(), NotAfter: time.Now().Add(time.Hour * 24 * time.Duration(daysValid)), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, @@ -594,6 +607,27 @@ func getAlternateDNSStrs(alternateDNS []interface{}) ([]string, error) { return alternateDNSStrs, nil } +func getURIs(uris []interface{}) ([]*url.URL, error) { + if uris == nil { + return []*url.URL{}, nil + } + var uriStr string + var ok bool + urlURIs := make([]*url.URL, len(uris)) + for i, uri := range uris { + uriStr, ok = uri.(string) + if !ok { + return nil, fmt.Errorf("error parsing uri: %v is not a string", uri) + } + u, err := url.Parse(uriStr) + if err != nil { + return nil, err + } + urlURIs[i] = u + } + return urlURIs, nil +} + func encryptAES(password string, plaintext string) (string, error) { if plaintext == "" { return "", nil diff --git a/crypto_test.go b/crypto_test.go index 449e7ffd..f1a5ceaa 100644 --- a/crypto_test.go +++ b/crypto_test.go @@ -303,6 +303,8 @@ func testGenSelfSignedCert(t *testing.T, keyAlgo *string) { ip2 = "10.0.0.2" dns1 = "bar.com" dns2 = "bat.com" + uri1 = "https://www.example.com" + uri2 = "spiffe://example.com/workload" ) var genSelfSignedCertExpr string @@ -313,7 +315,7 @@ func testGenSelfSignedCert(t *testing.T, keyAlgo *string) { } tpl := fmt.Sprintf( - `{{- $cert := %s "%s" (list "%s" "%s") (list "%s" "%s") 365 }} + `{{- $cert := %s "%s" (list "%s" "%s") (list "%s" "%s") (list "%s" "%s") 365 }} {{ $cert.Cert }}`, genSelfSignedCertExpr, cn, @@ -321,6 +323,8 @@ func testGenSelfSignedCert(t *testing.T, keyAlgo *string) { ip2, dns1, dns2, + uri1, + uri2, ) out, err := runRaw(tpl, nil) @@ -342,6 +346,8 @@ func testGenSelfSignedCert(t *testing.T, keyAlgo *string) { assert.Equal(t, ip2, cert.IPAddresses[1].String()) assert.Contains(t, cert.DNSNames, dns1) assert.Contains(t, cert.DNSNames, dns2) + assert.Equal(t, uri1, cert.URIs[0].String()) + assert.Equal(t, uri2, cert.URIs[1].String()) assert.False(t, cert.IsCA) } @@ -366,6 +372,8 @@ func testGenSignedCert(t *testing.T, caKeyAlgo, certKeyAlgo *string) { ip2 = "10.0.0.2" dns1 = "bar.com" dns2 = "bat.com" + uri1 = "https://www.example.com" + uri2 = "spiffe://example.com/workload" ) var genCAExpr, genSignedCertExpr string @@ -382,7 +390,7 @@ func testGenSignedCert(t *testing.T, caKeyAlgo, certKeyAlgo *string) { tpl := fmt.Sprintf( `{{- $ca := %s "foo" 365 }} -{{- $cert := %s "%s" (list "%s" "%s") (list "%s" "%s") 365 $ca }} +{{- $cert := %s "%s" (list "%s" "%s") (list "%s" "%s") (list "%s" "%s") 365 $ca }} {{ $cert.Cert }} `, genCAExpr, @@ -392,6 +400,8 @@ func testGenSignedCert(t *testing.T, caKeyAlgo, certKeyAlgo *string) { ip2, dns1, dns2, + uri1, + uri2, ) out, err := runRaw(tpl, nil) if err != nil { @@ -413,6 +423,8 @@ func testGenSignedCert(t *testing.T, caKeyAlgo, certKeyAlgo *string) { assert.Equal(t, ip2, cert.IPAddresses[1].String()) assert.Contains(t, cert.DNSNames, dns1) assert.Contains(t, cert.DNSNames, dns2) + assert.Equal(t, uri1, cert.URIs[0].String()) + assert.Equal(t, uri2, cert.URIs[1].String()) assert.False(t, cert.IsCA) }