diff --git a/openfeature/flag_evaluation.py b/openfeature/flag_evaluation.py index 26b565ad..98adab4b 100644 --- a/openfeature/flag_evaluation.py +++ b/openfeature/flag_evaluation.py @@ -6,8 +6,9 @@ from openfeature._backports.strenum import StrEnum from openfeature.exception import ErrorCode -if typing.TYPE_CHECKING: # resolves a circular dependency in type annotations - from openfeature.hook import Hook +if typing.TYPE_CHECKING: # pragma: no cover + # resolves a circular dependency in type annotations + from openfeature.hook import Hook, HookHints class FlagType(StrEnum): @@ -48,7 +49,7 @@ class FlagEvaluationDetails(typing.Generic[T_co]): @dataclass class FlagEvaluationOptions: hooks: typing.List[Hook] = field(default_factory=list) - hook_hints: dict = field(default_factory=dict) + hook_hints: HookHints = field(default_factory=dict) U_co = typing.TypeVar("U_co", covariant=True) diff --git a/openfeature/hook/__init__.py b/openfeature/hook/__init__.py index 13748aac..8cb1c14c 100644 --- a/openfeature/hook/__init__.py +++ b/openfeature/hook/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations import typing -from dataclasses import dataclass +from datetime import datetime from enum import Enum from typing import TYPE_CHECKING @@ -20,24 +20,53 @@ class HookType(Enum): ERROR = "error" -@dataclass class HookContext: - flag_key: str - flag_type: FlagType - default_value: typing.Any - evaluation_context: EvaluationContext - client_metadata: typing.Optional[ClientMetadata] = None - provider_metadata: typing.Optional[Metadata] = None + def __init__( # noqa: PLR0913 + self, + flag_key: str, + flag_type: FlagType, + default_value: typing.Any, + evaluation_context: EvaluationContext, + client_metadata: typing.Optional[ClientMetadata] = None, + provider_metadata: typing.Optional[Metadata] = None, + ): + self.flag_key = flag_key + self.flag_type = flag_type + self.default_value = default_value + self.evaluation_context = evaluation_context + self.client_metadata = client_metadata + self.provider_metadata = provider_metadata def __setattr__(self, key: str, value: typing.Any) -> None: - if hasattr(self, key) and key in ("flag_key", "flag_type", "default_value"): + if hasattr(self, key) and key in ( + "flag_key", + "flag_type", + "default_value", + "client_metadata", + "provider_metadata", + ): raise AttributeError(f"Attribute {key!r} is immutable") super().__setattr__(key, value) +# https://openfeature.dev/specification/sections/hooks/#requirement-421 +HookHints = typing.Mapping[ + str, + typing.Union[ + bool, + int, + float, + str, + datetime, + typing.List[typing.Any], + typing.Dict[str, typing.Any], + ], +] + + class Hook: def before( - self, hook_context: HookContext, hints: dict + self, hook_context: HookContext, hints: HookHints ) -> typing.Optional[EvaluationContext]: """ Runs before flag is resolved. @@ -54,7 +83,7 @@ def after( self, hook_context: HookContext, details: FlagEvaluationDetails[typing.Any], - hints: dict, + hints: HookHints, ) -> None: """ Runs after a flag is resolved. @@ -67,7 +96,7 @@ def after( pass def error( - self, hook_context: HookContext, exception: Exception, hints: dict + self, hook_context: HookContext, exception: Exception, hints: HookHints ) -> None: """ Run when evaluation encounters an error. Errors thrown will be swallowed. @@ -78,7 +107,7 @@ def error( """ pass - def finally_after(self, hook_context: HookContext, hints: dict) -> None: + def finally_after(self, hook_context: HookContext, hints: HookHints) -> None: """ Run after flag evaluation, including any error processing. This will always run. Errors will be swallowed. diff --git a/openfeature/hook/hook_support.py b/openfeature/hook/hook_support.py index 9bbfd492..349b25f3 100644 --- a/openfeature/hook/hook_support.py +++ b/openfeature/hook/hook_support.py @@ -4,7 +4,7 @@ from openfeature.evaluation_context import EvaluationContext from openfeature.flag_evaluation import FlagEvaluationDetails, FlagType -from openfeature.hook import Hook, HookContext, HookType +from openfeature.hook import Hook, HookContext, HookHints, HookType logger = logging.getLogger("openfeature") @@ -14,7 +14,7 @@ def error_hooks( hook_context: HookContext, exception: Exception, hooks: typing.List[Hook], - hints: typing.Optional[typing.Mapping] = None, + hints: typing.Optional[HookHints] = None, ) -> None: kwargs = {"hook_context": hook_context, "exception": exception, "hints": hints} _execute_hooks( @@ -26,7 +26,7 @@ def after_all_hooks( flag_type: FlagType, hook_context: HookContext, hooks: typing.List[Hook], - hints: typing.Optional[typing.Mapping] = None, + hints: typing.Optional[HookHints] = None, ) -> None: kwargs = {"hook_context": hook_context, "hints": hints} _execute_hooks( @@ -39,7 +39,7 @@ def after_hooks( hook_context: HookContext, details: FlagEvaluationDetails[typing.Any], hooks: typing.List[Hook], - hints: typing.Optional[typing.Mapping] = None, + hints: typing.Optional[HookHints] = None, ) -> None: kwargs = {"hook_context": hook_context, "details": details, "hints": hints} _execute_hooks_unchecked( @@ -51,7 +51,7 @@ def before_hooks( flag_type: FlagType, hook_context: HookContext, hooks: typing.List[Hook], - hints: typing.Optional[typing.Mapping] = None, + hints: typing.Optional[HookHints] = None, ) -> EvaluationContext: kwargs = {"hook_context": hook_context, "hints": hints} executed_hooks = _execute_hooks_unchecked( diff --git a/tests/hook/test_hook_support.py b/tests/hook/test_hook_support.py index 69ceb8da..37e06eee 100644 --- a/tests/hook/test_hook_support.py +++ b/tests/hook/test_hook_support.py @@ -40,10 +40,14 @@ def test_hook_context_has_immutable_and_mutable_fields(): 4.1.3 - The "flag key", "flag type", and "default value" properties MUST be immutable. 4.1.4.1 - The evaluation context MUST be mutable only within the before hook. + 4.2.2.2 - The client "metadata" field in the "hook context" MUST be immutable. + 4.2.2.3 - The provider "metadata" field in the "hook context" MUST be immutable. """ # Given - hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, EvaluationContext()) + hook_context = HookContext( + "flag_key", FlagType.BOOLEAN, True, EvaluationContext(), ClientMetadata("name") + ) # When with pytest.raises(AttributeError): @@ -52,10 +56,12 @@ def test_hook_context_has_immutable_and_mutable_fields(): hook_context.flag_type = FlagType.STRING with pytest.raises(AttributeError): hook_context.default_value = "new_value" + with pytest.raises(AttributeError): + hook_context.client_metadata = ClientMetadata("new_name") + with pytest.raises(AttributeError): + hook_context.provider_metadata = Metadata("name") hook_context.evaluation_context = EvaluationContext("targeting_key") - hook_context.client_metadata = ClientMetadata("name") - hook_context.provider_metadata = Metadata("name") # Then assert hook_context.flag_key == "flag_key" @@ -63,7 +69,7 @@ def test_hook_context_has_immutable_and_mutable_fields(): assert hook_context.default_value is True assert hook_context.evaluation_context.targeting_key == "targeting_key" assert hook_context.client_metadata.name == "name" - assert hook_context.provider_metadata.name == "name" + assert hook_context.provider_metadata is None def test_error_hooks_run_error_method(mock_hook):