diff --git a/main.go b/main.go index 53485a6e83fd..6211ddd8c9b3 100644 --- a/main.go +++ b/main.go @@ -476,7 +476,7 @@ func run(state overseer.State) { } } // asynchronously wait for scanning to finish and cleanup - go e.Finish(ctx) + go e.Finish(ctx, logFatal) if !*jsonLegacy && !*jsonOut { fmt.Fprintf(os.Stderr, "🐷🔑🐷 TruffleHog. Unearth your secrets. 🐷🔑🐷\n\n") diff --git a/pkg/engine/circleci.go b/pkg/engine/circleci.go index 44f00cac63dd..fb665eda7efb 100644 --- a/pkg/engine/circleci.go +++ b/pkg/engine/circleci.go @@ -1,6 +1,7 @@ package engine import ( + "fmt" "runtime" "github.com/go-errors/errors" @@ -38,14 +39,13 @@ func (e *Engine) ScanCircleCI(ctx context.Context, token string) error { return errors.WrapPrefix(err, "failed to init Circle CI source", 0) } - e.sourcesWg.Add(1) - go func() { + e.sourcesWg.Go(func() error { defer common.RecoverWithExit(ctx) - defer e.sourcesWg.Done() err := circleSource.Chunks(ctx, e.ChunksChan()) if err != nil { - ctx.Logger().Error(err, "error scanning Circle CI") + return fmt.Errorf("error scanning CircleCI: %w", err) } - }() + return nil + }) return nil } diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 38d67dd2673b..ac8afbc15660 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -10,6 +10,7 @@ import ( "time" ahocorasick "github.com/petar-dambovaliev/aho-corasick" + "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" "github.com/trufflesecurity/trufflehog/v3/pkg/common" @@ -32,7 +33,7 @@ type Engine struct { chunksScanned uint64 bytesScanned uint64 detectorAvgTime sync.Map - sourcesWg sync.WaitGroup + sourcesWg *errgroup.Group workersWg sync.WaitGroup // filterUnverified is used to reduce the number of unverified results. // If there are multiple unverified results for the same chunk for the same detector, @@ -123,6 +124,11 @@ func Start(ctx context.Context, options ...EngineOption) *Engine { } ctx.Logger().V(2).Info("engine started", "workers", e.concurrency) + sourcesWg, egCtx := errgroup.WithContext(ctx) + sourcesWg.SetLimit(e.concurrency) + e.sourcesWg = sourcesWg + ctx.SetParent(egCtx) + if len(e.decoders) == 0 { e.decoders = decoders.DefaultDecoders() } @@ -188,10 +194,14 @@ func Start(ctx context.Context, options ...EngineOption) *Engine { // Finish waits for running sources to complete and workers to finish scanning // chunks before closing their respective channels. Once Finish is called, no // more sources may be scanned by the engine. -func (e *Engine) Finish(ctx context.Context) { +func (e *Engine) Finish(ctx context.Context, logFunc func(error, string, ...any)) { defer common.RecoverWithExit(ctx) // wait for the sources to finish putting chunks onto the chunks channel - e.sourcesWg.Wait() + sourceErr := e.sourcesWg.Wait() + if sourceErr != nil { + logFunc(sourceErr, "error occurred while collecting chunks") + } + close(e.chunks) // wait for the workers to finish processing all of the chunks and putting // results onto the results channel diff --git a/pkg/engine/filesystem.go b/pkg/engine/filesystem.go index 46ac5408d662..7ad70626a5e4 100644 --- a/pkg/engine/filesystem.go +++ b/pkg/engine/filesystem.go @@ -1,6 +1,7 @@ package engine import ( + "fmt" "runtime" "github.com/go-errors/errors" @@ -37,14 +38,13 @@ func (e *Engine) ScanFileSystem(ctx context.Context, c sources.FilesystemConfig) return errors.WrapPrefix(err, "could not init filesystem source", 0) } fileSystemSource.WithFilter(c.Filter) - e.sourcesWg.Add(1) - go func() { + e.sourcesWg.Go(func() error { defer common.RecoverWithExit(ctx) - defer e.sourcesWg.Done() err := fileSystemSource.Chunks(ctx, e.ChunksChan()) if err != nil { - ctx.Logger().Error(err, "error scanning filesystem") + return fmt.Errorf("error scanning filesystem: %w", err) } - }() + return nil + }) return nil } diff --git a/pkg/engine/gcs.go b/pkg/engine/gcs.go index 75f3a44536f4..1ae24f901d9a 100644 --- a/pkg/engine/gcs.go +++ b/pkg/engine/gcs.go @@ -54,14 +54,13 @@ func (e *Engine) ScanGCS(ctx context.Context, c sources.GCSConfig) error { return fmt.Errorf("failed to initialize GCS source: %w", err) } - e.sourcesWg.Add(1) - go func() { + e.sourcesWg.Go(func() error { defer common.RecoverWithExit(ctx) - defer e.sourcesWg.Done() if err := source.Chunks(ctx, e.ChunksChan()); err != nil { - ctx.Logger().Error(err, "could not scan GCS") + return fmt.Errorf("could not scan GCS: %w", err) } - }() + return nil + }) return nil } diff --git a/pkg/engine/gcs_test.go b/pkg/engine/gcs_test.go index e66e7bd9297c..d1b225f91a67 100644 --- a/pkg/engine/gcs_test.go +++ b/pkg/engine/gcs_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/trufflesecurity/trufflehog/v3/pkg/context" + "github.com/trufflesecurity/trufflehog/v3/pkg/decoders" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" ) @@ -55,12 +56,30 @@ func TestScanGCS(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e := &Engine{} - err := e.ScanGCS(context.Background(), test.gcsConfig) + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + e := Start(ctx, + WithConcurrency(1), + WithDecoders(decoders.DefaultDecoders()...), + WithDetectors(false, DefaultDetectors()...), + ) + go func() { + resultCount := 0 + for range e.ResultsChan() { + resultCount++ + } + }() + + err := e.ScanGCS(ctx, test.gcsConfig) if err != nil && !test.wantErr && !strings.Contains(err.Error(), "googleapi: Error 400: Bad Request") { t.Errorf("ScanGCS() got: %v, want: %v", err, nil) return } + logFatalFunc := func(_ error, _ string, _ ...any) { + t.Fatalf("error logging function should not have been called") + } + e.Finish(ctx, logFatalFunc) if err == nil && test.wantErr { t.Errorf("ScanGCS() got: %v, want: %v", err, "error") diff --git a/pkg/engine/git.go b/pkg/engine/git.go index 9d7ecbb6e47d..370b4642a19b 100644 --- a/pkg/engine/git.go +++ b/pkg/engine/git.go @@ -66,14 +66,13 @@ func (e *Engine) ScanGit(ctx context.Context, c sources.GitConfig) error { "source_type", sourcespb.SourceType_SOURCE_TYPE_GIT.String(), "source_name", "git", ) - e.sourcesWg.Add(1) - go func() { + e.sourcesWg.Go(func() error { defer common.RecoverWithExit(ctx) - defer e.sourcesWg.Done() err := gitSource.ScanRepo(ctx, repo, c.RepoPath, scanOptions, e.ChunksChan()) if err != nil { - ctx.Logger().Error(err, "could not scan repo") + return fmt.Errorf("could not scan repo: %w", err) } - }() + return nil + }) return nil } diff --git a/pkg/engine/git_test.go b/pkg/engine/git_test.go index 95f2052d5e16..036172930fa5 100644 --- a/pkg/engine/git_test.go +++ b/pkg/engine/git_test.go @@ -53,39 +53,45 @@ func TestGitEngine(t *testing.T) { base: "2f251b8c1e72135a375b659951097ec7749d4af9", }, } { - e := Start(ctx, - WithConcurrency(1), - WithDecoders(decoders.DefaultDecoders()...), - WithDetectors(false, DefaultDetectors()...), - ) - cfg := sources.GitConfig{ - RepoPath: path, - HeadRef: tTest.branch, - BaseRef: tTest.base, - MaxDepth: tTest.maxDepth, - Filter: tTest.filter, - } - if err := e.ScanGit(ctx, cfg); err != nil { - return - } - go e.Finish(ctx) - resultCount := 0 - for result := range e.ResultsChan() { - switch meta := result.SourceMetadata.GetData().(type) { - case *source_metadatapb.MetaData_Git: - if tTest.expected[meta.Git.Commit].B != string(result.Raw) { - t.Errorf("%s: unexpected result. Got: %s, Expected: %s", tName, string(result.Raw), tTest.expected[meta.Git.Commit].B) - } - if tTest.expected[meta.Git.Commit].LineNumber != result.SourceMetadata.GetGit().Line { - t.Errorf("%s: unexpected line number. Got: %d, Expected: %d", tName, result.SourceMetadata.GetGit().Line, tTest.expected[meta.Git.Commit].LineNumber) - } + t.Run(tName, func(t *testing.T) { + e := Start(ctx, + WithConcurrency(1), + WithDecoders(decoders.DefaultDecoders()...), + WithDetectors(false, DefaultDetectors()...), + ) + cfg := sources.GitConfig{ + RepoPath: path, + HeadRef: tTest.branch, + BaseRef: tTest.base, + MaxDepth: tTest.maxDepth, + Filter: tTest.filter, + } + if err := e.ScanGit(ctx, cfg); err != nil { + return } - resultCount++ - } - if resultCount != len(tTest.expected) { - t.Errorf("%s: unexpected number of results. Got: %d, Expected: %d", tName, resultCount, len(tTest.expected)) - } + logFatalFunc := func(_ error, _ string, _ ...any) { + t.Fatalf("error logging function should not have been called") + } + go e.Finish(ctx, logFatalFunc) + resultCount := 0 + for result := range e.ResultsChan() { + switch meta := result.SourceMetadata.GetData().(type) { + case *source_metadatapb.MetaData_Git: + if tTest.expected[meta.Git.Commit].B != string(result.Raw) { + t.Errorf("%s: unexpected result. Got: %s, Expected: %s", tName, string(result.Raw), tTest.expected[meta.Git.Commit].B) + } + if tTest.expected[meta.Git.Commit].LineNumber != result.SourceMetadata.GetGit().Line { + t.Errorf("%s: unexpected line number. Got: %d, Expected: %d", tName, result.SourceMetadata.GetGit().Line, tTest.expected[meta.Git.Commit].LineNumber) + } + } + resultCount++ + + } + if resultCount != len(tTest.expected) { + t.Errorf("%s: unexpected number of results. Got: %d, Expected: %d", tName, resultCount, len(tTest.expected)) + } + }) } } @@ -124,5 +130,8 @@ func BenchmarkGitEngine(b *testing.B) { return } } - e.Finish(ctx) + logFatalFunc := func(_ error, _ string, _ ...any) { + b.Fatalf("error logging function should not have been called") + } + e.Finish(ctx, logFatalFunc) } diff --git a/pkg/engine/github.go b/pkg/engine/github.go index 4038cd4a1c95..79d34d76806f 100644 --- a/pkg/engine/github.go +++ b/pkg/engine/github.go @@ -1,6 +1,8 @@ package engine import ( + "fmt" + gogit "github.com/go-git/go-git/v5" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" @@ -58,14 +60,13 @@ func (e *Engine) ScanGitHub(ctx context.Context, c sources.GithubConfig) error { scanOptions := git.NewScanOptions(opts...) source.WithScanOptions(scanOptions) - e.sourcesWg.Add(1) - go func() { + e.sourcesWg.Go(func() error { defer common.RecoverWithExit(ctx) - defer e.sourcesWg.Done() err := source.Chunks(ctx, e.ChunksChan()) if err != nil { - ctx.Logger().Error(err, "could not scan github") + return fmt.Errorf("could not scan github: %w", err) } - }() + return nil + }) return nil } diff --git a/pkg/engine/gitlab.go b/pkg/engine/gitlab.go index f7fc95994fee..ab0f58d20e47 100644 --- a/pkg/engine/gitlab.go +++ b/pkg/engine/gitlab.go @@ -63,14 +63,13 @@ func (e *Engine) ScanGitLab(ctx context.Context, c sources.GitlabConfig) error { } gitlabSource.WithScanOptions(scanOptions) - e.sourcesWg.Add(1) - go func() { + e.sourcesWg.Go(func() error { defer common.RecoverWithExit(ctx) - defer e.sourcesWg.Done() err := gitlabSource.Chunks(ctx, e.ChunksChan()) if err != nil { - ctx.Logger().Error(err, "error scanning GitLab") + return fmt.Errorf("error scanning GitLab: %w", err) } - }() + return nil + }) return nil } diff --git a/pkg/engine/s3.go b/pkg/engine/s3.go index efae64eef47e..a2056010ae41 100644 --- a/pkg/engine/s3.go +++ b/pkg/engine/s3.go @@ -65,14 +65,13 @@ func (e *Engine) ScanS3(ctx context.Context, c sources.S3Config) error { return errors.WrapPrefix(err, "failed to init S3 source", 0) } - e.sourcesWg.Add(1) - go func() { + e.sourcesWg.Go(func() error { defer common.RecoverWithExit(ctx) - defer e.sourcesWg.Done() err := s3Source.Chunks(ctx, e.ChunksChan()) if err != nil { - ctx.Logger().Error(err, "error scanning S3") + return fmt.Errorf("error scanning S3: %w", err) } - }() + return nil + }) return nil } diff --git a/pkg/engine/syslog.go b/pkg/engine/syslog.go index d09ea7731c98..d93cfa973ae3 100644 --- a/pkg/engine/syslog.go +++ b/pkg/engine/syslog.go @@ -1,6 +1,7 @@ package engine import ( + "fmt" "os" "github.com/go-errors/errors" @@ -53,14 +54,13 @@ func (e *Engine) ScanSyslog(ctx context.Context, c sources.SyslogConfig) error { return err } - e.sourcesWg.Add(1) - go func() { + e.sourcesWg.Go(func() error { defer common.RecoverWithExit(ctx) - defer e.sourcesWg.Done() err := source.Chunks(ctx, e.ChunksChan()) if err != nil { - ctx.Logger().Error(err, "could not scan syslog") + return fmt.Errorf("could not scan syslog: %w", err) } - }() + return nil + }) return nil } diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index 5481201ca5a8..2693169a7ac0 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -197,7 +197,10 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err }) if err != nil { - s.log.Error(err, "could not list objects in s3 bucket", "bucket", bucket) + return fmt.Errorf( + "could not list objects in s3 bucket: bucket %s: %w", + bucket, + err) } } s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount), "")