Skip to content

Commit

Permalink
Merge pull request #725 from luraproject/support_multiple_certs
Browse files Browse the repository at this point in the history
Support multiple certs
  • Loading branch information
kpacha authored Jul 4, 2024
2 parents cb94681 + fc8ed5c commit be5c8bd
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 42 deletions.
29 changes: 18 additions & 11 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,19 +303,26 @@ type Plugin struct {
Pattern string `mapstructure:"pattern"`
}

// TLSKeyPair contains a pair of public and private keys
type TLSKeyPair struct {
PublicKey string `mapstructure:"public_key"`
PrivateKey string `mapstructure:"private_key"`
}

// TLS defines the configuration params for enabling TLS (HTTPS & HTTP/2) at the router layer
type TLS struct {
IsDisabled bool `mapstructure:"disabled"`
PublicKey string `mapstructure:"public_key"`
PrivateKey string `mapstructure:"private_key"`
CaCerts []string `mapstructure:"ca_certs"`
MinVersion string `mapstructure:"min_version"`
MaxVersion string `mapstructure:"max_version"`
CurvePreferences []uint16 `mapstructure:"curve_preferences"`
PreferServerCipherSuites bool `mapstructure:"prefer_server_cipher_suites"`
CipherSuites []uint16 `mapstructure:"cipher_suites"`
EnableMTLS bool `mapstructure:"enable_mtls"`
DisableSystemCaPool bool `mapstructure:"disable_system_ca_pool"`
IsDisabled bool `mapstructure:"disabled"`
PublicKey string `mapstructure:"public_key"`
PrivateKey string `mapstructure:"private_key"`
CaCerts []string `mapstructure:"ca_certs"`
MinVersion string `mapstructure:"min_version"`
MaxVersion string `mapstructure:"max_version"`
CurvePreferences []uint16 `mapstructure:"curve_preferences"`
PreferServerCipherSuites bool `mapstructure:"prefer_server_cipher_suites"`
CipherSuites []uint16 `mapstructure:"cipher_suites"`
EnableMTLS bool `mapstructure:"enable_mtls"`
DisableSystemCaPool bool `mapstructure:"disable_system_ca_pool"`
Keys []TLSKeyPair `mapstructure:"keys"`
}

// ClientTLS defines the configuration params for an HTTP Client
Expand Down
31 changes: 20 additions & 11 deletions config/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ func (p *parseableServiceConfig) normalize() ServiceConfig {
EnableMTLS: p.TLS.EnableMTLS,
DisableSystemCaPool: p.TLS.DisableSystemCaPool,
}
for _, k := range p.TLS.Keys {
cfg.TLS.Keys = append(cfg.TLS.Keys, TLSKeyPair(k))
}
}
if p.ClientTLS != nil {
cfg.ClientTLS = &ClientTLS{
Expand Down Expand Up @@ -244,18 +247,24 @@ func (p *parseableServiceConfig) normalize() ServiceConfig {
return cfg
}

type parseableTLSKeyPair struct {
PublicKey string `json:"public_key"`
PrivateKey string `json:"private_key"`
}

type parseableTLS struct {
IsDisabled bool `json:"disabled"`
PublicKey string `json:"public_key"`
PrivateKey string `json:"private_key"`
CaCerts []string `json:"ca_certs"`
MinVersion string `json:"min_version"`
MaxVersion string `json:"max_version"`
CurvePreferences []uint16 `json:"curve_preferences"`
PreferServerCipherSuites bool `json:"prefer_server_cipher_suites"`
CipherSuites []uint16 `json:"cipher_suites"`
EnableMTLS bool `json:"enable_mtls"`
DisableSystemCaPool bool `json:"disable_system_ca_pool"`
IsDisabled bool `json:"disabled"`
PublicKey string `json:"public_key"`
PrivateKey string `json:"private_key"`
CaCerts []string `json:"ca_certs"`
MinVersion string `json:"min_version"`
MaxVersion string `json:"max_version"`
CurvePreferences []uint16 `json:"curve_preferences"`
PreferServerCipherSuites bool `json:"prefer_server_cipher_suites"`
CipherSuites []uint16 `json:"cipher_suites"`
EnableMTLS bool `json:"enable_mtls"`
DisableSystemCaPool bool `json:"disable_system_ca_pool"`
Keys []parseableTLSKeyPair `json:"keys"`
}

type parseableClientTLS struct {
Expand Down
27 changes: 23 additions & 4 deletions transport/http/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,33 @@ func RunServerWithLoggerFactory(l logging.Logger) func(context.Context, config.S
done <- s.ListenAndServe()
}()
} else {
if cfg.TLS.PublicKey == "" {
if len(cfg.TLS.PublicKey) > 0 || len(cfg.TLS.PrivateKey) > 0 {
cfg.TLS.Keys = append(cfg.TLS.Keys, config.TLSKeyPair{
PublicKey: cfg.TLS.PublicKey,
PrivateKey: cfg.TLS.PrivateKey,
})
}
if len(cfg.TLS.Keys) == 0 {
return ErrPublicKey
}
if cfg.TLS.PrivateKey == "" {
return ErrPrivateKey
for _, k := range cfg.TLS.Keys {
if k.PublicKey == "" {
return ErrPublicKey
}
if k.PrivateKey == "" {
return ErrPrivateKey
}
cert, err := tls.LoadX509KeyPair(k.PublicKey, k.PrivateKey)
if err != nil {
return err
}
s.TLSConfig.Certificates = append(s.TLSConfig.Certificates, cert)
}

go func() {
done <- s.ListenAndServeTLS(cfg.TLS.PublicKey, cfg.TLS.PrivateKey)
// since we already use the list of certificates in the config
// we do not need to specify the files for public and private key here
done <- s.ListenAndServeTLS("", "")
}()
}

Expand Down
106 changes: 106 additions & 0 deletions transport/http/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"net"
"net/http"
"os"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -470,3 +471,108 @@ func h2cClient() *http.Client {
func newPort() int {
return 16666 + rand.Intn(40000)
}

func TestRunServer_MultipleTLS(t *testing.T) {
testKeysAreAvailable(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

port := newPort()

done := make(chan error)
go func() {
done <- RunServer(
ctx,
config.ServiceConfig{
Port: port,
TLS: &config.TLS{
CaCerts: []string{"ca.pem", "exampleca.pem"},
Keys: []config.TLSKeyPair{
config.TLSKeyPair{
PublicKey: "cert.pem",
PrivateKey: "key.pem",
},
config.TLSKeyPair{
PublicKey: "examplecert.pem",
PrivateKey: "examplekey.pem",
},
},
},
},
http.HandlerFunc(dummyHandler),
)
}()

client, err := httpsClient("cert.pem")
if err != nil {
t.Error(err)
return
}

<-time.After(100 * time.Millisecond)

resp, err := client.Get(fmt.Sprintf("https://localhost:%d", port))
if err != nil {
t.Error(err)
return
}
if resp.StatusCode != 200 {
t.Errorf("unexpected status code: %d", resp.StatusCode)
return
}

client, err = httpsClient("examplecert.pem")
if err != nil {
t.Error(err)
return
}
_, err = client.Get(fmt.Sprintf("https://127.0.0.1:%d", port))
// should fail, because it will be served with cert.pem
if err == nil || strings.Contains(err.Error(), "bad certificate") {
t.Error("expected to have 'bad certificate' error")
return
}

req, _ := http.NewRequest("GET", fmt.Sprintf("https://example.com:%d", port), http.NoBody)
overrideHostTransport(client)
resp, err = client.Do(req)
if err != nil {
t.Error(err)
return
}
if resp.StatusCode != 200 {
t.Errorf("unexpected status code: %d", resp.StatusCode)
return
}

cancel()
if err = <-done; err != nil {
t.Error(err)
}
}

// overrideHostTransport subtitutes the actual address that the request will
// connecto (overriding the dns resolution).
func overrideHostTransport(client *http.Client) {
t := http.DefaultTransport.(*http.Transport).Clone()
if client.Transport != nil {
if tt, ok := client.Transport.(*http.Transport); ok {
t = tt
}
}
myDialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}
t.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
overrideAddress := net.JoinHostPort("127.0.0.1", port)
return myDialer.DialContext(ctx, network, overrideAddress)
}
client.Transport = t
}
56 changes: 40 additions & 16 deletions transport/http/server/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,41 @@ import (
"time"
)

func init() {
if err := generateCerts(); err != nil {
log.Fatal(err.Error())
type certDef struct {
Prefix string
IPAddresses []string
DNSNames []string
}

func (c certDef) Org() string {
if c.Prefix == "" {
return "Acme Co"
}
return c.Prefix + " " + "Acme Co"
}

func generateCerts() error {
hosts := []string{"127.0.0.1", "::1", "localhost"}
func init() {
certs := []certDef{
certDef{
Prefix: "",
IPAddresses: []string{"127.0.0.1", "::1"},
DNSNames: []string{"localhost"},
},
certDef{
Prefix: "example",
IPAddresses: []string{"127.0.0.1"},
DNSNames: []string{"example.com"},
},
}

for _, cd := range certs {
if err := generateNamedCert(cd); err != nil {
log.Fatal(err.Error())
}
}
}

func generateNamedCert(hostCert certDef) error {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return fmt.Errorf("Failed to generate private key: %v", err)
Expand All @@ -44,23 +70,21 @@ func generateCerts() error {
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Acme Co"},
Organization: []string{hostCert.Org()},
},
NotBefore: notBefore,
NotAfter: notAfter,

NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: keyUsage,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}

for _, h := range hosts {
if ip := net.ParseIP(h); ip != nil {
for _, strIP := range hostCert.IPAddresses {
if ip := net.ParseIP(strIP); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
}
template.DNSNames = append(template.DNSNames, hostCert.DNSNames...)

template.IsCA = true
template.KeyUsage |= x509.KeyUsageCertSign
Expand All @@ -75,9 +99,9 @@ func generateCerts() error {
return fmt.Errorf("Failed to create ca: %v", err)
}

serverCert := "cert.pem"
serverKey := "key.pem"
caCert := "ca.pem"
serverCert := hostCert.Prefix + "cert.pem"
serverKey := hostCert.Prefix + "key.pem"
caCert := hostCert.Prefix + "ca.pem"

certOut, err := os.Create(serverCert)
if err != nil {
Expand Down

0 comments on commit be5c8bd

Please sign in to comment.