From 31a76bb49538e49f7180580e71b45a889a289674 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Fabianski?= Date: Mon, 18 Dec 2023 10:06:02 +0100 Subject: [PATCH] chore: move flag types into own folder --- internal/classification/db/db.go | 7 +- internal/classification/schema/schema.go | 4 +- internal/commands/artifact/run.go | 7 +- .../process/filelist/filelist_test.go | 8 +- .../commands/process/gitrepository/context.go | 20 +-- internal/commands/process/settings/rules.go | 10 +- .../commands/process/settings/settings.go | 7 +- internal/flag/flags.go | 1 + internal/flag/general_flags.go | 25 ++-- internal/flag/ignore_add_flags.go | 14 +- internal/flag/ignore_migrate_flags.go | 8 +- internal/flag/ignore_show_flags.go | 8 +- internal/flag/options.go | 86 +++-------- internal/flag/options_test.go | 81 ++++++++++ internal/flag/report_flags.go | 17 ++- internal/flag/repository_flags.go | 24 +-- internal/flag/rule_flags.go | 12 +- internal/flag/scan_flags.go | 73 +++++---- internal/flag/types/types.go | 138 ++++++++++++++++++ internal/flag/worker_flags.go | 12 +- .../report/output/privacy/privacy_test.go | 16 +- .../report/output/security/security_test.go | 24 +-- .../report/output/stats/gocloc_detector.go | 6 +- .../detectors/testhelper/testhelper.go | 5 +- internal/version_check/version_check.go | 3 +- 25 files changed, 403 insertions(+), 213 deletions(-) create mode 100644 internal/flag/flags.go create mode 100644 internal/flag/options_test.go create mode 100644 internal/flag/types/types.go diff --git a/internal/classification/db/db.go b/internal/classification/db/db.go index a7268dd9d..9e1815304 100644 --- a/internal/classification/db/db.go +++ b/internal/classification/db/db.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/bearer/bearer/internal/flag" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/tangzero/inflector" ) @@ -141,11 +142,11 @@ func DefaultWithMapping(subjectMappingPath string) DefaultDB { return defaultDB("", subjectMappingPath) } -func DefaultWithContext(context flag.Context) DefaultDB { +func DefaultWithContext(context flagtypes.Context) DefaultDB { return defaultDB(context, "") } -func defaultDB(context flag.Context, subjectMappingPath string) DefaultDB { +func defaultDB(context flagtypes.Context, subjectMappingPath string) DefaultDB { dataCategories := defaultDataCategories(context) categories := map[string]DataCategory{} for _, category := range dataCategories { @@ -189,7 +190,7 @@ func defaultRecipes() []Recipe { return recipes } -func defaultDataCategories(context flag.Context) []DataCategory { +func defaultDataCategories(context flagtypes.Context) []DataCategory { skipHealthContext := true if context == flag.Health { skipHealthContext = false diff --git a/internal/classification/schema/schema.go b/internal/classification/schema/schema.go index 57d5a453f..e3530cf8b 100644 --- a/internal/classification/schema/schema.go +++ b/internal/classification/schema/schema.go @@ -3,7 +3,7 @@ package schema import ( "regexp" - "github.com/bearer/bearer/internal/flag" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/bearer/bearer/internal/classification/db" "github.com/bearer/bearer/internal/report/detectors" @@ -40,7 +40,7 @@ type Config struct { DataTypes []db.DataType DataTypeClassificationPatterns []db.DataTypeClassificationPattern KnownPersonObjectPatterns []db.KnownPersonObjectPattern - Context flag.Context + Context flagtypes.Context } func New(config Config) *Classifier { diff --git a/internal/commands/artifact/run.go b/internal/commands/artifact/run.go index acd3cbf15..885add9f8 100644 --- a/internal/commands/artifact/run.go +++ b/internal/commands/artifact/run.go @@ -23,6 +23,7 @@ import ( "github.com/bearer/bearer/internal/commands/process/orchestrator/work" "github.com/bearer/bearer/internal/commands/process/settings" "github.com/bearer/bearer/internal/flag" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/bearer/bearer/internal/report/basebranchfindings" reportoutput "github.com/bearer/bearer/internal/report/output" "github.com/bearer/bearer/internal/report/output/stats" @@ -56,7 +57,7 @@ type Runner interface { // ReportPath returns the filename of the report ReportPath() string // Scan gathers the findings - Scan(ctx context.Context, opts flag.Options) ([]files.File, *basebranchfindings.Findings, error) + Scan(ctx context.Context, opts flagtypes.Options) ([]files.File, *basebranchfindings.Findings, error) // Report a writes a report Report(files []files.File, baseBranchFindings *basebranchfindings.Findings) (bool, error) } @@ -139,7 +140,7 @@ func (r *runner) CacheUsed() bool { return r.reuseDetection } -func (r *runner) Scan(ctx context.Context, opts flag.Options) ([]files.File, *basebranchfindings.Findings, error) { +func (r *runner) Scan(ctx context.Context, opts flagtypes.Options) ([]files.File, *basebranchfindings.Findings, error) { if r.reuseDetection { return nil, nil, nil } @@ -260,7 +261,7 @@ func getIgnoredFingerprints(client *api.API, settings settings.Config, gitContex } // Run performs artifact scanning -func Run(ctx context.Context, opts flag.Options) (err error) { +func Run(ctx context.Context, opts flagtypes.Options) (err error) { targetPath, err := file.CanonicalPath(opts.Target) if err != nil { return fmt.Errorf("failed to get absolute target: %w", err) diff --git a/internal/commands/process/filelist/filelist_test.go b/internal/commands/process/filelist/filelist_test.go index 7f0d9c303..73fb4bf57 100644 --- a/internal/commands/process/filelist/filelist_test.go +++ b/internal/commands/process/filelist/filelist_test.go @@ -7,7 +7,7 @@ import ( "github.com/bearer/bearer/internal/commands/process/filelist" "github.com/bearer/bearer/internal/commands/process/filelist/files" "github.com/bearer/bearer/internal/commands/process/settings" - "github.com/bearer/bearer/internal/flag" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/hhatto/gocloc" "github.com/stretchr/testify/assert" ) @@ -71,7 +71,7 @@ func TestFileList(t *testing.T) { Input: input{ projectPath: filepath.Join("testdata", "happy_path", "skip"), config: settings.Config{ - Scan: flag.ScanOptions{ + Scan: flagtypes.ScanOptions{ SkipPath: []string{"users/admin.go"}, }, Worker: settings.WorkerOptions{ @@ -94,7 +94,7 @@ func TestFileList(t *testing.T) { Input: input{ projectPath: filepath.Join("testdata", "happy_path", "skip"), config: settings.Config{ - Scan: flag.ScanOptions{ + Scan: flagtypes.ScanOptions{ SkipPath: []string{"users"}, }, Worker: settings.WorkerOptions{ @@ -110,7 +110,7 @@ func TestFileList(t *testing.T) { Input: input{ projectPath: filepath.Join("testdata", "happy_path", "skip"), config: settings.Config{ - Scan: flag.ScanOptions{ + Scan: flagtypes.ScanOptions{ SkipPath: []string{"users"}, }, Worker: settings.WorkerOptions{ diff --git a/internal/commands/process/gitrepository/context.go b/internal/commands/process/gitrepository/context.go index 30731ca1c..7ebffb7ac 100644 --- a/internal/commands/process/gitrepository/context.go +++ b/internal/commands/process/gitrepository/context.go @@ -13,7 +13,7 @@ import ( "golang.org/x/oauth2" "gopkg.in/yaml.v3" - "github.com/bearer/bearer/internal/flag" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/bearer/bearer/internal/git" ) @@ -35,7 +35,7 @@ type Context struct { HasUncommittedChanges bool } -func NewContext(options *flag.Options) (*Context, error) { +func NewContext(options *flagtypes.Options) (*Context, error) { if options.IgnoreGit { return nil, nil } @@ -118,7 +118,7 @@ func NewContext(options *flag.Options) (*Context, error) { return context, nil } -func getBranch(options *flag.Options, currentBranch string) string { +func getBranch(options *flagtypes.Options, currentBranch string) string { if options.Branch != "" { return options.Branch } @@ -126,7 +126,7 @@ func getBranch(options *flag.Options, currentBranch string) string { return currentBranch } -func getDefaultBranch(options *flag.Options, rootDir string) (string, error) { +func getDefaultBranch(options *flagtypes.Options, rootDir string) (string, error) { if options.DefaultBranch != "" { return options.DefaultBranch, nil } @@ -134,7 +134,7 @@ func getDefaultBranch(options *flag.Options, rootDir string) (string, error) { return git.GetDefaultBranch(rootDir) } -func getBaseBranch(options *flag.Options, defaultBranch string) (string, error) { +func getBaseBranch(options *flagtypes.Options, defaultBranch string) (string, error) { if !options.Diff { return "", nil } @@ -154,7 +154,7 @@ func getBaseBranch(options *flag.Options, defaultBranch string) (string, error) ) } -func getCommitHash(options *flag.Options, currentCommitHash string) string { +func getCommitHash(options *flagtypes.Options, currentCommitHash string) string { if options.Commit != "" { return options.Commit } @@ -163,7 +163,7 @@ func getCommitHash(options *flag.Options, currentCommitHash string) string { } func getBaseCommitHash( - options *flag.Options, + options *flagtypes.Options, rootDir string, baseBranch string, currentCommitHash string, @@ -210,7 +210,7 @@ func getBaseCommitHash( ) } -func lookupBaseCommitHashFromGithub(options *flag.Options, baseBranch string, currentCommitHash string) (string, error) { +func lookupBaseCommitHashFromGithub(options *flagtypes.Options, baseBranch string, currentCommitHash string) (string, error) { if options.GithubToken == "" || options.GithubRepository == "" { return "", nil } @@ -245,7 +245,7 @@ func lookupBaseCommitHashFromGithub(options *flag.Options, baseBranch string, cu return *comparison.MergeBaseCommit.SHA, nil } -func getOriginURL(options *flag.Options, rootDir string) (string, error) { +func getOriginURL(options *flagtypes.Options, rootDir string) (string, error) { if options.OriginURL != "" { return options.OriginURL, nil } @@ -253,7 +253,7 @@ func getOriginURL(options *flag.Options, rootDir string) (string, error) { return git.GetOriginURL(rootDir) } -func newGithubClient(options *flag.Options) (*github.Client, error) { +func newGithubClient(options *flagtypes.Options) (*github.Client, error) { tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: options.GithubToken}) httpClient := oauth2.NewClient(context.Background(), tokenSource) client := github.NewClient(httpClient) diff --git a/internal/commands/process/settings/rules.go b/internal/commands/process/settings/rules.go index 6c6fa3d24..bab3ed4a3 100644 --- a/internal/commands/process/settings/rules.go +++ b/internal/commands/process/settings/rules.go @@ -10,7 +10,7 @@ import ( "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" - "github.com/bearer/bearer/internal/flag" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/bearer/bearer/internal/report/customdetectors" "github.com/bearer/bearer/internal/util/output" "github.com/bearer/bearer/internal/util/set" @@ -45,7 +45,7 @@ func GetSupportedRuleLanguages() map[string]bool { func loadRules( externalRuleDirs []string, - options flag.RuleOptions, + options flagtypes.RuleOptions, versionMeta *version_check.VersionMeta, force bool, ) ( @@ -93,7 +93,7 @@ func loadRules( func loadRuleDefinitionsFromRemote( definitions map[string]RuleDefinition, - options flag.RuleOptions, + options flagtypes.RuleOptions, versionMeta *version_check.VersionMeta, ) { if options.DisableDefaultRules { @@ -310,7 +310,7 @@ func getSanitizers(definition *RuleDefinition) set.Set[string] { } func validateRuleOptionIDs( - options flag.RuleOptions, + options flagtypes.RuleOptions, definitions map[string]RuleDefinition, builtInDefinitions map[string]RuleDefinition, ) error { @@ -341,7 +341,7 @@ func validateRuleOptionIDs( return nil } -func getEnabledRules(options flag.RuleOptions, definitions map[string]RuleDefinition, rules map[string]struct{}) map[string]struct{} { +func getEnabledRules(options flagtypes.RuleOptions, definitions map[string]RuleDefinition, rules map[string]struct{}) map[string]struct{} { enabledRules := make(map[string]struct{}) for ruleId := range rules { diff --git a/internal/commands/process/settings/settings.go b/internal/commands/process/settings/settings.go index 33d06aa0f..aa91a2a1f 100644 --- a/internal/commands/process/settings/settings.go +++ b/internal/commands/process/settings/settings.go @@ -11,6 +11,7 @@ import ( "github.com/bearer/bearer/api" "github.com/bearer/bearer/internal/flag" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/bearer/bearer/internal/util/ignore" ignoretypes "github.com/bearer/bearer/internal/util/ignore/types" "github.com/bearer/bearer/internal/util/output" @@ -49,8 +50,8 @@ type WorkerOptions struct { type Config struct { Client *api.API Worker WorkerOptions `mapstructure:"worker" json:"worker" yaml:"worker"` - Scan flag.ScanOptions `mapstructure:"scan" json:"scan" yaml:"scan"` - Report flag.ReportOptions `mapstructure:"report" json:"report" yaml:"report"` + Scan flagtypes.ScanOptions `mapstructure:"scan" json:"scan" yaml:"scan"` + Report flagtypes.ReportOptions `mapstructure:"report" json:"report" yaml:"report"` IgnoredFingerprints map[string]ignoretypes.IgnoredFingerprint `mapstructure:"ignored_fingerprints" json:"ignored_fingerprints" yaml:"ignored_fingerprints"` StaleIgnoredFingerprintIds []string `mapstructure:"stale_ignored_fingerprint_ids" json:"stale_ignored_fingerprint_ids" yaml:"stale_ignored_fingerprint_ids"` CloudIgnoresUsed bool `mapstructure:"cloud_ignores_used" json:"cloud_ignores_used" yaml:"cloud_ignores_used"` @@ -318,7 +319,7 @@ func defaultWorkerOptions() WorkerOptions { } } -func FromOptions(opts flag.Options, versionMeta *version_check.VersionMeta) (Config, error) { +func FromOptions(opts flagtypes.Options, versionMeta *version_check.VersionMeta) (Config, error) { policies := DefaultPolicies() workerOptions := defaultWorkerOptions() result, err := loadRules( diff --git a/internal/flag/flags.go b/internal/flag/flags.go new file mode 100644 index 000000000..02d20d554 --- /dev/null +++ b/internal/flag/flags.go @@ -0,0 +1 @@ +package flag diff --git a/internal/flag/general_flags.go b/internal/flag/general_flags.go index 5d3bd7fb8..254d3d1eb 100644 --- a/internal/flag/general_flags.go +++ b/internal/flag/general_flags.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/bearer/bearer/api" + flagtypes "github.com/bearer/bearer/internal/flag/types" pointer "github.com/bearer/bearer/internal/util/pointers" "github.com/rs/zerolog/log" ) @@ -20,7 +21,7 @@ type generalFlagGroup struct{ flagGroupBase } var GeneralFlagGroup = &generalFlagGroup{flagGroupBase{name: "General"}} var ( - HostFlag = GeneralFlagGroup.add(Flag{ + HostFlag = GeneralFlagGroup.add(flagtypes.Flag{ Name: "host", ConfigName: "host", Value: "my.bearer.sh", @@ -29,7 +30,7 @@ var ( Hide: true, }) - APIKeyFlag = GeneralFlagGroup.add(Flag{ + APIKeyFlag = GeneralFlagGroup.add(flagtypes.Flag{ Name: "api-key", ConfigName: "api-key", Value: "", @@ -38,7 +39,7 @@ var ( Hide: true, }) - ConfigFileFlag = GeneralFlagGroup.add(Flag{ + ConfigFileFlag = GeneralFlagGroup.add(flagtypes.Flag{ Name: "config-file", ConfigName: "config-file", Value: "bearer.yml", @@ -46,21 +47,21 @@ var ( DisableInConfig: true, }) - DisableVersionCheckFlag = GeneralFlagGroup.add(Flag{ + DisableVersionCheckFlag = GeneralFlagGroup.add(flagtypes.Flag{ Name: "disable-version-check", ConfigName: "disable-version-check", Value: false, Usage: "Disable Bearer version checking", }) - NoColorFlag = GeneralFlagGroup.add(Flag{ + NoColorFlag = GeneralFlagGroup.add(flagtypes.Flag{ Name: "no-color", ConfigName: "report.no-color", Value: false, Usage: "Disable color in output", }) - IgnoreFileFlag = GeneralFlagGroup.add(Flag{ + IgnoreFileFlag = GeneralFlagGroup.add(flagtypes.Flag{ Name: "ignore-file", ConfigName: "ignore-file", Value: "bearer.ignore", @@ -68,7 +69,7 @@ var ( DisableInConfig: true, }) - DebugFlag = GeneralFlagGroup.add(Flag{ + DebugFlag = GeneralFlagGroup.add(flagtypes.Flag{ Name: "debug", ConfigName: "debug", Value: false, @@ -76,14 +77,14 @@ var ( DisableInConfig: true, }) - LogLevelFlag = GeneralFlagGroup.add(Flag{ + LogLevelFlag = GeneralFlagGroup.add(flagtypes.Flag{ Name: "log-level", ConfigName: "log-level", Value: "info", Usage: "Set log level (error, info, debug, trace)", }) - DebugProfileFlag = GeneralFlagGroup.add(Flag{ + DebugProfileFlag = GeneralFlagGroup.add(flagtypes.Flag{ Name: "debug-profile", ConfigName: "debug-profile", Value: false, @@ -92,7 +93,7 @@ var ( DisableInConfig: true, }) - IgnoreGitFlag = GeneralFlagGroup.add(Flag{ + IgnoreGitFlag = GeneralFlagGroup.add(flagtypes.Flag{ Name: "ignore-git", ConfigName: "ignore-git", Value: false, @@ -115,7 +116,7 @@ type GeneralOptions struct { IgnoreGit bool `mapstructure:"ignore-git" json:"ignore-git" yaml:"ignore-git"` } -func (generalFlagGroup) SetOptions(options *Options, args []string) error { +func (generalFlagGroup) SetOptions(options *flagtypes.Options, args []string) error { var client *api.API apiKey := getString(APIKeyFlag) if apiKey != "" { @@ -139,7 +140,7 @@ func (generalFlagGroup) SetOptions(options *Options, args []string) error { logLevel = DebugLogLevel } - options.GeneralOptions = GeneralOptions{ + options.GeneralOptions = flagtypes.GeneralOptions{ Client: client, ConfigFile: getString(ConfigFileFlag), DisableVersionCheck: getBool(DisableVersionCheckFlag), diff --git a/internal/flag/ignore_add_flags.go b/internal/flag/ignore_add_flags.go index 4d97ce3a9..b9027deca 100644 --- a/internal/flag/ignore_add_flags.go +++ b/internal/flag/ignore_add_flags.go @@ -1,11 +1,13 @@ package flag +import flagtypes "github.com/bearer/bearer/internal/flag/types" + type ignoreAddFlagGroup struct{ flagGroupBase } var IgnoreAddFlagGroup = &ignoreAddFlagGroup{flagGroupBase{name: "Ignore Add"}} var ( - AuthorFlag = IgnoreAddFlagGroup.add(Flag{ + AuthorFlag = IgnoreAddFlagGroup.add(flagtypes.Flag{ Name: "author", ConfigName: "ignore_add.author", Shorthand: "a", @@ -13,21 +15,21 @@ var ( Usage: "Add author information to this ignored finding. (default output of \"git config user.name\")", }) - CommentFlag = IgnoreAddFlagGroup.add(Flag{ + CommentFlag = IgnoreAddFlagGroup.add(flagtypes.Flag{ Name: "comment", ConfigName: "ignore_add.comment", Value: FormatEmpty, Usage: "Add a comment to this ignored finding.", }) - FalsePositiveFlag = IgnoreAddFlagGroup.add(Flag{ + FalsePositiveFlag = IgnoreAddFlagGroup.add(flagtypes.Flag{ Name: "false-positive", ConfigName: "ignore_add.false-positive", Value: false, Usage: "Mark an this ignored finding as false positive.", }) - IgnoreAddForceFlag = IgnoreAddFlagGroup.add(Flag{ + IgnoreAddForceFlag = IgnoreAddFlagGroup.add(flagtypes.Flag{ Name: "force", ConfigName: "ignore_add.force", Value: false, @@ -42,8 +44,8 @@ type IgnoreAddOptions struct { Force bool `mapstructure:"ignore_add_force" json:"ignore_add_force" yaml:"ignore_add_force"` } -func (ignoreAddFlagGroup) SetOptions(options *Options, args []string) error { - options.IgnoreAddOptions = IgnoreAddOptions{ +func (ignoreAddFlagGroup) SetOptions(options *flagtypes.Options, args []string) error { + options.IgnoreAddOptions = flagtypes.IgnoreAddOptions{ Author: getString(AuthorFlag), Comment: getString(CommentFlag), FalsePositive: getBool(FalsePositiveFlag), diff --git a/internal/flag/ignore_migrate_flags.go b/internal/flag/ignore_migrate_flags.go index 7ef851784..b326d85a1 100644 --- a/internal/flag/ignore_migrate_flags.go +++ b/internal/flag/ignore_migrate_flags.go @@ -1,11 +1,13 @@ package flag +import flagtypes "github.com/bearer/bearer/internal/flag/types" + type ignoreMigrateFlagGroup struct{ flagGroupBase } var IgnoreMigrateFlagGroup = &ignoreMigrateFlagGroup{flagGroupBase{name: "Ignore Migrate"}} var ( - IgnoreMigrateForceFlag = IgnoreMigrateFlagGroup.add(Flag{ + IgnoreMigrateForceFlag = IgnoreMigrateFlagGroup.add(flagtypes.Flag{ Name: "force", ConfigName: "ignore_migrate.force", Value: false, @@ -17,8 +19,8 @@ type IgnoreMigrateOptions struct { Force bool `mapstructure:"ignore_migrate_force" json:"ignore_migrate_force" yaml:"ignore_migrate_force"` } -func (ignoreMigrateFlagGroup) SetOptions(options *Options, args []string) error { - options.IgnoreMigrateOptions = IgnoreMigrateOptions{ +func (ignoreMigrateFlagGroup) SetOptions(options *flagtypes.Options, args []string) error { + options.IgnoreMigrateOptions = flagtypes.IgnoreMigrateOptions{ Force: getBool(IgnoreMigrateForceFlag), } diff --git a/internal/flag/ignore_show_flags.go b/internal/flag/ignore_show_flags.go index 966555be6..89876fd16 100644 --- a/internal/flag/ignore_show_flags.go +++ b/internal/flag/ignore_show_flags.go @@ -1,11 +1,13 @@ package flag +import flagtypes "github.com/bearer/bearer/internal/flag/types" + type ignoreShowFlagGroup struct{ flagGroupBase } var IgnoreShowFlagGroup = &ignoreShowFlagGroup{flagGroupBase{name: "Ignore Show"}} var ( - AllFlag = IgnoreShowFlagGroup.add(Flag{ + AllFlag = IgnoreShowFlagGroup.add(flagtypes.Flag{ Name: "all", ConfigName: "ignore_show.all", Value: false, @@ -17,8 +19,8 @@ type IgnoreShowOptions struct { All bool `mapstructure:"all" json:"all" yaml:"all"` } -func (ignoreShowFlagGroup) SetOptions(options *Options, args []string) error { - options.IgnoreShowOptions = IgnoreShowOptions{ +func (ignoreShowFlagGroup) SetOptions(options *flagtypes.Options, args []string) error { + options.IgnoreShowOptions = flagtypes.IgnoreShowOptions{ All: getBool(AllFlag), } diff --git a/internal/flag/options.go b/internal/flag/options.go index d18280da7..397a998d5 100644 --- a/internal/flag/options.go +++ b/internal/flag/options.go @@ -11,69 +11,21 @@ import ( "github.com/spf13/pflag" "github.com/spf13/viper" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/bearer/bearer/internal/types" "github.com/bearer/bearer/internal/util/set" ) -var ErrInvalidScannerReportCombination = errors.New("invalid scanner argument; privacy report requires sast scanner") - -type Flag struct { - // Name is for CLI flag and environment variable. - // If this field is empty, it will be available only in config file. - Name string - - // ConfigName is a key in config file. It is also used as a key of viper. - ConfigName string - - // Shorthand is a shorthand letter. - Shorthand string - - // Value is the default value. It must be filled to determine the flag type. - Value interface{} - - // Usage explains how to use the flag. - Usage string - - // DisableInConfig represents if flag should be present in config - DisableInConfig bool - - // Do not show flag in the helper - Hide bool +type Flags []flagtypes.FlagGroup - // Deprecated represents if the flag is deprecated - Deprecated bool - - // Additional environment variables to read the value from, in addition to the default - EnvironmentVariables []string -} +var ErrInvalidScannerReportCombination = errors.New("invalid scanner argument; privacy report requires sast scanner") type flagGroupBase struct { name string - flags []*Flag -} - -type FlagGroup interface { - Name() string - Flags() []*Flag - SetOptions(options *Options, args []string) error -} - -type Flags []FlagGroup - -// Options holds all the runtime configuration -type Options struct { - ReportOptions - RuleOptions - ScanOptions - RepositoryOptions - GeneralOptions - IgnoreAddOptions - IgnoreShowOptions - IgnoreMigrateOptions - WorkerOptions + flags []*flagtypes.Flag } -func addFlag(cmd *cobra.Command, flag *Flag) { +func addFlag(cmd *cobra.Command, flag *flagtypes.Flag) { if flag == nil || flag.Name == "" { return } @@ -94,7 +46,7 @@ func addFlag(cmd *cobra.Command, flag *Flag) { } } -func bind(cmd *cobra.Command, flag *Flag) error { +func bind(cmd *cobra.Command, flag *flagtypes.Flag) error { if flag == nil { return nil } else if flag.Name == "" { @@ -122,7 +74,7 @@ func bind(cmd *cobra.Command, flag *Flag) error { return nil } -func argsToMap(flag *Flag) map[string]bool { +func argsToMap(flag *flagtypes.Flag) map[string]bool { strSlice := getStringSlice(flag) result := make(map[string]bool) @@ -133,7 +85,7 @@ func argsToMap(flag *Flag) map[string]bool { return result } -func getString(flag *Flag) string { +func getString(flag *flagtypes.Flag) string { if flag == nil { return "" } @@ -141,7 +93,7 @@ func getString(flag *Flag) string { return viper.GetString(flag.ConfigName) } -func getStringSlice(flag *Flag) []string { +func getStringSlice(flag *flagtypes.Flag) []string { if flag == nil { return nil } @@ -159,21 +111,21 @@ func getStringSlice(flag *Flag) []string { return v } -func getBool(flag *Flag) bool { +func getBool(flag *flagtypes.Flag) bool { if flag == nil { return false } return viper.GetBool(flag.ConfigName) } -func getDuration(flag *Flag) time.Duration { +func getDuration(flag *flagtypes.Flag) time.Duration { if flag == nil { return 0 } return viper.GetDuration(flag.ConfigName) } -func getInteger(flag *Flag) int { +func getInteger(flag *flagtypes.Flag) int { if flag == nil { return -1 } @@ -181,7 +133,7 @@ func getInteger(flag *Flag) int { return viper.GetInt(flag.ConfigName) } -func getSeverities(flag *Flag) set.Set[string] { +func getSeverities(flag *flagtypes.Flag) set.Set[string] { result := set.New[string]() for _, value := range getStringSlice(flag) { @@ -195,7 +147,7 @@ func getSeverities(flag *Flag) set.Set[string] { return result } -func (f *flagGroupBase) add(flag Flag) *Flag { +func (f *flagGroupBase) add(flag flagtypes.Flag) *flagtypes.Flag { f.flags = append(f.flags, &flag) return &flag } @@ -204,7 +156,7 @@ func (f *flagGroupBase) Name() string { return f.name } -func (f *flagGroupBase) Flags() []*Flag { +func (f *flagGroupBase) Flags() []*flagtypes.Flag { return f.flags } @@ -264,18 +216,18 @@ func (f Flags) bind(cmd *cobra.Command, supportIgnoreConfig bool) error { return nil } -func (f Flags) ToOptions(args []string) (Options, error) { +func (f Flags) ToOptions(args []string) (flagtypes.Options, error) { // var err error - options := Options{} + options := flagtypes.Options{} for _, group := range f { if err := group.SetOptions(&options, args); err != nil { - return Options{}, fmt.Errorf("%s flags error: %w", group.Name(), err) + return flagtypes.Options{}, fmt.Errorf("%s flags error: %w", group.Name(), err) } } if options.ReportOptions.Report == "privacy" && !slices.Contains(options.ScanOptions.Scanner, "sast") { - return Options{}, ErrInvalidScannerReportCombination + return flagtypes.Options{}, ErrInvalidScannerReportCombination } return options, nil diff --git a/internal/flag/options_test.go b/internal/flag/options_test.go new file mode 100644 index 000000000..ce4d1f281 --- /dev/null +++ b/internal/flag/options_test.go @@ -0,0 +1,81 @@ +package flag + +import ( + "testing" + + flagtypes "github.com/bearer/bearer/internal/flag/types" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" +) + +func Test_getStringSlice(t *testing.T) { + type env struct { + key string + value string + } + tests := []struct { + name string + flag *flagtypes.Flag + flagValue interface{} + env env + want []string + }{ + { + name: "happy path. Empty value", + flag: ScannerFlag, + flagValue: "", + want: nil, + }, + { + name: "happy path. String value", + flag: ScannerFlag, + flagValue: "sast,secrets", + want: []string{ + string(ScannerSAST), + string(ScannerSecrets), + }, + }, + { + name: "happy path. Slice value", + flag: ScannerFlag, + flagValue: []string{ + "sast", + "secrets", + }, + want: []string{ + string(ScannerSAST), + string(ScannerSecrets), + }, + }, + { + name: "happy path. Env value", + flag: ScannerFlag, + env: env{ + key: "BEARER_SCANNER", + value: "secrets,sast", + }, + want: []string{ + string(ScannerSAST), + string(ScannerSecrets), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.env.key == "" { + viper.Set(tt.flag.ConfigName, tt.flagValue) + } else { + // err := viper.BindEnv(tt.flag.ConfigName, tt.env.key) + // assert.NoError(t, err) + + t.Setenv(tt.env.key, tt.env.value) + } + + sl := getStringSlice(tt.flag) + assert.Equal(t, tt.want, sl) + + viper.Reset() + }) + } +} diff --git a/internal/flag/report_flags.go b/internal/flag/report_flags.go index aa0219546..bd3ffe543 100644 --- a/internal/flag/report_flags.go +++ b/internal/flag/report_flags.go @@ -4,6 +4,7 @@ import ( "errors" "strings" + flagtypes "github.com/bearer/bearer/internal/flag/types" globaltypes "github.com/bearer/bearer/internal/types" "github.com/bearer/bearer/internal/util/set" sliceutil "github.com/bearer/bearer/internal/util/slices" @@ -42,38 +43,38 @@ type reportFlagGroup struct{ flagGroupBase } var ReportFlagGroup = &reportFlagGroup{flagGroupBase{name: "Report"}} var ( - FormatFlag = ReportFlagGroup.add(Flag{ + FormatFlag = ReportFlagGroup.add(flagtypes.Flag{ Name: "format", ConfigName: "report.format", Shorthand: "f", Value: FormatEmpty, Usage: "Specify report format (json, yaml, sarif, gitlab-sast, rdjson, html)", }) - ReportFlag = ReportFlagGroup.add(Flag{ + ReportFlag = ReportFlagGroup.add(flagtypes.Flag{ Name: "report", ConfigName: "report.report", Value: ReportSecurity, Usage: "Specify the type of report (security, privacy, dataflow).", }) - OutputFlag = ReportFlagGroup.add(Flag{ + OutputFlag = ReportFlagGroup.add(flagtypes.Flag{ Name: "output", ConfigName: "report.output", Value: "", Usage: "Specify the output path for the report.", }) - SeverityFlag = ReportFlagGroup.add(Flag{ + SeverityFlag = ReportFlagGroup.add(flagtypes.Flag{ Name: "severity", ConfigName: "report.severity", Value: strings.Join(globaltypes.Severities, ","), Usage: "Specify which severities are included in the report.", }) - FailOnSeverityFlag = ReportFlagGroup.add(Flag{ + FailOnSeverityFlag = ReportFlagGroup.add(flagtypes.Flag{ Name: "fail-on-severity", ConfigName: "report.fail-on-severity", Value: strings.Join(sliceutil.Except(globaltypes.Severities, globaltypes.LevelWarning), ","), Usage: "Specify which severities cause the report to fail. Works in conjunction with --exit-code.", }) - ExcludeFingerprintFlag = ReportFlagGroup.add(Flag{ + ExcludeFingerprintFlag = ReportFlagGroup.add(flagtypes.Flag{ Name: "exclude-fingerprint", ConfigName: "report.exclude-fingerprint", Value: []string{}, @@ -93,7 +94,7 @@ type ReportOptions struct { ExcludeFingerprint map[string]bool `mapstructure:"exclude_fingerprints" json:"exclude_fingerprints" yaml:"exclude_fingerprints"` } -func (reportFlagGroup) SetOptions(options *Options, args []string) error { +func (reportFlagGroup) SetOptions(options *flagtypes.Options, args []string) error { invalidFormat := ErrInvalidFormatDefault report := getString(ReportFlag) switch report { @@ -147,7 +148,7 @@ func (reportFlagGroup) SetOptions(options *Options, args []string) error { excludeFingerprintsMapping[fingerprint] = true } - options.ReportOptions = ReportOptions{ + options.ReportOptions = flagtypes.ReportOptions{ Format: format, Report: report, Output: getString(OutputFlag), diff --git a/internal/flag/repository_flags.go b/internal/flag/repository_flags.go index ae483efd5..d56396cf4 100644 --- a/internal/flag/repository_flags.go +++ b/internal/flag/repository_flags.go @@ -1,11 +1,13 @@ package flag +import flagtypes "github.com/bearer/bearer/internal/flag/types" + type repositoryFlagGroup struct{ flagGroupBase } var RepositoryFlagGroup = &repositoryFlagGroup{flagGroupBase{name: "Repository"}} var ( - RepositoryURLFlag = RepositoryFlagGroup.add(Flag{ + RepositoryURLFlag = RepositoryFlagGroup.add(flagtypes.Flag{ Name: "repository-url", ConfigName: "repository.url", Value: "", @@ -17,7 +19,7 @@ var ( DisableInConfig: true, Hide: true, }) - BranchFlag = RepositoryFlagGroup.add(Flag{ + BranchFlag = RepositoryFlagGroup.add(flagtypes.Flag{ Name: "branch", ConfigName: "repository.branch", Value: "", @@ -29,7 +31,7 @@ var ( DisableInConfig: true, Hide: true, }) - CommitFlag = RepositoryFlagGroup.add(Flag{ + CommitFlag = RepositoryFlagGroup.add(flagtypes.Flag{ Name: "commit", ConfigName: "repository.commit", Value: "", @@ -41,7 +43,7 @@ var ( DisableInConfig: true, Hide: true, }) - DefaultBranchFlag = RepositoryFlagGroup.add(Flag{ + DefaultBranchFlag = RepositoryFlagGroup.add(flagtypes.Flag{ Name: "default-branch", ConfigName: "repository.default-branch", Value: "", @@ -53,7 +55,7 @@ var ( DisableInConfig: true, Hide: true, }) - DiffBaseBranchFlag = RepositoryFlagGroup.add(Flag{ + DiffBaseBranchFlag = RepositoryFlagGroup.add(flagtypes.Flag{ Name: "diff-base-branch", ConfigName: "repository.diff-base-branch", Value: "", @@ -65,7 +67,7 @@ var ( DisableInConfig: true, Hide: true, }) - DiffBaseCommitFlag = RepositoryFlagGroup.add(Flag{ + DiffBaseCommitFlag = RepositoryFlagGroup.add(flagtypes.Flag{ Name: "diff-base-commit", ConfigName: "repository.diff-base-commit", Value: "", @@ -77,7 +79,7 @@ var ( DisableInConfig: true, Hide: true, }) - GithubTokenFlag = RepositoryFlagGroup.add(Flag{ + GithubTokenFlag = RepositoryFlagGroup.add(flagtypes.Flag{ Name: "github-token", ConfigName: "repository.github-token", Value: "", @@ -88,7 +90,7 @@ var ( DisableInConfig: true, Hide: true, }) - GithubRepositoryFlag = RepositoryFlagGroup.add(Flag{ + GithubRepositoryFlag = RepositoryFlagGroup.add(flagtypes.Flag{ Name: "github-repository", ConfigName: "repository.github-repository", Value: "", @@ -99,7 +101,7 @@ var ( DisableInConfig: true, Hide: true, }) - GithubAPIURLFlag = RepositoryFlagGroup.add(Flag{ + GithubAPIURLFlag = RepositoryFlagGroup.add(flagtypes.Flag{ Name: "github-api-url", ConfigName: "repository.github-api-url", Value: "", @@ -124,8 +126,8 @@ type RepositoryOptions struct { GithubAPIURL string } -func (repositoryFlagGroup) SetOptions(options *Options, args []string) error { - options.RepositoryOptions = RepositoryOptions{ +func (repositoryFlagGroup) SetOptions(options *flagtypes.Options, args []string) error { + options.RepositoryOptions = flagtypes.RepositoryOptions{ OriginURL: getString(RepositoryURLFlag), Branch: getString(BranchFlag), Commit: getString(CommitFlag), diff --git a/internal/flag/rule_flags.go b/internal/flag/rule_flags.go index ee4330078..89dc0f052 100644 --- a/internal/flag/rule_flags.go +++ b/internal/flag/rule_flags.go @@ -1,23 +1,25 @@ package flag +import flagtypes "github.com/bearer/bearer/internal/flag/types" + type ruleFlagGroup struct{ flagGroupBase } var RuleFlagGroup = &ruleFlagGroup{flagGroupBase{name: "Rule"}} var ( - DisableDefaultRulesFlag = RuleFlagGroup.add(Flag{ + DisableDefaultRulesFlag = RuleFlagGroup.add(flagtypes.Flag{ Name: "disable-default-rules", ConfigName: "rule.disable-default-rules", Value: false, Usage: "Disables all default and built-in rules.", }) - SkipRuleFlag = RuleFlagGroup.add(Flag{ + SkipRuleFlag = RuleFlagGroup.add(flagtypes.Flag{ Name: "skip-rule", ConfigName: "rule.skip-rule", Value: []string{}, Usage: "Specify the comma-separated ids of the rules you would like to skip. Runs all other rules.", }) - OnlyRuleFlag = RuleFlagGroup.add(Flag{ + OnlyRuleFlag = RuleFlagGroup.add(flagtypes.Flag{ Name: "only-rule", ConfigName: "rule.only-rule", Value: []string{}, @@ -31,8 +33,8 @@ type RuleOptions struct { OnlyRule map[string]bool `mapstructure:"only-rule" json:"only-rule" yaml:"only-rule"` } -func (ruleFlagGroup) SetOptions(options *Options, args []string) error { - options.RuleOptions = RuleOptions{ +func (ruleFlagGroup) SetOptions(options *flagtypes.Options, args []string) error { + options.RuleOptions = flagtypes.RuleOptions{ DisableDefaultRules: getBool(DisableDefaultRulesFlag), SkipRule: argsToMap(SkipRuleFlag), OnlyRule: argsToMap(OnlyRuleFlag), diff --git a/internal/flag/scan_flags.go b/internal/flag/scan_flags.go index 2c631e201..965b9020f 100644 --- a/internal/flag/scan_flags.go +++ b/internal/flag/scan_flags.go @@ -6,14 +6,13 @@ import ( "strings" "time" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/spf13/viper" ) -type Context string - const ( - Health Context = "health" - Empty Context = "" + Health flagtypes.Context = "health" + Empty flagtypes.Context = "" ScannerSAST = "sast" ScannerSecrets = "secrets" @@ -29,85 +28,85 @@ type scanFlagGroup struct{ flagGroupBase } var ScanFlagGroup = &scanFlagGroup{flagGroupBase{name: "Scan"}} var ( - SkipPathFlag = ScanFlagGroup.add(Flag{ + SkipPathFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "skip-path", ConfigName: "scan.skip-path", Value: []string{}, Usage: "Specify the comma separated files and directories to skip. Supports * syntax, e.g. --skip-path users/*.go,users/admin.sql", }) - DisableDomainResolutionFlag = ScanFlagGroup.add(Flag{ + DisableDomainResolutionFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "disable-domain-resolution", ConfigName: "scan.disable-domain-resolution", Value: true, Usage: "Do not attempt to resolve detected domains during classification", }) - DomainResolutionTimeoutFlag = ScanFlagGroup.add(Flag{ + DomainResolutionTimeoutFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "domain-resolution-timeout", ConfigName: "scan.domain-resolution-timeout", Value: 3 * time.Second, Usage: "Set timeout when attempting to resolve detected domains during classification, e.g. --domain-resolution-timeout=3s", }) - InternalDomainsFlag = ScanFlagGroup.add(Flag{ + InternalDomainsFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "internal-domains", ConfigName: "scan.internal-domains", Value: []string{}, Usage: "Define regular expressions for better classification of private or unreachable domains e.g. --internal-domains=\".*.my-company.com,private.sh\"", }) - ContextFlag = ScanFlagGroup.add(Flag{ + ContextFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "context", ConfigName: "scan.context", Value: "", Usage: "Expand context of schema classification e.g., --context=health, to include data types particular to health", }) - DataSubjectMappingFlag = ScanFlagGroup.add(Flag{ + DataSubjectMappingFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "data-subject-mapping", ConfigName: "scan.data_subject_mapping", Value: "", Usage: "Override default data subject mapping by providing a path to a custom mapping JSON file", }) - QuietFlag = ScanFlagGroup.add(Flag{ + QuietFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "quiet", ConfigName: "scan.quiet", Value: false, Usage: "Suppress non-essential messages", }) - HideProgressBarFlag = ScanFlagGroup.add(Flag{ + HideProgressBarFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "hide-progress-bar", ConfigName: "scan.hide_progress_bar", Value: false, Usage: "Hide progress bar from output", }) - ForceFlag = ScanFlagGroup.add(Flag{ + ForceFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "force", ConfigName: "scan.force", Value: false, Usage: "Disable the cache and runs the detections again", }) - ExternalRuleDirFlag = ScanFlagGroup.add(Flag{ + ExternalRuleDirFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "external-rule-dir", ConfigName: "scan.external-rule-dir", Value: []string{}, Usage: "Specify directories paths that contain .yaml files with external rules configuration", }) - ScannerFlag = ScanFlagGroup.add(Flag{ + ScannerFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "scanner", ConfigName: "scan.scanner", Value: []string{ScannerSAST}, Usage: "Specify which scanner to use e.g. --scanner=secrets, --scanner=secrets,sast", }) - ParallelFlag = ScanFlagGroup.add(Flag{ + ParallelFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "parallel", ConfigName: "scan.parallel", Value: 0, Usage: "Specify the amount of parallelism to use during the scan", }) - ExitCodeFlag = ScanFlagGroup.add(Flag{ + ExitCodeFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "exit-code", ConfigName: "scan.exit-code", Value: -1, Usage: "Force a given exit code for the scan command. Set this to 0 (success) to always return a success exit code despite any findings from the scan.", }) - DiffFlag = ScanFlagGroup.add(Flag{ + DiffFlag = ScanFlagGroup.add(flagtypes.Flag{ Name: "diff", ConfigName: "scan.diff", Value: false, @@ -117,24 +116,24 @@ var ( ) type ScanOptions struct { - Target string `mapstructure:"target" json:"target" yaml:"target"` - SkipPath []string `mapstructure:"skip-path" json:"skip-path" yaml:"skip-path"` - DisableDomainResolution bool `mapstructure:"disable-domain-resolution" json:"disable-domain-resolution" yaml:"disable-domain-resolution"` - DomainResolutionTimeout time.Duration `mapstructure:"domain-resolution-timeout" json:"domain-resolution-timeout" yaml:"domain-resolution-timeout"` - InternalDomains []string `mapstructure:"internal-domains" json:"internal-domains" yaml:"internal-domains"` - Context Context `mapstructure:"context" json:"context" yaml:"context"` - DataSubjectMapping string `mapstructure:"data_subject_mapping" json:"data_subject_mapping" yaml:"data_subject_mapping"` - Quiet bool `mapstructure:"quiet" json:"quiet" yaml:"quiet"` - HideProgressBar bool `mapstructure:"hide_progress_bar" json:"hide_progress_bar" yaml:"hide_progress_bar"` - Force bool `mapstructure:"force" json:"force" yaml:"force"` - ExternalRuleDir []string `mapstructure:"external-rule-dir" json:"external-rule-dir" yaml:"external-rule-dir"` - Scanner []string `mapstructure:"scanner" json:"scanner" yaml:"scanner"` - Parallel int `mapstructure:"parallel" json:"parallel" yaml:"parallel"` - ExitCode int `mapstructure:"exit-code" json:"exit-code" yaml:"exit-code"` - Diff bool `mapstructure:"diff" json:"diff" yaml:"diff"` + Target string `mapstructure:"target" json:"target" yaml:"target"` + SkipPath []string `mapstructure:"skip-path" json:"skip-path" yaml:"skip-path"` + DisableDomainResolution bool `mapstructure:"disable-domain-resolution" json:"disable-domain-resolution" yaml:"disable-domain-resolution"` + DomainResolutionTimeout time.Duration `mapstructure:"domain-resolution-timeout" json:"domain-resolution-timeout" yaml:"domain-resolution-timeout"` + InternalDomains []string `mapstructure:"internal-domains" json:"internal-domains" yaml:"internal-domains"` + Context flagtypes.Context `mapstructure:"context" json:"context" yaml:"context"` + DataSubjectMapping string `mapstructure:"data_subject_mapping" json:"data_subject_mapping" yaml:"data_subject_mapping"` + Quiet bool `mapstructure:"quiet" json:"quiet" yaml:"quiet"` + HideProgressBar bool `mapstructure:"hide_progress_bar" json:"hide_progress_bar" yaml:"hide_progress_bar"` + Force bool `mapstructure:"force" json:"force" yaml:"force"` + ExternalRuleDir []string `mapstructure:"external-rule-dir" json:"external-rule-dir" yaml:"external-rule-dir"` + Scanner []string `mapstructure:"scanner" json:"scanner" yaml:"scanner"` + Parallel int `mapstructure:"parallel" json:"parallel" yaml:"parallel"` + ExitCode int `mapstructure:"exit-code" json:"exit-code" yaml:"exit-code"` + Diff bool `mapstructure:"diff" json:"diff" yaml:"diff"` } -func (scanFlagGroup) SetOptions(options *Options, args []string) error { +func (scanFlagGroup) SetOptions(options *flagtypes.Options, args []string) error { var target string if len(args) == 1 { target = args[0] @@ -160,7 +159,7 @@ func (scanFlagGroup) SetOptions(options *Options, args []string) error { // DIFF_BASE_BRANCH is used for backwards compatibilty diff := getBool(DiffFlag) || os.Getenv("DIFF_BASE_BRANCH") != "" - options.ScanOptions = ScanOptions{ + options.ScanOptions = flagtypes.ScanOptions{ SkipPath: getStringSlice(SkipPathFlag), DisableDomainResolution: getBool(DisableDomainResolutionFlag), DomainResolutionTimeout: getDuration(DomainResolutionTimeoutFlag), @@ -181,11 +180,11 @@ func (scanFlagGroup) SetOptions(options *Options, args []string) error { return nil } -func getContext(flag *Flag) Context { +func getContext(flag *flagtypes.Flag) flagtypes.Context { if flag == nil { return "" } flagStr := strings.ToLower(getString(flag)) - return Context(flagStr) + return flagtypes.Context(flagStr) } diff --git a/internal/flag/types/types.go b/internal/flag/types/types.go new file mode 100644 index 000000000..09067e3e9 --- /dev/null +++ b/internal/flag/types/types.go @@ -0,0 +1,138 @@ +package types + +import ( + "time" + + "github.com/bearer/bearer/api" + "github.com/bearer/bearer/internal/util/set" +) + +type Flag struct { + // Name is for CLI flag and environment variable. + // If this field is empty, it will be available only in config file. + Name string + + // ConfigName is a key in config file. It is also used as a key of viper. + ConfigName string + + // Shorthand is a shorthand letter. + Shorthand string + + // Value is the default value. It must be filled to determine the flag type. + Value interface{} + + // Usage explains how to use the flag. + Usage string + + // DisableInConfig represents if flag should be present in config + DisableInConfig bool + + // Do not show flag in the helper + Hide bool + + // Deprecated represents if the flag is deprecated + Deprecated bool + + // Additional environment variables to read the value from, in addition to the default + EnvironmentVariables []string +} + +type FlagGroup interface { + Name() string + Flags() []*Flag + SetOptions(options *Options, args []string) error +} + +type Context string + +// Options holds all the runtime configuration +type Options struct { + ReportOptions + RuleOptions + ScanOptions + RepositoryOptions + GeneralOptions + IgnoreAddOptions + IgnoreShowOptions + IgnoreMigrateOptions + WorkerOptions +} + +type ScanOptions struct { + Target string `mapstructure:"target" json:"target" yaml:"target"` + SkipPath []string `mapstructure:"skip-path" json:"skip-path" yaml:"skip-path"` + DisableDomainResolution bool `mapstructure:"disable-domain-resolution" json:"disable-domain-resolution" yaml:"disable-domain-resolution"` + DomainResolutionTimeout time.Duration `mapstructure:"domain-resolution-timeout" json:"domain-resolution-timeout" yaml:"domain-resolution-timeout"` + InternalDomains []string `mapstructure:"internal-domains" json:"internal-domains" yaml:"internal-domains"` + Context Context `mapstructure:"context" json:"context" yaml:"context"` + DataSubjectMapping string `mapstructure:"data_subject_mapping" json:"data_subject_mapping" yaml:"data_subject_mapping"` + Quiet bool `mapstructure:"quiet" json:"quiet" yaml:"quiet"` + HideProgressBar bool `mapstructure:"hide_progress_bar" json:"hide_progress_bar" yaml:"hide_progress_bar"` + Force bool `mapstructure:"force" json:"force" yaml:"force"` + ExternalRuleDir []string `mapstructure:"external-rule-dir" json:"external-rule-dir" yaml:"external-rule-dir"` + Scanner []string `mapstructure:"scanner" json:"scanner" yaml:"scanner"` + Parallel int `mapstructure:"parallel" json:"parallel" yaml:"parallel"` + ExitCode int `mapstructure:"exit-code" json:"exit-code" yaml:"exit-code"` + Diff bool `mapstructure:"diff" json:"diff" yaml:"diff"` +} + +type RuleOptions struct { + DisableDefaultRules bool `mapstructure:"disable-default-rules" json:"disable-default-rules" yaml:"disable-default-rules"` + SkipRule map[string]bool `mapstructure:"skip-rule" json:"skip-rule" yaml:"skip-rule"` + OnlyRule map[string]bool `mapstructure:"only-rule" json:"only-rule" yaml:"only-rule"` +} + +type ReportOptions struct { + Format string `mapstructure:"format" json:"format" yaml:"format"` + Report string `mapstructure:"report" json:"report" yaml:"report"` + Output string `mapstructure:"output" json:"output" yaml:"output"` + Severity set.Set[string] `mapstructure:"severity" json:"severity" yaml:"severity"` + FailOnSeverity set.Set[string] `mapstructure:"fail-on-severity" json:"fail-on-severity" yaml:"fail-on-severity"` + ExcludeFingerprint map[string]bool `mapstructure:"exclude_fingerprints" json:"exclude_fingerprints" yaml:"exclude_fingerprints"` +} + +type RepositoryOptions struct { + OriginURL string + Branch string + Commit string + DefaultBranch string + DiffBaseBranch string + DiffBaseCommit string + GithubToken string + GithubRepository string + GithubAPIURL string +} + +// GlobalOptions defines flags and other configuration parameters for all the subcommands +type GeneralOptions struct { + ConfigFile string `json:"config_file" yaml:"config_file"` + Client *api.API + DisableVersionCheck bool + NoColor bool `mapstructure:"no_color" json:"no_color" yaml:"no_color"` + IgnoreFile string `mapstructure:"ignore_file" json:"ignore_file" yaml:"ignore_file"` + Debug bool `mapstructure:"debug" json:"debug" yaml:"debug"` + LogLevel string `mapstructure:"log-level" json:"log-level" yaml:"log-level"` + DebugProfile bool + IgnoreGit bool `mapstructure:"ignore-git" json:"ignore-git" yaml:"ignore-git"` +} + +type IgnoreAddOptions struct { + Author string `mapstructure:"author" json:"author" yaml:"author"` + Comment string `mapstructure:"comment" json:"comment" yaml:"comment"` + FalsePositive bool `mapstructure:"false_positive" json:"false_positive" yaml:"false_positive"` + Force bool `mapstructure:"ignore_add_force" json:"ignore_add_force" yaml:"ignore_add_force"` +} + +type IgnoreShowOptions struct { + All bool `mapstructure:"all" json:"all" yaml:"all"` +} + +type IgnoreMigrateOptions struct { + Force bool `mapstructure:"ignore_migrate_force" json:"ignore_migrate_force" yaml:"ignore_migrate_force"` +} + +type WorkerOptions struct { + ParentProcessID int + WorkerID string `mapstructure:"worker-id" json:"worker-id" yaml:"worker-id"` + Port string `mapstructure:"port" json:"port" yaml:"port"` +} diff --git a/internal/flag/worker_flags.go b/internal/flag/worker_flags.go index f3d115077..8509318a4 100644 --- a/internal/flag/worker_flags.go +++ b/internal/flag/worker_flags.go @@ -1,23 +1,25 @@ package flag +import flagtypes "github.com/bearer/bearer/internal/flag/types" + type workerFlagGroup struct{ flagGroupBase } var WorkerFlagGroup = &workerFlagGroup{flagGroupBase{name: "Worker"}} var ( - ParentProcessIDFlag = WorkerFlagGroup.add(Flag{ + ParentProcessIDFlag = WorkerFlagGroup.add(flagtypes.Flag{ Name: "parent-process-id", ConfigName: "worker.parent-process-id", Value: -1, }) - PortFlag = WorkerFlagGroup.add(Flag{ + PortFlag = WorkerFlagGroup.add(flagtypes.Flag{ Name: "port", ConfigName: "worker.port", Shorthand: "p", Value: "", Usage: "Set the server's listening port.", }) - WorkerIDFlag = WorkerFlagGroup.add(Flag{ + WorkerIDFlag = WorkerFlagGroup.add(flagtypes.Flag{ Name: "worker-id", ConfigName: "worker.id", Value: "", @@ -37,8 +39,8 @@ type WorkerOptions struct { Port string `mapstructure:"port" json:"port" yaml:"port"` } -func (workerFlagGroup) SetOptions(options *Options, args []string) error { - options.WorkerOptions = WorkerOptions{ +func (workerFlagGroup) SetOptions(options *flagtypes.Options, args []string) error { + options.WorkerOptions = flagtypes.WorkerOptions{ ParentProcessID: getInteger(ParentProcessIDFlag), Port: getString(PortFlag), WorkerID: getString(WorkerIDFlag), diff --git a/internal/report/output/privacy/privacy_test.go b/internal/report/output/privacy/privacy_test.go index 0020ac5a4..cb0ca1125 100644 --- a/internal/report/output/privacy/privacy_test.go +++ b/internal/report/output/privacy/privacy_test.go @@ -6,7 +6,7 @@ import ( "github.com/bradleyjkemp/cupaloy" "github.com/bearer/bearer/internal/commands/process/settings" - "github.com/bearer/bearer/internal/flag" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/bearer/bearer/internal/report/output/dataflow/types" "github.com/bearer/bearer/internal/report/output/privacy" "github.com/bearer/bearer/internal/report/output/testhelper" @@ -16,7 +16,7 @@ import ( ) func TestBuildCsvString(t *testing.T) { - config, err := generateConfig(flag.ReportOptions{Report: "privacy"}) + config, err := generateConfig(flagtypes.ReportOptions{Report: "privacy"}) config.Rules = map[string]*settings.Rule{ "ruby_third_parties_sentry": testhelper.RubyThirdPartiesSentryRule(), } @@ -37,7 +37,7 @@ func TestBuildCsvString(t *testing.T) { } func TestAddReportData(t *testing.T) { - config, err := generateConfig(flag.ReportOptions{Report: "privacy"}) + config, err := generateConfig(flagtypes.ReportOptions{Report: "privacy"}) config.Rules = map[string]*settings.Rule{ "ruby_third_parties_sentry": testhelper.RubyThirdPartiesSentryRule(), } @@ -56,14 +56,14 @@ func TestAddReportData(t *testing.T) { cupaloy.SnapshotT(t, output.PrivacyReport) } -func generateConfig(reportOptions flag.ReportOptions) (settings.Config, error) { - opts := flag.Options{ - ScanOptions: flag.ScanOptions{ +func generateConfig(reportOptions flagtypes.ReportOptions) (settings.Config, error) { + opts := flagtypes.Options{ + ScanOptions: flagtypes.ScanOptions{ Scanner: []string{"sast"}, }, - RuleOptions: flag.RuleOptions{}, + RuleOptions: flagtypes.RuleOptions{}, ReportOptions: reportOptions, - GeneralOptions: flag.GeneralOptions{}, + GeneralOptions: flagtypes.GeneralOptions{}, } meta := &version_check.VersionMeta{ diff --git a/internal/report/output/security/security_test.go b/internal/report/output/security/security_test.go index 0b52ba64f..7c1320f9b 100644 --- a/internal/report/output/security/security_test.go +++ b/internal/report/output/security/security_test.go @@ -9,7 +9,7 @@ import ( "github.com/bearer/bearer/internal/commands/process/filelist/files" "github.com/bearer/bearer/internal/commands/process/settings" - "github.com/bearer/bearer/internal/flag" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/bearer/bearer/internal/git" "github.com/bearer/bearer/internal/report/basebranchfindings" "github.com/bearer/bearer/internal/report/schema" @@ -25,7 +25,7 @@ import ( ) func TestBuildReportString(t *testing.T) { - config, err := generateConfig(flag.ReportOptions{Report: "security"}) + config, err := generateConfig(flagtypes.ReportOptions{Report: "security"}) // set rule version config.BearerRulesVersion = "TEST" @@ -59,7 +59,7 @@ func TestBuildReportString(t *testing.T) { } func TestNoRulesBuildReportString(t *testing.T) { - config, err := generateConfig(flag.ReportOptions{Report: "security"}) + config, err := generateConfig(flagtypes.ReportOptions{Report: "security"}) // set rule version config.BearerRulesVersion = "TEST" config.Rules = map[string]*settings.Rule{} @@ -88,7 +88,7 @@ func TestNoRulesBuildReportString(t *testing.T) { } func TestAddReportData(t *testing.T) { - config, err := generateConfig(flag.ReportOptions{Report: "security"}) + config, err := generateConfig(flagtypes.ReportOptions{Report: "security"}) config.Rules = map[string]*settings.Rule{ "ruby_lang_ssl_verification": testhelper.RubyLangSSLVerificationRule(), @@ -113,7 +113,7 @@ func TestAddReportDataWithSeverity(t *testing.T) { severity := set.New[string]() severity.Add(globaltypes.LevelCritical) - config, err := generateConfig(flag.ReportOptions{ + config, err := generateConfig(flagtypes.ReportOptions{ Report: "security", Severity: severity, }) @@ -157,7 +157,7 @@ func TestAddReportDataWithFailOnSeverity(t *testing.T) { severity.Add(test.Severity) } - config, err := generateConfig(flag.ReportOptions{ + config, err := generateConfig(flagtypes.ReportOptions{ Report: "security", Severity: severity, FailOnSeverity: failOnSeverity, @@ -195,7 +195,7 @@ func TestCalculateSeverity(t *testing.T) { } func TestFingerprintIsStableWithBaseBranchFindings(t *testing.T) { - config, err := generateConfig(flag.ReportOptions{Report: "security"}) + config, err := generateConfig(flagtypes.ReportOptions{Report: "security"}) if err != nil { t.Fatalf("failed to generate config:%s", err) } @@ -287,7 +287,7 @@ func TestFingerprintIsStableWithBaseBranchFindings(t *testing.T) { assert.Equal(t, fullScanFinding.Fingerprint, diffFinding.Fingerprint) } -func generateConfig(reportOptions flag.ReportOptions) (settings.Config, error) { +func generateConfig(reportOptions flagtypes.ReportOptions) (settings.Config, error) { if reportOptions.Severity == nil { reportOptions.Severity = set.New[string]() reportOptions.Severity.AddAll(globaltypes.Severities) @@ -301,13 +301,13 @@ func generateConfig(reportOptions flag.ReportOptions) (settings.Config, error) { reportOptions.FailOnSeverity.Add(globaltypes.LevelLow) } - opts := flag.Options{ - ScanOptions: flag.ScanOptions{ + opts := flagtypes.Options{ + ScanOptions: flagtypes.ScanOptions{ Scanner: []string{"sast"}, }, - RuleOptions: flag.RuleOptions{}, + RuleOptions: flagtypes.RuleOptions{}, ReportOptions: reportOptions, - GeneralOptions: flag.GeneralOptions{}, + GeneralOptions: flagtypes.GeneralOptions{}, } meta := &version_check.VersionMeta{ diff --git a/internal/report/output/stats/gocloc_detector.go b/internal/report/output/stats/gocloc_detector.go index 9a4369591..f8283191c 100644 --- a/internal/report/output/stats/gocloc_detector.go +++ b/internal/report/output/stats/gocloc_detector.go @@ -3,14 +3,14 @@ package stats import ( "time" - "github.com/bearer/bearer/internal/flag" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/bearer/bearer/internal/util/output" "github.com/hhatto/gocloc" "github.com/schollz/progressbar/v3" ) -func GoclocDetectorOutput(path string, opts flag.Options) (*gocloc.Result, error) { +func GoclocDetectorOutput(path string, opts flagtypes.Options) (*gocloc.Result, error) { clocOpts := gocloc.NewClocOptions() clocOpts.SkipDuplicated = true output.StdErrLog("Analyzing codebase") @@ -29,7 +29,7 @@ func GoclocDetectorOutput(path string, opts flag.Options) (*gocloc.Result, error return processor.Analyze([]string{path}) } -func hideProgress(opts flag.Options) bool { +func hideProgress(opts flagtypes.Options) bool { return opts.ScanOptions.HideProgressBar || opts.ScanOptions.Quiet || opts.Debug } diff --git a/internal/scanner/detectors/testhelper/testhelper.go b/internal/scanner/detectors/testhelper/testhelper.go index 03b846856..0286998e5 100644 --- a/internal/scanner/detectors/testhelper/testhelper.go +++ b/internal/scanner/detectors/testhelper/testhelper.go @@ -12,6 +12,7 @@ import ( "github.com/bearer/bearer/internal/classification" "github.com/bearer/bearer/internal/commands/process/settings" "github.com/bearer/bearer/internal/flag" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/bearer/bearer/internal/scanner/ast" "github.com/bearer/bearer/internal/scanner/ast/query" "github.com/bearer/bearer/internal/scanner/ast/traversalstrategy" @@ -40,10 +41,10 @@ func RunTest( t.Run(name, func(tt *testing.T) { classifier, err := classification.NewClassifier(&classification.Config{ Config: settings.Config{ - Scan: flag.ScanOptions{ + Scan: flagtypes.ScanOptions{ DisableDomainResolution: true, DomainResolutionTimeout: 0, - Context: flag.Context(flag.Empty), + Context: flagtypes.Context(flag.Empty), }, }, }) diff --git a/internal/version_check/version_check.go b/internal/version_check/version_check.go index e8bd664db..4a4f831ea 100644 --- a/internal/version_check/version_check.go +++ b/internal/version_check/version_check.go @@ -8,6 +8,7 @@ import ( "github.com/bearer/bearer/cmd/bearer/build" "github.com/bearer/bearer/internal/flag" + flagtypes "github.com/bearer/bearer/internal/flag/types" "github.com/bearer/bearer/internal/util/output" ) @@ -26,7 +27,7 @@ type BinaryVersionMeta struct { Message string } -func GetScanVersionMeta(ctx context.Context, options flag.Options, languages []string) (meta *VersionMeta, err error) { +func GetScanVersionMeta(ctx context.Context, options flagtypes.Options, languages []string) (meta *VersionMeta, err error) { if options.RuleOptions.DisableDefaultRules && options.GeneralOptions.DisableVersionCheck { log.Debug().Msg("skipping version API call as check and default rules both disabled")