Skip to content

Commit

Permalink
Set up mTLS in tests (#87)
Browse files Browse the repository at this point in the history
* Refactor certificate generation in tests

The process of generating certificates for testing has been refactored. The responsibility of creating the certificates has been moved from the lifecycle test to the provider suite setup. This change simplifies the lifecycle test and makes it easier to manage certificates across different tests. Additionally, a new method 'WithCerts' was introduced in TestHost interface replacing 'CreateCertBundle'.

* The CAs are swapped but basically everything else is wired up

* mTLS actually?

* Rename CA in most of the spots where it matters

* Fix comment
  • Loading branch information
UnstoppableMango authored Aug 3, 2024
1 parent 1302582 commit 849c812
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 52 deletions.
20 changes: 10 additions & 10 deletions provider/cmd/provisioner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ import (
)

var (
address string
network string
caFile string
certFile string
keyFile string
verbose bool
address string
network string
clientCaFile string
certFile string
keyFile string
verbose bool
)

var rootCmd = &cobra.Command{
Expand All @@ -41,14 +41,14 @@ var rootCmd = &cobra.Command{
log.Debug("creating provisioner")
provisioner := p.New(lis,
p.WithLogger(log),
p.WithOptionalCertificates(caFile, certFile, keyFile),
p.WithOptionalCertificates(clientCaFile, certFile, keyFile),
)

log.Info("serving",
"network", network,
"address", address,
"verbose", verbose,
"caFile", caFile,
"clientCaFile", clientCaFile,
"certFile", certFile,
"keyFile", keyFile,
)
Expand All @@ -62,10 +62,10 @@ func main() {
rootCmd.Flags().StringVar(&network, "network", "tcp", "Must be a valid `net.Listen()` network. i.e. \"tcp\", \"tcp4\", \"tcp6\", \"unix\" or \"unixpacket\"")
rootCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Log verbosity")

rootCmd.Flags().StringVar(&caFile, "ca-file", "", "The path to the certificate authority file")
rootCmd.Flags().StringVar(&clientCaFile, "client-ca-file", "", "The path to the certificate authority file")
rootCmd.Flags().StringVar(&certFile, "cert-file", "", "The path to the server certificate file")
rootCmd.Flags().StringVar(&keyFile, "key-file", "", "The path to the server private key file")
rootCmd.MarkFlagsRequiredTogether("ca-file", "cert-file", "key-file")
rootCmd.MarkFlagsRequiredTogether("client-ca-file", "cert-file", "key-file")

if err := rootCmd.Execute(); err != nil {
fmt.Printf("failed to execute: %s\n", err)
Expand Down
50 changes: 42 additions & 8 deletions provider/pkg/provider/config.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,62 @@
package provider

import (
"strings"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)

type Config struct {
Address string `pulumi:"address"`
Port string `pulumi:"port,optional"`
CaPem string `pulumi:"caPem,optional"`
CertPem string `pulumi:"certPem,optional"`
KeyPem string `pulumi:"keyPem,optional"`
}

func (c Config) NewGrpcClient() (*grpc.ClientConn, error) {
parts := []string{}
if c.Address != "" {
parts = append(parts, c.Address)
}
target := c.Address
if c.Port != "" {
parts = append(parts, c.Port)
target = target + ":" + c.Port
}

creds, err := c.TransportCredentials()
if err != nil {
return nil, err
}

target := strings.Join(parts, ":")
return grpc.NewClient(target,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithTransportCredentials(creds),
)
}

func (c Config) TransportCredentials() (credentials.TransportCredentials, error) {
if c.CaPem == "" && c.CertPem == "" && c.KeyPem == "" {
return insecure.NewCredentials(), nil
}

if c.CaPem != "" && c.CertPem != "" && c.KeyPem != "" {
cert, err := tls.X509KeyPair([]byte(c.CertPem), []byte(c.KeyPem))
if err != nil {
return nil, fmt.Errorf("failed to parse X509 key pair: %w", err)
}

ca := x509.NewCertPool()
if ok := ca.AppendCertsFromPEM([]byte(c.CaPem)); !ok {
return nil, errors.New("failed to append ca cert")
}

return credentials.NewTLS(&tls.Config{
ServerName: "provisioner",
Certificates: []tls.Certificate{cert},
RootCAs: ca,
}), nil
}

return nil, errors.New("caPem, certPem, and keyPem must all be set together")
}
2 changes: 1 addition & 1 deletion provider/pkg/provisioner/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func LoadCertificates(caPath, certPath, keyPath string) (*tls.Config, error) {
if err != nil {
return nil, fmt.Errorf("failed reading ca file: %w", err)
}
if ok := ca.AppendCertsFromPEM(caData); ok {
if ok := ca.AppendCertsFromPEM(caData); !ok {
return nil, fmt.Errorf("unable to append ca data from file %s", caPath)
}

Expand Down
6 changes: 1 addition & 5 deletions tests/lifecycle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,14 @@ var _ = Describe("Command Resources", func() {
_, err := provisioner.Exec(ctx, "mkdir", "-p", work)
Expect(err).NotTo(HaveOccurred())

By("generating certificates")
certs, err := provisioner.CreateCertBundle(ctx, "lifecycle", work)
Expect(err).NotTo(HaveOccurred())

By("fetching provisioner connection details")
addr, port, err := provisioner.ConnectionDetails(ctx)
Expect(err).NotTo(HaveOccurred())

By("configuring the provider")
err = util.ConfigureProvider(server).
WithProvisioner(addr, port).
WithCerts(certs).
WithCerts(provisioner.Ca(), clientCerts.Cert).
Configure()

Expect(err).NotTo(HaveOccurred())
Expand Down
9 changes: 8 additions & 1 deletion tests/provider_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
var (
provisioner util.TestProvisioner
sshServer util.SshServer
clientCerts *util.CertBundle
)

func TestProvider(t *testing.T) {
Expand All @@ -22,8 +23,14 @@ func TestProvider(t *testing.T) {
}

var _ = BeforeSuite(func(ctx context.Context) {
var err error

By("generating client certs")
clientCerts, err = util.NewCertBundle("ca", "pulumi")
Expect(err).NotTo(HaveOccurred())

By("creating a provisioner")
prov, err := util.NewProvisioner("6969", os.Stdout)
prov, err := util.NewProvisioner("6969", clientCerts.Ca, os.Stdout)
Expect(err).NotTo(HaveOccurred())

By("starting the provisioner")
Expand Down
20 changes: 20 additions & 0 deletions tests/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package tests

import (
"github.com/mdelapenya/tlscert"
)

type CertBundle struct {
Ca *tlscert.Certificate
Cert *tlscert.Certificate
}

func ServerCerts() (*CertBundle, error) {
ca := tlscert.SelfSignedCA("test-ca")

req := tlscert.NewRequest("test-cert")
req.Parent = ca
cert := tlscert.SelfSignedFromRequest(req)

return &CertBundle{ca, cert}, nil
}
24 changes: 5 additions & 19 deletions tests/util/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type TestHost interface {
Ip(context.Context) (string, error)
ReadFile(context.Context, string) ([]byte, error)
WriteFile(context.Context, string, []byte) error
CreateCertBundle(context.Context, string, string) (*HostCerts, error)
WithCerts(context.Context, *CertBundle) (*HostCerts, error)

Start(context.Context) error
Stop(context.Context) error
Expand Down Expand Up @@ -105,28 +105,14 @@ func (h *host) Ip(ctx context.Context) (string, error) {
return ctr.ContainerIP(ctx)
}

// CreateCertBundle implemnts TestHost.
func (h *host) CreateCertBundle(ctx context.Context, name string, dir string) (*HostCerts, error) {
ctr, err := h.ensureContainer(ctx)
if err != nil {
return nil, err
}

host, err := ctr.Host(ctx)
if err != nil {
return nil, fmt.Errorf("retriving host: %w", err)
}

_, err = h.Exec(ctx, "mkdir", "--parents", dir)
// WithCerts implemnts TestHost.
func (h *host) WithCerts(ctx context.Context, bundle *CertBundle) (*HostCerts, error) {
dir := "/etc/baremetal"
_, err := h.Exec(ctx, "mkdir", "--parents", dir)
if err != nil {
return nil, fmt.Errorf("creating cert directory: %w", err)
}

bundle, err := NewCertBundle(host, name)
if err != nil {
return nil, fmt.Errorf("creating cert bundle: %w", err)
}

caFile := path.Join(dir, "ca.pem")
certFile := path.Join(dir, "cert.pem")
keyFile := path.Join(dir, "key.pem")
Expand Down
11 changes: 6 additions & 5 deletions tests/util/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/blang/semver"
"github.com/mdelapenya/tlscert"
p "github.com/pulumi/pulumi-go-provider"
"github.com/pulumi/pulumi-go-provider/integration"
"github.com/pulumi/pulumi/sdk/v3/go/common/resource"
Expand All @@ -17,7 +18,7 @@ import (
type ProviderBuilder interface {
Configure() error
WithProvisioner(address, port string) ProviderBuilder
WithCerts(*HostCerts) ProviderBuilder
WithCerts(*tlscert.Certificate, *tlscert.Certificate) ProviderBuilder
}

func NewServer() integration.Server {
Expand Down Expand Up @@ -61,12 +62,12 @@ func (c *configureBuilder) Configure() error {
}

// WithCerts implements ProviderBuilder.
func (c *configureBuilder) WithCerts(certs *HostCerts) ProviderBuilder {
func (c *configureBuilder) WithCerts(ca *tlscert.Certificate, cert *tlscert.Certificate) ProviderBuilder {
args := c.Args.Mappable()
maps.Copy(args, map[string]interface{}{
"caPath": certs.CaPath,
"certPath": certs.CertPath,
"keyPath": certs.KeyPath,
"caPem": string(ca.Bytes),
"certPem": string(cert.Bytes),
"keyPem": string(cert.KeyBytes),
})

c.Args = resource.NewPropertyMapFromMap(args)
Expand Down
46 changes: 43 additions & 3 deletions tests/util/provisioner.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package util

import (
"bytes"
"context"
"fmt"
"io"
"os"
"path"

"github.com/mdelapenya/tlscert"
tc "github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
Expand All @@ -19,29 +21,62 @@ const (
type TestProvisioner interface {
TestHost

Ca() *tlscert.Certificate
ConnectionDetails(context.Context) (address, port string, err error)
}

type provisioner struct {
host
port string
port string
bundle *CertBundle
}

func NewProvisioner(port string, logger io.Writer) (TestProvisioner, error) {
func NewProvisioner(
port string,
clientCa *tlscert.Certificate,
logger io.Writer,
) (TestProvisioner, error) {
cwd, err := os.Getwd()
if err != nil {
return nil, err
}

certs, err := NewCertBundle("ca", "provisioner")
if err != nil {
return nil, err
}

certDir := "/etc/baremetal/pki"
clientCaPath := path.Join(certDir, "client-ca.pem")
certPath := path.Join(certDir, "cert.pem")
keyPath := path.Join(certDir, "key.pem")

req := tc.GenericContainerRequest{
ContainerRequest: tc.ContainerRequest{
FromDockerfile: tc.FromDockerfile{
Context: path.Clean(path.Join(cwd, "..")),
Dockerfile: path.Join("provider", "cmd", "provisioner", "Dockerfile"),
},
Files: []tc.ContainerFile{
{
ContainerFilePath: clientCaPath,
Reader: bytes.NewReader(clientCa.Bytes),
},
{
ContainerFilePath: certPath,
Reader: bytes.NewReader(certs.Cert.Bytes),
},
{
ContainerFilePath: keyPath,
Reader: bytes.NewReader(certs.Cert.KeyBytes),
},
},
Cmd: []string{
"--network", defaultProtocol,
"--address", fmt.Sprintf("%s:%s", "0.0.0.0", port),
"--client-ca-file", clientCaPath,
"--cert-file", certPath,
"--key-file", keyPath,
"--verbose",
},
ExposedPorts: []string{port},
Expand All @@ -52,7 +87,12 @@ func NewProvisioner(port string, logger io.Writer) (TestProvisioner, error) {
},
}

return &provisioner{host{req, nil}, port}, nil
return &provisioner{host{req, nil}, port, certs}, nil
}

// CertBundle implements TestProvisioner.
func (p *provisioner) Ca() *tlscert.Certificate {
return p.bundle.Ca
}

// ConnectionDetails implements TestProvisioner.
Expand Down

0 comments on commit 849c812

Please sign in to comment.