diff --git a/src/prefect/cache_policies.py b/src/prefect/cache_policies.py index af125d076354..61dfc4f91219 100644 --- a/src/prefect/cache_policies.py +++ b/src/prefect/cache_policies.py @@ -80,10 +80,12 @@ def compute_key( raise NotImplementedError def __sub__(self, other: str) -> "CachePolicy": + "No-op for all policies except Inputs and Compound" + + # for interface compatibility if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance] raise TypeError("Can only subtract strings from key policies.") - new = Inputs(exclude=[other]) - return CompoundCachePolicy(policies=[self, new]) + return self def __add__(self, other: "CachePolicy") -> "CachePolicy": # adding _None is a no-op @@ -214,8 +216,15 @@ def __add__(self, other: "CachePolicy") -> "CachePolicy": def __sub__(self, other: str) -> "CachePolicy": if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance] raise TypeError("Can only subtract strings from key policies.") - new = Inputs(exclude=[other]) - return CompoundCachePolicy(policies=[*self.policies, new]) + + inputs_policies = [p for p in self.policies if isinstance(p, Inputs)] + + if inputs_policies: + new = Inputs(exclude=[other]) + return CompoundCachePolicy(policies=[*self.policies, new]) + else: + # no dependency on inputs already + return self @dataclass diff --git a/tests/test_cache_policies.py b/tests/test_cache_policies.py index 3d02d5a7b904..8173e40c8fb3 100644 --- a/tests/test_cache_policies.py +++ b/tests/test_cache_policies.py @@ -93,12 +93,24 @@ def test_key_excludes_excluded_inputs(self): ) assert new_key == key - def test_subtraction_results_in_new_policy(self): + def test_subtraction_results_in_new_policy_for_inputs(self): policy = Inputs() new_policy = policy - "foo" assert policy != new_policy assert policy.exclude != new_policy.exclude + @pytest.mark.parametrize("policy", [RunId(), RunId() + TaskSource()]) + def test_subtraction_is_noop_for_non_inputs_policies(self, policy): + new_policy = policy - "foo" + assert policy is new_policy + assert policy.compute_key( + task_ctx=None, + inputs={"foo": 42, "y": "changing-value"}, + flow_parameters=None, + ) == policy.compute_key( + task_ctx=None, inputs={"foo": 42, "y": "changed"}, flow_parameters=None + ) + def test_excluded_can_be_manipulated_via_subtraction(self): policy = Inputs() - "y" assert policy.exclude == ["y"] @@ -129,15 +141,15 @@ def test_addition_creates_new_policies(self): assert policy != two assert policy.policies != two.policies - def test_subtraction_creates_new_policies(self): - policy = CompoundCachePolicy(policies=[]) + def test_subtraction_creates_new_policies_if_input_dependency(self): + policy = CompoundCachePolicy(policies=[Inputs()]) new_policy = policy - "foo" assert isinstance(new_policy, CompoundCachePolicy) assert policy != new_policy assert policy.policies != new_policy.policies def test_creation_via_subtraction(self): - one = RunId() + one = DEFAULT policy = one - "y" assert isinstance(policy, CompoundCachePolicy)