Skip to content
This repository has been archived by the owner on Jun 25, 2024. It is now read-only.

Commit

Permalink
config+lnd+cert: Add support in lnd for encrypting the TLS private key.
Browse files Browse the repository at this point in the history
This commit adds support in lnd to encrypt the TLS private key on disk with the wallet's seed. This obviously causes issues when the wallet is locked. So for the WalletUnlocker RPC we generate ephemeral TLS certificates with the key stored in memory. This feature is enabled with the --tlsencryptkey flag.
  • Loading branch information
gkrizek committed Jan 26, 2022
1 parent 3c07ec7 commit b9f2006
Show file tree
Hide file tree
Showing 7 changed files with 400 additions and 53 deletions.
33 changes: 19 additions & 14 deletions cert/selfsigned.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func IsOutdated(cert *x509.Certificate, tlsExtraIPs,
// https://github.com/btcsuite/btcutil
func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
tlsExtraDomains []string, tlsDisableAutofill bool,
certValidity time.Duration) error {
certValidity time.Duration) ([]byte, []byte, error) {

now := time.Now()
validUntil := now.Add(certValidity)
Expand All @@ -210,21 +210,21 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
// Generate a serial number that's below the serialNumberLimit.
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return fmt.Errorf("failed to generate serial number: %s", err)
return nil, nil, fmt.Errorf("failed to generate serial number: %s", err)
}

// Get all DNS names and IP addresses to use when creating the
// certificate.
host, dnsNames := dnsNames(tlsExtraDomains, tlsDisableAutofill)
ipAddresses, err := ipAddresses(tlsExtraIPs, tlsDisableAutofill)
if err != nil {
return err
return nil, nil, err
}

// Generate a private key for the certificate.
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return err
return nil, nil, err
}

// Construct the certificate template.
Expand All @@ -250,35 +250,40 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs,
derBytes, err := x509.CreateCertificate(rand.Reader, &template,
&template, &priv.PublicKey, priv)
if err != nil {
return fmt.Errorf("failed to create certificate: %v", err)
return nil, nil, fmt.Errorf("failed to create certificate: %v", err)
}

certBuf := &bytes.Buffer{}
err = pem.Encode(certBuf, &pem.Block{Type: "CERTIFICATE",
Bytes: derBytes})
if err != nil {
return fmt.Errorf("failed to encode certificate: %v", err)
return nil, nil, fmt.Errorf("failed to encode certificate: %v", err)
}

keybytes, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return fmt.Errorf("unable to encode privkey: %v", err)
return nil, nil, fmt.Errorf("unable to encode privkey: %v", err)
}
keyBuf := &bytes.Buffer{}
err = pem.Encode(keyBuf, &pem.Block{Type: "EC PRIVATE KEY",
Bytes: keybytes})
if err != nil {
return fmt.Errorf("failed to encode private key: %v", err)
return nil, nil, fmt.Errorf("failed to encode private key: %v", err)
}

// Write cert and key files.
if err = ioutil.WriteFile(certFile, certBuf.Bytes(), 0644); err != nil {
return err
if certFile != "" {
if err = ioutil.WriteFile(certFile, certBuf.Bytes(), 0644); err != nil {
return nil, nil, err
}
}
if err = ioutil.WriteFile(keyFile, keyBuf.Bytes(), 0600); err != nil {
os.Remove(certFile)
return err

if keyFile != "" {
if err = ioutil.WriteFile(keyFile, keyBuf.Bytes(), 0600); err != nil {
os.Remove(certFile)
return nil, nil, err
}
}

return nil
return certBuf.Bytes(), keyBuf.Bytes(), nil
}
121 changes: 115 additions & 6 deletions cert/selfsigned_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ const (
var (
extraIPs = []string{"1.1.1.1", "123.123.123.1", "199.189.12.12"}
extraDomains = []string{"home", "and", "away"}
privKeyBytes = [32]byte{
0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab,
0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4,
0x4f, 0x2f, 0x6f, 0x25, 0x88, 0xa3, 0xef, 0xb9,
0x6a, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53,
}

privKey, _ = btcec.PrivKeyFromBytes(btcec.S256(),
privKeyBytes[:])
)

// TestIsOutdatedCert checks that we'll consider the TLS certificate outdated
Expand All @@ -26,11 +35,12 @@ func TestIsOutdatedCert(t *testing.T) {
t.Fatal(err)
}

keyRing := &mock.SecretKeyRing{}
certPath := tempDir + "/tls.cert"
keyPath := tempDir + "/tls.key"

// Generate TLS files with two extra IPs and domains.
err = cert.GenCertPair(
_, _, err = cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath, extraIPs[:2],
extraDomains[:2], false, testTLSCertDuration,
)
Expand All @@ -42,8 +52,17 @@ func TestIsOutdatedCert(t *testing.T) {
// number of IPs and domains.
for numIPs := 1; numIPs <= len(extraIPs); numIPs++ {
for numDomains := 1; numDomains <= len(extraDomains); numDomains++ {
certBytes, err := ioutil.ReadFile(certPath)
if err != nil {
t.Fatal(err)
}
keyBytes, err := ioutil.ReadFile(keyPath)
if err != nil {
t.Fatal(err)
}

_, parsedCert, err := cert.LoadCert(
certPath, keyPath,
certBytes, keyBytes,
)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -81,18 +100,27 @@ func TestIsOutdatedPermutation(t *testing.T) {
t.Fatal(err)
}

keyRing := &mock.SecretKeyRing{}
certPath := tempDir + "/tls.cert"
keyPath := tempDir + "/tls.key"

// Generate TLS files from the IPs and domains.
err = cert.GenCertPair(
_, _, err = cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath, extraIPs[:],
extraDomains[:], false, testTLSCertDuration,
)
if err != nil {
t.Fatal(err)
}
_, parsedCert, err := cert.LoadCert(certPath, keyPath)
certBytes, err := ioutil.ReadFile(certPath)
if err != nil {
t.Fatal(err)
}
keyBytes, err := ioutil.ReadFile(keyPath)
if err != nil {
t.Fatal(err)
}
_, parsedCert, err := cert.LoadCert(certBytes, keyBytes)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -148,11 +176,12 @@ func TestTLSDisableAutofill(t *testing.T) {
t.Fatal(err)
}

keyRing := &mock.SecretKeyRing{}
certPath := tempDir + "/tls.cert"
keyPath := tempDir + "/tls.key"

// Generate TLS files with two extra IPs and domains and no interface IPs.
err = cert.GenCertPair(
_, _, err = cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath, extraIPs[:2],
extraDomains[:2], true, testTLSCertDuration,
)
Expand All @@ -161,8 +190,19 @@ func TestTLSDisableAutofill(t *testing.T) {
"unable to generate tls certificate pair",
)

// Read certs from disk
certBytes, err := ioutil.ReadFile(certPath)
if err != nil {
t.Fatal(err)
}
keyBytes, err := ioutil.ReadFile(keyPath)
if err != nil {
t.Fatal(err)
}

// Load the certificate
_, parsedCert, err := cert.LoadCert(
certPath, keyPath,
certBytes, keyBytes,
)
require.NoError(
t, err,
Expand Down Expand Up @@ -195,3 +235,72 @@ func TestTLSDisableAutofill(t *testing.T) {
"TLS Certificate was not marked as outdated when it should be",
)
}

// TestTlsConfig tests to ensure we can generate a TLS Config from
// a tls cert and tls key.
func TestTlsConfig(t *testing.T) {
tempDir, err := ioutil.TempDir("", "certtest")
if err != nil {
t.Fatal(err)
}

certPath := tempDir + "/tls.cert"
keyPath := tempDir + "/tls.key"
keyRing := &mock.SecretKeyRing{}

// Generate TLS files with an extra IP and domain.
_, _, err = cert.GenCertPair(
"lnd autogenerated cert", certPath, keyPath, []string{extraIPs[0]},
[]string{extraDomains[0]}, false, cert.DefaultAutogenValidity,
)
if err != nil {
t.Fatal(err)
}

// Read certs from disk
certBytes, err := ioutil.ReadFile(certPath)
if err != nil {
t.Fatal(err)
}
keyBytes, err := ioutil.ReadFile(keyPath)
if err != nil {
t.Fatal(err)
}

// Load the certificate
certData, parsedCert, err := cert.LoadCert(
certBytes, keyBytes,
)
if err != nil {
t.Fatal(err)
}

// Check to make sure the IP and domain are in the cert
var foundDomain bool
var foundIp bool
for _, domain := range parsedCert.DNSNames {
if domain == extraDomains[0] {
foundDomain = true
break
}
}
for _, ip := range parsedCert.IPAddresses {
if ip.String() == extraIPs[0] {
foundIp = true
break
}
}
if !foundDomain || !foundIp {
t.Fatal(fmt.Errorf("Did not find required information inside "+
"of TLS Certificate. foundDomain: %v, foundIp: %v",
foundDomain, foundIp))
}

// Create TLS Config
tlsCfg := cert.TLSConfFromCert(certData)

if len(tlsCfg.Certificates) != 1 {
t.Fatal(fmt.Errorf("Found incorrect number of TLS certificates "+
"in TLS Config: %v", len(tlsCfg.Certificates)))
}
}
63 changes: 60 additions & 3 deletions cert/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package cert
import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"sync"
)

var (
Expand All @@ -24,17 +26,36 @@ var (
}
)

type TlsReloader struct {
certMu sync.RWMutex
cert *tls.Certificate
}

// GetCertBytesFromPath reads the TLS certificate and key files at the given
// certPath and keyPath and returns the file bytes.
func GetCertBytesFromPath(certPath, keyPath string) (certBytes, keyBytes []byte, err error) {
certBytes, err = ioutil.ReadFile(certPath)
if err != nil {
return nil, nil, err
}
keyBytes, err = ioutil.ReadFile(keyPath)
if err != nil {
return nil, nil, err
}
return certBytes, keyBytes, nil
}

// LoadCert loads a certificate and its corresponding private key from the PEM
// files indicated and returns the certificate in the two formats it is most
// bytes indicated and returns the certificate in the two formats it is most
// commonly used.
func LoadCert(certPath, keyPath string) (tls.Certificate, *x509.Certificate,
func LoadCert(certBytes, keyBytes []byte) (tls.Certificate, *x509.Certificate,
error) {

// The certData returned here is just a wrapper around the PEM blocks
// loaded from the file. The PEM is not yet fully parsed but a basic
// check is performed that the certificate and private key actually
// belong together.
certData, err := tls.LoadX509KeyPair(certPath, keyPath)
certData, err := tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return tls.Certificate{}, nil, err
}
Expand All @@ -58,3 +79,39 @@ func TLSConfFromCert(certData tls.Certificate) *tls.Config {
MinVersion: tls.VersionTLS12,
}
}

// NewTLSReloader is used to create a new TLS Reloader that will be used
// to update the TLS certificate without restarting the server.
func NewTLSReloader(certBytes, keyBytes []byte) (*TlsReloader, error) {
result := &TlsReloader{}
cert, _, err := LoadCert(certBytes, keyBytes)
if err != nil {
return nil, err
}
result.cert = &cert
return result, nil
}

// AttemptReload will make an attempt to update the TLS certificate
// and key used by the server.
func (tlsr *TlsReloader) AttemptReload(certBytes, keyBytes []byte) error {
newCert, _, err := LoadCert(certBytes, keyBytes)
if err != nil {
return err
}
tlsr.certMu.Lock()
defer tlsr.certMu.Unlock()
tlsr.cert = &newCert
return nil
}

// GetCertificateFunc is used in the server's TLS configuration to
// determine the correct TLS certificate to server on a request.
func (tlsr *TlsReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (
*tls.Certificate, error) {
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
tlsr.certMu.RLock()
defer tlsr.certMu.RUnlock()
return tlsr.cert, nil
}
}
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ type Config struct {
TLSAutoRefresh bool `long:"tlsautorefresh" description:"Re-generate TLS certificate and key if the IPs or domains are changed"`
TLSDisableAutofill bool `long:"tlsdisableautofill" description:"Do not include the interface IPs or the system hostname in TLS certificate, use first --tlsextradomain as Common Name instead, if set"`
TLSCertDuration time.Duration `long:"tlscertduration" description:"The duration for which the auto-generated TLS certificate will be valid for"`
TLSEncryptKey bool `long:"tlsencryptkey" description:"Automatically encrypts the TLS private key and generates ephemeral TLS key pairs when the wallet is locked or not initialized"`

NoMacaroons bool `long:"no-macaroons" description:"Disable macaroon authentication, can only be used if server is not listening on a public interface."`
AdminMacPath string `long:"adminmacaroonpath" description:"Path to write the admin macaroon for lnd's RPC and REST services if it doesn't exist"`
Expand Down
Loading

0 comments on commit b9f2006

Please sign in to comment.