Skip to content

Commit

Permalink
Exit with non-zero exit code on chunk source error (trufflesecurity#1286
Browse files Browse the repository at this point in the history
)

* Exit with non-zero exit code on chunk source error

* Exit with a non-zero exit code whenever we hit an error getting
  chunks. Previously the error would be logged but trufflehog would exit
  with a 0 (success) status code.

* fix gcs test

---------

Co-authored-by: Dustin Decker <[email protected]>
Co-authored-by: ahrav <[email protected]>
  • Loading branch information
3 people authored Jun 26, 2023
1 parent 7cefea6 commit da5301e
Show file tree
Hide file tree
Showing 13 changed files with 117 additions and 79 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ func run(state overseer.State) {
}
}
// asynchronously wait for scanning to finish and cleanup
go e.Finish(ctx)
go e.Finish(ctx, logFatal)

if !*jsonLegacy && !*jsonOut {
fmt.Fprintf(os.Stderr, "🐷🔑🐷 TruffleHog. Unearth your secrets. 🐷🔑🐷\n\n")
Expand Down
10 changes: 5 additions & 5 deletions pkg/engine/circleci.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package engine

import (
"fmt"
"runtime"

"github.com/go-errors/errors"
Expand Down Expand Up @@ -38,14 +39,13 @@ func (e *Engine) ScanCircleCI(ctx context.Context, token string) error {
return errors.WrapPrefix(err, "failed to init Circle CI source", 0)
}

e.sourcesWg.Add(1)
go func() {
e.sourcesWg.Go(func() error {
defer common.RecoverWithExit(ctx)
defer e.sourcesWg.Done()
err := circleSource.Chunks(ctx, e.ChunksChan())
if err != nil {
ctx.Logger().Error(err, "error scanning Circle CI")
return fmt.Errorf("error scanning CircleCI: %w", err)
}
}()
return nil
})
return nil
}
16 changes: 13 additions & 3 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

ahocorasick "github.com/petar-dambovaliev/aho-corasick"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"

"github.com/trufflesecurity/trufflehog/v3/pkg/common"
Expand All @@ -32,7 +33,7 @@ type Engine struct {
chunksScanned uint64
bytesScanned uint64
detectorAvgTime sync.Map
sourcesWg sync.WaitGroup
sourcesWg *errgroup.Group
workersWg sync.WaitGroup
// filterUnverified is used to reduce the number of unverified results.
// If there are multiple unverified results for the same chunk for the same detector,
Expand Down Expand Up @@ -123,6 +124,11 @@ func Start(ctx context.Context, options ...EngineOption) *Engine {
}
ctx.Logger().V(2).Info("engine started", "workers", e.concurrency)

sourcesWg, egCtx := errgroup.WithContext(ctx)
sourcesWg.SetLimit(e.concurrency)
e.sourcesWg = sourcesWg
ctx.SetParent(egCtx)

if len(e.decoders) == 0 {
e.decoders = decoders.DefaultDecoders()
}
Expand Down Expand Up @@ -188,10 +194,14 @@ func Start(ctx context.Context, options ...EngineOption) *Engine {
// Finish waits for running sources to complete and workers to finish scanning
// chunks before closing their respective channels. Once Finish is called, no
// more sources may be scanned by the engine.
func (e *Engine) Finish(ctx context.Context) {
func (e *Engine) Finish(ctx context.Context, logFunc func(error, string, ...any)) {
defer common.RecoverWithExit(ctx)
// wait for the sources to finish putting chunks onto the chunks channel
e.sourcesWg.Wait()
sourceErr := e.sourcesWg.Wait()
if sourceErr != nil {
logFunc(sourceErr, "error occurred while collecting chunks")
}

close(e.chunks)
// wait for the workers to finish processing all of the chunks and putting
// results onto the results channel
Expand Down
10 changes: 5 additions & 5 deletions pkg/engine/filesystem.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package engine

import (
"fmt"
"runtime"

"github.com/go-errors/errors"
Expand Down Expand Up @@ -37,14 +38,13 @@ func (e *Engine) ScanFileSystem(ctx context.Context, c sources.FilesystemConfig)
return errors.WrapPrefix(err, "could not init filesystem source", 0)
}
fileSystemSource.WithFilter(c.Filter)
e.sourcesWg.Add(1)
go func() {
e.sourcesWg.Go(func() error {
defer common.RecoverWithExit(ctx)
defer e.sourcesWg.Done()
err := fileSystemSource.Chunks(ctx, e.ChunksChan())
if err != nil {
ctx.Logger().Error(err, "error scanning filesystem")
return fmt.Errorf("error scanning filesystem: %w", err)
}
}()
return nil
})
return nil
}
9 changes: 4 additions & 5 deletions pkg/engine/gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,13 @@ func (e *Engine) ScanGCS(ctx context.Context, c sources.GCSConfig) error {
return fmt.Errorf("failed to initialize GCS source: %w", err)
}

e.sourcesWg.Add(1)
go func() {
e.sourcesWg.Go(func() error {
defer common.RecoverWithExit(ctx)
defer e.sourcesWg.Done()
if err := source.Chunks(ctx, e.ChunksChan()); err != nil {
ctx.Logger().Error(err, "could not scan GCS")
return fmt.Errorf("could not scan GCS: %w", err)
}
}()
return nil
})
return nil
}

Expand Down
23 changes: 21 additions & 2 deletions pkg/engine/gcs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/decoders"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
)

Expand Down Expand Up @@ -55,12 +56,30 @@ func TestScanGCS(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
e := &Engine{}
err := e.ScanGCS(context.Background(), test.gcsConfig)
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()

e := Start(ctx,
WithConcurrency(1),
WithDecoders(decoders.DefaultDecoders()...),
WithDetectors(false, DefaultDetectors()...),
)
go func() {
resultCount := 0
for range e.ResultsChan() {
resultCount++
}
}()

err := e.ScanGCS(ctx, test.gcsConfig)
if err != nil && !test.wantErr && !strings.Contains(err.Error(), "googleapi: Error 400: Bad Request") {
t.Errorf("ScanGCS() got: %v, want: %v", err, nil)
return
}
logFatalFunc := func(_ error, _ string, _ ...any) {
t.Fatalf("error logging function should not have been called")
}
e.Finish(ctx, logFatalFunc)

if err == nil && test.wantErr {
t.Errorf("ScanGCS() got: %v, want: %v", err, "error")
Expand Down
9 changes: 4 additions & 5 deletions pkg/engine/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,13 @@ func (e *Engine) ScanGit(ctx context.Context, c sources.GitConfig) error {
"source_type", sourcespb.SourceType_SOURCE_TYPE_GIT.String(),
"source_name", "git",
)
e.sourcesWg.Add(1)
go func() {
e.sourcesWg.Go(func() error {
defer common.RecoverWithExit(ctx)
defer e.sourcesWg.Done()
err := gitSource.ScanRepo(ctx, repo, c.RepoPath, scanOptions, e.ChunksChan())
if err != nil {
ctx.Logger().Error(err, "could not scan repo")
return fmt.Errorf("could not scan repo: %w", err)
}
}()
return nil
})
return nil
}
73 changes: 41 additions & 32 deletions pkg/engine/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,39 +53,45 @@ func TestGitEngine(t *testing.T) {
base: "2f251b8c1e72135a375b659951097ec7749d4af9",
},
} {
e := Start(ctx,
WithConcurrency(1),
WithDecoders(decoders.DefaultDecoders()...),
WithDetectors(false, DefaultDetectors()...),
)
cfg := sources.GitConfig{
RepoPath: path,
HeadRef: tTest.branch,
BaseRef: tTest.base,
MaxDepth: tTest.maxDepth,
Filter: tTest.filter,
}
if err := e.ScanGit(ctx, cfg); err != nil {
return
}
go e.Finish(ctx)
resultCount := 0
for result := range e.ResultsChan() {
switch meta := result.SourceMetadata.GetData().(type) {
case *source_metadatapb.MetaData_Git:
if tTest.expected[meta.Git.Commit].B != string(result.Raw) {
t.Errorf("%s: unexpected result. Got: %s, Expected: %s", tName, string(result.Raw), tTest.expected[meta.Git.Commit].B)
}
if tTest.expected[meta.Git.Commit].LineNumber != result.SourceMetadata.GetGit().Line {
t.Errorf("%s: unexpected line number. Got: %d, Expected: %d", tName, result.SourceMetadata.GetGit().Line, tTest.expected[meta.Git.Commit].LineNumber)
}
t.Run(tName, func(t *testing.T) {
e := Start(ctx,
WithConcurrency(1),
WithDecoders(decoders.DefaultDecoders()...),
WithDetectors(false, DefaultDetectors()...),
)
cfg := sources.GitConfig{
RepoPath: path,
HeadRef: tTest.branch,
BaseRef: tTest.base,
MaxDepth: tTest.maxDepth,
Filter: tTest.filter,
}
if err := e.ScanGit(ctx, cfg); err != nil {
return
}
resultCount++

}
if resultCount != len(tTest.expected) {
t.Errorf("%s: unexpected number of results. Got: %d, Expected: %d", tName, resultCount, len(tTest.expected))
}
logFatalFunc := func(_ error, _ string, _ ...any) {
t.Fatalf("error logging function should not have been called")
}
go e.Finish(ctx, logFatalFunc)
resultCount := 0
for result := range e.ResultsChan() {
switch meta := result.SourceMetadata.GetData().(type) {
case *source_metadatapb.MetaData_Git:
if tTest.expected[meta.Git.Commit].B != string(result.Raw) {
t.Errorf("%s: unexpected result. Got: %s, Expected: %s", tName, string(result.Raw), tTest.expected[meta.Git.Commit].B)
}
if tTest.expected[meta.Git.Commit].LineNumber != result.SourceMetadata.GetGit().Line {
t.Errorf("%s: unexpected line number. Got: %d, Expected: %d", tName, result.SourceMetadata.GetGit().Line, tTest.expected[meta.Git.Commit].LineNumber)
}
}
resultCount++

}
if resultCount != len(tTest.expected) {
t.Errorf("%s: unexpected number of results. Got: %d, Expected: %d", tName, resultCount, len(tTest.expected))
}
})
}
}

Expand Down Expand Up @@ -124,5 +130,8 @@ func BenchmarkGitEngine(b *testing.B) {
return
}
}
e.Finish(ctx)
logFatalFunc := func(_ error, _ string, _ ...any) {
b.Fatalf("error logging function should not have been called")
}
e.Finish(ctx, logFatalFunc)
}
11 changes: 6 additions & 5 deletions pkg/engine/github.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package engine

import (
"fmt"

gogit "github.com/go-git/go-git/v5"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
Expand Down Expand Up @@ -58,14 +60,13 @@ func (e *Engine) ScanGitHub(ctx context.Context, c sources.GithubConfig) error {
scanOptions := git.NewScanOptions(opts...)
source.WithScanOptions(scanOptions)

e.sourcesWg.Add(1)
go func() {
e.sourcesWg.Go(func() error {
defer common.RecoverWithExit(ctx)
defer e.sourcesWg.Done()
err := source.Chunks(ctx, e.ChunksChan())
if err != nil {
ctx.Logger().Error(err, "could not scan github")
return fmt.Errorf("could not scan github: %w", err)
}
}()
return nil
})
return nil
}
9 changes: 4 additions & 5 deletions pkg/engine/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ func (e *Engine) ScanGitLab(ctx context.Context, c sources.GitlabConfig) error {
}
gitlabSource.WithScanOptions(scanOptions)

e.sourcesWg.Add(1)
go func() {
e.sourcesWg.Go(func() error {
defer common.RecoverWithExit(ctx)
defer e.sourcesWg.Done()
err := gitlabSource.Chunks(ctx, e.ChunksChan())
if err != nil {
ctx.Logger().Error(err, "error scanning GitLab")
return fmt.Errorf("error scanning GitLab: %w", err)
}
}()
return nil
})
return nil
}
9 changes: 4 additions & 5 deletions pkg/engine/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,13 @@ func (e *Engine) ScanS3(ctx context.Context, c sources.S3Config) error {
return errors.WrapPrefix(err, "failed to init S3 source", 0)
}

e.sourcesWg.Add(1)
go func() {
e.sourcesWg.Go(func() error {
defer common.RecoverWithExit(ctx)
defer e.sourcesWg.Done()
err := s3Source.Chunks(ctx, e.ChunksChan())
if err != nil {
ctx.Logger().Error(err, "error scanning S3")
return fmt.Errorf("error scanning S3: %w", err)
}
}()
return nil
})
return nil
}
10 changes: 5 additions & 5 deletions pkg/engine/syslog.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package engine

import (
"fmt"
"os"

"github.com/go-errors/errors"
Expand Down Expand Up @@ -53,14 +54,13 @@ func (e *Engine) ScanSyslog(ctx context.Context, c sources.SyslogConfig) error {
return err
}

e.sourcesWg.Add(1)
go func() {
e.sourcesWg.Go(func() error {
defer common.RecoverWithExit(ctx)
defer e.sourcesWg.Done()
err := source.Chunks(ctx, e.ChunksChan())
if err != nil {
ctx.Logger().Error(err, "could not scan syslog")
return fmt.Errorf("could not scan syslog: %w", err)
}
}()
return nil
})
return nil
}
5 changes: 4 additions & 1 deletion pkg/sources/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,10 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
})

if err != nil {
s.log.Error(err, "could not list objects in s3 bucket", "bucket", bucket)
return fmt.Errorf(
"could not list objects in s3 bucket: bucket %s: %w",
bucket,
err)
}
}
s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount), "")
Expand Down

0 comments on commit da5301e

Please sign in to comment.