diff --git a/README.md b/README.md index 909ee20e..7653bc0a 100644 --- a/README.md +++ b/README.md @@ -99,16 +99,17 @@ print("Value: " + str(flag_value)) ## 🌟 Features -| Status | Features | Description | -| ------ | ------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------- | -| ✅ | [Providers](#providers) | Integrate with a commercial, open source, or in-house feature management tool. | -| ✅ | [Targeting](#targeting) | Contextually-aware flag evaluation using [evaluation context](https://openfeature.dev/docs/reference/concepts/evaluation-context). | -| ✅ | [Hooks](#hooks) | Add functionality to various stages of the flag evaluation life-cycle. | -| ✅ | [Logging](#logging) | Integrate with popular logging packages. | -| ✅ | [Domains](#domains) | Logically bind clients with providers. | -| ✅ | [Eventing](#eventing) | React to state changes in the provider or flag management system. | -| ✅ | [Shutdown](#shutdown) | Gracefully clean up a provider during application shutdown. | -| ✅ | [Extending](#extending) | Extend OpenFeature with custom providers and hooks. | +| Status | Features | Description | +|--------|---------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------| +| ✅ | [Providers](#providers) | Integrate with a commercial, open source, or in-house feature management tool. | +| ✅ | [Targeting](#targeting) | Contextually-aware flag evaluation using [evaluation context](https://openfeature.dev/docs/reference/concepts/evaluation-context). | +| ✅ | [Hooks](#hooks) | Add functionality to various stages of the flag evaluation life-cycle. | +| ✅ | [Logging](#logging) | Integrate with popular logging packages. | +| ✅ | [Domains](#domains) | Logically bind clients with providers. | +| ✅ | [Eventing](#eventing) | React to state changes in the provider or flag management system. | +| ✅ | [Shutdown](#shutdown) | Gracefully clean up a provider during application shutdown. | +| ✅ | [Transaction Context Propagation](#transaction-context-propagation) | Set a specific [evaluation context](/docs/reference/concepts/evaluation-context) for a transaction (e.g. an HTTP request or a thread) | +| ✅ | [Extending](#extending) | Extend OpenFeature with custom providers and hooks. | Implemented: ✅ | In-progress: ⚠️ | Not implemented yet: ❌ @@ -235,6 +236,86 @@ def on_provider_ready(event_details: EventDetails): client.add_handler(ProviderEvent.PROVIDER_READY, on_provider_ready) ``` +### Transaction Context Propagation + +Transaction context is a container for transaction-specific evaluation context (e.g. user id, user agent, IP). +Transaction context can be set where specific data is available (e.g. an auth service or request handler) and by using the transaction context propagator it will automatically be applied to all flag evaluations within a transaction (e.g. a request or thread). + +You can implement a different transaction context propagator by implementing the `TransactionContextPropagator` class exported by the OpenFeature SDK. +In most cases you can use `ContextVarsTransactionContextPropagator` as it works for `threads` and `asyncio` using [Context Variables](https://peps.python.org/pep-0567/). + +The following example shows a **multithreaded** Flask application using transaction context propagation to propagate the request ip and user id into request scoped transaction context. + +```python +from flask import Flask, request +from openfeature import api +from openfeature.evaluation_context import EvaluationContext +from openfeature.transaction_context import ContextVarsTransactionContextPropagator + +# Initialize the Flask app +app = Flask(__name__) + +# Set the transaction context propagator +api.set_transaction_context_propagator(ContextVarsTransactionContextPropagator()) + +# Middleware to set the transaction context +# You can call api.set_transaction_context anywhere you have information, +# you want to have available in the code-paths below the current one. +@app.before_request +def set_request_transaction_context(): + ip = request.headers.get("X-Forwarded-For", request.remote_addr) + user_id = request.headers.get("User-Id") # Assuming we're getting the user ID from a header + evaluation_context = EvaluationContext(targeting_key=user_id, attributes={"ipAddress": ip}) + api.set_transaction_context(evaluation_context) + +def create_response() -> str: + # This method can be anywhere in our code. + # The feature flag evaluation will automatically contain the transaction context merged with other context + new_response = api.get_client().get_string_value("response-message", "Hello User!") + return f"Message from server: {new_response}" + +# Example route where we use the transaction context +@app.route('/greeting') +def some_endpoint(): + return create_response() +``` + +This also works for asyncio based implementations e.g. FastApi as seen in the following example: + +```python +from fastapi import FastAPI, Request +from openfeature import api +from openfeature.evaluation_context import EvaluationContext +from openfeature.transaction_context import ContextVarsTransactionContextPropagator + +# Initialize the FastAPI app +app = FastAPI() + +# Set the transaction context propagator +api.set_transaction_context_propagator(ContextVarsTransactionContextPropagator()) + +# Middleware to set the transaction context +@app.middleware("http") +async def set_request_transaction_context(request: Request, call_next): + ip = request.headers.get("X-Forwarded-For", request.client.host) + user_id = request.headers.get("User-Id") # Assuming we're getting the user ID from a header + evaluation_context = EvaluationContext(targeting_key=user_id, attributes={"ipAddress": ip}) + api.set_transaction_context(evaluation_context) + response = await call_next(request) + return response + +def create_response() -> str: + # This method can be located anywhere in our code. + # The feature flag evaluation will automatically include the transaction context merged with other context. + new_response = api.get_client().get_string_value("response-message", "Hello User!") + return f"Message from server: {new_response}" + +# Example route where we use the transaction context +@app.get('/greeting') +async def some_endpoint(): + return create_response() +``` + ### Shutdown The OpenFeature API provides a shutdown function to perform a cleanup of all registered providers. This should only be called when your application is in the process of shutting down. diff --git a/openfeature/api.py b/openfeature/api.py index c95d10ac..c7d29c48 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -12,6 +12,10 @@ from openfeature.provider import FeatureProvider from openfeature.provider._registry import provider_registry from openfeature.provider.metadata import Metadata +from openfeature.transaction_context import TransactionContextPropagator +from openfeature.transaction_context.no_op_transaction_context_propagator import ( + NoOpTransactionContextPropagator, +) __all__ = [ "get_client", @@ -20,6 +24,9 @@ "get_provider_metadata", "get_evaluation_context", "set_evaluation_context", + "set_transaction_context_propagator", + "get_transaction_context", + "set_transaction_context", "add_hooks", "clear_hooks", "get_hooks", @@ -29,6 +36,9 @@ ] _evaluation_context = EvaluationContext() +_evaluation_transaction_context_propagator: TransactionContextPropagator = ( + NoOpTransactionContextPropagator() +) _hooks: typing.List[Hook] = [] @@ -68,6 +78,24 @@ def set_evaluation_context(evaluation_context: EvaluationContext) -> None: _evaluation_context = evaluation_context +def set_transaction_context_propagator( + transaction_context_propagator: TransactionContextPropagator, +) -> None: + global _evaluation_transaction_context_propagator + _evaluation_transaction_context_propagator = transaction_context_propagator + + +def get_transaction_context() -> EvaluationContext: + return _evaluation_transaction_context_propagator.get_transaction_context() + + +def set_transaction_context(evaluation_context: EvaluationContext) -> None: + global _evaluation_transaction_context_propagator + _evaluation_transaction_context_propagator.set_transaction_context( + evaluation_context + ) + + def add_hooks(hooks: typing.List[Hook]) -> None: global _hooks _hooks = _hooks + hooks diff --git a/openfeature/client.py b/openfeature/client.py index 9e4518ec..1edfca63 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -335,9 +335,10 @@ def evaluate_flag_details( # noqa: PLR0915 ) invocation_context = invocation_context.merge(ctx2=evaluation_context) - # Requirement 3.2.2 merge: API.context->client.context->invocation.context + # Requirement 3.2.2 merge: API.context->transaction.context->client.context->invocation.context merged_context = ( api.get_evaluation_context() + .merge(api.get_transaction_context()) .merge(self.context) .merge(invocation_context) ) diff --git a/openfeature/transaction_context/__init__.py b/openfeature/transaction_context/__init__.py new file mode 100644 index 00000000..e97fd36f --- /dev/null +++ b/openfeature/transaction_context/__init__.py @@ -0,0 +1,11 @@ +from openfeature.transaction_context.context_var_transaction_context_propagator import ( + ContextVarsTransactionContextPropagator, +) +from openfeature.transaction_context.transaction_context_propagator import ( + TransactionContextPropagator, +) + +__all__ = [ + "TransactionContextPropagator", + "ContextVarsTransactionContextPropagator", +] diff --git a/openfeature/transaction_context/context_var_transaction_context_propagator.py b/openfeature/transaction_context/context_var_transaction_context_propagator.py new file mode 100644 index 00000000..1abc04fa --- /dev/null +++ b/openfeature/transaction_context/context_var_transaction_context_propagator.py @@ -0,0 +1,18 @@ +from contextvars import ContextVar + +from openfeature.evaluation_context import EvaluationContext +from openfeature.transaction_context.transaction_context_propagator import ( + TransactionContextPropagator, +) + + +class ContextVarsTransactionContextPropagator(TransactionContextPropagator): + _transaction_context_var: ContextVar[EvaluationContext] = ContextVar( + "transaction_context", default=EvaluationContext() + ) + + def get_transaction_context(self) -> EvaluationContext: + return self._transaction_context_var.get() + + def set_transaction_context(self, transaction_context: EvaluationContext) -> None: + self._transaction_context_var.set(transaction_context) diff --git a/openfeature/transaction_context/no_op_transaction_context_propagator.py b/openfeature/transaction_context/no_op_transaction_context_propagator.py new file mode 100644 index 00000000..22a5a3f1 --- /dev/null +++ b/openfeature/transaction_context/no_op_transaction_context_propagator.py @@ -0,0 +1,12 @@ +from openfeature.evaluation_context import EvaluationContext +from openfeature.transaction_context.transaction_context_propagator import ( + TransactionContextPropagator, +) + + +class NoOpTransactionContextPropagator(TransactionContextPropagator): + def get_transaction_context(self) -> EvaluationContext: + return EvaluationContext() + + def set_transaction_context(self, transaction_context: EvaluationContext) -> None: + pass diff --git a/openfeature/transaction_context/transaction_context_propagator.py b/openfeature/transaction_context/transaction_context_propagator.py new file mode 100644 index 00000000..9a54367b --- /dev/null +++ b/openfeature/transaction_context/transaction_context_propagator.py @@ -0,0 +1,11 @@ +import typing + +from openfeature.evaluation_context import EvaluationContext + + +class TransactionContextPropagator(typing.Protocol): + def get_transaction_context(self) -> EvaluationContext: ... + + def set_transaction_context( + self, transaction_context: EvaluationContext + ) -> None: ... diff --git a/pyproject.toml b/pyproject.toml index 40887d9d..6d37bd88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "behave", "coverage[toml]>=6.5", "pytest", + "pytest-asyncio" ] [tool.hatch.envs.default.scripts] diff --git a/tests/test_client.py b/tests/test_client.py index b51c460c..7f0ca461 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -5,8 +5,10 @@ import pytest +from openfeature import api from openfeature.api import add_hooks, clear_hooks, get_client, set_provider from openfeature.client import OpenFeatureClient +from openfeature.evaluation_context import EvaluationContext from openfeature.event import EventDetails, ProviderEvent, ProviderEventDetails from openfeature.exception import ErrorCode, OpenFeatureError from openfeature.flag_evaluation import FlagResolutionDetails, Reason @@ -14,6 +16,7 @@ from openfeature.provider import FeatureProvider, ProviderStatus from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider from openfeature.provider.no_op_provider import NoOpProvider +from openfeature.transaction_context import ContextVarsTransactionContextPropagator @pytest.mark.parametrize( @@ -384,3 +387,47 @@ def emit_events_task(): f2 = executor.submit(emit_events_task) f1.result() f2.result() + + +def test_client_should_merge_contexts(): + api.clear_hooks() + api.set_transaction_context_propagator(ContextVarsTransactionContextPropagator()) + + provider = NoOpProvider() + provider.resolve_boolean_details = MagicMock(wraps=provider.resolve_boolean_details) + api.set_provider(provider) + + # Global evaluation context + global_context = EvaluationContext( + targeting_key="global", attributes={"global_attr": "global_value"} + ) + api.set_evaluation_context(global_context) + + # Transaction context + transaction_context = EvaluationContext( + targeting_key="transaction", + attributes={"transaction_attr": "transaction_value"}, + ) + api.set_transaction_context(transaction_context) + + # Client-specific context + client_context = EvaluationContext( + targeting_key="client", attributes={"client_attr": "client_value"} + ) + client = OpenFeatureClient(domain=None, version=None, context=client_context) + + # Invocation-specific context + invocation_context = EvaluationContext( + targeting_key="invocation", attributes={"invocation_attr": "invocation_value"} + ) + client.get_boolean_details("flag", False, invocation_context) + + # Retrieve the call arguments + args, kwargs = provider.resolve_boolean_details.call_args + flag_key, default_value, context = args + + assert context.targeting_key == "invocation" # Last one in the merge chain + assert context.attributes["global_attr"] == "global_value" + assert context.attributes["transaction_attr"] == "transaction_value" + assert context.attributes["client_attr"] == "client_value" + assert context.attributes["invocation_attr"] == "invocation_value" diff --git a/tests/test_transaction_context.py b/tests/test_transaction_context.py new file mode 100644 index 00000000..1da57c28 --- /dev/null +++ b/tests/test_transaction_context.py @@ -0,0 +1,175 @@ +import asyncio +import threading +from unittest.mock import MagicMock + +import pytest + +from openfeature.api import ( + get_transaction_context, + set_transaction_context, + set_transaction_context_propagator, +) +from openfeature.evaluation_context import EvaluationContext +from openfeature.transaction_context import ( + ContextVarsTransactionContextPropagator, + TransactionContextPropagator, +) +from openfeature.transaction_context.no_op_transaction_context_propagator import ( + NoOpTransactionContextPropagator, +) + + +# Test cases +def test_should_return_default_evaluation_context_with_noop_propagator(): + # Given + set_transaction_context_propagator(NoOpTransactionContextPropagator()) + + # When + context = get_transaction_context() + + # Then + assert isinstance(context, EvaluationContext) + assert context.attributes == {} + + +def test_should_set_and_get_custom_transaction_context(): + # Given + set_transaction_context_propagator(ContextVarsTransactionContextPropagator()) + evaluation_context = EvaluationContext("custom_key", {"attr1": "val1"}) + + # When + set_transaction_context(evaluation_context) + + # Then + context = get_transaction_context() + assert context.targeting_key == "custom_key" + assert context.attributes == {"attr1": "val1"} + + +def test_should_override_propagator_and_reset_context(): + # Given + custom_propagator = MagicMock(spec=TransactionContextPropagator) + default_context = EvaluationContext() + + set_transaction_context_propagator(custom_propagator) + + # When + set_transaction_context_propagator(NoOpTransactionContextPropagator()) + + # Then + assert get_transaction_context() == default_context + + +def test_should_call_set_transaction_context_on_propagator(): + # Given + custom_propagator = MagicMock(spec=TransactionContextPropagator) + evaluation_context = EvaluationContext("custom_key", {"attr1": "val1"}) + set_transaction_context_propagator(custom_propagator) + + # When + set_transaction_context(evaluation_context) + + # Then + custom_propagator.set_transaction_context.assert_called_with(evaluation_context) + + +def test_should_return_default_context_with_noop_propagator_set(): + # Given + noop_propagator = NoOpTransactionContextPropagator() + + set_transaction_context_propagator(noop_propagator) + + # When + context = get_transaction_context() + + # Then + assert context == EvaluationContext() + + +def test_should_propagate_event_when_context_set(): + # Given + custom_propagator = ContextVarsTransactionContextPropagator() + set_transaction_context_propagator(custom_propagator) + evaluation_context = EvaluationContext("custom_key", {"attr1": "val1"}) + + # When + set_transaction_context(evaluation_context) + + # Then + assert ( + custom_propagator._transaction_context_var.get().targeting_key == "custom_key" + ) + assert custom_propagator._transaction_context_var.get().attributes == { + "attr1": "val1" + } + + +def test_context_vars_transaction_context_propagator_multiple_threads(): + # Given + context_var_propagator = ContextVarsTransactionContextPropagator() + set_transaction_context_propagator(context_var_propagator) + + number_of_threads = 3 + barrier = threading.Barrier(number_of_threads) + + def thread_func(context_value, result_list, index): + context = EvaluationContext( + f"context_{context_value}", {"thread": context_value} + ) + set_transaction_context(context) + barrier.wait() + result_list[index] = get_transaction_context() + + results = [None] * number_of_threads + threads = [] + + # When + for i in range(3): + thread = threading.Thread(target=thread_func, args=(i, results, i)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # Then + for i in range(3): + assert results[i].targeting_key == f"context_{i}" + assert results[i].attributes == {"thread": i} + + +@pytest.mark.asyncio +async def test_context_vars_transaction_context_propagator_asyncio(): + # Given + context_var_propagator = ContextVarsTransactionContextPropagator() + set_transaction_context_propagator(context_var_propagator) + + number_of_tasks = 3 + event = asyncio.Event() + ready_count = 0 + + async def async_func(context_value, results, index): + nonlocal ready_count + context = EvaluationContext( + f"context_{context_value}", {"async": context_value} + ) + set_transaction_context(context) + + ready_count += 1 # Increment the ready count + if ready_count == number_of_tasks: + event.set() # Set the event when all tasks are ready + + await event.wait() # Wait for the event to be set + results[index] = get_transaction_context() + + # Placeholder for results + results = [None] * number_of_tasks + + # When + tasks = [async_func(i, results, i) for i in range(number_of_tasks)] + await asyncio.gather(*tasks) + + # Then + for i in range(number_of_tasks): + assert results[i].targeting_key == f"context_{i}" + assert results[i].attributes == {"async": i}