diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index f08a5cb04e..246a1e28ab 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -679,9 +679,6 @@ func TestFetchJWTSVID(t *testing.T) { } defer l.Close() - ca, cakey := createCA(t, trustDomain) - baseSVID, baseSVIDKey := createSVID(t, ca, cakey, "spiffe://"+trustDomain+"/agent", 1*time.Hour) - fetchResp := &node.FetchJWTSVIDResponse{} apiHandler := newMockNodeAPIHandler(&mockNodeAPIHandlerConfig{ @@ -693,6 +690,9 @@ func TestFetchJWTSVID(t *testing.T) { }, svidTTL: 200, }) + + baseSVID, baseSVIDKey := apiHandler.newSVID("spiffe://"+trustDomain+"/spire/agent/join_token/abcd", 1*time.Hour) + apiHandler.start() defer apiHandler.stop() @@ -1097,7 +1097,7 @@ func (h *mockNodeAPIHandler) getGRPCServerConfig(hello *tls.ClientHelloInfo) (*t roots.AddCert(h.ca()) c := &tls.Config{ - ClientAuth: tls.RequestClientCert, + ClientAuth: tls.VerifyClientCertIfGiven, Certificates: certs, ClientCAs: roots, } @@ -1106,21 +1106,26 @@ func (h *mockNodeAPIHandler) getGRPCServerConfig(hello *tls.ClientHelloInfo) (*t } func (h *mockNodeAPIHandler) getCertFromCtx(ctx context.Context) (certificate *x509.Certificate, err error) { - ctxPeer, ok := peer.FromContext(ctx) if !ok { - return nil, errors.New("It was not posible to extract peer from request") + return nil, errors.New("no peer information") } tlsInfo, ok := ctxPeer.AuthInfo.(credentials.TLSInfo) if !ok { - return nil, errors.New("It was not posible to extract AuthInfo from request") + return nil, errors.New("no TLS auth info for peer") } - if len(tlsInfo.State.PeerCertificates) == 0 { - return nil, errors.New("PeerCertificates was empty") + if len(tlsInfo.State.VerifiedChains) == 0 { + return nil, errors.New("no verified client certificate presented by peer") + } + chain := tlsInfo.State.VerifiedChains[0] + if len(chain) == 0 { + // this shouldn't be possible with the tls package, but we should be + // defensive. + return nil, errors.New("verified client chain is missing certificates") } - return tlsInfo.State.PeerCertificates[0], nil + return chain[0], nil } func createTempDir(t *testing.T) string { diff --git a/pkg/server/endpoints/endpoints.go b/pkg/server/endpoints/endpoints.go index fb08f9183c..76e2a396c2 100644 --- a/pkg/server/endpoints/endpoints.go +++ b/pkg/server/endpoints/endpoints.go @@ -200,7 +200,7 @@ func (e *endpoints) getGRPCServerConfig(ctx context.Context) func(*tls.ClientHel // an SVID. In order to include the bootstrap endpoint // in the same server as the rest of the Node API, // request but don't require a client certificate - ClientAuth: tls.RequestClientCert, + ClientAuth: tls.VerifyClientCertIfGiven, Certificates: certs, ClientCAs: roots, diff --git a/pkg/server/endpoints/endpoints_test.go b/pkg/server/endpoints/endpoints_test.go index f282080fce..f04c491c9a 100644 --- a/pkg/server/endpoints/endpoints_test.go +++ b/pkg/server/endpoints/endpoints_test.go @@ -8,13 +8,16 @@ import ( "errors" "net" "net/url" + "strings" + "sync" "testing" "time" - "github.com/imkira/go-observer" + observer "github.com/imkira/go-observer" "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/spire/pkg/common/bundleutil" "github.com/spiffe/spire/pkg/server/svid" + "github.com/spiffe/spire/proto/common" "github.com/spiffe/spire/proto/server/datastore" "github.com/spiffe/spire/test/fakes/fakedatastore" "github.com/spiffe/spire/test/fakes/fakeservercatalog" @@ -152,7 +155,7 @@ func (s *EndpointsTestSuite) TestGetGRPCServerConfig() { tlsConfig, err := s.e.getGRPCServerConfig(ctx)(nil) require.NoError(s.T(), err) - s.Assert().Equal(tls.RequestClientCert, tlsConfig.ClientAuth) + s.Assert().Equal(tls.VerifyClientCertIfGiven, tlsConfig.ClientAuth) s.Assert().Equal(certs, tlsConfig.Certificates) s.Assert().Equal(pool, tlsConfig.ClientCAs) } @@ -221,3 +224,107 @@ func (s *EndpointsTestSuite) configureBundle() ([]tls.Certificate, *x509.CertPoo }, }, caPool } + +func (s *EndpointsTestSuite) TestClientCertificateVerification() { + caTmpl, err := util.NewCATemplate("example.org") + s.Require().NoError(err) + caCert, caKey, err := util.SelfSign(caTmpl) + s.Require().NoError(err) + + serverTmpl, err := util.NewSVIDTemplate("spiffe://example.org/server") + s.Require().NoError(err) + serverTmpl.DNSNames = []string{"just-for-validation"} + serverCert, serverKey, err := util.Sign(serverTmpl, caCert, caKey) + s.Require().NoError(err) + + clientTmpl, err := util.NewSVIDTemplate("spiffe://example.org/agent") + s.Require().NoError(err) + clientCert, clientKey, err := util.Sign(clientTmpl, caCert, caKey) + s.Require().NoError(err) + + otherCaTmpl, err := util.NewCATemplate("example.org") + s.Require().NoError(err) + otherCaCert, otherCaKey, err := util.SelfSign(otherCaTmpl) + s.Require().NoError(err) + + otherClientTmpl, err := util.NewSVIDTemplate("spiffe://example.org/agent") + s.Require().NoError(err) + otherClientCert, otherClientKey, err := util.Sign(otherClientTmpl, otherCaCert, otherCaKey) + s.Require().NoError(err) + + rootCAs := x509.NewCertPool() + rootCAs.AddCert(caCert) + + // set the trust bundle and plumb a CA certificate + _, err = s.ds.CreateBundle(context.Background(), &datastore.CreateBundleRequest{ + Bundle: &common.Bundle{ + TrustDomainId: "spiffe://example.org", + RootCas: []*common.Certificate{ + {DerBytes: caCert.Raw}, + }, + }, + }) + s.Require().NoError(err) + s.svidState.Update(svid.State{ + SVID: []*x509.Certificate{serverCert}, + Key: serverKey, + }) + + var wg sync.WaitGroup + defer wg.Wait() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + wg.Add(1) + go func() { + defer wg.Done() + s.e.ListenAndServe(ctx) + }() + + // This helper function attempts a TLS connection to the gRPC server. It + // uses the supplied client certificate, if any. It gives up the 2 seconds + // for the server to start listening, which is generous. Any non-dial + // related errors (i.e. TLS handshake failures) are returned. + try := func(cert *tls.Certificate) error { + tlsConfig := &tls.Config{ + RootCAs: rootCAs, + // this override is just so we don't have to set up spiffe peer + // validation of the server by the client, which is outside the + // scope of this test. + ServerName: "just-for-validation", + } + if cert != nil { + tlsConfig.Certificates = append(tlsConfig.Certificates, *cert) + } + for i := 0; i < 20; i++ { + conn, err := tls.Dial("tcp", "127.0.0.1:8000", tlsConfig) + if err != nil { + if strings.HasPrefix(err.Error(), "dial") { + time.Sleep(time.Millisecond * 100) + continue + } + return err + } + conn.Close() + return nil + } + s.FailNow("unable to connect to server within 2 seconds") + return errors.New("unreachable") + } + + err = try(nil) + s.Require().NoError(err, "client should be allowed if no cert presented") + + err = try(&tls.Certificate{ + Certificate: [][]byte{clientCert.Raw}, + PrivateKey: clientKey, + }) + s.Require().NoError(err, "client should be allowed if proper cert presented") + + err = try(&tls.Certificate{ + Certificate: [][]byte{otherClientCert.Raw}, + PrivateKey: otherClientKey, + }) + s.Require().Error(err, "client should NOT be allowed if cert presented is not trusted") +} diff --git a/pkg/server/endpoints/node/handler.go b/pkg/server/endpoints/node/handler.go index cc71d7af47..a6ab788958 100644 --- a/pkg/server/endpoints/node/handler.go +++ b/pkg/server/endpoints/node/handler.go @@ -560,21 +560,26 @@ func (h *Handler) getAttestResponse(ctx context.Context, } func (h *Handler) getCertFromCtx(ctx context.Context) (certificate *x509.Certificate, err error) { - ctxPeer, ok := peer.FromContext(ctx) if !ok { - return nil, errors.New("It was not posible to extract peer from request") + return nil, errors.New("no peer information") } tlsInfo, ok := ctxPeer.AuthInfo.(credentials.TLSInfo) if !ok { - return nil, errors.New("It was not posible to extract AuthInfo from request") + return nil, errors.New("no TLS auth info for peer") } - if len(tlsInfo.State.PeerCertificates) == 0 { - return nil, errors.New("PeerCertificates was empty") + if len(tlsInfo.State.VerifiedChains) == 0 { + return nil, errors.New("no verified client certificate presented by peer") + } + chain := tlsInfo.State.VerifiedChains[0] + if len(chain) == 0 { + // this shouldn't be possible with the tls package, but we should be + // defensive. + return nil, errors.New("verified client chain is missing certificates") } - return tlsInfo.State.PeerCertificates[0], nil + return chain[0], nil } func (h *Handler) signCSRs(ctx context.Context, diff --git a/pkg/server/endpoints/node/handler_test.go b/pkg/server/endpoints/node/handler_test.go index 7dc484a418..572d55dbff 100644 --- a/pkg/server/endpoints/node/handler_test.go +++ b/pkg/server/endpoints/node/handler_test.go @@ -27,11 +27,11 @@ import ( "github.com/spiffe/spire/test/fakes/fakeserverca" "github.com/spiffe/spire/test/fakes/fakeservercatalog" "github.com/spiffe/spire/test/fakes/fakeupstreamca" - "github.com/spiffe/spire/test/mock/proto/api/node" - "github.com/spiffe/spire/test/mock/proto/server/datastore" - "github.com/spiffe/spire/test/mock/proto/server/nodeattestor" - "github.com/spiffe/spire/test/mock/proto/server/noderesolver" - "github.com/spiffe/spire/test/mock/server/ca" + mock_node "github.com/spiffe/spire/test/mock/proto/api/node" + mock_datastore "github.com/spiffe/spire/test/mock/proto/server/datastore" + mock_nodeattestor "github.com/spiffe/spire/test/mock/proto/server/nodeattestor" + mock_noderesolver "github.com/spiffe/spire/test/mock/proto/server/noderesolver" + mock_ca "github.com/spiffe/spire/test/mock/server/ca" "github.com/spiffe/spire/test/util" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -732,7 +732,7 @@ func getFakePeer() *peer.Peer { parsedCert := loadCertFromPEM("base_cert.pem") state := tls.ConnectionState{ - PeerCertificates: []*x509.Certificate{parsedCert}, + VerifiedChains: [][]*x509.Certificate{{parsedCert}}, } addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") diff --git a/test/util/cert_generation.go b/test/util/cert_generation.go index fccd9dcc55..6e9f38e1e6 100644 --- a/test/util/cert_generation.go +++ b/test/util/cert_generation.go @@ -9,9 +9,8 @@ import ( "crypto/x509/pkix" "math/big" mrand "math/rand" + "net/url" "time" - - "github.com/spiffe/go-spiffe/uri" ) // NewSVIDTemplate returns a default SVID template with the specified SPIFFE ID. Must @@ -136,18 +135,11 @@ func defaultCATemplate() *x509.Certificate { // Create an x509 extension with the URI SAN of the given SPIFFE ID, and set it onto // the referenced certificate func addSpiffeExtension(spiffeID string, cert *x509.Certificate) error { - uriSANs, err := uri.MarshalUriSANs([]string{spiffeID}) + u, err := url.Parse(spiffeID) if err != nil { return err } - - ext := []pkix.Extension{{ - Id: uri.OidExtensionSubjectAltName, - Value: uriSANs, - Critical: true, - }} - - cert.ExtraExtensions = ext + cert.URIs = append(cert.URIs, u) return nil }