diff --git a/builtin/credential/cert/backend.go b/builtin/credential/cert/backend.go index 89625fbb49d3..53ebc9d74834 100644 --- a/builtin/credential/cert/backend.go +++ b/builtin/credential/cert/backend.go @@ -8,10 +8,13 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "time" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/ocsp" "github.com/hashicorp/vault/sdk/logical" ) @@ -20,6 +23,13 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, if err := b.Setup(ctx, conf); err != nil { return nil, err } + bConf, err := b.Config(ctx, conf.StorageView) + if err != nil { + return nil, err + } + if bConf != nil { + b.updatedConfig(bConf) + } if err := b.lockThenpopulateCRLs(ctx, conf.StorageView); err != nil { return nil, err } @@ -49,7 +59,6 @@ func Backend() *backend { } b.crlUpdateMutex = &sync.RWMutex{} - return &b } @@ -57,8 +66,11 @@ type backend struct { *framework.Backend MapCertId *framework.PathMap - crls map[string]CRLInfo - crlUpdateMutex *sync.RWMutex + crls map[string]CRLInfo + crlUpdateMutex *sync.RWMutex + ocspClientMutex sync.RWMutex + ocspClient *ocsp.Client + configUpdated atomic.Bool } func (b *backend) invalidate(_ context.Context, key string) { @@ -67,9 +79,25 @@ func (b *backend) invalidate(_ context.Context, key string) { b.crlUpdateMutex.Lock() defer b.crlUpdateMutex.Unlock() b.crls = nil + case key == "config": + b.configUpdated.Store(true) } } +func (b *backend) initOCSPClient(cacheSize int) { + b.ocspClient = ocsp.New(func() hclog.Logger { + return b.Logger() + }, cacheSize) +} + +func (b *backend) updatedConfig(config *config) error { + b.ocspClientMutex.Lock() + defer b.ocspClientMutex.Unlock() + b.initOCSPClient(config.OcspCacheSize) + b.configUpdated.Store(false) + return nil +} + func (b *backend) fetchCRL(ctx context.Context, storage logical.Storage, name string, crl *CRLInfo) error { response, err := http.Get(crl.CDP.Url) if err != nil { @@ -104,6 +132,19 @@ func (b *backend) updateCRLs(ctx context.Context, req *logical.Request) error { return errs.ErrorOrNil() } +func (b *backend) storeConfig(ctx context.Context, storage logical.Storage, config *config) error { + entry, err := logical.StorageEntryJSON("config", config) + if err != nil { + return err + } + + if err := storage.Put(ctx, entry); err != nil { + return err + } + b.updatedConfig(config) + return nil +} + const backendHelp = ` The "cert" credential provider allows authentication using TLS client certificates. A client connects to Vault and uses diff --git a/builtin/credential/cert/backend_test.go b/builtin/credential/cert/backend_test.go index 062fc156bc7a..9764cf608e42 100644 --- a/builtin/credential/cert/backend_test.go +++ b/builtin/credential/cert/backend_test.go @@ -1062,12 +1062,13 @@ func TestBackend_CRLs(t *testing.T) { } func testFactory(t *testing.T) logical.Backend { + storage := &logical.InmemStorage{} b, err := Factory(context.Background(), &logical.BackendConfig{ System: &logical.StaticSystemView{ DefaultLeaseTTLVal: 1000 * time.Second, MaxLeaseTTLVal: 1800 * time.Second, }, - StorageView: &logical.InmemStorage{}, + StorageView: storage, }) if err != nil { t.Fatalf("error: %s", err) @@ -1893,27 +1894,33 @@ type allowed struct { metadata_ext string // allowed metadata extensions to add to identity alias } -func testAccStepCert( - t *testing.T, name string, cert []byte, policies string, testData allowed, expectError bool, -) logicaltest.TestStep { +func testAccStepCert(t *testing.T, name string, cert []byte, policies string, testData allowed, expectError bool) logicaltest.TestStep { + return testAccStepCertWithExtraParams(t, name, cert, policies, testData, expectError, nil) +} + +func testAccStepCertWithExtraParams(t *testing.T, name string, cert []byte, policies string, testData allowed, expectError bool, extraParams map[string]interface{}) logicaltest.TestStep { + data := map[string]interface{}{ + "certificate": string(cert), + "policies": policies, + "display_name": name, + "allowed_names": testData.names, + "allowed_common_names": testData.common_names, + "allowed_dns_sans": testData.dns, + "allowed_email_sans": testData.emails, + "allowed_uri_sans": testData.uris, + "allowed_organizational_units": testData.organizational_units, + "required_extensions": testData.ext, + "allowed_metadata_extensions": testData.metadata_ext, + "lease": 1000, + } + for k, v := range extraParams { + data[k] = v + } return logicaltest.TestStep{ Operation: logical.UpdateOperation, Path: "certs/" + name, ErrorOk: expectError, - Data: map[string]interface{}{ - "certificate": string(cert), - "policies": policies, - "display_name": name, - "allowed_names": testData.names, - "allowed_common_names": testData.common_names, - "allowed_dns_sans": testData.dns, - "allowed_email_sans": testData.emails, - "allowed_uri_sans": testData.uris, - "allowed_organizational_units": testData.organizational_units, - "required_extensions": testData.ext, - "allowed_metadata_extensions": testData.metadata_ext, - "lease": 1000, - }, + Data: data, Check: func(resp *logical.Response) error { if resp == nil && expectError { return fmt.Errorf("expected error but received nil") diff --git a/builtin/credential/cert/path_certs.go b/builtin/credential/cert/path_certs.go index 00e103b51a56..065829da9a50 100644 --- a/builtin/credential/cert/path_certs.go +++ b/builtin/credential/cert/path_certs.go @@ -7,7 +7,8 @@ import ( "strings" "time" - sockaddr "github.com/hashicorp/go-sockaddr" + "github.com/hashicorp/go-sockaddr" + "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" @@ -47,7 +48,32 @@ Must be x509 PEM encoded.`, EditType: "file", }, }, - + "ocsp_enabled": { + Type: framework.TypeBool, + Description: `Whether to attempt OCSP verification of certificates at login`, + }, + "ocsp_ca_certificates": { + Type: framework.TypeString, + Description: `Any additional CA certificates needed to communicate with OCSP servers`, + DisplayAttrs: &framework.DisplayAttributes{ + EditType: "file", + }, + }, + "ocsp_servers_override": { + Type: framework.TypeCommaStringSlice, + Description: `A comma-separated list of OCSP server addresses. If unset, the OCSP server is determined +from the AuthorityInformationAccess extension on the certificate being inspected.`, + }, + "ocsp_fail_open": { + Type: framework.TypeBool, + Default: false, + Description: "If set to true, if an OCSP revocation cannot be made successfully, login will proceed rather than failing. If false, failing to get an OCSP status fails the request.", + }, + "ocsp_query_all_servers": { + Type: framework.TypeBool, + Default: false, + Description: "If set to true, rather than accepting the first successful OCSP response, query all servers and consider the certificate valid only if all servers agree.", + }, "allowed_names": { Type: framework.TypeCommaStringSlice, Description: `A comma-separated list of names. @@ -294,6 +320,21 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr if certificateRaw, ok := d.GetOk("certificate"); ok { cert.Certificate = certificateRaw.(string) } + if ocspCertificatesRaw, ok := d.GetOk("ocsp_ca_certificates"); ok { + cert.OcspCaCertificates = ocspCertificatesRaw.(string) + } + if ocspEnabledRaw, ok := d.GetOk("ocsp_enabled"); ok { + cert.OcspEnabled = ocspEnabledRaw.(bool) + } + if ocspServerOverrides, ok := d.GetOk("ocsp_servers_override"); ok { + cert.OcspServersOverride = ocspServerOverrides.([]string) + } + if ocspFailOpen, ok := d.GetOk("ocsp_fail_open"); ok { + cert.OcspFailOpen = ocspFailOpen.(bool) + } + if ocspQueryAll, ok := d.GetOk("ocsp_query_all_servers"); ok { + cert.OcspQueryAllServers = ocspQueryAll.(bool) + } if displayNameRaw, ok := d.GetOk("display_name"); ok { cert.DisplayName = displayNameRaw.(string) } @@ -399,7 +440,7 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr } } if !clientAuth { - return logical.ErrorResponse("non-CA certificates should have TLS client authentication set as an extended key usage"), nil + return logical.ErrorResponse("nonCA certificates should have TLS client authentication set as an extended key usage"), nil } } @@ -438,6 +479,12 @@ type CertEntry struct { RequiredExtensions []string AllowedMetadataExtensions []string BoundCIDRs []*sockaddr.SockAddrMarshaler + + OcspCaCertificates string + OcspEnabled bool + OcspServersOverride []string + OcspFailOpen bool + OcspQueryAllServers bool } const pathCertHelpSyn = ` @@ -449,6 +496,7 @@ This endpoint allows you to create, read, update, and delete trusted certificate that are allowed to authenticate. Deleting a certificate will not revoke auth for prior authenticated connections. -To do this, do a revoke on "login". If you don't need to revoke login immediately, +To do this, do a revoke on "login". If you don'log need to revoke login immediately, then the next renew will cause the lease to expire. + ` diff --git a/builtin/credential/cert/path_config.go b/builtin/credential/cert/path_config.go index 9cc17f3a6aaf..c08992af15c4 100644 --- a/builtin/credential/cert/path_config.go +++ b/builtin/credential/cert/path_config.go @@ -8,6 +8,8 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) +const maxCacheSize = 100000 + func pathConfig(b *backend) *framework.Path { return &framework.Path{ Pattern: "config", @@ -22,6 +24,11 @@ func pathConfig(b *backend) *framework.Path { Default: false, Description: `If set, metadata of the certificate including the metadata corresponding to allowed_metadata_extensions will be stored in the alias. Defaults to false.`, }, + "ocsp_cache_size": { + Type: framework.TypeInt, + Default: 100, + Description: `The size of the in memory OCSP response cache, shared by all configured certs`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -32,18 +39,25 @@ func pathConfig(b *backend) *framework.Path { } func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - disableBinding := data.Get("disable_binding").(bool) - enableIdentityAliasMetadata := data.Get("enable_identity_alias_metadata").(bool) - - entry, err := logical.StorageEntryJSON("config", config{ - DisableBinding: disableBinding, - EnableIdentityAliasMetadata: enableIdentityAliasMetadata, - }) + config, err := b.Config(ctx, req.Storage) if err != nil { return nil, err } - if err := req.Storage.Put(ctx, entry); err != nil { + if disableBindingRaw, ok := data.GetOk("disable_binding"); ok { + config.DisableBinding = disableBindingRaw.(bool) + } + if enableIdentityAliasMetadataRaw, ok := data.GetOk("enable_identity_alias_metadata"); ok { + config.EnableIdentityAliasMetadata = enableIdentityAliasMetadataRaw.(bool) + } + if cacheSizeRaw, ok := data.GetOk("ocsp_cache_size"); ok { + cacheSize := cacheSizeRaw.(int) + if cacheSize < 2 || cacheSize > maxCacheSize { + return logical.ErrorResponse("invalid cache size, must be >= 2 and <= %d", maxCacheSize), nil + } + config.OcspCacheSize = cacheSize + } + if err := b.storeConfig(ctx, req.Storage, config); err != nil { return nil, err } return nil, nil @@ -58,6 +72,7 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *f data := map[string]interface{}{ "disable_binding": cfg.DisableBinding, "enable_identity_alias_metadata": cfg.EnableIdentityAliasMetadata, + "ocsp_cache_size": cfg.OcspCacheSize, } return &logical.Response{ @@ -85,4 +100,5 @@ func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error type config struct { DisableBinding bool `json:"disable_binding"` EnableIdentityAliasMetadata bool `json:"enable_identity_alias_metadata"` + OcspCacheSize int `json:"ocsp_cache_size"` } diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 11e63d75eaad..36144791b5a8 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -12,6 +12,8 @@ import ( "fmt" "strings" + "github.com/hashicorp/vault/sdk/helper/ocsp" + "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/policyutil" @@ -81,6 +83,9 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra if err != nil { return nil, err } + if b.configUpdated.Load() { + b.updatedConfig(config) + } if b.crls == nil { // Probably invalidated due to replication, but we need these to proceed @@ -161,6 +166,9 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f if err != nil { return nil, err } + if b.configUpdated.Load() { + b.updatedConfig(config) + } if b.crls == nil { if err := b.populateCRLs(ctx, req.Storage); err != nil { @@ -237,8 +245,8 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d certName = d.Get("name").(string) } - // Load the trusted certificates - roots, trusted, trustedNonCAs := b.loadTrustedCerts(ctx, req.Storage, certName) + // Load the trusted certificates and other details + roots, trusted, trustedNonCAs, verifyConf := b.loadTrustedCerts(ctx, req.Storage, certName) // Get the list of full chains matching the connection and validates the // certificate itself @@ -247,6 +255,11 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d return nil, nil, err } + var extraCas []*x509.Certificate + for _, t := range trusted { + extraCas = append(extraCas, t.Certificates...) + } + // If trustedNonCAs is not empty it means that client had registered a non-CA cert // with the backend. if len(trustedNonCAs) != 0 { @@ -254,9 +267,14 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d tCert := trustedNonCA.Certificates[0] // Check for client cert being explicitly listed in the config (and matching other constraints) if tCert.SerialNumber.Cmp(clientCert.SerialNumber) == 0 && - bytes.Equal(tCert.AuthorityKeyId, clientCert.AuthorityKeyId) && - b.matchesConstraints(clientCert, trustedNonCA.Certificates, trustedNonCA) { - return trustedNonCA, nil, nil + bytes.Equal(tCert.AuthorityKeyId, clientCert.AuthorityKeyId) { + matches, err := b.matchesConstraints(ctx, clientCert, trustedNonCA.Certificates, trustedNonCA, verifyConf) + if err != nil { + return nil, nil, err + } + if matches { + return trustedNonCA, nil, nil + } } } } @@ -273,10 +291,15 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d for _, tCert := range trust.Certificates { // For each certificate in the entry for _, chain := range trustedChains { // For each root chain that we matched for _, cCert := range chain { // For each cert in the matched chain - if tCert.Equal(cCert) && // ParsedCert intersects with matched chain - b.matchesConstraints(clientCert, chain, trust) { // validate client cert + matched chain against the config - // Add the match to the list - matches = append(matches, trust) + if tCert.Equal(cCert) { // ParsedCert intersects with matched chain + match, err := b.matchesConstraints(ctx, clientCert, chain, trust, verifyConf) // validate client cert + matched chain against the config + if err != nil { + return nil, nil, err + } + if match { + // Add the match to the list + matches = append(matches, trust) + } } } } @@ -292,8 +315,10 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d return matches[0], nil, nil } -func (b *backend) matchesConstraints(clientCert *x509.Certificate, trustedChain []*x509.Certificate, config *ParsedCert) bool { - return !b.checkForChainInCRLs(trustedChain) && +func (b *backend) matchesConstraints(ctx context.Context, clientCert *x509.Certificate, trustedChain []*x509.Certificate, + config *ParsedCert, conf *ocsp.VerifyConfig, +) (bool, error) { + soFar := !b.checkForChainInCRLs(trustedChain) && b.matchesNames(clientCert, config) && b.matchesCommonName(clientCert, config) && b.matchesDNSSANs(clientCert, config) && @@ -301,6 +326,14 @@ func (b *backend) matchesConstraints(clientCert *x509.Certificate, trustedChain b.matchesURISANs(clientCert, config) && b.matchesOrganizationalUnits(clientCert, config) && b.matchesCertificateExtensions(clientCert, config) + if config.Entry.OcspEnabled { + ocspGood, err := b.checkForCertInOCSP(ctx, clientCert, trustedChain, conf) + if err != nil { + return false, err + } + soFar = soFar && ocspGood + } + return soFar, nil } // matchesNames verifies that the certificate matches at least one configured @@ -447,7 +480,7 @@ func (b *backend) matchesCertificateExtensions(clientCert *x509.Certificate, con asn1.Unmarshal(ext.Value, &parsedValue) clientExtMap[ext.Id.String()] = parsedValue } - // If any of the required extensions don't match the constraint fails + // If any of the required extensions don'log match the constraint fails for _, requiredExt := range config.Entry.RequiredExtensions { reqExt := strings.SplitN(requiredExt, ":", 2) clientExtValue, clientExtValueOk := clientExtMap[reqExt[0]] @@ -491,7 +524,7 @@ func (b *backend) certificateExtensionsMetadata(clientCert *x509.Certificate, co } // loadTrustedCerts is used to load all the trusted certificates from the backend -func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert) { +func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert, conf *ocsp.VerifyConfig) { pool = x509.NewCertPool() trusted = make([]*ParsedCert, 0) trustedNonCAs = make([]*ParsedCert, 0) @@ -508,6 +541,7 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, } } + conf = &ocsp.VerifyConfig{} for _, name := range names { entry, err := b.Cert(ctx, storage, strings.TrimPrefix(name, "cert/")) if err != nil { @@ -515,7 +549,7 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, continue } if entry == nil { - // This could happen when the certName was provided and the cert doesn't exist, + // This could happen when the certName was provided and the cert doesn'log exist, // or just if between the LIST and the GET the cert was deleted. continue } @@ -525,6 +559,8 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, b.Logger().Error("failed to parse certificate", "name", name) continue } + parsed = append(parsed, parsePEM([]byte(entry.OcspCaCertificates))...) + if !parsed[0].IsCA { trustedNonCAs = append(trustedNonCAs, &ParsedCert{ Entry: entry, @@ -541,10 +577,33 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, Certificates: parsed, }) } + if entry.OcspEnabled { + conf.OcspEnabled = true + conf.OcspServersOverride = append(conf.OcspServersOverride, entry.OcspServersOverride...) + if entry.OcspFailOpen { + conf.OcspFailureMode = ocsp.FailOpenTrue + } else { + conf.OcspFailureMode = ocsp.FailOpenFalse + } + conf.QueryAllServers = conf.QueryAllServers || entry.OcspQueryAllServers + } } return } +func (b *backend) checkForCertInOCSP(ctx context.Context, clientCert *x509.Certificate, chain []*x509.Certificate, conf *ocsp.VerifyConfig) (bool, error) { + if !conf.OcspEnabled || len(chain) < 2 { + return true, nil + } + b.ocspClientMutex.RLock() + defer b.ocspClientMutex.RUnlock() + err := b.ocspClient.VerifyLeafCertificate(ctx, clientCert, chain[1], conf) + if err != nil { + return false, nil + } + return true, nil +} + func (b *backend) checkForChainInCRLs(chain []*x509.Certificate) bool { badChain := false for _, cert := range chain { diff --git a/builtin/credential/cert/path_login_test.go b/builtin/credential/cert/path_login_test.go index a01ec981663f..f69444270f39 100644 --- a/builtin/credential/cert/path_login_test.go +++ b/builtin/credential/cert/path_login_test.go @@ -1,24 +1,61 @@ package cert import ( + "context" "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "fmt" "io/ioutil" "math/big" mathrand "math/rand" "net" + "net/http" "os" "path/filepath" "strings" "testing" "time" + "github.com/hashicorp/vault/sdk/helper/certutil" + + "golang.org/x/crypto/ocsp" + logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical" "github.com/hashicorp/vault/sdk/logical" ) +var ocspPort int + +var source InMemorySource + +type testLogger struct{} + +func (t *testLogger) Log(args ...any) { + fmt.Printf("%v", args) +} + +func TestMain(m *testing.M) { + source = make(InMemorySource) + + listener, err := net.Listen("tcp", ":0") + if err != nil { + return + } + + ocspPort = listener.Addr().(*net.TCPAddr).Port + srv := &http.Server{ + Addr: "localhost:0", + Handler: NewResponder(&testLogger{}, source, nil), + } + go func() { + srv.Serve(listener) + }() + defer srv.Shutdown(context.Background()) + m.Run() +} + func TestCert_RoleResolve(t *testing.T) { certTemplate := &x509.Certificate{ Subject: pkix.Name{ @@ -159,6 +196,34 @@ func testAccStepResolveRoleExpectRoleResolutionToFail(t *testing.T, connState tl } } +func testAccStepResolveRoleOCSPFail(t *testing.T, connState tls.ConnectionState, certName string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ResolveRoleOperation, + Path: "login", + Unauthenticated: true, + ConnState: &connState, + ErrorOk: true, + Check: func(resp *logical.Response) error { + if resp == nil || !resp.IsError() { + t.Fatalf("Response was not an error: resp:%#v", resp) + } + + errString, ok := resp.Data["error"].(string) + if !ok { + t.Fatal("Error not part of response.") + } + + if !strings.Contains(errString, "no chain matching") { + t.Fatalf("Error was not due to OCSP failure. Error: %s", errString) + } + return nil + }, + Data: map[string]interface{}{ + "name": certName, + }, + } +} + func TestCert_RoleResolve_RoleDoesNotExist(t *testing.T) { certTemplate := &x509.Certificate{ Subject: pkix.Name{ @@ -197,3 +262,97 @@ func TestCert_RoleResolve_RoleDoesNotExist(t *testing.T) { }, }) } + +func TestCert_RoleResolveOCSP(t *testing.T) { + cases := []struct { + name string + failOpen bool + certStatus int + errExpected bool + }{ + {"failFalseGoodCert", false, ocsp.Good, false}, + {"failFalseRevokedCert", false, ocsp.Revoked, true}, + {"failFalseUnknownCert", false, ocsp.Unknown, true}, + {"failTrueGoodCert", true, ocsp.Good, false}, + {"failTrueRevokedCert", true, ocsp.Revoked, true}, + {"failTrueUnknownCert", true, ocsp.Unknown, false}, + } + certTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "example.com", + }, + DNSNames: []string{"example.com"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + OCSPServer: []string{fmt.Sprintf("http://localhost:%d", ocspPort)}, + } + tempDir, connState, err := generateTestCertAndConnState(t, certTemplate) + if tempDir != "" { + defer os.RemoveAll(tempDir) + } + if err != nil { + t.Fatalf("error testing connection state: %v", err) + } + ca, err := ioutil.ReadFile(filepath.Join(tempDir, "ca_cert.pem")) + if err != nil { + t.Fatalf("err: %v", err) + } + + issuer := parsePEM(ca) + pkf, err := ioutil.ReadFile(filepath.Join(tempDir, "ca_key.pem")) + if err != nil { + t.Fatalf("err: %v", err) + } + pk, err := certutil.ParsePEMBundle(string(pkf)) + if err != nil { + t.Fatalf("err: %v", err) + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + resp, err := ocsp.CreateResponse(issuer[0], issuer[0], ocsp.Response{ + Status: c.certStatus, + SerialNumber: certTemplate.SerialNumber, + ProducedAt: time.Now(), + ThisUpdate: time.Now(), + NextUpdate: time.Now().Add(time.Hour), + }, pk.PrivateKey) + if err != nil { + t.Fatal(err) + } + source[certTemplate.SerialNumber.String()] = resp + + b := testFactory(t) + b.(*backend).ocspClient.ClearCache() + var resolveStep logicaltest.TestStep + var loginStep logicaltest.TestStep + if c.errExpected { + loginStep = testAccStepLoginWithNameInvalid(t, connState, "web") + resolveStep = testAccStepResolveRoleOCSPFail(t, connState, "web") + } else { + loginStep = testAccStepLoginWithName(t, connState, "web") + resolveStep = testAccStepResolveRoleWithName(t, connState, "web") + } + logicaltest.Test(t, logicaltest.TestCase{ + CredentialBackend: b, + Steps: []logicaltest.TestStep{ + testAccStepCertWithExtraParams(t, "web", ca, "foo", allowed{dns: "example.com"}, false, + map[string]interface{}{"ocsp_enabled": true, "ocsp_fail_open": c.failOpen}), + loginStep, + resolveStep, + }, + }) + }) + } +} + +func serialFromBigInt(serial *big.Int) string { + return strings.TrimSpace(certutil.GetHexFormatted(serial.Bytes(), ":")) +} diff --git a/builtin/credential/cert/test_responder.go b/builtin/credential/cert/test_responder.go new file mode 100644 index 000000000000..1c7c75b2ff33 --- /dev/null +++ b/builtin/credential/cert/test_responder.go @@ -0,0 +1,301 @@ +// Package ocsp implements an OCSP responder based on a generic storage backend. +// It provides a couple of sample implementations. +// Because OCSP responders handle high query volumes, we have to be careful +// about how much logging we do. Error-level logs are reserved for problems +// internal to the server, that can be fixed by an administrator. Any type of +// incorrect input from a user should be logged and Info or below. For things +// that are logged on every request, Debug is the appropriate level. +// +// From https://github.com/cloudflare/cfssl/blob/master/ocsp/responder.go + +package cert + +import ( + "crypto" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "time" + + "golang.org/x/crypto/ocsp" +) + +var ( + malformedRequestErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x01} + internalErrorErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x02} + tryLaterErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x03} + sigRequredErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x05} + unauthorizedErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x06} + + // ErrNotFound indicates the request OCSP response was not found. It is used to + // indicate that the responder should reply with unauthorizedErrorResponse. + ErrNotFound = errors.New("Request OCSP Response not found") +) + +// Source represents the logical source of OCSP responses, i.e., +// the logic that actually chooses a response based on a request. In +// order to create an actual responder, wrap one of these in a Responder +// object and pass it to http.Handle. By default the Responder will set +// the headers Cache-Control to "max-age=(response.NextUpdate-now), public, no-transform, must-revalidate", +// Last-Modified to response.ThisUpdate, Expires to response.NextUpdate, +// ETag to the SHA256 hash of the response, and Content-Type to +// application/ocsp-response. If you want to override these headers, +// or set extra headers, your source should return a http.Header +// with the headers you wish to set. If you don'log want to set any +// extra headers you may return nil instead. +type Source interface { + Response(*ocsp.Request) ([]byte, http.Header, error) +} + +// An InMemorySource is a map from serialNumber -> der(response) +type InMemorySource map[string][]byte + +// Response looks up an OCSP response to provide for a given request. +// InMemorySource looks up a response purely based on serial number, +// without regard to what issuer the request is asking for. +func (src InMemorySource) Response(request *ocsp.Request) ([]byte, http.Header, error) { + response, present := src[request.SerialNumber.String()] + if !present { + return nil, nil, ErrNotFound + } + return response, nil, nil +} + +// Stats is a basic interface that allows users to record information +// about returned responses +type Stats interface { + ResponseStatus(ocsp.ResponseStatus) +} + +type logger interface { + Log(args ...any) +} + +// A Responder object provides the HTTP logic to expose a +// Source of OCSP responses. +type Responder struct { + log logger + Source Source + stats Stats +} + +// NewResponder instantiates a Responder with the give Source. +func NewResponder(t logger, source Source, stats Stats) *Responder { + return &Responder{ + Source: source, + stats: stats, + log: t, + } +} + +func overrideHeaders(response http.ResponseWriter, headers http.Header) { + for k, v := range headers { + if len(v) == 1 { + response.Header().Set(k, v[0]) + } else if len(v) > 1 { + response.Header().Del(k) + for _, e := range v { + response.Header().Add(k, e) + } + } + } +} + +// hashToString contains mappings for the only hash functions +// x/crypto/ocsp supports +var hashToString = map[crypto.Hash]string{ + crypto.SHA1: "SHA1", + crypto.SHA256: "SHA256", + crypto.SHA384: "SHA384", + crypto.SHA512: "SHA512", +} + +// A Responder can process both GET and POST requests. The mapping +// from an OCSP request to an OCSP response is done by the Source; +// the Responder simply decodes the request, and passes back whatever +// response is provided by the source. +// Note: The caller must use http.StripPrefix to strip any path components +// (including '/') on GET requests. +// Do not use this responder in conjunction with http.NewServeMux, because the +// default handler will try to canonicalize path components by changing any +// strings of repeated '/' into a single '/', which will break the base64 +// encoding. +func (rs *Responder) ServeHTTP(response http.ResponseWriter, request *http.Request) { + // By default we set a 'max-age=0, no-cache' Cache-Control header, this + // is only returned to the client if a valid authorized OCSP response + // is not found or an error is returned. If a response if found the header + // will be altered to contain the proper max-age and modifiers. + response.Header().Add("Cache-Control", "max-age=0, no-cache") + // Read response from request + var requestBody []byte + var err error + switch request.Method { + case "GET": + base64Request, err := url.QueryUnescape(request.URL.Path) + if err != nil { + rs.log.Log("Error decoding URL:", request.URL.Path) + response.WriteHeader(http.StatusBadRequest) + return + } + // url.QueryUnescape not only unescapes %2B escaping, but it additionally + // turns the resulting '+' into a space, which makes base64 decoding fail. + // So we go back afterwards and turn ' ' back into '+'. This means we + // accept some malformed input that includes ' ' or %20, but that's fine. + base64RequestBytes := []byte(base64Request) + for i := range base64RequestBytes { + if base64RequestBytes[i] == ' ' { + base64RequestBytes[i] = '+' + } + } + // In certain situations a UA may construct a request that has a double + // slash between the host name and the base64 request body due to naively + // constructing the request URL. In that case strip the leading slash + // so that we can still decode the request. + if len(base64RequestBytes) > 0 && base64RequestBytes[0] == '/' { + base64RequestBytes = base64RequestBytes[1:] + } + requestBody, err = base64.StdEncoding.DecodeString(string(base64RequestBytes)) + if err != nil { + rs.log.Log("Error decoding base64 from URL", string(base64RequestBytes)) + response.WriteHeader(http.StatusBadRequest) + return + } + case "POST": + requestBody, err = ioutil.ReadAll(request.Body) + if err != nil { + rs.log.Log("Problem reading body of POST", err) + response.WriteHeader(http.StatusBadRequest) + return + } + default: + response.WriteHeader(http.StatusMethodNotAllowed) + return + } + b64Body := base64.StdEncoding.EncodeToString(requestBody) + rs.log.Log("Received OCSP request", b64Body) + + // All responses after this point will be OCSP. + // We could check for the content type of the request, but that + // seems unnecessariliy restrictive. + response.Header().Add("Content-Type", "application/ocsp-response") + + // Parse response as an OCSP request + // XXX: This fails if the request contains the nonce extension. + // We don'log intend to support nonces anyway, but maybe we + // should return unauthorizedRequest instead of malformed. + ocspRequest, err := ocsp.ParseRequest(requestBody) + if err != nil { + rs.log.Log("Error decoding request body", b64Body) + response.WriteHeader(http.StatusBadRequest) + response.Write(malformedRequestErrorResponse) + if rs.stats != nil { + rs.stats.ResponseStatus(ocsp.Malformed) + } + return + } + + // Look up OCSP response from source + ocspResponse, headers, err := rs.Source.Response(ocspRequest) + if err != nil { + if err == ErrNotFound { + rs.log.Log("No response found for request: serial %x, request body %s", + ocspRequest.SerialNumber, b64Body) + response.Write(unauthorizedErrorResponse) + if rs.stats != nil { + rs.stats.ResponseStatus(ocsp.Unauthorized) + } + return + } + rs.log.Log("Error retrieving response for request: serial %x, request body %s, error", + ocspRequest.SerialNumber, b64Body, err) + response.WriteHeader(http.StatusInternalServerError) + response.Write(internalErrorErrorResponse) + if rs.stats != nil { + rs.stats.ResponseStatus(ocsp.InternalError) + } + return + } + + parsedResponse, err := ocsp.ParseResponse(ocspResponse, nil) + if err != nil { + rs.log.Log("Error parsing response for serial %x", + ocspRequest.SerialNumber, err) + response.Write(internalErrorErrorResponse) + if rs.stats != nil { + rs.stats.ResponseStatus(ocsp.InternalError) + } + return + } + + // Write OCSP response to response + response.Header().Add("Last-Modified", parsedResponse.ThisUpdate.Format(time.RFC1123)) + response.Header().Add("Expires", parsedResponse.NextUpdate.Format(time.RFC1123)) + now := time.Now() + maxAge := 0 + if now.Before(parsedResponse.NextUpdate) { + maxAge = int(parsedResponse.NextUpdate.Sub(now) / time.Second) + } else { + // TODO(#530): we want max-age=0 but this is technically an authorized OCSP response + // (despite being stale) and 5019 forbids attaching no-cache + maxAge = 0 + } + response.Header().Set( + "Cache-Control", + fmt.Sprintf( + "max-age=%d, public, no-transform, must-revalidate", + maxAge, + ), + ) + responseHash := sha256.Sum256(ocspResponse) + response.Header().Add("ETag", fmt.Sprintf("\"%X\"", responseHash)) + + if headers != nil { + overrideHeaders(response, headers) + } + + // RFC 7232 says that a 304 response must contain the above + // headers if they would also be sent for a 200 for the same + // request, so we have to wait until here to do this + if etag := request.Header.Get("If-None-Match"); etag != "" { + if etag == fmt.Sprintf("\"%X\"", responseHash) { + response.WriteHeader(http.StatusNotModified) + return + } + } + response.WriteHeader(http.StatusOK) + response.Write(ocspResponse) + if rs.stats != nil { + rs.stats.ResponseStatus(ocsp.Success) + } +} + +/* +Copyright (c) 2014 CloudFlare Inc. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ diff --git a/builtin/logical/pki/ocsp_test.go b/builtin/logical/pki/ocsp_test.go index c72ff0305663..772a1859e6a8 100644 --- a/builtin/logical/pki/ocsp_test.go +++ b/builtin/logical/pki/ocsp_test.go @@ -41,7 +41,7 @@ func TestOcsp_Disabled(t *testing.T) { "ocsp_disable": "true", }) requireSuccessNilResponse(t, resp, err) - resp, err = sendOcspRequest(t, b, s, localTT.reqType, testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) + resp, err = SendOcspRequest(t, b, s, localTT.reqType, testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) require.NoError(t, err) requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body") require.Equal(t, 401, resp.Data["http_status_code"]) @@ -63,7 +63,7 @@ func TestOcsp_UnknownIssuerWithNoDefault(t *testing.T) { // Create another completely empty mount so the created issuer/certificate above is unknown b, s := createBackendWithStorage(t) - resp, err := sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) + resp, err := SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) require.NoError(t, err) requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body") require.Equal(t, 401, resp.Data["http_status_code"]) @@ -85,7 +85,7 @@ func TestOcsp_WrongIssuerInRequest(t *testing.T) { }) requireSuccessNonNilResponse(t, resp, err, "revoke") - resp, err = sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer2, crypto.SHA1) + resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer2, crypto.SHA1) require.NoError(t, err) requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body") require.Equal(t, 200, resp.Data["http_status_code"]) @@ -167,7 +167,7 @@ func TestOcsp_InvalidIssuerIdInRevocationEntry(t *testing.T) { require.NoError(t, err, "failed writing out new revocation entry: %v", revEntry) // Send the request - resp, err = sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) + resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) require.NoError(t, err) requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body") require.Equal(t, 200, resp.Data["http_status_code"]) @@ -220,7 +220,7 @@ func TestOcsp_UnknownIssuerIdWithDefaultHavingOcspUsageRemoved(t *testing.T) { requireSuccessNonNilResponse(t, resp, err, "failed resetting usage flags on issuer2") // Send the request - resp, err = sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) + resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) require.NoError(t, err) requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body") require.Equal(t, 401, resp.Data["http_status_code"]) @@ -257,7 +257,7 @@ func TestOcsp_RevokedCertHasIssuerWithoutOcspUsage(t *testing.T) { require.False(t, usages.HasUsage(OCSPSigningUsage)) // Request an OCSP request from it, we should get an Unauthorized response back - resp, err = sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) + resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) requireSuccessNonNilResponse(t, resp, err, "ocsp get request") requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body") require.Equal(t, 401, resp.Data["http_status_code"]) @@ -296,7 +296,7 @@ func TestOcsp_RevokedCertHasIssuerWithoutAKey(t *testing.T) { requireSuccessNonNilResponse(t, resp, err, "failed deleting key") // Request an OCSP request from it, we should get an Unauthorized response back - resp, err = sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) + resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) requireSuccessNonNilResponse(t, resp, err, "ocsp get request") requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body") require.Equal(t, 401, resp.Data["http_status_code"]) @@ -342,7 +342,7 @@ func TestOcsp_MultipleMatchingIssuersOneWithoutSigningUsage(t *testing.T) { require.False(t, usages.HasUsage(OCSPSigningUsage)) // Request an OCSP request from it, we should get a Good response back, from the rotated cert - resp, err = sendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) + resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) requireSuccessNonNilResponse(t, resp, err, "ocsp get request") requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body") require.Equal(t, 200, resp.Data["http_status_code"]) @@ -410,7 +410,7 @@ func runOcspRequestTest(t *testing.T, requestType string, caKeyType string, caKe b, s, testEnv := setupOcspEnvWithCaKeyConfig(t, caKeyType, caKeyBits, caKeySigBits) // Non-revoked cert - resp, err := sendOcspRequest(t, b, s, requestType, testEnv.leafCertIssuer1, testEnv.issuer1, requestHash) + resp, err := SendOcspRequest(t, b, s, requestType, testEnv.leafCertIssuer1, testEnv.issuer1, requestHash) requireSuccessNonNilResponse(t, resp, err, "ocsp get request") requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body") require.Equal(t, 200, resp.Data["http_status_code"]) @@ -435,7 +435,7 @@ func runOcspRequestTest(t *testing.T, requestType string, caKeyType string, caKe }) requireSuccessNonNilResponse(t, resp, err, "revoke") - resp, err = sendOcspRequest(t, b, s, requestType, testEnv.leafCertIssuer1, testEnv.issuer1, requestHash) + resp, err = SendOcspRequest(t, b, s, requestType, testEnv.leafCertIssuer1, testEnv.issuer1, requestHash) requireSuccessNonNilResponse(t, resp, err, "ocsp get request with revoked") requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body") require.Equal(t, 200, resp.Data["http_status_code"]) @@ -455,7 +455,7 @@ func runOcspRequestTest(t *testing.T, requestType string, caKeyType string, caKe requireOcspResponseSignedBy(t, ocspResp, testEnv.issuer1) // Request status for our second issuer - resp, err = sendOcspRequest(t, b, s, requestType, testEnv.leafCertIssuer2, testEnv.issuer2, requestHash) + resp, err = SendOcspRequest(t, b, s, requestType, testEnv.leafCertIssuer2, testEnv.issuer2, requestHash) requireSuccessNonNilResponse(t, resp, err, "ocsp get request") requireFieldsSetInResp(t, resp, "http_content_type", "http_status_code", "http_raw_body") require.Equal(t, 200, resp.Data["http_status_code"]) @@ -569,7 +569,7 @@ func setupOcspEnvWithCaKeyConfig(t *testing.T, keyType string, caKeyBits int, ca return b, s, testEnv } -func sendOcspRequest(t *testing.T, b *backend, s logical.Storage, getOrPost string, cert, issuer *x509.Certificate, requestHash crypto.Hash) (*logical.Response, error) { +func SendOcspRequest(t *testing.T, b *backend, s logical.Storage, getOrPost string, cert, issuer *x509.Certificate, requestHash crypto.Hash) (*logical.Response, error) { ocspRequest := generateRequest(t, requestHash, cert, issuer) switch strings.ToLower(getOrPost) { @@ -578,7 +578,7 @@ func sendOcspRequest(t *testing.T, b *backend, s logical.Storage, getOrPost stri case "post": return sendOcspPostRequest(b, s, ocspRequest) default: - t.Fatalf("unsupported value for sendOcspRequest getOrPost arg: %s", getOrPost) + t.Fatalf("unsupported value for SendOcspRequest getOrPost arg: %s", getOrPost) } return nil, nil } diff --git a/changelog/17093.txt b/changelog/17093.txt new file mode 100644 index 000000000000..a51f3de8ff4b --- /dev/null +++ b/changelog/17093.txt @@ -0,0 +1,3 @@ +```release-note:improvement +auth/cert: Add configurable support for validating client certs with OCSP. +``` \ No newline at end of file diff --git a/sdk/go.mod b/sdk/go.mod index 13351351c8d0..eb27efc1a8ce 100644 --- a/sdk/go.mod +++ b/sdk/go.mod @@ -17,6 +17,7 @@ require ( github.com/hashicorp/go-kms-wrapping/entropy/v2 v2.0.0 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-plugin v1.4.5 + github.com/hashicorp/go-retryablehttp v0.5.3 github.com/hashicorp/go-secure-stdlib/base62 v0.1.1 github.com/hashicorp/go-secure-stdlib/mlock v0.1.1 github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6 @@ -45,6 +46,7 @@ require ( github.com/fatih/color v1.7.0 // indirect github.com/frankban/quicktest v1.10.0 // indirect github.com/go-asn1-ber/asn1-ber v1.3.1 // indirect + github.com/hashicorp/go-cleanhttp v0.5.0 // indirect github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.6 // indirect diff --git a/sdk/go.sum b/sdk/go.sum index 9107c46e1be6..2c9a7fd11f0c 100644 --- a/sdk/go.sum +++ b/sdk/go.sum @@ -87,6 +87,7 @@ github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFb github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.0 h1:wvCrVc9TjDls6+YGAF2hAifE1E5U1+b4tH6KdvN3Gig= github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= github.com/hashicorp/go-hclog v0.16.2 h1:K4ev2ib4LdQETX5cSZBG0DVLk1jwGqSPXBjdah3veNs= github.com/hashicorp/go-hclog v0.16.2/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= @@ -100,6 +101,7 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-plugin v1.4.5 h1:oTE/oQR4eghggRg8VY7PAz3dr++VwDNBGCcOfIvHpBo= github.com/hashicorp/go-plugin v1.4.5/go.mod h1:viDMjcLJuDui6pXb8U4HVfb8AamCWhHGUjr2IrTF67s= +github.com/hashicorp/go-retryablehttp v0.5.3 h1:QlWt0KvWT0lq8MFppF9tsJGF+ynG7ztc2KIPhzRGk7s= github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= github.com/hashicorp/go-secure-stdlib/base62 v0.1.1 h1:6KMBnfEv0/kLAz0O76sliN5mXbCDcLfs2kP7ssP7+DQ= github.com/hashicorp/go-secure-stdlib/base62 v0.1.1/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go new file mode 100644 index 000000000000..e54fdeface46 --- /dev/null +++ b/sdk/helper/ocsp/client.go @@ -0,0 +1,1059 @@ +// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. + +package ocsp + +import ( + "bytes" + "context" + "crypto" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/base64" + "errors" + "fmt" + "io" + "math/big" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-retryablehttp" + lru "github.com/hashicorp/golang-lru" + "github.com/hashicorp/vault/sdk/helper/certutil" + "golang.org/x/crypto/ocsp" +) + +// FailOpenMode is OCSP fail open mode. FailOpenTrue by default and may +// set to ocspModeFailClosed for fail closed mode +type FailOpenMode uint32 + +type requestFunc func(method, urlStr string, body interface{}) (*retryablehttp.Request, error) + +type clientInterface interface { + Do(req *retryablehttp.Request) (*http.Response, error) +} + +const ( + httpHeaderContentType = "Content-Type" + httpHeaderAccept = "accept" + httpHeaderContentLength = "Content-Length" + httpHeaderHost = "Host" + ocspRequestContentType = "application/ocsp-request" + ocspResponseContentType = "application/ocsp-response" +) + +const ( + ocspFailOpenNotSet FailOpenMode = iota + // FailOpenTrue represents OCSP fail open mode. + FailOpenTrue + // FailOpenFalse represents OCSP fail closed mode. + FailOpenFalse +) + +const ( + ocspModeFailOpen = "FAIL_OPEN" + ocspModeFailClosed = "FAIL_CLOSED" + ocspModeInsecure = "INSECURE" +) + +const ocspCacheKey = "ocsp_cache" + +const ( + // defaultOCSPResponderTimeout is the total timeout for OCSP responder. + defaultOCSPResponderTimeout = 10 * time.Second +) + +const ( + // cacheExpire specifies cache data expiration time in seconds. + cacheExpire = float64(24 * 60 * 60) +) + +type ocspCachedResponse struct { + time float64 + producedAt float64 + thisUpdate float64 + nextUpdate float64 + status ocspStatusCode +} + +type Client struct { + // caRoot includes the CA certificates. + caRoot map[string]*x509.Certificate + // certPool includes the CA certificates. + certPool *x509.CertPool + ocspResponseCache *lru.TwoQueueCache + ocspResponseCacheLock sync.RWMutex + // cacheUpdated is true if the memory cache is updated + cacheUpdated bool + logFactory func() hclog.Logger +} + +type ocspStatusCode int + +type ocspStatus struct { + code ocspStatusCode + err error +} + +const ( + ocspSuccess ocspStatusCode = 0 + ocspStatusGood ocspStatusCode = -1 + ocspStatusRevoked ocspStatusCode = -2 + ocspStatusUnknown ocspStatusCode = -3 + ocspStatusOthers ocspStatusCode = -4 + ocspFailedDecomposeRequest ocspStatusCode = -5 + ocspInvalidValidity ocspStatusCode = -6 + ocspMissedCache ocspStatusCode = -7 + ocspCacheExpired ocspStatusCode = -8 +) + +// copied from crypto/ocsp.go +type certID struct { + HashAlgorithm pkix.AlgorithmIdentifier + NameHash []byte + IssuerKeyHash []byte + SerialNumber *big.Int +} + +// cache key +type certIDKey struct { + NameHash string + IssuerKeyHash string + SerialNumber string +} + +// copied from crypto/ocsp +var hashOIDs = map[crypto.Hash]asn1.ObjectIdentifier{ + crypto.SHA1: asn1.ObjectIdentifier([]int{1, 3, 14, 3, 2, 26}), + crypto.SHA256: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 1}), + crypto.SHA384: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 2}), + crypto.SHA512: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 3}), +} + +// copied from crypto/ocsp +func getOIDFromHashAlgorithm(target crypto.Hash) (asn1.ObjectIdentifier, error) { + for hash, oid := range hashOIDs { + if hash == target { + return oid, nil + } + } + return nil, fmt.Errorf("no valid OID is found for the hash algorithm: %v", target) +} + +func (c *Client) ClearCache() { + c.ocspResponseCache.Purge() +} + +func (c *Client) getHashAlgorithmFromOID(target pkix.AlgorithmIdentifier) crypto.Hash { + for hash, oid := range hashOIDs { + if oid.Equal(target.Algorithm) { + return hash + } + } + // no valid hash algorithm is found for the oid. Falling back to SHA1 + return crypto.SHA1 +} + +// isInValidityRange checks the validity +func isInValidityRange(currTime, nextUpdate time.Time) bool { + return !nextUpdate.IsZero() && !currTime.After(nextUpdate) +} + +func extractCertIDKeyFromRequest(ocspReq []byte) (*certIDKey, *ocspStatus) { + r, err := ocsp.ParseRequest(ocspReq) + if err != nil { + return nil, &ocspStatus{ + code: ocspFailedDecomposeRequest, + err: err, + } + } + + // encode CertID, used as a key in the cache + encodedCertID := &certIDKey{ + base64.StdEncoding.EncodeToString(r.IssuerNameHash), + base64.StdEncoding.EncodeToString(r.IssuerKeyHash), + r.SerialNumber.String(), + } + return encodedCertID, &ocspStatus{ + code: ocspSuccess, + } +} + +func (c *Client) encodeCertIDKey(certIDKeyBase64 string) (*certIDKey, error) { + r, err := base64.StdEncoding.DecodeString(certIDKeyBase64) + if err != nil { + return nil, err + } + var cid certID + rest, err := asn1.Unmarshal(r, &cid) + if err != nil { + // error in parsing + return nil, err + } + if len(rest) > 0 { + // extra bytes to the end + return nil, err + } + return &certIDKey{ + base64.StdEncoding.EncodeToString(cid.NameHash), + base64.StdEncoding.EncodeToString(cid.IssuerKeyHash), + cid.SerialNumber.String(), + }, nil +} + +func (c *Client) checkOCSPResponseCache(encodedCertID *certIDKey, subject, issuer *x509.Certificate) (*ocspStatus, error) { + c.ocspResponseCacheLock.RLock() + var cacheValue *ocspCachedResponse + v, ok := c.ocspResponseCache.Get(*encodedCertID) + if ok { + cacheValue = v.(*ocspCachedResponse) + } + c.ocspResponseCacheLock.RUnlock() + + status, err := c.extractOCSPCacheResponseValue(cacheValue, subject, issuer) + if err != nil { + return nil, err + } + if !isValidOCSPStatus(status.code) { + c.deleteOCSPCache(encodedCertID) + } + return status, err +} + +func (c *Client) deleteOCSPCache(encodedCertID *certIDKey) { + c.ocspResponseCacheLock.Lock() + c.ocspResponseCache.Remove(*encodedCertID) + c.cacheUpdated = true + c.ocspResponseCacheLock.Unlock() +} + +func validateOCSP(ocspRes *ocsp.Response) (*ocspStatus, error) { + curTime := time.Now() + + if ocspRes == nil { + return nil, errors.New("OCSP Response is nil") + } + if !isInValidityRange(curTime, ocspRes.NextUpdate) { + return &ocspStatus{ + code: ocspInvalidValidity, + err: fmt.Errorf("invalid validity: producedAt: %v, thisUpdate: %v, nextUpdate: %v", ocspRes.ProducedAt, ocspRes.ThisUpdate, ocspRes.NextUpdate), + }, nil + } + return returnOCSPStatus(ocspRes), nil +} + +func returnOCSPStatus(ocspRes *ocsp.Response) *ocspStatus { + switch ocspRes.Status { + case ocsp.Good: + return &ocspStatus{ + code: ocspStatusGood, + err: nil, + } + case ocsp.Revoked: + return &ocspStatus{ + code: ocspStatusRevoked, + } + case ocsp.Unknown: + return &ocspStatus{ + code: ocspStatusUnknown, + err: errors.New("OCSP status unknown."), + } + default: + return &ocspStatus{ + code: ocspStatusOthers, + err: fmt.Errorf("OCSP others. %v", ocspRes.Status), + } + } +} + +// retryOCSP is the second level of retry method if the returned contents are corrupted. It often happens with OCSP +// serer and retry helps. +func (c *Client) retryOCSP( + ctx context.Context, + client clientInterface, + req requestFunc, + ocspHost *url.URL, + headers map[string]string, + reqBody []byte, + issuer *x509.Certificate, +) (ocspRes *ocsp.Response, ocspResBytes []byte, ocspS *ocspStatus, err error) { + origHost := *ocspHost + doRequest := func(request *retryablehttp.Request) (*http.Response, error) { + if err != nil { + return nil, err + } + if request != nil { + request = request.WithContext(ctx) + for k, v := range headers { + request.Header[k] = append(request.Header[k], v) + } + } + res, err := client.Do(request) + if err != nil { + return nil, err + } + c.Logger().Debug("StatusCode from OCSP Server:", "statusCode", res.StatusCode) + return res, err + } + + ocspHost.Path = ocspHost.Path + "/" + base64.StdEncoding.EncodeToString(reqBody) + var res *http.Response + request, err := req("GET", ocspHost.String(), nil) + if err != nil { + return nil, nil, nil, err + } + if res, err = doRequest(request); err != nil { + return nil, nil, nil, err + } else { + defer res.Body.Close() + } + if res.StatusCode == http.StatusMethodNotAllowed { + request, err := req("POST", origHost.String(), bytes.NewBuffer(reqBody)) + if err != nil { + return nil, nil, nil, err + } + if res, err := doRequest(request); err != nil { + return nil, nil, nil, err + } else { + defer res.Body.Close() + } + } + if res.StatusCode != http.StatusOK { + return nil, nil, nil, fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status) + } + ocspResBytes, err = io.ReadAll(res.Body) + if err != nil { + return nil, nil, nil, err + } + ocspRes, err = ocsp.ParseResponse(ocspResBytes, issuer) + if err != nil { + return nil, nil, nil, err + } + + return ocspRes, ocspResBytes, &ocspStatus{ + code: ocspSuccess, + }, nil +} + +// GetRevocationStatus checks the certificate revocation status for subject using issuer certificate. +func (c *Client) GetRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate, conf *VerifyConfig) (*ocspStatus, error) { + status, ocspReq, encodedCertID, err := c.validateWithCache(subject, issuer) + if err != nil { + return nil, err + } + if isValidOCSPStatus(status.code) { + return status, nil + } + if ocspReq == nil || encodedCertID == nil { + return status, nil + } + c.Logger().Debug("cache missed", "server", subject.OCSPServer) + if len(subject.OCSPServer) == 0 && len(conf.OcspServersOverride) == 0 { + return nil, fmt.Errorf("no OCSP responder URL: subject: %v", subject.Subject) + } + ocspHosts := subject.OCSPServer + if len(conf.OcspServersOverride) > 0 { + ocspHosts = conf.OcspServersOverride + } + + var wg sync.WaitGroup + + ocspStatuses := make([]*ocspStatus, len(ocspHosts)) + ocspResponses := make([]*ocsp.Response, len(ocspHosts)) + errors := make([]error, len(ocspHosts)) + + for i, ocspHost := range ocspHosts { + u, err := url.Parse(ocspHost) + if err != nil { + return nil, err + } + + hostname := u.Hostname() + + headers := make(map[string]string) + headers[httpHeaderContentType] = ocspRequestContentType + headers[httpHeaderAccept] = ocspResponseContentType + headers[httpHeaderContentLength] = strconv.Itoa(len(ocspReq)) + headers[httpHeaderHost] = hostname + timeout := defaultOCSPResponderTimeout + + ocspClient := retryablehttp.NewClient() + ocspClient.HTTPClient.Timeout = timeout + ocspClient.HTTPClient.Transport = newInsecureOcspTransport(conf.ExtraCas) + + doRequest := func() error { + if conf.QueryAllServers { + defer wg.Done() + } + ocspRes, _, ocspS, err := c.retryOCSP( + ctx, ocspClient, retryablehttp.NewRequest, u, headers, ocspReq, issuer) + ocspResponses[i] = ocspRes + if err != nil { + errors[i] = err + return err + } + if ocspS.code != ocspSuccess { + ocspStatuses[i] = ocspS + return nil + } + + ret, err := validateOCSP(ocspRes) + if err != nil { + errors[i] = err + return err + } + if isValidOCSPStatus(ret.code) { + ocspStatuses[i] = ret + } + return nil + } + if conf.QueryAllServers { + wg.Add(1) + go doRequest() + } else { + err = doRequest() + if err == nil { + break + } + } + } + if conf.QueryAllServers { + wg.Wait() + } + // Good by default + var ret *ocspStatus + ocspRes := ocspResponses[0] + var firstError error + for i := range ocspHosts { + if errors[i] != nil { + if firstError == nil { + firstError = errors[i] + } + } else if ocspStatuses[i] != nil { + switch ocspStatuses[i].code { + case ocspStatusRevoked: + ret = ocspStatuses[i] + ocspRes = ocspResponses[i] + break + case ocspStatusGood: + // Use this response only if we don't have a status already, or if what we have was unknown + if ret == nil || ret.code == ocspStatusUnknown { + ret = ocspStatuses[i] + ocspRes = ocspResponses[i] + } + case ocspStatusUnknown: + if ret == nil { + // We may want to use this as the overall result + ret = ocspStatuses[i] + ocspRes = ocspResponses[i] + } + } + } + } + + // If no server reported the cert revoked, but we did have an error, report it + if (ret == nil || ret.code == ocspStatusUnknown) && firstError != nil { + return nil, firstError + } + // otherwise ret should contain a response for the overall request + + if !isValidOCSPStatus(ret.code) { + return ret, nil + } + v := ocspCachedResponse{ + status: ret.code, + time: float64(time.Now().UTC().Unix()), + producedAt: float64(ocspRes.ProducedAt.UTC().Unix()), + thisUpdate: float64(ocspRes.ThisUpdate.UTC().Unix()), + nextUpdate: float64(ocspRes.NextUpdate.UTC().Unix()), + } + + c.ocspResponseCacheLock.Lock() + c.ocspResponseCache.Add(*encodedCertID, &v) + c.cacheUpdated = true + c.ocspResponseCacheLock.Unlock() + return ret, nil +} + +func isValidOCSPStatus(status ocspStatusCode) bool { + return status == ocspStatusGood || status == ocspStatusRevoked || status == ocspStatusUnknown +} + +type VerifyConfig struct { + OcspEnabled bool + ExtraCas []*x509.Certificate + OcspServersOverride []string + OcspFailureMode FailOpenMode + QueryAllServers bool +} + +// VerifyLeafCertificate verifies just the subject against it's direct issuer +func (c *Client) VerifyLeafCertificate(ctx context.Context, subject, issuer *x509.Certificate, conf *VerifyConfig) error { + results, err := c.GetRevocationStatus(ctx, subject, issuer, conf) + if err != nil { + return err + } + if results.code == ocspStatusGood { + return nil + } else { + serial := issuer.SerialNumber + serialHex := strings.TrimSpace(certutil.GetHexFormatted(serial.Bytes(), ":")) + if results.code == ocspStatusRevoked { + return fmt.Errorf("certificate with serial number %s has been revoked", serialHex) + } else if conf.OcspFailureMode == FailOpenFalse { + return fmt.Errorf("unknown OCSP status for cert with serial number %s", strings.TrimSpace(certutil.GetHexFormatted(serial.Bytes(), ":"))) + } else { + c.Logger().Warn("could not validate OCSP status for cert, but continuing in fail open mode", "serial", serialHex) + } + } + return nil +} + +// VerifyPeerCertificate verifies all of certificate revocation status +func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate, conf *VerifyConfig) error { + for i := 0; i < len(verifiedChains); i++ { + // Certificate signed by Root CA. This should be one before the last in the Certificate Chain + numberOfNoneRootCerts := len(verifiedChains[i]) - 1 + if !verifiedChains[i][numberOfNoneRootCerts].IsCA || string(verifiedChains[i][numberOfNoneRootCerts].RawIssuer) != string(verifiedChains[i][numberOfNoneRootCerts].RawSubject) { + // Check if the last Non Root Cert is also a CA or is self signed. + // if the last certificate is not, add it to the list + rca := c.caRoot[string(verifiedChains[i][numberOfNoneRootCerts].RawIssuer)] + if rca == nil { + return fmt.Errorf("failed to find root CA. pkix.name: %v", verifiedChains[i][numberOfNoneRootCerts].Issuer) + } + verifiedChains[i] = append(verifiedChains[i], rca) + numberOfNoneRootCerts++ + } + results, err := c.GetAllRevocationStatus(ctx, verifiedChains[i], conf) + if err != nil { + return err + } + if r := c.canEarlyExitForOCSP(results, numberOfNoneRootCerts, conf); r != nil { + return r.err + } + } + + return nil +} + +func (c *Client) canEarlyExitForOCSP(results []*ocspStatus, chainSize int, conf *VerifyConfig) *ocspStatus { + msg := "" + if conf.OcspFailureMode == FailOpenFalse { + // Fail closed. any error is returned to stop connection + for _, r := range results { + if r.err != nil { + return r + } + } + } else { + // Fail open and all results are valid. + allValid := len(results) == chainSize + for _, r := range results { + if !isValidOCSPStatus(r.code) { + allValid = false + break + } + } + for _, r := range results { + if allValid && r.code == ocspStatusRevoked { + return r + } + if r != nil && r.code != ocspStatusGood && r.err != nil { + msg += "" + r.err.Error() + } + } + } + if len(msg) > 0 { + c.Logger().Warn( + "OCSP is set to fail-open, and could not retrieve OCSP based revocation checking but proceeding.", "detail", msg) + } + return nil +} + +func (c *Client) validateWithCacheForAllCertificates(verifiedChains []*x509.Certificate) (bool, error) { + n := len(verifiedChains) - 1 + for j := 0; j < n; j++ { + subject := verifiedChains[j] + issuer := verifiedChains[j+1] + status, _, _, err := c.validateWithCache(subject, issuer) + if err != nil { + return false, err + } + if !isValidOCSPStatus(status.code) { + return false, nil + } + } + return true, nil +} + +func (c *Client) validateWithCache(subject, issuer *x509.Certificate) (*ocspStatus, []byte, *certIDKey, error) { + ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{}) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to create OCSP request from the certificates: %v", err) + } + encodedCertID, ocspS := extractCertIDKeyFromRequest(ocspReq) + if ocspS.code != ocspSuccess { + return nil, nil, nil, fmt.Errorf("failed to extract CertID from OCSP Request: %v", err) + } + status, err := c.checkOCSPResponseCache(encodedCertID, subject, issuer) + if err != nil { + return nil, nil, nil, err + } + return status, ocspReq, encodedCertID, nil +} + +func (c *Client) GetAllRevocationStatus(ctx context.Context, verifiedChains []*x509.Certificate, conf *VerifyConfig) ([]*ocspStatus, error) { + _, err := c.validateWithCacheForAllCertificates(verifiedChains) + if err != nil { + return nil, err + } + n := len(verifiedChains) - 1 + results := make([]*ocspStatus, n) + for j := 0; j < n; j++ { + results[j], err = c.GetRevocationStatus(ctx, verifiedChains[j], verifiedChains[j+1], conf) + if err != nil { + return nil, err + } + if !isValidOCSPStatus(results[j].code) { + return results, nil + } + } + return results, nil +} + +// verifyPeerCertificateSerial verifies the certificate revocation status in serial. +func (c *Client) verifyPeerCertificateSerial(conf *VerifyConfig) func(_ [][]byte, verifiedChains [][]*x509.Certificate) (err error) { + return func(_ [][]byte, verifiedChains [][]*x509.Certificate) error { + return c.VerifyPeerCertificate(context.TODO(), verifiedChains, conf) + } +} + +func (c *Client) extractOCSPCacheResponseValueWithoutSubject(cacheValue ocspCachedResponse) (*ocspStatus, error) { + return c.extractOCSPCacheResponseValue(&cacheValue, nil, nil) +} + +func (c *Client) extractOCSPCacheResponseValue(cacheValue *ocspCachedResponse, subject, issuer *x509.Certificate) (*ocspStatus, error) { + subjectName := "Unknown" + if subject != nil { + subjectName = subject.Subject.CommonName + } + + curTime := time.Now() + if cacheValue == nil { + return &ocspStatus{ + code: ocspMissedCache, + err: fmt.Errorf("miss cache data. subject: %v", subjectName), + }, nil + } + currentTime := float64(curTime.UTC().Unix()) + if currentTime-cacheValue.time >= cacheExpire { + return &ocspStatus{ + code: ocspCacheExpired, + err: fmt.Errorf("cache expired. current: %v, cache: %v", + time.Unix(int64(currentTime), 0).UTC(), time.Unix(int64(cacheValue.time), 0).UTC()), + }, nil + } + + return validateOCSP(&ocsp.Response{ + ProducedAt: time.Unix(int64(cacheValue.producedAt), 0).UTC(), + ThisUpdate: time.Unix(int64(cacheValue.thisUpdate), 0).UTC(), + NextUpdate: time.Unix(int64(cacheValue.nextUpdate), 0).UTC(), + Status: int(cacheValue.status), + }) +} + +/* +// writeOCSPCache writes a OCSP Response cache +func (c *Client) writeOCSPCache(ctx context.Context, storage logical.Storage) error { + c.Logger().Debug("writing OCSP Response cache") + t := time.Now() + m := make(map[string][]interface{}) + keys := c.ocspResponseCache.Keys() + if len(keys) > persistedCacheSize { + keys = keys[:persistedCacheSize] + } + for _, k := range keys { + e, ok := c.ocspResponseCache.Get(k) + if ok { + entry := e.(*ocspCachedResponse) + // Don't store if expired + if isInValidityRange(t, time.Unix(int64(entry.thisUpdate), 0), time.Unix(int64(entry.nextUpdate), 0)) { + key := k.(certIDKey) + cacheKeyInBase64, err := decodeCertIDKey(&key) + if err != nil { + return err + } + m[cacheKeyInBase64] = []interface{}{entry.status, entry.time, entry.producedAt, entry.thisUpdate, entry.nextUpdate} + } + } + } + + v, err := jsonutil.EncodeJSONAndCompress(m, nil) + if err != nil { + return err + } + entry := logical.StorageEntry{ + Key: ocspCacheKey, + Value: v, + } + return storage.Put(ctx, &entry) +} + +// readOCSPCache reads a OCSP Response cache from storage +func (c *Client) readOCSPCache(ctx context.Context, storage logical.Storage) error { + c.Logger().Debug("reading OCSP Response cache") + + entry, err := storage.Get(ctx, ocspCacheKey) + if err != nil { + return err + } + if entry == nil { + return nil + } + var untypedCache map[string][]interface{} + + err = jsonutil.DecodeJSON(entry.Value, &untypedCache) + if err != nil { + return errors.New("failed to unmarshal OCSP cache") + } + + for k, v := range untypedCache { + key, err := c.encodeCertIDKey(k) + if err != nil { + return err + } + var times [4]float64 + for i, t := range v[1:] { + if jn, ok := t.(json.Number); ok { + times[i], err = jn.Float64() + if err != nil { + return err + } + } else { + times[i] = t.(float64) + } + } + var status int + if jn, ok := v[0].(json.Number); ok { + s, err := jn.Int64() + if err != nil { + return err + } + status = int(s) + } else { + status = v[0].(int) + } + + c.ocspResponseCache.Add(*key, &ocspCachedResponse{ + status: ocspStatusCode(status), + time: times[0], + producedAt: times[1], + thisUpdate: times[2], + nextUpdate: times[3], + }) + } + + return nil +} +*/ + +func New(logFactory func() hclog.Logger, cacheSize int) *Client { + if cacheSize < 100 { + cacheSize = 100 + } + cache, _ := lru.New2Q(cacheSize) + c := Client{ + caRoot: make(map[string]*x509.Certificate), + ocspResponseCache: cache, + logFactory: logFactory, + } + + return &c +} + +func (c *Client) Logger() hclog.Logger { + return c.logFactory() +} + +// insecureOcspTransport is the transport object that doesn't do certificate revocation check. +func newInsecureOcspTransport(extraCas []*x509.Certificate) *http.Transport { + // Get the SystemCertPool, continue with an empty pool on error + rootCAs, _ := x509.SystemCertPool() + if rootCAs == nil { + rootCAs = x509.NewCertPool() + } + for _, c := range extraCas { + rootCAs.AddCert(c) + } + config := &tls.Config{ + RootCAs: rootCAs, + } + return &http.Transport{ + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Minute, + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSClientConfig: config, + } +} + +// NewTransport includes the certificate revocation check with OCSP in sequential. +func (c *Client) NewTransport(conf *VerifyConfig) *http.Transport { + rootCAs := c.certPool + if rootCAs == nil { + rootCAs, _ = x509.SystemCertPool() + } + if rootCAs == nil { + rootCAs = x509.NewCertPool() + } + for _, c := range conf.ExtraCas { + rootCAs.AddCert(c) + } + return &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: rootCAs, + VerifyPeerCertificate: c.verifyPeerCertificateSerial(conf), + }, + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Minute, + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + } +} + +/* +func (c *Client) WriteCache(ctx context.Context, storage logical.Storage) error { + c.ocspResponseCacheLock.Lock() + defer c.ocspResponseCacheLock.Unlock() + if c.cacheUpdated { + err := c.writeOCSPCache(ctx, storage) + if err == nil { + c.cacheUpdated = false + } + return err + } + return nil +} + +func (c *Client) ReadCache(ctx context.Context, storage logical.Storage) error { + c.ocspResponseCacheLock.Lock() + defer c.ocspResponseCacheLock.Unlock() + return c.readOCSPCache(ctx, storage) +} +*/ +/* + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go new file mode 100644 index 000000000000..2f3f1976d2a8 --- /dev/null +++ b/sdk/helper/ocsp/ocsp_test.go @@ -0,0 +1,530 @@ +// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. + +package ocsp + +import ( + "bytes" + "context" + "crypto" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "testing" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-retryablehttp" + lru "github.com/hashicorp/golang-lru" + "golang.org/x/crypto/ocsp" +) + +func TestOCSP(t *testing.T) { + targetURL := []string{ + "https://sfcdev1.blob.core.windows.net/", + "https://sfctest0.snowflakecomputing.com/", + "https://s3-us-west-2.amazonaws.com/sfc-snowsql-updates/?prefix=1.1/windows_x86_64", + } + + conf := VerifyConfig{ + OcspFailureMode: FailOpenFalse, + } + c := New(testLogFactory, 10) + transports := []*http.Transport{ + newInsecureOcspTransport(nil), + c.NewTransport(&conf), + } + + for _, tgt := range targetURL { + c.ocspResponseCache, _ = lru.New2Q(10) + for _, tr := range transports { + c := &http.Client{ + Transport: tr, + Timeout: 30 * time.Second, + } + req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil)) + if err != nil { + t.Fatalf("fail to create a request. err: %v", err) + } + res, err := c.Do(req) + if err != nil { + t.Fatalf("failed to GET contents. err: %v", err) + } + defer res.Body.Close() + _, err = ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("failed to read content body for %v", tgt) + } + + } + } +} + +/** +// Used for development, requires an active Vault with PKI setup +func TestMultiOCSP(t *testing.T) { + + targetURL := []string{ + "https://localhost:8200/v1/pki/ocsp", + "https://localhost:8200/v1/pki/ocsp", + "https://localhost:8200/v1/pki/ocsp", + } + + b, _ := pem.Decode([]byte(vaultCert)) + caCert, _ := x509.ParseCertificate(b.Bytes) + conf := VerifyConfig{ + OcspFailureMode: FailOpenFalse, + QueryAllServers: true, + OcspServersOverride: targetURL, + ExtraCas: []*x509.Certificate{caCert}, + } + c := New(testLogFactory, 10) + transports := []*http.Transport{ + newInsecureOcspTransport(conf.ExtraCas), + c.NewTransport(&conf), + } + + tgt := "https://localhost:8200/v1/pki/ca/pem" + c.ocspResponseCache, _ = lru.New2Q(10) + for _, tr := range transports { + c := &http.Client{ + Transport: tr, + Timeout: 30 * time.Second, + } + req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil)) + if err != nil { + t.Fatalf("fail to create a request. err: %v", err) + } + res, err := c.Do(req) + if err != nil { + t.Fatalf("failed to GET contents. err: %v", err) + } + defer res.Body.Close() + _, err = ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("failed to read content body for %v", tgt) + } + } +} +*/ + +func TestUnitEncodeCertIDGood(t *testing.T) { + targetURLs := []string{ + "faketestaccount.snowflakecomputing.com:443", + "s3-us-west-2.amazonaws.com:443", + "sfcdev1.blob.core.windows.net:443", + } + for _, tt := range targetURLs { + chainedCerts := getCert(tt) + for i := 0; i < len(chainedCerts)-1; i++ { + subject := chainedCerts[i] + issuer := chainedCerts[i+1] + ocspServers := subject.OCSPServer + if len(ocspServers) == 0 { + t.Fatalf("no OCSP server is found. cert: %v", subject.Subject) + } + ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{}) + if err != nil { + t.Fatalf("failed to create OCSP request. err: %v", err) + } + var ost *ocspStatus + _, ost = extractCertIDKeyFromRequest(ocspReq) + if ost.err != nil { + t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err) + } + // better hash. Not sure if the actual OCSP server accepts this, though. + ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512}) + if err != nil { + t.Fatalf("failed to create OCSP request. err: %v", err) + } + _, ost = extractCertIDKeyFromRequest(ocspReq) + if ost.err != nil { + t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err) + } + // tweaked request binary + ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512}) + if err != nil { + t.Fatalf("failed to create OCSP request. err: %v", err) + } + ocspReq[10] = 0 // random change + _, ost = extractCertIDKeyFromRequest(ocspReq) + if ost.err == nil { + t.Fatal("should have failed") + } + } + } +} + +func TestUnitCheckOCSPResponseCache(t *testing.T) { + c := New(testLogFactory, 10) + dummyKey0 := certIDKey{ + NameHash: "dummy0", + IssuerKeyHash: "dummy0", + SerialNumber: "dummy0", + } + dummyKey := certIDKey{ + NameHash: "dummy1", + IssuerKeyHash: "dummy1", + SerialNumber: "dummy1", + } + currentTime := float64(time.Now().UTC().Unix()) + c.ocspResponseCache.Add(dummyKey0, &ocspCachedResponse{time: currentTime}) + subject := &x509.Certificate{} + issuer := &x509.Certificate{} + ost, err := c.checkOCSPResponseCache(&dummyKey, subject, issuer) + if err != nil { + t.Fatal(err) + } + if ost.code != ocspMissedCache { + t.Fatalf("should have failed. expected: %v, got: %v", ocspMissedCache, ost.code) + } + // old timestamp + c.ocspResponseCache.Add(dummyKey, &ocspCachedResponse{time: float64(1395054952)}) + ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) + if err != nil { + t.Fatal(err) + } + if ost.code != ocspCacheExpired { + t.Fatalf("should have failed. expected: %v, got: %v", ocspCacheExpired, ost.code) + } + + // invalid validity + c.ocspResponseCache.Add(dummyKey, &ocspCachedResponse{time: float64(currentTime - 1000)}) + ost, err = c.checkOCSPResponseCache(&dummyKey, subject, nil) + if err == nil && isValidOCSPStatus(ost.code) { + t.Fatalf("should have failed.") + } +} + +func TestUnitValidateOCSP(t *testing.T) { + ocspRes := &ocsp.Response{} + ost, err := validateOCSP(ocspRes) + if err == nil && isValidOCSPStatus(ost.code) { + t.Fatalf("should have failed.") + } + + currentTime := time.Now() + ocspRes.ThisUpdate = currentTime.Add(-2 * time.Hour) + ocspRes.NextUpdate = currentTime.Add(2 * time.Hour) + ocspRes.Status = ocsp.Revoked + ost, err = validateOCSP(ocspRes) + if err != nil { + t.Fatal(err) + } + + if ost.code != ocspStatusRevoked { + t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusRevoked, ost.code) + } + ocspRes.Status = ocsp.Good + ost, err = validateOCSP(ocspRes) + if err != nil { + t.Fatal(err) + } + + if ost.code != ocspStatusGood { + t.Fatalf("should have success. expected: %v, got: %v", ocspStatusGood, ost.code) + } + ocspRes.Status = ocsp.Unknown + ost, err = validateOCSP(ocspRes) + if err != nil { + t.Fatal(err) + } + if ost.code != ocspStatusUnknown { + t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusUnknown, ost.code) + } + ocspRes.Status = ocsp.ServerFailed + ost, err = validateOCSP(ocspRes) + if err != nil { + t.Fatal(err) + } + if ost.code != ocspStatusOthers { + t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusOthers, ost.code) + } +} + +func TestUnitEncodeCertID(t *testing.T) { + var st *ocspStatus + _, st = extractCertIDKeyFromRequest([]byte{0x1, 0x2}) + if st.code != ocspFailedDecomposeRequest { + t.Fatalf("failed to get OCSP status. expected: %v, got: %v", ocspFailedDecomposeRequest, st.code) + } +} + +func getCert(addr string) []*x509.Certificate { + tcpConn, err := net.DialTimeout("tcp", addr, 40*time.Second) + if err != nil { + panic(err) + } + defer tcpConn.Close() + + err = tcpConn.SetDeadline(time.Now().Add(10 * time.Second)) + if err != nil { + panic(err) + } + config := tls.Config{InsecureSkipVerify: true, ServerName: addr} + + conn := tls.Client(tcpConn, &config) + defer conn.Close() + + err = conn.Handshake() + if err != nil { + panic(err) + } + + state := conn.ConnectionState() + + return state.PeerCertificates +} + +func TestOCSPRetry(t *testing.T) { + c := New(testLogFactory, 10) + certs := getCert("s3-us-west-2.amazonaws.com:443") + dummyOCSPHost := &url.URL{ + Scheme: "https", + Host: "dummyOCSPHost", + } + client := &fakeHTTPClient{ + cnt: 3, + success: true, + body: []byte{1, 2, 3}, + logger: hclog.New(hclog.DefaultOptions), + t: t, + } + res, b, st, err := c.retryOCSP( + context.TODO(), + client, fakeRequestFunc, + dummyOCSPHost, + make(map[string]string), []byte{0}, certs[len(certs)-1]) + if err == nil { + fmt.Printf("should fail: %v, %v, %v\n", res, b, st) + } + client = &fakeHTTPClient{ + cnt: 30, + success: true, + body: []byte{1, 2, 3}, + logger: hclog.New(hclog.DefaultOptions), + t: t, + } + res, b, st, err = c.retryOCSP( + context.TODO(), + client, fakeRequestFunc, + dummyOCSPHost, + make(map[string]string), []byte{0}, certs[len(certs)-1]) + if err == nil { + fmt.Printf("should fail: %v, %v, %v\n", res, b, st) + } +} + +type tcCanEarlyExit struct { + results []*ocspStatus + resultLen int + retFailOpen *ocspStatus + retFailClosed *ocspStatus +} + +func TestCanEarlyExitForOCSP(t *testing.T) { + testcases := []tcCanEarlyExit{ + { // 0 + results: []*ocspStatus{ + { + code: ocspStatusGood, + }, + { + code: ocspStatusGood, + }, + { + code: ocspStatusGood, + }, + }, + retFailOpen: nil, + retFailClosed: nil, + }, + { // 1 + results: []*ocspStatus{ + { + code: ocspStatusRevoked, + err: errors.New("revoked"), + }, + { + code: ocspStatusGood, + }, + { + code: ocspStatusGood, + }, + }, + retFailOpen: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, + retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, + }, + { // 2 + results: []*ocspStatus{ + { + code: ocspStatusUnknown, + err: errors.New("unknown"), + }, + { + code: ocspStatusGood, + }, + { + code: ocspStatusGood, + }, + }, + retFailOpen: nil, + retFailClosed: &ocspStatus{ocspStatusUnknown, errors.New("unknown")}, + }, + { // 3: not taken as revoked if any invalid OCSP response (ocspInvalidValidity) is included. + results: []*ocspStatus{ + { + code: ocspStatusRevoked, + err: errors.New("revoked"), + }, + { + code: ocspInvalidValidity, + }, + { + code: ocspStatusGood, + }, + }, + retFailOpen: nil, + retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, + }, + { // 4: not taken as revoked if the number of results don't match the expected results. + results: []*ocspStatus{ + { + code: ocspStatusRevoked, + err: errors.New("revoked"), + }, + { + code: ocspStatusGood, + }, + }, + resultLen: 3, + retFailOpen: nil, + retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, + }, + } + c := New(testLogFactory, 10) + for idx, tt := range testcases { + expectedLen := len(tt.results) + if tt.resultLen > 0 { + expectedLen = tt.resultLen + } + r := c.canEarlyExitForOCSP(tt.results, expectedLen, &VerifyConfig{OcspFailureMode: FailOpenTrue}) + if !(tt.retFailOpen == nil && r == nil) && !(tt.retFailOpen != nil && r != nil && tt.retFailOpen.code == r.code) { + t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailOpen, r) + } + r = c.canEarlyExitForOCSP(tt.results, expectedLen, &VerifyConfig{OcspFailureMode: FailOpenFalse}) + if !(tt.retFailClosed == nil && r == nil) && !(tt.retFailClosed != nil && r != nil && tt.retFailClosed.code == r.code) { + t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailClosed, r) + } + } +} + +var testLogger = hclog.New(hclog.DefaultOptions) + +func testLogFactory() hclog.Logger { + return testLogger +} + +type fakeHTTPClient struct { + cnt int // number of retry + success bool // return success after retry in cnt times + timeout bool // timeout + body []byte // return body + t *testing.T + logger hclog.Logger + redirected bool +} + +func (c *fakeHTTPClient) Do(_ *retryablehttp.Request) (*http.Response, error) { + c.cnt-- + if c.cnt < 0 { + c.cnt = 0 + } + c.t.Log("fakeHTTPClient.cnt", c.cnt) + + var retcode int + if !c.redirected { + c.redirected = true + c.cnt++ + retcode = 405 + } else if c.success && c.cnt == 1 { + retcode = 200 + } else { + if c.timeout { + // simulate timeout + time.Sleep(time.Second * 1) + return nil, &fakeHTTPError{ + err: "Whatever reason (Client.Timeout exceeded while awaiting headers)", + timeout: true, + } + } + retcode = 0 + } + + ret := &http.Response{ + StatusCode: retcode, + Body: &fakeResponseBody{body: c.body}, + } + return ret, nil +} + +type fakeHTTPError struct { + err string + timeout bool +} + +func (e *fakeHTTPError) Error() string { return e.err } +func (e *fakeHTTPError) Timeout() bool { return e.timeout } +func (e *fakeHTTPError) Temporary() bool { return true } + +type fakeResponseBody struct { + body []byte + cnt int +} + +func (b *fakeResponseBody) Read(p []byte) (n int, err error) { + if b.cnt == 0 { + copy(p, b.body) + b.cnt = 1 + return len(b.body), nil + } + b.cnt = 0 + return 0, io.EOF +} + +func (b *fakeResponseBody) Close() error { + return nil +} + +func fakeRequestFunc(_, _ string, _ interface{}) (*retryablehttp.Request, error) { + return nil, nil +} + +const vaultCert = `-----BEGIN CERTIFICATE----- +MIIDuTCCAqGgAwIBAgIUA6VeVD1IB5rXcCZRAqPO4zr/GAMwDQYJKoZIhvcNAQEL +BQAwcjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAlZBMREwDwYDVQQHDAhTb21lQ2l0 +eTESMBAGA1UECgwJTXlDb21wYW55MRMwEQYDVQQLDApNeURpdmlzaW9uMRowGAYD +VQQDDBF3d3cuY29uaHVnZWNvLmNvbTAeFw0yMjA5MDcxOTA1MzdaFw0yNDA5MDYx +OTA1MzdaMHIxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJWQTERMA8GA1UEBwwIU29t +ZUNpdHkxEjAQBgNVBAoMCU15Q29tcGFueTETMBEGA1UECwwKTXlEaXZpc2lvbjEa +MBgGA1UEAwwRd3d3LmNvbmh1Z2Vjby5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IB +DwAwggEKAoIBAQDL9qzEXi4PIafSAqfcwcmjujFvbG1QZbI8swxnD+w8i4ufAQU5 +LDmvMrGo3ZbhJ0mCihYmFxpjhRdP2raJQ9TysHlPXHtDRpr9ckWTKBz2oIfqVtJ2 +qzteQkWCkDAO7kPqzgCFsMeoMZeONRkeGib0lEzQAbW/Rqnphg8zVVkyQ71DZ7Pc +d5WkC2E28kKcSramhWfVFpxG3hSIrLOX2esEXteLRzKxFPf+gi413JZFKYIWrebP +u5t0++MLNpuX322geoki4BWMjQsd47XILmxZ4aj33ScZvdrZESCnwP76hKIxg9mO +lMxrqSWKVV5jHZrElSEj9LYJgDO1Y6eItn7hAgMBAAGjRzBFMAsGA1UdDwQEAwIE +MDATBgNVHSUEDDAKBggrBgEFBQcDATAhBgNVHREEGjAYggtleGFtcGxlLmNvbYIJ +bG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQA5dPdf5SdtMwe2uSspO/EuWqbM +497vMQBW1Ey8KRKasJjhvOVYMbe7De5YsnW4bn8u5pl0zQGF4hEtpmifAtVvziH/ +K+ritQj9VVNbLLCbFcg+b0kfjt4yrDZ64vWvIeCgPjG1Kme8gdUUWgu9dOud5gdx +qg/tIFv4TRS/eIIymMlfd9owOD3Ig6S5fy4NaAJFAwXf8+3Rzuc+e7JSAPgAufjh +tOTWinxvoiOLuYwo9CyGgq4qKBFsrY0aE0gdA7oTQkpbEbo2EbqiWUl/PTCl1Y4Z +nSZ0n+4q9QC9RLrWwYTwh838d5RVLUst2mBKSA+vn7YkqmBJbdBC6nkd7n7H +-----END CERTIFICATE----- +`