diff --git a/internal/tenant/context.go b/internal/tenant/context.go index afd887dab..5f79f9270 100644 --- a/internal/tenant/context.go +++ b/internal/tenant/context.go @@ -10,6 +10,7 @@ import ( "github.com/sourcegraph/zoekt/internal/tenant/internal/enforcement" "github.com/sourcegraph/zoekt/internal/tenant/internal/tenanttype" + "github.com/sourcegraph/zoekt/internal/tenant/systemtenant" "github.com/sourcegraph/zoekt/trace" ) @@ -26,6 +27,10 @@ func FromContext(ctx context.Context) (*tenanttype.Tenant, error) { // Log logs the tenant ID to the trace. If tenant logging is enabled, it also // logs a stack trace to a pprof profile. func Log(ctx context.Context, tr *trace.Trace) { + if systemtenant.Is(ctx) { + tr.LazyPrintf("tenant: system") + return + } tnt, ok := tenanttype.GetTenant(ctx) if !ok { if profile := pprofMissingTenant(); profile != nil { diff --git a/internal/tenant/systemtenant/systemtenant.go b/internal/tenant/systemtenant/systemtenant.go index a7957e6b6..eb5b2c994 100644 --- a/internal/tenant/systemtenant/systemtenant.go +++ b/internal/tenant/systemtenant/systemtenant.go @@ -10,9 +10,11 @@ type contextKey int const systemTenantKey contextKey = iota -// UnsafeCtx is a context that allows queries across all tenants. Don't use this -// for user requests. -var UnsafeCtx = context.WithValue(context.Background(), systemTenantKey, systemTenantKey) +// WithUnsafeContext taints the context to allow queries across all tenants. +// Never use this for user requests. +func WithUnsafeContext(ctx context.Context) context.Context { + return context.WithValue(ctx, systemTenantKey, systemTenantKey) +} // Is returns true if the context has been marked to allow queries across all // tenants. diff --git a/internal/tenant/systemtenant/systemtenant_test.go b/internal/tenant/systemtenant/systemtenant_test.go index 4330d82c7..f92a512d9 100644 --- a/internal/tenant/systemtenant/systemtenant_test.go +++ b/internal/tenant/systemtenant/systemtenant_test.go @@ -7,9 +7,9 @@ import ( "github.com/stretchr/testify/require" ) -func TestSystemtenantRoundtrip(t *testing.T) { +func TestSystemTenantRoundTrip(t *testing.T) { if Is(context.Background()) { t.Fatal() } - require.True(t, Is(UnsafeCtx)) + require.True(t, Is(WithUnsafeContext(context.Background()))) } diff --git a/shards/shards.go b/shards/shards.go index 6f13d5ae5..4552096b2 100644 --- a/shards/shards.go +++ b/shards/shards.go @@ -1083,10 +1083,10 @@ func (s *shardedSearcher) getLoaded() loaded { func mkRankedShard(s zoekt.Searcher) *rankedShard { q := query.Const{Value: true} - // We need to use UnsafeCtx here, otherwise we cannot return a proper + // We need to use WithUnsafeContext here, otherwise we cannot return a proper // rankedShard. On the user request path we use selectRepoSet which relies on // rankedShard.repos being set. - result, err := s.List(systemtenant.UnsafeCtx, &q, nil) + result, err := s.List(systemtenant.WithUnsafeContext(context.Background()), &q, nil) if err != nil { log.Printf("[ERROR] mkRankedShard(%s): failed to cache repository list: %v", s, err) return &rankedShard{Searcher: s} diff --git a/web/server.go b/web/server.go index eb713ddfc..9e23af8b8 100644 --- a/web/server.go +++ b/web/server.go @@ -32,7 +32,9 @@ import ( "time" "github.com/grafana/regexp" + "github.com/sourcegraph/zoekt" + "github.com/sourcegraph/zoekt/internal/tenant/systemtenant" zjson "github.com/sourcegraph/zoekt/json" "github.com/sourcegraph/zoekt/query" ) @@ -206,7 +208,10 @@ func (s *Server) serveHealthz(w http.ResponseWriter, r *http.Request) { q := &query.Const{Value: true} opts := &zoekt.SearchOptions{ShardMaxMatchCount: 1, TotalMaxMatchCount: 1, MaxDocDisplayCount: 1} - result, err := s.Searcher.Search(r.Context(), q, opts) + // We need to use WithUnsafeContext here because we want to perform a full + // search returning results. The result of this search is not used for anything + // other than determining if the server is healthy. + result, err := s.Searcher.Search(systemtenant.WithUnsafeContext(r.Context()), q, opts) if err != nil { http.Error(w, fmt.Sprintf("not ready: %v", err), http.StatusInternalServerError) return