From d60e2bbba2ca2a31e3e584ffa4b3be130514847e Mon Sep 17 00:00:00 2001 From: z3z1ma Date: Fri, 12 Apr 2024 23:05:31 -0700 Subject: [PATCH] chore: patch for dbt 1.5+ compat --- dbt_feature_flags/__init__.py | 2 +- dbt_feature_flags/base.py | 20 ++++++++++---------- dbt_feature_flags/patch.py | 18 +++++++++++++----- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/dbt_feature_flags/__init__.py b/dbt_feature_flags/__init__.py index b4dd59b..c35c240 100644 --- a/dbt_feature_flags/__init__.py +++ b/dbt_feature_flags/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0" +__version__ = "0.5.2" diff --git a/dbt_feature_flags/base.py b/dbt_feature_flags/base.py index f7462e9..60f2dea 100644 --- a/dbt_feature_flags/base.py +++ b/dbt_feature_flags/base.py @@ -14,8 +14,8 @@ import abc import logging +import typing as t from functools import wraps -from typing import Any, Union, final class BaseFeatureFlagsClient(abc.ABC): @@ -28,41 +28,41 @@ class BaseFeatureFlagsClient(abc.ABC): def __init__(self) -> None: self._add_validators() - @final + @t.final def _add_validators(self): self.bool_variation = validate(types=(bool,))(self.bool_variation) self.string_variation = validate(types=(str,))(self.string_variation) self.number_variation = validate(types=(float, int))(self.number_variation) - self.json_variation = validate(types=(dict, list, None))(self.json_variation) + self.json_variation = validate(types=(dict, list, None))(self.json_variation) # type: ignore @abc.abstractmethod - def bool_variation(self, flag: str, default: Any) -> bool: + def bool_variation(self, flag: str, default: t.Any) -> bool: raise NotImplementedError( "Boolean feature flags are not implemented for this driver" ) @abc.abstractmethod - def string_variation(self, flag: str, default: Any) -> str: + def string_variation(self, flag: str, default: t.Any) -> str: raise NotImplementedError( "String feature flags are not implemented for this driver" ) @abc.abstractmethod - def number_variation(self, flag: str, default: Any) -> Union[float, int]: + def number_variation(self, flag: str, default: t.Any) -> t.Union[float, int]: raise NotImplementedError( "Number feature flags are not implemented for this driver" ) @abc.abstractmethod - def json_variation(self, flag: str, default: Any) -> Union[dict, list]: + def json_variation(self, flag: str, default: t.Any) -> t.Union[dict, list]: raise NotImplementedError( "JSON feature flags are not implemented for this driver" ) -def validate(types: Union[list, tuple]): +def validate(types: t.Tuple[t.Type[t.Any], ...]): def _validate(v, flag_name, func_name): - if not isinstance(v, types): + if not isinstance(v, tuple(types)): raise ValueError( f"Invalid return value for {func_name}({flag_name}...) feature flag call. Found type {type(v).__name__}." ) @@ -70,7 +70,7 @@ def _validate(v, flag_name, func_name): def _main(func): @wraps(func) - def _injected_validator(flag: str, default: Any = func.__defaults__[0]): + def _injected_validator(flag: str, default: t.Any = func.__defaults__[0]): if not isinstance(default, types): raise ValueError( f"Invalid default value: {default} for {func.__name__}({flag}...) feature flag call. Found type {type(default).__name__}." diff --git a/dbt_feature_flags/patch.py b/dbt_feature_flags/patch.py index 0e1537a..5658dae 100644 --- a/dbt_feature_flags/patch.py +++ b/dbt_feature_flags/patch.py @@ -17,10 +17,13 @@ import typing as t from enum import Enum from functools import wraps +from types import SimpleNamespace from dbt_feature_flags import base, harness, launchdarkly -_MOCK_CLIENT = object() +MockClient = t.NewType("MockClient", type(object())) + +_MOCK_CLIENT = t.cast(MockClient, object()) class SupportedProviders(str, Enum): @@ -34,7 +37,7 @@ def _is_truthy(value: str) -> bool: return value.lower() in ("1", "true", "yes") -def _get_client() -> base.BaseFeatureFlagsClient | _MOCK_CLIENT: +def _get_client() -> base.BaseFeatureFlagsClient | MockClient | None: """Return the user specified client. Valid implementations MUST inherit from BaseFeatureFlagsClient. @@ -90,16 +93,21 @@ def _wrapped( ctx["feature_flag_json"] = client.json_variation return fn(string, ctx, node, capture_macros, native) - _wrapped.status = "patched" + _wrapped.status = "patched" # type: ignore return _wrapped def patch_dbt_environment() -> None: """Patch dbt's jinja environment to include feature flag functions.""" + import dbt.flags from dbt.clients import jinja - jinja._get_rendered = jinja.get_rendered - jinja.get_rendered = get_rendered(jinja._get_rendered, _get_client()) + # small patch to make compatible with dbt 1.5+ + g_flags = getattr(dbt.flags, "get_flags", lambda: SimpleNamespace()) + g_flags().MACRO_DEBUGGING = False + + jinja._get_rendered = jinja.get_rendered # type: ignore + jinja.get_rendered = get_rendered(jinja._get_rendered, _get_client()) # type: ignore if __name__ == "__main__":