diff --git a/management/cmd/management.go b/management/cmd/management.go index 78b1a8d631f..719d1a78c1a 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -475,7 +475,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) { loadedConfig := &server.Config{} - _, err := util.ReadJson(mgmtConfigPath, loadedConfig) + _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig) if err != nil { return nil, err } diff --git a/util/file.go b/util/file.go index 8355488c98a..ecaecd22260 100644 --- a/util/file.go +++ b/util/file.go @@ -1,11 +1,15 @@ package util import ( + "bytes" "context" "encoding/json" + "fmt" "io" "os" "path/filepath" + "strings" + "text/template" log "github.com/sirupsen/logrus" ) @@ -160,6 +164,55 @@ func ReadJson(file string, res interface{}) (interface{}, error) { return res, nil } +// ReadJsonWithEnvSub reads JSON config file and maps to a provided interface with environment variable substitution +func ReadJsonWithEnvSub(file string, res interface{}) (interface{}, error) { + envVars := getEnvMap() + + f, err := os.Open(file) + if err != nil { + return nil, err + } + defer f.Close() + + bs, err := io.ReadAll(f) + if err != nil { + return nil, err + } + + t, err := template.New("").Parse(string(bs)) + if err != nil { + return nil, fmt.Errorf("error parsing template: %v", err) + } + + var output bytes.Buffer + // Execute the template, substituting environment variables + err = t.Execute(&output, envVars) + if err != nil { + return nil, fmt.Errorf("error executing template: %v", err) + } + + err = json.Unmarshal(output.Bytes(), &res) + if err != nil { + return nil, fmt.Errorf("failed parsing Json file after template was executed, err: %v", err) + } + + return res, nil +} + +// getEnvMap Convert the output of os.Environ() to a map +func getEnvMap() map[string]string { + envMap := make(map[string]string) + + for _, env := range os.Environ() { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = parts[1] + } + } + + return envMap +} + // CopyFileContents copies contents of the given src file to the dst file func CopyFileContents(src, dst string) (err error) { in, err := os.Open(src) diff --git a/util/file_suite_test.go b/util/file_suite_test.go new file mode 100644 index 00000000000..3de7db49bdd --- /dev/null +++ b/util/file_suite_test.go @@ -0,0 +1,126 @@ +package util_test + +import ( + "crypto/md5" + "encoding/hex" + "io" + "os" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/netbirdio/netbird/util" +) + +var _ = Describe("Client", func() { + + var ( + tmpDir string + ) + + type TestConfig struct { + SomeMap map[string]string + SomeArray []string + SomeField int + } + + BeforeEach(func() { + var err error + tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + err := os.RemoveAll(tmpDir) + Expect(err).NotTo(HaveOccurred()) + }) + + Describe("Config", func() { + Context("in JSON format", func() { + It("should be written and read successfully", func() { + + m := make(map[string]string) + m["key1"] = "value1" + m["key2"] = "value2" + + arr := []string{"value1", "value2"} + + written := &TestConfig{ + SomeMap: m, + SomeArray: arr, + SomeField: 99, + } + + err := util.WriteJson(tmpDir+"/testconfig.json", written) + Expect(err).NotTo(HaveOccurred()) + + read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(read).NotTo(BeNil()) + Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) + Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) + Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) + Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) + + }) + }) + }) + + Describe("Copying file contents", func() { + Context("from one file to another", func() { + It("should be successful", func() { + + src := tmpDir + "/copytest_src" + dst := tmpDir + "/copytest_dst" + + err := util.WriteJson(src, []string{"1", "2", "3"}) + Expect(err).NotTo(HaveOccurred()) + + err = util.CopyFileContents(src, dst) + Expect(err).NotTo(HaveOccurred()) + + hashSrc := md5.New() + hashDst := md5.New() + + srcFile, err := os.Open(src) + Expect(err).NotTo(HaveOccurred()) + + dstFile, err := os.Open(dst) + Expect(err).NotTo(HaveOccurred()) + + _, err = io.Copy(hashSrc, srcFile) + Expect(err).NotTo(HaveOccurred()) + + _, err = io.Copy(hashDst, dstFile) + Expect(err).NotTo(HaveOccurred()) + + err = srcFile.Close() + Expect(err).NotTo(HaveOccurred()) + + err = dstFile.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) + }) + }) + }) + + Describe("Handle config file without full path", func() { + Context("config file handling", func() { + It("should be successful", func() { + written := &TestConfig{ + SomeField: 123, + } + cfgFile := "test_cfg.json" + defer os.Remove(cfgFile) + + err := util.WriteJson(cfgFile, written) + Expect(err).NotTo(HaveOccurred()) + + read, err := util.ReadJson(cfgFile, &TestConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(read).NotTo(BeNil()) + }) + }) + }) +}) diff --git a/util/file_test.go b/util/file_test.go index 3de7db49bdd..1330e738e8d 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -1,126 +1,198 @@ -package util_test +package util import ( - "crypto/md5" - "encoding/hex" - "io" "os" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - - "github.com/netbirdio/netbird/util" + "reflect" + "strings" + "testing" ) -var _ = Describe("Client", func() { - - var ( - tmpDir string - ) - - type TestConfig struct { - SomeMap map[string]string - SomeArray []string - SomeField int +func TestReadJsonWithEnvSub(t *testing.T) { + type Config struct { + CertFile string `json:"CertFile"` + Credentials string `json:"Credentials"` + NestedOption struct { + URL string `json:"URL"` + } `json:"NestedOption"` } - BeforeEach(func() { - var err error - tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") - Expect(err).NotTo(HaveOccurred()) - }) - - AfterEach(func() { - err := os.RemoveAll(tmpDir) - Expect(err).NotTo(HaveOccurred()) - }) - - Describe("Config", func() { - Context("in JSON format", func() { - It("should be written and read successfully", func() { - - m := make(map[string]string) - m["key1"] = "value1" - m["key2"] = "value2" + type testCase struct { + name string + envVars map[string]string + jsonTemplate string + expectedResult Config + expectError bool + errorContains string + } - arr := []string{"value1", "value2"} + tests := []testCase{ + { + name: "All environment variables set", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/env_cert.crt", + Credentials: "env_credentials", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "https://env.testing.com", + }, + }, + expectError: false, + }, + { + name: "Missing environment variable", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + // "URL" is intentionally missing + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/env_cert.crt", + Credentials: "env_credentials", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "", + }, + }, + expectError: false, + }, + { + name: "Invalid JSON template", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, // Note the missing closing brace in "{{ .CREDENTIALS }" + expectedResult: Config{}, + expectError: true, + errorContains: "unexpected \"}\" in operand", + }, + { + name: "No substitutions", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "/etc/certs/cert.crt", + "Credentials": "admnlknflkdasdf", + "NestedOption" : { + "URL": "https://testing.com" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/cert.crt", + Credentials: "admnlknflkdasdf", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "https://testing.com", + }, + }, + expectError: false, + }, + { + name: "Should fail when Invalid characters in variables", + envVars: map[string]string{ + "CERT_FILE": `"/etc/certs/"cert".crt"`, + "CREDENTIALS": `env_credentia{ls}`, + "URL": `https://env.testing.com?param={{value}}`, + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: `"/etc/certs/"cert".crt"`, + Credentials: `env_credentia{ls}`, + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: `https://env.testing.com?param={{value}}`, + }, + }, + expectError: true, + }, + } - written := &TestConfig{ - SomeMap: m, - SomeArray: arr, - SomeField: 99, + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + for key, value := range tc.envVars { + t.Setenv(key, value) + } + + tempFile, err := os.CreateTemp("", "config*.json") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + defer func() { + err = os.Remove(tempFile.Name()) + if err != nil { + t.Logf("Failed to remove temp file: %v", err) } + }() - err := util.WriteJson(tmpDir+"/testconfig.json", written) - Expect(err).NotTo(HaveOccurred()) - - read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) - Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) - Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) - Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) - - }) - }) - }) - - Describe("Copying file contents", func() { - Context("from one file to another", func() { - It("should be successful", func() { - - src := tmpDir + "/copytest_src" - dst := tmpDir + "/copytest_dst" - - err := util.WriteJson(src, []string{"1", "2", "3"}) - Expect(err).NotTo(HaveOccurred()) + _, err = tempFile.WriteString(tc.jsonTemplate) + if err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + err = tempFile.Close() + if err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } - err = util.CopyFileContents(src, dst) - Expect(err).NotTo(HaveOccurred()) + var result Config - hashSrc := md5.New() - hashDst := md5.New() + _, err = ReadJsonWithEnvSub(tempFile.Name(), &result) - srcFile, err := os.Open(src) - Expect(err).NotTo(HaveOccurred()) - - dstFile, err := os.Open(dst) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashSrc, srcFile) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashDst, dstFile) - Expect(err).NotTo(HaveOccurred()) - - err = srcFile.Close() - Expect(err).NotTo(HaveOccurred()) - - err = dstFile.Close() - Expect(err).NotTo(HaveOccurred()) - - Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) - }) - }) - }) - - Describe("Handle config file without full path", func() { - Context("config file handling", func() { - It("should be successful", func() { - written := &TestConfig{ - SomeField: 123, + if tc.expectError { + if err == nil { + t.Fatalf("Expected error but got none") } - cfgFile := "test_cfg.json" - defer os.Remove(cfgFile) - - err := util.WriteJson(cfgFile, written) - Expect(err).NotTo(HaveOccurred()) - - read, err := util.ReadJson(cfgFile, &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - }) + if !strings.Contains(err.Error(), tc.errorContains) { + t.Errorf("Expected error containing '%s', but got '%v'", tc.errorContains, err) + } + } else { + if err != nil { + t.Fatalf("ReadJsonWithEnvSub failed: %v", err) + } + if !reflect.DeepEqual(result, tc.expectedResult) { + t.Errorf("Result does not match expected.\nGot: %+v\nExpected: %+v", result, tc.expectedResult) + } + } }) - }) -}) + } +}