From 8f34df949f984d5ee02639cdbae4ffc7ac9e6167 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Thu, 8 Sep 2022 15:40:06 -0500 Subject: [PATCH 01/39] wip --- sdk/helper/ocsp/client.go | 1082 +++++++++++++++++++++++++++++++++ sdk/helper/ocsp/ocsp_test.go | 506 +++++++++++++++ sdk/helper/ocsp/retry.go | 550 +++++++++++++++++ sdk/helper/ocsp/retry_test.go | 271 +++++++++ 4 files changed, 2409 insertions(+) create mode 100644 sdk/helper/ocsp/client.go create mode 100644 sdk/helper/ocsp/ocsp_test.go create mode 100644 sdk/helper/ocsp/retry.go create mode 100644 sdk/helper/ocsp/retry_test.go diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go new file mode 100644 index 000000000000..dda265e70a13 --- /dev/null +++ b/sdk/helper/ocsp/client.go @@ -0,0 +1,1082 @@ +// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. + +package ocsp + +import ( + "context" + "crypto" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/logical" + "golang.org/x/crypto/ocsp" + "io" + "io/ioutil" + "math/big" + "net" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// OCSPFailOpenMode is OCSP fail open mode. OCSPFailOpenTrue by default and may +// set to ocspModeFailClosed for fail closed mode +type OCSPFailOpenMode uint32 + +const ( + ocspFailOpenNotSet OCSPFailOpenMode = iota + // OCSPFailOpenTrue represents OCSP fail open mode. + OCSPFailOpenTrue + // OCSPFailOpenFalse represents OCSP fail closed mode. + OCSPFailOpenFalse +) +const ( + ocspModeFailOpen = "FAIL_OPEN" + ocspModeFailClosed = "FAIL_CLOSED" + ocspModeInsecure = "INSECURE" +) + +// OCSP fail open mode +var ocspFailOpen = OCSPFailOpenTrue + +const ( + // defaultOCSPCacheServerTimeout is the total timeout for OCSP cache server. + defaultOCSPCacheServerTimeout = 5 * time.Second + + // defaultOCSPResponderTimeout is the total timeout for OCSP responder. + defaultOCSPResponderTimeout = 10 * time.Second +) + +const ( + cacheFileBaseName = "ocsp_response_cache.json" + // cacheExpire specifies cache data expiration time in seconds. + cacheExpire = float64(24 * 60 * 60) + cacheServerURL = "http://ocsp.snowflakecomputing.com" + cacheServerEnabledEnv = "SF_OCSP_RESPONSE_CACHE_SERVER_ENABLED" + cacheServerURLEnv = "SF_OCSP_RESPONSE_CACHE_SERVER_URL" + cacheDirEnv = "SF_OCSP_RESPONSE_CACHE_DIR" + ocspRetryURLEnv = "SF_OCSP_RESPONSE_RETRY_URL" +) + +const ( + ocspTestInjectValidityErrorEnv = "SF_OCSP_TEST_INJECT_VALIDITY_ERROR" + ocspTestInjectUnknownStatusEnv = "SF_OCSP_TEST_INJECT_UNKNOWN_STATUS" + ocspTestResponseCacheServerTimeoutEnv = "SF_OCSP_TEST_OCSP_RESPONSE_CACHE_SERVER_TIMEOUT" + ocspTestResponderTimeoutEnv = "SF_OCSP_TEST_OCSP_RESPONDER_TIMEOUT" + ocspTestResponderURLEnv = "SF_OCSP_TEST_RESPONDER_URL" + ocspTestNoOCSPURLEnv = "SF_OCSP_TEST_NO_OCSP_RESPONDER_URL" +) + +const ( + tolerableValidityRatio = 100 // buffer for certificate revocation update time + maxClockSkew = 900 * time.Second // buffer for clock skew +) + +type Client struct { + // caRoot includes the CA certificates. + caRoot map[string]*x509.Certificate + // certPOol includes the CA certificates. + certPool *x509.CertPool + ocspResponseCache map[certIDKey][]interface{} + ocspResponseCacheLock sync.RWMutex + // cacheUpdated is true if the memory cache is updated + cacheUpdated bool + logFactory func() hclog.Logger +} + +type ocspStatusCode int + +type ocspStatus struct { + code ocspStatusCode + err error +} + +const ( + ocspSuccess ocspStatusCode = 0 + ocspStatusGood ocspStatusCode = -1 + ocspStatusRevoked ocspStatusCode = -2 + ocspStatusUnknown ocspStatusCode = -3 + ocspStatusOthers ocspStatusCode = -4 + ocspNoServer ocspStatusCode = -5 + ocspFailedParseOCSPHost ocspStatusCode = -6 + ocspFailedComposeRequest ocspStatusCode = -7 + ocspFailedDecomposeRequest ocspStatusCode = -8 + ocspFailedSubmit ocspStatusCode = -9 + ocspFailedResponse ocspStatusCode = -10 + ocspFailedExtractResponse ocspStatusCode = -11 + ocspFailedParseResponse ocspStatusCode = -12 + ocspInvalidValidity ocspStatusCode = -13 + ocspMissedCache ocspStatusCode = -14 + ocspCacheExpired ocspStatusCode = -15 + ocspFailedDecodeResponse ocspStatusCode = -16 +) + +// copied from crypto/ocsp.go +type certID struct { + HashAlgorithm pkix.AlgorithmIdentifier + NameHash []byte + IssuerKeyHash []byte + SerialNumber *big.Int +} + +// cache key +type certIDKey struct { + HashAlgorithm crypto.Hash + NameHash string + IssuerKeyHash string + SerialNumber string +} + +// copied from crypto/ocsp +var hashOIDs = map[crypto.Hash]asn1.ObjectIdentifier{ + crypto.SHA1: asn1.ObjectIdentifier([]int{1, 3, 14, 3, 2, 26}), + crypto.SHA256: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 1}), + crypto.SHA384: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 2}), + crypto.SHA512: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 3}), +} + +// copied from crypto/ocsp +func (c *Client) getOIDFromHashAlgorithm(target crypto.Hash) asn1.ObjectIdentifier { + for hash, oid := range hashOIDs { + if hash == target { + return oid + } + } + c.Logger().Error("no valid OID is found for the hash algorithm", "target", target) + return nil +} + +func (c *Client) getHashAlgorithmFromOID(target pkix.AlgorithmIdentifier) crypto.Hash { + for hash, oid := range hashOIDs { + if oid.Equal(target.Algorithm) { + return hash + } + } + c.Logger().Error("no valid hash algorithm is found for the oid. Falling back to SHA1", "target", target) + return crypto.SHA1 +} + +// calcTolerableValidity returns the maximum validity buffer +func calcTolerableValidity(thisUpdate, nextUpdate time.Time) time.Duration { + return durationMax(nextUpdate.Sub(thisUpdate)/tolerableValidityRatio, maxClockSkew) +} + +func durationMax(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} + +func durationMin(a, b time.Duration) time.Duration { + return durationMax(b, a) +} + +// isInValidityRange checks the validity +func isInValidityRange(currTime, thisUpdate, nextUpdate time.Time) bool { + if currTime.Sub(thisUpdate.Add(-maxClockSkew)) < 0 { + return false + } + if nextUpdate.Add(calcTolerableValidity(thisUpdate, nextUpdate)).Sub(currTime) < 0 { + return false + } + return true +} + +func isTestInvalidValidity() bool { + return strings.EqualFold(os.Getenv(ocspTestInjectValidityErrorEnv), "true") +} + +func extractCertIDKeyFromRequest(ocspReq []byte) (*certIDKey, *ocspStatus) { + r, err := ocsp.ParseRequest(ocspReq) + if err != nil { + return nil, &ocspStatus{ + code: ocspFailedDecomposeRequest, + err: err, + } + } + + // encode CertID, used as a key in the cache + encodedCertID := &certIDKey{ + r.HashAlgorithm, + base64.StdEncoding.EncodeToString(r.IssuerNameHash), + base64.StdEncoding.EncodeToString(r.IssuerKeyHash), + r.SerialNumber.String(), + } + return encodedCertID, &ocspStatus{ + code: ocspSuccess, + } +} + +func (c *Client) encodeCertIDKey(certIDKeyBase64 string) *certIDKey { + r, err := base64.StdEncoding.DecodeString(certIDKeyBase64) + if err != nil { + return nil + } + var cid certID + rest, err := asn1.Unmarshal(r, &cid) + if err != nil { + // error in parsing + return nil + } + if len(rest) > 0 { + // extra bytes to the end + return nil + } + return &certIDKey{ + c.getHashAlgorithmFromOID(cid.HashAlgorithm), + base64.StdEncoding.EncodeToString(cid.NameHash), + base64.StdEncoding.EncodeToString(cid.IssuerKeyHash), + cid.SerialNumber.String(), + } +} + +func (c *Client) decodeCertIDKey(k *certIDKey) string { + serialNumber := new(big.Int) + serialNumber.SetString(k.SerialNumber, 10) + nameHash, err := base64.StdEncoding.DecodeString(k.NameHash) + if err != nil { + return "" + } + issuerKeyHash, err := base64.StdEncoding.DecodeString(k.IssuerKeyHash) + if err != nil { + return "" + } + encodedCertID, err := asn1.Marshal(certID{ + pkix.AlgorithmIdentifier{ + Algorithm: c.getOIDFromHashAlgorithm(k.HashAlgorithm), + Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */}, + }, + nameHash, + issuerKeyHash, + serialNumber, + }) + if err != nil { + return "" + } + return base64.StdEncoding.EncodeToString(encodedCertID) +} + +func (c *Client) checkOCSPResponseCache(encodedCertID *certIDKey, subject, issuer *x509.Certificate) (*ocspStatus, error) { + if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") { + return &ocspStatus{code: ocspNoServer}, nil + } + c.ocspResponseCacheLock.RLock() + gotValueFromCache := c.ocspResponseCache[*encodedCertID] + c.ocspResponseCacheLock.RUnlock() + + status, err := c.extractOCSPCacheResponseValue(gotValueFromCache, subject, issuer) + if err != nil { + return nil, err + } + if !isValidOCSPStatus(status.code) { + c.deleteOCSPCache(encodedCertID) + } + return status, err +} + +func (c *Client) deleteOCSPCache(encodedCertID *certIDKey) { + c.ocspResponseCacheLock.Lock() + delete(c.ocspResponseCache, *encodedCertID) + c.cacheUpdated = true + c.ocspResponseCacheLock.Unlock() +} + +func validateOCSP(ocspRes *ocsp.Response) (*ocspStatus, error) { + curTime := time.Now() + + if ocspRes == nil { + return nil, errors.New("OCSP Response is nil") + } + if isTestInvalidValidity() || !isInValidityRange(curTime, ocspRes.ThisUpdate, ocspRes.NextUpdate) { + return nil, fmt.Errorf("invalid validity: producedAt: %v, thisUpdate: %v, nextUpdate: %v", ocspRes.ProducedAt, ocspRes.ThisUpdate, ocspRes.NextUpdate) + } + if isTestUnknownStatus() { + ocspRes.Status = ocsp.Unknown + } + return returnOCSPStatus(ocspRes), nil +} + +func returnOCSPStatus(ocspRes *ocsp.Response) *ocspStatus { + switch ocspRes.Status { + case ocsp.Good: + return &ocspStatus{ + code: ocspStatusGood, + err: nil, + } + case ocsp.Revoked: + return &ocspStatus{ + code: ocspStatusRevoked, + } + case ocsp.Unknown: + return &ocspStatus{ + code: ocspStatusUnknown, + err: errors.New("OCSP status unknown."), + } + default: + return &ocspStatus{ + code: ocspStatusOthers, + err: fmt.Errorf("OCSP others. %v", ocspRes.Status), + } + } +} + +func isTestUnknownStatus() bool { + return strings.EqualFold(os.Getenv(ocspTestInjectUnknownStatusEnv), "true") +} + +func (c *Client) checkOCSPCacheServer( + ctx context.Context, + client clientInterface, + req requestFunc, + ocspServerHost *url.URL, + totalTimeout time.Duration) ( + cacheContent *map[string][]interface{}, + ocspS *ocspStatus) { + var respd map[string][]interface{} + headers := make(map[string]string) + res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout).execute() + if err != nil { + c.Logger().Error("failed to get OCSP cache from OCSP Cache Server. ", "err", err) + return nil, &ocspStatus{ + code: ocspFailedSubmit, + err: err, + } + } + defer res.Body.Close() + c.Logger().Debug("StatusCode from OCSP Cache Server", "statusCode", res.StatusCode) + if res.StatusCode != http.StatusOK { + return nil, &ocspStatus{ + code: ocspFailedResponse, + err: fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status), + } + } + c.Logger().Debug("reading contents") + + dec := json.NewDecoder(res.Body) + for { + if err := dec.Decode(&respd); err == io.EOF { + break + } else if err != nil { + c.Logger().Error("failed to decode OCSP cache.", "err", err) + return nil, &ocspStatus{ + code: ocspFailedExtractResponse, + err: err, + } + } + } + return &respd, &ocspStatus{ + code: ocspSuccess, + } +} + +// retryOCSP is the second level of retry method if the returned contents are corrupted. It often happens with OCSP +// serer and retry helps. +func (c *Client) retryOCSP( + ctx context.Context, + client clientInterface, + req requestFunc, + ocspHost *url.URL, + headers map[string]string, + reqBody []byte, + issuer *x509.Certificate, + totalTimeout time.Duration) ( + ocspRes *ocsp.Response, + ocspResBytes []byte, + ocspS *ocspStatus) { + multiplier := 1 + if atomic.LoadUint32((*uint32)(&ocspFailOpen)) == (uint32)(OCSPFailOpenFalse) { + multiplier = 3 // up to 3 times for Fail Close mode + } + res, err := newRetryHTTP( + ctx, client, req, ocspHost, headers, + totalTimeout*time.Duration(multiplier)).doPost().setBody(reqBody).execute() + if err != nil { + return ocspRes, ocspResBytes, &ocspStatus{ + code: ocspFailedSubmit, + err: err, + } + } + defer res.Body.Close() + c.Logger().Debug("StatusCode from OCSP Server:", "statusCode", res.StatusCode) + if res.StatusCode != http.StatusOK { + return ocspRes, ocspResBytes, &ocspStatus{ + code: ocspFailedResponse, + err: fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status), + } + } + c.Logger().Debug("reading contents") + ocspResBytes, err = ioutil.ReadAll(res.Body) + if err != nil { + return ocspRes, ocspResBytes, &ocspStatus{ + code: ocspFailedExtractResponse, + err: err, + } + } + c.Logger().Debug("parsing OCSP response") + ocspRes, err = ocsp.ParseResponse(ocspResBytes, issuer) + if err != nil { + return ocspRes, ocspResBytes, &ocspStatus{ + code: ocspFailedParseResponse, + err: err, + } + } + + return ocspRes, ocspResBytes, &ocspStatus{ + code: ocspSuccess, + } +} + +// getRevocationStatus checks the certificate revocation status for subject using issuer certificate. +func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate) (*ocspStatus, error) { + c.Logger().Info("get-revocation-status", "subject", subject.Subject, "issuer", issuer.Subject) + + status, ocspReq, encodedCertID, err := c.ValidateWithCache(subject, issuer) + if err != nil { + return nil, err + } + if isValidOCSPStatus(status.code) { + return status, nil + } + if ocspReq == nil || encodedCertID == nil { + return status, nil + } + c.Logger().Info("cache missed") + c.Logger().Info("OCSP: ", "server", subject.OCSPServer) + if len(subject.OCSPServer) == 0 || isTestNoOCSPURL() { + return nil, fmt.Errorf("no OCSP responder URL: subject: %v", subject.Subject) + } + ocspHost := subject.OCSPServer[0] + u, err := url.Parse(ocspHost) + if err != nil { + return nil, err + } + hostnameStr := os.Getenv(ocspTestResponderURLEnv) + var hostname string + if retryURL := os.Getenv(ocspRetryURLEnv); retryURL != "" { + hostname = fmt.Sprintf(retryURL, u.Hostname(), base64.StdEncoding.EncodeToString(ocspReq)) + } else { + hostname = u.Hostname() + } + if hostnameStr != "" { + u0, err := url.Parse(hostnameStr) + if err == nil { + hostname = u0.Hostname() + u = u0 + } + } + headers := make(map[string]string) + headers[httpHeaderContentType] = "application/ocsp-request" + headers[httpHeaderAccept] = "application/ocsp-response" + headers[httpHeaderContentLength] = strconv.Itoa(len(ocspReq)) + headers[httpHeaderHost] = hostname + timeoutStr := os.Getenv(ocspTestResponderTimeoutEnv) + timeout := defaultOCSPResponderTimeout + if timeoutStr != "" { + var timeoutInMilliseconds int + timeoutInMilliseconds, err = strconv.Atoi(timeoutStr) + if err == nil { + timeout = time.Duration(timeoutInMilliseconds) * time.Millisecond + } + } + ocspClient := &http.Client{ + Timeout: timeout, + Transport: snowflakeInsecureTransport, + } + ocspRes, ocspResBytes, ocspS := c.retryOCSP( + ctx, ocspClient, http.NewRequest, u, headers, ocspReq, issuer, timeout) + if ocspS.code != ocspSuccess { + return ocspS, nil + } + + ret, err := validateOCSP(ocspRes) + if err != nil { + return nil, err + } + if !isValidOCSPStatus(ret.code) { + return ret, nil // return invalid + } + v := []interface{}{float64(time.Now().UTC().Unix()), base64.StdEncoding.EncodeToString(ocspResBytes)} + c.ocspResponseCacheLock.Lock() + c.ocspResponseCache[*encodedCertID] = v + c.cacheUpdated = true + c.ocspResponseCacheLock.Unlock() + return ret, nil +} + +func isTestNoOCSPURL() bool { + return strings.EqualFold(os.Getenv(ocspTestNoOCSPURLEnv), "true") +} + +func isValidOCSPStatus(status ocspStatusCode) bool { + return status == ocspStatusGood || status == ocspStatusRevoked || status == ocspStatusUnknown +} + +// VerifyPeerCertificate verifies all of certificate revocation status +func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate) (err error) { + for i := 0; i < len(verifiedChains); i++ { + // Certificate signed by Root CA. This should be one before the last in the Certificate Chain + numberOfNoneRootCerts := len(verifiedChains[i]) - 1 + if !verifiedChains[i][numberOfNoneRootCerts].IsCA || string(verifiedChains[i][numberOfNoneRootCerts].RawIssuer) != string(verifiedChains[i][numberOfNoneRootCerts].RawSubject) { + // Check if the last Non Root Cert is also a CA or is self signed. + // if the last certificate is not, add it to the list + rca := c.caRoot[string(verifiedChains[i][numberOfNoneRootCerts].RawIssuer)] + if rca == nil { + return fmt.Errorf("failed to find root CA. pkix.name: %v", verifiedChains[i][numberOfNoneRootCerts].Issuer) + } + verifiedChains[i] = append(verifiedChains[i], rca) + numberOfNoneRootCerts++ + } + results, err := c.getAllRevocationStatus(ctx, verifiedChains[i]) + if err != nil { + return err + } + if r := c.canEarlyExitForOCSP(results, numberOfNoneRootCerts); r != nil { + return r.err + } + } + + return nil +} + +func (c *Client) canEarlyExitForOCSP(results []*ocspStatus, chainSize int) *ocspStatus { + msg := "" + if atomic.LoadUint32((*uint32)(&ocspFailOpen)) == (uint32)(OCSPFailOpenFalse) { + // Fail closed. any error is returned to stop connection + for _, r := range results { + if r.err != nil { + return r + } + } + } else { + // Fail open and all results are valid. + allValid := len(results) == chainSize + for _, r := range results { + if !isValidOCSPStatus(r.code) { + allValid = false + break + } + } + for _, r := range results { + if allValid && r.code == ocspStatusRevoked { + return r + } + if r != nil && r.code != ocspStatusGood && r.err != nil { + msg += "" + r.err.Error() + } + } + } + if len(msg) > 0 { + c.Logger().Warn( + "WARNING!!! Using fail-open to connect. Driver is connecting to an "+ + "HTTPS endpoint without OCSP based Certificate Revocation checking "+ + "as it could not obtain a valid OCSP Response to use from the CA OCSP "+ + "responder", "detail", msg[1:]) + } + return nil +} + +func (c *Client) ValidateWithCacheForAllCertificates(verifiedChains []*x509.Certificate) (bool, error) { + n := len(verifiedChains) - 1 + for j := 0; j < n; j++ { + subject := verifiedChains[j] + issuer := verifiedChains[j+1] + status, _, _, err := c.ValidateWithCache(subject, issuer) + if err != nil { + return false, err + } + if !isValidOCSPStatus(status.code) { + return false, nil + } + } + return true, nil +} + +func (c *Client) ValidateWithCache(subject, issuer *x509.Certificate) (*ocspStatus, []byte, *certIDKey, error) { + ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{}) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to create OCSP request from the certificates: %v", err) + } + encodedCertID, ocspS := extractCertIDKeyFromRequest(ocspReq) + if ocspS.code != ocspSuccess { + return nil, nil, nil, fmt.Errorf("failed to extract CertID from OCSP Request: %v", err) + } + status, err := c.checkOCSPResponseCache(encodedCertID, subject, issuer) + if err != nil { + return nil, nil, nil, err + } + return status, ocspReq, encodedCertID, nil +} + +func (c *Client) getAllRevocationStatus(ctx context.Context, verifiedChains []*x509.Certificate) ([]*ocspStatus, error) { + _, err := c.ValidateWithCacheForAllCertificates(verifiedChains) + if err != nil { + return nil, err + } + n := len(verifiedChains) - 1 + results := make([]*ocspStatus, n) + for j := 0; j < n; j++ { + results[j], err = c.getRevocationStatus(ctx, verifiedChains[j], verifiedChains[j+1]) + if err != nil { + return nil, err + } + if !isValidOCSPStatus(results[j].code) { + return results, nil + } + } + return results, nil +} + +// verifyPeerCertificateSerial verifies the certificate revocation status in serial. +func (c *Client) verifyPeerCertificateSerial(_ [][]byte, verifiedChains [][]*x509.Certificate) (err error) { + return c.VerifyPeerCertificate(context.TODO(), verifiedChains) +} + +/* +// initOCSPCache initializes OCSP Response cache file. +func (c *Client) initOCSPCache() { + if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") { + return + } + ocspResponseCache = make(map[certIDKey][]interface{}) + ocspResponseCacheLock = &sync.RWMutex{} + + c.Logger().Info("reading OCSP Response cache file.", "filename", cacheFileName) + f, err := os.OpenFile(cacheFileName, os.O_CREATE|os.O_RDONLY, os.ModePerm) + if err != nil { + c.Logger().Debug("failed to open. Ignored.", "err", err) + return + } + defer f.Close() + + buf := make(map[string][]interface{}) + + r := bufio.NewReader(f) + dec := json.NewDecoder(r) + for { + if err = dec.Decode(&buf); err == io.EOF { + break + } else if err != nil { + c.Logger().Debug("failed to read. Ignored.", "err", err) + return + } + } + for k, cacheValue := range buf { + status := c.extractOCSPCacheResponseValueWithoutSubject(cacheValue) + if !isValidOCSPStatus(status.code) { + continue + } + cacheKey := c.encodeCertIDKey(k) + ocspResponseCache[*cacheKey] = cacheValue + + } + cacheUpdated = false +}*/ +func (c *Client) extractOCSPCacheResponseValueWithoutSubject(cacheValue []interface{}) (*ocspStatus, error) { + return c.extractOCSPCacheResponseValue(cacheValue, nil, nil) +} + +func (c *Client) extractOCSPCacheResponseValue(cacheValue []interface{}, subject, issuer *x509.Certificate) (*ocspStatus, error) { + subjectName := "Unknown" + if subject != nil { + subjectName = subject.Subject.CommonName + } + + curTime := time.Now() + if len(cacheValue) != 2 { + return &ocspStatus{ + code: ocspMissedCache, + err: fmt.Errorf("miss cache data. subject: %v", subjectName), + }, nil + } + if ts, ok := cacheValue[0].(float64); ok { + currentTime := float64(curTime.UTC().Unix()) + if currentTime-ts >= cacheExpire { + return &ocspStatus{ + code: ocspCacheExpired, + err: fmt.Errorf("cache expired. current: %v, cache: %v", + time.Unix(int64(currentTime), 0).UTC(), time.Unix(int64(ts), 0).UTC()), + }, nil + } + } else { + return nil, errors.New("the first cache element is not float64") + } + var err error + var r *ocsp.Response + if s, ok := cacheValue[1].(string); ok { + var b []byte + b, err = base64.StdEncoding.DecodeString(s) + if err != nil { + return nil, fmt.Errorf("failed to decode OCSP Response value in a cache. subject: %v, err: %v", subjectName, err) + + } + // check the revocation status here + r, err = ocsp.ParseResponse(b, issuer) + if err != nil { + c.Logger().Warn("the second cache element is not a valid OCSP Response. Ignored.", "subject", subjectName) + return nil, fmt.Errorf("failed to parse OCSP Respose. subject: %v, err: %v", subjectName, err) + } + } else { + return nil, errors.New("the second cache element is not string") + + } + return validateOCSP(r) +} + +const storageValueKey = "backingStore" + +// writeOCSPCacheFile writes a OCSP Response cache file. This is called if all revocation status is success. +// lock file is used to mitigate race condition with other process. +func (c *Client) writeOCSPCacheFile(ctx context.Context, storage logical.Storage) error { + c.Logger().Debug("writing OCSP Response cache file") + + buf := make(map[string][]interface{}) + for k, v := range c.ocspResponseCache { + cacheKeyInBase64 := c.decodeCertIDKey(&k) + buf[cacheKeyInBase64] = v + } + + j, err := json.Marshal(buf) + if err != nil { + return errors.New("failed to convert OCSP Response cache to JSON") + } + + entry := logical.StorageEntry{ + Key: "ocsp_cache", + Value: j, + } + return storage.Put(ctx, &entry) +} + +/* +// readCACerts read a set of root CAs +func (c *Client) readCACerts() { + raw := []byte(c.caRootPEM) + certPool = x509.NewCertPool() + caRoot = make(map[string]*x509.Certificate) + var p *pem.Block + for { + p, raw = pem.Decode(raw) + if p == nil { + break + } + if p.Type != "CERTIFICATE" { + continue + } + c, err := x509.ParseCertificate(p.Bytes) + if err != nil { + panic("failed to parse CA certificate.") + } + certPool.AddCert(c) + caRoot[string(c.RawSubject)] = c + } +} + +// createOCSPCacheDir creates OCSP response cache directory and set the cache file name. +func createOCSPCacheDir() { + if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") { + c.Logger().Info(`OCSP Cache Server disabled. All further access and use of + OCSP Cache will be disabled for this OCSP Status Query`) + return + } + cacheDir = os.Getenv(cacheDirEnv) + if cacheDir == "" { + cacheDir = os.Getenv("SNOWFLAKE_TEST_WORKSPACE") + } + if cacheDir == "" { + switch runtime.GOOS { + case "windows": + cacheDir = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local", "Snowflake", "Caches") + case "darwin": + home := os.Getenv("HOME") + if home == "" { + c.Logger().Info("HOME is blank.") + } + cacheDir = filepath.Join(home, "Library", "Caches", "Snowflake") + default: + home := os.Getenv("HOME") + if home == "" { + c.Logger().Info("HOME is blank") + } + cacheDir = filepath.Join(home, ".cache", "snowflake") + } + } + + if _, err := os.Stat(cacheDir); os.IsNotExist(err) { + if err = os.MkdirAll(cacheDir, os.ModePerm); err != nil { + c.Logger().Debugf("failed to create cache directory. %v, err: %v. ignored", cacheDir, err) + } + } + cacheFileName = filepath.Join(cacheDir, cacheFileBaseName) + c.Logger().Infof("reset OCSP cache file. %v", cacheFileName) +} +*/ +func New(logFactory func() hclog.Logger) *Client { + c := Client{ + caRoot: make(map[string]*x509.Certificate), + ocspResponseCache: make(map[certIDKey][]interface{}), + logFactory: logFactory, + } + + return &c +} + +func (c *Client) Logger() hclog.Logger { + return c.logFactory() +} + +// snowflakeInsecureTransport is the transport object that doesn't do certificate revocation check. +var snowflakeInsecureTransport = &http.Transport{ + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Minute, + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, +} + +// SnowflakeTransport includes the certificate revocation check with OCSP in sequential. By default, the driver uses +// this transport object. +func (c *Client) NewTransport() *http.Transport { + return &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: c.certPool, + VerifyPeerCertificate: c.verifyPeerCertificateSerial, + }, + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Minute, + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + } +} + +func (c *Client) WriteCache(ctx context.Context, storage logical.Storage) error { + c.ocspResponseCacheLock.Lock() + if c.cacheUpdated { + defer c.ocspResponseCacheLock.Unlock() + if c.cacheUpdated { + return c.writeOCSPCacheFile(ctx, storage) + } + c.cacheUpdated = false + } + return nil +} + +/* + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go new file mode 100644 index 000000000000..934061abd44d --- /dev/null +++ b/sdk/helper/ocsp/ocsp_test.go @@ -0,0 +1,506 @@ +// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. + +package ocsp + +import ( + "bytes" + "context" + "crypto" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "github.com/hashicorp/go-hclog" + "io/ioutil" + "net" + "net/http" + "net/url" + "os" + "testing" + "time" + + "golang.org/x/crypto/ocsp" +) + +func TestOCSP(t *testing.T) { + cacheServerEnabled := []string{ + "true", + "false", + } + targetURL := []string{ + "https://sfctest0.snowflakecomputing.com/", + "https://s3-us-west-2.amazonaws.com/sfc-snowsql-updates/?prefix=1.1/windows_x86_64", + "https://sfcdev1.blob.core.windows.net/", + } + + c := New(testLogFactory) + transports := []*http.Transport{ + snowflakeInsecureTransport, + c.NewTransport(), + } + + for _, enabled := range cacheServerEnabled { + for _, tgt := range targetURL { + _ = os.Setenv(cacheServerEnabledEnv, enabled) + //_ = os.Remove(cacheFileName) // clear cache file + c.ocspResponseCache = make(map[certIDKey][]interface{}) + for _, tr := range transports { + c := &http.Client{ + Transport: tr, + Timeout: 30 * time.Second, + } + req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil)) + if err != nil { + t.Fatalf("fail to create a request. err: %v", err) + } + res, err := c.Do(req) + if err != nil { + t.Fatalf("failed to GET contents. err: %v", err) + } + defer res.Body.Close() + _, err = ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("failed to read content body for %v", tgt) + } + + } + } + } + _ = os.Unsetenv(cacheServerEnabledEnv) +} + +type tcValidityRange struct { + thisTime time.Time + nextTime time.Time + ret bool +} + +func TestUnitIsInValidityRange(t *testing.T) { + currentTime := time.Now() + testcases := []tcValidityRange{ + { + // basic tests + thisTime: currentTime.Add(-100 * time.Second), + nextTime: currentTime.Add(maxClockSkew), + ret: true, + }, + { + // on the border + thisTime: currentTime.Add(maxClockSkew), + nextTime: currentTime.Add(maxClockSkew), + ret: true, + }, + { + // 1 earlier late + thisTime: currentTime.Add(maxClockSkew + 1*time.Second), + nextTime: currentTime.Add(maxClockSkew), + ret: false, + }, + { + // on the border + thisTime: currentTime.Add(-maxClockSkew), + nextTime: currentTime.Add(-maxClockSkew), + ret: true, + }, + { + // around the border + thisTime: currentTime.Add(-24*time.Hour - 40*time.Second), + nextTime: currentTime.Add(-24*time.Hour/time.Duration(100) - 40*time.Second), + ret: false, + }, + { + // on the border + thisTime: currentTime.Add(-48*time.Hour - 29*time.Minute), + nextTime: currentTime.Add(-48 * time.Hour / time.Duration(100)), + ret: true, + }, + } + for _, tc := range testcases { + if tc.ret != isInValidityRange(currentTime, tc.thisTime, tc.nextTime) { + t.Fatalf("failed to check validity. should be: %v, currentTime: %v, thisTime: %v, nextTime: %v", tc.ret, currentTime, tc.thisTime, tc.nextTime) + } + } +} + +func TestUnitEncodeCertIDGood(t *testing.T) { + targetURLs := []string{ + "faketestaccount.snowflakecomputing.com:443", + "s3-us-west-2.amazonaws.com:443", + "sfcdev1.blob.core.windows.net:443", + } + for _, tt := range targetURLs { + chainedCerts := getCert(tt) + for i := 0; i < len(chainedCerts)-1; i++ { + subject := chainedCerts[i] + issuer := chainedCerts[i+1] + ocspServers := subject.OCSPServer + if len(ocspServers) == 0 { + t.Fatalf("no OCSP server is found. cert: %v", subject.Subject) + } + ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{}) + if err != nil { + t.Fatalf("failed to create OCSP request. err: %v", err) + } + var ost *ocspStatus + _, ost = extractCertIDKeyFromRequest(ocspReq) + if ost.err != nil { + t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err) + } + // better hash. Not sure if the actual OCSP server accepts this, though. + ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512}) + if err != nil { + t.Fatalf("failed to create OCSP request. err: %v", err) + } + _, ost = extractCertIDKeyFromRequest(ocspReq) + if ost.err != nil { + t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err) + } + // tweaked request binary + ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512}) + if err != nil { + t.Fatalf("failed to create OCSP request. err: %v", err) + } + ocspReq[10] = 0 // random change + _, ost = extractCertIDKeyFromRequest(ocspReq) + if ost.err == nil { + t.Fatal("should have failed") + } + } + } +} + +func TestUnitCheckOCSPResponseCache(t *testing.T) { + c := New(testLogFactory) + dummyKey0 := certIDKey{ + HashAlgorithm: crypto.SHA1, + NameHash: "dummy0", + IssuerKeyHash: "dummy0", + SerialNumber: "dummy0", + } + dummyKey := certIDKey{ + HashAlgorithm: crypto.SHA1, + NameHash: "dummy1", + IssuerKeyHash: "dummy1", + SerialNumber: "dummy1", + } + b64Key := base64.StdEncoding.EncodeToString([]byte("DUMMY_VALUE")) + currentTime := float64(time.Now().UTC().Unix()) + c.ocspResponseCache[dummyKey0] = []interface{}{currentTime, b64Key} + subject := &x509.Certificate{} + issuer := &x509.Certificate{} + ost, err := c.checkOCSPResponseCache(&dummyKey, subject, issuer) + if err != nil { + t.Fatal(err) + } + if ost.code != ocspMissedCache { + t.Fatalf("should have failed. expected: %v, got: %v", ocspMissedCache, ost.code) + } + // old timestamp + c.ocspResponseCache[dummyKey] = []interface{}{float64(1395054952), b64Key} + ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) + if err != nil { + t.Fatal(err) + } + if ost.code != ocspCacheExpired { + t.Fatalf("should have failed. expected: %v, got: %v", ocspCacheExpired, ost.code) + } + // future timestamp + c.ocspResponseCache[dummyKey] = []interface{}{float64(1805054952), b64Key} + ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) + if err == nil { + t.Fatalf("should have failed.") + } + // actual OCSP but it fails to parse, because an invalid issuer certificate is given. + actualOcspResponse := "MIIB0woBAKCCAcwwggHIBgkrBgEFBQcwAQEEggG5MIIBtTCBnqIWBBSxPsNpA/i/RwHUmCYaCALvY2QrwxgPMjAxNz" + + "A1MTYyMjAwMDBaMHMwcTBJMAkGBSsOAwIaBQAEFN+qEuMosQlBk+KfQoLOR0BClVijBBSxPsNpA/i/RwHUmCYaCALvY2QrwwIQBOHnp" + + "Nxc8vNtwCtCuF0Vn4AAGA8yMDE3MDUxNjIyMDAwMFqgERgPMjAxNzA1MjMyMjAwMDBaMA0GCSqGSIb3DQEBCwUAA4IBAQCuRGwqQsKy" + + "IAAGHgezTfG0PzMYgGD/XRDhU+2i08WTJ4Zs40Lu88cBeRXWF3iiJSpiX3/OLgfI7iXmHX9/sm2SmeNWc0Kb39bk5Lw1jwezf8hcI9+" + + "mZHt60vhUgtgZk21SsRlTZ+S4VXwtDqB1Nhv6cnSnfrL2A9qJDZS2ltPNOwebWJnznDAs2dg+KxmT2yBXpHM1kb0EOolWvNgORbgIgB" + + "koRzw/UU7zKsqiTB0ZN/rgJp+MocTdqQSGKvbZyR8d4u8eNQqi1x4Pk3yO/pftANFaJKGB+JPgKS3PQAqJaXcipNcEfqtl7y4PO6kqA" + + "Jb4xI/OTXIrRA5TsT4cCioE" + // issuer is not a true issuer certificate + c.ocspResponseCache[dummyKey] = []interface{}{float64(currentTime - 1000), actualOcspResponse} + ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) + if err == nil { + t.Fatalf("should have failed.") + } + // invalid validity + c.ocspResponseCache[dummyKey] = []interface{}{float64(currentTime - 1000), actualOcspResponse} + ost, err = c.checkOCSPResponseCache(&dummyKey, subject, nil) + if err == nil { + t.Fatalf("should have failed.") + } + // wrong timestamp type + c.ocspResponseCache[dummyKey] = []interface{}{uint32(currentTime - 1000), 123456} + ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) + if err == nil { + t.Fatalf("should have failed.") + } + + // wrong value type + c.ocspResponseCache[dummyKey] = []interface{}{float64(currentTime - 1000), 123456} + ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) + if err == nil { + t.Fatalf("should have failed.") + } +} + +func TestUnitValidateOCSP(t *testing.T) { + ocspRes := &ocsp.Response{} + ost, err := validateOCSP(ocspRes) + if err == nil { + t.Fatalf("should have failed.") + } + + currentTime := time.Now() + ocspRes.ThisUpdate = currentTime.Add(-2 * time.Hour) + ocspRes.NextUpdate = currentTime.Add(2 * time.Hour) + ocspRes.Status = ocsp.Revoked + ost, err = validateOCSP(ocspRes) + if err != nil { + t.Fatal(err) + } + + if ost.code != ocspStatusRevoked { + t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusRevoked, ost.code) + } + ocspRes.Status = ocsp.Good + ost, err = validateOCSP(ocspRes) + if err != nil { + t.Fatal(err) + } + + if ost.code != ocspStatusGood { + t.Fatalf("should have success. expected: %v, got: %v", ocspStatusGood, ost.code) + } + ocspRes.Status = ocsp.Unknown + ost, err = validateOCSP(ocspRes) + if err != nil { + t.Fatal(err) + } + if ost.code != ocspStatusUnknown { + t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusUnknown, ost.code) + } + ocspRes.Status = ocsp.ServerFailed + ost, err = validateOCSP(ocspRes) + if err != nil { + t.Fatal(err) + } + if ost.code != ocspStatusOthers { + t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusOthers, ost.code) + } +} + +func TestUnitEncodeCertID(t *testing.T) { + var st *ocspStatus + _, st = extractCertIDKeyFromRequest([]byte{0x1, 0x2}) + if st.code != ocspFailedDecomposeRequest { + t.Fatalf("failed to get OCSP status. expected: %v, got: %v", ocspFailedDecomposeRequest, st.code) + } +} + +func getCert(addr string) []*x509.Certificate { + tcpConn, err := net.DialTimeout("tcp", addr, 40*time.Second) + if err != nil { + panic(err) + } + defer tcpConn.Close() + + err = tcpConn.SetDeadline(time.Now().Add(10 * time.Second)) + if err != nil { + panic(err) + } + config := tls.Config{InsecureSkipVerify: true, ServerName: addr} + + conn := tls.Client(tcpConn, &config) + defer conn.Close() + + err = conn.Handshake() + if err != nil { + panic(err) + } + + state := conn.ConnectionState() + + return state.PeerCertificates +} + +func TestOCSPRetry(t *testing.T) { + c := New(testLogFactory) + certs := getCert("s3-us-west-2.amazonaws.com:443") + dummyOCSPHost := &url.URL{ + Scheme: "https", + Host: "dummyOCSPHost", + } + client := &fakeHTTPClient{ + cnt: 3, + success: true, + body: []byte{1, 2, 3}, + logger: hclog.New(hclog.DefaultOptions), + } + res, b, st := c.retryOCSP( + context.TODO(), + client, fakeRequestFunc, + dummyOCSPHost, + make(map[string]string), []byte{0}, certs[len(certs)-1], 10*time.Second) + if st.err == nil { + fmt.Printf("should fail: %v, %v, %v\n", res, b, st) + } + client = &fakeHTTPClient{ + cnt: 30, + success: true, + body: []byte{1, 2, 3}, + logger: hclog.New(hclog.DefaultOptions), + } + res, b, st = c.retryOCSP( + context.TODO(), + client, fakeRequestFunc, + dummyOCSPHost, + make(map[string]string), []byte{0}, certs[len(certs)-1], 5*time.Second) + if st.err == nil { + fmt.Printf("should fail: %v, %v, %v\n", res, b, st) + } +} + +func TestOCSPCacheServerRetry(t *testing.T) { + c := New(testLogFactory) + dummyOCSPHost := &url.URL{ + Scheme: "https", + Host: "dummyOCSPHost", + } + client := &fakeHTTPClient{ + cnt: 3, + success: true, + body: []byte{1, 2, 3}, + logger: hclog.New(hclog.DefaultOptions), + } + res, st := c.checkOCSPCacheServer( + context.TODO(), client, fakeRequestFunc, dummyOCSPHost, 20*time.Second) + if st.err == nil { + t.Errorf("should fail: %v", res) + } + client = &fakeHTTPClient{ + cnt: 30, + success: true, + body: []byte{1, 2, 3}, + logger: hclog.New(hclog.DefaultOptions), + } + res, st = c.checkOCSPCacheServer( + context.TODO(), client, fakeRequestFunc, dummyOCSPHost, 10*time.Second) + if st.err == nil { + t.Errorf("should fail: %v", res) + } +} + +type tcCanEarlyExit struct { + results []*ocspStatus + resultLen int + retFailOpen *ocspStatus + retFailClosed *ocspStatus +} + +func TestCanEarlyExitForOCSP(t *testing.T) { + testcases := []tcCanEarlyExit{ + { // 0 + results: []*ocspStatus{ + { + code: ocspStatusGood, + }, + { + code: ocspStatusGood, + }, + { + code: ocspStatusGood, + }, + }, + retFailOpen: nil, + retFailClosed: nil, + }, + { // 1 + results: []*ocspStatus{ + { + code: ocspStatusRevoked, + err: errors.New("revoked"), + }, + { + code: ocspStatusGood, + }, + { + code: ocspStatusGood, + }, + }, + retFailOpen: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, + retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, + }, + { // 2 + results: []*ocspStatus{ + { + code: ocspStatusUnknown, + err: errors.New("unknown"), + }, + { + code: ocspStatusGood, + }, + { + code: ocspStatusGood, + }, + }, + retFailOpen: nil, + retFailClosed: &ocspStatus{ocspStatusUnknown, errors.New("unknown")}, + }, + { // 3: not taken as revoked if any invalid OCSP response (ocspInvalidValidity) is included. + results: []*ocspStatus{ + { + code: ocspStatusRevoked, + err: errors.New("revoked"), + }, + { + code: ocspInvalidValidity, + }, + { + code: ocspStatusGood, + }, + }, + retFailOpen: nil, + retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, + }, + { // 4: not taken as revoked if the number of results don't match the expected results. + results: []*ocspStatus{ + { + code: ocspStatusRevoked, + err: errors.New("revoked"), + }, + { + code: ocspStatusGood, + }, + }, + resultLen: 3, + retFailOpen: nil, + retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, + }, + } + c := New(testLogFactory) + for idx, tt := range testcases { + ocspFailOpen = OCSPFailOpenTrue + expectedLen := len(tt.results) + if tt.resultLen > 0 { + expectedLen = tt.resultLen + } + r := c.canEarlyExitForOCSP(tt.results, expectedLen) + if !(tt.retFailOpen == nil && r == nil) && !(tt.retFailOpen != nil && r != nil && tt.retFailOpen.code == r.code) { + t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailOpen, r) + } + ocspFailOpen = OCSPFailOpenFalse + r = c.canEarlyExitForOCSP(tt.results, expectedLen) + if !(tt.retFailClosed == nil && r == nil) && !(tt.retFailClosed != nil && r != nil && tt.retFailClosed.code == r.code) { + t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailClosed, r) + } + } +} + +var testLogger = hclog.New(hclog.DefaultOptions) + +func testLogFactory() hclog.Logger { + return testLogger +} diff --git a/sdk/helper/ocsp/retry.go b/sdk/helper/ocsp/retry.go new file mode 100644 index 000000000000..4060abb87a18 --- /dev/null +++ b/sdk/helper/ocsp/retry.go @@ -0,0 +1,550 @@ +// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. + +package ocsp + +import ( + "bytes" + "context" + "crypto/x509" + "fmt" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-uuid" + "io" + "math/rand" + "net/http" + "net/url" + "runtime" + "strconv" + "strings" + "sync" + "time" +) + +const ( + httpHeaderContentType = "Content-Type" + httpHeaderAccept = "accept" + httpHeaderUserAgent = "User-Agent" + httpHeaderServiceName = "X-Snowflake-Service" + httpHeaderContentLength = "Content-Length" + httpHeaderHost = "Host" + httpHeaderValueOctetStream = "application/octet-stream" + httpHeaderContentEncoding = "Content-Encoding" +) + +var random *rand.Rand + +func init() { + random = rand.New(rand.NewSource(time.Now().UnixNano())) +} + +const ( + // requestGUIDKey is attached to every request against Snowflake + requestGUIDKey string = "request_guid" + // retryCounterKey is attached to query-request from the second time + retryCounterKey string = "retryCounter" + // requestIDKey is attached to all requests to Snowflake + requestIDKey string = "requestId" +) + +// This class takes in an url during construction and replaces the value of +// request_guid every time replace() is called. If the url does not contain +// request_guid, just return the original url +type requestGUIDReplacer interface { + // replace the url with new ID + replace() *url.URL +} + +// Make requestGUIDReplacer given a url string +func newRequestGUIDReplace(urlPtr *url.URL) requestGUIDReplacer { + values, err := url.ParseQuery(urlPtr.RawQuery) + if err != nil { + // nop if invalid query parameters + return &transientReplace{urlPtr} + } + if len(values.Get(requestGUIDKey)) == 0 { + // nop if no request_guid is included. + return &transientReplace{urlPtr} + } + + return &requestGUIDReplace{urlPtr, values} +} + +// this replacer does nothing but replace the url +type transientReplace struct { + urlPtr *url.URL +} + +func (replacer *transientReplace) replace() *url.URL { + return replacer.urlPtr +} + +/* +requestGUIDReplacer is a one-shot object that is created out of the retry loop and +called with replace to change the retry_guid's value upon every retry +*/ +type requestGUIDReplace struct { + urlPtr *url.URL + urlValues url.Values +} + +/** +This function would replace they value of the requestGUIDKey in a url with a newly +generated UUID +*/ +func (replacer *requestGUIDReplace) replace() *url.URL { + replacer.urlValues.Del(requestGUIDKey) + uuid, _ := uuid.GenerateUUID() + replacer.urlValues.Add(requestGUIDKey, uuid) + replacer.urlPtr.RawQuery = replacer.urlValues.Encode() + return replacer.urlPtr +} + +type retryCounterUpdater interface { + replaceOrAdd(retry int) *url.URL +} + +type retryCounterUpdate struct { + urlPtr *url.URL + urlValues url.Values +} + +// this replacer does nothing but replace the url +type transientReplaceOrAdd struct { + urlPtr *url.URL +} + +func (replaceOrAdder *transientReplaceOrAdd) replaceOrAdd(retry int) *url.URL { + return replaceOrAdder.urlPtr +} + +func (replacer *retryCounterUpdate) replaceOrAdd(retry int) *url.URL { + replacer.urlValues.Del(retryCounterKey) + replacer.urlValues.Add(retryCounterKey, strconv.Itoa(retry)) + replacer.urlPtr.RawQuery = replacer.urlValues.Encode() + return replacer.urlPtr +} + +// Snowflake Server Endpoints +const ( + loginRequestPath = "/session/v1/login-request" + queryRequestPath = "/queries/v1/query-request" + tokenRequestPath = "/session/token-request" + abortRequestPath = "/queries/v1/abort-request" + authenticatorRequestPath = "/session/authenticator-request" + sessionRequestPath = "/session" + heartBeatPath = "/session/heartbeat" +) + +func newRetryUpdate(urlPtr *url.URL) retryCounterUpdater { + if !strings.HasPrefix(urlPtr.Path, queryRequestPath) { + // nop if not query-request + return &transientReplaceOrAdd{urlPtr} + } + values, err := url.ParseQuery(urlPtr.RawQuery) + if err != nil { + // nop if the URL is not valid + return &transientReplaceOrAdd{urlPtr} + } + return &retryCounterUpdate{urlPtr, values} +} + +type waitAlgo struct { + mutex *sync.Mutex // required for random.Int63n + base time.Duration // base wait time + cap time.Duration // maximum wait time +} + +func randSecondDuration(n time.Duration) time.Duration { + return time.Duration(random.Int63n(int64(n/time.Second))) * time.Second +} + +// decorrelated jitter backoff +func (w *waitAlgo) decorr(attempt int, sleep time.Duration) time.Duration { + w.mutex.Lock() + defer w.mutex.Unlock() + t := 3*sleep - w.base + switch { + case t > 0: + return durationMin(w.cap, randSecondDuration(t)+w.base) + case t < 0: + return durationMin(w.cap, randSecondDuration(-t)+3*sleep) + } + return w.base +} + +var defaultWaitAlgo = &waitAlgo{ + mutex: &sync.Mutex{}, + base: 5 * time.Second, + cap: 160 * time.Second, +} + +type requestFunc func(method, urlStr string, body io.Reader) (*http.Request, error) + +type clientInterface interface { + Do(req *http.Request) (*http.Response, error) +} + +type retryHTTP struct { + ctx context.Context + client clientInterface + req requestFunc + method string + fullURL *url.URL + headers map[string]string + body []byte + timeout time.Duration + raise4XX bool + logger hclog.Logger +} + +func newRetryHTTP(ctx context.Context, + client clientInterface, + req requestFunc, + fullURL *url.URL, + headers map[string]string, + timeout time.Duration) *retryHTTP { + instance := retryHTTP{} + instance.ctx = ctx + instance.client = client + instance.req = req + instance.method = "GET" + instance.fullURL = fullURL + instance.headers = headers + instance.body = nil + instance.timeout = timeout + instance.raise4XX = false + instance.logger = hclog.New(hclog.DefaultOptions) + return &instance +} + +func (r *retryHTTP) doRaise4XX(raise4XX bool) *retryHTTP { + r.raise4XX = raise4XX + return r +} + +func (r *retryHTTP) doPost() *retryHTTP { + r.method = "POST" + return r +} + +func (r *retryHTTP) setBody(body []byte) *retryHTTP { + r.body = body + return r +} + +func (r *retryHTTP) execute() (res *http.Response, err error) { + totalTimeout := r.timeout + r.logger.Info("retryHTTP", "totalTimeout", totalTimeout) + retryCounter := 0 + sleepTime := time.Duration(0) + + var rIDReplacer requestGUIDReplacer + var rUpdater retryCounterUpdater + + for { + r.logger.Debug("retry count", "retryCounter", retryCounter) + req, err := r.req(r.method, r.fullURL.String(), bytes.NewReader(r.body)) + if err != nil { + return nil, err + } + if req != nil { + // req can be nil in tests + req = req.WithContext(r.ctx) + } + for k, v := range r.headers { + req.Header.Set(k, v) + } + res, err = r.client.Do(req) + if err != nil { + // check if it can retry. + doExit, err := r.isRetryableError(err) + if doExit { + return res, err + } + // cannot just return 4xx and 5xx status as the error can be sporadic. run often helps. + r.logger.Warn( + "failed http connection. no response is returned. retrying...", "err", err) + } else { + if res.StatusCode == http.StatusOK || r.raise4XX && res != nil && res.StatusCode >= 400 && res.StatusCode < 500 { + // exit if success + // or + // abort connection if raise4XX flag is enabled and the range of HTTP status code are 4XX. + // This is currently used for Snowflake login. The caller must generate an error object based on HTTP status. + break + } + r.logger.Warn( + "failed http connection. retrying...\n", "statusCode", res.StatusCode) + res.Body.Close() + } + // uses decorrelated jitter backoff + sleepTime = defaultWaitAlgo.decorr(retryCounter, sleepTime) + + if totalTimeout > 0 { + r.logger.Info("to timeout: ", "totalTimeout", totalTimeout) + // if any timeout is set + totalTimeout -= sleepTime + if totalTimeout <= 0 { + if err != nil { + return nil, err + } + if res != nil { + return nil, fmt.Errorf("timeout after %s. HTTP Status: %v. Hanging?", r.timeout, res.StatusCode) + } + return nil, fmt.Errorf("timeout after %s. Hanging?", r.timeout) + } + } + retryCounter++ + if rIDReplacer == nil { + rIDReplacer = newRequestGUIDReplace(r.fullURL) + } + r.fullURL = rIDReplacer.replace() + if rUpdater == nil { + rUpdater = newRetryUpdate(r.fullURL) + } + r.fullURL = rUpdater.replaceOrAdd(retryCounter) + r.logger.Info("sleeping to retry", "sleepTime", sleepTime, "totalTimeout", totalTimeout) + + await := time.NewTimer(sleepTime) + select { + case <-await.C: + // retry the request + case <-r.ctx.Done(): + await.Stop() + return res, r.ctx.Err() + } + } + return res, err +} + +func (r *retryHTTP) isRetryableError(err error) (bool, error) { + urlError, isURLError := err.(*url.Error) + if isURLError { + // context cancel or timeout + if urlError.Err == context.DeadlineExceeded || urlError.Err == context.Canceled { + return true, urlError.Err + } + if urlError.Err.Error() == "OCSP status revoked" { + // Certificate Revoked + return true, nil + } + if _, ok := urlError.Err.(x509.CertificateInvalidError); ok { + // Certificate is invalid + return true, err + } + if _, ok := urlError.Err.(x509.UnknownAuthorityError); ok { + // Certificate is self-signed + return true, err + } + errString := urlError.Err.Error() + if runtime.GOOS == "darwin" && strings.HasPrefix(errString, "x509:") && strings.HasSuffix(errString, "certificate is expired") { + // Certificate is expired + return true, err + } + + } + return false, err +} + +/* + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ diff --git a/sdk/helper/ocsp/retry_test.go b/sdk/helper/ocsp/retry_test.go new file mode 100644 index 000000000000..8a68f13217b9 --- /dev/null +++ b/sdk/helper/ocsp/retry_test.go @@ -0,0 +1,271 @@ +// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. + +package ocsp + +import ( + "context" + "github.com/hashicorp/go-hclog" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "testing" + "time" +) + +func fakeRequestFunc(_, _ string, _ io.Reader) (*http.Request, error) { + return nil, nil +} + +type fakeHTTPError struct { + err string + timeout bool +} + +func (e *fakeHTTPError) Error() string { return e.err } +func (e *fakeHTTPError) Timeout() bool { return e.timeout } +func (e *fakeHTTPError) Temporary() bool { return true } + +type fakeResponseBody struct { + body []byte + cnt int +} + +func (b *fakeResponseBody) Read(p []byte) (n int, err error) { + if b.cnt == 0 { + copy(p, b.body) + b.cnt = 1 + return len(b.body), nil + } + b.cnt = 0 + return 0, io.EOF +} + +func (b *fakeResponseBody) Close() error { + return nil +} + +type fakeHTTPClient struct { + cnt int // number of retry + success bool // return success after retry in cnt times + timeout bool // timeout + body []byte // return body + logger hclog.Logger +} + +func (c *fakeHTTPClient) Do(req *http.Request) (*http.Response, error) { + c.cnt-- + if c.cnt < 0 { + c.cnt = 0 + } + c.logger.Info("fakeHTTPClient", "cnt", c.cnt) + + var retcode int + if c.success && c.cnt == 0 { + retcode = 200 + } else { + if c.timeout { + // simulate timeout + time.Sleep(time.Second * 1) + return nil, &fakeHTTPError{ + err: "Whatever reason (Client.Timeout exceeded while awaiting headers)", + timeout: true, + } + } + retcode = 0 + } + + ret := &http.Response{ + StatusCode: retcode, + Body: &fakeResponseBody{body: c.body}, + } + return ret, nil +} + +func TestRequestGUID(t *testing.T) { + var ridReplacer requestGUIDReplacer + var testURL *url.URL + var actualURL *url.URL + retryTime := 4 + + // empty url + testURL = &url.URL{} + ridReplacer = newRequestGUIDReplace(testURL) + for i := 0; i < retryTime; i++ { + actualURL = ridReplacer.replace() + if actualURL.String() != "" { + t.Fatalf("empty url not replaced by an empty one, got %s", actualURL) + } + } + + // url with on retry id + testURL = &url.URL{ + Path: "/" + requestIDKey + "=123-1923-9?param2=value", + } + ridReplacer = newRequestGUIDReplace(testURL) + for i := 0; i < retryTime; i++ { + actualURL = ridReplacer.replace() + + if actualURL != testURL { + t.Fatalf("url without retry id not replaced by origin one, got %s", actualURL) + } + } + + // url with retry id + // With both prefix and suffix + prefix := "/" + requestIDKey + "=123-1923-9?" + requestGUIDKey + "=" + suffix := "?param2=value" + testURL = &url.URL{ + Path: prefix + "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + suffix, + } + ridReplacer = newRequestGUIDReplace(testURL) + for i := 0; i < retryTime; i++ { + actualURL = ridReplacer.replace() + if (!strings.HasPrefix(actualURL.Path, prefix)) || + (!strings.HasSuffix(actualURL.Path, suffix)) || + len(testURL.Path) != len(actualURL.Path) { + t.Fatalf("Retry url not replaced correctedly: \n origin: %s \n result: %s", testURL, actualURL) + } + } + + // With no suffix + prefix = "/" + requestIDKey + "=123-1923-9?" + requestGUIDKey + "=" + suffix = "" + testURL = &url.URL{ + Path: prefix + "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + suffix, + } + ridReplacer = newRequestGUIDReplace(testURL) + for i := 0; i < retryTime; i++ { + actualURL = ridReplacer.replace() + if (!strings.HasPrefix(actualURL.Path, prefix)) || + (!strings.HasSuffix(actualURL.Path, suffix)) || + len(testURL.Path) != len(actualURL.Path) { + t.Fatalf("Retry url not replaced correctedly: \n origin: %s \n result: %s", testURL, actualURL) + } + + } + // With no prefix + prefix = requestGUIDKey + "=" + suffix = "?param2=value" + testURL = &url.URL{ + Path: prefix + "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + suffix, + } + ridReplacer = newRequestGUIDReplace(testURL) + for i := 0; i < retryTime; i++ { + actualURL = ridReplacer.replace() + if (!strings.HasPrefix(actualURL.Path, prefix)) || + (!strings.HasSuffix(actualURL.Path, suffix)) || + len(testURL.Path) != len(actualURL.Path) { + t.Fatalf("Retry url not replaced correctedly: \n origin: %s \n result: %s", testURL, actualURL) + } + } +} + +func TestRetryQuerySuccess(t *testing.T) { + c := New(testLogFactory) + c.Logger().Info("Retry N times and Success") + client := &fakeHTTPClient{ + cnt: 3, + success: true, + } + urlPtr, err := url.Parse("https://fakeaccountretrysuccess.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey + "=testid&clientStartTime=123456") + if err != nil { + t.Fatal("failed to parse the test URL") + } + _, err = newRetryHTTP(context.TODO(), + client, + fakeRequestFunc, urlPtr, make(map[string]string), 60*time.Second).doPost().setBody([]byte{0}).execute() + if err != nil { + t.Fatal("failed to run retry") + } + var values url.Values + values, err = url.ParseQuery(urlPtr.RawQuery) + if err != nil { + t.Fatal("failed to fail to parse the URL") + } + retry, err := strconv.Atoi(values.Get(retryCounterKey)) + if err != nil { + t.Fatalf("failed to get retry counter: %v", err) + } + if retry < 2 { + t.Fatalf("not enough retry counter: %v", retry) + } +} +func TestRetryQueryFail(t *testing.T) { + c := New(testLogFactory) + c.Logger().Info("Retry N times and Fail") + client := &fakeHTTPClient{ + cnt: 4, + success: false, + } + urlPtr, err := url.Parse("https://fakeaccountretryfail.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey + "=testid&clientStartTime=123456") + if err != nil { + t.Fatal("failed to parse the test URL") + } + _, err = newRetryHTTP(context.TODO(), + client, + fakeRequestFunc, urlPtr, make(map[string]string), 60*time.Second).doPost().setBody([]byte{0}).execute() + if err == nil { + t.Fatal("should fail to run retry") + } + var values url.Values + values, err = url.ParseQuery(urlPtr.RawQuery) + if err != nil { + t.Fatalf("failed to fail to parse the URL: %v", err) + } + retry, err := strconv.Atoi(values.Get(retryCounterKey)) + if err != nil { + t.Fatalf("failed to get retry counter: %v", err) + } + if retry < 2 { + t.Fatalf("not enough retry counter: %v", retry) + } +} +func TestRetryLoginRequest(t *testing.T) { + client := &fakeHTTPClient{ + cnt: 3, + success: true, + timeout: true, + logger: hclog.New(hclog.DefaultOptions), + } + client.logger.Info("Retry N times for timeouts and Success") + urlPtr, err := url.Parse("https://fakeaccountretrylogin.snowflakecomputing.com:443/login-request?request_id=testid") + if err != nil { + t.Fatal("failed to parse the test URL") + } + _, err = newRetryHTTP(context.TODO(), + client, + fakeRequestFunc, urlPtr, make(map[string]string), 60*time.Second).doPost().setBody([]byte{0}).execute() + if err != nil { + t.Fatal("failed to run retry") + } + var values url.Values + values, err = url.ParseQuery(urlPtr.RawQuery) + if err != nil { + t.Fatalf("failed to fail to parse the URL: %v", err) + } + if values.Get(retryCounterKey) != "" { + t.Fatalf("no retry counter should be attached: %v", retryCounterKey) + } + client.logger.Info("Retry N times for timeouts and Fail") + client = &fakeHTTPClient{ + cnt: 10, + success: false, + timeout: true, + logger: hclog.New(hclog.DefaultOptions), + } + _, err = newRetryHTTP(context.TODO(), + client, + fakeRequestFunc, urlPtr, make(map[string]string), 10*time.Second).doPost().setBody([]byte{0}).execute() + if err == nil { + t.Fatal("should fail to run retry") + } + values, err = url.ParseQuery(urlPtr.RawQuery) + if err != nil { + t.Fatalf("failed to fail to parse the URL: %v", err) + } + if values.Get(retryCounterKey) != "" { + t.Fatalf("no retry counter should be attached: %v", retryCounterKey) + } +} From 5b6303265b9593b5fe256e284fa030142362bf35 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Fri, 9 Sep 2022 14:53:22 -0500 Subject: [PATCH 02/39] Add cached OCSP client support to Cert Auth --- builtin/credential/cert/backend.go | 23 +- builtin/credential/cert/path_certs.go | 29 +- builtin/credential/cert/path_login.go | 62 ++- sdk/go.mod | 2 + sdk/go.sum | 2 + sdk/helper/ocsp/client.go | 500 +++++++++++------------ sdk/helper/ocsp/ocsp_test.go | 130 ++++-- sdk/helper/ocsp/retry.go | 550 -------------------------- sdk/helper/ocsp/retry_test.go | 271 ------------- 9 files changed, 420 insertions(+), 1149 deletions(-) delete mode 100644 sdk/helper/ocsp/retry.go delete mode 100644 sdk/helper/ocsp/retry_test.go diff --git a/builtin/credential/cert/backend.go b/builtin/credential/cert/backend.go index db999426e90a..adffe2c3062d 100644 --- a/builtin/credential/cert/backend.go +++ b/builtin/credential/cert/backend.go @@ -2,6 +2,8 @@ package cert import ( "context" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/helper/ocsp" "strings" "sync" @@ -14,6 +16,10 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, if err := b.Setup(ctx, conf); err != nil { return nil, err } + + if err := b.ocspClient.ReadCache(ctx, conf.StorageView); err != nil { + return nil, err + } return b, nil } @@ -33,13 +39,16 @@ func Backend() *backend { pathCerts(&b), pathCRLs(&b), }, - AuthRenew: b.pathLoginRenew, - Invalidate: b.invalidate, - BackendType: logical.TypeCredential, + AuthRenew: b.pathLoginRenew, + Invalidate: b.invalidate, + BackendType: logical.TypeCredential, + PeriodicFunc: b.periodFunc, } b.crlUpdateMutex = &sync.RWMutex{} - + b.ocspClient = ocsp.New(func() hclog.Logger { + return b.Logger() + }) return &b } @@ -48,7 +57,9 @@ type backend struct { MapCertId *framework.PathMap crls map[string]CRLInfo + ocspDisabled bool crlUpdateMutex *sync.RWMutex + ocspClient *ocsp.Client } func (b *backend) invalidate(_ context.Context, key string) { @@ -60,6 +71,10 @@ func (b *backend) invalidate(_ context.Context, key string) { } } +func (b *backend) periodFunc(ctx context.Context, request *logical.Request) error { + return b.ocspClient.WriteCache(ctx, request.Storage) +} + const backendHelp = ` The "cert" credential provider allows authentication using TLS client certificates. A client connects to Vault and uses diff --git a/builtin/credential/cert/path_certs.go b/builtin/credential/cert/path_certs.go index 00e103b51a56..b1d9e53470fc 100644 --- a/builtin/credential/cert/path_certs.go +++ b/builtin/credential/cert/path_certs.go @@ -47,7 +47,22 @@ Must be x509 PEM encoded.`, EditType: "file", }, }, - + "ocsp_enabled": { + Type: framework.TypeBool, + Description: `Whether to attempt OCSP verification of certificates at login`, + }, + "ocsp_ca_certificates": { + Type: framework.TypeString, + Description: `Any additional CA certificates needed to communicate with OCSP servers`, + DisplayAttrs: &framework.DisplayAttributes{ + EditType: "file", + }, + }, + "ocsp_servers_override": { + Type: framework.TypeCommaStringSlice, + Description: `A comma-separated list of OCSP server addresses. If unset, the OCSP server is determined +from the AuthorityInformationAccess extension on the certificate being inspected.`, + }, "allowed_names": { Type: framework.TypeCommaStringSlice, Description: `A comma-separated list of names. @@ -294,9 +309,18 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr if certificateRaw, ok := d.GetOk("certificate"); ok { cert.Certificate = certificateRaw.(string) } + if ocspCertificatesRaw, ok := d.GetOk("ocsp_ca_certificates"); ok { + cert.OcspCaCertificates = ocspCertificatesRaw.(string) + } + if ocspEnabledRaw, ok := d.GetOk("ocsp_enabled"); ok { + cert.OcspEnabled = ocspEnabledRaw.(bool) + } if displayNameRaw, ok := d.GetOk("display_name"); ok { cert.DisplayName = displayNameRaw.(string) } + if ocspServerOverrides, ok := d.GetOk("ocsp_servers_override"); ok { + cert.OcspServersOverride = ocspServerOverrides.([]string) + } if allowedNamesRaw, ok := d.GetOk("allowed_names"); ok { cert.AllowedNames = allowedNamesRaw.([]string) } @@ -424,6 +448,9 @@ type CertEntry struct { Name string Certificate string + OcspCaCertificates string + OcspEnabled bool + OcspServersOverride []string DisplayName string Policies []string TTL time.Duration diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index c78188b7cc9b..868d7f40b9bb 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -224,8 +224,8 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d certName = d.Get("name").(string) } - // Load the trusted certificates - roots, trusted, trustedNonCAs := b.loadTrustedCerts(ctx, req.Storage, certName) + // Load the trusted certificates and other details + roots, trusted, trustedNonCAs, ocspServersOverride := b.loadTrustedCerts(ctx, req.Storage, certName) // Get the list of full chains matching the connection and validates the // certificate itself @@ -234,6 +234,11 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d return nil, nil, err } + var extraCas []*x509.Certificate + for _, t := range trusted { + extraCas = append(extraCas, t.Certificates...) + } + // If trustedNonCAs is not empty it means that client had registered a non-CA cert // with the backend. if len(trustedNonCAs) != 0 { @@ -241,9 +246,14 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d tCert := trustedNonCA.Certificates[0] // Check for client cert being explicitly listed in the config (and matching other constraints) if tCert.SerialNumber.Cmp(clientCert.SerialNumber) == 0 && - bytes.Equal(tCert.AuthorityKeyId, clientCert.AuthorityKeyId) && - b.matchesConstraints(clientCert, trustedNonCA.Certificates, trustedNonCA) { - return trustedNonCA, nil, nil + bytes.Equal(tCert.AuthorityKeyId, clientCert.AuthorityKeyId) { + matches, err := b.matchesConstraints(ctx, clientCert, trustedNonCA.Certificates, trustedNonCA, extraCas, ocspServersOverride) + if err != nil { + return nil, nil, err + } + if matches { + return trustedNonCA, nil, nil + } } } } @@ -260,10 +270,15 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d for _, tCert := range trust.Certificates { // For each certificate in the entry for _, chain := range trustedChains { // For each root chain that we matched for _, cCert := range chain { // For each cert in the matched chain - if tCert.Equal(cCert) && // ParsedCert intersects with matched chain - b.matchesConstraints(clientCert, chain, trust) { // validate client cert + matched chain against the config - // Add the match to the list - matches = append(matches, trust) + if tCert.Equal(cCert) { // ParsedCert intersects with matched chain + match, err := b.matchesConstraints(ctx, clientCert, chain, trust, extraCas, ocspServersOverride) // validate client cert + matched chain against the config + if err != nil { + return nil, nil, err + } + if match { + // Add the match to the list + matches = append(matches, trust) + } } } } @@ -279,8 +294,9 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d return matches[0], nil, nil } -func (b *backend) matchesConstraints(clientCert *x509.Certificate, trustedChain []*x509.Certificate, config *ParsedCert) bool { - return !b.checkForChainInCRLs(trustedChain) && +func (b *backend) matchesConstraints(ctx context.Context, clientCert *x509.Certificate, trustedChain []*x509.Certificate, + config *ParsedCert, extraCas []*x509.Certificate, ocspServersOverride []string) (bool, error) { + soFar := !b.checkForChainInCRLs(trustedChain) && b.matchesNames(clientCert, config) && b.matchesCommonName(clientCert, config) && b.matchesDNSSANs(clientCert, config) && @@ -288,6 +304,14 @@ func (b *backend) matchesConstraints(clientCert *x509.Certificate, trustedChain b.matchesURISANs(clientCert, config) && b.matchesOrganizationalUnits(clientCert, config) && b.matchesCertificateExtensions(clientCert, config) + if config.Entry.OcspEnabled { + ocspGood, err := b.checkForChainInOCSP(ctx, trustedChain, extraCas, ocspServersOverride) + if err != nil { + return false, err + } + soFar = soFar && ocspGood + } + return soFar, nil } // matchesNames verifies that the certificate matches at least one configured @@ -478,7 +502,7 @@ func (b *backend) certificateExtensionsMetadata(clientCert *x509.Certificate, co } // loadTrustedCerts is used to load all the trusted certificates from the backend -func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert) { +func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert, ocspServersOverride []string) { pool = x509.NewCertPool() trusted = make([]*ParsedCert, 0) trustedNonCAs = make([]*ParsedCert, 0) @@ -512,6 +536,8 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, b.Logger().Error("failed to parse certificate", "name", name) continue } + parsed = append(parsed, parsePEM([]byte(entry.OcspCaCertificates))...) + if !parsed[0].IsCA { trustedNonCAs = append(trustedNonCAs, &ParsedCert{ Entry: entry, @@ -528,10 +554,22 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, Certificates: parsed, }) } + ocspServersOverride = entry.OcspServersOverride } return } +func (b *backend) checkForChainInOCSP(ctx context.Context, chain []*x509.Certificate, extraCas []*x509.Certificate, ocspServersOverride []string) (bool, error) { + if b.ocspDisabled || len(chain) < 2 { + return true, nil + } + err := b.ocspClient.VerifyPeerCertificate(ctx, [][]*x509.Certificate{chain}, extraCas, ocspServersOverride) + if err != nil { + return false, nil + } + return true, nil +} + func (b *backend) checkForChainInCRLs(chain []*x509.Certificate) bool { badChain := false for _, cert := range chain { diff --git a/sdk/go.mod b/sdk/go.mod index 6945d15d77b2..c855138f22b3 100644 --- a/sdk/go.mod +++ b/sdk/go.mod @@ -17,6 +17,7 @@ require ( github.com/hashicorp/go-kms-wrapping/entropy v0.1.0 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-plugin v1.4.3 + github.com/hashicorp/go-retryablehttp v0.5.3 github.com/hashicorp/go-secure-stdlib/base62 v0.1.1 github.com/hashicorp/go-secure-stdlib/mlock v0.1.1 github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6 @@ -45,6 +46,7 @@ require ( github.com/fatih/color v1.7.0 // indirect github.com/frankban/quicktest v1.10.0 // indirect github.com/go-asn1-ber/asn1-ber v1.3.1 // indirect + github.com/hashicorp/go-cleanhttp v0.5.0 // indirect github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.6 // indirect diff --git a/sdk/go.sum b/sdk/go.sum index 7fd2f4ad8029..0f26be46fa85 100644 --- a/sdk/go.sum +++ b/sdk/go.sum @@ -88,6 +88,7 @@ github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFb github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.0 h1:wvCrVc9TjDls6+YGAF2hAifE1E5U1+b4tH6KdvN3Gig= github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= github.com/hashicorp/go-hclog v0.14.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= github.com/hashicorp/go-hclog v0.16.2 h1:K4ev2ib4LdQETX5cSZBG0DVLk1jwGqSPXBjdah3veNs= @@ -102,6 +103,7 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-plugin v1.4.3 h1:DXmvivbWD5qdiBts9TpBC7BYL1Aia5sxbRgQB+v6UZM= github.com/hashicorp/go-plugin v1.4.3/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ= +github.com/hashicorp/go-retryablehttp v0.5.3 h1:QlWt0KvWT0lq8MFppF9tsJGF+ynG7ztc2KIPhzRGk7s= github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= github.com/hashicorp/go-secure-stdlib/base62 v0.1.1 h1:6KMBnfEv0/kLAz0O76sliN5mXbCDcLfs2kP7ssP7+DQ= github.com/hashicorp/go-secure-stdlib/base62 v0.1.1/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index dda265e70a13..b696caae959b 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -3,6 +3,7 @@ package ocsp import ( + "bytes" "context" "crypto" "crypto/tls" @@ -14,6 +15,7 @@ import ( "errors" "fmt" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" "golang.org/x/crypto/ocsp" "io" @@ -34,6 +36,19 @@ import ( // set to ocspModeFailClosed for fail closed mode type OCSPFailOpenMode uint32 +type requestFunc func(method, urlStr string, body io.Reader) (*http.Request, error) + +type clientInterface interface { + Do(req *http.Request) (*http.Response, error) +} + +const ( + httpHeaderContentType = "Content-Type" + httpHeaderAccept = "accept" + httpHeaderContentLength = "Content-Length" + httpHeaderHost = "Host" +) + const ( ocspFailOpenNotSet OCSPFailOpenMode = iota // OCSPFailOpenTrue represents OCSP fail open mode. @@ -41,14 +56,14 @@ const ( // OCSPFailOpenFalse represents OCSP fail closed mode. OCSPFailOpenFalse ) + const ( ocspModeFailOpen = "FAIL_OPEN" ocspModeFailClosed = "FAIL_CLOSED" ocspModeInsecure = "INSECURE" ) -// OCSP fail open mode -var ocspFailOpen = OCSPFailOpenTrue +const ocspCacheKey = "ocsp_cache" const ( // defaultOCSPCacheServerTimeout is the total timeout for OCSP cache server. @@ -83,16 +98,24 @@ const ( maxClockSkew = 900 * time.Second // buffer for clock skew ) +type ocspCachedResponse struct { + time float64 + resp string +} + type Client struct { // caRoot includes the CA certificates. caRoot map[string]*x509.Certificate // certPOol includes the CA certificates. certPool *x509.CertPool - ocspResponseCache map[certIDKey][]interface{} + ocspResponseCache map[certIDKey]ocspCachedResponse ocspResponseCacheLock sync.RWMutex // cacheUpdated is true if the memory cache is updated cacheUpdated bool logFactory func() hclog.Logger + + // OCSP fail open mode + ocspFailOpen OCSPFailOpenMode } type ocspStatusCode int @@ -147,14 +170,13 @@ var hashOIDs = map[crypto.Hash]asn1.ObjectIdentifier{ } // copied from crypto/ocsp -func (c *Client) getOIDFromHashAlgorithm(target crypto.Hash) asn1.ObjectIdentifier { +func getOIDFromHashAlgorithm(target crypto.Hash) (asn1.ObjectIdentifier, error) { for hash, oid := range hashOIDs { if hash == target { - return oid + return oid, nil } } - c.Logger().Error("no valid OID is found for the hash algorithm", "target", target) - return nil + return nil, fmt.Errorf("no valid OID is found for the hash algorithm: %v", target) } func (c *Client) getHashAlgorithmFromOID(target pkix.AlgorithmIdentifier) crypto.Hash { @@ -219,43 +241,47 @@ func extractCertIDKeyFromRequest(ocspReq []byte) (*certIDKey, *ocspStatus) { } } -func (c *Client) encodeCertIDKey(certIDKeyBase64 string) *certIDKey { +func (c *Client) encodeCertIDKey(certIDKeyBase64 string) (*certIDKey, error) { r, err := base64.StdEncoding.DecodeString(certIDKeyBase64) if err != nil { - return nil + return nil, err } var cid certID rest, err := asn1.Unmarshal(r, &cid) if err != nil { // error in parsing - return nil + return nil, err } if len(rest) > 0 { // extra bytes to the end - return nil + return nil, err } return &certIDKey{ c.getHashAlgorithmFromOID(cid.HashAlgorithm), base64.StdEncoding.EncodeToString(cid.NameHash), base64.StdEncoding.EncodeToString(cid.IssuerKeyHash), cid.SerialNumber.String(), - } + }, nil } -func (c *Client) decodeCertIDKey(k *certIDKey) string { +func decodeCertIDKey(k *certIDKey) (string, error) { serialNumber := new(big.Int) serialNumber.SetString(k.SerialNumber, 10) nameHash, err := base64.StdEncoding.DecodeString(k.NameHash) if err != nil { - return "" + return "", err } issuerKeyHash, err := base64.StdEncoding.DecodeString(k.IssuerKeyHash) if err != nil { - return "" + return "", err + } + hashAlgoOid, err := getOIDFromHashAlgorithm(k.HashAlgorithm) + if err != nil { + return "", err } encodedCertID, err := asn1.Marshal(certID{ pkix.AlgorithmIdentifier{ - Algorithm: c.getOIDFromHashAlgorithm(k.HashAlgorithm), + Algorithm: hashAlgoOid, Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */}, }, nameHash, @@ -263,9 +289,9 @@ func (c *Client) decodeCertIDKey(k *certIDKey) string { serialNumber, }) if err != nil { - return "" + return "", err } - return base64.StdEncoding.EncodeToString(encodedCertID) + return base64.StdEncoding.EncodeToString(encodedCertID), nil } func (c *Client) checkOCSPResponseCache(encodedCertID *certIDKey, subject, issuer *x509.Certificate) (*ocspStatus, error) { @@ -276,7 +302,7 @@ func (c *Client) checkOCSPResponseCache(encodedCertID *certIDKey, subject, issue gotValueFromCache := c.ocspResponseCache[*encodedCertID] c.ocspResponseCacheLock.RUnlock() - status, err := c.extractOCSPCacheResponseValue(gotValueFromCache, subject, issuer) + status, err := c.extractOCSPCacheResponseValue(&gotValueFromCache, subject, issuer) if err != nil { return nil, err } @@ -343,24 +369,22 @@ func (c *Client) checkOCSPCacheServer( ocspServerHost *url.URL, totalTimeout time.Duration) ( cacheContent *map[string][]interface{}, - ocspS *ocspStatus) { + ocspS *ocspStatus, err error) { var respd map[string][]interface{} - headers := make(map[string]string) - res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout).execute() + + request, err := req("GET", ocspServerHost.Hostname(), nil) if err != nil { - c.Logger().Error("failed to get OCSP cache from OCSP Cache Server. ", "err", err) - return nil, &ocspStatus{ - code: ocspFailedSubmit, - err: err, - } + return nil, nil, err } + res, err := client.Do(request) + if err != nil { + return nil, nil, err + } // newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout).execute() + defer res.Body.Close() c.Logger().Debug("StatusCode from OCSP Cache Server", "statusCode", res.StatusCode) if res.StatusCode != http.StatusOK { - return nil, &ocspStatus{ - code: ocspFailedResponse, - err: fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status), - } + return nil, nil, fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status) } c.Logger().Debug("reading contents") @@ -369,16 +393,12 @@ func (c *Client) checkOCSPCacheServer( if err := dec.Decode(&respd); err == io.EOF { break } else if err != nil { - c.Logger().Error("failed to decode OCSP cache.", "err", err) - return nil, &ocspStatus{ - code: ocspFailedExtractResponse, - err: err, - } + return nil, nil, err } } return &respd, &ocspStatus{ code: ocspSuccess, - } + }, nil } // retryOCSP is the second level of retry method if the returned contents are corrupted. It often happens with OCSP @@ -390,57 +410,49 @@ func (c *Client) retryOCSP( ocspHost *url.URL, headers map[string]string, reqBody []byte, - issuer *x509.Certificate, - totalTimeout time.Duration) ( + issuer *x509.Certificate) ( ocspRes *ocsp.Response, ocspResBytes []byte, - ocspS *ocspStatus) { - multiplier := 1 - if atomic.LoadUint32((*uint32)(&ocspFailOpen)) == (uint32)(OCSPFailOpenFalse) { - multiplier = 3 // up to 3 times for Fail Close mode - } - res, err := newRetryHTTP( - ctx, client, req, ocspHost, headers, - totalTimeout*time.Duration(multiplier)).doPost().setBody(reqBody).execute() + ocspS *ocspStatus, err error) { + + request, err := req("POST", ocspHost.String(), bytes.NewBuffer(reqBody)) if err != nil { - return ocspRes, ocspResBytes, &ocspStatus{ - code: ocspFailedSubmit, - err: err, + return nil, nil, nil, err + } + if request != nil { + request = request.WithContext(ctx) + for k, v := range headers { + request.Header[k] = append(request.Header[k], v) } } + res, err := client.Do(request) + if err != nil { + return nil, nil, nil, err + } defer res.Body.Close() c.Logger().Debug("StatusCode from OCSP Server:", "statusCode", res.StatusCode) if res.StatusCode != http.StatusOK { - return ocspRes, ocspResBytes, &ocspStatus{ - code: ocspFailedResponse, - err: fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status), - } + return nil, nil, nil, fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status) } c.Logger().Debug("reading contents") ocspResBytes, err = ioutil.ReadAll(res.Body) if err != nil { - return ocspRes, ocspResBytes, &ocspStatus{ - code: ocspFailedExtractResponse, - err: err, - } + return nil, nil, nil, err } c.Logger().Debug("parsing OCSP response") ocspRes, err = ocsp.ParseResponse(ocspResBytes, issuer) if err != nil { - return ocspRes, ocspResBytes, &ocspStatus{ - code: ocspFailedParseResponse, - err: err, - } + return nil, nil, nil, err } return ocspRes, ocspResBytes, &ocspStatus{ code: ocspSuccess, - } + }, nil } // getRevocationStatus checks the certificate revocation status for subject using issuer certificate. -func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate) (*ocspStatus, error) { - c.Logger().Info("get-revocation-status", "subject", subject.Subject, "issuer", issuer.Subject) +func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate, extraCas []*x509.Certificate, ocspServersOverride []string) (*ocspStatus, error) { + c.Logger().Debug("get-revocation-status", "subject", subject.Subject, "issuer", issuer.Subject) status, ocspReq, encodedCertID, err := c.ValidateWithCache(subject, issuer) if err != nil { @@ -452,62 +464,60 @@ func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509. if ocspReq == nil || encodedCertID == nil { return status, nil } - c.Logger().Info("cache missed") - c.Logger().Info("OCSP: ", "server", subject.OCSPServer) - if len(subject.OCSPServer) == 0 || isTestNoOCSPURL() { + c.Logger().Debug("cache missed") + c.Logger().Debug("OCSP: ", "server", subject.OCSPServer) + if len(subject.OCSPServer) == 0 && len(ocspServersOverride) == 0 { return nil, fmt.Errorf("no OCSP responder URL: subject: %v", subject.Subject) } - ocspHost := subject.OCSPServer[0] - u, err := url.Parse(ocspHost) - if err != nil { - return nil, err - } - hostnameStr := os.Getenv(ocspTestResponderURLEnv) - var hostname string - if retryURL := os.Getenv(ocspRetryURLEnv); retryURL != "" { - hostname = fmt.Sprintf(retryURL, u.Hostname(), base64.StdEncoding.EncodeToString(ocspReq)) - } else { - hostname = u.Hostname() + ocspHosts := subject.OCSPServer + if len(ocspServersOverride) > 0 { + ocspHosts = ocspServersOverride } - if hostnameStr != "" { - u0, err := url.Parse(hostnameStr) - if err == nil { - hostname = u0.Hostname() - u = u0 + + var ret *ocspStatus + var ocspResBytes []byte + for _, ocspHost := range ocspHosts { + u, err := url.Parse(ocspHost) + if err != nil { + return nil, err } - } - headers := make(map[string]string) - headers[httpHeaderContentType] = "application/ocsp-request" - headers[httpHeaderAccept] = "application/ocsp-response" - headers[httpHeaderContentLength] = strconv.Itoa(len(ocspReq)) - headers[httpHeaderHost] = hostname - timeoutStr := os.Getenv(ocspTestResponderTimeoutEnv) - timeout := defaultOCSPResponderTimeout - if timeoutStr != "" { - var timeoutInMilliseconds int - timeoutInMilliseconds, err = strconv.Atoi(timeoutStr) - if err == nil { - timeout = time.Duration(timeoutInMilliseconds) * time.Millisecond + + hostname := u.Hostname() + + headers := make(map[string]string) + headers[httpHeaderContentType] = "application/ocsp-request" + headers[httpHeaderAccept] = "application/ocsp-response" + headers[httpHeaderContentLength] = strconv.Itoa(len(ocspReq)) + headers[httpHeaderHost] = hostname + timeout := defaultOCSPResponderTimeout + + ocspClient := &http.Client{ + Timeout: timeout, + Transport: newInsecureOcspTransport(extraCas), + } + var ocspRes *ocsp.Response + var ocspS *ocspStatus + ocspRes, ocspResBytes, ocspS, err = c.retryOCSP( + ctx, ocspClient, http.NewRequest, u, headers, ocspReq, issuer) + if err != nil { + return nil, err + } + if ocspS.code != ocspSuccess { + return ocspS, nil } - } - ocspClient := &http.Client{ - Timeout: timeout, - Transport: snowflakeInsecureTransport, - } - ocspRes, ocspResBytes, ocspS := c.retryOCSP( - ctx, ocspClient, http.NewRequest, u, headers, ocspReq, issuer, timeout) - if ocspS.code != ocspSuccess { - return ocspS, nil - } - ret, err := validateOCSP(ocspRes) - if err != nil { - return nil, err + ret, err = validateOCSP(ocspRes) + if err != nil { + return nil, err + } + if isValidOCSPStatus(ret.code) { + break + } } if !isValidOCSPStatus(ret.code) { - return ret, nil // return invalid + return ret, nil } - v := []interface{}{float64(time.Now().UTC().Unix()), base64.StdEncoding.EncodeToString(ocspResBytes)} + v := ocspCachedResponse{time: float64(time.Now().UTC().Unix()), resp: base64.StdEncoding.EncodeToString(ocspResBytes)} c.ocspResponseCacheLock.Lock() c.ocspResponseCache[*encodedCertID] = v c.cacheUpdated = true @@ -515,16 +525,12 @@ func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509. return ret, nil } -func isTestNoOCSPURL() bool { - return strings.EqualFold(os.Getenv(ocspTestNoOCSPURLEnv), "true") -} - func isValidOCSPStatus(status ocspStatusCode) bool { return status == ocspStatusGood || status == ocspStatusRevoked || status == ocspStatusUnknown } // VerifyPeerCertificate verifies all of certificate revocation status -func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate) (err error) { +func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate, extraCas []*x509.Certificate, ocspServersOverride []string) (err error) { for i := 0; i < len(verifiedChains); i++ { // Certificate signed by Root CA. This should be one before the last in the Certificate Chain numberOfNoneRootCerts := len(verifiedChains[i]) - 1 @@ -538,7 +544,7 @@ func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]* verifiedChains[i] = append(verifiedChains[i], rca) numberOfNoneRootCerts++ } - results, err := c.getAllRevocationStatus(ctx, verifiedChains[i]) + results, err := c.GetAllRevocationStatus(ctx, verifiedChains[i], extraCas, ocspServersOverride) if err != nil { return err } @@ -552,7 +558,7 @@ func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]* func (c *Client) canEarlyExitForOCSP(results []*ocspStatus, chainSize int) *ocspStatus { msg := "" - if atomic.LoadUint32((*uint32)(&ocspFailOpen)) == (uint32)(OCSPFailOpenFalse) { + if atomic.LoadUint32((*uint32)(&c.ocspFailOpen)) == (uint32)(OCSPFailOpenFalse) { // Fail closed. any error is returned to stop connection for _, r := range results { if r.err != nil { @@ -619,7 +625,7 @@ func (c *Client) ValidateWithCache(subject, issuer *x509.Certificate) (*ocspStat return status, ocspReq, encodedCertID, nil } -func (c *Client) getAllRevocationStatus(ctx context.Context, verifiedChains []*x509.Certificate) ([]*ocspStatus, error) { +func (c *Client) GetAllRevocationStatus(ctx context.Context, verifiedChains, extraCas []*x509.Certificate, ocspServersOverride []string) ([]*ocspStatus, error) { _, err := c.ValidateWithCacheForAllCertificates(verifiedChains) if err != nil { return nil, err @@ -627,7 +633,7 @@ func (c *Client) getAllRevocationStatus(ctx context.Context, verifiedChains []*x n := len(verifiedChains) - 1 results := make([]*ocspStatus, n) for j := 0; j < n; j++ { - results[j], err = c.getRevocationStatus(ctx, verifiedChains[j], verifiedChains[j+1]) + results[j], err = c.getRevocationStatus(ctx, verifiedChains[j], verifiedChains[j+1], extraCas, ocspServersOverride) if err != nil { return nil, err } @@ -639,193 +645,114 @@ func (c *Client) getAllRevocationStatus(ctx context.Context, verifiedChains []*x } // verifyPeerCertificateSerial verifies the certificate revocation status in serial. -func (c *Client) verifyPeerCertificateSerial(_ [][]byte, verifiedChains [][]*x509.Certificate) (err error) { - return c.VerifyPeerCertificate(context.TODO(), verifiedChains) -} - -/* -// initOCSPCache initializes OCSP Response cache file. -func (c *Client) initOCSPCache() { - if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") { - return - } - ocspResponseCache = make(map[certIDKey][]interface{}) - ocspResponseCacheLock = &sync.RWMutex{} - - c.Logger().Info("reading OCSP Response cache file.", "filename", cacheFileName) - f, err := os.OpenFile(cacheFileName, os.O_CREATE|os.O_RDONLY, os.ModePerm) - if err != nil { - c.Logger().Debug("failed to open. Ignored.", "err", err) - return +func (c *Client) verifyPeerCertificateSerial(extraCas []*x509.Certificate, ocspServersOverride []string) func(_ [][]byte, verifiedChains [][]*x509.Certificate) (err error) { + return func(_ [][]byte, verifiedChains [][]*x509.Certificate) error { + return c.VerifyPeerCertificate(context.TODO(), verifiedChains, extraCas, ocspServersOverride) } - defer f.Close() - - buf := make(map[string][]interface{}) - - r := bufio.NewReader(f) - dec := json.NewDecoder(r) - for { - if err = dec.Decode(&buf); err == io.EOF { - break - } else if err != nil { - c.Logger().Debug("failed to read. Ignored.", "err", err) - return - } - } - for k, cacheValue := range buf { - status := c.extractOCSPCacheResponseValueWithoutSubject(cacheValue) - if !isValidOCSPStatus(status.code) { - continue - } - cacheKey := c.encodeCertIDKey(k) - ocspResponseCache[*cacheKey] = cacheValue +} - } - cacheUpdated = false -}*/ -func (c *Client) extractOCSPCacheResponseValueWithoutSubject(cacheValue []interface{}) (*ocspStatus, error) { - return c.extractOCSPCacheResponseValue(cacheValue, nil, nil) +func (c *Client) extractOCSPCacheResponseValueWithoutSubject(cacheValue ocspCachedResponse) (*ocspStatus, error) { + return c.extractOCSPCacheResponseValue(&cacheValue, nil, nil) } -func (c *Client) extractOCSPCacheResponseValue(cacheValue []interface{}, subject, issuer *x509.Certificate) (*ocspStatus, error) { +func (c *Client) extractOCSPCacheResponseValue(cacheValue *ocspCachedResponse, subject, issuer *x509.Certificate) (*ocspStatus, error) { subjectName := "Unknown" if subject != nil { subjectName = subject.Subject.CommonName } curTime := time.Now() - if len(cacheValue) != 2 { + if cacheValue == nil { return &ocspStatus{ code: ocspMissedCache, err: fmt.Errorf("miss cache data. subject: %v", subjectName), }, nil } - if ts, ok := cacheValue[0].(float64); ok { - currentTime := float64(curTime.UTC().Unix()) - if currentTime-ts >= cacheExpire { - return &ocspStatus{ - code: ocspCacheExpired, - err: fmt.Errorf("cache expired. current: %v, cache: %v", - time.Unix(int64(currentTime), 0).UTC(), time.Unix(int64(ts), 0).UTC()), - }, nil - } - } else { - return nil, errors.New("the first cache element is not float64") + currentTime := float64(curTime.UTC().Unix()) + if currentTime-cacheValue.time >= cacheExpire { + return &ocspStatus{ + code: ocspCacheExpired, + err: fmt.Errorf("cache expired. current: %v, cache: %v", + time.Unix(int64(currentTime), 0).UTC(), time.Unix(int64(cacheValue.time), 0).UTC()), + }, nil } + var err error var r *ocsp.Response - if s, ok := cacheValue[1].(string); ok { - var b []byte - b, err = base64.StdEncoding.DecodeString(s) - if err != nil { - return nil, fmt.Errorf("failed to decode OCSP Response value in a cache. subject: %v, err: %v", subjectName, err) - - } - // check the revocation status here - r, err = ocsp.ParseResponse(b, issuer) - if err != nil { - c.Logger().Warn("the second cache element is not a valid OCSP Response. Ignored.", "subject", subjectName) - return nil, fmt.Errorf("failed to parse OCSP Respose. subject: %v, err: %v", subjectName, err) - } - } else { - return nil, errors.New("the second cache element is not string") + var b []byte + b, err = base64.StdEncoding.DecodeString(cacheValue.resp) + if err != nil { + return nil, fmt.Errorf("failed to decode OCSP Response value in a cache. subject: %v, err: %v", subjectName, err) } + // check the revocation status here + r, err = ocsp.ParseResponse(b, issuer) + if err != nil { + c.Logger().Warn("the second cache element is not a valid OCSP Response. Ignored.", "subject", subjectName) + return nil, fmt.Errorf("failed to parse OCSP Respose. subject: %v, err: %v", subjectName, err) + } + return validateOCSP(r) } const storageValueKey = "backingStore" -// writeOCSPCacheFile writes a OCSP Response cache file. This is called if all revocation status is success. -// lock file is used to mitigate race condition with other process. -func (c *Client) writeOCSPCacheFile(ctx context.Context, storage logical.Storage) error { +// writeOCSPCache writes a OCSP Response cache +func (c *Client) writeOCSPCache(ctx context.Context, storage logical.Storage) error { c.Logger().Debug("writing OCSP Response cache file") - buf := make(map[string][]interface{}) - for k, v := range c.ocspResponseCache { - cacheKeyInBase64 := c.decodeCertIDKey(&k) - buf[cacheKeyInBase64] = v + m := make(map[string][]interface{}) + for k, entry := range c.ocspResponseCache { + cacheKeyInBase64, err := decodeCertIDKey(&k) + if err != nil { + return err + } + m[cacheKeyInBase64] = []interface{}{entry.time, entry.resp} } - j, err := json.Marshal(buf) + v, err := jsonutil.EncodeJSONAndCompress(m, nil) if err != nil { - return errors.New("failed to convert OCSP Response cache to JSON") + return err } - entry := logical.StorageEntry{ - Key: "ocsp_cache", - Value: j, + Key: ocspCacheKey, + Value: v, } return storage.Put(ctx, &entry) } -/* -// readCACerts read a set of root CAs -func (c *Client) readCACerts() { - raw := []byte(c.caRootPEM) - certPool = x509.NewCertPool() - caRoot = make(map[string]*x509.Certificate) - var p *pem.Block - for { - p, raw = pem.Decode(raw) - if p == nil { - break - } - if p.Type != "CERTIFICATE" { - continue - } - c, err := x509.ParseCertificate(p.Bytes) - if err != nil { - panic("failed to parse CA certificate.") - } - certPool.AddCert(c) - caRoot[string(c.RawSubject)] = c +// readOCSPCache reads a OCSP Response cache from storage +func (c *Client) readOCSPCache(ctx context.Context, storage logical.Storage) error { + c.Logger().Debug("reading OCSP Response cache entry") + + entry, err := storage.Get(ctx, ocspCacheKey) + if err != nil { + return err } -} + if entry == nil { + return nil + } + var untypedCache map[string][]interface{} -// createOCSPCacheDir creates OCSP response cache directory and set the cache file name. -func createOCSPCacheDir() { - if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") { - c.Logger().Info(`OCSP Cache Server disabled. All further access and use of - OCSP Cache will be disabled for this OCSP Status Query`) - return - } - cacheDir = os.Getenv(cacheDirEnv) - if cacheDir == "" { - cacheDir = os.Getenv("SNOWFLAKE_TEST_WORKSPACE") - } - if cacheDir == "" { - switch runtime.GOOS { - case "windows": - cacheDir = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local", "Snowflake", "Caches") - case "darwin": - home := os.Getenv("HOME") - if home == "" { - c.Logger().Info("HOME is blank.") - } - cacheDir = filepath.Join(home, "Library", "Caches", "Snowflake") - default: - home := os.Getenv("HOME") - if home == "" { - c.Logger().Info("HOME is blank") - } - cacheDir = filepath.Join(home, ".cache", "snowflake") - } + err = jsonutil.DecodeJSON(entry.Value, &untypedCache) + if err != nil { + return errors.New("failed to unmarshal OCSP cache") } - if _, err := os.Stat(cacheDir); os.IsNotExist(err) { - if err = os.MkdirAll(cacheDir, os.ModePerm); err != nil { - c.Logger().Debugf("failed to create cache directory. %v, err: %v. ignored", cacheDir, err) + for k, v := range untypedCache { + key, err := c.encodeCertIDKey(k) + if err != nil { + return err } + c.ocspResponseCache[*key] = ocspCachedResponse{time: v[0].(float64), resp: v[1].(string)} } - cacheFileName = filepath.Join(cacheDir, cacheFileBaseName) - c.Logger().Infof("reset OCSP cache file. %v", cacheFileName) + return nil } -*/ + func New(logFactory func() hclog.Logger) *Client { c := Client{ caRoot: make(map[string]*x509.Certificate), - ocspResponseCache: make(map[certIDKey][]interface{}), + ocspResponseCache: make(map[certIDKey]ocspCachedResponse), logFactory: logFactory, } @@ -836,24 +763,37 @@ func (c *Client) Logger() hclog.Logger { return c.logFactory() } -// snowflakeInsecureTransport is the transport object that doesn't do certificate revocation check. -var snowflakeInsecureTransport = &http.Transport{ - MaxIdleConns: 10, - IdleConnTimeout: 30 * time.Minute, - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, +// insecureOcspTransport is the transport object that doesn't do certificate revocation check. +func newInsecureOcspTransport(extraCas []*x509.Certificate) *http.Transport { + // Get the SystemCertPool, continue with an empty pool on error + rootCAs, _ := x509.SystemCertPool() + if rootCAs == nil { + rootCAs = x509.NewCertPool() + } + for _, c := range extraCas { + rootCAs.AddCert(c) + } + config := &tls.Config{ + RootCAs: rootCAs, + } + return &http.Transport{ + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Minute, + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSClientConfig: config, + } } -// SnowflakeTransport includes the certificate revocation check with OCSP in sequential. By default, the driver uses -// this transport object. -func (c *Client) NewTransport() *http.Transport { +// NewTransport includes the certificate revocation check with OCSP in sequential. +func (c *Client) NewTransport(extraCas []*x509.Certificate, ocspServersOverride []string) *http.Transport { return &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: c.certPool, - VerifyPeerCertificate: c.verifyPeerCertificateSerial, + VerifyPeerCertificate: c.verifyPeerCertificateSerial(extraCas, ocspServersOverride), }, MaxIdleConns: 10, IdleConnTimeout: 30 * time.Minute, @@ -867,16 +807,22 @@ func (c *Client) NewTransport() *http.Transport { func (c *Client) WriteCache(ctx context.Context, storage logical.Storage) error { c.ocspResponseCacheLock.Lock() + defer c.ocspResponseCacheLock.Unlock() if c.cacheUpdated { - defer c.ocspResponseCacheLock.Unlock() if c.cacheUpdated { - return c.writeOCSPCacheFile(ctx, storage) + return c.writeOCSPCache(ctx, storage) } c.cacheUpdated = false } return nil } +func (c *Client) ReadCache(ctx context.Context, storage logical.Storage) error { + c.ocspResponseCacheLock.Lock() + defer c.ocspResponseCacheLock.Unlock() + return c.readOCSPCache(ctx, storage) +} + /* Apache License Version 2.0, January 2004 diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go index 934061abd44d..7b00092e3df1 100644 --- a/sdk/helper/ocsp/ocsp_test.go +++ b/sdk/helper/ocsp/ocsp_test.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "github.com/hashicorp/go-hclog" + "io" "io/ioutil" "net" "net/http" @@ -29,22 +30,22 @@ func TestOCSP(t *testing.T) { "false", } targetURL := []string{ + "https://sfcdev1.blob.core.windows.net/", "https://sfctest0.snowflakecomputing.com/", "https://s3-us-west-2.amazonaws.com/sfc-snowsql-updates/?prefix=1.1/windows_x86_64", - "https://sfcdev1.blob.core.windows.net/", } c := New(testLogFactory) transports := []*http.Transport{ - snowflakeInsecureTransport, - c.NewTransport(), + newInsecureOcspTransport(nil), + c.NewTransport(nil, nil), } for _, enabled := range cacheServerEnabled { for _, tgt := range targetURL { _ = os.Setenv(cacheServerEnabledEnv, enabled) //_ = os.Remove(cacheFileName) // clear cache file - c.ocspResponseCache = make(map[certIDKey][]interface{}) + c.ocspResponseCache = make(map[certIDKey]*ocspCachedResponse) for _, tr := range transports { c := &http.Client{ Transport: tr, @@ -186,7 +187,7 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) { } b64Key := base64.StdEncoding.EncodeToString([]byte("DUMMY_VALUE")) currentTime := float64(time.Now().UTC().Unix()) - c.ocspResponseCache[dummyKey0] = []interface{}{currentTime, b64Key} + c.ocspResponseCache[dummyKey0] = &ocspCachedResponse{currentTime, b64Key} subject := &x509.Certificate{} issuer := &x509.Certificate{} ost, err := c.checkOCSPResponseCache(&dummyKey, subject, issuer) @@ -197,7 +198,7 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) { t.Fatalf("should have failed. expected: %v, got: %v", ocspMissedCache, ost.code) } // old timestamp - c.ocspResponseCache[dummyKey] = []interface{}{float64(1395054952), b64Key} + c.ocspResponseCache[dummyKey] = &ocspCachedResponse{float64(1395054952), b64Key} ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) if err != nil { t.Fatal(err) @@ -206,7 +207,7 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) { t.Fatalf("should have failed. expected: %v, got: %v", ocspCacheExpired, ost.code) } // future timestamp - c.ocspResponseCache[dummyKey] = []interface{}{float64(1805054952), b64Key} + c.ocspResponseCache[dummyKey] = &ocspCachedResponse{float64(1805054952), b64Key} ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) if err == nil { t.Fatalf("should have failed.") @@ -220,30 +221,17 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) { "koRzw/UU7zKsqiTB0ZN/rgJp+MocTdqQSGKvbZyR8d4u8eNQqi1x4Pk3yO/pftANFaJKGB+JPgKS3PQAqJaXcipNcEfqtl7y4PO6kqA" + "Jb4xI/OTXIrRA5TsT4cCioE" // issuer is not a true issuer certificate - c.ocspResponseCache[dummyKey] = []interface{}{float64(currentTime - 1000), actualOcspResponse} + c.ocspResponseCache[dummyKey] = &ocspCachedResponse{float64(currentTime - 1000), actualOcspResponse} ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) if err == nil { t.Fatalf("should have failed.") } // invalid validity - c.ocspResponseCache[dummyKey] = []interface{}{float64(currentTime - 1000), actualOcspResponse} + c.ocspResponseCache[dummyKey] = &ocspCachedResponse{float64(currentTime - 1000), actualOcspResponse} ost, err = c.checkOCSPResponseCache(&dummyKey, subject, nil) if err == nil { t.Fatalf("should have failed.") } - // wrong timestamp type - c.ocspResponseCache[dummyKey] = []interface{}{uint32(currentTime - 1000), 123456} - ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) - if err == nil { - t.Fatalf("should have failed.") - } - - // wrong value type - c.ocspResponseCache[dummyKey] = []interface{}{float64(currentTime - 1000), 123456} - ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) - if err == nil { - t.Fatalf("should have failed.") - } } func TestUnitValidateOCSP(t *testing.T) { @@ -338,13 +326,14 @@ func TestOCSPRetry(t *testing.T) { success: true, body: []byte{1, 2, 3}, logger: hclog.New(hclog.DefaultOptions), + t: t, } - res, b, st := c.retryOCSP( + res, b, st, err := c.retryOCSP( context.TODO(), client, fakeRequestFunc, dummyOCSPHost, - make(map[string]string), []byte{0}, certs[len(certs)-1], 10*time.Second) - if st.err == nil { + make(map[string]string), []byte{0}, certs[len(certs)-1]) + if err == nil { fmt.Printf("should fail: %v, %v, %v\n", res, b, st) } client = &fakeHTTPClient{ @@ -352,13 +341,14 @@ func TestOCSPRetry(t *testing.T) { success: true, body: []byte{1, 2, 3}, logger: hclog.New(hclog.DefaultOptions), + t: t, } - res, b, st = c.retryOCSP( + res, b, st, err = c.retryOCSP( context.TODO(), client, fakeRequestFunc, dummyOCSPHost, - make(map[string]string), []byte{0}, certs[len(certs)-1], 5*time.Second) - if st.err == nil { + make(map[string]string), []byte{0}, certs[len(certs)-1]) + if err == nil { fmt.Printf("should fail: %v, %v, %v\n", res, b, st) } } @@ -374,10 +364,11 @@ func TestOCSPCacheServerRetry(t *testing.T) { success: true, body: []byte{1, 2, 3}, logger: hclog.New(hclog.DefaultOptions), + t: t, } - res, st := c.checkOCSPCacheServer( + res, _, err := c.checkOCSPCacheServer( context.TODO(), client, fakeRequestFunc, dummyOCSPHost, 20*time.Second) - if st.err == nil { + if err == nil { t.Errorf("should fail: %v", res) } client = &fakeHTTPClient{ @@ -385,10 +376,11 @@ func TestOCSPCacheServerRetry(t *testing.T) { success: true, body: []byte{1, 2, 3}, logger: hclog.New(hclog.DefaultOptions), + t: t, } - res, st = c.checkOCSPCacheServer( + res, _, err = c.checkOCSPCacheServer( context.TODO(), client, fakeRequestFunc, dummyOCSPHost, 10*time.Second) - if st.err == nil { + if err == nil { t.Errorf("should fail: %v", res) } } @@ -482,7 +474,7 @@ func TestCanEarlyExitForOCSP(t *testing.T) { } c := New(testLogFactory) for idx, tt := range testcases { - ocspFailOpen = OCSPFailOpenTrue + c.ocspFailOpen = OCSPFailOpenTrue expectedLen := len(tt.results) if tt.resultLen > 0 { expectedLen = tt.resultLen @@ -491,7 +483,7 @@ func TestCanEarlyExitForOCSP(t *testing.T) { if !(tt.retFailOpen == nil && r == nil) && !(tt.retFailOpen != nil && r != nil && tt.retFailOpen.code == r.code) { t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailOpen, r) } - ocspFailOpen = OCSPFailOpenFalse + c.ocspFailOpen = OCSPFailOpenFalse r = c.canEarlyExitForOCSP(tt.results, expectedLen) if !(tt.retFailClosed == nil && r == nil) && !(tt.retFailClosed != nil && r != nil && tt.retFailClosed.code == r.code) { t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailClosed, r) @@ -504,3 +496,73 @@ var testLogger = hclog.New(hclog.DefaultOptions) func testLogFactory() hclog.Logger { return testLogger } + +type fakeHTTPClient struct { + cnt int // number of retry + success bool // return success after retry in cnt times + timeout bool // timeout + body []byte // return body + t *testing.T + logger hclog.Logger +} + +func (c *fakeHTTPClient) Do(req *http.Request) (*http.Response, error) { + c.cnt-- + if c.cnt < 0 { + c.cnt = 0 + } + c.t.Log("fakeHTTPClient.cnt", c.cnt) + + var retcode int + if c.success && c.cnt == 0 { + retcode = 200 + } else { + if c.timeout { + // simulate timeout + time.Sleep(time.Second * 1) + return nil, &fakeHTTPError{ + err: "Whatever reason (Client.Timeout exceeded while awaiting headers)", + timeout: true, + } + } + retcode = 0 + } + + ret := &http.Response{ + StatusCode: retcode, + Body: &fakeResponseBody{body: c.body}, + } + return ret, nil +} + +type fakeHTTPError struct { + err string + timeout bool +} + +func (e *fakeHTTPError) Error() string { return e.err } +func (e *fakeHTTPError) Timeout() bool { return e.timeout } +func (e *fakeHTTPError) Temporary() bool { return true } + +type fakeResponseBody struct { + body []byte + cnt int +} + +func (b *fakeResponseBody) Read(p []byte) (n int, err error) { + if b.cnt == 0 { + copy(p, b.body) + b.cnt = 1 + return len(b.body), nil + } + b.cnt = 0 + return 0, io.EOF +} + +func (b *fakeResponseBody) Close() error { + return nil +} + +func fakeRequestFunc(_, _ string, _ io.Reader) (*http.Request, error) { + return nil, nil +} diff --git a/sdk/helper/ocsp/retry.go b/sdk/helper/ocsp/retry.go deleted file mode 100644 index 4060abb87a18..000000000000 --- a/sdk/helper/ocsp/retry.go +++ /dev/null @@ -1,550 +0,0 @@ -// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. - -package ocsp - -import ( - "bytes" - "context" - "crypto/x509" - "fmt" - "github.com/hashicorp/go-hclog" - "github.com/hashicorp/go-uuid" - "io" - "math/rand" - "net/http" - "net/url" - "runtime" - "strconv" - "strings" - "sync" - "time" -) - -const ( - httpHeaderContentType = "Content-Type" - httpHeaderAccept = "accept" - httpHeaderUserAgent = "User-Agent" - httpHeaderServiceName = "X-Snowflake-Service" - httpHeaderContentLength = "Content-Length" - httpHeaderHost = "Host" - httpHeaderValueOctetStream = "application/octet-stream" - httpHeaderContentEncoding = "Content-Encoding" -) - -var random *rand.Rand - -func init() { - random = rand.New(rand.NewSource(time.Now().UnixNano())) -} - -const ( - // requestGUIDKey is attached to every request against Snowflake - requestGUIDKey string = "request_guid" - // retryCounterKey is attached to query-request from the second time - retryCounterKey string = "retryCounter" - // requestIDKey is attached to all requests to Snowflake - requestIDKey string = "requestId" -) - -// This class takes in an url during construction and replaces the value of -// request_guid every time replace() is called. If the url does not contain -// request_guid, just return the original url -type requestGUIDReplacer interface { - // replace the url with new ID - replace() *url.URL -} - -// Make requestGUIDReplacer given a url string -func newRequestGUIDReplace(urlPtr *url.URL) requestGUIDReplacer { - values, err := url.ParseQuery(urlPtr.RawQuery) - if err != nil { - // nop if invalid query parameters - return &transientReplace{urlPtr} - } - if len(values.Get(requestGUIDKey)) == 0 { - // nop if no request_guid is included. - return &transientReplace{urlPtr} - } - - return &requestGUIDReplace{urlPtr, values} -} - -// this replacer does nothing but replace the url -type transientReplace struct { - urlPtr *url.URL -} - -func (replacer *transientReplace) replace() *url.URL { - return replacer.urlPtr -} - -/* -requestGUIDReplacer is a one-shot object that is created out of the retry loop and -called with replace to change the retry_guid's value upon every retry -*/ -type requestGUIDReplace struct { - urlPtr *url.URL - urlValues url.Values -} - -/** -This function would replace they value of the requestGUIDKey in a url with a newly -generated UUID -*/ -func (replacer *requestGUIDReplace) replace() *url.URL { - replacer.urlValues.Del(requestGUIDKey) - uuid, _ := uuid.GenerateUUID() - replacer.urlValues.Add(requestGUIDKey, uuid) - replacer.urlPtr.RawQuery = replacer.urlValues.Encode() - return replacer.urlPtr -} - -type retryCounterUpdater interface { - replaceOrAdd(retry int) *url.URL -} - -type retryCounterUpdate struct { - urlPtr *url.URL - urlValues url.Values -} - -// this replacer does nothing but replace the url -type transientReplaceOrAdd struct { - urlPtr *url.URL -} - -func (replaceOrAdder *transientReplaceOrAdd) replaceOrAdd(retry int) *url.URL { - return replaceOrAdder.urlPtr -} - -func (replacer *retryCounterUpdate) replaceOrAdd(retry int) *url.URL { - replacer.urlValues.Del(retryCounterKey) - replacer.urlValues.Add(retryCounterKey, strconv.Itoa(retry)) - replacer.urlPtr.RawQuery = replacer.urlValues.Encode() - return replacer.urlPtr -} - -// Snowflake Server Endpoints -const ( - loginRequestPath = "/session/v1/login-request" - queryRequestPath = "/queries/v1/query-request" - tokenRequestPath = "/session/token-request" - abortRequestPath = "/queries/v1/abort-request" - authenticatorRequestPath = "/session/authenticator-request" - sessionRequestPath = "/session" - heartBeatPath = "/session/heartbeat" -) - -func newRetryUpdate(urlPtr *url.URL) retryCounterUpdater { - if !strings.HasPrefix(urlPtr.Path, queryRequestPath) { - // nop if not query-request - return &transientReplaceOrAdd{urlPtr} - } - values, err := url.ParseQuery(urlPtr.RawQuery) - if err != nil { - // nop if the URL is not valid - return &transientReplaceOrAdd{urlPtr} - } - return &retryCounterUpdate{urlPtr, values} -} - -type waitAlgo struct { - mutex *sync.Mutex // required for random.Int63n - base time.Duration // base wait time - cap time.Duration // maximum wait time -} - -func randSecondDuration(n time.Duration) time.Duration { - return time.Duration(random.Int63n(int64(n/time.Second))) * time.Second -} - -// decorrelated jitter backoff -func (w *waitAlgo) decorr(attempt int, sleep time.Duration) time.Duration { - w.mutex.Lock() - defer w.mutex.Unlock() - t := 3*sleep - w.base - switch { - case t > 0: - return durationMin(w.cap, randSecondDuration(t)+w.base) - case t < 0: - return durationMin(w.cap, randSecondDuration(-t)+3*sleep) - } - return w.base -} - -var defaultWaitAlgo = &waitAlgo{ - mutex: &sync.Mutex{}, - base: 5 * time.Second, - cap: 160 * time.Second, -} - -type requestFunc func(method, urlStr string, body io.Reader) (*http.Request, error) - -type clientInterface interface { - Do(req *http.Request) (*http.Response, error) -} - -type retryHTTP struct { - ctx context.Context - client clientInterface - req requestFunc - method string - fullURL *url.URL - headers map[string]string - body []byte - timeout time.Duration - raise4XX bool - logger hclog.Logger -} - -func newRetryHTTP(ctx context.Context, - client clientInterface, - req requestFunc, - fullURL *url.URL, - headers map[string]string, - timeout time.Duration) *retryHTTP { - instance := retryHTTP{} - instance.ctx = ctx - instance.client = client - instance.req = req - instance.method = "GET" - instance.fullURL = fullURL - instance.headers = headers - instance.body = nil - instance.timeout = timeout - instance.raise4XX = false - instance.logger = hclog.New(hclog.DefaultOptions) - return &instance -} - -func (r *retryHTTP) doRaise4XX(raise4XX bool) *retryHTTP { - r.raise4XX = raise4XX - return r -} - -func (r *retryHTTP) doPost() *retryHTTP { - r.method = "POST" - return r -} - -func (r *retryHTTP) setBody(body []byte) *retryHTTP { - r.body = body - return r -} - -func (r *retryHTTP) execute() (res *http.Response, err error) { - totalTimeout := r.timeout - r.logger.Info("retryHTTP", "totalTimeout", totalTimeout) - retryCounter := 0 - sleepTime := time.Duration(0) - - var rIDReplacer requestGUIDReplacer - var rUpdater retryCounterUpdater - - for { - r.logger.Debug("retry count", "retryCounter", retryCounter) - req, err := r.req(r.method, r.fullURL.String(), bytes.NewReader(r.body)) - if err != nil { - return nil, err - } - if req != nil { - // req can be nil in tests - req = req.WithContext(r.ctx) - } - for k, v := range r.headers { - req.Header.Set(k, v) - } - res, err = r.client.Do(req) - if err != nil { - // check if it can retry. - doExit, err := r.isRetryableError(err) - if doExit { - return res, err - } - // cannot just return 4xx and 5xx status as the error can be sporadic. run often helps. - r.logger.Warn( - "failed http connection. no response is returned. retrying...", "err", err) - } else { - if res.StatusCode == http.StatusOK || r.raise4XX && res != nil && res.StatusCode >= 400 && res.StatusCode < 500 { - // exit if success - // or - // abort connection if raise4XX flag is enabled and the range of HTTP status code are 4XX. - // This is currently used for Snowflake login. The caller must generate an error object based on HTTP status. - break - } - r.logger.Warn( - "failed http connection. retrying...\n", "statusCode", res.StatusCode) - res.Body.Close() - } - // uses decorrelated jitter backoff - sleepTime = defaultWaitAlgo.decorr(retryCounter, sleepTime) - - if totalTimeout > 0 { - r.logger.Info("to timeout: ", "totalTimeout", totalTimeout) - // if any timeout is set - totalTimeout -= sleepTime - if totalTimeout <= 0 { - if err != nil { - return nil, err - } - if res != nil { - return nil, fmt.Errorf("timeout after %s. HTTP Status: %v. Hanging?", r.timeout, res.StatusCode) - } - return nil, fmt.Errorf("timeout after %s. Hanging?", r.timeout) - } - } - retryCounter++ - if rIDReplacer == nil { - rIDReplacer = newRequestGUIDReplace(r.fullURL) - } - r.fullURL = rIDReplacer.replace() - if rUpdater == nil { - rUpdater = newRetryUpdate(r.fullURL) - } - r.fullURL = rUpdater.replaceOrAdd(retryCounter) - r.logger.Info("sleeping to retry", "sleepTime", sleepTime, "totalTimeout", totalTimeout) - - await := time.NewTimer(sleepTime) - select { - case <-await.C: - // retry the request - case <-r.ctx.Done(): - await.Stop() - return res, r.ctx.Err() - } - } - return res, err -} - -func (r *retryHTTP) isRetryableError(err error) (bool, error) { - urlError, isURLError := err.(*url.Error) - if isURLError { - // context cancel or timeout - if urlError.Err == context.DeadlineExceeded || urlError.Err == context.Canceled { - return true, urlError.Err - } - if urlError.Err.Error() == "OCSP status revoked" { - // Certificate Revoked - return true, nil - } - if _, ok := urlError.Err.(x509.CertificateInvalidError); ok { - // Certificate is invalid - return true, err - } - if _, ok := urlError.Err.(x509.UnknownAuthorityError); ok { - // Certificate is self-signed - return true, err - } - errString := urlError.Err.Error() - if runtime.GOOS == "darwin" && strings.HasPrefix(errString, "x509:") && strings.HasSuffix(errString, "certificate is expired") { - // Certificate is expired - return true, err - } - - } - return false, err -} - -/* - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ diff --git a/sdk/helper/ocsp/retry_test.go b/sdk/helper/ocsp/retry_test.go deleted file mode 100644 index 8a68f13217b9..000000000000 --- a/sdk/helper/ocsp/retry_test.go +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. - -package ocsp - -import ( - "context" - "github.com/hashicorp/go-hclog" - "io" - "net/http" - "net/url" - "strconv" - "strings" - "testing" - "time" -) - -func fakeRequestFunc(_, _ string, _ io.Reader) (*http.Request, error) { - return nil, nil -} - -type fakeHTTPError struct { - err string - timeout bool -} - -func (e *fakeHTTPError) Error() string { return e.err } -func (e *fakeHTTPError) Timeout() bool { return e.timeout } -func (e *fakeHTTPError) Temporary() bool { return true } - -type fakeResponseBody struct { - body []byte - cnt int -} - -func (b *fakeResponseBody) Read(p []byte) (n int, err error) { - if b.cnt == 0 { - copy(p, b.body) - b.cnt = 1 - return len(b.body), nil - } - b.cnt = 0 - return 0, io.EOF -} - -func (b *fakeResponseBody) Close() error { - return nil -} - -type fakeHTTPClient struct { - cnt int // number of retry - success bool // return success after retry in cnt times - timeout bool // timeout - body []byte // return body - logger hclog.Logger -} - -func (c *fakeHTTPClient) Do(req *http.Request) (*http.Response, error) { - c.cnt-- - if c.cnt < 0 { - c.cnt = 0 - } - c.logger.Info("fakeHTTPClient", "cnt", c.cnt) - - var retcode int - if c.success && c.cnt == 0 { - retcode = 200 - } else { - if c.timeout { - // simulate timeout - time.Sleep(time.Second * 1) - return nil, &fakeHTTPError{ - err: "Whatever reason (Client.Timeout exceeded while awaiting headers)", - timeout: true, - } - } - retcode = 0 - } - - ret := &http.Response{ - StatusCode: retcode, - Body: &fakeResponseBody{body: c.body}, - } - return ret, nil -} - -func TestRequestGUID(t *testing.T) { - var ridReplacer requestGUIDReplacer - var testURL *url.URL - var actualURL *url.URL - retryTime := 4 - - // empty url - testURL = &url.URL{} - ridReplacer = newRequestGUIDReplace(testURL) - for i := 0; i < retryTime; i++ { - actualURL = ridReplacer.replace() - if actualURL.String() != "" { - t.Fatalf("empty url not replaced by an empty one, got %s", actualURL) - } - } - - // url with on retry id - testURL = &url.URL{ - Path: "/" + requestIDKey + "=123-1923-9?param2=value", - } - ridReplacer = newRequestGUIDReplace(testURL) - for i := 0; i < retryTime; i++ { - actualURL = ridReplacer.replace() - - if actualURL != testURL { - t.Fatalf("url without retry id not replaced by origin one, got %s", actualURL) - } - } - - // url with retry id - // With both prefix and suffix - prefix := "/" + requestIDKey + "=123-1923-9?" + requestGUIDKey + "=" - suffix := "?param2=value" - testURL = &url.URL{ - Path: prefix + "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + suffix, - } - ridReplacer = newRequestGUIDReplace(testURL) - for i := 0; i < retryTime; i++ { - actualURL = ridReplacer.replace() - if (!strings.HasPrefix(actualURL.Path, prefix)) || - (!strings.HasSuffix(actualURL.Path, suffix)) || - len(testURL.Path) != len(actualURL.Path) { - t.Fatalf("Retry url not replaced correctedly: \n origin: %s \n result: %s", testURL, actualURL) - } - } - - // With no suffix - prefix = "/" + requestIDKey + "=123-1923-9?" + requestGUIDKey + "=" - suffix = "" - testURL = &url.URL{ - Path: prefix + "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + suffix, - } - ridReplacer = newRequestGUIDReplace(testURL) - for i := 0; i < retryTime; i++ { - actualURL = ridReplacer.replace() - if (!strings.HasPrefix(actualURL.Path, prefix)) || - (!strings.HasSuffix(actualURL.Path, suffix)) || - len(testURL.Path) != len(actualURL.Path) { - t.Fatalf("Retry url not replaced correctedly: \n origin: %s \n result: %s", testURL, actualURL) - } - - } - // With no prefix - prefix = requestGUIDKey + "=" - suffix = "?param2=value" - testURL = &url.URL{ - Path: prefix + "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + suffix, - } - ridReplacer = newRequestGUIDReplace(testURL) - for i := 0; i < retryTime; i++ { - actualURL = ridReplacer.replace() - if (!strings.HasPrefix(actualURL.Path, prefix)) || - (!strings.HasSuffix(actualURL.Path, suffix)) || - len(testURL.Path) != len(actualURL.Path) { - t.Fatalf("Retry url not replaced correctedly: \n origin: %s \n result: %s", testURL, actualURL) - } - } -} - -func TestRetryQuerySuccess(t *testing.T) { - c := New(testLogFactory) - c.Logger().Info("Retry N times and Success") - client := &fakeHTTPClient{ - cnt: 3, - success: true, - } - urlPtr, err := url.Parse("https://fakeaccountretrysuccess.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey + "=testid&clientStartTime=123456") - if err != nil { - t.Fatal("failed to parse the test URL") - } - _, err = newRetryHTTP(context.TODO(), - client, - fakeRequestFunc, urlPtr, make(map[string]string), 60*time.Second).doPost().setBody([]byte{0}).execute() - if err != nil { - t.Fatal("failed to run retry") - } - var values url.Values - values, err = url.ParseQuery(urlPtr.RawQuery) - if err != nil { - t.Fatal("failed to fail to parse the URL") - } - retry, err := strconv.Atoi(values.Get(retryCounterKey)) - if err != nil { - t.Fatalf("failed to get retry counter: %v", err) - } - if retry < 2 { - t.Fatalf("not enough retry counter: %v", retry) - } -} -func TestRetryQueryFail(t *testing.T) { - c := New(testLogFactory) - c.Logger().Info("Retry N times and Fail") - client := &fakeHTTPClient{ - cnt: 4, - success: false, - } - urlPtr, err := url.Parse("https://fakeaccountretryfail.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey + "=testid&clientStartTime=123456") - if err != nil { - t.Fatal("failed to parse the test URL") - } - _, err = newRetryHTTP(context.TODO(), - client, - fakeRequestFunc, urlPtr, make(map[string]string), 60*time.Second).doPost().setBody([]byte{0}).execute() - if err == nil { - t.Fatal("should fail to run retry") - } - var values url.Values - values, err = url.ParseQuery(urlPtr.RawQuery) - if err != nil { - t.Fatalf("failed to fail to parse the URL: %v", err) - } - retry, err := strconv.Atoi(values.Get(retryCounterKey)) - if err != nil { - t.Fatalf("failed to get retry counter: %v", err) - } - if retry < 2 { - t.Fatalf("not enough retry counter: %v", retry) - } -} -func TestRetryLoginRequest(t *testing.T) { - client := &fakeHTTPClient{ - cnt: 3, - success: true, - timeout: true, - logger: hclog.New(hclog.DefaultOptions), - } - client.logger.Info("Retry N times for timeouts and Success") - urlPtr, err := url.Parse("https://fakeaccountretrylogin.snowflakecomputing.com:443/login-request?request_id=testid") - if err != nil { - t.Fatal("failed to parse the test URL") - } - _, err = newRetryHTTP(context.TODO(), - client, - fakeRequestFunc, urlPtr, make(map[string]string), 60*time.Second).doPost().setBody([]byte{0}).execute() - if err != nil { - t.Fatal("failed to run retry") - } - var values url.Values - values, err = url.ParseQuery(urlPtr.RawQuery) - if err != nil { - t.Fatalf("failed to fail to parse the URL: %v", err) - } - if values.Get(retryCounterKey) != "" { - t.Fatalf("no retry counter should be attached: %v", retryCounterKey) - } - client.logger.Info("Retry N times for timeouts and Fail") - client = &fakeHTTPClient{ - cnt: 10, - success: false, - timeout: true, - logger: hclog.New(hclog.DefaultOptions), - } - _, err = newRetryHTTP(context.TODO(), - client, - fakeRequestFunc, urlPtr, make(map[string]string), 10*time.Second).doPost().setBody([]byte{0}).execute() - if err == nil { - t.Fatal("should fail to run retry") - } - values, err = url.ParseQuery(urlPtr.RawQuery) - if err != nil { - t.Fatalf("failed to fail to parse the URL: %v", err) - } - if values.Get(retryCounterKey) != "" { - t.Fatalf("no retry counter should be attached: %v", retryCounterKey) - } -} From b0798a7255260801f4f9272dd17d8eeb21b59567 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Fri, 9 Sep 2022 15:01:11 -0500 Subject: [PATCH 03/39] ->pointer --- sdk/helper/ocsp/client.go | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index b696caae959b..e5f5614f8605 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -108,7 +108,7 @@ type Client struct { caRoot map[string]*x509.Certificate // certPOol includes the CA certificates. certPool *x509.CertPool - ocspResponseCache map[certIDKey]ocspCachedResponse + ocspResponseCache map[certIDKey]*ocspCachedResponse ocspResponseCacheLock sync.RWMutex // cacheUpdated is true if the memory cache is updated cacheUpdated bool @@ -302,7 +302,7 @@ func (c *Client) checkOCSPResponseCache(encodedCertID *certIDKey, subject, issue gotValueFromCache := c.ocspResponseCache[*encodedCertID] c.ocspResponseCacheLock.RUnlock() - status, err := c.extractOCSPCacheResponseValue(&gotValueFromCache, subject, issuer) + status, err := c.extractOCSPCacheResponseValue(gotValueFromCache, subject, issuer) if err != nil { return nil, err } @@ -519,7 +519,7 @@ func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509. } v := ocspCachedResponse{time: float64(time.Now().UTC().Unix()), resp: base64.StdEncoding.EncodeToString(ocspResBytes)} c.ocspResponseCacheLock.Lock() - c.ocspResponseCache[*encodedCertID] = v + c.ocspResponseCache[*encodedCertID] = &v c.cacheUpdated = true c.ocspResponseCacheLock.Unlock() return ret, nil @@ -744,7 +744,16 @@ func (c *Client) readOCSPCache(ctx context.Context, storage logical.Storage) err if err != nil { return err } - c.ocspResponseCache[*key] = ocspCachedResponse{time: v[0].(float64), resp: v[1].(string)} + var time float64 + if jn, ok := v[0].(json.Number); ok { + time, err = jn.Float64() + if err != nil { + return err + } + } else { + time = v[0].(float64) + } + c.ocspResponseCache[*key] = &ocspCachedResponse{time: time, resp: v[1].(string)} } return nil } @@ -752,7 +761,7 @@ func (c *Client) readOCSPCache(ctx context.Context, storage logical.Storage) err func New(logFactory func() hclog.Logger) *Client { c := Client{ caRoot: make(map[string]*x509.Certificate), - ocspResponseCache: make(map[certIDKey]ocspCachedResponse), + ocspResponseCache: make(map[certIDKey]*ocspCachedResponse), logFactory: logFactory, } @@ -809,9 +818,7 @@ func (c *Client) WriteCache(ctx context.Context, storage logical.Storage) error c.ocspResponseCacheLock.Lock() defer c.ocspResponseCacheLock.Unlock() if c.cacheUpdated { - if c.cacheUpdated { - return c.writeOCSPCache(ctx, storage) - } + return c.writeOCSPCache(ctx, storage) c.cacheUpdated = false } return nil From 87bd5a382ffeaefe1eb1da4ad6531a831cb617eb Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Fri, 9 Sep 2022 15:08:20 -0500 Subject: [PATCH 04/39] Code cleanup --- sdk/helper/ocsp/client.go | 115 ++++++----------------------------- sdk/helper/ocsp/ocsp_test.go | 82 ++++++------------------- 2 files changed, 39 insertions(+), 158 deletions(-) diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index e5f5614f8605..4daf9739d2f7 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -24,9 +24,7 @@ import ( "net" "net/http" "net/url" - "os" "strconv" - "strings" "sync" "sync/atomic" "time" @@ -66,31 +64,13 @@ const ( const ocspCacheKey = "ocsp_cache" const ( - // defaultOCSPCacheServerTimeout is the total timeout for OCSP cache server. - defaultOCSPCacheServerTimeout = 5 * time.Second - // defaultOCSPResponderTimeout is the total timeout for OCSP responder. defaultOCSPResponderTimeout = 10 * time.Second ) const ( - cacheFileBaseName = "ocsp_response_cache.json" // cacheExpire specifies cache data expiration time in seconds. - cacheExpire = float64(24 * 60 * 60) - cacheServerURL = "http://ocsp.snowflakecomputing.com" - cacheServerEnabledEnv = "SF_OCSP_RESPONSE_CACHE_SERVER_ENABLED" - cacheServerURLEnv = "SF_OCSP_RESPONSE_CACHE_SERVER_URL" - cacheDirEnv = "SF_OCSP_RESPONSE_CACHE_DIR" - ocspRetryURLEnv = "SF_OCSP_RESPONSE_RETRY_URL" -) - -const ( - ocspTestInjectValidityErrorEnv = "SF_OCSP_TEST_INJECT_VALIDITY_ERROR" - ocspTestInjectUnknownStatusEnv = "SF_OCSP_TEST_INJECT_UNKNOWN_STATUS" - ocspTestResponseCacheServerTimeoutEnv = "SF_OCSP_TEST_OCSP_RESPONSE_CACHE_SERVER_TIMEOUT" - ocspTestResponderTimeoutEnv = "SF_OCSP_TEST_OCSP_RESPONDER_TIMEOUT" - ocspTestResponderURLEnv = "SF_OCSP_TEST_RESPONDER_URL" - ocspTestNoOCSPURLEnv = "SF_OCSP_TEST_NO_OCSP_RESPONDER_URL" + cacheExpire = float64(24 * 60 * 60) ) const ( @@ -131,18 +111,10 @@ const ( ocspStatusRevoked ocspStatusCode = -2 ocspStatusUnknown ocspStatusCode = -3 ocspStatusOthers ocspStatusCode = -4 - ocspNoServer ocspStatusCode = -5 - ocspFailedParseOCSPHost ocspStatusCode = -6 - ocspFailedComposeRequest ocspStatusCode = -7 - ocspFailedDecomposeRequest ocspStatusCode = -8 - ocspFailedSubmit ocspStatusCode = -9 - ocspFailedResponse ocspStatusCode = -10 - ocspFailedExtractResponse ocspStatusCode = -11 - ocspFailedParseResponse ocspStatusCode = -12 - ocspInvalidValidity ocspStatusCode = -13 - ocspMissedCache ocspStatusCode = -14 - ocspCacheExpired ocspStatusCode = -15 - ocspFailedDecodeResponse ocspStatusCode = -16 + ocspFailedDecomposeRequest ocspStatusCode = -5 + ocspInvalidValidity ocspStatusCode = -6 + ocspMissedCache ocspStatusCode = -7 + ocspCacheExpired ocspStatusCode = -8 ) // copied from crypto/ocsp.go @@ -216,10 +188,6 @@ func isInValidityRange(currTime, thisUpdate, nextUpdate time.Time) bool { return true } -func isTestInvalidValidity() bool { - return strings.EqualFold(os.Getenv(ocspTestInjectValidityErrorEnv), "true") -} - func extractCertIDKeyFromRequest(ocspReq []byte) (*certIDKey, *ocspStatus) { r, err := ocsp.ParseRequest(ocspReq) if err != nil { @@ -295,9 +263,6 @@ func decodeCertIDKey(k *certIDKey) (string, error) { } func (c *Client) checkOCSPResponseCache(encodedCertID *certIDKey, subject, issuer *x509.Certificate) (*ocspStatus, error) { - if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") { - return &ocspStatus{code: ocspNoServer}, nil - } c.ocspResponseCacheLock.RLock() gotValueFromCache := c.ocspResponseCache[*encodedCertID] c.ocspResponseCacheLock.RUnlock() @@ -325,11 +290,11 @@ func validateOCSP(ocspRes *ocsp.Response) (*ocspStatus, error) { if ocspRes == nil { return nil, errors.New("OCSP Response is nil") } - if isTestInvalidValidity() || !isInValidityRange(curTime, ocspRes.ThisUpdate, ocspRes.NextUpdate) { - return nil, fmt.Errorf("invalid validity: producedAt: %v, thisUpdate: %v, nextUpdate: %v", ocspRes.ProducedAt, ocspRes.ThisUpdate, ocspRes.NextUpdate) - } - if isTestUnknownStatus() { - ocspRes.Status = ocsp.Unknown + if !isInValidityRange(curTime, ocspRes.ThisUpdate, ocspRes.NextUpdate) { + return &ocspStatus{ + code: ocspInvalidValidity, + err: fmt.Errorf("invalid validity: producedAt: %v, thisUpdate: %v, nextUpdate: %v", ocspRes.ProducedAt, ocspRes.ThisUpdate, ocspRes.NextUpdate), + }, nil } return returnOCSPStatus(ocspRes), nil } @@ -358,49 +323,6 @@ func returnOCSPStatus(ocspRes *ocsp.Response) *ocspStatus { } } -func isTestUnknownStatus() bool { - return strings.EqualFold(os.Getenv(ocspTestInjectUnknownStatusEnv), "true") -} - -func (c *Client) checkOCSPCacheServer( - ctx context.Context, - client clientInterface, - req requestFunc, - ocspServerHost *url.URL, - totalTimeout time.Duration) ( - cacheContent *map[string][]interface{}, - ocspS *ocspStatus, err error) { - var respd map[string][]interface{} - - request, err := req("GET", ocspServerHost.Hostname(), nil) - if err != nil { - return nil, nil, err - } - res, err := client.Do(request) - if err != nil { - return nil, nil, err - } // newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout).execute() - - defer res.Body.Close() - c.Logger().Debug("StatusCode from OCSP Cache Server", "statusCode", res.StatusCode) - if res.StatusCode != http.StatusOK { - return nil, nil, fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status) - } - c.Logger().Debug("reading contents") - - dec := json.NewDecoder(res.Body) - for { - if err := dec.Decode(&respd); err == io.EOF { - break - } else if err != nil { - return nil, nil, err - } - } - return &respd, &ocspStatus{ - code: ocspSuccess, - }, nil -} - // retryOCSP is the second level of retry method if the returned contents are corrupted. It often happens with OCSP // serer and retry helps. func (c *Client) retryOCSP( @@ -695,8 +617,6 @@ func (c *Client) extractOCSPCacheResponseValue(cacheValue *ocspCachedResponse, s return validateOCSP(r) } -const storageValueKey = "backingStore" - // writeOCSPCache writes a OCSP Response cache func (c *Client) writeOCSPCache(ctx context.Context, storage logical.Storage) error { c.Logger().Debug("writing OCSP Response cache file") @@ -744,16 +664,16 @@ func (c *Client) readOCSPCache(ctx context.Context, storage logical.Storage) err if err != nil { return err } - var time float64 + var ts float64 if jn, ok := v[0].(json.Number); ok { - time, err = jn.Float64() + ts, err = jn.Float64() if err != nil { return err } } else { - time = v[0].(float64) + ts = v[0].(float64) } - c.ocspResponseCache[*key] = &ocspCachedResponse{time: time, resp: v[1].(string)} + c.ocspResponseCache[*key] = &ocspCachedResponse{time: ts, resp: v[1].(string)} } return nil } @@ -818,8 +738,11 @@ func (c *Client) WriteCache(ctx context.Context, storage logical.Storage) error c.ocspResponseCacheLock.Lock() defer c.ocspResponseCacheLock.Unlock() if c.cacheUpdated { - return c.writeOCSPCache(ctx, storage) - c.cacheUpdated = false + err := c.writeOCSPCache(ctx, storage) + if err == nil { + c.cacheUpdated = false + } + return err } return nil } diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go index 7b00092e3df1..eb8d7246d20e 100644 --- a/sdk/helper/ocsp/ocsp_test.go +++ b/sdk/helper/ocsp/ocsp_test.go @@ -17,7 +17,6 @@ import ( "net" "net/http" "net/url" - "os" "testing" "time" @@ -25,10 +24,6 @@ import ( ) func TestOCSP(t *testing.T) { - cacheServerEnabled := []string{ - "true", - "false", - } targetURL := []string{ "https://sfcdev1.blob.core.windows.net/", "https://sfctest0.snowflakecomputing.com/", @@ -41,34 +36,29 @@ func TestOCSP(t *testing.T) { c.NewTransport(nil, nil), } - for _, enabled := range cacheServerEnabled { - for _, tgt := range targetURL { - _ = os.Setenv(cacheServerEnabledEnv, enabled) - //_ = os.Remove(cacheFileName) // clear cache file - c.ocspResponseCache = make(map[certIDKey]*ocspCachedResponse) - for _, tr := range transports { - c := &http.Client{ - Transport: tr, - Timeout: 30 * time.Second, - } - req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil)) - if err != nil { - t.Fatalf("fail to create a request. err: %v", err) - } - res, err := c.Do(req) - if err != nil { - t.Fatalf("failed to GET contents. err: %v", err) - } - defer res.Body.Close() - _, err = ioutil.ReadAll(res.Body) - if err != nil { - t.Fatalf("failed to read content body for %v", tgt) - } - + for _, tgt := range targetURL { + c.ocspResponseCache = make(map[certIDKey]*ocspCachedResponse) + for _, tr := range transports { + c := &http.Client{ + Transport: tr, + Timeout: 30 * time.Second, } + req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil)) + if err != nil { + t.Fatalf("fail to create a request. err: %v", err) + } + res, err := c.Do(req) + if err != nil { + t.Fatalf("failed to GET contents. err: %v", err) + } + defer res.Body.Close() + _, err = ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("failed to read content body for %v", tgt) + } + } } - _ = os.Unsetenv(cacheServerEnabledEnv) } type tcValidityRange struct { @@ -353,38 +343,6 @@ func TestOCSPRetry(t *testing.T) { } } -func TestOCSPCacheServerRetry(t *testing.T) { - c := New(testLogFactory) - dummyOCSPHost := &url.URL{ - Scheme: "https", - Host: "dummyOCSPHost", - } - client := &fakeHTTPClient{ - cnt: 3, - success: true, - body: []byte{1, 2, 3}, - logger: hclog.New(hclog.DefaultOptions), - t: t, - } - res, _, err := c.checkOCSPCacheServer( - context.TODO(), client, fakeRequestFunc, dummyOCSPHost, 20*time.Second) - if err == nil { - t.Errorf("should fail: %v", res) - } - client = &fakeHTTPClient{ - cnt: 30, - success: true, - body: []byte{1, 2, 3}, - logger: hclog.New(hclog.DefaultOptions), - t: t, - } - res, _, err = c.checkOCSPCacheServer( - context.TODO(), client, fakeRequestFunc, dummyOCSPHost, 10*time.Second) - if err == nil { - t.Errorf("should fail: %v", res) - } -} - type tcCanEarlyExit struct { results []*ocspStatus resultLen int From 0b9f7855dbf95676d36d2e8b75f6c32a3bcd0aeb Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Fri, 9 Sep 2022 15:12:11 -0500 Subject: [PATCH 05/39] Fix unit tests --- sdk/helper/ocsp/ocsp_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go index eb8d7246d20e..f8ca74a43956 100644 --- a/sdk/helper/ocsp/ocsp_test.go +++ b/sdk/helper/ocsp/ocsp_test.go @@ -219,7 +219,7 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) { // invalid validity c.ocspResponseCache[dummyKey] = &ocspCachedResponse{float64(currentTime - 1000), actualOcspResponse} ost, err = c.checkOCSPResponseCache(&dummyKey, subject, nil) - if err == nil { + if err == nil && isValidOCSPStatus(ost.code) { t.Fatalf("should have failed.") } } @@ -227,7 +227,7 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) { func TestUnitValidateOCSP(t *testing.T) { ocspRes := &ocsp.Response{} ost, err := validateOCSP(ocspRes) - if err == nil { + if err == nil && isValidOCSPStatus(ost.code) { t.Fatalf("should have failed.") } From fb65fd5cebf0120f0bca8bc80c8218e6426bda7c Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Mon, 12 Sep 2022 12:12:44 -0500 Subject: [PATCH 06/39] Use an LRU cache, and only persist up to 1000 of the most recently used values to stay under the storage entry limit --- sdk/helper/ocsp/client.go | 113 ++++++++++++++++++++++------------- sdk/helper/ocsp/ocsp_test.go | 32 ++-------- 2 files changed, 77 insertions(+), 68 deletions(-) diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index 4daf9739d2f7..a73316c2b3f7 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -15,6 +15,7 @@ import ( "errors" "fmt" "github.com/hashicorp/go-hclog" + lru "github.com/hashicorp/golang-lru" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" "golang.org/x/crypto/ocsp" @@ -70,7 +71,9 @@ const ( const ( // cacheExpire specifies cache data expiration time in seconds. - cacheExpire = float64(24 * 60 * 60) + cacheExpire = float64(24 * 60 * 60) + cacheSize = 10000 + persistedCacheSize = 1000 ) const ( @@ -79,8 +82,11 @@ const ( ) type ocspCachedResponse struct { - time float64 - resp string + time float64 + producedAt float64 + thisUpdate float64 + nextUpdate float64 + status ocspStatusCode } type Client struct { @@ -88,7 +94,7 @@ type Client struct { caRoot map[string]*x509.Certificate // certPOol includes the CA certificates. certPool *x509.CertPool - ocspResponseCache map[certIDKey]*ocspCachedResponse + ocspResponseCache *lru.TwoQueueCache ocspResponseCacheLock sync.RWMutex // cacheUpdated is true if the memory cache is updated cacheUpdated bool @@ -264,10 +270,14 @@ func decodeCertIDKey(k *certIDKey) (string, error) { func (c *Client) checkOCSPResponseCache(encodedCertID *certIDKey, subject, issuer *x509.Certificate) (*ocspStatus, error) { c.ocspResponseCacheLock.RLock() - gotValueFromCache := c.ocspResponseCache[*encodedCertID] + var cacheValue *ocspCachedResponse + v, ok := c.ocspResponseCache.Get(*encodedCertID) + if ok { + cacheValue = v.(*ocspCachedResponse) + } c.ocspResponseCacheLock.RUnlock() - status, err := c.extractOCSPCacheResponseValue(gotValueFromCache, subject, issuer) + status, err := c.extractOCSPCacheResponseValue(cacheValue, subject, issuer) if err != nil { return nil, err } @@ -279,7 +289,7 @@ func (c *Client) checkOCSPResponseCache(encodedCertID *certIDKey, subject, issue func (c *Client) deleteOCSPCache(encodedCertID *certIDKey) { c.ocspResponseCacheLock.Lock() - delete(c.ocspResponseCache, *encodedCertID) + c.ocspResponseCache.Remove(*encodedCertID) c.cacheUpdated = true c.ocspResponseCacheLock.Unlock() } @@ -397,7 +407,7 @@ func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509. } var ret *ocspStatus - var ocspResBytes []byte + var ocspRes *ocsp.Response for _, ocspHost := range ocspHosts { u, err := url.Parse(ocspHost) if err != nil { @@ -417,9 +427,8 @@ func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509. Timeout: timeout, Transport: newInsecureOcspTransport(extraCas), } - var ocspRes *ocsp.Response var ocspS *ocspStatus - ocspRes, ocspResBytes, ocspS, err = c.retryOCSP( + ocspRes, _, ocspS, err = c.retryOCSP( ctx, ocspClient, http.NewRequest, u, headers, ocspReq, issuer) if err != nil { return nil, err @@ -439,9 +448,15 @@ func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509. if !isValidOCSPStatus(ret.code) { return ret, nil } - v := ocspCachedResponse{time: float64(time.Now().UTC().Unix()), resp: base64.StdEncoding.EncodeToString(ocspResBytes)} + v := ocspCachedResponse{ + time: float64(time.Now().UTC().Unix()), + producedAt: float64(ocspRes.ProducedAt.UTC().Unix()), + thisUpdate: float64(ocspRes.ThisUpdate.UTC().Unix()), + nextUpdate: float64(ocspRes.NextUpdate.UTC().Unix()), + } + c.ocspResponseCacheLock.Lock() - c.ocspResponseCache[*encodedCertID] = &v + c.ocspResponseCache.Add(encodedCertID, &v) c.cacheUpdated = true c.ocspResponseCacheLock.Unlock() return ret, nil @@ -599,35 +614,38 @@ func (c *Client) extractOCSPCacheResponseValue(cacheValue *ocspCachedResponse, s }, nil } - var err error - var r *ocsp.Response - var b []byte - b, err = base64.StdEncoding.DecodeString(cacheValue.resp) - if err != nil { - return nil, fmt.Errorf("failed to decode OCSP Response value in a cache. subject: %v, err: %v", subjectName, err) - - } - // check the revocation status here - r, err = ocsp.ParseResponse(b, issuer) - if err != nil { - c.Logger().Warn("the second cache element is not a valid OCSP Response. Ignored.", "subject", subjectName) - return nil, fmt.Errorf("failed to parse OCSP Respose. subject: %v, err: %v", subjectName, err) - } - - return validateOCSP(r) + return validateOCSP(&ocsp.Response{ + ProducedAt: time.Unix(int64(cacheValue.producedAt), 0).UTC(), + ThisUpdate: time.Unix(int64(cacheValue.thisUpdate), 0).UTC(), + NextUpdate: time.Unix(int64(cacheValue.nextUpdate), 0).UTC(), + Status: int(cacheValue.status), + }) } // writeOCSPCache writes a OCSP Response cache func (c *Client) writeOCSPCache(ctx context.Context, storage logical.Storage) error { c.Logger().Debug("writing OCSP Response cache file") + t := time.Now() m := make(map[string][]interface{}) - for k, entry := range c.ocspResponseCache { - cacheKeyInBase64, err := decodeCertIDKey(&k) - if err != nil { - return err + keys := c.ocspResponseCache.Keys() + if len(keys) > persistedCacheSize { + keys = keys[:persistedCacheSize] + } + for _, k := range keys { + e, ok := c.ocspResponseCache.Get(k) + if ok { + entry := e.(*ocspCachedResponse) + // Don't store if expired + if isInValidityRange(t, time.Unix(int64(entry.thisUpdate), 0), time.Unix(int64(entry.nextUpdate), 0)) { + key := k.(certIDKey) + cacheKeyInBase64, err := decodeCertIDKey(&key) + if err != nil { + return err + } + m[cacheKeyInBase64] = []interface{}{entry.status, entry.time, entry.producedAt, entry.thisUpdate, entry.nextUpdate} + } } - m[cacheKeyInBase64] = []interface{}{entry.time, entry.resp} } v, err := jsonutil.EncodeJSONAndCompress(m, nil) @@ -664,24 +682,35 @@ func (c *Client) readOCSPCache(ctx context.Context, storage logical.Storage) err if err != nil { return err } - var ts float64 - if jn, ok := v[0].(json.Number); ok { - ts, err = jn.Float64() - if err != nil { - return err + var times [4]float64 + for i, t := range v[1:] { + if jn, ok := t.(json.Number); ok { + times[i], err = jn.Float64() + if err != nil { + return err + } + } else { + times[i] = t.(float64) } - } else { - ts = v[0].(float64) } - c.ocspResponseCache[*key] = &ocspCachedResponse{time: ts, resp: v[1].(string)} + + c.ocspResponseCache.Add(*key, &ocspCachedResponse{ + status: ocspStatusCode(v[0].(int)), + time: times[0], + producedAt: times[1], + thisUpdate: times[2], + nextUpdate: times[3], + }) } + return nil } func New(logFactory func() hclog.Logger) *Client { + cache, _ := lru.New2Q(cacheSize) c := Client{ caRoot: make(map[string]*x509.Certificate), - ocspResponseCache: make(map[certIDKey]*ocspCachedResponse), + ocspResponseCache: cache, logFactory: logFactory, } diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go index f8ca74a43956..e3007efbfb3e 100644 --- a/sdk/helper/ocsp/ocsp_test.go +++ b/sdk/helper/ocsp/ocsp_test.go @@ -8,10 +8,10 @@ import ( "crypto" "crypto/tls" "crypto/x509" - "encoding/base64" "errors" "fmt" "github.com/hashicorp/go-hclog" + lru "github.com/hashicorp/golang-lru" "io" "io/ioutil" "net" @@ -37,7 +37,7 @@ func TestOCSP(t *testing.T) { } for _, tgt := range targetURL { - c.ocspResponseCache = make(map[certIDKey]*ocspCachedResponse) + c.ocspResponseCache, _ = lru.New2Q(10) for _, tr := range transports { c := &http.Client{ Transport: tr, @@ -175,9 +175,8 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) { IssuerKeyHash: "dummy1", SerialNumber: "dummy1", } - b64Key := base64.StdEncoding.EncodeToString([]byte("DUMMY_VALUE")) currentTime := float64(time.Now().UTC().Unix()) - c.ocspResponseCache[dummyKey0] = &ocspCachedResponse{currentTime, b64Key} + c.ocspResponseCache.Add(dummyKey0, &ocspCachedResponse{time: currentTime}) subject := &x509.Certificate{} issuer := &x509.Certificate{} ost, err := c.checkOCSPResponseCache(&dummyKey, subject, issuer) @@ -188,7 +187,7 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) { t.Fatalf("should have failed. expected: %v, got: %v", ocspMissedCache, ost.code) } // old timestamp - c.ocspResponseCache[dummyKey] = &ocspCachedResponse{float64(1395054952), b64Key} + c.ocspResponseCache.Add(dummyKey, &ocspCachedResponse{time: float64(1395054952)}) ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) if err != nil { t.Fatal(err) @@ -196,28 +195,9 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) { if ost.code != ocspCacheExpired { t.Fatalf("should have failed. expected: %v, got: %v", ocspCacheExpired, ost.code) } - // future timestamp - c.ocspResponseCache[dummyKey] = &ocspCachedResponse{float64(1805054952), b64Key} - ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) - if err == nil { - t.Fatalf("should have failed.") - } - // actual OCSP but it fails to parse, because an invalid issuer certificate is given. - actualOcspResponse := "MIIB0woBAKCCAcwwggHIBgkrBgEFBQcwAQEEggG5MIIBtTCBnqIWBBSxPsNpA/i/RwHUmCYaCALvY2QrwxgPMjAxNz" + - "A1MTYyMjAwMDBaMHMwcTBJMAkGBSsOAwIaBQAEFN+qEuMosQlBk+KfQoLOR0BClVijBBSxPsNpA/i/RwHUmCYaCALvY2QrwwIQBOHnp" + - "Nxc8vNtwCtCuF0Vn4AAGA8yMDE3MDUxNjIyMDAwMFqgERgPMjAxNzA1MjMyMjAwMDBaMA0GCSqGSIb3DQEBCwUAA4IBAQCuRGwqQsKy" + - "IAAGHgezTfG0PzMYgGD/XRDhU+2i08WTJ4Zs40Lu88cBeRXWF3iiJSpiX3/OLgfI7iXmHX9/sm2SmeNWc0Kb39bk5Lw1jwezf8hcI9+" + - "mZHt60vhUgtgZk21SsRlTZ+S4VXwtDqB1Nhv6cnSnfrL2A9qJDZS2ltPNOwebWJnznDAs2dg+KxmT2yBXpHM1kb0EOolWvNgORbgIgB" + - "koRzw/UU7zKsqiTB0ZN/rgJp+MocTdqQSGKvbZyR8d4u8eNQqi1x4Pk3yO/pftANFaJKGB+JPgKS3PQAqJaXcipNcEfqtl7y4PO6kqA" + - "Jb4xI/OTXIrRA5TsT4cCioE" - // issuer is not a true issuer certificate - c.ocspResponseCache[dummyKey] = &ocspCachedResponse{float64(currentTime - 1000), actualOcspResponse} - ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) - if err == nil { - t.Fatalf("should have failed.") - } + // invalid validity - c.ocspResponseCache[dummyKey] = &ocspCachedResponse{float64(currentTime - 1000), actualOcspResponse} + c.ocspResponseCache.Add(dummyKey, &ocspCachedResponse{time: float64(currentTime - 1000)}) ost, err = c.checkOCSPResponseCache(&dummyKey, subject, nil) if err == nil && isValidOCSPStatus(ost.code) { t.Fatalf("should have failed.") From 5d768c013136fe9f006b75e611c6d2882121b25d Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Mon, 12 Sep 2022 12:35:50 -0500 Subject: [PATCH 07/39] Fix caching, add fail open mode parameter to cert auth roles --- builtin/credential/cert/path_certs.go | 15 +++++-- builtin/credential/cert/path_login.go | 22 ++++++---- sdk/helper/ocsp/client.go | 60 +++++++++++++-------------- sdk/helper/ocsp/ocsp_test.go | 8 ++-- 4 files changed, 59 insertions(+), 46 deletions(-) diff --git a/builtin/credential/cert/path_certs.go b/builtin/credential/cert/path_certs.go index b1d9e53470fc..de9cd9d157fc 100644 --- a/builtin/credential/cert/path_certs.go +++ b/builtin/credential/cert/path_certs.go @@ -63,6 +63,11 @@ Must be x509 PEM encoded.`, Description: `A comma-separated list of OCSP server addresses. If unset, the OCSP server is determined from the AuthorityInformationAccess extension on the certificate being inspected.`, }, + "ocsp_fail_open": { + Type: framework.TypeBool, + Default: false, + Description: "If set to true, if an OCSP revocation cannot be made successfully, login will proceed rather. If false, failing to get an OCSP status fails the request.", + }, "allowed_names": { Type: framework.TypeCommaStringSlice, Description: `A comma-separated list of names. @@ -315,12 +320,15 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr if ocspEnabledRaw, ok := d.GetOk("ocsp_enabled"); ok { cert.OcspEnabled = ocspEnabledRaw.(bool) } - if displayNameRaw, ok := d.GetOk("display_name"); ok { - cert.DisplayName = displayNameRaw.(string) - } if ocspServerOverrides, ok := d.GetOk("ocsp_servers_override"); ok { cert.OcspServersOverride = ocspServerOverrides.([]string) } + if ocspFailOpen, ok := d.GetOk("ocsp_fail_open"); ok { + cert.OcspFailOpen = ocspFailOpen.(bool) + } + if displayNameRaw, ok := d.GetOk("display_name"); ok { + cert.DisplayName = displayNameRaw.(string) + } if allowedNamesRaw, ok := d.GetOk("allowed_names"); ok { cert.AllowedNames = allowedNamesRaw.([]string) } @@ -451,6 +459,7 @@ type CertEntry struct { OcspCaCertificates string OcspEnabled bool OcspServersOverride []string + OcspFailOpen bool DisplayName string Policies []string TTL time.Duration diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 868d7f40b9bb..85845068b6c3 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -10,6 +10,7 @@ import ( "encoding/pem" "errors" "fmt" + "github.com/hashicorp/vault/sdk/helper/ocsp" "strings" "github.com/hashicorp/vault/sdk/framework" @@ -225,7 +226,7 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d } // Load the trusted certificates and other details - roots, trusted, trustedNonCAs, ocspServersOverride := b.loadTrustedCerts(ctx, req.Storage, certName) + roots, trusted, trustedNonCAs, ocspServersOverride, ocspFailureMode := b.loadTrustedCerts(ctx, req.Storage, certName) // Get the list of full chains matching the connection and validates the // certificate itself @@ -247,7 +248,7 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d // Check for client cert being explicitly listed in the config (and matching other constraints) if tCert.SerialNumber.Cmp(clientCert.SerialNumber) == 0 && bytes.Equal(tCert.AuthorityKeyId, clientCert.AuthorityKeyId) { - matches, err := b.matchesConstraints(ctx, clientCert, trustedNonCA.Certificates, trustedNonCA, extraCas, ocspServersOverride) + matches, err := b.matchesConstraints(ctx, clientCert, trustedNonCA.Certificates, trustedNonCA, extraCas, ocspServersOverride, ocspFailureMode) if err != nil { return nil, nil, err } @@ -271,7 +272,7 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d for _, chain := range trustedChains { // For each root chain that we matched for _, cCert := range chain { // For each cert in the matched chain if tCert.Equal(cCert) { // ParsedCert intersects with matched chain - match, err := b.matchesConstraints(ctx, clientCert, chain, trust, extraCas, ocspServersOverride) // validate client cert + matched chain against the config + match, err := b.matchesConstraints(ctx, clientCert, chain, trust, extraCas, ocspServersOverride, ocspFailureMode) // validate client cert + matched chain against the config if err != nil { return nil, nil, err } @@ -295,7 +296,7 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d } func (b *backend) matchesConstraints(ctx context.Context, clientCert *x509.Certificate, trustedChain []*x509.Certificate, - config *ParsedCert, extraCas []*x509.Certificate, ocspServersOverride []string) (bool, error) { + config *ParsedCert, extraCas []*x509.Certificate, ocspServersOverride []string, ocspFailureMode ocsp.FailOpenMode) (bool, error) { soFar := !b.checkForChainInCRLs(trustedChain) && b.matchesNames(clientCert, config) && b.matchesCommonName(clientCert, config) && @@ -305,7 +306,7 @@ func (b *backend) matchesConstraints(ctx context.Context, clientCert *x509.Certi b.matchesOrganizationalUnits(clientCert, config) && b.matchesCertificateExtensions(clientCert, config) if config.Entry.OcspEnabled { - ocspGood, err := b.checkForChainInOCSP(ctx, trustedChain, extraCas, ocspServersOverride) + ocspGood, err := b.checkForChainInOCSP(ctx, trustedChain, extraCas, ocspServersOverride, ocspFailureMode) if err != nil { return false, err } @@ -502,7 +503,7 @@ func (b *backend) certificateExtensionsMetadata(clientCert *x509.Certificate, co } // loadTrustedCerts is used to load all the trusted certificates from the backend -func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert, ocspServersOverride []string) { +func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert, ocspServersOverride []string, ocspFailureMode ocsp.FailOpenMode) { pool = x509.NewCertPool() trusted = make([]*ParsedCert, 0) trustedNonCAs = make([]*ParsedCert, 0) @@ -555,15 +556,20 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, }) } ocspServersOverride = entry.OcspServersOverride + if entry.OcspFailOpen { + ocspFailureMode = ocsp.FailOpenTrue + } else { + ocspFailureMode = ocsp.FailOpenFalse + } } return } -func (b *backend) checkForChainInOCSP(ctx context.Context, chain []*x509.Certificate, extraCas []*x509.Certificate, ocspServersOverride []string) (bool, error) { +func (b *backend) checkForChainInOCSP(ctx context.Context, chain []*x509.Certificate, extraCas []*x509.Certificate, ocspServersOverride []string, ocspFailureMode ocsp.FailOpenMode) (bool, error) { if b.ocspDisabled || len(chain) < 2 { return true, nil } - err := b.ocspClient.VerifyPeerCertificate(ctx, [][]*x509.Certificate{chain}, extraCas, ocspServersOverride) + err := b.ocspClient.VerifyPeerCertificate(ctx, [][]*x509.Certificate{chain}, extraCas, ocspServersOverride, ocspFailureMode) if err != nil { return false, nil } diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index a73316c2b3f7..660855eaafc7 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -27,13 +27,12 @@ import ( "net/url" "strconv" "sync" - "sync/atomic" "time" ) -// OCSPFailOpenMode is OCSP fail open mode. OCSPFailOpenTrue by default and may +// FailOpenMode is OCSP fail open mode. FailOpenTrue by default and may // set to ocspModeFailClosed for fail closed mode -type OCSPFailOpenMode uint32 +type FailOpenMode uint32 type requestFunc func(method, urlStr string, body io.Reader) (*http.Request, error) @@ -49,11 +48,11 @@ const ( ) const ( - ocspFailOpenNotSet OCSPFailOpenMode = iota - // OCSPFailOpenTrue represents OCSP fail open mode. - OCSPFailOpenTrue - // OCSPFailOpenFalse represents OCSP fail closed mode. - OCSPFailOpenFalse + ocspFailOpenNotSet FailOpenMode = iota + // FailOpenTrue represents OCSP fail open mode. + FailOpenTrue + // FailOpenFalse represents OCSP fail closed mode. + FailOpenFalse ) const ( @@ -99,9 +98,6 @@ type Client struct { // cacheUpdated is true if the memory cache is updated cacheUpdated bool logFactory func() hclog.Logger - - // OCSP fail open mode - ocspFailOpen OCSPFailOpenMode } type ocspStatusCode int @@ -342,10 +338,7 @@ func (c *Client) retryOCSP( ocspHost *url.URL, headers map[string]string, reqBody []byte, - issuer *x509.Certificate) ( - ocspRes *ocsp.Response, - ocspResBytes []byte, - ocspS *ocspStatus, err error) { + issuer *x509.Certificate) (ocspRes *ocsp.Response, ocspResBytes []byte, ocspS *ocspStatus, err error) { request, err := req("POST", ocspHost.String(), bytes.NewBuffer(reqBody)) if err != nil { @@ -456,7 +449,7 @@ func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509. } c.ocspResponseCacheLock.Lock() - c.ocspResponseCache.Add(encodedCertID, &v) + c.ocspResponseCache.Add(*encodedCertID, &v) c.cacheUpdated = true c.ocspResponseCacheLock.Unlock() return ret, nil @@ -467,7 +460,7 @@ func isValidOCSPStatus(status ocspStatusCode) bool { } // VerifyPeerCertificate verifies all of certificate revocation status -func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate, extraCas []*x509.Certificate, ocspServersOverride []string) (err error) { +func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate, extraCas []*x509.Certificate, ocspServersOverride []string, ocspFailureMode FailOpenMode) (err error) { for i := 0; i < len(verifiedChains); i++ { // Certificate signed by Root CA. This should be one before the last in the Certificate Chain numberOfNoneRootCerts := len(verifiedChains[i]) - 1 @@ -485,7 +478,7 @@ func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]* if err != nil { return err } - if r := c.canEarlyExitForOCSP(results, numberOfNoneRootCerts); r != nil { + if r := c.canEarlyExitForOCSP(results, numberOfNoneRootCerts, ocspFailureMode); r != nil { return r.err } } @@ -493,9 +486,9 @@ func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]* return nil } -func (c *Client) canEarlyExitForOCSP(results []*ocspStatus, chainSize int) *ocspStatus { +func (c *Client) canEarlyExitForOCSP(results []*ocspStatus, chainSize int, ocspFailureMode FailOpenMode) *ocspStatus { msg := "" - if atomic.LoadUint32((*uint32)(&c.ocspFailOpen)) == (uint32)(OCSPFailOpenFalse) { + if ocspFailureMode == FailOpenFalse { // Fail closed. any error is returned to stop connection for _, r := range results { if r.err != nil { @@ -522,10 +515,7 @@ func (c *Client) canEarlyExitForOCSP(results []*ocspStatus, chainSize int) *ocsp } if len(msg) > 0 { c.Logger().Warn( - "WARNING!!! Using fail-open to connect. Driver is connecting to an "+ - "HTTPS endpoint without OCSP based Certificate Revocation checking "+ - "as it could not obtain a valid OCSP Response to use from the CA OCSP "+ - "responder", "detail", msg[1:]) + "OCSP is set to fail-open, and could not retrieve OCSP based revocation checking but proceeding.", "detail", msg[1:]) } return nil } @@ -582,9 +572,9 @@ func (c *Client) GetAllRevocationStatus(ctx context.Context, verifiedChains, ext } // verifyPeerCertificateSerial verifies the certificate revocation status in serial. -func (c *Client) verifyPeerCertificateSerial(extraCas []*x509.Certificate, ocspServersOverride []string) func(_ [][]byte, verifiedChains [][]*x509.Certificate) (err error) { +func (c *Client) verifyPeerCertificateSerial(extraCas []*x509.Certificate, ocspServersOverride []string, ocspFailureMode FailOpenMode) func(_ [][]byte, verifiedChains [][]*x509.Certificate) (err error) { return func(_ [][]byte, verifiedChains [][]*x509.Certificate) error { - return c.VerifyPeerCertificate(context.TODO(), verifiedChains, extraCas, ocspServersOverride) + return c.VerifyPeerCertificate(context.TODO(), verifiedChains, extraCas, ocspServersOverride, ocspFailureMode) } } @@ -624,7 +614,7 @@ func (c *Client) extractOCSPCacheResponseValue(cacheValue *ocspCachedResponse, s // writeOCSPCache writes a OCSP Response cache func (c *Client) writeOCSPCache(ctx context.Context, storage logical.Storage) error { - c.Logger().Debug("writing OCSP Response cache file") + c.Logger().Debug("writing OCSP Response cache entry") t := time.Now() m := make(map[string][]interface{}) @@ -693,9 +683,19 @@ func (c *Client) readOCSPCache(ctx context.Context, storage logical.Storage) err times[i] = t.(float64) } } + var status int + if jn, ok := v[0].(json.Number); ok { + s, err := jn.Int64() + if err != nil { + return err + } + status = int(s) + } else { + status = v[0].(int) + } c.ocspResponseCache.Add(*key, &ocspCachedResponse{ - status: ocspStatusCode(v[0].(int)), + status: ocspStatusCode(status), time: times[0], producedAt: times[1], thisUpdate: times[2], @@ -747,11 +747,11 @@ func newInsecureOcspTransport(extraCas []*x509.Certificate) *http.Transport { } // NewTransport includes the certificate revocation check with OCSP in sequential. -func (c *Client) NewTransport(extraCas []*x509.Certificate, ocspServersOverride []string) *http.Transport { +func (c *Client) NewTransport(extraCas []*x509.Certificate, ocspServersOverride []string, ocspFailureMode FailOpenMode) *http.Transport { return &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: c.certPool, - VerifyPeerCertificate: c.verifyPeerCertificateSerial(extraCas, ocspServersOverride), + VerifyPeerCertificate: c.verifyPeerCertificateSerial(extraCas, ocspServersOverride, ocspFailureMode), }, MaxIdleConns: 10, IdleConnTimeout: 30 * time.Minute, diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go index e3007efbfb3e..0e78048f6d5f 100644 --- a/sdk/helper/ocsp/ocsp_test.go +++ b/sdk/helper/ocsp/ocsp_test.go @@ -33,7 +33,7 @@ func TestOCSP(t *testing.T) { c := New(testLogFactory) transports := []*http.Transport{ newInsecureOcspTransport(nil), - c.NewTransport(nil, nil), + c.NewTransport(nil, nil, FailOpenFalse), } for _, tgt := range targetURL { @@ -412,17 +412,15 @@ func TestCanEarlyExitForOCSP(t *testing.T) { } c := New(testLogFactory) for idx, tt := range testcases { - c.ocspFailOpen = OCSPFailOpenTrue expectedLen := len(tt.results) if tt.resultLen > 0 { expectedLen = tt.resultLen } - r := c.canEarlyExitForOCSP(tt.results, expectedLen) + r := c.canEarlyExitForOCSP(tt.results, expectedLen, FailOpenTrue) if !(tt.retFailOpen == nil && r == nil) && !(tt.retFailOpen != nil && r != nil && tt.retFailOpen.code == r.code) { t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailOpen, r) } - c.ocspFailOpen = OCSPFailOpenFalse - r = c.canEarlyExitForOCSP(tt.results, expectedLen) + r = c.canEarlyExitForOCSP(tt.results, expectedLen, FailOpenFalse) if !(tt.retFailClosed == nil && r == nil) && !(tt.retFailClosed != nil && r != nil && tt.retFailClosed.code == r.code) { t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailClosed, r) } From 6e0063365d542943fbff7a0a8092d380c36df987 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Mon, 12 Sep 2022 12:37:26 -0500 Subject: [PATCH 08/39] reduce logging --- sdk/helper/ocsp/client.go | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index 660855eaafc7..066bb920fde5 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -359,12 +359,10 @@ func (c *Client) retryOCSP( if res.StatusCode != http.StatusOK { return nil, nil, nil, fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status) } - c.Logger().Debug("reading contents") ocspResBytes, err = ioutil.ReadAll(res.Body) if err != nil { return nil, nil, nil, err } - c.Logger().Debug("parsing OCSP response") ocspRes, err = ocsp.ParseResponse(ocspResBytes, issuer) if err != nil { return nil, nil, nil, err @@ -377,8 +375,6 @@ func (c *Client) retryOCSP( // getRevocationStatus checks the certificate revocation status for subject using issuer certificate. func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate, extraCas []*x509.Certificate, ocspServersOverride []string) (*ocspStatus, error) { - c.Logger().Debug("get-revocation-status", "subject", subject.Subject, "issuer", issuer.Subject) - status, ocspReq, encodedCertID, err := c.ValidateWithCache(subject, issuer) if err != nil { return nil, err @@ -389,8 +385,7 @@ func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509. if ocspReq == nil || encodedCertID == nil { return status, nil } - c.Logger().Debug("cache missed") - c.Logger().Debug("OCSP: ", "server", subject.OCSPServer) + c.Logger().Debug("cache missed", "server", subject.OCSPServer) if len(subject.OCSPServer) == 0 && len(ocspServersOverride) == 0 { return nil, fmt.Errorf("no OCSP responder URL: subject: %v", subject.Subject) } @@ -614,8 +609,7 @@ func (c *Client) extractOCSPCacheResponseValue(cacheValue *ocspCachedResponse, s // writeOCSPCache writes a OCSP Response cache func (c *Client) writeOCSPCache(ctx context.Context, storage logical.Storage) error { - c.Logger().Debug("writing OCSP Response cache entry") - + c.Logger().Debug("writing OCSP Response cache") t := time.Now() m := make(map[string][]interface{}) keys := c.ocspResponseCache.Keys() @@ -651,7 +645,7 @@ func (c *Client) writeOCSPCache(ctx context.Context, storage logical.Storage) er // readOCSPCache reads a OCSP Response cache from storage func (c *Client) readOCSPCache(ctx context.Context, storage logical.Storage) error { - c.Logger().Debug("reading OCSP Response cache entry") + c.Logger().Debug("reading OCSP Response cache") entry, err := storage.Get(ctx, ocspCacheKey) if err != nil { From 44d808d7b7b326a60b63c73181707b41f4195a22 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Mon, 12 Sep 2022 14:45:58 -0500 Subject: [PATCH 09/39] Add the retry client and GET then POST logic --- sdk/helper/ocsp/client.go | 63 +++++++++++++++++++++++------------- sdk/helper/ocsp/ocsp_test.go | 24 ++++++++------ 2 files changed, 56 insertions(+), 31 deletions(-) diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index 066bb920fde5..5e13d5c41dc1 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -15,11 +15,11 @@ import ( "errors" "fmt" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-retryablehttp" lru "github.com/hashicorp/golang-lru" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" "golang.org/x/crypto/ocsp" - "io" "io/ioutil" "math/big" "net" @@ -34,10 +34,10 @@ import ( // set to ocspModeFailClosed for fail closed mode type FailOpenMode uint32 -type requestFunc func(method, urlStr string, body io.Reader) (*http.Request, error) +type requestFunc func(method, urlStr string, body interface{}) (*retryablehttp.Request, error) type clientInterface interface { - Do(req *http.Request) (*http.Response, error) + Do(req *retryablehttp.Request) (*http.Response, error) } const ( @@ -340,22 +340,41 @@ func (c *Client) retryOCSP( reqBody []byte, issuer *x509.Certificate) (ocspRes *ocsp.Response, ocspResBytes []byte, ocspS *ocspStatus, err error) { - request, err := req("POST", ocspHost.String(), bytes.NewBuffer(reqBody)) - if err != nil { - return nil, nil, nil, err - } - if request != nil { - request = request.WithContext(ctx) - for k, v := range headers { - request.Header[k] = append(request.Header[k], v) + origHost := *ocspHost + doRequest := func(request *retryablehttp.Request) (*http.Response, error) { + if err != nil { + return nil, err + } + if request != nil { + request = request.WithContext(ctx) + for k, v := range headers { + request.Header[k] = append(request.Header[k], v) + } + } + res, err := client.Do(request) + if err != nil { + return nil, err } + c.Logger().Debug("StatusCode from OCSP Server:", "statusCode", res.StatusCode) + return res, err } - res, err := client.Do(request) - if err != nil { + + ocspHost.Path = ocspHost.Path + "/" + base64.StdEncoding.EncodeToString(reqBody) + var res *http.Response + request, err := req("GET", ocspHost.String(), nil) + if res, err = doRequest(request); err != nil { return nil, nil, nil, err + } else { + defer res.Body.Close() + } + if res.StatusCode == http.StatusMethodNotAllowed { + request, err = req("POST", origHost.String(), bytes.NewBuffer(reqBody)) + if res, err = doRequest(request); err != nil { + return nil, nil, nil, err + } else { + defer res.Body.Close() + } } - defer res.Body.Close() - c.Logger().Debug("StatusCode from OCSP Server:", "statusCode", res.StatusCode) if res.StatusCode != http.StatusOK { return nil, nil, nil, fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status) } @@ -411,13 +430,13 @@ func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509. headers[httpHeaderHost] = hostname timeout := defaultOCSPResponderTimeout - ocspClient := &http.Client{ - Timeout: timeout, - Transport: newInsecureOcspTransport(extraCas), - } + ocspClient := retryablehttp.NewClient() + ocspClient.HTTPClient.Timeout = timeout + ocspClient.HTTPClient.Transport = newInsecureOcspTransport(extraCas) + var ocspS *ocspStatus ocspRes, _, ocspS, err = c.retryOCSP( - ctx, ocspClient, http.NewRequest, u, headers, ocspReq, issuer) + ctx, ocspClient, retryablehttp.NewRequest, u, headers, ocspReq, issuer) if err != nil { return nil, err } @@ -455,7 +474,7 @@ func isValidOCSPStatus(status ocspStatusCode) bool { } // VerifyPeerCertificate verifies all of certificate revocation status -func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate, extraCas []*x509.Certificate, ocspServersOverride []string, ocspFailureMode FailOpenMode) (err error) { +func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate, extraCas []*x509.Certificate, ocspServersOverride []string, ocspFailureMode FailOpenMode) error { for i := 0; i < len(verifiedChains); i++ { // Certificate signed by Root CA. This should be one before the last in the Certificate Chain numberOfNoneRootCerts := len(verifiedChains[i]) - 1 @@ -510,7 +529,7 @@ func (c *Client) canEarlyExitForOCSP(results []*ocspStatus, chainSize int, ocspF } if len(msg) > 0 { c.Logger().Warn( - "OCSP is set to fail-open, and could not retrieve OCSP based revocation checking but proceeding.", "detail", msg[1:]) + "OCSP is set to fail-open, and could not retrieve OCSP based revocation checking but proceeding.", "detail", msg) } return nil } diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go index 0e78048f6d5f..4b198d24c5c1 100644 --- a/sdk/helper/ocsp/ocsp_test.go +++ b/sdk/helper/ocsp/ocsp_test.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-retryablehttp" lru "github.com/hashicorp/golang-lru" "io" "io/ioutil" @@ -434,15 +435,16 @@ func testLogFactory() hclog.Logger { } type fakeHTTPClient struct { - cnt int // number of retry - success bool // return success after retry in cnt times - timeout bool // timeout - body []byte // return body - t *testing.T - logger hclog.Logger + cnt int // number of retry + success bool // return success after retry in cnt times + timeout bool // timeout + body []byte // return body + t *testing.T + logger hclog.Logger + redirected bool } -func (c *fakeHTTPClient) Do(req *http.Request) (*http.Response, error) { +func (c *fakeHTTPClient) Do(_ *retryablehttp.Request) (*http.Response, error) { c.cnt-- if c.cnt < 0 { c.cnt = 0 @@ -450,7 +452,11 @@ func (c *fakeHTTPClient) Do(req *http.Request) (*http.Response, error) { c.t.Log("fakeHTTPClient.cnt", c.cnt) var retcode int - if c.success && c.cnt == 0 { + if !c.redirected { + c.redirected = true + c.cnt++ + retcode = 405 + } else if c.success && c.cnt == 1 { retcode = 200 } else { if c.timeout { @@ -499,6 +505,6 @@ func (b *fakeResponseBody) Close() error { return nil } -func fakeRequestFunc(_, _ string, _ io.Reader) (*http.Request, error) { +func fakeRequestFunc(_, _ string, _ interface{}) (*retryablehttp.Request, error) { return nil, nil } From 2b73de707996830fa07049132c0a6724caeec7af Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Mon, 12 Sep 2022 16:31:02 -0500 Subject: [PATCH 10/39] Drop persisted cache, make cache size configurable, allow for parallel testing of multiple servers --- builtin/credential/cert/backend.go | 34 +++--- builtin/credential/cert/path_certs.go | 20 ++++ builtin/credential/cert/path_config.go | 5 +- builtin/credential/cert/path_login.go | 26 +++-- sdk/helper/ocsp/client.go | 147 +++++++++++++++++-------- sdk/helper/ocsp/ocsp_test.go | 86 +++++++++++++-- 6 files changed, 239 insertions(+), 79 deletions(-) diff --git a/builtin/credential/cert/backend.go b/builtin/credential/cert/backend.go index adffe2c3062d..a2b72e6d877e 100644 --- a/builtin/credential/cert/backend.go +++ b/builtin/credential/cert/backend.go @@ -16,10 +16,13 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, if err := b.Setup(ctx, conf); err != nil { return nil, err } - - if err := b.ocspClient.ReadCache(ctx, conf.StorageView); err != nil { + bConf, err := b.Config(ctx, conf.StorageView) + if err != nil { return nil, err } + if conf != nil { + b.initOCSPClient(bConf.OcspCacheSize) + } return b, nil } @@ -39,16 +42,12 @@ func Backend() *backend { pathCerts(&b), pathCRLs(&b), }, - AuthRenew: b.pathLoginRenew, - Invalidate: b.invalidate, - BackendType: logical.TypeCredential, - PeriodicFunc: b.periodFunc, + AuthRenew: b.pathLoginRenew, + Invalidate: b.invalidate, + BackendType: logical.TypeCredential, } b.crlUpdateMutex = &sync.RWMutex{} - b.ocspClient = ocsp.New(func() hclog.Logger { - return b.Logger() - }) return &b } @@ -56,10 +55,11 @@ type backend struct { *framework.Backend MapCertId *framework.PathMap - crls map[string]CRLInfo - ocspDisabled bool - crlUpdateMutex *sync.RWMutex - ocspClient *ocsp.Client + crls map[string]CRLInfo + ocspDisabled bool + crlUpdateMutex *sync.RWMutex + ocspClientMutex sync.RWMutex + ocspClient *ocsp.Client } func (b *backend) invalidate(_ context.Context, key string) { @@ -71,8 +71,12 @@ func (b *backend) invalidate(_ context.Context, key string) { } } -func (b *backend) periodFunc(ctx context.Context, request *logical.Request) error { - return b.ocspClient.WriteCache(ctx, request.Storage) +func (b *backend) initOCSPClient(cacheSize int) { + b.ocspClientMutex.Lock() + defer b.ocspClientMutex.Unlock() + b.ocspClient = ocsp.New(func() hclog.Logger { + return b.Logger() + }, cacheSize) } const backendHelp = ` diff --git a/builtin/credential/cert/path_certs.go b/builtin/credential/cert/path_certs.go index de9cd9d157fc..ae1acae4a3de 100644 --- a/builtin/credential/cert/path_certs.go +++ b/builtin/credential/cert/path_certs.go @@ -13,6 +13,8 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) +const defaultOCSPCacheSize = 100 + func pathListCerts(b *backend) *framework.Path { return &framework.Path{ Pattern: "certs/?", @@ -68,6 +70,16 @@ from the AuthorityInformationAccess extension on the certificate being inspected Default: false, Description: "If set to true, if an OCSP revocation cannot be made successfully, login will proceed rather. If false, failing to get an OCSP status fails the request.", }, + "ocsp_query_all_servers": { + Type: framework.TypeBool, + Default: false, + Description: "If set to true, rather than accepting the first successful OCSP response, query all servers and consider the certificate valid only if all servers agree.", + }, + "ocsp_cache_size": { + Type: framework.TypeInt, + Default: defaultOCSPCacheSize, + Description: "The size of the OCSP response cache.", + }, "allowed_names": { Type: framework.TypeCommaStringSlice, Description: `A comma-separated list of names. @@ -326,6 +338,12 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr if ocspFailOpen, ok := d.GetOk("ocsp_fail_open"); ok { cert.OcspFailOpen = ocspFailOpen.(bool) } + if ocspQueryAll, ok := d.GetOk("ocsp_query_all_servers"); ok { + cert.OcspQueryAllServers = ocspQueryAll.(bool) + } + if ocspCacheSize, ok := d.GetOk("ocsp_cache_size"); ok { + cert.OcspCacheSize = ocspCacheSize.(int) + } if displayNameRaw, ok := d.GetOk("display_name"); ok { cert.DisplayName = displayNameRaw.(string) } @@ -460,6 +478,8 @@ type CertEntry struct { OcspEnabled bool OcspServersOverride []string OcspFailOpen bool + OcspCacheSize int + OcspQueryAllServers bool DisplayName string Policies []string TTL time.Duration diff --git a/builtin/credential/cert/path_config.go b/builtin/credential/cert/path_config.go index 9cc17f3a6aaf..0702ff97f356 100644 --- a/builtin/credential/cert/path_config.go +++ b/builtin/credential/cert/path_config.go @@ -34,10 +34,11 @@ func pathConfig(b *backend) *framework.Path { func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { disableBinding := data.Get("disable_binding").(bool) enableIdentityAliasMetadata := data.Get("enable_identity_alias_metadata").(bool) - + cacheSize := data.Get("ocsp_cache_size").(int) entry, err := logical.StorageEntryJSON("config", config{ DisableBinding: disableBinding, EnableIdentityAliasMetadata: enableIdentityAliasMetadata, + OcspCacheSize: cacheSize, }) if err != nil { return nil, err @@ -46,6 +47,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat if err := req.Storage.Put(ctx, entry); err != nil { return nil, err } + b.initOCSPClient(cacheSize) return nil, nil } @@ -85,4 +87,5 @@ func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error type config struct { DisableBinding bool `json:"disable_binding"` EnableIdentityAliasMetadata bool `json:"enable_identity_alias_metadata"` + OcspCacheSize int `json:"ocsp_cache_size"` } diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 85845068b6c3..3906004b4cd5 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -226,7 +226,7 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d } // Load the trusted certificates and other details - roots, trusted, trustedNonCAs, ocspServersOverride, ocspFailureMode := b.loadTrustedCerts(ctx, req.Storage, certName) + roots, trusted, trustedNonCAs, verifyConf := b.loadTrustedCerts(ctx, req.Storage, certName) // Get the list of full chains matching the connection and validates the // certificate itself @@ -248,7 +248,7 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d // Check for client cert being explicitly listed in the config (and matching other constraints) if tCert.SerialNumber.Cmp(clientCert.SerialNumber) == 0 && bytes.Equal(tCert.AuthorityKeyId, clientCert.AuthorityKeyId) { - matches, err := b.matchesConstraints(ctx, clientCert, trustedNonCA.Certificates, trustedNonCA, extraCas, ocspServersOverride, ocspFailureMode) + matches, err := b.matchesConstraints(ctx, clientCert, trustedNonCA.Certificates, trustedNonCA, verifyConf) if err != nil { return nil, nil, err } @@ -272,7 +272,7 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d for _, chain := range trustedChains { // For each root chain that we matched for _, cCert := range chain { // For each cert in the matched chain if tCert.Equal(cCert) { // ParsedCert intersects with matched chain - match, err := b.matchesConstraints(ctx, clientCert, chain, trust, extraCas, ocspServersOverride, ocspFailureMode) // validate client cert + matched chain against the config + match, err := b.matchesConstraints(ctx, clientCert, chain, trust, verifyConf) // validate client cert + matched chain against the config if err != nil { return nil, nil, err } @@ -296,7 +296,7 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d } func (b *backend) matchesConstraints(ctx context.Context, clientCert *x509.Certificate, trustedChain []*x509.Certificate, - config *ParsedCert, extraCas []*x509.Certificate, ocspServersOverride []string, ocspFailureMode ocsp.FailOpenMode) (bool, error) { + config *ParsedCert, conf *ocsp.VerifyConfig) (bool, error) { soFar := !b.checkForChainInCRLs(trustedChain) && b.matchesNames(clientCert, config) && b.matchesCommonName(clientCert, config) && @@ -306,7 +306,7 @@ func (b *backend) matchesConstraints(ctx context.Context, clientCert *x509.Certi b.matchesOrganizationalUnits(clientCert, config) && b.matchesCertificateExtensions(clientCert, config) if config.Entry.OcspEnabled { - ocspGood, err := b.checkForChainInOCSP(ctx, trustedChain, extraCas, ocspServersOverride, ocspFailureMode) + ocspGood, err := b.checkForChainInOCSP(ctx, trustedChain, conf) if err != nil { return false, err } @@ -503,7 +503,7 @@ func (b *backend) certificateExtensionsMetadata(clientCert *x509.Certificate, co } // loadTrustedCerts is used to load all the trusted certificates from the backend -func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert, ocspServersOverride []string, ocspFailureMode ocsp.FailOpenMode) { +func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert, conf *ocsp.VerifyConfig) { pool = x509.NewCertPool() trusted = make([]*ParsedCert, 0) trustedNonCAs = make([]*ParsedCert, 0) @@ -520,6 +520,7 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, } } + conf = &ocsp.VerifyConfig{} for _, name := range names { entry, err := b.Cert(ctx, storage, strings.TrimPrefix(name, "cert/")) if err != nil { @@ -555,21 +556,24 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, Certificates: parsed, }) } - ocspServersOverride = entry.OcspServersOverride + conf.OcspServersOverride = append(conf.OcspServersOverride, entry.OcspServersOverride...) if entry.OcspFailOpen { - ocspFailureMode = ocsp.FailOpenTrue + conf.OcspFailureMode = ocsp.FailOpenTrue } else { - ocspFailureMode = ocsp.FailOpenFalse + conf.OcspFailureMode = ocsp.FailOpenFalse } + conf.QueryAllServers = conf.QueryAllServers || entry.OcspQueryAllServers } return } -func (b *backend) checkForChainInOCSP(ctx context.Context, chain []*x509.Certificate, extraCas []*x509.Certificate, ocspServersOverride []string, ocspFailureMode ocsp.FailOpenMode) (bool, error) { +func (b *backend) checkForChainInOCSP(ctx context.Context, chain []*x509.Certificate, conf *ocsp.VerifyConfig) (bool, error) { if b.ocspDisabled || len(chain) < 2 { return true, nil } - err := b.ocspClient.VerifyPeerCertificate(ctx, [][]*x509.Certificate{chain}, extraCas, ocspServersOverride, ocspFailureMode) + b.ocspClientMutex.RLock() + defer b.ocspClientMutex.RUnlock() + err := b.ocspClient.VerifyPeerCertificate(ctx, [][]*x509.Certificate{chain}, conf) if err != nil { return false, nil } diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index 5e13d5c41dc1..f23ff7587f4a 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -11,14 +11,11 @@ import ( "crypto/x509/pkix" "encoding/asn1" "encoding/base64" - "encoding/json" "errors" "fmt" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-retryablehttp" lru "github.com/hashicorp/golang-lru" - "github.com/hashicorp/vault/sdk/helper/jsonutil" - "github.com/hashicorp/vault/sdk/logical" "golang.org/x/crypto/ocsp" "io/ioutil" "math/big" @@ -45,6 +42,8 @@ const ( httpHeaderAccept = "accept" httpHeaderContentLength = "Content-Length" httpHeaderHost = "Host" + ocspRequestContentType = "application/ocsp-request" + ocspResponseContentType = "application/ocsp-response" ) const ( @@ -70,9 +69,7 @@ const ( const ( // cacheExpire specifies cache data expiration time in seconds. - cacheExpire = float64(24 * 60 * 60) - cacheSize = 10000 - persistedCacheSize = 1000 + cacheExpire = float64(24 * 60 * 60) ) const ( @@ -91,7 +88,7 @@ type ocspCachedResponse struct { type Client struct { // caRoot includes the CA certificates. caRoot map[string]*x509.Certificate - // certPOol includes the CA certificates. + // certPool includes the CA certificates. certPool *x509.CertPool ocspResponseCache *lru.TwoQueueCache ocspResponseCacheLock sync.RWMutex @@ -159,7 +156,7 @@ func (c *Client) getHashAlgorithmFromOID(target pkix.AlgorithmIdentifier) crypto return hash } } - c.Logger().Error("no valid hash algorithm is found for the oid. Falling back to SHA1", "target", target) + //no valid hash algorithm is found for the oid. Falling back to SHA1 return crypto.SHA1 } @@ -393,7 +390,7 @@ func (c *Client) retryOCSP( } // getRevocationStatus checks the certificate revocation status for subject using issuer certificate. -func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate, extraCas []*x509.Certificate, ocspServersOverride []string) (*ocspStatus, error) { +func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate, conf *VerifyConfig) (*ocspStatus, error) { status, ocspReq, encodedCertID, err := c.ValidateWithCache(subject, issuer) if err != nil { return nil, err @@ -405,17 +402,21 @@ func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509. return status, nil } c.Logger().Debug("cache missed", "server", subject.OCSPServer) - if len(subject.OCSPServer) == 0 && len(ocspServersOverride) == 0 { + if len(subject.OCSPServer) == 0 && len(conf.OcspServersOverride) == 0 { return nil, fmt.Errorf("no OCSP responder URL: subject: %v", subject.Subject) } ocspHosts := subject.OCSPServer - if len(ocspServersOverride) > 0 { - ocspHosts = ocspServersOverride + if len(conf.OcspServersOverride) > 0 { + ocspHosts = conf.OcspServersOverride } - var ret *ocspStatus - var ocspRes *ocsp.Response - for _, ocspHost := range ocspHosts { + var wg sync.WaitGroup + + ocspStatuses := make([]*ocspStatus, len(ocspHosts)) + ocspResponses := make([]*ocsp.Response, len(ocspHosts)) + errors := make([]error, len(ocspHosts)) + + for i, ocspHost := range ocspHosts { u, err := url.Parse(ocspHost) if err != nil { return nil, err @@ -424,34 +425,70 @@ func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509. hostname := u.Hostname() headers := make(map[string]string) - headers[httpHeaderContentType] = "application/ocsp-request" - headers[httpHeaderAccept] = "application/ocsp-response" + headers[httpHeaderContentType] = ocspRequestContentType + headers[httpHeaderAccept] = ocspResponseContentType headers[httpHeaderContentLength] = strconv.Itoa(len(ocspReq)) headers[httpHeaderHost] = hostname timeout := defaultOCSPResponderTimeout ocspClient := retryablehttp.NewClient() ocspClient.HTTPClient.Timeout = timeout - ocspClient.HTTPClient.Transport = newInsecureOcspTransport(extraCas) + ocspClient.HTTPClient.Transport = newInsecureOcspTransport(conf.ExtraCas) - var ocspS *ocspStatus - ocspRes, _, ocspS, err = c.retryOCSP( - ctx, ocspClient, retryablehttp.NewRequest, u, headers, ocspReq, issuer) - if err != nil { - return nil, err - } - if ocspS.code != ocspSuccess { - return ocspS, nil - } + doRequest := func() error { + if conf.QueryAllServers { + defer wg.Done() + } + ocspRes, _, ocspS, err := c.retryOCSP( + ctx, ocspClient, retryablehttp.NewRequest, u, headers, ocspReq, issuer) + ocspResponses[i] = ocspRes + if err != nil { + errors[i] = err + return err + } + if ocspS.code != ocspSuccess { + ocspStatuses[i] = ocspS + return nil + } - ret, err = validateOCSP(ocspRes) - if err != nil { - return nil, err + ret, err := validateOCSP(ocspRes) + if err != nil { + errors[i] = err + return err + } + if isValidOCSPStatus(ret.code) { + ocspStatuses[i] = ret + } + return nil + } + if conf.QueryAllServers { + wg.Add(1) + go doRequest() + } else { + err = doRequest() + if err == nil { + break + } } - if isValidOCSPStatus(ret.code) { + } + if conf.QueryAllServers { + wg.Wait() + } + // Good by default + var ret *ocspStatus = &ocspStatus{code: ocspStatusGood} + ocspRes := ocspResponses[0] + for i, _ := range ocspHosts { + if errors[i] != nil { + return nil, errors[i] + } else if ocspStatuses[i] != nil && (!conf.QueryAllServers || ocspStatuses[i].code != ocspStatusGood) { + ret = ocspStatuses[i] + ocspRes = ocspResponses[i] break + } else { + } } + if !isValidOCSPStatus(ret.code) { return ret, nil } @@ -473,8 +510,15 @@ func isValidOCSPStatus(status ocspStatusCode) bool { return status == ocspStatusGood || status == ocspStatusRevoked || status == ocspStatusUnknown } +type VerifyConfig struct { + ExtraCas []*x509.Certificate + OcspServersOverride []string + OcspFailureMode FailOpenMode + QueryAllServers bool +} + // VerifyPeerCertificate verifies all of certificate revocation status -func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate, extraCas []*x509.Certificate, ocspServersOverride []string, ocspFailureMode FailOpenMode) error { +func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate, conf *VerifyConfig) error { for i := 0; i < len(verifiedChains); i++ { // Certificate signed by Root CA. This should be one before the last in the Certificate Chain numberOfNoneRootCerts := len(verifiedChains[i]) - 1 @@ -488,11 +532,11 @@ func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]* verifiedChains[i] = append(verifiedChains[i], rca) numberOfNoneRootCerts++ } - results, err := c.GetAllRevocationStatus(ctx, verifiedChains[i], extraCas, ocspServersOverride) + results, err := c.GetAllRevocationStatus(ctx, verifiedChains[i], conf) if err != nil { return err } - if r := c.canEarlyExitForOCSP(results, numberOfNoneRootCerts, ocspFailureMode); r != nil { + if r := c.canEarlyExitForOCSP(results, numberOfNoneRootCerts, conf); r != nil { return r.err } } @@ -500,9 +544,9 @@ func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]* return nil } -func (c *Client) canEarlyExitForOCSP(results []*ocspStatus, chainSize int, ocspFailureMode FailOpenMode) *ocspStatus { +func (c *Client) canEarlyExitForOCSP(results []*ocspStatus, chainSize int, conf *VerifyConfig) *ocspStatus { msg := "" - if ocspFailureMode == FailOpenFalse { + if conf.OcspFailureMode == FailOpenFalse { // Fail closed. any error is returned to stop connection for _, r := range results { if r.err != nil { @@ -566,7 +610,7 @@ func (c *Client) ValidateWithCache(subject, issuer *x509.Certificate) (*ocspStat return status, ocspReq, encodedCertID, nil } -func (c *Client) GetAllRevocationStatus(ctx context.Context, verifiedChains, extraCas []*x509.Certificate, ocspServersOverride []string) ([]*ocspStatus, error) { +func (c *Client) GetAllRevocationStatus(ctx context.Context, verifiedChains []*x509.Certificate, conf *VerifyConfig) ([]*ocspStatus, error) { _, err := c.ValidateWithCacheForAllCertificates(verifiedChains) if err != nil { return nil, err @@ -574,7 +618,7 @@ func (c *Client) GetAllRevocationStatus(ctx context.Context, verifiedChains, ext n := len(verifiedChains) - 1 results := make([]*ocspStatus, n) for j := 0; j < n; j++ { - results[j], err = c.getRevocationStatus(ctx, verifiedChains[j], verifiedChains[j+1], extraCas, ocspServersOverride) + results[j], err = c.getRevocationStatus(ctx, verifiedChains[j], verifiedChains[j+1], conf) if err != nil { return nil, err } @@ -586,9 +630,9 @@ func (c *Client) GetAllRevocationStatus(ctx context.Context, verifiedChains, ext } // verifyPeerCertificateSerial verifies the certificate revocation status in serial. -func (c *Client) verifyPeerCertificateSerial(extraCas []*x509.Certificate, ocspServersOverride []string, ocspFailureMode FailOpenMode) func(_ [][]byte, verifiedChains [][]*x509.Certificate) (err error) { +func (c *Client) verifyPeerCertificateSerial(conf *VerifyConfig) func(_ [][]byte, verifiedChains [][]*x509.Certificate) (err error) { return func(_ [][]byte, verifiedChains [][]*x509.Certificate) error { - return c.VerifyPeerCertificate(context.TODO(), verifiedChains, extraCas, ocspServersOverride, ocspFailureMode) + return c.VerifyPeerCertificate(context.TODO(), verifiedChains, conf) } } @@ -626,6 +670,7 @@ func (c *Client) extractOCSPCacheResponseValue(cacheValue *ocspCachedResponse, s }) } +/* // writeOCSPCache writes a OCSP Response cache func (c *Client) writeOCSPCache(ctx context.Context, storage logical.Storage) error { c.Logger().Debug("writing OCSP Response cache") @@ -718,8 +763,9 @@ func (c *Client) readOCSPCache(ctx context.Context, storage logical.Storage) err return nil } +*/ -func New(logFactory func() hclog.Logger) *Client { +func New(logFactory func() hclog.Logger, cacheSize int) *Client { cache, _ := lru.New2Q(cacheSize) c := Client{ caRoot: make(map[string]*x509.Certificate), @@ -760,11 +806,21 @@ func newInsecureOcspTransport(extraCas []*x509.Certificate) *http.Transport { } // NewTransport includes the certificate revocation check with OCSP in sequential. -func (c *Client) NewTransport(extraCas []*x509.Certificate, ocspServersOverride []string, ocspFailureMode FailOpenMode) *http.Transport { +func (c *Client) NewTransport(conf *VerifyConfig) *http.Transport { + rootCAs := c.certPool + if rootCAs == nil { + rootCAs, _ = x509.SystemCertPool() + } + if rootCAs == nil { + rootCAs = x509.NewCertPool() + } + for _, c := range conf.ExtraCas { + rootCAs.AddCert(c) + } return &http.Transport{ TLSClientConfig: &tls.Config{ - RootCAs: c.certPool, - VerifyPeerCertificate: c.verifyPeerCertificateSerial(extraCas, ocspServersOverride, ocspFailureMode), + RootCAs: rootCAs, + VerifyPeerCertificate: c.verifyPeerCertificateSerial(conf), }, MaxIdleConns: 10, IdleConnTimeout: 30 * time.Minute, @@ -776,6 +832,7 @@ func (c *Client) NewTransport(extraCas []*x509.Certificate, ocspServersOverride } } +/* func (c *Client) WriteCache(ctx context.Context, storage logical.Storage) error { c.ocspResponseCacheLock.Lock() defer c.ocspResponseCacheLock.Unlock() @@ -794,7 +851,7 @@ func (c *Client) ReadCache(ctx context.Context, storage logical.Storage) error { defer c.ocspResponseCacheLock.Unlock() return c.readOCSPCache(ctx, storage) } - +*/ /* Apache License Version 2.0, January 2004 diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go index 4b198d24c5c1..7c34ac3c96e7 100644 --- a/sdk/helper/ocsp/ocsp_test.go +++ b/sdk/helper/ocsp/ocsp_test.go @@ -8,6 +8,7 @@ import ( "crypto" "crypto/tls" "crypto/x509" + "encoding/pem" "errors" "fmt" "github.com/hashicorp/go-hclog" @@ -31,10 +32,13 @@ func TestOCSP(t *testing.T) { "https://s3-us-west-2.amazonaws.com/sfc-snowsql-updates/?prefix=1.1/windows_x86_64", } - c := New(testLogFactory) + conf := VerifyConfig{ + OcspFailureMode: FailOpenFalse, + } + c := New(testLogFactory, 10) transports := []*http.Transport{ newInsecureOcspTransport(nil), - c.NewTransport(nil, nil, FailOpenFalse), + c.NewTransport(&conf), } for _, tgt := range targetURL { @@ -62,6 +66,50 @@ func TestOCSP(t *testing.T) { } } +func TestMultiOCSP(t *testing.T) { + targetURL := []string{ + "https://localhost:8200/v1/pki/ocsp", + "https://localhost:8200/v1/pki/ocsp", + "https://localhost:8200/v1/pki/ocsp", + } + + b, _ := pem.Decode([]byte(vaultCert)) + caCert, _ := x509.ParseCertificate(b.Bytes) + conf := VerifyConfig{ + OcspFailureMode: FailOpenFalse, + QueryAllServers: true, + OcspServersOverride: targetURL, + ExtraCas: []*x509.Certificate{caCert}, + } + c := New(testLogFactory, 10) + transports := []*http.Transport{ + newInsecureOcspTransport(conf.ExtraCas), + c.NewTransport(&conf), + } + + tgt := "https://localhost:8200/v1/pki/ca/pem" + c.ocspResponseCache, _ = lru.New2Q(10) + for _, tr := range transports { + c := &http.Client{ + Transport: tr, + Timeout: 30 * time.Second, + } + req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil)) + if err != nil { + t.Fatalf("fail to create a request. err: %v", err) + } + res, err := c.Do(req) + if err != nil { + t.Fatalf("failed to GET contents. err: %v", err) + } + defer res.Body.Close() + _, err = ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("failed to read content body for %v", tgt) + } + } +} + type tcValidityRange struct { thisTime time.Time nextTime time.Time @@ -163,7 +211,7 @@ func TestUnitEncodeCertIDGood(t *testing.T) { } func TestUnitCheckOCSPResponseCache(t *testing.T) { - c := New(testLogFactory) + c := New(testLogFactory, 10) dummyKey0 := certIDKey{ HashAlgorithm: crypto.SHA1, NameHash: "dummy0", @@ -286,7 +334,7 @@ func getCert(addr string) []*x509.Certificate { } func TestOCSPRetry(t *testing.T) { - c := New(testLogFactory) + c := New(testLogFactory, 10) certs := getCert("s3-us-west-2.amazonaws.com:443") dummyOCSPHost := &url.URL{ Scheme: "https", @@ -411,17 +459,17 @@ func TestCanEarlyExitForOCSP(t *testing.T) { retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, }, } - c := New(testLogFactory) + c := New(testLogFactory, 10) for idx, tt := range testcases { expectedLen := len(tt.results) if tt.resultLen > 0 { expectedLen = tt.resultLen } - r := c.canEarlyExitForOCSP(tt.results, expectedLen, FailOpenTrue) + r := c.canEarlyExitForOCSP(tt.results, expectedLen, &VerifyConfig{OcspFailureMode: FailOpenTrue}) if !(tt.retFailOpen == nil && r == nil) && !(tt.retFailOpen != nil && r != nil && tt.retFailOpen.code == r.code) { t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailOpen, r) } - r = c.canEarlyExitForOCSP(tt.results, expectedLen, FailOpenFalse) + r = c.canEarlyExitForOCSP(tt.results, expectedLen, &VerifyConfig{OcspFailureMode: FailOpenFalse}) if !(tt.retFailClosed == nil && r == nil) && !(tt.retFailClosed != nil && r != nil && tt.retFailClosed.code == r.code) { t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailClosed, r) } @@ -508,3 +556,27 @@ func (b *fakeResponseBody) Close() error { func fakeRequestFunc(_, _ string, _ interface{}) (*retryablehttp.Request, error) { return nil, nil } + +const vaultCert = `-----BEGIN CERTIFICATE----- +MIIDuTCCAqGgAwIBAgIUA6VeVD1IB5rXcCZRAqPO4zr/GAMwDQYJKoZIhvcNAQEL +BQAwcjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAlZBMREwDwYDVQQHDAhTb21lQ2l0 +eTESMBAGA1UECgwJTXlDb21wYW55MRMwEQYDVQQLDApNeURpdmlzaW9uMRowGAYD +VQQDDBF3d3cuY29uaHVnZWNvLmNvbTAeFw0yMjA5MDcxOTA1MzdaFw0yNDA5MDYx +OTA1MzdaMHIxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJWQTERMA8GA1UEBwwIU29t +ZUNpdHkxEjAQBgNVBAoMCU15Q29tcGFueTETMBEGA1UECwwKTXlEaXZpc2lvbjEa +MBgGA1UEAwwRd3d3LmNvbmh1Z2Vjby5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IB +DwAwggEKAoIBAQDL9qzEXi4PIafSAqfcwcmjujFvbG1QZbI8swxnD+w8i4ufAQU5 +LDmvMrGo3ZbhJ0mCihYmFxpjhRdP2raJQ9TysHlPXHtDRpr9ckWTKBz2oIfqVtJ2 +qzteQkWCkDAO7kPqzgCFsMeoMZeONRkeGib0lEzQAbW/Rqnphg8zVVkyQ71DZ7Pc +d5WkC2E28kKcSramhWfVFpxG3hSIrLOX2esEXteLRzKxFPf+gi413JZFKYIWrebP +u5t0++MLNpuX322geoki4BWMjQsd47XILmxZ4aj33ScZvdrZESCnwP76hKIxg9mO +lMxrqSWKVV5jHZrElSEj9LYJgDO1Y6eItn7hAgMBAAGjRzBFMAsGA1UdDwQEAwIE +MDATBgNVHSUEDDAKBggrBgEFBQcDATAhBgNVHREEGjAYggtleGFtcGxlLmNvbYIJ +bG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQA5dPdf5SdtMwe2uSspO/EuWqbM +497vMQBW1Ey8KRKasJjhvOVYMbe7De5YsnW4bn8u5pl0zQGF4hEtpmifAtVvziH/ +K+ritQj9VVNbLLCbFcg+b0kfjt4yrDZ64vWvIeCgPjG1Kme8gdUUWgu9dOud5gdx +qg/tIFv4TRS/eIIymMlfd9owOD3Ig6S5fy4NaAJFAwXf8+3Rzuc+e7JSAPgAufjh +tOTWinxvoiOLuYwo9CyGgq4qKBFsrY0aE0gdA7oTQkpbEbo2EbqiWUl/PTCl1Y4Z +nSZ0n+4q9QC9RLrWwYTwh838d5RVLUst2mBKSA+vn7YkqmBJbdBC6nkd7n7H +-----END CERTIFICATE----- +` From 0f9cda6d0bf63f8df6231d5ee41f610df4004f4f Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Mon, 12 Sep 2022 16:34:08 -0500 Subject: [PATCH 11/39] dead code --- sdk/helper/ocsp/client.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index f23ff7587f4a..fd523c35340c 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -484,8 +484,6 @@ func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509. ret = ocspStatuses[i] ocspRes = ocspResponses[i] break - } else { - } } From fca24ee249e77e4726d652064c2f3963a2e88c8c Mon Sep 17 00:00:00 2001 From: Scott Miller Date: Tue, 13 Sep 2022 11:17:22 -0500 Subject: [PATCH 12/39] Update builtin/credential/cert/path_certs.go Co-authored-by: Alexander Scheel --- builtin/credential/cert/path_certs.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/builtin/credential/cert/path_certs.go b/builtin/credential/cert/path_certs.go index ae1acae4a3de..1e4fee0ac83d 100644 --- a/builtin/credential/cert/path_certs.go +++ b/builtin/credential/cert/path_certs.go @@ -68,7 +68,7 @@ from the AuthorityInformationAccess extension on the certificate being inspected "ocsp_fail_open": { Type: framework.TypeBool, Default: false, - Description: "If set to true, if an OCSP revocation cannot be made successfully, login will proceed rather. If false, failing to get an OCSP status fails the request.", + Description: "If set to true, if an OCSP revocation cannot be made successfully, login will proceed rather than failing. If false, failing to get an OCSP status fails the request.", }, "ocsp_query_all_servers": { Type: framework.TypeBool, From 59a3769acaf3876ac5058cd5bc7f3917c5198a19 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 13 Sep 2022 11:27:32 -0500 Subject: [PATCH 13/39] Hook invalidate to reinit the ocsp cache size --- builtin/credential/cert/backend.go | 17 ++++++++++++++--- builtin/credential/cert/path_login.go | 8 +++++++- sdk/helper/ocsp/client.go | 8 ++------ 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/builtin/credential/cert/backend.go b/builtin/credential/cert/backend.go index a2b72e6d877e..bf35cf28abdb 100644 --- a/builtin/credential/cert/backend.go +++ b/builtin/credential/cert/backend.go @@ -20,8 +20,8 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, if err != nil { return nil, err } - if conf != nil { - b.initOCSPClient(bConf.OcspCacheSize) + if bConf != nil { + b.updatedConfig(bConf) } return b, nil } @@ -56,10 +56,11 @@ type backend struct { MapCertId *framework.PathMap crls map[string]CRLInfo - ocspDisabled bool + ocspEnabled bool crlUpdateMutex *sync.RWMutex ocspClientMutex sync.RWMutex ocspClient *ocsp.Client + configUpdated bool } func (b *backend) invalidate(_ context.Context, key string) { @@ -68,6 +69,8 @@ func (b *backend) invalidate(_ context.Context, key string) { b.crlUpdateMutex.Lock() defer b.crlUpdateMutex.Unlock() b.crls = nil + case key == "config": + b.configUpdated = true } } @@ -79,6 +82,14 @@ func (b *backend) initOCSPClient(cacheSize int) { }, cacheSize) } +func (b *backend) updatedConfig(config *config) error { + if config != nil { + b.initOCSPClient(config.OcspCacheSize) + } + b.configUpdated = false + return nil +} + const backendHelp = ` The "cert" credential provider allows authentication using TLS client certificates. A client connects to Vault and uses diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 3906004b4cd5..ad287fa8308f 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -82,6 +82,9 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra if err != nil { return nil, err } + if b.configUpdated { + b.updatedConfig(config) + } var matched *ParsedCert if verifyResp, resp, err := b.verifyCredentials(ctx, req, data); err != nil { @@ -155,6 +158,9 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f if err != nil { return nil, err } + if b.configUpdated { + b.updatedConfig(config) + } if !config.DisableBinding { var matched *ParsedCert @@ -568,7 +574,7 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, } func (b *backend) checkForChainInOCSP(ctx context.Context, chain []*x509.Certificate, conf *ocsp.VerifyConfig) (bool, error) { - if b.ocspDisabled || len(chain) < 2 { + if !b.ocspEnabled || len(chain) < 2 { return true, nil } b.ocspClientMutex.RLock() diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index fd523c35340c..cb55898c7b1d 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -17,7 +17,7 @@ import ( "github.com/hashicorp/go-retryablehttp" lru "github.com/hashicorp/golang-lru" "golang.org/x/crypto/ocsp" - "io/ioutil" + "io" "math/big" "net" "net/http" @@ -172,10 +172,6 @@ func durationMax(a, b time.Duration) time.Duration { return b } -func durationMin(a, b time.Duration) time.Duration { - return durationMax(b, a) -} - // isInValidityRange checks the validity func isInValidityRange(currTime, thisUpdate, nextUpdate time.Time) bool { if currTime.Sub(thisUpdate.Add(-maxClockSkew)) < 0 { @@ -375,7 +371,7 @@ func (c *Client) retryOCSP( if res.StatusCode != http.StatusOK { return nil, nil, nil, fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status) } - ocspResBytes, err = ioutil.ReadAll(res.Body) + ocspResBytes, err = io.ReadAll(res.Body) if err != nil { return nil, nil, nil, err } From ea1feb4650ee417c9ef9b8d949e7abbfc4944f34 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 13 Sep 2022 11:29:37 -0500 Subject: [PATCH 14/39] locking --- builtin/credential/cert/backend.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/builtin/credential/cert/backend.go b/builtin/credential/cert/backend.go index bf35cf28abdb..f0ada42f691f 100644 --- a/builtin/credential/cert/backend.go +++ b/builtin/credential/cert/backend.go @@ -70,19 +70,22 @@ func (b *backend) invalidate(_ context.Context, key string) { defer b.crlUpdateMutex.Unlock() b.crls = nil case key == "config": + // Is this really necessary? + b.ocspClientMutex.Lock() + defer b.ocspClientMutex.Unlock() b.configUpdated = true } } func (b *backend) initOCSPClient(cacheSize int) { - b.ocspClientMutex.Lock() - defer b.ocspClientMutex.Unlock() b.ocspClient = ocsp.New(func() hclog.Logger { return b.Logger() }, cacheSize) } func (b *backend) updatedConfig(config *config) error { + b.ocspClientMutex.Lock() + defer b.ocspClientMutex.Unlock() if config != nil { b.initOCSPClient(config.OcspCacheSize) } From e7d790da23d620d4a689c43693bd5eb4ed6979e7 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 13 Sep 2022 12:56:04 -0500 Subject: [PATCH 15/39] Conditionally init the ocsp client --- builtin/credential/cert/backend.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/builtin/credential/cert/backend.go b/builtin/credential/cert/backend.go index f0ada42f691f..77a34ac879b6 100644 --- a/builtin/credential/cert/backend.go +++ b/builtin/credential/cert/backend.go @@ -87,7 +87,11 @@ func (b *backend) updatedConfig(config *config) error { b.ocspClientMutex.Lock() defer b.ocspClientMutex.Unlock() if config != nil { - b.initOCSPClient(config.OcspCacheSize) + if b.ocspEnabled { + b.initOCSPClient(config.OcspCacheSize) + } else { + b.ocspClient = nil + } } b.configUpdated = false return nil From 75244c912435c1b44bea7f711f5934303bfaccbd Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 13 Sep 2022 13:01:12 -0500 Subject: [PATCH 16/39] Remove cache size config from cert configs, it's a backend global --- builtin/credential/cert/path_certs.go | 4 ---- builtin/credential/cert/path_config.go | 8 +++++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/builtin/credential/cert/path_certs.go b/builtin/credential/cert/path_certs.go index 1e4fee0ac83d..c63c6e4caf99 100644 --- a/builtin/credential/cert/path_certs.go +++ b/builtin/credential/cert/path_certs.go @@ -341,9 +341,6 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr if ocspQueryAll, ok := d.GetOk("ocsp_query_all_servers"); ok { cert.OcspQueryAllServers = ocspQueryAll.(bool) } - if ocspCacheSize, ok := d.GetOk("ocsp_cache_size"); ok { - cert.OcspCacheSize = ocspCacheSize.(int) - } if displayNameRaw, ok := d.GetOk("display_name"); ok { cert.DisplayName = displayNameRaw.(string) } @@ -478,7 +475,6 @@ type CertEntry struct { OcspEnabled bool OcspServersOverride []string OcspFailOpen bool - OcspCacheSize int OcspQueryAllServers bool DisplayName string Policies []string diff --git a/builtin/credential/cert/path_config.go b/builtin/credential/cert/path_config.go index 0702ff97f356..c837b79f4399 100644 --- a/builtin/credential/cert/path_config.go +++ b/builtin/credential/cert/path_config.go @@ -35,11 +35,12 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat disableBinding := data.Get("disable_binding").(bool) enableIdentityAliasMetadata := data.Get("enable_identity_alias_metadata").(bool) cacheSize := data.Get("ocsp_cache_size").(int) - entry, err := logical.StorageEntryJSON("config", config{ + config := config{ DisableBinding: disableBinding, EnableIdentityAliasMetadata: enableIdentityAliasMetadata, OcspCacheSize: cacheSize, - }) + } + entry, err := logical.StorageEntryJSON("config", config) if err != nil { return nil, err } @@ -47,7 +48,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat if err := req.Storage.Put(ctx, entry); err != nil { return nil, err } - b.initOCSPClient(cacheSize) + b.updatedConfig(&config) return nil, nil } @@ -60,6 +61,7 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *f data := map[string]interface{}{ "disable_binding": cfg.DisableBinding, "enable_identity_alias_metadata": cfg.EnableIdentityAliasMetadata, + "ocsp_cache_size": cfg.OcspCacheSize, } return &logical.Response{ From 68c330857995acbdff7f468577c91a0ae61bfb41 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 13 Sep 2022 13:02:19 -0500 Subject: [PATCH 17/39] Add field --- builtin/credential/cert/path_config.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/builtin/credential/cert/path_config.go b/builtin/credential/cert/path_config.go index c837b79f4399..68205e134630 100644 --- a/builtin/credential/cert/path_config.go +++ b/builtin/credential/cert/path_config.go @@ -22,6 +22,11 @@ func pathConfig(b *backend) *framework.Path { Default: false, Description: `If set, metadata of the certificate including the metadata corresponding to allowed_metadata_extensions will be stored in the alias. Defaults to false.`, }, + "ocsp_cache_size": { + Type: framework.TypeInt, + Default: 100, + Description: `The size of the in memory OCSP response cache, shared by all configured certs`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ From acc2c2979bfdbf3ecd5a7dfbaf37576a6687e4a0 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 13 Sep 2022 15:19:05 -0500 Subject: [PATCH 18/39] Remove strangely complex validity logic --- sdk/helper/ocsp/client.go | 29 ++------------------ sdk/helper/ocsp/ocsp_test.go | 53 ------------------------------------ 2 files changed, 3 insertions(+), 79 deletions(-) diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index cb55898c7b1d..614a4a46cdeb 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -72,11 +72,6 @@ const ( cacheExpire = float64(24 * 60 * 60) ) -const ( - tolerableValidityRatio = 100 // buffer for certificate revocation update time - maxClockSkew = 900 * time.Second // buffer for clock skew -) - type ocspCachedResponse struct { time float64 producedAt float64 @@ -160,27 +155,9 @@ func (c *Client) getHashAlgorithmFromOID(target pkix.AlgorithmIdentifier) crypto return crypto.SHA1 } -// calcTolerableValidity returns the maximum validity buffer -func calcTolerableValidity(thisUpdate, nextUpdate time.Time) time.Duration { - return durationMax(nextUpdate.Sub(thisUpdate)/tolerableValidityRatio, maxClockSkew) -} - -func durationMax(a, b time.Duration) time.Duration { - if a > b { - return a - } - return b -} - // isInValidityRange checks the validity -func isInValidityRange(currTime, thisUpdate, nextUpdate time.Time) bool { - if currTime.Sub(thisUpdate.Add(-maxClockSkew)) < 0 { - return false - } - if nextUpdate.Add(calcTolerableValidity(thisUpdate, nextUpdate)).Sub(currTime) < 0 { - return false - } - return true +func isInValidityRange(currTime, nextUpdate time.Time) bool { + return !currTime.After(nextUpdate) } func extractCertIDKeyFromRequest(ocspReq []byte) (*certIDKey, *ocspStatus) { @@ -289,7 +266,7 @@ func validateOCSP(ocspRes *ocsp.Response) (*ocspStatus, error) { if ocspRes == nil { return nil, errors.New("OCSP Response is nil") } - if !isInValidityRange(curTime, ocspRes.ThisUpdate, ocspRes.NextUpdate) { + if !isInValidityRange(curTime, ocspRes.NextUpdate) { return &ocspStatus{ code: ocspInvalidValidity, err: fmt.Errorf("invalid validity: producedAt: %v, thisUpdate: %v, nextUpdate: %v", ocspRes.ProducedAt, ocspRes.ThisUpdate, ocspRes.NextUpdate), diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go index 7c34ac3c96e7..19a9f0a9faba 100644 --- a/sdk/helper/ocsp/ocsp_test.go +++ b/sdk/helper/ocsp/ocsp_test.go @@ -110,59 +110,6 @@ func TestMultiOCSP(t *testing.T) { } } -type tcValidityRange struct { - thisTime time.Time - nextTime time.Time - ret bool -} - -func TestUnitIsInValidityRange(t *testing.T) { - currentTime := time.Now() - testcases := []tcValidityRange{ - { - // basic tests - thisTime: currentTime.Add(-100 * time.Second), - nextTime: currentTime.Add(maxClockSkew), - ret: true, - }, - { - // on the border - thisTime: currentTime.Add(maxClockSkew), - nextTime: currentTime.Add(maxClockSkew), - ret: true, - }, - { - // 1 earlier late - thisTime: currentTime.Add(maxClockSkew + 1*time.Second), - nextTime: currentTime.Add(maxClockSkew), - ret: false, - }, - { - // on the border - thisTime: currentTime.Add(-maxClockSkew), - nextTime: currentTime.Add(-maxClockSkew), - ret: true, - }, - { - // around the border - thisTime: currentTime.Add(-24*time.Hour - 40*time.Second), - nextTime: currentTime.Add(-24*time.Hour/time.Duration(100) - 40*time.Second), - ret: false, - }, - { - // on the border - thisTime: currentTime.Add(-48*time.Hour - 29*time.Minute), - nextTime: currentTime.Add(-48 * time.Hour / time.Duration(100)), - ret: true, - }, - } - for _, tc := range testcases { - if tc.ret != isInValidityRange(currentTime, tc.thisTime, tc.nextTime) { - t.Fatalf("failed to check validity. should be: %v, currentTime: %v, thisTime: %v, nextTime: %v", tc.ret, currentTime, tc.thisTime, tc.nextTime) - } - } -} - func TestUnitEncodeCertIDGood(t *testing.T) { targetURLs := []string{ "faketestaccount.snowflakecomputing.com:443", From 81e8d6aca30574cc30004126de9cb9829be58c85 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 13 Sep 2022 17:10:27 -0500 Subject: [PATCH 19/39] Address more feedback --- builtin/credential/cert/path_certs.go | 7 --- builtin/credential/cert/path_config.go | 28 +++++++---- builtin/credential/cert/path_login.go | 6 +-- sdk/helper/ocsp/client.go | 64 ++++++++++++++++++++++---- 4 files changed, 78 insertions(+), 27 deletions(-) diff --git a/builtin/credential/cert/path_certs.go b/builtin/credential/cert/path_certs.go index c63c6e4caf99..ff186ad73d75 100644 --- a/builtin/credential/cert/path_certs.go +++ b/builtin/credential/cert/path_certs.go @@ -13,8 +13,6 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) -const defaultOCSPCacheSize = 100 - func pathListCerts(b *backend) *framework.Path { return &framework.Path{ Pattern: "certs/?", @@ -75,11 +73,6 @@ from the AuthorityInformationAccess extension on the certificate being inspected Default: false, Description: "If set to true, rather than accepting the first successful OCSP response, query all servers and consider the certificate valid only if all servers agree.", }, - "ocsp_cache_size": { - Type: framework.TypeInt, - Default: defaultOCSPCacheSize, - Description: "The size of the OCSP response cache.", - }, "allowed_names": { Type: framework.TypeCommaStringSlice, Description: `A comma-separated list of names. diff --git a/builtin/credential/cert/path_config.go b/builtin/credential/cert/path_config.go index 68205e134630..53c7878c289e 100644 --- a/builtin/credential/cert/path_config.go +++ b/builtin/credential/cert/path_config.go @@ -8,6 +8,8 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) +const maxCacheSize = 100000 + func pathConfig(b *backend) *framework.Path { return &framework.Path{ Pattern: "config", @@ -37,13 +39,23 @@ func pathConfig(b *backend) *framework.Path { } func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - disableBinding := data.Get("disable_binding").(bool) - enableIdentityAliasMetadata := data.Get("enable_identity_alias_metadata").(bool) - cacheSize := data.Get("ocsp_cache_size").(int) - config := config{ - DisableBinding: disableBinding, - EnableIdentityAliasMetadata: enableIdentityAliasMetadata, - OcspCacheSize: cacheSize, + config, err := b.Config(ctx, req.Storage) + if err != nil { + return nil, err + } + + if disableBindingRaw, ok := data.GetOk("disable_binding"); ok { + config.DisableBinding = disableBindingRaw.(bool) + } + if enableIdentityAliasMetadataRaw, ok := data.GetOk("enable_identity_alias_metadata"); ok { + config.EnableIdentityAliasMetadata = enableIdentityAliasMetadataRaw.(bool) + } + if cacheSizeRaw, ok := data.GetOk("ocsp_cache_size"); ok { + cacheSize := cacheSizeRaw.(int) + if cacheSize < 2 || cacheSize > maxCacheSize { + return logical.ErrorResponse("invalid cache size, must be >= 2 and <= %d", maxCacheSize), nil + } + config.OcspCacheSize = cacheSize } entry, err := logical.StorageEntryJSON("config", config) if err != nil { @@ -53,7 +65,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat if err := req.Storage.Put(ctx, entry); err != nil { return nil, err } - b.updatedConfig(&config) + b.updatedConfig(config) return nil, nil } diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index ad287fa8308f..4c67ffc51aba 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -312,7 +312,7 @@ func (b *backend) matchesConstraints(ctx context.Context, clientCert *x509.Certi b.matchesOrganizationalUnits(clientCert, config) && b.matchesCertificateExtensions(clientCert, config) if config.Entry.OcspEnabled { - ocspGood, err := b.checkForChainInOCSP(ctx, trustedChain, conf) + ocspGood, err := b.checkForCertInOCSP(ctx, clientCert, trustedChain, conf) if err != nil { return false, err } @@ -573,13 +573,13 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, return } -func (b *backend) checkForChainInOCSP(ctx context.Context, chain []*x509.Certificate, conf *ocsp.VerifyConfig) (bool, error) { +func (b *backend) checkForCertInOCSP(ctx context.Context, clientCert *x509.Certificate, chain []*x509.Certificate, conf *ocsp.VerifyConfig) (bool, error) { if !b.ocspEnabled || len(chain) < 2 { return true, nil } b.ocspClientMutex.RLock() defer b.ocspClientMutex.RUnlock() - err := b.ocspClient.VerifyPeerCertificate(ctx, [][]*x509.Certificate{chain}, conf) + err := b.ocspClient.VerifyLeafCertificate(ctx, clientCert, chain[0], conf) if err != nil { return false, nil } diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index 614a4a46cdeb..1c36a199b1ce 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-retryablehttp" lru "github.com/hashicorp/golang-lru" + "github.com/hashicorp/vault/sdk/helper/certutil" "golang.org/x/crypto/ocsp" "io" "math/big" @@ -23,6 +24,7 @@ import ( "net/http" "net/url" "strconv" + "strings" "sync" "time" ) @@ -157,7 +159,7 @@ func (c *Client) getHashAlgorithmFromOID(target pkix.AlgorithmIdentifier) crypto // isInValidityRange checks the validity func isInValidityRange(currTime, nextUpdate time.Time) bool { - return !currTime.After(nextUpdate) + return !nextUpdate.IsZero() && !currTime.After(nextUpdate) } func extractCertIDKeyFromRequest(ocspReq []byte) (*certIDKey, *ocspStatus) { @@ -362,8 +364,8 @@ func (c *Client) retryOCSP( }, nil } -// getRevocationStatus checks the certificate revocation status for subject using issuer certificate. -func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate, conf *VerifyConfig) (*ocspStatus, error) { +// GetRevocationStatus checks the certificate revocation status for subject using issuer certificate. +func (c *Client) GetRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate, conf *VerifyConfig) (*ocspStatus, error) { status, ocspReq, encodedCertID, err := c.ValidateWithCache(subject, issuer) if err != nil { return nil, err @@ -450,16 +452,38 @@ func (c *Client) getRevocationStatus(ctx context.Context, subject, issuer *x509. // Good by default var ret *ocspStatus = &ocspStatus{code: ocspStatusGood} ocspRes := ocspResponses[0] + var firstError error + foundRevocation := false for i, _ := range ocspHosts { if errors[i] != nil { - return nil, errors[i] - } else if ocspStatuses[i] != nil && (!conf.QueryAllServers || ocspStatuses[i].code != ocspStatusGood) { - ret = ocspStatuses[i] - ocspRes = ocspResponses[i] - break + if firstError == nil { + firstError = errors[i] + } + } else if ocspStatuses[i] != nil { + switch ocspStatuses[i].code { + case ocspStatusRevoked: + foundRevocation = true + ret = ocspStatuses[i] + ocspRes = ocspResponses[i] + break + case ocspStatusGood: + //continue + case ocspStatusUnknown: + if !conf.QueryAllServers { + // We may want to use this as the overall result + ret = ocspStatuses[i] + ocspRes = ocspResponses[i] + } + } } } + // If no server reported the cert revoked, but we did have an error, report it + if !foundRevocation && firstError != nil { + return nil, firstError + } + // otherwise ret should contain a response for the overall request + if !isValidOCSPStatus(ret.code) { return ret, nil } @@ -488,6 +512,28 @@ type VerifyConfig struct { QueryAllServers bool } +// VerifyLeafCertificate verifies just the subject against it's direct issuer +func (c *Client) VerifyLeafCertificate(ctx context.Context, subject, issuer *x509.Certificate, conf *VerifyConfig) error { + results, err := c.GetRevocationStatus(ctx, subject, issuer, conf) + if err != nil { + return err + } + if results.code == ocspStatusGood { + return nil + } else { + serial := issuer.SerialNumber + serialHex := strings.TrimSpace(certutil.GetHexFormatted(serial.Bytes(), ":")) + if results.code == ocspStatusRevoked { + return fmt.Errorf("certificate with serial number %s has been revoked", serialHex) + } else if conf.OcspFailureMode == FailOpenFalse { + return fmt.Errorf("unknown OCSP status for cert with serial number %s", strings.TrimSpace(certutil.GetHexFormatted(serial.Bytes(), ":"))) + } else { + c.Logger().Warn("could not validate OCSP status for cert, but continuing in fail open mode", "serial", serialHex) + } + } + return nil +} + // VerifyPeerCertificate verifies all of certificate revocation status func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate, conf *VerifyConfig) error { for i := 0; i < len(verifiedChains); i++ { @@ -589,7 +635,7 @@ func (c *Client) GetAllRevocationStatus(ctx context.Context, verifiedChains []*x n := len(verifiedChains) - 1 results := make([]*ocspStatus, n) for j := 0; j < n; j++ { - results[j], err = c.getRevocationStatus(ctx, verifiedChains[j], verifiedChains[j+1], conf) + results[j], err = c.GetRevocationStatus(ctx, verifiedChains[j], verifiedChains[j+1], conf) if err != nil { return nil, err } From ca7e8c0d4a3923e280623c08781bc5ea6bf306a5 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Wed, 14 Sep 2022 09:40:02 -0500 Subject: [PATCH 20/39] Rework error returning logic --- sdk/helper/ocsp/client.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index 1c36a199b1ce..3a4763352a23 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -453,7 +453,6 @@ func (c *Client) GetRevocationStatus(ctx context.Context, subject, issuer *x509. var ret *ocspStatus = &ocspStatus{code: ocspStatusGood} ocspRes := ocspResponses[0] var firstError error - foundRevocation := false for i, _ := range ocspHosts { if errors[i] != nil { if firstError == nil { @@ -462,7 +461,6 @@ func (c *Client) GetRevocationStatus(ctx context.Context, subject, issuer *x509. } else if ocspStatuses[i] != nil { switch ocspStatuses[i].code { case ocspStatusRevoked: - foundRevocation = true ret = ocspStatuses[i] ocspRes = ocspResponses[i] break @@ -479,7 +477,7 @@ func (c *Client) GetRevocationStatus(ctx context.Context, subject, issuer *x509. } // If no server reported the cert revoked, but we did have an error, report it - if !foundRevocation && firstError != nil { + if (ret == nil || ret.code == ocspStatusUnknown) && firstError != nil { return nil, firstError } // otherwise ret should contain a response for the overall request From b5f2b03eb28e16cc88908ef588707d7352110c73 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Wed, 14 Sep 2022 10:23:43 -0500 Subject: [PATCH 21/39] More edge cases --- sdk/helper/ocsp/client.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index 3a4763352a23..dac6ec3f935a 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -465,9 +465,13 @@ func (c *Client) GetRevocationStatus(ctx context.Context, subject, issuer *x509. ocspRes = ocspResponses[i] break case ocspStatusGood: - //continue + // Use this response only if we + if ret == nil { + ret = ocspStatuses[i] + ocspRes = ocspResponses[i] + } case ocspStatusUnknown: - if !conf.QueryAllServers { + if !conf.QueryAllServers && ret == nil { // We may want to use this as the overall result ret = ocspStatuses[i] ocspRes = ocspResponses[i] From 6ba87f9a494d5c66b63c4165c27b824f0e77b28f Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Wed, 14 Sep 2022 10:28:01 -0500 Subject: [PATCH 22/39] MORE edge cases --- sdk/helper/ocsp/client.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index dac6ec3f935a..505eec7eddd9 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -450,7 +450,7 @@ func (c *Client) GetRevocationStatus(ctx context.Context, subject, issuer *x509. wg.Wait() } // Good by default - var ret *ocspStatus = &ocspStatus{code: ocspStatusGood} + var ret *ocspStatus ocspRes := ocspResponses[0] var firstError error for i, _ := range ocspHosts { @@ -465,13 +465,13 @@ func (c *Client) GetRevocationStatus(ctx context.Context, subject, issuer *x509. ocspRes = ocspResponses[i] break case ocspStatusGood: - // Use this response only if we - if ret == nil { + // Use this response only if we don't have a status already, or if what we have was unknown + if ret == nil || ret.code == ocspStatusUnknown { ret = ocspStatuses[i] ocspRes = ocspResponses[i] } case ocspStatusUnknown: - if !conf.QueryAllServers && ret == nil { + if ret == nil { // We may want to use this as the overall result ret = ocspStatuses[i] ocspRes = ocspResponses[i] From 51db45fdf8a9125ca7b56a2632b96ec196d81bfb Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Thu, 3 Nov 2022 11:33:29 -0500 Subject: [PATCH 23/39] Add a test matrix with a builtin responder --- builtin/credential/cert/backend.go | 24 +- builtin/credential/cert/backend_test.go | 57 +++-- builtin/credential/cert/path_certs.go | 110 ++++----- builtin/credential/cert/path_config.go | 8 +- builtin/credential/cert/path_login.go | 7 +- builtin/credential/cert/path_login_test.go | 109 ++++++++- builtin/credential/cert/responder.go | 265 +++++++++++++++++++++ sdk/helper/ocsp/client.go | 22 +- sdk/helper/ocsp/ocsp_test.go | 6 +- 9 files changed, 501 insertions(+), 107 deletions(-) create mode 100644 builtin/credential/cert/responder.go diff --git a/builtin/credential/cert/backend.go b/builtin/credential/cert/backend.go index b2c72d7f3498..b777bc175dda 100644 --- a/builtin/credential/cert/backend.go +++ b/builtin/credential/cert/backend.go @@ -11,9 +11,9 @@ import ( "time" "github.com/hashicorp/go-hclog" - "github.com/hashicorp/vault/sdk/helper/ocsp" "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/ocsp" "github.com/hashicorp/vault/sdk/logical" ) @@ -66,7 +66,6 @@ type backend struct { MapCertId *framework.PathMap crls map[string]CRLInfo - ocspEnabled bool crlUpdateMutex *sync.RWMutex ocspClientMutex sync.RWMutex ocspClient *ocsp.Client @@ -96,13 +95,7 @@ func (b *backend) initOCSPClient(cacheSize int) { func (b *backend) updatedConfig(config *config) error { b.ocspClientMutex.Lock() defer b.ocspClientMutex.Unlock() - if config != nil { - if b.ocspEnabled { - b.initOCSPClient(config.OcspCacheSize) - } else { - b.ocspClient = nil - } - } + b.initOCSPClient(config.OcspCacheSize) b.configUpdated = false return nil } @@ -141,6 +134,19 @@ func (b *backend) updateCRLs(ctx context.Context, req *logical.Request) error { return errs.ErrorOrNil() } +func (b *backend) storeConfig(ctx context.Context, storage logical.Storage, config *config) error { + entry, err := logical.StorageEntryJSON("config", config) + if err != nil { + return err + } + + if err := storage.Put(ctx, entry); err != nil { + return err + } + b.updatedConfig(config) + return nil +} + const backendHelp = ` The "cert" credential provider allows authentication using TLS client certificates. A client connects to Vault and uses diff --git a/builtin/credential/cert/backend_test.go b/builtin/credential/cert/backend_test.go index 062fc156bc7a..2c5bae6aed03 100644 --- a/builtin/credential/cert/backend_test.go +++ b/builtin/credential/cert/backend_test.go @@ -1062,12 +1062,13 @@ func TestBackend_CRLs(t *testing.T) { } func testFactory(t *testing.T) logical.Backend { + storage := &logical.InmemStorage{} b, err := Factory(context.Background(), &logical.BackendConfig{ System: &logical.StaticSystemView{ DefaultLeaseTTLVal: 1000 * time.Second, MaxLeaseTTLVal: 1800 * time.Second, }, - StorageView: &logical.InmemStorage{}, + StorageView: storage, }) if err != nil { t.Fatalf("error: %s", err) @@ -1475,7 +1476,7 @@ func TestBackend_mixed_constraints(t *testing.T) { testAccStepCert(t, "3invalid", ca, "foo", allowed{names: "invalid"}, false), testAccStepLogin(t, connState), // Assumes CertEntries are processed in alphabetical order (due to store.List), so we only match 2matching if 1unconstrained doesn't match - testAccStepLoginWithName(t, connState, "2matching"), + testAccStepLoginWithName(t, connState, "2matching", false), testAccStepLoginWithNameInvalid(t, connState, "3invalid"), }, }) @@ -1719,16 +1720,22 @@ func testAccStepReadConfig(t *testing.T, conf config, connState tls.ConnectionSt } func testAccStepLogin(t *testing.T, connState tls.ConnectionState) logicaltest.TestStep { - return testAccStepLoginWithName(t, connState, "") + return testAccStepLoginWithName(t, connState, "", false) } -func testAccStepLoginWithName(t *testing.T, connState tls.ConnectionState, certName string) logicaltest.TestStep { +func testAccStepLoginWithName(t *testing.T, connState tls.ConnectionState, certName string, errExpected bool) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.UpdateOperation, Path: "login", Unauthenticated: true, ConnState: &connState, Check: func(resp *logical.Response) error { + if errExpected { + if !resp.IsError() { + t.Fatalf("expected error") + } + return nil + } if resp.Auth.TTL != 1000*time.Second { t.Fatalf("bad lease length: %#v", resp.Auth) } @@ -1743,6 +1750,7 @@ func testAccStepLoginWithName(t *testing.T, connState tls.ConnectionState, certN Data: map[string]interface{}{ "name": certName, }, + ErrorOk: errExpected, } } @@ -1893,27 +1901,34 @@ type allowed struct { metadata_ext string // allowed metadata extensions to add to identity alias } -func testAccStepCert( - t *testing.T, name string, cert []byte, policies string, testData allowed, expectError bool, -) logicaltest.TestStep { +func testAccStepCert(t *testing.T, name string, cert []byte, policies string, testData allowed, expectError bool) logicaltest.TestStep { + return testAccStepCertWithExtraParams(t, name, cert, policies, testData, expectError, nil) +} + +func testAccStepCertWithExtraParams(t *testing.T, name string, cert []byte, policies string, testData allowed, expectError bool, extraParams map[string]interface{}) logicaltest.TestStep { + data := map[string]interface{}{ + "certificate": string(cert), + "policies": policies, + "display_name": name, + "allowed_names": testData.names, + "allowed_common_names": testData.common_names, + "allowed_dns_sans": testData.dns, + "allowed_email_sans": testData.emails, + "allowed_uri_sans": testData.uris, + "allowed_organizational_units": testData.organizational_units, + "required_extensions": testData.ext, + "allowed_metadata_extensions": testData.metadata_ext, + "lease": 1000, + "ocsp_enabled": true, + } + for k, v := range extraParams { + data[k] = v + } return logicaltest.TestStep{ Operation: logical.UpdateOperation, Path: "certs/" + name, ErrorOk: expectError, - Data: map[string]interface{}{ - "certificate": string(cert), - "policies": policies, - "display_name": name, - "allowed_names": testData.names, - "allowed_common_names": testData.common_names, - "allowed_dns_sans": testData.dns, - "allowed_email_sans": testData.emails, - "allowed_uri_sans": testData.uris, - "allowed_organizational_units": testData.organizational_units, - "required_extensions": testData.ext, - "allowed_metadata_extensions": testData.metadata_ext, - "lease": 1000, - }, + Data: data, Check: func(resp *logical.Response) error { if resp == nil && expectError { return fmt.Errorf("expected error but received nil") diff --git a/builtin/credential/cert/path_certs.go b/builtin/credential/cert/path_certs.go index 17b206f52a9c..7184435ef56e 100644 --- a/builtin/credential/cert/path_certs.go +++ b/builtin/credential/cert/path_certs.go @@ -4,10 +4,11 @@ import ( "context" "crypto/x509" "fmt" - "github.com/hashicorp/go-sockaddr" "strings" "time" + "github.com/hashicorp/go-sockaddr" + "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" @@ -429,63 +430,63 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr return logical.ErrorResponse("failed to parse certificate"), nil } - // If the certificate is not a CA cert, then ensure that x509.ExtKeyUsageClientAuth is set - if !parsed[0].IsCA && parsed[0].ExtKeyUsage != nil { - var clientAuth bool - for _, usage := range parsed[0].ExtKeyUsage { - if usage == x509.ExtKeyUsageClientAuth || usage == x509.ExtKeyUsageAny { - clientAuth = true - break - } - } - if !clientAuth { - return logical.ErrorResponse("nonCA certificates should have TLS client authentication set as an extended key usage"), nil - } - } - - // Store it - entry, err := logical.StorageEntryJSON("cert/"+name, cert) - if err != nil { - return nil, err - } - if err := req.Storage.Put(ctx, entry); err != nil { - return nil, err - } - - if len(resp.Warnings) == 0 { - return nil, nil - } - - return &resp, nil + // If the certificate is not a CA cert, then ensure that x509.ExtKeyUsageClientAuth is set + if !parsed[0].IsCA && parsed[0].ExtKeyUsage != nil { + var clientAuth bool + for _, usage := range parsed[0].ExtKeyUsage { + if usage == x509.ExtKeyUsageClientAuth || usage == x509.ExtKeyUsageAny { + clientAuth = true + break + } + } + if !clientAuth { + return logical.ErrorResponse("nonCA certificates should have TLS client authentication set as an extended key usage"), nil + } } -type CertEntry struct { - tokenutil.TokenParams - - Name string - Certificate string - DisplayName string - Policies []string - TTL time.Duration - MaxTTL time.Duration - Period time.Duration - AllowedNames []string - AllowedCommonNames []string - AllowedDNSSANs []string - AllowedEmailSANs []string - AllowedURISANs []string - AllowedOrganizationalUnits []string - RequiredExtensions []string - AllowedMetadataExtensions []string - BoundCIDRs []*sockaddr.SockAddrMarshaler - - OcspCaCertificates string - OcspEnabled bool - OcspServersOverride []string - OcspFailOpen bool - OcspQueryAllServers bool + // Store it + entry, err := logical.StorageEntryJSON("cert/"+name, cert) + if err != nil { + return nil, err + } + if err := req.Storage.Put(ctx, entry); err != nil { + return nil, err } + if len(resp.Warnings) == 0 { + return nil, nil + } + + return &resp, nil +} + +type CertEntry struct { + tokenutil.TokenParams + + Name string + Certificate string + DisplayName string + Policies []string + TTL time.Duration + MaxTTL time.Duration + Period time.Duration + AllowedNames []string + AllowedCommonNames []string + AllowedDNSSANs []string + AllowedEmailSANs []string + AllowedURISANs []string + AllowedOrganizationalUnits []string + RequiredExtensions []string + AllowedMetadataExtensions []string + BoundCIDRs []*sockaddr.SockAddrMarshaler + + OcspCaCertificates string + OcspEnabled bool + OcspServersOverride []string + OcspFailOpen bool + OcspQueryAllServers bool +} + const pathCertHelpSyn = ` Manage trusted certificates used for authentication. ` @@ -498,3 +499,4 @@ Deleting a certificate will not revoke auth for prior authenticated connections. To do this, do a revoke on "login". If you don't need to revoke login immediately, then the next renew will cause the lease to expire. +` diff --git a/builtin/credential/cert/path_config.go b/builtin/credential/cert/path_config.go index 53c7878c289e..c08992af15c4 100644 --- a/builtin/credential/cert/path_config.go +++ b/builtin/credential/cert/path_config.go @@ -57,15 +57,9 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat } config.OcspCacheSize = cacheSize } - entry, err := logical.StorageEntryJSON("config", config) - if err != nil { - return nil, err - } - - if err := req.Storage.Put(ctx, entry); err != nil { + if err := b.storeConfig(ctx, req.Storage, config); err != nil { return nil, err } - b.updatedConfig(config) return nil, nil } diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index bbfd9652ac6b..359396888f66 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -10,9 +10,10 @@ import ( "encoding/pem" "errors" "fmt" - "github.com/hashicorp/vault/sdk/helper/ocsp" "strings" + "github.com/hashicorp/vault/sdk/helper/ocsp" + "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/policyutil" @@ -587,12 +588,12 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, } func (b *backend) checkForCertInOCSP(ctx context.Context, clientCert *x509.Certificate, chain []*x509.Certificate, conf *ocsp.VerifyConfig) (bool, error) { - if !b.ocspEnabled || len(chain) < 2 { + if len(chain) < 2 { return true, nil } b.ocspClientMutex.RLock() defer b.ocspClientMutex.RUnlock() - err := b.ocspClient.VerifyLeafCertificate(ctx, clientCert, chain[0], conf) + err := b.ocspClient.VerifyLeafCertificate(ctx, clientCert, chain[len(chain)-1], conf) if err != nil { return false, nil } diff --git a/builtin/credential/cert/path_login_test.go b/builtin/credential/cert/path_login_test.go index a01ec981663f..2ca0a70b31f7 100644 --- a/builtin/credential/cert/path_login_test.go +++ b/builtin/credential/cert/path_login_test.go @@ -4,21 +4,39 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "fmt" "io/ioutil" "math/big" mathrand "math/rand" "net" + "net/http" "os" "path/filepath" "strings" "testing" "time" + "github.com/hashicorp/vault/sdk/helper/certutil" + + "golang.org/x/crypto/ocsp" + logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical" "github.com/hashicorp/vault/sdk/logical" ) +const ocspPort = 31808 + +var source InMemorySource + +func TestMain(m *testing.M) { + source = make(InMemorySource) + go func() { + http.ListenAndServe(fmt.Sprintf("localhost:%d", ocspPort), NewResponder(source, nil)) + }() + m.Run() +} + func TestCert_RoleResolve(t *testing.T) { certTemplate := &x509.Certificate{ Subject: pkix.Name{ @@ -52,7 +70,7 @@ func TestCert_RoleResolve(t *testing.T) { CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{dns: "example.com"}, false), - testAccStepLoginWithName(t, connState, "web"), + testAccStepLoginWithName(t, connState, "web", false), testAccStepResolveRoleWithName(t, connState, "web"), }, }) @@ -109,7 +127,7 @@ func TestCert_RoleResolveWithoutProvidingCertName(t *testing.T) { CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{dns: "example.com"}, false), - testAccStepLoginWithName(t, connState, "web"), + testAccStepLoginWithName(t, connState, "web", false), testAccStepResolveRoleWithEmptyDataMap(t, connState, "web"), }, }) @@ -192,8 +210,93 @@ func TestCert_RoleResolve_RoleDoesNotExist(t *testing.T) { CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{dns: "example.com"}, false), - testAccStepLoginWithName(t, connState, "web"), + testAccStepLoginWithName(t, connState, "web", false), testAccStepResolveRoleExpectRoleResolutionToFail(t, connState, "notweb"), }, }) } + +func TestCert_RoleResolveOCSP(t *testing.T) { + cases := []struct { + name string + failOpen bool + certStatus int + errExpected bool + }{ + {"failFalseGoodCert", false, ocsp.Good, false}, + {"failFalseRevokedCert", false, ocsp.Revoked, true}, + {"failFalseUnknownCert", false, ocsp.Unknown, true}, + {"failTrueGoodCert", true, ocsp.Good, false}, + {"failTrueRevokedCert", true, ocsp.Revoked, true}, + {"failTrueUnknownCert", true, ocsp.Unknown, false}, + } + certTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "example.com", + }, + DNSNames: []string{"example.com"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + OCSPServer: []string{fmt.Sprintf("http://localhost:%d", ocspPort)}, + } + tempDir, connState, err := generateTestCertAndConnState(t, certTemplate) + if tempDir != "" { + defer os.RemoveAll(tempDir) + } + if err != nil { + t.Fatalf("error testing connection state: %v", err) + } + ca, err := ioutil.ReadFile(filepath.Join(tempDir, "ca_cert.pem")) + if err != nil { + t.Fatalf("err: %v", err) + } + + issuer := parsePEM(ca) + pkf, err := ioutil.ReadFile(filepath.Join(tempDir, "ca_key.pem")) + if err != nil { + t.Fatalf("err: %v", err) + } + pk, err := certutil.ParsePEMBundle(string(pkf)) + if err != nil { + t.Fatalf("err: %v", err) + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + resp, err := ocsp.CreateResponse(issuer[0], issuer[0], ocsp.Response{ + Status: c.certStatus, + SerialNumber: certTemplate.SerialNumber, + ProducedAt: time.Now(), + ThisUpdate: time.Now(), + NextUpdate: time.Now().Add(time.Hour), + }, pk.PrivateKey) + if err != nil { + t.Fatal(err) + } + source[certTemplate.SerialNumber.String()] = resp + + b := testFactory(t) + b.(*backend).ocspClient.ClearCache() + logicaltest.Test(t, logicaltest.TestCase{ + CredentialBackend: b, + Steps: []logicaltest.TestStep{ + testAccStepCertWithExtraParams(t, "web", ca, "foo", allowed{dns: "example.com"}, false, + map[string]interface{}{"ocsp_fail_open": c.failOpen}), + testAccStepLoginWithName(t, connState, "web", c.errExpected), + testAccStepResolveRoleWithName(t, connState, "web"), + }, + }) + }) + } +} + +func serialFromBigInt(serial *big.Int) string { + return strings.TrimSpace(certutil.GetHexFormatted(serial.Bytes(), ":")) +} diff --git a/builtin/credential/cert/responder.go b/builtin/credential/cert/responder.go new file mode 100644 index 000000000000..5afc14b35a66 --- /dev/null +++ b/builtin/credential/cert/responder.go @@ -0,0 +1,265 @@ +// Package ocsp implements an OCSP responder based on a generic storage backend. +// It provides a couple of sample implementations. +// Because OCSP responders handle high query volumes, we have to be careful +// about how much logging we do. Error-level logs are reserved for problems +// internal to the server, that can be fixed by an administrator. Any type of +// incorrect input from a user should be logged and Info or below. For things +// that are logged on every request, Debug is the appropriate level. +package cert + +import ( + "crypto" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "time" + + "golang.org/x/crypto/ocsp" +) + +var ( + malformedRequestErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x01} + internalErrorErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x02} + tryLaterErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x03} + sigRequredErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x05} + unauthorizedErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x06} + + // ErrNotFound indicates the request OCSP response was not found. It is used to + // indicate that the responder should reply with unauthorizedErrorResponse. + ErrNotFound = errors.New("Request OCSP Response not found") +) + +// Source represents the logical source of OCSP responses, i.e., +// the logic that actually chooses a response based on a request. In +// order to create an actual responder, wrap one of these in a Responder +// object and pass it to http.Handle. By default the Responder will set +// the headers Cache-Control to "max-age=(response.NextUpdate-now), public, no-transform, must-revalidate", +// Last-Modified to response.ThisUpdate, Expires to response.NextUpdate, +// ETag to the SHA256 hash of the response, and Content-Type to +// application/ocsp-response. If you want to override these headers, +// or set extra headers, your source should return a http.Header +// with the headers you wish to set. If you don't want to set any +// extra headers you may return nil instead. +type Source interface { + Response(*ocsp.Request) ([]byte, http.Header, error) +} + +// An InMemorySource is a map from serialNumber -> der(response) +type InMemorySource map[string][]byte + +// Response looks up an OCSP response to provide for a given request. +// InMemorySource looks up a response purely based on serial number, +// without regard to what issuer the request is asking for. +func (src InMemorySource) Response(request *ocsp.Request) ([]byte, http.Header, error) { + response, present := src[request.SerialNumber.String()] + if !present { + return nil, nil, ErrNotFound + } + return response, nil, nil +} + +// Stats is a basic interface that allows users to record information +// about returned responses +type Stats interface { + ResponseStatus(ocsp.ResponseStatus) +} + +// A Responder object provides the HTTP logic to expose a +// Source of OCSP responses. +type Responder struct { + Source Source + stats Stats +} + +// NewResponder instantiates a Responder with the give Source. +func NewResponder(source Source, stats Stats) *Responder { + return &Responder{ + Source: source, + stats: stats, + } +} + +func overrideHeaders(response http.ResponseWriter, headers http.Header) { + for k, v := range headers { + if len(v) == 1 { + response.Header().Set(k, v[0]) + } else if len(v) > 1 { + response.Header().Del(k) + for _, e := range v { + response.Header().Add(k, e) + } + } + } +} + +// hashToString contains mappings for the only hash functions +// x/crypto/ocsp supports +var hashToString = map[crypto.Hash]string{ + crypto.SHA1: "SHA1", + crypto.SHA256: "SHA256", + crypto.SHA384: "SHA384", + crypto.SHA512: "SHA512", +} + +// A Responder can process both GET and POST requests. The mapping +// from an OCSP request to an OCSP response is done by the Source; +// the Responder simply decodes the request, and passes back whatever +// response is provided by the source. +// Note: The caller must use http.StripPrefix to strip any path components +// (including '/') on GET requests. +// Do not use this responder in conjunction with http.NewServeMux, because the +// default handler will try to canonicalize path components by changing any +// strings of repeated '/' into a single '/', which will break the base64 +// encoding. +func (rs Responder) ServeHTTP(response http.ResponseWriter, request *http.Request) { + // By default we set a 'max-age=0, no-cache' Cache-Control header, this + // is only returned to the client if a valid authorized OCSP response + // is not found or an error is returned. If a response if found the header + // will be altered to contain the proper max-age and modifiers. + response.Header().Add("Cache-Control", "max-age=0, no-cache") + // Read response from request + var requestBody []byte + var err error + switch request.Method { + case "GET": + base64Request, err := url.QueryUnescape(request.URL.Path) + if err != nil { + fmt.Printf("Error decoding URL: %s", request.URL.Path) + response.WriteHeader(http.StatusBadRequest) + return + } + // url.QueryUnescape not only unescapes %2B escaping, but it additionally + // turns the resulting '+' into a space, which makes base64 decoding fail. + // So we go back afterwards and turn ' ' back into '+'. This means we + // accept some malformed input that includes ' ' or %20, but that's fine. + base64RequestBytes := []byte(base64Request) + for i := range base64RequestBytes { + if base64RequestBytes[i] == ' ' { + base64RequestBytes[i] = '+' + } + } + // In certain situations a UA may construct a request that has a double + // slash between the host name and the base64 request body due to naively + // constructing the request URL. In that case strip the leading slash + // so that we can still decode the request. + if len(base64RequestBytes) > 0 && base64RequestBytes[0] == '/' { + base64RequestBytes = base64RequestBytes[1:] + } + requestBody, err = base64.StdEncoding.DecodeString(string(base64RequestBytes)) + if err != nil { + fmt.Printf("Error decoding base64 from URL: %s", string(base64RequestBytes)) + response.WriteHeader(http.StatusBadRequest) + return + } + case "POST": + requestBody, err = ioutil.ReadAll(request.Body) + if err != nil { + fmt.Printf("Problem reading body of POST: %s", err) + response.WriteHeader(http.StatusBadRequest) + return + } + default: + response.WriteHeader(http.StatusMethodNotAllowed) + return + } + b64Body := base64.StdEncoding.EncodeToString(requestBody) + fmt.Printf("Received OCSP request: %s", b64Body) + + // All responses after this point will be OCSP. + // We could check for the content type of the request, but that + // seems unnecessariliy restrictive. + response.Header().Add("Content-Type", "application/ocsp-response") + + // Parse response as an OCSP request + // XXX: This fails if the request contains the nonce extension. + // We don't intend to support nonces anyway, but maybe we + // should return unauthorizedRequest instead of malformed. + ocspRequest, err := ocsp.ParseRequest(requestBody) + if err != nil { + fmt.Printf("Error decoding request body: %s", b64Body) + response.WriteHeader(http.StatusBadRequest) + response.Write(malformedRequestErrorResponse) + if rs.stats != nil { + rs.stats.ResponseStatus(ocsp.Malformed) + } + return + } + + // Look up OCSP response from source + ocspResponse, headers, err := rs.Source.Response(ocspRequest) + if err != nil { + if err == ErrNotFound { + fmt.Printf("No response found for request: serial %x, request body %s", + ocspRequest.SerialNumber, b64Body) + response.Write(unauthorizedErrorResponse) + if rs.stats != nil { + rs.stats.ResponseStatus(ocsp.Unauthorized) + } + return + } + fmt.Printf("Error retrieving response for request: serial %x, request body %s, error: %s", + ocspRequest.SerialNumber, b64Body, err) + response.WriteHeader(http.StatusInternalServerError) + response.Write(internalErrorErrorResponse) + if rs.stats != nil { + rs.stats.ResponseStatus(ocsp.InternalError) + } + return + } + + parsedResponse, err := ocsp.ParseResponse(ocspResponse, nil) + if err != nil { + fmt.Printf("Error parsing response for serial %x: %s", + ocspRequest.SerialNumber, err) + response.Write(internalErrorErrorResponse) + if rs.stats != nil { + rs.stats.ResponseStatus(ocsp.InternalError) + } + return + } + + // Write OCSP response to response + response.Header().Add("Last-Modified", parsedResponse.ThisUpdate.Format(time.RFC1123)) + response.Header().Add("Expires", parsedResponse.NextUpdate.Format(time.RFC1123)) + now := time.Now() + maxAge := 0 + if now.Before(parsedResponse.NextUpdate) { + maxAge = int(parsedResponse.NextUpdate.Sub(now) / time.Second) + } else { + // TODO(#530): we want max-age=0 but this is technically an authorized OCSP response + // (despite being stale) and 5019 forbids attaching no-cache + maxAge = 0 + } + response.Header().Set( + "Cache-Control", + fmt.Sprintf( + "max-age=%d, public, no-transform, must-revalidate", + maxAge, + ), + ) + responseHash := sha256.Sum256(ocspResponse) + response.Header().Add("ETag", fmt.Sprintf("\"%X\"", responseHash)) + + if headers != nil { + overrideHeaders(response, headers) + } + + // RFC 7232 says that a 304 response must contain the above + // headers if they would also be sent for a 200 for the same + // request, so we have to wait until here to do this + if etag := request.Header.Get("If-None-Match"); etag != "" { + if etag == fmt.Sprintf("\"%X\"", responseHash) { + response.WriteHeader(http.StatusNotModified) + return + } + } + response.WriteHeader(http.StatusOK) + response.Write(ocspResponse) + if rs.stats != nil { + rs.stats.ResponseStatus(ocsp.Success) + } +} diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index 505eec7eddd9..f31dc05ac91c 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -13,11 +13,6 @@ import ( "encoding/base64" "errors" "fmt" - "github.com/hashicorp/go-hclog" - "github.com/hashicorp/go-retryablehttp" - lru "github.com/hashicorp/golang-lru" - "github.com/hashicorp/vault/sdk/helper/certutil" - "golang.org/x/crypto/ocsp" "io" "math/big" "net" @@ -27,6 +22,12 @@ import ( "strings" "sync" "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-retryablehttp" + lru "github.com/hashicorp/golang-lru" + "github.com/hashicorp/vault/sdk/helper/certutil" + "golang.org/x/crypto/ocsp" ) // FailOpenMode is OCSP fail open mode. FailOpenTrue by default and may @@ -147,13 +148,17 @@ func getOIDFromHashAlgorithm(target crypto.Hash) (asn1.ObjectIdentifier, error) return nil, fmt.Errorf("no valid OID is found for the hash algorithm: %v", target) } +func (c *Client) ClearCache() { + c.ocspResponseCache.Purge() +} + func (c *Client) getHashAlgorithmFromOID(target pkix.AlgorithmIdentifier) crypto.Hash { for hash, oid := range hashOIDs { if oid.Equal(target.Algorithm) { return hash } } - //no valid hash algorithm is found for the oid. Falling back to SHA1 + // no valid hash algorithm is found for the oid. Falling back to SHA1 return crypto.SHA1 } @@ -453,7 +458,7 @@ func (c *Client) GetRevocationStatus(ctx context.Context, subject, issuer *x509. var ret *ocspStatus ocspRes := ocspResponses[0] var firstError error - for i, _ := range ocspHosts { + for i := range ocspHosts { if errors[i] != nil { if firstError == nil { firstError = errors[i] @@ -785,6 +790,9 @@ func (c *Client) readOCSPCache(ctx context.Context, storage logical.Storage) err */ func New(logFactory func() hclog.Logger, cacheSize int) *Client { + if cacheSize < 100 { + cacheSize = 100 + } cache, _ := lru.New2Q(cacheSize) c := Client{ caRoot: make(map[string]*x509.Certificate), diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go index 19a9f0a9faba..45ee45d3611f 100644 --- a/sdk/helper/ocsp/ocsp_test.go +++ b/sdk/helper/ocsp/ocsp_test.go @@ -11,9 +11,6 @@ import ( "encoding/pem" "errors" "fmt" - "github.com/hashicorp/go-hclog" - "github.com/hashicorp/go-retryablehttp" - lru "github.com/hashicorp/golang-lru" "io" "io/ioutil" "net" @@ -22,6 +19,9 @@ import ( "testing" "time" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-retryablehttp" + lru "github.com/hashicorp/golang-lru" "golang.org/x/crypto/ocsp" ) From 0642d7e4a41c0756b49885d28e0df1153dbd4cea Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Thu, 3 Nov 2022 11:41:55 -0500 Subject: [PATCH 24/39] changelog --- changelog/17093.txt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 changelog/17093.txt diff --git a/changelog/17093.txt b/changelog/17093.txt new file mode 100644 index 000000000000..a51f3de8ff4b --- /dev/null +++ b/changelog/17093.txt @@ -0,0 +1,3 @@ +```release-note:improvement +auth/cert: Add configurable support for validating client certs with OCSP. +``` \ No newline at end of file From 0ff724c490307b807007cb79df40b32f4eb26009 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Fri, 4 Nov 2022 09:34:51 -0500 Subject: [PATCH 25/39] Use an atomic for configUpdated --- builtin/credential/cert/backend.go | 10 ++++------ builtin/credential/cert/path_login.go | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/builtin/credential/cert/backend.go b/builtin/credential/cert/backend.go index b777bc175dda..53ebc9d74834 100644 --- a/builtin/credential/cert/backend.go +++ b/builtin/credential/cert/backend.go @@ -8,6 +8,7 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "time" "github.com/hashicorp/go-hclog" @@ -69,7 +70,7 @@ type backend struct { crlUpdateMutex *sync.RWMutex ocspClientMutex sync.RWMutex ocspClient *ocsp.Client - configUpdated bool + configUpdated atomic.Bool } func (b *backend) invalidate(_ context.Context, key string) { @@ -79,10 +80,7 @@ func (b *backend) invalidate(_ context.Context, key string) { defer b.crlUpdateMutex.Unlock() b.crls = nil case key == "config": - // Is this really necessary? - b.ocspClientMutex.Lock() - defer b.ocspClientMutex.Unlock() - b.configUpdated = true + b.configUpdated.Store(true) } } @@ -96,7 +94,7 @@ func (b *backend) updatedConfig(config *config) error { b.ocspClientMutex.Lock() defer b.ocspClientMutex.Unlock() b.initOCSPClient(config.OcspCacheSize) - b.configUpdated = false + b.configUpdated.Store(false) return nil } diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 359396888f66..75ffed0982b6 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -83,7 +83,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra if err != nil { return nil, err } - if b.configUpdated { + if b.configUpdated.Load() { b.updatedConfig(config) } @@ -166,7 +166,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f if err != nil { return nil, err } - if b.configUpdated { + if b.configUpdated.Load() { b.updatedConfig(config) } From 7e50dd0f293308d1cc3f67bfacc3386cd85c7a97 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Fri, 4 Nov 2022 11:13:55 -0500 Subject: [PATCH 26/39] Actually use ocsp_enabled, and bind to a random port for testing --- builtin/credential/cert/backend_test.go | 1 - builtin/credential/cert/path_login.go | 17 ++++++++++------- builtin/credential/cert/path_login_test.go | 18 +++++++++++++++--- sdk/helper/ocsp/client.go | 1 + 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/builtin/credential/cert/backend_test.go b/builtin/credential/cert/backend_test.go index 2c5bae6aed03..d0504df25133 100644 --- a/builtin/credential/cert/backend_test.go +++ b/builtin/credential/cert/backend_test.go @@ -1919,7 +1919,6 @@ func testAccStepCertWithExtraParams(t *testing.T, name string, cert []byte, poli "required_extensions": testData.ext, "allowed_metadata_extensions": testData.metadata_ext, "lease": 1000, - "ocsp_enabled": true, } for k, v := range extraParams { data[k] = v diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 75ffed0982b6..df39b52224ae 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -576,19 +576,22 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, Certificates: parsed, }) } - conf.OcspServersOverride = append(conf.OcspServersOverride, entry.OcspServersOverride...) - if entry.OcspFailOpen { - conf.OcspFailureMode = ocsp.FailOpenTrue - } else { - conf.OcspFailureMode = ocsp.FailOpenFalse + if entry.OcspEnabled { + conf.OcspEnabled = true + conf.OcspServersOverride = append(conf.OcspServersOverride, entry.OcspServersOverride...) + if entry.OcspFailOpen { + conf.OcspFailureMode = ocsp.FailOpenTrue + } else { + conf.OcspFailureMode = ocsp.FailOpenFalse + } + conf.QueryAllServers = conf.QueryAllServers || entry.OcspQueryAllServers } - conf.QueryAllServers = conf.QueryAllServers || entry.OcspQueryAllServers } return } func (b *backend) checkForCertInOCSP(ctx context.Context, clientCert *x509.Certificate, chain []*x509.Certificate, conf *ocsp.VerifyConfig) (bool, error) { - if len(chain) < 2 { + if !conf.OcspEnabled || len(chain) < 2 { return true, nil } b.ocspClientMutex.RLock() diff --git a/builtin/credential/cert/path_login_test.go b/builtin/credential/cert/path_login_test.go index 2ca0a70b31f7..02bf083ad2aa 100644 --- a/builtin/credential/cert/path_login_test.go +++ b/builtin/credential/cert/path_login_test.go @@ -1,6 +1,7 @@ package cert import ( + "context" "crypto/tls" "crypto/x509" "crypto/x509/pkix" @@ -25,15 +26,26 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) -const ocspPort = 31808 +var ocspPort int var source InMemorySource func TestMain(m *testing.M) { source = make(InMemorySource) + + listener, err := net.Listen("tcp", ":0") + if err != nil { + return + } + ocspPort = listener.Addr().(*net.TCPAddr).Port + srv := &http.Server{ + Addr: "localhost:0", + Handler: NewResponder(source, nil), + } go func() { - http.ListenAndServe(fmt.Sprintf("localhost:%d", ocspPort), NewResponder(source, nil)) + srv.Serve(listener) }() + defer srv.Shutdown(context.Background()) m.Run() } @@ -288,7 +300,7 @@ func TestCert_RoleResolveOCSP(t *testing.T) { CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepCertWithExtraParams(t, "web", ca, "foo", allowed{dns: "example.com"}, false, - map[string]interface{}{"ocsp_fail_open": c.failOpen}), + map[string]interface{}{"ocsp_enabled": true, "ocsp_fail_open": c.failOpen}), testAccStepLoginWithName(t, connState, "web", c.errExpected), testAccStepResolveRoleWithName(t, connState, "web"), }, diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index f31dc05ac91c..5ff32ad79780 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -513,6 +513,7 @@ func isValidOCSPStatus(status ocspStatusCode) bool { } type VerifyConfig struct { + OcspEnabled bool ExtraCas []*x509.Certificate OcspServersOverride []string OcspFailureMode FailOpenMode From 222bc583aeb092f901a32f9e81172bac5aa7074f Mon Sep 17 00:00:00 2001 From: Scott Miller Date: Mon, 7 Nov 2022 09:34:16 -0600 Subject: [PATCH 27/39] Update builtin/credential/cert/path_login.go Co-authored-by: Alexander Scheel --- builtin/credential/cert/path_login.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index df39b52224ae..3e0453f88a54 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -596,7 +596,7 @@ func (b *backend) checkForCertInOCSP(ctx context.Context, clientCert *x509.Certi } b.ocspClientMutex.RLock() defer b.ocspClientMutex.RUnlock() - err := b.ocspClient.VerifyLeafCertificate(ctx, clientCert, chain[len(chain)-1], conf) + err := b.ocspClient.VerifyLeafCertificate(ctx, clientCert, chain[1], conf) if err != nil { return false, nil } From f1921348276aee6bcad2e8370119c1c4bfb03bd0 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 8 Nov 2022 10:27:27 -0600 Subject: [PATCH 28/39] Refactor unit tests --- builtin/credential/cert/backend_test.go | 13 ++---- builtin/credential/cert/path_login.go | 2 +- builtin/credential/cert/path_login_test.go | 47 +++++++++++++++++++--- 3 files changed, 46 insertions(+), 16 deletions(-) diff --git a/builtin/credential/cert/backend_test.go b/builtin/credential/cert/backend_test.go index d0504df25133..9764cf608e42 100644 --- a/builtin/credential/cert/backend_test.go +++ b/builtin/credential/cert/backend_test.go @@ -1476,7 +1476,7 @@ func TestBackend_mixed_constraints(t *testing.T) { testAccStepCert(t, "3invalid", ca, "foo", allowed{names: "invalid"}, false), testAccStepLogin(t, connState), // Assumes CertEntries are processed in alphabetical order (due to store.List), so we only match 2matching if 1unconstrained doesn't match - testAccStepLoginWithName(t, connState, "2matching", false), + testAccStepLoginWithName(t, connState, "2matching"), testAccStepLoginWithNameInvalid(t, connState, "3invalid"), }, }) @@ -1720,22 +1720,16 @@ func testAccStepReadConfig(t *testing.T, conf config, connState tls.ConnectionSt } func testAccStepLogin(t *testing.T, connState tls.ConnectionState) logicaltest.TestStep { - return testAccStepLoginWithName(t, connState, "", false) + return testAccStepLoginWithName(t, connState, "") } -func testAccStepLoginWithName(t *testing.T, connState tls.ConnectionState, certName string, errExpected bool) logicaltest.TestStep { +func testAccStepLoginWithName(t *testing.T, connState tls.ConnectionState, certName string) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.UpdateOperation, Path: "login", Unauthenticated: true, ConnState: &connState, Check: func(resp *logical.Response) error { - if errExpected { - if !resp.IsError() { - t.Fatalf("expected error") - } - return nil - } if resp.Auth.TTL != 1000*time.Second { t.Fatalf("bad lease length: %#v", resp.Auth) } @@ -1750,7 +1744,6 @@ func testAccStepLoginWithName(t *testing.T, connState tls.ConnectionState, certN Data: map[string]interface{}{ "name": certName, }, - ErrorOk: errExpected, } } diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 3e0453f88a54..df39b52224ae 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -596,7 +596,7 @@ func (b *backend) checkForCertInOCSP(ctx context.Context, clientCert *x509.Certi } b.ocspClientMutex.RLock() defer b.ocspClientMutex.RUnlock() - err := b.ocspClient.VerifyLeafCertificate(ctx, clientCert, chain[1], conf) + err := b.ocspClient.VerifyLeafCertificate(ctx, clientCert, chain[len(chain)-1], conf) if err != nil { return false, nil } diff --git a/builtin/credential/cert/path_login_test.go b/builtin/credential/cert/path_login_test.go index 02bf083ad2aa..e4c9aca8d04c 100644 --- a/builtin/credential/cert/path_login_test.go +++ b/builtin/credential/cert/path_login_test.go @@ -82,7 +82,7 @@ func TestCert_RoleResolve(t *testing.T) { CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{dns: "example.com"}, false), - testAccStepLoginWithName(t, connState, "web", false), + testAccStepLoginWithName(t, connState, "web"), testAccStepResolveRoleWithName(t, connState, "web"), }, }) @@ -139,7 +139,7 @@ func TestCert_RoleResolveWithoutProvidingCertName(t *testing.T) { CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{dns: "example.com"}, false), - testAccStepLoginWithName(t, connState, "web", false), + testAccStepLoginWithName(t, connState, "web"), testAccStepResolveRoleWithEmptyDataMap(t, connState, "web"), }, }) @@ -189,6 +189,34 @@ func testAccStepResolveRoleExpectRoleResolutionToFail(t *testing.T, connState tl } } +func testAccStepResolveRoleOCSPFail(t *testing.T, connState tls.ConnectionState, certName string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ResolveRoleOperation, + Path: "login", + Unauthenticated: true, + ConnState: &connState, + ErrorOk: true, + Check: func(resp *logical.Response) error { + if resp == nil && !resp.IsError() { + t.Fatalf("Response was not an error: resp:%#v", resp) + } + + errString, ok := resp.Data["error"].(string) + if !ok { + t.Fatal("Error not part of response.") + } + + if !strings.Contains(errString, "no chain matching") { + t.Fatalf("Error was not due to OCSP failure. Error: %s", errString) + } + return nil + }, + Data: map[string]interface{}{ + "name": certName, + }, + } +} + func TestCert_RoleResolve_RoleDoesNotExist(t *testing.T) { certTemplate := &x509.Certificate{ Subject: pkix.Name{ @@ -222,7 +250,7 @@ func TestCert_RoleResolve_RoleDoesNotExist(t *testing.T) { CredentialBackend: testFactory(t), Steps: []logicaltest.TestStep{ testAccStepCert(t, "web", ca, "foo", allowed{dns: "example.com"}, false), - testAccStepLoginWithName(t, connState, "web", false), + testAccStepLoginWithName(t, connState, "web"), testAccStepResolveRoleExpectRoleResolutionToFail(t, connState, "notweb"), }, }) @@ -296,13 +324,22 @@ func TestCert_RoleResolveOCSP(t *testing.T) { b := testFactory(t) b.(*backend).ocspClient.ClearCache() + var resolveStep logicaltest.TestStep + var loginStep logicaltest.TestStep + if c.errExpected { + loginStep = testAccStepLoginWithNameInvalid(t, connState, "web") + resolveStep = testAccStepResolveRoleOCSPFail(t, connState, "web") + } else { + loginStep = testAccStepLoginWithName(t, connState, "web") + resolveStep = testAccStepResolveRoleWithName(t, connState, "web") + } logicaltest.Test(t, logicaltest.TestCase{ CredentialBackend: b, Steps: []logicaltest.TestStep{ testAccStepCertWithExtraParams(t, "web", ca, "foo", allowed{dns: "example.com"}, false, map[string]interface{}{"ocsp_enabled": true, "ocsp_fail_open": c.failOpen}), - testAccStepLoginWithName(t, connState, "web", c.errExpected), - testAccStepResolveRoleWithName(t, connState, "web"), + loginStep, + resolveStep, }, }) }) From b8fb6edfecb6ff49563ce0f95b0d4f1007584b1b Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 8 Nov 2022 12:05:42 -0600 Subject: [PATCH 29/39] Add status to cache --- builtin/credential/cert/path_login.go | 2 +- builtin/credential/cert/path_login_test.go | 2 +- builtin/logical/aws/iam_policies_test.go | 2 +- sdk/helper/ocsp/client.go | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index df39b52224ae..3e0453f88a54 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -596,7 +596,7 @@ func (b *backend) checkForCertInOCSP(ctx context.Context, clientCert *x509.Certi } b.ocspClientMutex.RLock() defer b.ocspClientMutex.RUnlock() - err := b.ocspClient.VerifyLeafCertificate(ctx, clientCert, chain[len(chain)-1], conf) + err := b.ocspClient.VerifyLeafCertificate(ctx, clientCert, chain[1], conf) if err != nil { return false, nil } diff --git a/builtin/credential/cert/path_login_test.go b/builtin/credential/cert/path_login_test.go index e4c9aca8d04c..36f43380de6e 100644 --- a/builtin/credential/cert/path_login_test.go +++ b/builtin/credential/cert/path_login_test.go @@ -197,7 +197,7 @@ func testAccStepResolveRoleOCSPFail(t *testing.T, connState tls.ConnectionState, ConnState: &connState, ErrorOk: true, Check: func(resp *logical.Response) error { - if resp == nil && !resp.IsError() { + if resp == nil || !resp.IsError() { t.Fatalf("Response was not an error: resp:%#v", resp) } diff --git a/builtin/logical/aws/iam_policies_test.go b/builtin/logical/aws/iam_policies_test.go index ddba67f6b8bd..5e8ae6feb6f0 100644 --- a/builtin/logical/aws/iam_policies_test.go +++ b/builtin/logical/aws/iam_policies_test.go @@ -207,7 +207,7 @@ func Test_combinePolicyDocuments(t *testing.T) { `{"Version": "2012-10-17", "Statement": [{"Effect": "Allow", "NotAction": "ec2:DescribeAvailabilityZones", "Resource": "*"}]}`, }, expectedOutput: `{"Version": "2012-10-17","Statement":[{"Effect": "Allow","NotAction": "ec2:DescribeAvailabilityZones", "Resource": "*"}]}`, - expectedErr: false, + expectedErr: false, }, { description: "one blank policy", diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index 5ff32ad79780..b0f8ae5b0155 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -495,6 +495,7 @@ func (c *Client) GetRevocationStatus(ctx context.Context, subject, issuer *x509. return ret, nil } v := ocspCachedResponse{ + status: ret.code, time: float64(time.Now().UTC().Unix()), producedAt: float64(ocspRes.ProducedAt.UTC().Unix()), thisUpdate: float64(ocspRes.ThisUpdate.UTC().Unix()), From 9933c3bea8a3efc917401ec587b9ac95549511d8 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 8 Nov 2022 13:09:08 -0600 Subject: [PATCH 30/39] Make some functions private --- sdk/helper/ocsp/client.go | 43 +++++---------------------------------- 1 file changed, 5 insertions(+), 38 deletions(-) diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index b0f8ae5b0155..5151f7606046 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -124,7 +124,6 @@ type certID struct { // cache key type certIDKey struct { - HashAlgorithm crypto.Hash NameHash string IssuerKeyHash string SerialNumber string @@ -178,7 +177,6 @@ func extractCertIDKeyFromRequest(ocspReq []byte) (*certIDKey, *ocspStatus) { // encode CertID, used as a key in the cache encodedCertID := &certIDKey{ - r.HashAlgorithm, base64.StdEncoding.EncodeToString(r.IssuerNameHash), base64.StdEncoding.EncodeToString(r.IssuerKeyHash), r.SerialNumber.String(), @@ -204,43 +202,12 @@ func (c *Client) encodeCertIDKey(certIDKeyBase64 string) (*certIDKey, error) { return nil, err } return &certIDKey{ - c.getHashAlgorithmFromOID(cid.HashAlgorithm), base64.StdEncoding.EncodeToString(cid.NameHash), base64.StdEncoding.EncodeToString(cid.IssuerKeyHash), cid.SerialNumber.String(), }, nil } -func decodeCertIDKey(k *certIDKey) (string, error) { - serialNumber := new(big.Int) - serialNumber.SetString(k.SerialNumber, 10) - nameHash, err := base64.StdEncoding.DecodeString(k.NameHash) - if err != nil { - return "", err - } - issuerKeyHash, err := base64.StdEncoding.DecodeString(k.IssuerKeyHash) - if err != nil { - return "", err - } - hashAlgoOid, err := getOIDFromHashAlgorithm(k.HashAlgorithm) - if err != nil { - return "", err - } - encodedCertID, err := asn1.Marshal(certID{ - pkix.AlgorithmIdentifier{ - Algorithm: hashAlgoOid, - Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */}, - }, - nameHash, - issuerKeyHash, - serialNumber, - }) - if err != nil { - return "", err - } - return base64.StdEncoding.EncodeToString(encodedCertID), nil -} - func (c *Client) checkOCSPResponseCache(encodedCertID *certIDKey, subject, issuer *x509.Certificate) (*ocspStatus, error) { c.ocspResponseCacheLock.RLock() var cacheValue *ocspCachedResponse @@ -371,7 +338,7 @@ func (c *Client) retryOCSP( // GetRevocationStatus checks the certificate revocation status for subject using issuer certificate. func (c *Client) GetRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate, conf *VerifyConfig) (*ocspStatus, error) { - status, ocspReq, encodedCertID, err := c.ValidateWithCache(subject, issuer) + status, ocspReq, encodedCertID, err := c.validateWithCache(subject, issuer) if err != nil { return nil, err } @@ -604,12 +571,12 @@ func (c *Client) canEarlyExitForOCSP(results []*ocspStatus, chainSize int, conf return nil } -func (c *Client) ValidateWithCacheForAllCertificates(verifiedChains []*x509.Certificate) (bool, error) { +func (c *Client) validateWithCacheForAllCertificates(verifiedChains []*x509.Certificate) (bool, error) { n := len(verifiedChains) - 1 for j := 0; j < n; j++ { subject := verifiedChains[j] issuer := verifiedChains[j+1] - status, _, _, err := c.ValidateWithCache(subject, issuer) + status, _, _, err := c.validateWithCache(subject, issuer) if err != nil { return false, err } @@ -620,7 +587,7 @@ func (c *Client) ValidateWithCacheForAllCertificates(verifiedChains []*x509.Cert return true, nil } -func (c *Client) ValidateWithCache(subject, issuer *x509.Certificate) (*ocspStatus, []byte, *certIDKey, error) { +func (c *Client) validateWithCache(subject, issuer *x509.Certificate) (*ocspStatus, []byte, *certIDKey, error) { ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{}) if err != nil { return nil, nil, nil, fmt.Errorf("failed to create OCSP request from the certificates: %v", err) @@ -637,7 +604,7 @@ func (c *Client) ValidateWithCache(subject, issuer *x509.Certificate) (*ocspStat } func (c *Client) GetAllRevocationStatus(ctx context.Context, verifiedChains []*x509.Certificate, conf *VerifyConfig) ([]*ocspStatus, error) { - _, err := c.ValidateWithCacheForAllCertificates(verifiedChains) + _, err := c.validateWithCacheForAllCertificates(verifiedChains) if err != nil { return nil, err } From 6912f06439b3824cf2f3cb53672488208e055326 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 8 Nov 2022 15:16:11 -0600 Subject: [PATCH 31/39] Rename for testing, and attribute --- .../cert/{responder.go => test_responder.go} | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) rename builtin/credential/cert/{responder.go => test_responder.go} (87%) diff --git a/builtin/credential/cert/responder.go b/builtin/credential/cert/test_responder.go similarity index 87% rename from builtin/credential/cert/responder.go rename to builtin/credential/cert/test_responder.go index 5afc14b35a66..c37b0c124a45 100644 --- a/builtin/credential/cert/responder.go +++ b/builtin/credential/cert/test_responder.go @@ -5,6 +5,9 @@ // internal to the server, that can be fixed by an administrator. Any type of // incorrect input from a user should be logged and Info or below. For things // that are logged on every request, Debug is the appropriate level. +// +// From https://github.com/cloudflare/cfssl/blob/master/ocsp/responder.go + package cert import ( @@ -263,3 +266,30 @@ func (rs Responder) ServeHTTP(response http.ResponseWriter, request *http.Reques rs.stats.ResponseStatus(ocsp.Success) } } + +/* +Copyright (c) 2014 CloudFlare Inc. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ From 40219080aabd5237136f9da30441dec5c90f6187 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 8 Nov 2022 15:22:56 -0600 Subject: [PATCH 32/39] Up to date gofumpt --- builtin/credential/cert/path_login.go | 3 ++- builtin/logical/aws/iam_policies_test.go | 2 +- sdk/helper/ocsp/client.go | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 3e0453f88a54..3ef6cb11a544 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -316,7 +316,8 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d } func (b *backend) matchesConstraints(ctx context.Context, clientCert *x509.Certificate, trustedChain []*x509.Certificate, - config *ParsedCert, conf *ocsp.VerifyConfig) (bool, error) { + config *ParsedCert, conf *ocsp.VerifyConfig, +) (bool, error) { soFar := !b.checkForChainInCRLs(trustedChain) && b.matchesNames(clientCert, config) && b.matchesCommonName(clientCert, config) && diff --git a/builtin/logical/aws/iam_policies_test.go b/builtin/logical/aws/iam_policies_test.go index 5e8ae6feb6f0..ddba67f6b8bd 100644 --- a/builtin/logical/aws/iam_policies_test.go +++ b/builtin/logical/aws/iam_policies_test.go @@ -207,7 +207,7 @@ func Test_combinePolicyDocuments(t *testing.T) { `{"Version": "2012-10-17", "Statement": [{"Effect": "Allow", "NotAction": "ec2:DescribeAvailabilityZones", "Resource": "*"}]}`, }, expectedOutput: `{"Version": "2012-10-17","Statement":[{"Effect": "Allow","NotAction": "ec2:DescribeAvailabilityZones", "Resource": "*"}]}`, - expectedErr: false, + expectedErr: false, }, { description: "one blank policy", diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index 5151f7606046..ffd4bb74cf00 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -282,8 +282,8 @@ func (c *Client) retryOCSP( ocspHost *url.URL, headers map[string]string, reqBody []byte, - issuer *x509.Certificate) (ocspRes *ocsp.Response, ocspResBytes []byte, ocspS *ocspStatus, err error) { - + issuer *x509.Certificate, +) (ocspRes *ocsp.Response, ocspResBytes []byte, ocspS *ocspStatus, err error) { origHost := *ocspHost doRequest := func(request *retryablehttp.Request) (*http.Response, error) { if err != nil { From 8143803dad8e8ec1aa08e3a380f0e139ae063be2 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 8 Nov 2022 15:46:01 -0600 Subject: [PATCH 33/39] remove hash from key, and disable the vault dependent unit test --- sdk/helper/ocsp/ocsp_test.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go index 45ee45d3611f..1ad59e920303 100644 --- a/sdk/helper/ocsp/ocsp_test.go +++ b/sdk/helper/ocsp/ocsp_test.go @@ -160,13 +160,11 @@ func TestUnitEncodeCertIDGood(t *testing.T) { func TestUnitCheckOCSPResponseCache(t *testing.T) { c := New(testLogFactory, 10) dummyKey0 := certIDKey{ - HashAlgorithm: crypto.SHA1, NameHash: "dummy0", IssuerKeyHash: "dummy0", SerialNumber: "dummy0", } dummyKey := certIDKey{ - HashAlgorithm: crypto.SHA1, NameHash: "dummy1", IssuerKeyHash: "dummy1", SerialNumber: "dummy1", From e91c2cf210e410662278aeb163c53b8d1c2227fb Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Thu, 10 Nov 2022 14:23:38 -0600 Subject: [PATCH 34/39] Comment out TestMultiOCSP --- sdk/helper/ocsp/ocsp_test.go | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go index 1ad59e920303..cf85cebdb7bd 100644 --- a/sdk/helper/ocsp/ocsp_test.go +++ b/sdk/helper/ocsp/ocsp_test.go @@ -4,25 +4,15 @@ package ocsp import ( "bytes" - "context" - "crypto" - "crypto/tls" "crypto/x509" - "encoding/pem" - "errors" - "fmt" - "io" "io/ioutil" - "net" "net/http" - "net/url" "testing" "time" - "github.com/hashicorp/go-hclog" - "github.com/hashicorp/go-retryablehttp" - lru "github.com/hashicorp/golang-lru" "golang.org/x/crypto/ocsp" + + lru "github.com/hashicorp/golang-lru" ) func TestOCSP(t *testing.T) { @@ -66,7 +56,10 @@ func TestOCSP(t *testing.T) { } } +/** +// Used for development, requires an active Vault with PKI setup func TestMultiOCSP(t *testing.T) { + targetURL := []string{ "https://localhost:8200/v1/pki/ocsp", "https://localhost:8200/v1/pki/ocsp", @@ -109,6 +102,7 @@ func TestMultiOCSP(t *testing.T) { } } } +*/ func TestUnitEncodeCertIDGood(t *testing.T) { targetURLs := []string{ From ff60c7dddfca97a0f15f7a56b1f44afc67de2925 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Tue, 15 Nov 2022 10:20:58 -0600 Subject: [PATCH 35/39] imports --- sdk/helper/ocsp/ocsp_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go index cf85cebdb7bd..4072568df283 100644 --- a/sdk/helper/ocsp/ocsp_test.go +++ b/sdk/helper/ocsp/ocsp_test.go @@ -4,12 +4,21 @@ package ocsp import ( "bytes" + "crypto" + "crypto/tls" "crypto/x509" + "fmt" + "io" "io/ioutil" + "net" "net/http" + "net/url" "testing" "time" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-retryablehttp" + "golang.org/x/crypto/ocsp" lru "github.com/hashicorp/golang-lru" From 65bb99acb7e11a3bbca810967b1878ebd4e30a17 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Fri, 18 Nov 2022 11:54:09 -0600 Subject: [PATCH 36/39] more imports --- sdk/helper/ocsp/ocsp_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/helper/ocsp/ocsp_test.go b/sdk/helper/ocsp/ocsp_test.go index 4072568df283..2f3f1976d2a8 100644 --- a/sdk/helper/ocsp/ocsp_test.go +++ b/sdk/helper/ocsp/ocsp_test.go @@ -4,9 +4,11 @@ package ocsp import ( "bytes" + "context" "crypto" "crypto/tls" "crypto/x509" + "errors" "fmt" "io" "io/ioutil" @@ -18,10 +20,8 @@ import ( "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-retryablehttp" - - "golang.org/x/crypto/ocsp" - lru "github.com/hashicorp/golang-lru" + "golang.org/x/crypto/ocsp" ) func TestOCSP(t *testing.T) { From 9a2acb34d972784786421dd0e8ff67702104c5a7 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Fri, 18 Nov 2022 13:45:51 -0600 Subject: [PATCH 37/39] Address semgrep results --- builtin/credential/cert/test_responder.go | 20 +++++++++++--------- sdk/helper/ocsp/client.go | 12 +++++++++--- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/builtin/credential/cert/test_responder.go b/builtin/credential/cert/test_responder.go index c37b0c124a45..2f0459f2ed4f 100644 --- a/builtin/credential/cert/test_responder.go +++ b/builtin/credential/cert/test_responder.go @@ -19,6 +19,7 @@ import ( "io/ioutil" "net/http" "net/url" + "testing" "time" "golang.org/x/crypto/ocsp" @@ -74,6 +75,7 @@ type Stats interface { // A Responder object provides the HTTP logic to expose a // Source of OCSP responses. type Responder struct { + t *testing.T Source Source stats Stats } @@ -118,7 +120,7 @@ var hashToString = map[crypto.Hash]string{ // default handler will try to canonicalize path components by changing any // strings of repeated '/' into a single '/', which will break the base64 // encoding. -func (rs Responder) ServeHTTP(response http.ResponseWriter, request *http.Request) { +func (rs *Responder) ServeHTTP(response http.ResponseWriter, request *http.Request) { // By default we set a 'max-age=0, no-cache' Cache-Control header, this // is only returned to the client if a valid authorized OCSP response // is not found or an error is returned. If a response if found the header @@ -131,7 +133,7 @@ func (rs Responder) ServeHTTP(response http.ResponseWriter, request *http.Reques case "GET": base64Request, err := url.QueryUnescape(request.URL.Path) if err != nil { - fmt.Printf("Error decoding URL: %s", request.URL.Path) + rs.t.Log("Error decoding URL:", request.URL.Path) response.WriteHeader(http.StatusBadRequest) return } @@ -154,14 +156,14 @@ func (rs Responder) ServeHTTP(response http.ResponseWriter, request *http.Reques } requestBody, err = base64.StdEncoding.DecodeString(string(base64RequestBytes)) if err != nil { - fmt.Printf("Error decoding base64 from URL: %s", string(base64RequestBytes)) + rs.t.Log("Error decoding base64 from URL", string(base64RequestBytes)) response.WriteHeader(http.StatusBadRequest) return } case "POST": requestBody, err = ioutil.ReadAll(request.Body) if err != nil { - fmt.Printf("Problem reading body of POST: %s", err) + rs.t.Log("Problem reading body of POST", err) response.WriteHeader(http.StatusBadRequest) return } @@ -170,7 +172,7 @@ func (rs Responder) ServeHTTP(response http.ResponseWriter, request *http.Reques return } b64Body := base64.StdEncoding.EncodeToString(requestBody) - fmt.Printf("Received OCSP request: %s", b64Body) + rs.t.Log("Received OCSP request", b64Body) // All responses after this point will be OCSP. // We could check for the content type of the request, but that @@ -183,7 +185,7 @@ func (rs Responder) ServeHTTP(response http.ResponseWriter, request *http.Reques // should return unauthorizedRequest instead of malformed. ocspRequest, err := ocsp.ParseRequest(requestBody) if err != nil { - fmt.Printf("Error decoding request body: %s", b64Body) + rs.t.Log("Error decoding request body", b64Body) response.WriteHeader(http.StatusBadRequest) response.Write(malformedRequestErrorResponse) if rs.stats != nil { @@ -196,7 +198,7 @@ func (rs Responder) ServeHTTP(response http.ResponseWriter, request *http.Reques ocspResponse, headers, err := rs.Source.Response(ocspRequest) if err != nil { if err == ErrNotFound { - fmt.Printf("No response found for request: serial %x, request body %s", + rs.t.Log("No response found for request: serial %x, request body %s", ocspRequest.SerialNumber, b64Body) response.Write(unauthorizedErrorResponse) if rs.stats != nil { @@ -204,7 +206,7 @@ func (rs Responder) ServeHTTP(response http.ResponseWriter, request *http.Reques } return } - fmt.Printf("Error retrieving response for request: serial %x, request body %s, error: %s", + rs.t.Log("Error retrieving response for request: serial %x, request body %s, error", ocspRequest.SerialNumber, b64Body, err) response.WriteHeader(http.StatusInternalServerError) response.Write(internalErrorErrorResponse) @@ -216,7 +218,7 @@ func (rs Responder) ServeHTTP(response http.ResponseWriter, request *http.Reques parsedResponse, err := ocsp.ParseResponse(ocspResponse, nil) if err != nil { - fmt.Printf("Error parsing response for serial %x: %s", + rs.t.Log("Error parsing response for serial %x", ocspRequest.SerialNumber, err) response.Write(internalErrorErrorResponse) if rs.stats != nil { diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index ffd4bb74cf00..5bd01b3e7b6a 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -306,14 +306,20 @@ func (c *Client) retryOCSP( ocspHost.Path = ocspHost.Path + "/" + base64.StdEncoding.EncodeToString(reqBody) var res *http.Response request, err := req("GET", ocspHost.String(), nil) - if res, err = doRequest(request); err != nil { + if err != nil { + return nil, nil, nil, err + } + if res, err := doRequest(request); err != nil { return nil, nil, nil, err } else { defer res.Body.Close() } if res.StatusCode == http.StatusMethodNotAllowed { - request, err = req("POST", origHost.String(), bytes.NewBuffer(reqBody)) - if res, err = doRequest(request); err != nil { + request, err := req("POST", origHost.String(), bytes.NewBuffer(reqBody)) + if err != nil { + return nil, nil, nil, err + } + if res, err := doRequest(request); err != nil { return nil, nil, nil, err } else { defer res.Body.Close() From 899f36dfd399982341242b9ed57275c19b3e7cea Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Fri, 18 Nov 2022 13:51:59 -0600 Subject: [PATCH 38/39] Attempt to pass some sort of logging to test_responder --- builtin/credential/cert/backend_test.go | 10 ++++---- builtin/credential/cert/path_certs.go | 2 +- builtin/credential/cert/path_login.go | 4 +-- builtin/credential/cert/path_login_test.go | 9 ++++++- builtin/credential/cert/test_responder.go | 30 ++++++++++++---------- sdk/helper/ocsp/client.go | 2 +- 6 files changed, 34 insertions(+), 23 deletions(-) diff --git a/builtin/credential/cert/backend_test.go b/builtin/credential/cert/backend_test.go index 9764cf608e42..75386562e83c 100644 --- a/builtin/credential/cert/backend_test.go +++ b/builtin/credential/cert/backend_test.go @@ -1475,7 +1475,7 @@ func TestBackend_mixed_constraints(t *testing.T) { testAccStepCert(t, "2matching", ca, "foo", allowed{names: "*.example.com,whatever"}, false), testAccStepCert(t, "3invalid", ca, "foo", allowed{names: "invalid"}, false), testAccStepLogin(t, connState), - // Assumes CertEntries are processed in alphabetical order (due to store.List), so we only match 2matching if 1unconstrained doesn't match + // Assumes CertEntries are processed in alphabetical order (due to store.List), so we only match 2matching if 1unconstrained doesn'log match testAccStepLoginWithName(t, connState, "2matching"), testAccStepLoginWithNameInvalid(t, connState, "3invalid"), }, @@ -1556,7 +1556,7 @@ func TestBackend_validCIDR(t *testing.T) { if cidrsResult[0].String() != boundCIDRs[0] || cidrsResult[1].String() != boundCIDRs[1] { - t.Fatalf("bound_cidrs couldn't be set correctly, EXPECTED: %v, ACTUAL: %v", boundCIDRs, cidrsResult) + t.Fatalf("bound_cidrs couldn'log be set correctly, EXPECTED: %v, ACTUAL: %v", boundCIDRs, cidrsResult) } loginReq := &logical.Request{ @@ -1633,7 +1633,7 @@ func TestBackend_invalidCIDR(t *testing.T) { Connection: &logical.Connection{ConnState: &connState}, } - // override the remote address with an IPV4 that isn't authorized + // override the remote address with an IPV4 that isn'log authorized loginReq.Connection.RemoteAddr = "127.0.0.1/8" _, err = b.HandleRequest(context.Background(), loginReq) @@ -1711,7 +1711,7 @@ func testAccStepReadConfig(t *testing.T, conf config, connState tls.ConnectionSt } if b != conf.EnableIdentityAliasMetadata { - t.Fatalf("bad: expected enable_identity_alias_metadata to be %t, got %t", conf.EnableIdentityAliasMetadata, b) + t.Fatalf("bad: expected enable_identity_alias_metadata to be %log, got %log", conf.EnableIdentityAliasMetadata, b) } return nil @@ -2207,7 +2207,7 @@ func Test_Renew(t *testing.T) { t.Fatalf("expected a period value of %s in the response, got: %s", period, resp.Auth.Period) } - // Delete CA, make sure we can't renew + // Delete CA, make sure we can'log renew resp, err = b.pathCertDelete(context.Background(), req, fd) if err != nil { t.Fatal(err) diff --git a/builtin/credential/cert/path_certs.go b/builtin/credential/cert/path_certs.go index 7184435ef56e..065829da9a50 100644 --- a/builtin/credential/cert/path_certs.go +++ b/builtin/credential/cert/path_certs.go @@ -496,7 +496,7 @@ This endpoint allows you to create, read, update, and delete trusted certificate that are allowed to authenticate. Deleting a certificate will not revoke auth for prior authenticated connections. -To do this, do a revoke on "login". If you don't need to revoke login immediately, +To do this, do a revoke on "login". If you don'log need to revoke login immediately, then the next renew will cause the lease to expire. ` diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 3ef6cb11a544..36144791b5a8 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -480,7 +480,7 @@ func (b *backend) matchesCertificateExtensions(clientCert *x509.Certificate, con asn1.Unmarshal(ext.Value, &parsedValue) clientExtMap[ext.Id.String()] = parsedValue } - // If any of the required extensions don't match the constraint fails + // If any of the required extensions don'log match the constraint fails for _, requiredExt := range config.Entry.RequiredExtensions { reqExt := strings.SplitN(requiredExt, ":", 2) clientExtValue, clientExtValueOk := clientExtMap[reqExt[0]] @@ -549,7 +549,7 @@ func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, continue } if entry == nil { - // This could happen when the certName was provided and the cert doesn't exist, + // This could happen when the certName was provided and the cert doesn'log exist, // or just if between the LIST and the GET the cert was deleted. continue } diff --git a/builtin/credential/cert/path_login_test.go b/builtin/credential/cert/path_login_test.go index 36f43380de6e..f69444270f39 100644 --- a/builtin/credential/cert/path_login_test.go +++ b/builtin/credential/cert/path_login_test.go @@ -30,6 +30,12 @@ var ocspPort int var source InMemorySource +type testLogger struct{} + +func (t *testLogger) Log(args ...any) { + fmt.Printf("%v", args) +} + func TestMain(m *testing.M) { source = make(InMemorySource) @@ -37,10 +43,11 @@ func TestMain(m *testing.M) { if err != nil { return } + ocspPort = listener.Addr().(*net.TCPAddr).Port srv := &http.Server{ Addr: "localhost:0", - Handler: NewResponder(source, nil), + Handler: NewResponder(&testLogger{}, source, nil), } go func() { srv.Serve(listener) diff --git a/builtin/credential/cert/test_responder.go b/builtin/credential/cert/test_responder.go index 2f0459f2ed4f..1c7c75b2ff33 100644 --- a/builtin/credential/cert/test_responder.go +++ b/builtin/credential/cert/test_responder.go @@ -19,7 +19,6 @@ import ( "io/ioutil" "net/http" "net/url" - "testing" "time" "golang.org/x/crypto/ocsp" @@ -46,7 +45,7 @@ var ( // ETag to the SHA256 hash of the response, and Content-Type to // application/ocsp-response. If you want to override these headers, // or set extra headers, your source should return a http.Header -// with the headers you wish to set. If you don't want to set any +// with the headers you wish to set. If you don'log want to set any // extra headers you may return nil instead. type Source interface { Response(*ocsp.Request) ([]byte, http.Header, error) @@ -72,19 +71,24 @@ type Stats interface { ResponseStatus(ocsp.ResponseStatus) } +type logger interface { + Log(args ...any) +} + // A Responder object provides the HTTP logic to expose a // Source of OCSP responses. type Responder struct { - t *testing.T + log logger Source Source stats Stats } // NewResponder instantiates a Responder with the give Source. -func NewResponder(source Source, stats Stats) *Responder { +func NewResponder(t logger, source Source, stats Stats) *Responder { return &Responder{ Source: source, stats: stats, + log: t, } } @@ -133,7 +137,7 @@ func (rs *Responder) ServeHTTP(response http.ResponseWriter, request *http.Reque case "GET": base64Request, err := url.QueryUnescape(request.URL.Path) if err != nil { - rs.t.Log("Error decoding URL:", request.URL.Path) + rs.log.Log("Error decoding URL:", request.URL.Path) response.WriteHeader(http.StatusBadRequest) return } @@ -156,14 +160,14 @@ func (rs *Responder) ServeHTTP(response http.ResponseWriter, request *http.Reque } requestBody, err = base64.StdEncoding.DecodeString(string(base64RequestBytes)) if err != nil { - rs.t.Log("Error decoding base64 from URL", string(base64RequestBytes)) + rs.log.Log("Error decoding base64 from URL", string(base64RequestBytes)) response.WriteHeader(http.StatusBadRequest) return } case "POST": requestBody, err = ioutil.ReadAll(request.Body) if err != nil { - rs.t.Log("Problem reading body of POST", err) + rs.log.Log("Problem reading body of POST", err) response.WriteHeader(http.StatusBadRequest) return } @@ -172,7 +176,7 @@ func (rs *Responder) ServeHTTP(response http.ResponseWriter, request *http.Reque return } b64Body := base64.StdEncoding.EncodeToString(requestBody) - rs.t.Log("Received OCSP request", b64Body) + rs.log.Log("Received OCSP request", b64Body) // All responses after this point will be OCSP. // We could check for the content type of the request, but that @@ -181,11 +185,11 @@ func (rs *Responder) ServeHTTP(response http.ResponseWriter, request *http.Reque // Parse response as an OCSP request // XXX: This fails if the request contains the nonce extension. - // We don't intend to support nonces anyway, but maybe we + // We don'log intend to support nonces anyway, but maybe we // should return unauthorizedRequest instead of malformed. ocspRequest, err := ocsp.ParseRequest(requestBody) if err != nil { - rs.t.Log("Error decoding request body", b64Body) + rs.log.Log("Error decoding request body", b64Body) response.WriteHeader(http.StatusBadRequest) response.Write(malformedRequestErrorResponse) if rs.stats != nil { @@ -198,7 +202,7 @@ func (rs *Responder) ServeHTTP(response http.ResponseWriter, request *http.Reque ocspResponse, headers, err := rs.Source.Response(ocspRequest) if err != nil { if err == ErrNotFound { - rs.t.Log("No response found for request: serial %x, request body %s", + rs.log.Log("No response found for request: serial %x, request body %s", ocspRequest.SerialNumber, b64Body) response.Write(unauthorizedErrorResponse) if rs.stats != nil { @@ -206,7 +210,7 @@ func (rs *Responder) ServeHTTP(response http.ResponseWriter, request *http.Reque } return } - rs.t.Log("Error retrieving response for request: serial %x, request body %s, error", + rs.log.Log("Error retrieving response for request: serial %x, request body %s, error", ocspRequest.SerialNumber, b64Body, err) response.WriteHeader(http.StatusInternalServerError) response.Write(internalErrorErrorResponse) @@ -218,7 +222,7 @@ func (rs *Responder) ServeHTTP(response http.ResponseWriter, request *http.Reque parsedResponse, err := ocsp.ParseResponse(ocspResponse, nil) if err != nil { - rs.t.Log("Error parsing response for serial %x", + rs.log.Log("Error parsing response for serial %x", ocspRequest.SerialNumber, err) response.Write(internalErrorErrorResponse) if rs.stats != nil { diff --git a/sdk/helper/ocsp/client.go b/sdk/helper/ocsp/client.go index 5bd01b3e7b6a..e54fdeface46 100644 --- a/sdk/helper/ocsp/client.go +++ b/sdk/helper/ocsp/client.go @@ -309,7 +309,7 @@ func (c *Client) retryOCSP( if err != nil { return nil, nil, nil, err } - if res, err := doRequest(request); err != nil { + if res, err = doRequest(request); err != nil { return nil, nil, nil, err } else { defer res.Body.Close() From ec3d5ff8f26dbc31937a9d28895ce0c4dd6f5af7 Mon Sep 17 00:00:00 2001 From: "Scott G. Miller" Date: Mon, 21 Nov 2022 10:17:35 -0600 Subject: [PATCH 39/39] fix overzealous search&replace --- builtin/credential/cert/backend_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/builtin/credential/cert/backend_test.go b/builtin/credential/cert/backend_test.go index 75386562e83c..9764cf608e42 100644 --- a/builtin/credential/cert/backend_test.go +++ b/builtin/credential/cert/backend_test.go @@ -1475,7 +1475,7 @@ func TestBackend_mixed_constraints(t *testing.T) { testAccStepCert(t, "2matching", ca, "foo", allowed{names: "*.example.com,whatever"}, false), testAccStepCert(t, "3invalid", ca, "foo", allowed{names: "invalid"}, false), testAccStepLogin(t, connState), - // Assumes CertEntries are processed in alphabetical order (due to store.List), so we only match 2matching if 1unconstrained doesn'log match + // Assumes CertEntries are processed in alphabetical order (due to store.List), so we only match 2matching if 1unconstrained doesn't match testAccStepLoginWithName(t, connState, "2matching"), testAccStepLoginWithNameInvalid(t, connState, "3invalid"), }, @@ -1556,7 +1556,7 @@ func TestBackend_validCIDR(t *testing.T) { if cidrsResult[0].String() != boundCIDRs[0] || cidrsResult[1].String() != boundCIDRs[1] { - t.Fatalf("bound_cidrs couldn'log be set correctly, EXPECTED: %v, ACTUAL: %v", boundCIDRs, cidrsResult) + t.Fatalf("bound_cidrs couldn't be set correctly, EXPECTED: %v, ACTUAL: %v", boundCIDRs, cidrsResult) } loginReq := &logical.Request{ @@ -1633,7 +1633,7 @@ func TestBackend_invalidCIDR(t *testing.T) { Connection: &logical.Connection{ConnState: &connState}, } - // override the remote address with an IPV4 that isn'log authorized + // override the remote address with an IPV4 that isn't authorized loginReq.Connection.RemoteAddr = "127.0.0.1/8" _, err = b.HandleRequest(context.Background(), loginReq) @@ -1711,7 +1711,7 @@ func testAccStepReadConfig(t *testing.T, conf config, connState tls.ConnectionSt } if b != conf.EnableIdentityAliasMetadata { - t.Fatalf("bad: expected enable_identity_alias_metadata to be %log, got %log", conf.EnableIdentityAliasMetadata, b) + t.Fatalf("bad: expected enable_identity_alias_metadata to be %t, got %t", conf.EnableIdentityAliasMetadata, b) } return nil @@ -2207,7 +2207,7 @@ func Test_Renew(t *testing.T) { t.Fatalf("expected a period value of %s in the response, got: %s", period, resp.Auth.Period) } - // Delete CA, make sure we can'log renew + // Delete CA, make sure we can't renew resp, err = b.pathCertDelete(context.Background(), req, fd) if err != nil { t.Fatal(err)