diff --git a/README.md b/README.md index b7d731eb..59c0da52 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,7 @@ print("Value: " + str(flag_value)) | ✅ | [Domains](#domains) | Logically bind clients with providers. | | ✅ | [Eventing](#eventing) | React to state changes in the provider or flag management system. | | ✅ | [Shutdown](#shutdown) | Gracefully clean up a provider during application shutdown. | +| ️️⚠️ | [Transaction Context Propagation](#transaction-context-propagation) | Set a specific [evaluation context](/docs/reference/concepts/evaluation-context) for a transaction (e.g. an HTTP request or a thread) | | ✅ | [Extending](#extending) | Extend OpenFeature with custom providers and hooks. | Implemented: ✅ | In-progress: ⚠️ | Not implemented yet: ❌ diff --git a/openfeature/api.py b/openfeature/api.py index c04d423e..29866e4d 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -12,6 +12,10 @@ from openfeature.provider import FeatureProvider from openfeature.provider._registry import provider_registry from openfeature.provider.metadata import Metadata +from openfeature.transaction_context import ( + NoopTransactionContextPropagator, + TransactionContextPropagator, +) __all__ = [ "get_client", @@ -26,12 +30,19 @@ "shutdown", "add_handler", "remove_handler", + "set_transaction_context_propagator", + "set_transaction_context", + "get_transaction_context", ] _evaluation_context = EvaluationContext() _hooks: typing.List[Hook] = [] +_transaction_context_propagator: TransactionContextPropagator = ( + NoopTransactionContextPropagator() +) + def get_client( domain: typing.Optional[str] = None, version: typing.Optional[str] = None @@ -94,3 +105,17 @@ def add_handler(event: ProviderEvent, handler: EventHandler) -> None: def remove_handler(event: ProviderEvent, handler: EventHandler) -> None: _event_support.remove_global_handler(event, handler) + + +def set_transaction_context_propagator( + propagator: TransactionContextPropagator, +) -> None: + _transaction_context_propagator = propagator + + +def set_transaction_context(context: EvaluationContext) -> None: + _transaction_context_propagator.set_transaction_context(context) + + +def get_transaction_context() -> EvaluationContext: + return _transaction_context_propagator.get_transaction_context() diff --git a/openfeature/client.py b/openfeature/client.py index 0429911a..f4152af1 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -335,9 +335,10 @@ def evaluate_flag_details( # noqa: PLR0915 ) invocation_context = invocation_context.merge(ctx2=evaluation_context) - # Requirement 3.2.2 merge: API.context->client.context->invocation.context + # Requirement 3.2.3 merge: API.context->transaction.context->client.context->invocation.context merged_context = ( api.get_evaluation_context() + .merge(api.get_transaction_context()) .merge(self.context) .merge(invocation_context) ) diff --git a/openfeature/transaction_context.py b/openfeature/transaction_context.py new file mode 100644 index 00000000..978ef477 --- /dev/null +++ b/openfeature/transaction_context.py @@ -0,0 +1,37 @@ +import typing +from contextvars import ContextVar + +from openfeature.evaluation_context import EvaluationContext + +__all__ = [ + "TransactionContextPropagator", + "NoopTransactionContextPropagator", + "ContextVarTransactionContextPropagator", +] + + +class TransactionContextPropagator(typing.Protocol): + def get_transaction_context(self) -> EvaluationContext: ... + + def set_transaction_context(self, context: EvaluationContext) -> None: ... + + +class NoopTransactionContextPropagator(TransactionContextPropagator): + def get_transaction_context(self) -> EvaluationContext: + return EvaluationContext() + + def set_transaction_context(self, context: EvaluationContext) -> None: + pass + + +class ContextVarTransactionContextPropagator(TransactionContextPropagator): + def __init__(self) -> None: + self._contextvar = ContextVar( + "transaction_context", default=EvaluationContext() + ) + + def get_transaction_context(self) -> EvaluationContext: + return self._contextvar.get() + + def set_transaction_context(self, context: EvaluationContext) -> None: + self._contextvar.set(context) diff --git a/tests/test_transaction_context.py b/tests/test_transaction_context.py new file mode 100644 index 00000000..7d0ed7f2 --- /dev/null +++ b/tests/test_transaction_context.py @@ -0,0 +1,28 @@ +from concurrent.futures import ThreadPoolExecutor + +from openfeature.evaluation_context import EvaluationContext +from openfeature.transaction_context import ContextVarTransactionContextPropagator + + +def test_contextvar_transaction_context_propagator(): + propagator = ContextVarTransactionContextPropagator() + + context = propagator.get_transaction_context() + assert isinstance(context, EvaluationContext) + + context = EvaluationContext(targeting_key="foo", attributes={"key": "value"}) + propagator.set_transaction_context(context) + transaction_context = propagator.get_transaction_context() + + assert transaction_context.targeting_key == "foo" + assert transaction_context.attributes == {"key": "value"} + + def thread_fn(): + thread_context = propagator.get_transaction_context() + assert thread_context.targeting_key is None + assert thread_context.attributes == {} + + with ThreadPoolExecutor() as executor: + future = executor.submit(thread_fn) + + future.result()