Skip to content

Commit

Permalink
all: minimize use of proxy.DefaultClient
Browse files Browse the repository at this point in the history
Plumb an explicit proxy client through as many functions as possible,
using the default client only in tests and top-level code.

This will allow us to identify and clean up tests that use real proxy
calls and should use mocks.

Change-Id: Ibd6423ea77c2007424c4539fe25f78c5b1f4764a
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/524135
LUCI-TryBot-Result: Go LUCI <[email protected]>
Reviewed-by: Damien Neil <[email protected]>
  • Loading branch information
tatianab committed Sep 11, 2023
1 parent 1aeac8c commit 24e908f
Show file tree
Hide file tree
Showing 18 changed files with 112 additions and 103 deletions.
4 changes: 3 additions & 1 deletion all_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/vulndb/internal/cveschema5"
"golang.org/x/vulndb/internal/osvutils"
"golang.org/x/vulndb/internal/proxy"
"golang.org/x/vulndb/internal/report"
)

Expand Down Expand Up @@ -74,6 +75,7 @@ func TestLintReports(t *testing.T) {
// Map from aliases (CVEs/GHSAS) to report paths, used to check for duplicate aliases.
aliases := make(map[string]string)
sort.Strings(reports)
pc := proxy.DefaultClient
for _, filename := range reports {
t.Run(filename, func(t *testing.T) {
r, err := report.Read(filename)
Expand All @@ -83,7 +85,7 @@ func TestLintReports(t *testing.T) {
if err := r.CheckFilename(filename); err != nil {
t.Error(err)
}
lints := r.Lint()
lints := r.Lint(pc)
if len(lints) > 0 {
t.Errorf(strings.Join(lints, "\n"))
}
Expand Down
4 changes: 3 additions & 1 deletion cmd/issue/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"golang.org/x/vulndb/internal/ghsa"
"golang.org/x/vulndb/internal/gitrepo"
"golang.org/x/vulndb/internal/issues"
"golang.org/x/vulndb/internal/proxy"
"golang.org/x/vulndb/internal/report"
"golang.org/x/vulndb/internal/worker"
)
Expand Down Expand Up @@ -131,11 +132,12 @@ func constructIssue(ctx context.Context, c *issues.Client, ghsaClient *ghsa.Clie
if err != nil {
return err
}
pc := proxy.DefaultClient
for _, sa := range ghsas {
for _, id := range sa.Identifiers {
ids = append(ids, id.Value)
}
body, err := worker.CreateGHSABody(sa, allReports)
body, err := worker.CreateGHSABody(sa, allReports, pc)
if err != nil {
return err
}
Expand Down
42 changes: 23 additions & 19 deletions cmd/vulnreport/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,19 @@ func main() {
}

ghsaClient := ghsa.NewClient(ctx, *githubToken)
pc := proxy.DefaultClient
var cmdFunc func(context.Context, string) error
switch cmd {
case "lint":
cmdFunc = lint
case "commit":
cmdFunc = func(ctx context.Context, name string) error { return commit(ctx, name, ghsaClient, *force) }
cmdFunc = func(ctx context.Context, name string) error { return commit(ctx, name, ghsaClient, pc, *force) }
case "cve":
cmdFunc = func(ctx context.Context, name string) error { return cveCmd(ctx, name) }
case "fix":
cmdFunc = func(ctx context.Context, name string) error { return fix(ctx, name, ghsaClient, *force) }
cmdFunc = func(ctx context.Context, name string) error { return fix(ctx, name, ghsaClient, pc, *force) }
case "osv":
cmdFunc = osvCmd
cmdFunc = func(ctx context.Context, name string) error { return osvCmd(ctx, name, pc) }
case "set-dates":
repo, err := gitrepo.Open(ctx, ".")
if err != nil {
Expand Down Expand Up @@ -269,6 +270,7 @@ func parseArgsToGithubIDs(args []string, existingByIssue map[int]*report.Report)
type createCfg struct {
ghsaClient *ghsa.Client
issuesClient *issues.Client
proxyClient *proxy.Client
existingByFile map[string]*report.Report
existingByIssue map[int]*report.Report
allowClosed bool
Expand Down Expand Up @@ -319,6 +321,7 @@ func setupCreate(ctx context.Context, args []string) ([]int, *createCfg, error)
return githubIDs, &createCfg{
issuesClient: issues.NewClient(ctx, &issues.Config{Owner: owner, Repo: repoName, Token: *githubToken}),
ghsaClient: ghsa.NewClient(ctx, *githubToken),
proxyClient: proxy.DefaultClient,
existingByFile: existingByFile,
existingByIssue: existingByIssue,
allowClosed: *closedOk,
Expand All @@ -327,7 +330,7 @@ func setupCreate(ctx context.Context, args []string) ([]int, *createCfg, error)

func createReport(ctx context.Context, cfg *createCfg, iss *issues.Issue) (r *report.Report, err error) {
defer derrors.Wrap(&err, "createReport(%d)", iss.Number)
parsed, err := parseGithubIssue(iss, cfg.allowClosed)
parsed, err := parseGithubIssue(iss, cfg.proxyClient, cfg.allowClosed)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -483,20 +486,20 @@ func newReport(ctx context.Context, cfg *createCfg, parsed *parsedIssue) (*repor
if err != nil {
return nil, err
}
r = ghsa.ToReport(parsed.id)
r = ghsa.ToReport(parsed.id, cfg.proxyClient)
} else {
ghsa, err := cfg.ghsaClient.FetchGHSA(ctx, parsed.ghsas[0])
if err != nil {
return nil, err
}
r = report.GHSAToReport(ghsa, parsed.modulePath)
r = report.GHSAToReport(ghsa, parsed.modulePath, cfg.proxyClient)
}
case len(parsed.cves) > 0:
cve, err := cvelistrepo.FetchCVE(ctx, loadCVERepo(ctx), parsed.cves[0])
if err != nil {
return nil, err
}
r = report.CVEToReport(cve, parsed.modulePath)
r = report.CVEToReport(cve, parsed.modulePath, cfg.proxyClient)
default:
r = &report.Report{}
}
Expand Down Expand Up @@ -524,7 +527,7 @@ type parsedIssue struct {
excluded report.ExcludedReason
}

func parseGithubIssue(iss *issues.Issue, allowClosed bool) (*parsedIssue, error) {
func parseGithubIssue(iss *issues.Issue, pc *proxy.Client, allowClosed bool) (*parsedIssue, error) {
parsed := &parsedIssue{
id: iss.NewGoID(),
}
Expand Down Expand Up @@ -557,7 +560,7 @@ func parseGithubIssue(iss *issues.Issue, allowClosed bool) (*parsedIssue, error)
// Remove backslashes.
path := strings.ReplaceAll(strings.TrimSuffix(p, ":"), "\"", "")
// Find the underlying module if this is a package path.
if module := proxy.FindModule(parsed.modulePath); module != "" {
if module := pc.FindModule(parsed.modulePath); module != "" {
parsed.modulePath = module
} else {
parsed.modulePath = path
Expand Down Expand Up @@ -717,11 +720,11 @@ func lint(_ context.Context, filename string) (err error) {
defer derrors.Wrap(&err, "lint(%q)", filename)
infolog.Printf("lint %s\n", filename)

_, err = report.ReadAndLint(filename)
_, err = report.ReadAndLint(filename, proxy.DefaultClient)
return err
}

func fix(ctx context.Context, filename string, ghsaClient *ghsa.Client, force bool) (err error) {
func fix(ctx context.Context, filename string, ghsaClient *ghsa.Client, pc *proxy.Client, force bool) (err error) {
defer derrors.Wrap(&err, "fix(%q)", filename)
infolog.Printf("fix %s\n", filename)

Expand All @@ -741,10 +744,10 @@ func fix(ctx context.Context, filename string, ghsaClient *ghsa.Client, force bo
}
}()

if lints := r.Lint(); force || len(lints) > 0 {
r.Fix()
if lints := r.Lint(pc); force || len(lints) > 0 {
r.Fix(pc)
}
if lints := r.Lint(); len(lints) > 0 {
if lints := r.Lint(pc); len(lints) > 0 {
warnlog.Printf("%s still has lint errors after fix:\n\t- %s", filename, strings.Join(lints, "\n\t- "))
}

Expand Down Expand Up @@ -931,9 +934,10 @@ func findExportedSymbols(m *report.Module, p *report.Package) (_ []string, err e
return newslice, nil
}

func osvCmd(_ context.Context, filename string) (err error) {
func osvCmd(_ context.Context, filename string, pc *proxy.Client) (err error) {
defer derrors.Wrap(&err, "osv(%q)", filename)
r, err := report.ReadAndLint(filename)

r, err := report.ReadAndLint(filename, pc)
if err != nil {
return err
}
Expand Down Expand Up @@ -975,15 +979,15 @@ func writeCVE(r *report.Report) error {
return database.WriteJSON(r.CVEFilename(), cve, true)
}

func commit(ctx context.Context, filename string, ghsaClient *ghsa.Client, force bool) (err error) {
func commit(ctx context.Context, filename string, ghsaClient *ghsa.Client, pc *proxy.Client, force bool) (err error) {
defer derrors.Wrap(&err, "commit(%q)", filename)

// Clean up the report file and lint the result.
// Stop if there any problems.
if err := fix(ctx, filename, ghsaClient, force); err != nil {
if err := fix(ctx, filename, ghsaClient, pc, force); err != nil {
return err
}
r, err := report.ReadAndLint(filename)
r, err := report.ReadAndLint(filename, pc)
if err != nil {
return err
}
Expand Down
4 changes: 3 additions & 1 deletion cmd/worker/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"golang.org/x/vulndb/internal/ghsa"
"golang.org/x/vulndb/internal/gitrepo"
"golang.org/x/vulndb/internal/issues"
"golang.org/x/vulndb/internal/proxy"
"golang.org/x/vulndb/internal/report"
"golang.org/x/vulndb/internal/worker"
"golang.org/x/vulndb/internal/worker/log"
Expand Down Expand Up @@ -262,7 +263,8 @@ func createIssuesCommand(ctx context.Context) error {
if err != nil {
return err
}
return worker.CreateIssues(ctx, cfg.Store, client, allReports, *limit)
pc := proxy.DefaultClient
return worker.CreateIssues(ctx, cfg.Store, client, pc, allReports, *limit)
}

func showCommand(ctx context.Context, ids []string) error {
Expand Down
13 changes: 7 additions & 6 deletions internal/genericosv/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ import (
"golang.org/x/vulndb/internal/cveschema5"
"golang.org/x/vulndb/internal/ghsa"
"golang.org/x/vulndb/internal/osv"
"golang.org/x/vulndb/internal/proxy"
"golang.org/x/vulndb/internal/report"
"golang.org/x/vulndb/internal/version"
)

// ToReport converts OSV into a Go Report with the given ID.
func (osv *Entry) ToReport(goID string) *report.Report {
func (osv *Entry) ToReport(goID string, pc *proxy.Client) *report.Report {
r := &report.Report{
ID: goID,
Summary: osv.Summary,
Expand All @@ -45,10 +46,10 @@ func (osv *Entry) ToReport(goID string) *report.Report {
for _, ref := range osv.References {
r.References = append(r.References, convertRef(ref))
}
r.Modules = affectedToModules(osv.Affected, addNote)
r.Modules = affectedToModules(osv.Affected, addNote, pc)
r.Credits = convertCredits(osv.Credits)
r.Fix()
if lints := r.Lint(); len(lints) > 0 {
r.Fix(pc)
if lints := r.Lint(pc); len(lints) > 0 {
slices.Sort(lints)
for _, lint := range lints {
addNote(fmt.Sprintf("lint: %s", lint))
Expand All @@ -59,7 +60,7 @@ func (osv *Entry) ToReport(goID string) *report.Report {

type addNoteFunc func(string)

func affectedToModules(as []osvschema.Affected, addNote addNoteFunc) []*report.Module {
func affectedToModules(as []osvschema.Affected, addNote addNoteFunc, pc *proxy.Client) []*report.Module {
var modules []*report.Module
for _, a := range as {
if a.Package.Ecosystem != osvschema.EcosystemGo {
Expand All @@ -73,7 +74,7 @@ func affectedToModules(as []osvschema.Affected, addNote addNoteFunc) []*report.M
}

for _, m := range modules {
m.FixVersions()
m.FixVersions(pc)
}

sortModules(modules)
Expand Down
11 changes: 6 additions & 5 deletions internal/genericosv/report_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ var (
func TestToReport(t *testing.T) {
if *realProxy {
defer func() {
err := updateProxyResponses()
err := updateProxyResponses(proxy.DefaultClient)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -67,7 +67,7 @@ func TestToReport(t *testing.T) {
t.Fatal(err)
}

got := osv.ToReport("GO-TEST-ID")
got := osv.ToReport("GO-TEST-ID", proxy.DefaultClient)
yamlFile := filepath.Join(testYAMLDir, ghsaID+".yaml")
if *update {
if err := got.Write(yamlFile); err != nil {
Expand All @@ -91,6 +91,7 @@ func TestToReport(t *testing.T) {

// TODO(https://go.dev/issues/61769): unskip test cases as we add features.
func TestAffectedToModules(t *testing.T) {
pc := proxy.DefaultClient
for _, tc := range []struct {
desc string
in []osvschema.Affected
Expand Down Expand Up @@ -272,7 +273,7 @@ func TestAffectedToModules(t *testing.T) {
addNote := func(note string) {
gotNotes = append(gotNotes, note)
}
got := affectedToModules(tc.in, addNote)
got := affectedToModules(tc.in, addNote, pc)
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("affectedToModules() mismatch (-want +got)\n%s", diff)
}
Expand Down Expand Up @@ -310,8 +311,8 @@ func setupMockProxy(t *testing.T) error {
}

// Write proxy responses for this run to testdata/proxy.json.
func updateProxyResponses() error {
responses, err := json.MarshalIndent(proxy.Responses(), "", "\t")
func updateProxyResponses(pc *proxy.Client) error {
responses, err := json.MarshalIndent(pc.Responses(), "", "\t")
if err != nil {
return err
}
Expand Down
30 changes: 5 additions & 25 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@ import (
"golang.org/x/vulndb/internal/version"
)

// TODO(https://go.dev/issues/60275): Replace this with a function that
// returns a new instance of a default client.
var DefaultClient *Client

const ProxyURL = "https://proxy.golang.org"

// Client is a client for reading from the proxy.
//
// It uses a simple in-memory cache that does not expire,
Expand All @@ -39,7 +43,7 @@ type Client struct {
}

func init() {
proxyURL := "https://proxy.golang.org"
proxyURL := ProxyURL
if proxy, ok := os.LookupEnv("GOPROXY"); ok {
proxyURL = proxy
}
Expand Down Expand Up @@ -110,10 +114,6 @@ func (c *Client) lookup(urlSuffix string) ([]byte, error) {
return b, nil
}

func CanonicalModulePath(path, version string) (string, error) {
return DefaultClient.CanonicalModulePath(path, version)
}

func (c *Client) CanonicalModulePath(path, version string) (_ string, err error) {
escapedPath, err := module.EscapePath(path)
if err != nil {
Expand All @@ -137,10 +137,6 @@ func (c *Client) CanonicalModulePath(path, version string) (_ string, err error)
return m.Module.Mod.Path, nil
}

func CanonicalModuleVersion(path, ver string) (_ string, err error) {
return DefaultClient.CanonicalModuleVersion(path, ver)
}

// CanonicalModuleVersion returns the canonical version string (with no leading "v" prefix)
// for the given module path and version string.
func (c *Client) CanonicalModuleVersion(path, ver string) (_ string, err error) {
Expand All @@ -163,10 +159,6 @@ func (c *Client) CanonicalModuleVersion(path, ver string) (_ string, err error)
return version.TrimPrefix(v), nil
}

func Latest(path string) (string, error) {
return DefaultClient.Latest(path)
}

// Latest returns the latest version of the module, with no leading "v"
// prefix.
func (c *Client) Latest(path string) (string, error) {
Expand All @@ -189,10 +181,6 @@ func (c *Client) Latest(path string) (string, error) {
return version.TrimPrefix(ver), nil
}

func Versions(path string) ([]string, error) {
return DefaultClient.Versions(path)
}

// Versions returns a list of module versions (with no leading "v" prefix),
// sorted in ascending order.
func (c *Client) Versions(path string) ([]string, error) {
Expand All @@ -217,10 +205,6 @@ func (c *Client) Versions(path string) ([]string, error) {
return vs, nil
}

func FindModule(path string) string {
return DefaultClient.FindModule(path)
}

// FindModule returns the longest directory prefix of path that
// is a module, or "" if no such prefix is found.
func (c *Client) FindModule(modPath string) string {
Expand All @@ -238,10 +222,6 @@ func (c *Client) FindModule(modPath string) string {
return ""
}

func Responses() map[string]*Response {
return DefaultClient.Responses()
}

// Responses returns a map from endpoints to the latest response received for each endpoint.
//
// Intended for testing: the output can be passed to NewTestClient to create a mock client
Expand Down
Loading

0 comments on commit 24e908f

Please sign in to comment.