diff --git a/tenant/resolver.go b/tenant/resolver.go index 35e95b1c8..9a01d6322 100644 --- a/tenant/resolver.go +++ b/tenant/resolver.go @@ -16,7 +16,18 @@ import ( // //nolint:revive func TenantID(ctx context.Context) (string, error) { - orgIDs, err := TenantIDs(ctx) + //lint:ignore faillint wrapper around upstream method + orgID, err := user.ExtractOrgID(ctx) + if err != nil { + return "", err + } + if !strings.Contains(orgID, tenantIDsSeparator) { + if err := ValidTenantID(orgID); err != nil { + return "", err + } + return orgID, nil + } + orgIDs, err := tenantIDsFromString(orgID) if err != nil { return "", err } @@ -42,6 +53,10 @@ func TenantIDs(ctx context.Context) ([]string, error) { return nil, err } + return tenantIDsFromString(orgID) +} + +func tenantIDsFromString(orgID string) ([]string, error) { orgIDs := strings.Split(orgID, tenantIDsSeparator) for _, id := range orgIDs { if err := ValidTenantID(id); err != nil { diff --git a/tenant/tenant_test.go b/tenant/tenant_test.go index bc3a60b54..f0b6b9e15 100644 --- a/tenant/tenant_test.go +++ b/tenant/tenant_test.go @@ -1,9 +1,12 @@ package tenant import ( + "context" "strings" "testing" + "github.com/grafana/dskit/user" + "github.com/stretchr/testify/assert" ) @@ -48,3 +51,23 @@ func TestValidTenantIDs(t *testing.T) { }) } } + +func BenchmarkTenantID(b *testing.B) { + singleCtx := context.Background() + singleCtx = user.InjectOrgID(singleCtx, "tenant-a") + multiCtx := context.Background() + multiCtx = user.InjectOrgID(multiCtx, "tenant-a|tenant-b|tenant-c") + + b.ResetTimer() + b.ReportAllocs() + b.Run("single", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = TenantID(singleCtx) + } + }) + b.Run("multi", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = TenantID(multiCtx) + } + }) +}