diff --git a/api/fetch_ignores.go b/api/fetch_ignores.go index 35a1cf360..1404854a9 100644 --- a/api/fetch_ignores.go +++ b/api/fetch_ignores.go @@ -2,7 +2,6 @@ package api import ( "encoding/json" - "os" ignoretypes "github.com/bearer/bearer/internal/util/ignore/types" ) @@ -20,7 +19,7 @@ type CloudIgnorePayload struct { PullRequestNumber string `json:"pull_request_number,omitempty"` } -func (api *API) FetchIgnores(fullname string, localIgnores []string) (*CloudIgnoreData, error) { +func (api *API) FetchIgnores(fullname string, pullRequestNumber string, localIgnores []string) (*CloudIgnoreData, error) { endpoint := Endpoints.FetchIgnores bytes, err := api.makeRequest(endpoint.Route, endpoint.HttpMethod, @@ -29,7 +28,7 @@ func (api *API) FetchIgnores(fullname string, localIgnores []string) (*CloudIgno Data: CloudIgnorePayload{ Project: fullname, LocalIgnores: localIgnores, - PullRequestNumber: os.Getenv("PR_NUMBER"), + PullRequestNumber: pullRequestNumber, }, }) if err != nil { diff --git a/internal/commands/artifact/run.go b/internal/commands/artifact/run.go index 885add9f8..934d18957 100644 --- a/internal/commands/artifact/run.go +++ b/internal/commands/artifact/run.go @@ -231,7 +231,7 @@ func (r *runner) scanBaseBranch( return result, nil } -func getIgnoredFingerprints(client *api.API, settings settings.Config, gitContext *gitrepository.Context) ( +func getIgnoredFingerprints(client *api.API, settings settings.Config, gitContext *gitrepository.Context, pullRequestNumber string) ( useCloudIgnores bool, ignoredFingerprints map[string]ignoretypes.IgnoredFingerprint, staleIgnoredFingerprintIds []string, @@ -246,6 +246,7 @@ func getIgnoredFingerprints(client *api.API, settings settings.Config, gitContex useCloudIgnores, ignoredFingerprints, staleIgnoredFingerprintIds, err = ignore.GetIgnoredFingerprintsFromCloud( client, gitContext.FullName, + pullRequestNumber, localIgnoredFingerprints, ) if err != nil { @@ -309,6 +310,7 @@ func Run(ctx context.Context, opts flagtypes.Options) (err error) { opts.GeneralOptions.Client, scanSettings, gitContext, + opts.PullRequestNumber, ) if err != nil { return err diff --git a/internal/flag/repository_flags.go b/internal/flag/repository_flags.go index d56396cf4..f065f4308 100644 --- a/internal/flag/repository_flags.go +++ b/internal/flag/repository_flags.go @@ -112,31 +112,32 @@ var ( DisableInConfig: true, Hide: true, }) + PullRequestNumberFlag = RepositoryFlagGroup.add(flagtypes.Flag{ + Name: "pull-request-number", + ConfigName: "repository.pull-request-number", + Value: "", + Usage: "Used when fetching branch level ignores for a PR/MR", + EnvironmentVariables: []string{ + "PR_NUMBER", // github + "CI_MERGE_REQUEST_ID", //gitlab + }, + DisableInConfig: true, + Hide: true, + }) ) -type RepositoryOptions struct { - OriginURL string - Branch string - Commit string - DefaultBranch string - DiffBaseBranch string - DiffBaseCommit string - GithubToken string - GithubRepository string - GithubAPIURL string -} - func (repositoryFlagGroup) SetOptions(options *flagtypes.Options, args []string) error { options.RepositoryOptions = flagtypes.RepositoryOptions{ - OriginURL: getString(RepositoryURLFlag), - Branch: getString(BranchFlag), - Commit: getString(CommitFlag), - DefaultBranch: getString(DefaultBranchFlag), - DiffBaseBranch: getString(DiffBaseBranchFlag), - DiffBaseCommit: getString(DiffBaseCommitFlag), - GithubToken: getString(GithubTokenFlag), - GithubRepository: getString(GithubRepositoryFlag), - GithubAPIURL: getString(GithubAPIURLFlag), + OriginURL: getString(RepositoryURLFlag), + Branch: getString(BranchFlag), + Commit: getString(CommitFlag), + DefaultBranch: getString(DefaultBranchFlag), + DiffBaseBranch: getString(DiffBaseBranchFlag), + DiffBaseCommit: getString(DiffBaseCommitFlag), + GithubToken: getString(GithubTokenFlag), + GithubRepository: getString(GithubRepositoryFlag), + GithubAPIURL: getString(GithubAPIURLFlag), + PullRequestNumber: getString(PullRequestNumberFlag), } return nil diff --git a/internal/flag/repository_flags_test.go b/internal/flag/repository_flags_test.go index 1eaa15f8e..7af935df9 100644 --- a/internal/flag/repository_flags_test.go +++ b/internal/flag/repository_flags_test.go @@ -285,3 +285,38 @@ func Test_getRepositoryGithubAPIURLFlag(t *testing.T) { RunFlagTests(testCases, t) } + +func Test_getPullRequestNumberFlag(t *testing.T) { + testCases := []TestCase{ + { + name: "Repository PullRequestNumber. Default", + flag: PullRequestNumberFlag, + flagValue: "", + want: nil, + }, + { + name: "Repository PullRequestNumber. PR_NUMBER env", + flag: PullRequestNumberFlag, + env: Env{ + key: "PR_NUMBER", + value: "42", + }, + want: []string{ + string("42"), + }, + }, + { + name: "Repository PullRequestNumber. CI_MERGE_REQUEST_ID env", + flag: PullRequestNumberFlag, + env: Env{ + key: "CI_MERGE_REQUEST_ID", + value: "24", + }, + want: []string{ + string("24"), + }, + }, + } + + RunFlagTests(testCases, t) +} diff --git a/internal/flag/types/types.go b/internal/flag/types/types.go index 09067e3e9..7af9d413c 100644 --- a/internal/flag/types/types.go +++ b/internal/flag/types/types.go @@ -92,15 +92,16 @@ type ReportOptions struct { } type RepositoryOptions struct { - OriginURL string - Branch string - Commit string - DefaultBranch string - DiffBaseBranch string - DiffBaseCommit string - GithubToken string - GithubRepository string - GithubAPIURL string + OriginURL string + Branch string + Commit string + DefaultBranch string + DiffBaseBranch string + DiffBaseCommit string + GithubToken string + GithubRepository string + GithubAPIURL string + PullRequestNumber string } // GlobalOptions defines flags and other configuration parameters for all the subcommands diff --git a/internal/util/ignore/ignore.go b/internal/util/ignore/ignore.go index 8d0a9d5b5..0ac3b45a7 100644 --- a/internal/util/ignore/ignore.go +++ b/internal/util/ignore/ignore.go @@ -49,6 +49,7 @@ func GetIgnoredFingerprints(filePath string, target *string) (ignoredFingerprint func GetIgnoredFingerprintsFromCloud( client *api.API, fullname string, + pullRequestNumber string, localIgnores map[string]types.IgnoredFingerprint, ) ( useCloudIgnores bool, @@ -56,7 +57,8 @@ func GetIgnoredFingerprintsFromCloud( staleIgnoredFingerprintIds []string, err error, ) { - data, err := client.FetchIgnores(fullname, maps.Keys(localIgnores)) + + data, err := client.FetchIgnores(fullname, pullRequestNumber, maps.Keys(localIgnores)) if err != nil { return useCloudIgnores, ignoredFingerprints, staleIgnoredFingerprintIds, err }