From 7f04bcfda6bd766cd7135a8c3a5847b1526f97c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Fabianski?= Date: Mon, 18 Dec 2023 14:25:05 +0100 Subject: [PATCH] tests: improve test coverage for flags --- internal/flag/options.go | 36 ++-- internal/flag/options_test.go | 46 +--- internal/flag/repository_flags_test.go | 287 +++++++++++++++++++++++++ internal/flag/scan_flags.go | 9 +- internal/flag/test_helper.go | 48 +++++ 5 files changed, 371 insertions(+), 55 deletions(-) create mode 100644 internal/flag/repository_flags_test.go create mode 100644 internal/flag/test_helper.go diff --git a/internal/flag/options.go b/internal/flag/options.go index 397a998d5..42fb71539 100644 --- a/internal/flag/options.go +++ b/internal/flag/options.go @@ -18,6 +18,8 @@ import ( type Flags []flagtypes.FlagGroup +const envPrefix = "bearer" + var ErrInvalidScannerReportCombination = errors.New("invalid scanner argument; privacy report requires sast scanner") type flagGroupBase struct { @@ -46,6 +48,26 @@ func addFlag(cmd *cobra.Command, flag *flagtypes.Flag) { } } +func BindViper(flag *flagtypes.Flag) error { + arguments := append( + []string{ + flag.ConfigName, + strings.ToUpper( + strings.Join( + []string{ + envPrefix, + strings.ReplaceAll(flag.Name, "-", "_"), + }, + "_", + ), + ), + }, + flag.EnvironmentVariables..., + ) + + return viper.BindEnv(arguments...) +} + func bind(cmd *cobra.Command, flag *flagtypes.Flag) error { if flag == nil { return nil @@ -59,19 +81,7 @@ func bind(cmd *cobra.Command, flag *flagtypes.Flag) error { return err } - viper.AutomaticEnv() - viper.SetEnvPrefix("bearer") - viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_", ".", "_")) - arguments := append( - []string{flag.ConfigName}, - flag.EnvironmentVariables..., - ) - - if err := viper.BindEnv(arguments...); err != nil { - return err - } - - return nil + return BindViper(flag) } func argsToMap(flag *flagtypes.Flag) map[string]bool { diff --git a/internal/flag/options_test.go b/internal/flag/options_test.go index ce4d1f281..66bc0a892 100644 --- a/internal/flag/options_test.go +++ b/internal/flag/options_test.go @@ -2,32 +2,18 @@ package flag import ( "testing" - - flagtypes "github.com/bearer/bearer/internal/flag/types" - "github.com/spf13/viper" - "github.com/stretchr/testify/assert" ) func Test_getStringSlice(t *testing.T) { - type env struct { - key string - value string - } - tests := []struct { - name string - flag *flagtypes.Flag - flagValue interface{} - env env - want []string - }{ + testCases := []TestCase{ { - name: "happy path. Empty value", + name: "Empty value", flag: ScannerFlag, flagValue: "", want: nil, }, { - name: "happy path. String value", + name: "String value", flag: ScannerFlag, flagValue: "sast,secrets", want: []string{ @@ -36,7 +22,7 @@ func Test_getStringSlice(t *testing.T) { }, }, { - name: "happy path. Slice value", + name: "Slice value", flag: ScannerFlag, flagValue: []string{ "sast", @@ -48,11 +34,11 @@ func Test_getStringSlice(t *testing.T) { }, }, { - name: "happy path. Env value", + name: "Env value", flag: ScannerFlag, - env: env{ + env: Env{ key: "BEARER_SCANNER", - value: "secrets,sast", + value: "sast,secrets", }, want: []string{ string(ScannerSAST), @@ -61,21 +47,5 @@ func Test_getStringSlice(t *testing.T) { }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.env.key == "" { - viper.Set(tt.flag.ConfigName, tt.flagValue) - } else { - // err := viper.BindEnv(tt.flag.ConfigName, tt.env.key) - // assert.NoError(t, err) - - t.Setenv(tt.env.key, tt.env.value) - } - - sl := getStringSlice(tt.flag) - assert.Equal(t, tt.want, sl) - - viper.Reset() - }) - } + RunFlagTests(testCases, t) } diff --git a/internal/flag/repository_flags_test.go b/internal/flag/repository_flags_test.go new file mode 100644 index 000000000..1eaa15f8e --- /dev/null +++ b/internal/flag/repository_flags_test.go @@ -0,0 +1,287 @@ +package flag + +import ( + "testing" +) + +func Test_getRepositoryURLFlag(t *testing.T) { + testCases := []TestCase{ + { + name: "Repository URL. Default", + flag: RepositoryURLFlag, + flagValue: "", + want: nil, + }, + { + name: "Repository URL. ORIGIN_URL env", + flag: RepositoryURLFlag, + env: Env{ + key: "ORIGIN_URL", + value: "https://example.com", + }, + want: []string{ + string("https://example.com"), + }, + }, + { + name: "Repository URL. CI_REPOSITORY_URL env", + flag: RepositoryURLFlag, + env: Env{ + key: "CI_REPOSITORY_URL", + value: "https://example.com", + }, + want: []string{ + string("https://example.com"), + }, + }, + } + + RunFlagTests(testCases, t) +} + +func Test_getRepositoryBranchFlag(t *testing.T) { + testCases := []TestCase{ + { + name: "Repository Branch. Default", + flag: BranchFlag, + flagValue: "", + want: nil, + }, + { + name: "Repository Branch. CURRENT_BRANCH env", + flag: BranchFlag, + env: Env{ + key: "CURRENT_BRANCH", + value: "main", + }, + want: []string{ + string("main"), + }, + }, + { + name: "Repository Branch. CI_COMMIT_REF_NAME env", + flag: BranchFlag, + env: Env{ + key: "CI_COMMIT_REF_NAME", + value: "main", + }, + want: []string{ + string("main"), + }, + }, + } + + RunFlagTests(testCases, t) +} + +func Test_getRepositoryCommitFlag(t *testing.T) { + testCases := []TestCase{ + { + name: "Repository Commit. Default", + flag: CommitFlag, + flagValue: "", + want: nil, + }, + { + name: "Repository Commit. SHA env", + flag: CommitFlag, + env: Env{ + key: "SHA", + value: "abc123", + }, + want: []string{ + string("abc123"), + }, + }, + { + name: "Repository Commit. CI_COMMIT_SHA env", + flag: CommitFlag, + env: Env{ + key: "CI_COMMIT_SHA", + value: "abc123", + }, + want: []string{ + string("abc123"), + }, + }, + } + + RunFlagTests(testCases, t) +} + +func Test_getRepositoryDefaultBranchFlag(t *testing.T) { + testCases := []TestCase{ + { + name: "Repository Default Branch. Default", + flag: DefaultBranchFlag, + flagValue: "", + want: nil, + }, + { + name: "Repository Default Branch. DEFAULT_BRANCH env", + flag: DefaultBranchFlag, + env: Env{ + key: "DEFAULT_BRANCH", + value: "main", + }, + want: []string{ + string("main"), + }, + }, + { + name: "Repository Default Branch. CI_DEFAULT_BRANCH env", + flag: DefaultBranchFlag, + env: Env{ + key: "CI_DEFAULT_BRANCH", + value: "main", + }, + want: []string{ + string("main"), + }, + }, + } + + RunFlagTests(testCases, t) +} + +func Test_getRepositoryDiffBaseBranchFlag(t *testing.T) { + testCases := []TestCase{ + { + name: "Repository Diff Base Branch. Default", + flag: DiffBaseBranchFlag, + flagValue: "", + want: nil, + }, + { + name: "Repository Diff Base Branch. DIFF_BASE_BRANCH env", + flag: DiffBaseBranchFlag, + env: Env{ + key: "DIFF_BASE_BRANCH", + value: "main", + }, + want: []string{ + string("main"), + }, + }, + { + name: "Repository Diff Base Branch. CI_MERGE_REQUEST_TARGET_BRANCH_NAME env", + flag: DiffBaseBranchFlag, + env: Env{ + key: "CI_MERGE_REQUEST_TARGET_BRANCH_NAME", + value: "main", + }, + want: []string{ + string("main"), + }, + }, + } + + RunFlagTests(testCases, t) +} + +func Test_getRepositoryDiffBaseCommitFlag(t *testing.T) { + testCases := []TestCase{ + { + name: "Repository Diff Base Commit. Default", + flag: DiffBaseCommitFlag, + flagValue: "", + want: nil, + }, + { + name: "Repository Diff Base Commit. DIFF_BASE_COMMIT env", + flag: DiffBaseCommitFlag, + env: Env{ + key: "DIFF_BASE_COMMIT", + value: "abc123", + }, + want: []string{ + string("abc123"), + }, + }, + { + name: "Repository Diff Base Commit. CI_MERGE_REQUEST_DIFF_BASE_SHA env", + flag: DiffBaseCommitFlag, + env: Env{ + key: "CI_MERGE_REQUEST_DIFF_BASE_SHA", + value: "abc123", + }, + want: []string{ + string("abc123"), + }, + }, + } + + RunFlagTests(testCases, t) +} + +func Test_getRepositoryGithubTokenFlag(t *testing.T) { + testCases := []TestCase{ + { + name: "Repository GithubTokenFlag. Default", + flag: GithubTokenFlag, + flagValue: "", + want: nil, + }, + { + name: "Repository GithubTokenFlag. GITHUB_TOKEN env", + flag: GithubTokenFlag, + env: Env{ + key: "GITHUB_TOKEN", + value: "abc123", + }, + want: []string{ + string("abc123"), + }, + }, + } + + RunFlagTests(testCases, t) +} + +func Test_getRepositoryGithubRepositoryFlag(t *testing.T) { + testCases := []TestCase{ + { + name: "Repository GithubRepositoryFlag. Default", + flag: GithubRepositoryFlag, + flagValue: "", + want: nil, + }, + { + name: "Repository GithubRepositoryFlag. GITHUB_REPOSITORY env", + flag: GithubRepositoryFlag, + env: Env{ + key: "GITHUB_REPOSITORY", + value: "Bearer/bearer", + }, + want: []string{ + string("Bearer/bearer"), + }, + }, + } + + RunFlagTests(testCases, t) +} + +func Test_getRepositoryGithubAPIURLFlag(t *testing.T) { + testCases := []TestCase{ + { + name: "Repository GithubAPIURLFlag. Default", + flag: GithubAPIURLFlag, + flagValue: "", + want: nil, + }, + { + name: "Repository GithubAPIURLFlag. GITHUB_API_URL env", + flag: GithubAPIURLFlag, + env: Env{ + key: "GITHUB_API_URL", + value: "https://github.com/bearer/bearer", + }, + want: []string{ + string("https://github.com/bearer/bearer"), + }, + }, + } + + RunFlagTests(testCases, t) +} diff --git a/internal/flag/scan_flags.go b/internal/flag/scan_flags.go index 965b9020f..4c21a74b5 100644 --- a/internal/flag/scan_flags.go +++ b/internal/flag/scan_flags.go @@ -89,10 +89,11 @@ var ( Usage: "Specify directories paths that contain .yaml files with external rules configuration", }) ScannerFlag = ScanFlagGroup.add(flagtypes.Flag{ - Name: "scanner", - ConfigName: "scan.scanner", - Value: []string{ScannerSAST}, - Usage: "Specify which scanner to use e.g. --scanner=secrets, --scanner=secrets,sast", + Name: "scanner", + ConfigName: "scan.scanner", + Value: []string{ScannerSAST}, + Usage: "Specify which scanner to use e.g. --scanner=secrets, --scanner=secrets,sast", + EnvironmentVariables: []string{"SCANNER"}, }) ParallelFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "parallel", diff --git a/internal/flag/test_helper.go b/internal/flag/test_helper.go new file mode 100644 index 000000000..cb3fd8fc8 --- /dev/null +++ b/internal/flag/test_helper.go @@ -0,0 +1,48 @@ +package flag + +import ( + "testing" + + flagtypes "github.com/bearer/bearer/internal/flag/types" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" +) + +type Env struct { + key string + value string +} + +type TestCase struct { + name string + flag *flagtypes.Flag + flagValue interface{} + env Env + want []string +} + +func RunFlagTest(testCase TestCase, t *testing.T) { + t.Run(testCase.name, func(t *testing.T) { + if testCase.env.key == "" { + viper.Set(testCase.flag.ConfigName, testCase.flagValue) + } else { + err := BindViper(testCase.flag) + if err != nil { + assert.NoError(t, err) + } + + t.Setenv(testCase.env.key, testCase.env.value) + } + + sl := getStringSlice(testCase.flag) + assert.Equal(t, testCase.want, sl) + + viper.Reset() + }) +} + +func RunFlagTests(tests []TestCase, t *testing.T) { + for _, tt := range tests { + RunFlagTest(tt, t) + } +}