Skip to content

Commit

Permalink
feat: add context-value flag (#1448)
Browse files Browse the repository at this point in the history
- add the `--context-value` command line flag to pass arbitrary key
value pairs to the evaluation context

Signed-off-by: Aleksei Muratov <[email protected]>
  • Loading branch information
alemrtv authored Dec 5, 2024
1 parent f7dd1eb commit 7ca092e
Show file tree
Hide file tree
Showing 16 changed files with 181 additions and 54 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ uninstall:
lint:
go install -v github.com/golangci/golangci-lint/cmd/[email protected]
$(foreach module, $(ALL_GO_MOD_DIRS), ${GOPATH}/bin/golangci-lint run --deadline=5m --timeout=5m $(module)/... || exit;)
lint-fix:
go install -v github.com/golangci/golangci-lint/cmd/[email protected]
$(foreach module, $(ALL_GO_MOD_DIRS), ${GOPATH}/bin/golangci-lint run --fix --deadline=5m --timeout=5m $(module)/... || exit;)
install-mockgen:
go install go.uber.org/mock/[email protected]
mockgen: install-mockgen
Expand Down
1 change: 1 addition & 0 deletions core/pkg/service/iservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Configuration struct {
SocketPath string
CORS []string
Options []connect.HandlerOption
ContextValues map[string]any
}

/*
Expand Down
3 changes: 3 additions & 0 deletions docs/reference/flag-definitions.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ For example, when accessing flagd via HTTP, the POST body may look like this:

The evaluation context can be accessed in targeting rules using the `var` operation followed by the evaluation context property name.

The evaluation context can be appended by arbitrary key value pairs
via the `-X` command line flag.

| Description | Example |
| -------------------------------------------------------------- | ---------------------------------------------------- |
| Retrieve property from the evaluation context | `#!json { "var": "email" }` |
Expand Down
1 change: 1 addition & 0 deletions docs/reference/flagd-cli/flagd_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ flagd start [flags]
### Options

```
-X, --context-value stringToString add arbitrary key value pairs to the flag evaluation context (default [])
-C, --cors-origin strings CORS allowed origins, * will allow all origins
-h, --help help for start
-z, --log-format string Set the logging format, e.g. console or json (default "console")
Expand Down
11 changes: 10 additions & 1 deletion flagd/cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ const (
sourcesFlagName = "sources"
syncPortFlagName = "sync-port"
uriFlagName = "uri"
contextValueFlagName = "context-value"
)

func init() {
flags := startCmd.Flags()

// allows environment variables to use _ instead of -
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) // sync-provider-args becomes SYNC_PROVIDER_ARGS
viper.SetEnvPrefix("FLAGD") // port becomes FLAGD_PORT
Expand Down Expand Up @@ -78,6 +78,8 @@ func init() {
flags.StringP(otelCAPathFlagName, "A", "", "tls certificate authority path to use with OpenTelemetry collector")
flags.DurationP(otelReloadIntervalFlagName, "I", time.Hour, "how long between reloading the otel tls certificate "+
"from disk")
flags.StringToStringP(contextValueFlagName, "X", map[string]string{}, "add arbitrary key value pairs "+
"to the flag evaluation context")

_ = viper.BindPFlag(corsFlagName, flags.Lookup(corsFlagName))
_ = viper.BindPFlag(logFormatFlagName, flags.Lookup(logFormatFlagName))
Expand All @@ -95,6 +97,7 @@ func init() {
_ = viper.BindPFlag(uriFlagName, flags.Lookup(uriFlagName))
_ = viper.BindPFlag(syncPortFlagName, flags.Lookup(syncPortFlagName))
_ = viper.BindPFlag(ofrepPortFlagName, flags.Lookup(ofrepPortFlagName))
_ = viper.BindPFlag(contextValueFlagName, flags.Lookup(contextValueFlagName))
}

// startCmd represents the start command
Expand Down Expand Up @@ -139,6 +142,11 @@ var startCmd = &cobra.Command{
}
syncProviders = append(syncProviders, syncProvidersFromConfig...)

contextValuesToMap := make(map[string]any)
for k, v := range viper.GetStringMapString(contextValueFlagName) {
contextValuesToMap[k] = v
}

// Build Runtime -----------------------------------------------------------
rt, err := runtime.FromConfig(logger, Version, runtime.Config{
CORS: viper.GetStringSlice(corsFlagName),
Expand All @@ -156,6 +164,7 @@ var startCmd = &cobra.Command{
ServiceSocketPath: viper.GetString(socketPathFlagName),
SyncServicePort: viper.GetUint16(syncPortFlagName),
SyncProviders: syncProviders,
ContextValues: contextValuesToMap,
})
if err != nil {
rtLogger.Fatal(err.Error())
Expand Down
16 changes: 11 additions & 5 deletions flagd/pkg/runtime/from_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type Config struct {

SyncProviders []sync.SourceConfig
CORS []string

ContextValues map[string]any
}

// FromConfig builds a runtime from startup configurations
Expand Down Expand Up @@ -101,17 +103,20 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
ofrepService, err := ofrep.NewOfrepService(jsonEvaluator, config.CORS, ofrep.SvcConfiguration{
Logger: logger.WithFields(zap.String("component", "OFREPService")),
Port: config.OfrepServicePort,
})
},
config.ContextValues,
)
if err != nil {
return nil, fmt.Errorf("error creating ofrep service")
}

// flag sync service
flagSyncService, err := flagsync.NewSyncService(flagsync.SvcConfigurations{
Logger: logger.WithFields(zap.String("component", "FlagSyncService")),
Port: config.SyncServicePort,
Sources: sources,
Store: s,
Logger: logger.WithFields(zap.String("component", "FlagSyncService")),
Port: config.SyncServicePort,
Sources: sources,
Store: s,
ContextValues: config.ContextValues,
})
if err != nil {
return nil, fmt.Errorf("error creating sync service: %w", err)
Expand Down Expand Up @@ -145,6 +150,7 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
SocketPath: config.ServiceSocketPath,
CORS: config.CORS,
Options: options,
ContextValues: config.ContextValues,
},
SyncImpl: iSyncs,
}, nil
Expand Down
2 changes: 2 additions & 0 deletions flagd/pkg/service/flag-evaluation/connect_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene
s.eval,
s.eventingConfiguration,
s.metrics,
svcConf.ContextValues,
)

marshalOpts := WithJSON(
Expand All @@ -170,6 +171,7 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene
s.eval,
s.eventingConfiguration,
s.metrics,
svcConf.ContextValues,
)

_, newHandler := evaluationV1.NewServiceHandler(newFes, append(svcConf.Options, marshalOpts)...)
Expand Down
45 changes: 34 additions & 11 deletions flagd/pkg/service/flag-evaluation/flag_evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,24 @@ type OldFlagEvaluationService struct {
metrics telemetry.IMetricsRecorder
eventingConfiguration IEvents
flagEvalTracer trace.Tracer
contextValues map[string]any
}

// NewOldFlagEvaluationService creates a OldFlagEvaluationService with provided parameters
func NewOldFlagEvaluationService(log *logger.Logger,
eval evaluator.IEvaluator, eventingCfg IEvents, metricsRecorder telemetry.IMetricsRecorder,
func NewOldFlagEvaluationService(
log *logger.Logger,
eval evaluator.IEvaluator,
eventingCfg IEvents,
metricsRecorder telemetry.IMetricsRecorder,
contextValues map[string]any,
) *OldFlagEvaluationService {
svc := &OldFlagEvaluationService{
logger: log,
eval: eval,
metrics: &telemetry.NoopMetricsRecorder{},
eventingConfiguration: eventingCfg,
flagEvalTracer: otel.Tracer("flagEvaluationService"),
contextValues: contextValues,
}

if metricsRecorder != nil {
Expand All @@ -65,12 +71,8 @@ func (s *OldFlagEvaluationService) ResolveAll(
res := &schemaV1.ResolveAllResponse{
Flags: make(map[string]*schemaV1.AnyFlag),
}
evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}

values, err := s.eval.ResolveAllValues(sCtx, reqID, evalCtx)
values, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues))
if err != nil {
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
Expand Down Expand Up @@ -172,6 +174,7 @@ func (s *OldFlagEvaluationService) ResolveBoolean(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()
res := connect.NewResponse(&schemaV1.ResolveBooleanResponse{})

err := resolve[bool](
sCtx,
s.logger,
Expand All @@ -180,6 +183,7 @@ func (s *OldFlagEvaluationService) ResolveBoolean(
req.Msg.GetContext(),
&booleanResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -206,6 +210,7 @@ func (s *OldFlagEvaluationService) ResolveString(
req.Msg.GetContext(),
&stringResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -232,6 +237,7 @@ func (s *OldFlagEvaluationService) ResolveInt(
req.Msg.GetContext(),
&intResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -258,6 +264,7 @@ func (s *OldFlagEvaluationService) ResolveFloat(
req.Msg.GetContext(),
&floatResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -284,6 +291,7 @@ func (s *OldFlagEvaluationService) ResolveObject(
req.Msg.GetContext(),
&objectResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -293,21 +301,36 @@ func (s *OldFlagEvaluationService) ResolveObject(
return res, err
}

// mergeContexts combines values from the request context with the values from the config --context-values flag.
// Request context values have a higher priority.
func mergeContexts(reqCtx, configFlagsCtx map[string]any) map[string]any {
merged := make(map[string]any)
for k, v := range reqCtx {
merged[k] = v
}
for k, v := range configFlagsCtx {
merged[k] = v
}
return merged
}

// resolve is a generic flag resolver
func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], flagKey string,
evaluationContext *structpb.Struct, resp response[T], metrics telemetry.IMetricsRecorder,
configContextValues map[string]any,
) error {
reqID := xid.New().String()
defer logger.ClearFields(reqID)

mergedContext := mergeContexts(evaluationContext.AsMap(), configContextValues)
logger.WriteFields(
reqID,
zap.String("flag-key", flagKey),
zap.Strings("context-keys", formatContextKeys(evaluationContext)),
zap.Strings("context-keys", formatContextKeys(mergedContext)),
)

var evalErrFormatted error
result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, evaluationContext.AsMap())
result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, mergedContext)
if evalErr != nil {
logger.WarnWithID(reqID, fmt.Sprintf("returning error response, reason: %v", evalErr))
reason = model.ErrorReason
Expand All @@ -329,9 +352,9 @@ func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver
return evalErrFormatted
}

func formatContextKeys(context *structpb.Struct) []string {
func formatContextKeys(context map[string]any) []string {
res := []string{}
for k := range context.AsMap() {
for k := range context {
res = append(res, k)
}
return res
Expand Down
11 changes: 11 additions & 0 deletions flagd/pkg/service/flag-evaluation/flag_evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ func TestConnectService_ResolveAll(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveAll(context.Background(), connect.NewRequest(tt.req))
if err != nil && !errors.Is(err, tt.wantErr) {
Expand Down Expand Up @@ -235,6 +236,7 @@ func TestFlag_Evaluation_ResolveBoolean(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveBoolean(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
Expand Down Expand Up @@ -290,6 +292,7 @@ func BenchmarkFlag_Evaluation_ResolveBoolean(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
Expand Down Expand Up @@ -388,6 +391,7 @@ func TestFlag_Evaluation_ResolveString(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveString(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
Expand Down Expand Up @@ -443,6 +447,7 @@ func BenchmarkFlag_Evaluation_ResolveString(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
Expand Down Expand Up @@ -540,6 +545,7 @@ func TestFlag_Evaluation_ResolveFloat(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveFloat(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
Expand Down Expand Up @@ -595,6 +601,7 @@ func BenchmarkFlag_Evaluation_ResolveFloat(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
Expand Down Expand Up @@ -692,6 +699,7 @@ func TestFlag_Evaluation_ResolveInt(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveInt(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
Expand Down Expand Up @@ -747,6 +755,7 @@ func BenchmarkFlag_Evaluation_ResolveInt(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
Expand Down Expand Up @@ -847,6 +856,7 @@ func TestFlag_Evaluation_ResolveObject(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)

outParsed, err := structpb.NewStruct(tt.evalFields.result)
Expand Down Expand Up @@ -910,6 +920,7 @@ func BenchmarkFlag_Evaluation_ResolveObject(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
if name != "eval returns error" {
outParsed, err := structpb.NewStruct(tt.evalFields.result)
Expand Down
Loading

0 comments on commit 7ca092e

Please sign in to comment.