From b9f200645718d72c7cb6b51289084d11dc1a47eb Mon Sep 17 00:00:00 2001 From: Graham Krizek Date: Fri, 29 Oct 2021 20:43:24 -0500 Subject: [PATCH] config+lnd+cert: Add support in lnd for encrypting the TLS private key. This commit adds support in lnd to encrypt the TLS private key on disk with the wallet's seed. This obviously causes issues when the wallet is locked. So for the WalletUnlocker RPC we generate ephemeral TLS certificates with the key stored in memory. This feature is enabled with the --tlsencryptkey flag. --- cert/selfsigned.go | 33 ++++--- cert/selfsigned_test.go | 121 ++++++++++++++++++++++-- cert/tls.go | 63 +++++++++++- config.go | 1 + lnd.go | 205 ++++++++++++++++++++++++++++++++++------ server.go | 26 ++++- server_test.go | 4 +- 7 files changed, 400 insertions(+), 53 deletions(-) diff --git a/cert/selfsigned.go b/cert/selfsigned.go index a0ae23a71e..a97414cdf2 100644 --- a/cert/selfsigned.go +++ b/cert/selfsigned.go @@ -197,7 +197,7 @@ func IsOutdated(cert *x509.Certificate, tlsExtraIPs, // https://github.com/btcsuite/btcutil func GenCertPair(org, certFile, keyFile string, tlsExtraIPs, tlsExtraDomains []string, tlsDisableAutofill bool, - certValidity time.Duration) error { + certValidity time.Duration) ([]byte, []byte, error) { now := time.Now() validUntil := now.Add(certValidity) @@ -210,7 +210,7 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs, // Generate a serial number that's below the serialNumberLimit. serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { - return fmt.Errorf("failed to generate serial number: %s", err) + return nil, nil, fmt.Errorf("failed to generate serial number: %s", err) } // Get all DNS names and IP addresses to use when creating the @@ -218,13 +218,13 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs, host, dnsNames := dnsNames(tlsExtraDomains, tlsDisableAutofill) ipAddresses, err := ipAddresses(tlsExtraIPs, tlsDisableAutofill) if err != nil { - return err + return nil, nil, err } // Generate a private key for the certificate. priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { - return err + return nil, nil, err } // Construct the certificate template. @@ -250,35 +250,40 @@ func GenCertPair(org, certFile, keyFile string, tlsExtraIPs, derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { - return fmt.Errorf("failed to create certificate: %v", err) + return nil, nil, fmt.Errorf("failed to create certificate: %v", err) } certBuf := &bytes.Buffer{} err = pem.Encode(certBuf, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) if err != nil { - return fmt.Errorf("failed to encode certificate: %v", err) + return nil, nil, fmt.Errorf("failed to encode certificate: %v", err) } keybytes, err := x509.MarshalECPrivateKey(priv) if err != nil { - return fmt.Errorf("unable to encode privkey: %v", err) + return nil, nil, fmt.Errorf("unable to encode privkey: %v", err) } keyBuf := &bytes.Buffer{} err = pem.Encode(keyBuf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keybytes}) if err != nil { - return fmt.Errorf("failed to encode private key: %v", err) + return nil, nil, fmt.Errorf("failed to encode private key: %v", err) } // Write cert and key files. - if err = ioutil.WriteFile(certFile, certBuf.Bytes(), 0644); err != nil { - return err + if certFile != "" { + if err = ioutil.WriteFile(certFile, certBuf.Bytes(), 0644); err != nil { + return nil, nil, err + } } - if err = ioutil.WriteFile(keyFile, keyBuf.Bytes(), 0600); err != nil { - os.Remove(certFile) - return err + + if keyFile != "" { + if err = ioutil.WriteFile(keyFile, keyBuf.Bytes(), 0600); err != nil { + os.Remove(certFile) + return nil, nil, err + } } - return nil + return certBuf.Bytes(), keyBuf.Bytes(), nil } diff --git a/cert/selfsigned_test.go b/cert/selfsigned_test.go index dd9953e2a3..35418e423d 100644 --- a/cert/selfsigned_test.go +++ b/cert/selfsigned_test.go @@ -16,6 +16,15 @@ const ( var ( extraIPs = []string{"1.1.1.1", "123.123.123.1", "199.189.12.12"} extraDomains = []string{"home", "and", "away"} + privKeyBytes = [32]byte{ + 0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab, + 0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4, + 0x4f, 0x2f, 0x6f, 0x25, 0x88, 0xa3, 0xef, 0xb9, + 0x6a, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53, + } + + privKey, _ = btcec.PrivKeyFromBytes(btcec.S256(), + privKeyBytes[:]) ) // TestIsOutdatedCert checks that we'll consider the TLS certificate outdated @@ -26,11 +35,12 @@ func TestIsOutdatedCert(t *testing.T) { t.Fatal(err) } + keyRing := &mock.SecretKeyRing{} certPath := tempDir + "/tls.cert" keyPath := tempDir + "/tls.key" // Generate TLS files with two extra IPs and domains. - err = cert.GenCertPair( + _, _, err = cert.GenCertPair( "lnd autogenerated cert", certPath, keyPath, extraIPs[:2], extraDomains[:2], false, testTLSCertDuration, ) @@ -42,8 +52,17 @@ func TestIsOutdatedCert(t *testing.T) { // number of IPs and domains. for numIPs := 1; numIPs <= len(extraIPs); numIPs++ { for numDomains := 1; numDomains <= len(extraDomains); numDomains++ { + certBytes, err := ioutil.ReadFile(certPath) + if err != nil { + t.Fatal(err) + } + keyBytes, err := ioutil.ReadFile(keyPath) + if err != nil { + t.Fatal(err) + } + _, parsedCert, err := cert.LoadCert( - certPath, keyPath, + certBytes, keyBytes, ) if err != nil { t.Fatal(err) @@ -81,18 +100,27 @@ func TestIsOutdatedPermutation(t *testing.T) { t.Fatal(err) } + keyRing := &mock.SecretKeyRing{} certPath := tempDir + "/tls.cert" keyPath := tempDir + "/tls.key" // Generate TLS files from the IPs and domains. - err = cert.GenCertPair( + _, _, err = cert.GenCertPair( "lnd autogenerated cert", certPath, keyPath, extraIPs[:], extraDomains[:], false, testTLSCertDuration, ) if err != nil { t.Fatal(err) } - _, parsedCert, err := cert.LoadCert(certPath, keyPath) + certBytes, err := ioutil.ReadFile(certPath) + if err != nil { + t.Fatal(err) + } + keyBytes, err := ioutil.ReadFile(keyPath) + if err != nil { + t.Fatal(err) + } + _, parsedCert, err := cert.LoadCert(certBytes, keyBytes) if err != nil { t.Fatal(err) } @@ -148,11 +176,12 @@ func TestTLSDisableAutofill(t *testing.T) { t.Fatal(err) } + keyRing := &mock.SecretKeyRing{} certPath := tempDir + "/tls.cert" keyPath := tempDir + "/tls.key" // Generate TLS files with two extra IPs and domains and no interface IPs. - err = cert.GenCertPair( + _, _, err = cert.GenCertPair( "lnd autogenerated cert", certPath, keyPath, extraIPs[:2], extraDomains[:2], true, testTLSCertDuration, ) @@ -161,8 +190,19 @@ func TestTLSDisableAutofill(t *testing.T) { "unable to generate tls certificate pair", ) + // Read certs from disk + certBytes, err := ioutil.ReadFile(certPath) + if err != nil { + t.Fatal(err) + } + keyBytes, err := ioutil.ReadFile(keyPath) + if err != nil { + t.Fatal(err) + } + + // Load the certificate _, parsedCert, err := cert.LoadCert( - certPath, keyPath, + certBytes, keyBytes, ) require.NoError( t, err, @@ -195,3 +235,72 @@ func TestTLSDisableAutofill(t *testing.T) { "TLS Certificate was not marked as outdated when it should be", ) } + +// TestTlsConfig tests to ensure we can generate a TLS Config from +// a tls cert and tls key. +func TestTlsConfig(t *testing.T) { + tempDir, err := ioutil.TempDir("", "certtest") + if err != nil { + t.Fatal(err) + } + + certPath := tempDir + "/tls.cert" + keyPath := tempDir + "/tls.key" + keyRing := &mock.SecretKeyRing{} + + // Generate TLS files with an extra IP and domain. + _, _, err = cert.GenCertPair( + "lnd autogenerated cert", certPath, keyPath, []string{extraIPs[0]}, + []string{extraDomains[0]}, false, cert.DefaultAutogenValidity, + ) + if err != nil { + t.Fatal(err) + } + + // Read certs from disk + certBytes, err := ioutil.ReadFile(certPath) + if err != nil { + t.Fatal(err) + } + keyBytes, err := ioutil.ReadFile(keyPath) + if err != nil { + t.Fatal(err) + } + + // Load the certificate + certData, parsedCert, err := cert.LoadCert( + certBytes, keyBytes, + ) + if err != nil { + t.Fatal(err) + } + + // Check to make sure the IP and domain are in the cert + var foundDomain bool + var foundIp bool + for _, domain := range parsedCert.DNSNames { + if domain == extraDomains[0] { + foundDomain = true + break + } + } + for _, ip := range parsedCert.IPAddresses { + if ip.String() == extraIPs[0] { + foundIp = true + break + } + } + if !foundDomain || !foundIp { + t.Fatal(fmt.Errorf("Did not find required information inside "+ + "of TLS Certificate. foundDomain: %v, foundIp: %v", + foundDomain, foundIp)) + } + + // Create TLS Config + tlsCfg := cert.TLSConfFromCert(certData) + + if len(tlsCfg.Certificates) != 1 { + t.Fatal(fmt.Errorf("Found incorrect number of TLS certificates "+ + "in TLS Config: %v", len(tlsCfg.Certificates))) + } +} diff --git a/cert/tls.go b/cert/tls.go index a8783158e1..6d90f2896d 100644 --- a/cert/tls.go +++ b/cert/tls.go @@ -3,6 +3,8 @@ package cert import ( "crypto/tls" "crypto/x509" + "io/ioutil" + "sync" ) var ( @@ -24,17 +26,36 @@ var ( } ) +type TlsReloader struct { + certMu sync.RWMutex + cert *tls.Certificate +} + +// GetCertBytesFromPath reads the TLS certificate and key files at the given +// certPath and keyPath and returns the file bytes. +func GetCertBytesFromPath(certPath, keyPath string) (certBytes, keyBytes []byte, err error) { + certBytes, err = ioutil.ReadFile(certPath) + if err != nil { + return nil, nil, err + } + keyBytes, err = ioutil.ReadFile(keyPath) + if err != nil { + return nil, nil, err + } + return certBytes, keyBytes, nil +} + // LoadCert loads a certificate and its corresponding private key from the PEM -// files indicated and returns the certificate in the two formats it is most +// bytes indicated and returns the certificate in the two formats it is most // commonly used. -func LoadCert(certPath, keyPath string) (tls.Certificate, *x509.Certificate, +func LoadCert(certBytes, keyBytes []byte) (tls.Certificate, *x509.Certificate, error) { // The certData returned here is just a wrapper around the PEM blocks // loaded from the file. The PEM is not yet fully parsed but a basic // check is performed that the certificate and private key actually // belong together. - certData, err := tls.LoadX509KeyPair(certPath, keyPath) + certData, err := tls.X509KeyPair(certBytes, keyBytes) if err != nil { return tls.Certificate{}, nil, err } @@ -58,3 +79,39 @@ func TLSConfFromCert(certData tls.Certificate) *tls.Config { MinVersion: tls.VersionTLS12, } } + +// NewTLSReloader is used to create a new TLS Reloader that will be used +// to update the TLS certificate without restarting the server. +func NewTLSReloader(certBytes, keyBytes []byte) (*TlsReloader, error) { + result := &TlsReloader{} + cert, _, err := LoadCert(certBytes, keyBytes) + if err != nil { + return nil, err + } + result.cert = &cert + return result, nil +} + +// AttemptReload will make an attempt to update the TLS certificate +// and key used by the server. +func (tlsr *TlsReloader) AttemptReload(certBytes, keyBytes []byte) error { + newCert, _, err := LoadCert(certBytes, keyBytes) + if err != nil { + return err + } + tlsr.certMu.Lock() + defer tlsr.certMu.Unlock() + tlsr.cert = &newCert + return nil +} + +// GetCertificateFunc is used in the server's TLS configuration to +// determine the correct TLS certificate to server on a request. +func (tlsr *TlsReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) ( + *tls.Certificate, error) { + return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + tlsr.certMu.RLock() + defer tlsr.certMu.RUnlock() + return tlsr.cert, nil + } +} diff --git a/config.go b/config.go index c3f26cd69f..c2f03d8c82 100644 --- a/config.go +++ b/config.go @@ -244,6 +244,7 @@ type Config struct { TLSAutoRefresh bool `long:"tlsautorefresh" description:"Re-generate TLS certificate and key if the IPs or domains are changed"` TLSDisableAutofill bool `long:"tlsdisableautofill" description:"Do not include the interface IPs or the system hostname in TLS certificate, use first --tlsextradomain as Common Name instead, if set"` TLSCertDuration time.Duration `long:"tlscertduration" description:"The duration for which the auto-generated TLS certificate will be valid for"` + TLSEncryptKey bool `long:"tlsencryptkey" description:"Automatically encrypts the TLS private key and generates ephemeral TLS key pairs when the wallet is locked or not initialized"` NoMacaroons bool `long:"no-macaroons" description:"Disable macaroon authentication, can only be used if server is not listening on a public interface."` AdminMacPath string `long:"adminmacaroonpath" description:"Path to write the admin macaroon for lnd's RPC and REST services if it doesn't exist"` diff --git a/lnd.go b/lnd.go index 13d7fa2540..23c40e1b9a 100644 --- a/lnd.go +++ b/lnd.go @@ -5,8 +5,10 @@ package lnd import ( + "bytes" "context" "crypto/tls" + "crypto/x509" "errors" "fmt" "io/ioutil" @@ -35,6 +37,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lncfg" + "github.com/lightningnetwork/lnd/lnencrypt" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/macaroons" @@ -217,8 +220,18 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, return mkErr("error initializing DBs: %v", err) } + // The real KeyRing isn't available until after the wallet is unlocked, + // but we need one now if --tlsencryptkey is true. Because we aren't + // encrypting anything here it can be an empty KeyRing. + var emptyKeyRing keychain.KeyRing + // Only process macaroons if --no-macaroons isn't set. - serverOpts, restDialOpts, restListen, cleanUp, err := getTLSConfig(cfg) + serverOpts, + restDialOpts, + restListen, + cleanUp, + tlsReloader, + err := getTLSConfig(cfg, emptyKeyRing) if err != nil { return mkErr("unable to load TLS credentials: %v", err) } @@ -528,6 +541,39 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, } defer atplManager.Stop() + // If --tlsencryptkey is set, we previously generated a throwaway TLSConfig + // Now we want to remove that and load the persistent TLSConfig + // The wallet is unlocked at this point so we can use the real KeyRing + if cfg.TLSEncryptKey { + tmpCertPath := cfg.TLSCertPath + ".tmp" + err = os.Remove(tmpCertPath) + if err != nil { + ltndLog.Warn("unable to delete temp cert at %v", tmpCertPath) + } + + // Ensure the persistent TLS credentials are created + _, _, _, _, _, err = getTLSConfig(cfg, activeChainControl.KeyRing) + if err != nil { + err := fmt.Errorf("unable to load TLS credentials: %v", err) + ltndLog.Error(err) + return err + } + certBytes, keyBytes, err := cert.GetCertBytesFromPath(cfg.TLSCertPath, cfg.TLSKeyPath) + if err != nil { + return err + } + reader := bytes.NewReader(keyBytes) + keyBytes, err = lnencrypt.DecryptPayloadFromReader(reader, activeChainControl.KeyRing) + if err != nil { + return err + } + // Switch the server's TLS certificate to the persistent one + err = tlsReloader.AttemptReload(certBytes, keyBytes) + if err != nil { + return err + } + } + // Now we have created all dependencies necessary to populate and // start the RPC server. err = rpcServer.addDeps( @@ -621,28 +667,125 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // getTLSConfig returns a TLS configuration for the gRPC server and credentials // and a proxy destination for the REST reverse proxy. -func getTLSConfig(cfg *Config) ([]grpc.ServerOption, []grpc.DialOption, - func(net.Addr) (net.Listener, error), func(), error) { +func getTLSConfig(cfg *Config, keyRing keychain.KeyRing) ( + []grpc.ServerOption, []grpc.DialOption, + func(net.Addr) (net.Listener, error), func(), *cert.TlsReloader, error) { + + // If TLS Key Encryption is on but the keyring is empty then this + // is a temporary certificate. + var emptyKeyRing keychain.KeyRing + var certData tls.Certificate + var parsedCert *x509.Certificate + var certBytes []byte + var keyBytes []byte + var err error + var certPath string + if cfg.TLSEncryptKey && (keyRing == emptyKeyRing) { + + rpcsLog.Infof("Generating ephemeral TLS certificates...") + tmpValidity := 24 * time.Hour + // Append .tmp to the end of the cert for differentiation. + tmpCertPath := cfg.TLSCertPath + ".tmp" + certPath = tmpCertPath + // Pass in a blank string for the key path so the + // function doesn't write them to disk. + certBytes, keyBytes, err = cert.GenCertPair( + "lnd temporary autogenerated cert", tmpCertPath, + "", cfg.TLSExtraIPs, cfg.TLSExtraDomains, + cfg.TLSDisableAutofill, tmpValidity, + ) + if err != nil { + return nil, nil, nil, nil, nil, err + } + rpcsLog.Infof("Done generating ephemeral TLS certificates") - // Ensure we create TLS key and certificate if they don't exist. - if !fileExists(cfg.TLSCertPath) && !fileExists(cfg.TLSKeyPath) { - rpcsLog.Infof("Generating TLS certificates...") - err := cert.GenCertPair( - "lnd autogenerated cert", cfg.TLSCertPath, - cfg.TLSKeyPath, cfg.TLSExtraIPs, cfg.TLSExtraDomains, - cfg.TLSDisableAutofill, cfg.TLSCertDuration, + certData, parsedCert, err = cert.LoadCert( + certBytes, keyBytes, ) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, err } - rpcsLog.Infof("Done generating TLS certificates") - } + } else { + + // Ensure we create TLS key and certificate if they don't exist. + if !fileExists(cfg.TLSCertPath) && !fileExists(cfg.TLSKeyPath) { + rpcsLog.Infof("Generating TLS certificates...") + certBytes, keyBytes, err = cert.GenCertPair( + "lnd autogenerated cert", cfg.TLSCertPath, + "", cfg.TLSExtraIPs, cfg.TLSExtraDomains, + cfg.TLSDisableAutofill, cfg.TLSCertDuration, + ) + if err != nil { + return nil, nil, nil, nil, nil, err + } - certData, parsedCert, err := cert.LoadCert( - cfg.TLSCertPath, cfg.TLSKeyPath, - ) - if err != nil { - return nil, nil, nil, nil, err + if cfg.TLSEncryptKey { + keyBuf := bytes.NewBuffer(keyBytes) + var b bytes.Buffer + err = lnencrypt.EncryptPayloadToWriter(*keyBuf, &b, keyRing) + if err != nil { + return nil, nil, nil, nil, nil, err + } + if err = ioutil.WriteFile(cfg.TLSKeyPath, b.Bytes(), + 0600); err != nil { + return nil, nil, nil, nil, nil, err + } + } else { + keyBuf := bytes.NewBuffer(keyBytes) + if err = ioutil.WriteFile(cfg.TLSKeyPath, keyBuf.Bytes(), + 0600); err != nil { + return nil, nil, nil, nil, nil, err + } + } + rpcsLog.Infof("Done generating TLS certificates") + } else { + rpcsLog.Info("Gettings the existing cert data") + certBytes, keyBytes, err = cert.GetCertBytesFromPath(cfg.TLSCertPath, cfg.TLSKeyPath) + if err != nil { + return nil, nil, nil, nil, nil, err + } + } + + // We check to see if the private key is encrypted or plaintext. + // If it's encrypted we need to try to decrypt it so we can use it + // in the gRPC server. + privateKeyPrefix := []byte("-----BEGIN EC PRIVATE KEY-----") + if !bytes.HasPrefix(keyBytes, privateKeyPrefix) { + rpcsLog.Info("Saw that the private key is encrypted") + // If the private key is encrypted but the user didn't pass + // --tlsencryptkey we error out. This is because the wallet is not + // unlocked yet and we don't have access to the keys yet for decrypt. + if !cfg.TLSEncryptKey { + return nil, nil, nil, nil, nil, fmt.Errorf("it appears the TLS key is " + + "encrypted but you didn't pass the --tlsencryptkey flag. " + + "Please restart lnd with the --tlsencryptkey flag or delete " + + "the TLS files for regeneration") + } + reader := bytes.NewReader(keyBytes) + keyBytes, err = lnencrypt.DecryptPayloadFromReader(reader, keyRing) + if err != nil { + return nil, nil, nil, nil, nil, err + } + } else if cfg.TLSEncryptKey { + // If the user requests an encrypted key but the key is in plaintext + // we encrypt the key before writing to disk. + keyBuf := bytes.NewBuffer(keyBytes) + var b bytes.Buffer + err = lnencrypt.EncryptPayloadToWriter(*keyBuf, &b, keyRing) + if err != nil { + return nil, nil, nil, nil, nil, err + } + if err = ioutil.WriteFile(cfg.TLSKeyPath, b.Bytes(), 0600); err != nil { + return nil, nil, nil, nil, nil, err + } + } + certPath = cfg.TLSCertPath + certData, parsedCert, err = cert.LoadCert( + certBytes, keyBytes, + ) + if err != nil { + return nil, nil, nil, nil, nil, err + } } // We check whether the certificate we have on disk match the IPs and @@ -656,7 +799,7 @@ func getTLSConfig(cfg *Config) ([]grpc.ServerOption, []grpc.DialOption, cfg.TLSExtraDomains, cfg.TLSDisableAutofill, ) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, err } } @@ -668,39 +811,45 @@ func getTLSConfig(cfg *Config) ([]grpc.ServerOption, []grpc.DialOption, err := os.Remove(cfg.TLSCertPath) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, err } err = os.Remove(cfg.TLSKeyPath) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, err } rpcsLog.Infof("Renewing TLS certificates...") - err = cert.GenCertPair( + certBytes, keyBytes, err = cert.GenCertPair( "lnd autogenerated cert", cfg.TLSCertPath, cfg.TLSKeyPath, cfg.TLSExtraIPs, cfg.TLSExtraDomains, cfg.TLSDisableAutofill, cfg.TLSCertDuration, ) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, err } rpcsLog.Infof("Done renewing TLS certificates") // Reload the certificate data. certData, _, err = cert.LoadCert( - cfg.TLSCertPath, cfg.TLSKeyPath, + certBytes, keyBytes, ) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, err } } + tlsr, err := cert.NewTLSReloader(certBytes, keyBytes) + if err != nil { + return nil, nil, nil, nil, nil, err + } + tlsCfg := cert.TLSConfFromCert(certData) + tlsCfg.GetCertificate = tlsr.GetCertificateFunc() - restCreds, err := credentials.NewClientTLSFromFile(cfg.TLSCertPath, "") + restCreds, err := credentials.NewClientTLSFromFile(certPath, "") if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, err } // If Let's Encrypt is enabled, instantiate autocert to request/renew @@ -787,7 +936,7 @@ func getTLSConfig(cfg *Config) ([]grpc.ServerOption, []grpc.DialOption, return lncfg.TLSListenOnAddress(addr, tlsCfg) } - return serverOpts, restDialOpts, restListen, cleanUp, nil + return serverOpts, restDialOpts, restListen, cleanUp, tlsr, nil } // fileExists reports whether the named file or directory exists. diff --git a/server.go b/server.go index 61687c42f8..4d67918a9f 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,7 @@ import ( "encoding/hex" "fmt" "image/color" + "io/ioutil" "math/big" prand "math/rand" "net" @@ -47,6 +48,7 @@ import ( "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lncfg" + "github.com/lightningnetwork/lnd/lnencrypt" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/routerrpc" @@ -1554,8 +1556,30 @@ func (s *server) createLivenessMonitor(cfg *Config, cc *chainreg.ChainControl) { tlsHealthCheck := healthcheck.NewObservation( "tls", func() error { + + var emptyKeyRing keychain.KeyRing + certBytes, err := ioutil.ReadFile(cfg.TLSCertPath) + if err != nil { + return err + } + keyBytes, err := ioutil.ReadFile(cfg.TLSKeyPath) + if err != nil { + return err + } + + // If key encryption is set, then decrypt the file. + // We don't need to do a file type check here because GenCertPair + // has been ran with the same value for cfg.TLSEncryptKey. + if cfg.TLSEncryptKey { + reader := bytes.NewReader(keyBytes) + keyBytes, err = lnencrypt.DecryptPayloadFromReader(reader, emptyKeyRing) + if err != nil { + return err + } + } + _, parsedCert, err := cert.LoadCert( - cfg.TLSCertPath, cfg.TLSKeyPath, + certBytes, keyBytes, ) if err != nil { return err diff --git a/server_test.go b/server_test.go index 4ac7adec7b..b5665005e9 100644 --- a/server_test.go +++ b/server_test.go @@ -20,6 +20,7 @@ import ( "time" "github.com/lightningnetwork/lnd/lncfg" + "github.com/lightningnetwork/lnd/lntest/mock" ) func TestParseHexColor(t *testing.T) { @@ -123,7 +124,8 @@ func TestTLSAutoRegeneration(t *testing.T) { TLSCertDuration: 42 * time.Hour, RPCListeners: rpcListeners, } - _, _, _, cleanUp, err := getTLSConfig(cfg) + keyRing := &mock.SecretKeyRing{} + _, _, _, cleanUp, _, err := getTLSConfig(cfg, keyRing) if err != nil { t.Fatalf("couldn't retrieve TLS config") }