diff --git a/README.md b/README.md index 5d4cd896..17e38eea 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,15 @@ api.set_provider(NoOpProvider()) open_feature_client = api.get_client() ``` +`set_provider()` is non-blocking: it registers the provider immediately and runs initialization in a background thread. +Flag evaluations during the initialization window return the default value with a `PROVIDER_NOT_READY` error code. +Use `set_provider_and_wait()` if you need to ensure the provider is ready before proceeding: + +```python +# blocks until the provider is initialized (or raises on failure) +api.set_provider_and_wait(NoOpProvider()) +``` + In some situations, it may be beneficial to register multiple providers in the same application. This is possible using [domains](#domains), which is covered in more detail below. diff --git a/openfeature/api.py b/openfeature/api.py index 817104ab..4585e50e 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -33,6 +33,7 @@ "remove_handler", "set_evaluation_context", "set_provider", + "set_provider_and_wait", "set_transaction_context", "set_transaction_context_propagator", "shutdown", @@ -52,6 +53,13 @@ def set_provider(provider: FeatureProvider, domain: str | None = None) -> None: provider_registry.set_provider(domain, provider) +def set_provider_and_wait(provider: FeatureProvider, domain: str | None = None) -> None: + if domain is None: + provider_registry.set_default_provider(provider, wait_for_init=True) + else: + provider_registry.set_provider(domain, provider, wait_for_init=True) + + def clear_providers() -> None: provider_registry.clear_providers() _event_support.clear() diff --git a/openfeature/provider/_registry.py b/openfeature/provider/_registry.py index bf8fa9a8..d0718783 100644 --- a/openfeature/provider/_registry.py +++ b/openfeature/provider/_registry.py @@ -1,3 +1,5 @@ +import threading + from openfeature._event_support import run_handlers_for_provider from openfeature.evaluation_context import EvaluationContext, get_evaluation_context from openfeature.event import ( @@ -13,77 +15,144 @@ class ProviderRegistry: _default_provider: FeatureProvider _providers: dict[str, FeatureProvider] _provider_status: dict[FeatureProvider, ProviderStatus] + _lock: threading.RLock def __init__(self) -> None: + self._lock = threading.RLock() self._default_provider = NoOpProvider() self._providers = {} self._provider_status = { self._default_provider: ProviderStatus.READY, } - def set_provider(self, domain: str, provider: FeatureProvider) -> None: + def set_provider( + self, domain: str, provider: FeatureProvider, wait_for_init: bool = False + ) -> None: if provider is None: raise GeneralError(error_message="No provider") if domain is None: raise GeneralError(error_message="No domain") - providers = self._providers - if domain in providers: - old_provider = providers[domain] - del providers[domain] - if ( - old_provider != self._default_provider - and old_provider not in providers.values() - ): - self._shutdown_provider(old_provider) - if provider != self._default_provider and provider not in providers.values(): - self._initialize_provider(provider) - providers[domain] = provider + + old_provider: FeatureProvider | None = None + needs_init = False + with self._lock: + old_provider = self._providers.get(domain) + self._providers[domain] = provider + already_bound = provider is self._default_provider or any( + p is provider for d, p in self._providers.items() if d != domain + ) + if not already_bound: + needs_init = True + self._provider_status[provider] = ProviderStatus.NOT_READY + + if needs_init: + self._initialize_provider(provider, wait_for_init=wait_for_init) + + # old-provider shutdown is always async so a hanging shutdown() cannot + # block set_provider. + if old_provider is not None and old_provider is not provider: + self._shutdown_if_unused(old_provider) def get_provider(self, domain: str | None) -> FeatureProvider: if domain is None: return self._default_provider return self._providers.get(domain, self._default_provider) - def set_default_provider(self, provider: FeatureProvider) -> None: + def set_default_provider( + self, provider: FeatureProvider, wait_for_init: bool = False + ) -> None: if provider is None: raise GeneralError(error_message="No provider") - if ( - self._default_provider - and self._default_provider not in self._providers.values() - ): - self._shutdown_provider(self._default_provider) - self._default_provider = provider - if self._default_provider not in self._providers.values(): - self._initialize_provider(provider) + old_provider: FeatureProvider | None = None + needs_init = False + with self._lock: + old_provider = self._default_provider + self._default_provider = provider + if ( + provider is not old_provider + and provider not in self._providers.values() + ): + needs_init = True + self._provider_status[provider] = ProviderStatus.NOT_READY + + if needs_init: + self._initialize_provider(provider, wait_for_init=wait_for_init) + + if old_provider is not None and old_provider is not provider: + self._shutdown_if_unused(old_provider) def get_default_provider(self) -> FeatureProvider: return self._default_provider def clear_providers(self) -> None: self.shutdown() - self._providers.clear() - self._default_provider = NoOpProvider() - self._provider_status = { - self._default_provider: ProviderStatus.READY, - } + with self._lock: + self._providers.clear() + self._default_provider = NoOpProvider() + self._provider_status = { + self._default_provider: ProviderStatus.READY, + } def shutdown(self) -> None: - for provider in {self._default_provider, *self._providers.values()}: + with self._lock: + providers = {self._default_provider, *self._providers.values()} + + for provider in providers: self._shutdown_provider(provider) def _get_evaluation_context(self) -> EvaluationContext: return get_evaluation_context() - def _initialize_provider(self, provider: FeatureProvider) -> None: + def _initialize_provider( + self, provider: FeatureProvider, wait_for_init: bool + ) -> None: provider.attach(self.dispatch_event) + if not hasattr(provider, "initialize"): + # nothing async to do; dispatch READY synchronously. + self.dispatch_event( + provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails() + ) + return + if wait_for_init: + self._run_initialize(provider, raise_on_error=True) + return + + thread = threading.Thread( + target=self._run_initialize, + args=(provider,), + kwargs={"raise_on_error": False}, + daemon=True, + ) + thread.start() + + def _run_initialize( + self, provider: FeatureProvider, raise_on_error: bool = False + ) -> None: try: - if hasattr(provider, "initialize"): - provider.initialize(self._get_evaluation_context()) + provider.initialize(self._get_evaluation_context()) + # stale init: provider was replaced/shut down during initialize(); drop event. + # Check active registration, not _provider_status, since replaced providers + # remain in _provider_status until async shutdown pops them. + with self._lock: + if ( + provider is not self._default_provider + and provider not in self._providers.values() + ): + return self.dispatch_event( provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails() ) except Exception as err: + # stale init: provider was replaced/shut down during initialize(); drop event. + # Check active registration, not _provider_status, since replaced providers + # remain in _provider_status until async shutdown pops them. + with self._lock: + if ( + provider is not self._default_provider + and provider not in self._providers.values() + ): + return error_code = ( err.error_code if isinstance(err, OpenFeatureError) @@ -97,12 +166,29 @@ def _initialize_provider(self, provider: FeatureProvider) -> None: error_code=error_code, ), ) + if raise_on_error: + raise + + def _shutdown_if_unused(self, provider: FeatureProvider) -> None: + # only shut down if no longer referenced. shutdown runs on a daemon + # thread so a hanging shutdown() cannot block the caller. + with self._lock: + if provider is self._default_provider: + return + if provider in self._providers.values(): + return + + thread = threading.Thread( + target=self._shutdown_provider, args=(provider,), daemon=True + ) + thread.start() def _shutdown_provider(self, provider: FeatureProvider) -> None: try: if hasattr(provider, "shutdown"): provider.shutdown() - del self._provider_status[provider] + with self._lock: + self._provider_status.pop(provider, None) except Exception as err: self.dispatch_event( provider, @@ -132,17 +218,18 @@ def _update_provider_status( event: ProviderEvent, details: ProviderEventDetails, ) -> None: - if event == ProviderEvent.PROVIDER_READY: - self._provider_status[provider] = ProviderStatus.READY - elif event == ProviderEvent.PROVIDER_STALE: - self._provider_status[provider] = ProviderStatus.STALE - elif event == ProviderEvent.PROVIDER_ERROR: - status = ( - ProviderStatus.FATAL - if details.error_code == ErrorCode.PROVIDER_FATAL - else ProviderStatus.ERROR - ) - self._provider_status[provider] = status + with self._lock: + if event == ProviderEvent.PROVIDER_READY: + self._provider_status[provider] = ProviderStatus.READY + elif event == ProviderEvent.PROVIDER_STALE: + self._provider_status[provider] = ProviderStatus.STALE + elif event == ProviderEvent.PROVIDER_ERROR: + status = ( + ProviderStatus.FATAL + if details.error_code == ErrorCode.PROVIDER_FATAL + else ProviderStatus.ERROR + ) + self._provider_status[provider] = status provider_registry = ProviderRegistry() diff --git a/tests/conftest.py b/tests/conftest.py index 1f0a7982..495634c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,5 +15,5 @@ def clear_providers(): @pytest.fixture() def no_op_provider_client(): - api.set_provider(NoOpProvider()) + api.set_provider_and_wait(NoOpProvider()) return api.get_client() diff --git a/tests/features/steps/metadata_steps.py b/tests/features/steps/metadata_steps.py index 0154a9f0..bed87d17 100644 --- a/tests/features/steps/metadata_steps.py +++ b/tests/features/steps/metadata_steps.py @@ -1,13 +1,13 @@ from behave import given, then -from openfeature.api import get_client, set_provider +from openfeature.api import get_client, set_provider_and_wait from openfeature.provider.in_memory_provider import InMemoryProvider from tests.features.data import IN_MEMORY_FLAGS @given("a stable provider") def step_impl_stable_provider(context): - set_provider(InMemoryProvider(IN_MEMORY_FLAGS)) + set_provider_and_wait(InMemoryProvider(IN_MEMORY_FLAGS)) context.client = get_client() diff --git a/tests/features/steps/steps.py b/tests/features/steps/steps.py index 5d9d38fd..9b699331 100644 --- a/tests/features/steps/steps.py +++ b/tests/features/steps/steps.py @@ -4,7 +4,7 @@ from behave import given, then, when -from openfeature.api import get_client, set_provider +from openfeature.api import get_client, set_provider_and_wait from openfeature.client import OpenFeatureClient from openfeature.evaluation_context import EvaluationContext from openfeature.exception import ErrorCode @@ -28,13 +28,13 @@ def step_impl_resolved_should_be(context, flag_type, key, expected_reason): @given("a provider is registered with cache disabled") def step_impl_provider_without_cache(context): - set_provider(InMemoryProvider(IN_MEMORY_FLAGS)) + set_provider_and_wait(InMemoryProvider(IN_MEMORY_FLAGS)) context.client = get_client() @given("a provider is registered") def step_impl_provider(context): - set_provider(InMemoryProvider(IN_MEMORY_FLAGS)) + set_provider_and_wait(InMemoryProvider(IN_MEMORY_FLAGS)) context.client = get_client() diff --git a/tests/provider/test_registry.py b/tests/provider/test_registry.py index b5e10503..c90bd70f 100644 --- a/tests/provider/test_registry.py +++ b/tests/provider/test_registry.py @@ -1,8 +1,10 @@ +import threading +import time from unittest.mock import Mock import pytest -from openfeature.exception import GeneralError +from openfeature.exception import GeneralError, ProviderFatalError from openfeature.provider import ProviderStatus from openfeature.provider._registry import ProviderRegistry from openfeature.provider.no_op_provider import NoOpProvider @@ -67,8 +69,8 @@ def test_registering_provider_for_first_time_initializes_it(): registry = ProviderRegistry() provider = Mock() - registry.set_provider("domain1", provider) - registry.set_provider("domain2", provider) + registry.set_provider("domain1", provider, wait_for_init=True) + registry.set_provider("domain2", provider, wait_for_init=True) provider.initialize.assert_called_once() @@ -103,7 +105,7 @@ def test_setting_default_provider_initializes_it(): registry = ProviderRegistry() provider = Mock() - registry.set_default_provider(provider) + registry.set_default_provider(provider, wait_for_init=True) provider.initialize.assert_called_once() @@ -114,8 +116,8 @@ def test_registering_provider_as_default_then_domain_only_initializes_once(): registry = ProviderRegistry() provider = Mock() - registry.set_default_provider(provider) - registry.set_provider("domain", provider) + registry.set_default_provider(provider, wait_for_init=True) + registry.set_provider("domain", provider, wait_for_init=True) provider.initialize.assert_called_once() @@ -126,8 +128,8 @@ def test_registering_provider_as_domain_then_default_only_initializes_once(): registry = ProviderRegistry() provider = Mock() - registry.set_provider("domain", provider) - registry.set_default_provider(provider) + registry.set_provider("domain", provider, wait_for_init=True) + registry.set_default_provider(provider, wait_for_init=True) provider.initialize.assert_called_once() @@ -191,7 +193,7 @@ def test_initializing_provider_sets_status_ready(): assert registry.get_provider_status(provider) == ProviderStatus.NOT_READY - registry.set_provider("domain", provider) + registry.set_provider("domain", provider, wait_for_init=True) provider.initialize.assert_called_once() assert registry.get_provider_status(provider) == ProviderStatus.READY @@ -203,7 +205,7 @@ def test_shutting_down_provider_sets_status_not_ready(): registry = ProviderRegistry() provider = Mock() - registry.set_provider("domain", provider) + registry.set_provider("domain", provider, wait_for_init=True) assert registry.get_provider_status(provider) == ProviderStatus.READY registry.shutdown() @@ -216,8 +218,8 @@ def test_clearing_registry_resets_providers_and_default(): registry = ProviderRegistry() provider = Mock() - registry.set_provider("domain", provider) - registry.set_default_provider(provider) + registry.set_provider("domain", provider, wait_for_init=True) + registry.set_default_provider(provider, wait_for_init=True) registry.clear_providers() @@ -228,3 +230,149 @@ def test_clearing_registry_resets_providers_and_default(): provider.initialize.assert_called_once() provider.shutdown.assert_called_once() + + +def test_set_provider_returns_before_initialization_completes(): + """Test that set_provider (non-blocking) returns before initialize finishes.""" + + registry = ProviderRegistry() + init_started = threading.Event() + init_may_proceed = threading.Event() + provider = Mock() + + def slow_initialize(ctx): + init_started.set() + init_may_proceed.wait() + + provider.initialize.side_effect = slow_initialize + + registry.set_provider("domain", provider) + + assert init_started.wait(timeout=2), "initialize was never called in background" + assert registry.get_provider_status(provider) == ProviderStatus.NOT_READY + + init_may_proceed.set() # unblock the background thread + + +def test_set_provider_and_wait_blocks_until_ready(): + """Test that set_provider with wait_for_init=True blocks until READY.""" + + registry = ProviderRegistry() + initialized = threading.Event() + provider = Mock() + + def tracking_initialize(ctx): + initialized.set() + + provider.initialize.side_effect = tracking_initialize + + registry.set_provider("domain", provider, wait_for_init=True) + + assert initialized.is_set() + assert registry.get_provider_status(provider) == ProviderStatus.READY + + +def test_set_provider_and_wait_reraises_on_error(): + """Test that set_provider with wait_for_init=True re-raises initialization errors.""" + registry = ProviderRegistry() + provider = Mock() + provider.initialize.side_effect = ProviderFatalError() + + with pytest.raises(ProviderFatalError): + registry.set_provider("domain", provider, wait_for_init=True) + + +def test_concurrent_set_provider_for_same_provider_initializes_once(): + """Concurrent set_provider calls for different domains using the same + provider instance must only initialize the provider once.""" + + registry = ProviderRegistry() + init_count = 0 + start_gate = threading.Event() + + def slow_initialize(ctx): + nonlocal init_count + # widen the window in which two threads can both observe "not bound" + start_gate.wait(timeout=2) + init_count += 1 + + provider = Mock() + provider.initialize.side_effect = slow_initialize + + def call(domain): + registry.set_provider(domain, provider, wait_for_init=True) + + t1 = threading.Thread(target=call, args=("d1",)) + t2 = threading.Thread(target=call, args=("d2",)) + t1.start() + t2.start() + start_gate.set() + t1.join(timeout=5) + t2.join(timeout=5) + + assert init_count == 1 + + +def test_provider_replaced_during_async_init_does_not_set_ready_status(): + """If a provider is replaced while its async initialize is still running, + the late PROVIDER_READY event must not resurrect its status.""" + + registry = ProviderRegistry() + init_started = threading.Event() + init_may_proceed = threading.Event() + + slow_provider = Mock() + + def slow_initialize(ctx): + init_started.set() + init_may_proceed.wait(timeout=2) + + slow_provider.initialize.side_effect = slow_initialize + + registry.set_provider("domain", slow_provider) + assert init_started.wait(timeout=2) + + # replace with a different provider before the slow init finishes + replacement = Mock() + registry.set_provider("domain", replacement, wait_for_init=True) + + # now let the slow init complete + init_may_proceed.set() + # give the background thread a moment to attempt its (stale) dispatch + time.sleep(0.1) + + # stale event must not have set READY on the replaced provider + assert registry.get_provider_status(slow_provider) == ProviderStatus.NOT_READY + assert registry.get_provider_status(replacement) == ProviderStatus.READY + + +def test_set_provider_does_not_block_on_hanging_old_shutdown(): + """If the previously-registered provider's shutdown() hangs, a subsequent + set_provider call must not be blocked by it.""" + + registry = ProviderRegistry() + + hanging = Mock() + hang = threading.Event() + hanging.shutdown.side_effect = lambda: hang.wait(timeout=5) + + replacement = Mock() + + registry.set_provider("domain", hanging, wait_for_init=True) + + completed = threading.Event() + + def replace(): + registry.set_provider("domain", replacement, wait_for_init=True) + completed.set() + + threading.Thread(target=replace, daemon=True).start() + + # the swap+init of replacement must complete even though `hanging.shutdown` + # is still blocked. + assert completed.wait(timeout=2), ( + "set_provider was blocked by old provider's hanging shutdown()" + ) + + # release the hung shutdown so the test can clean up + hang.set() diff --git a/tests/test_api.py b/tests/test_api.py index cacdf694..b7945cbb 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,3 +1,4 @@ +import threading from unittest.mock import MagicMock import pytest @@ -14,6 +15,7 @@ remove_handler, set_evaluation_context, set_provider, + set_provider_and_wait, shutdown, ) from openfeature.evaluation_context import EvaluationContext @@ -69,7 +71,7 @@ def test_should_invoke_provider_initialize_function_on_newly_registered_provider # When set_evaluation_context(evaluation_context) - set_provider(provider) + set_provider_and_wait(provider) # Then provider.initialize.assert_called_with(evaluation_context) @@ -170,10 +172,10 @@ def test_should_provide_a_function_to_bind_provider_through_domain(): def test_should_not_initialize_provider_already_bound_to_another_domain(): # Given provider = MagicMock(spec=FeatureProvider) - set_provider(provider, "foo") + set_provider_and_wait(provider, "foo") # When - set_provider(provider, "bar") + set_provider_and_wait(provider, "bar") # Then provider.initialize.assert_called_once() @@ -326,7 +328,7 @@ def test_add_remove_event_handler(): def test_handlers_attached_to_provider_already_in_associated_state_should_run_immediately(): # Given provider = NoOpProvider() - set_provider(provider) + set_provider_and_wait(provider) spy = MagicMock() # When @@ -345,7 +347,7 @@ def test_provider_ready_handlers_run_if_provider_initialize_function_terminates_ spy.reset_mock() # reset the mock to avoid counting the immediate call on subscribe # When - set_provider(provider) + set_provider_and_wait(provider) # Then spy.provider_ready.assert_called_once() @@ -360,7 +362,8 @@ def test_provider_error_handlers_run_if_provider_initialize_function_terminates_ add_handler(ProviderEvent.PROVIDER_ERROR, spy.provider_error) # When - set_provider(provider) + with pytest.raises(ProviderFatalError): + set_provider_and_wait(provider) # Then spy.provider_error.assert_called_once() @@ -369,7 +372,7 @@ def test_provider_error_handlers_run_if_provider_initialize_function_terminates_ def test_provider_status_is_updated_after_provider_emits_event(): # Given provider = NoOpProvider() - set_provider(provider) + set_provider_and_wait(provider) client = get_client() # When @@ -393,3 +396,103 @@ def test_provider_status_is_updated_after_provider_emits_event(): provider.emit_provider_ready(ProviderEventDetails()) # Then assert client.get_provider_status() == ProviderStatus.READY + + +# Non-blocking set_provider tests + + +def test_set_provider_returns_before_initialization_completes(): + # Given: a provider whose initialize blocks until signalled + init_started = threading.Event() + init_may_proceed = threading.Event() + + provider = MagicMock(spec=FeatureProvider) + + def slow_initialize(ctx): + init_started.set() + init_may_proceed.wait() + + provider.initialize.side_effect = slow_initialize + + # When + set_provider(provider) + + # Then: set_provider returned before initialize completed (we reached this line + # while the background thread is still blocked inside initialize) + assert init_started.wait(timeout=2), "initialize was never called" + init_may_proceed.set() # unblock the background thread + + +def test_provider_status_is_not_ready_during_async_initialization(): + # Given: a provider whose initialize blocks until signalled + init_may_proceed = threading.Event() + provider = MagicMock(spec=FeatureProvider) + + def slow_initialize(ctx): + init_may_proceed.wait() + + provider.initialize.side_effect = slow_initialize + + # When + set_provider(provider) + client = get_client() + + # Then: status is NOT_READY while init is still running + assert client.get_provider_status() == ProviderStatus.NOT_READY + + # Cleanup: let the background thread finish + init_may_proceed.set() + + +def test_set_provider_and_wait_blocks_until_initialization_completes(): + # Given + initialized = threading.Event() + provider = MagicMock(spec=FeatureProvider) + + def slow_initialize(ctx): + initialized.set() + + provider.initialize.side_effect = slow_initialize + + # When + set_provider_and_wait(provider) + + # Then: initialize was called before set_provider_and_wait returned + assert initialized.is_set() + assert get_client().get_provider_status() == ProviderStatus.READY + + +def test_set_provider_and_wait_reraises_on_failure(): + # Given + provider = MagicMock(spec=FeatureProvider) + provider.initialize.side_effect = ProviderFatalError() + + # When / Then + with pytest.raises(ProviderFatalError): + set_provider_and_wait(provider) + + +def test_set_provider_swallows_error_and_emits_provider_error_event(): + # Given + provider = MagicMock(spec=FeatureProvider) + error_fired = threading.Event() + + def failing_initialize(ctx): + raise ProviderFatalError() + + provider.initialize.side_effect = failing_initialize + + spy = MagicMock() + + def on_error(details): + spy.on_error(details) + error_fired.set() + + add_handler(ProviderEvent.PROVIDER_ERROR, on_error) + + # When: non-blocking set_provider — must not raise + set_provider(provider) + + # Then: error event fired, exception was not propagated + assert error_fired.wait(timeout=2), "PROVIDER_ERROR event was never fired" + spy.on_error.assert_called_once()