diff --git a/api/environments/identities/models.py b/api/environments/identities/models.py index 80611c521860..9fe99df13df9 100644 --- a/api/environments/identities/models.py +++ b/api/environments/identities/models.py @@ -219,9 +219,11 @@ def generate_traits(self, trait_data_items, persist=False): identity=self, **Trait.generate_trait_value_data(trait_value), ) - trait_models.append(trait) - if not trait_data_item.get("transient"): + if trait_data_item.get("transient"): + trait.transient = True + else: trait_models_to_persist.append(trait) + trait_models.append(trait) if persist: Trait.objects.bulk_create(trait_models_to_persist) @@ -252,13 +254,13 @@ def update_traits( transient = trait_data_item.get("transient") if transient: - transient_traits.append( - Trait( - **Trait.generate_trait_value_data(trait_value), - trait_key=trait_key, - identity=self, - ) + trait = Trait( + **Trait.generate_trait_value_data(trait_value), + trait_key=trait_key, + identity=self, ) + trait.transient = True + transient_traits.append(trait) continue if trait_value is None: @@ -304,7 +306,7 @@ def update_traits( # See: https://github.com/Flagsmith/flagsmith/issues/370 Trait.objects.bulk_create(new_traits, ignore_conflicts=True) - # return the full list of traits for this identity by refreshing from the db + # return the full list of traits for this identity # override persisted traits by transient traits in case of key collisions return [ *{ diff --git a/api/environments/identities/traits/models.py b/api/environments/identities/traits/models.py index b1b5d70ed8f9..b21f686e3bf3 100644 --- a/api/environments/identities/traits/models.py +++ b/api/environments/identities/traits/models.py @@ -57,6 +57,14 @@ def natural_key(self): def trait_value(self): return self.get_trait_value() + @property + def transient(self) -> bool: + return getattr(self, "_transient", False) + + @transient.setter + def transient(self, transient: bool) -> None: + self._transient = transient + def get_trait_value(self): try: value_type = self.value_type diff --git a/api/environments/identities/traits/serializers.py b/api/environments/identities/traits/serializers.py index 47b938506a8e..e55640d23767 100644 --- a/api/environments/identities/traits/serializers.py +++ b/api/environments/identities/traits/serializers.py @@ -22,7 +22,7 @@ def get_trait_value(obj): class TraitSerializerBasic(serializers.ModelSerializer): trait_value = TraitValueField(allow_null=True) - transient = serializers.BooleanField(default=False, write_only=True) + transient = serializers.BooleanField(default=False) class Meta: model = Trait diff --git a/api/tests/integration/environments/identities/test_integration_identities.py b/api/tests/integration/environments/identities/test_integration_identities.py index 1e6a3fa81f82..445eedacfbb4 100644 --- a/api/tests/integration/environments/identities/test_integration_identities.py +++ b/api/tests/integration/environments/identities/test_integration_identities.py @@ -284,8 +284,9 @@ def test_get_feature_states_for_identity__transient_trait__segment_match_expecte url = reverse("api-v1:sdk-identities") # When - # flags are requested for a new transient identity + # flags are requested for a new identity # that matches the segment + # with a transient trait response = sdk_client.post( url, data=json.dumps( @@ -310,6 +311,20 @@ def test_get_feature_states_for_identity__transient_trait__segment_match_expecte # Then assert response.status_code == status.HTTP_200_OK response_json = response.json() + assert response_json["traits"] == [ + { + "id": mock.ANY, + "trait_key": segment_condition_property, + "trait_value": segment_condition_value, + "transient": True, + }, + { + "id": mock.ANY, + "trait_key": "persistent", + "trait_value": "trait value", + "transient": False, + }, + ] assert ( flag_data := next( ( @@ -322,3 +337,54 @@ def test_get_feature_states_for_identity__transient_trait__segment_match_expecte ) assert flag_data["enabled"] is True assert flag_data["feature_state_value"] == "segment override" + + +def test_get_feature_states_for_identity__transient_trait__existing_identity__return_expected( + sdk_client: APIClient, + identity_identifier: str, + identity: int, +) -> None: + # Given + url = reverse("api-v1:sdk-identities") + + # When + # flags are requested for an existing identity + # with a transient trait + response = sdk_client.post( + url, + data=json.dumps( + { + "identifier": identity_identifier, + "traits": [ + { + "trait_key": "transient", + "trait_value": "trait value", + "transient": True, + }, + { + "trait_key": "persistent", + "trait_value": "trait value", + }, + ], + } + ), + content_type="application/json", + ) + + # Then + assert response.status_code == status.HTTP_200_OK + response_json = response.json() + assert response_json["traits"] == [ + { + "id": mock.ANY, + "trait_key": "persistent", + "trait_value": "trait value", + "transient": False, + }, + { + "id": mock.ANY, + "trait_key": "transient", + "trait_value": "trait value", + "transient": True, + }, + ] diff --git a/api/tests/unit/environments/identities/test_unit_identities_views.py b/api/tests/unit/environments/identities/test_unit_identities_views.py index 55052511be85..bf9e030058d0 100644 --- a/api/tests/unit/environments/identities/test_unit_identities_views.py +++ b/api/tests/unit/environments/identities/test_unit_identities_views.py @@ -1176,13 +1176,17 @@ def test_post_identities__transient_traits__no_persistence( ) -> None: # Given identifier = "transient" - trait_key = "trait_key" + transient_trait_key = "trait_key" + non_transient_trait_key = "other" api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key) url = reverse("api-v1:sdk-identities") data = { "identifier": identifier, - "traits": [{"trait_key": trait_key, "trait_value": "bar", "transient": True}], + "traits": [ + {"trait_key": transient_trait_key, "trait_value": "bar", "transient": True}, + {"trait_key": non_transient_trait_key, "trait_value": "value"}, + ], } # When @@ -1193,7 +1197,8 @@ def test_post_identities__transient_traits__no_persistence( # Then assert response.status_code == status.HTTP_200_OK assert Identity.objects.filter(identifier=identifier).exists() - assert not Trait.objects.filter(trait_key=trait_key).exists() + assert Trait.objects.filter(trait_key=non_transient_trait_key).exists() + assert not Trait.objects.filter(trait_key=transient_trait_key).exists() def test_user_with_view_identities_permission_can_retrieve_identity(