From 4f71482f06221dee2f22fdf75156991623bfead8 Mon Sep 17 00:00:00 2001 From: Kim Gustyr Date: Thu, 2 Jan 2025 10:53:44 +0000 Subject: [PATCH] feat: Support `--exclusive-start-key` option for `ensure_identity_traits_blanks` (#4941) --- .../commands/ensure_identity_traits_blanks.py | 29 +++++++++---- api/environments/dynamodb/wrappers/base.py | 41 ++++++++++++++++--- .../dynamodb/wrappers/environment_wrapper.py | 4 +- .../edge_api/identities/conftest.py | 2 +- .../identities/test_edge_identity_viewset.py | 2 +- .../edge_api/test_unit_edge_api_commands.py | 22 ++++++++++ 6 files changed, 84 insertions(+), 16 deletions(-) diff --git a/api/edge_api/management/commands/ensure_identity_traits_blanks.py b/api/edge_api/management/commands/ensure_identity_traits_blanks.py index 53eabc3d1345..b8e34acd7580 100644 --- a/api/edge_api/management/commands/ensure_identity_traits_blanks.py +++ b/api/edge_api/management/commands/ensure_identity_traits_blanks.py @@ -1,27 +1,42 @@ +import json +from argparse import ArgumentParser from typing import Any +import structlog from django.core.management import BaseCommand -from structlog import get_logger -from structlog.stdlib import BoundLogger from environments.dynamodb import DynamoIdentityWrapper identity_wrapper = DynamoIdentityWrapper() +logger: structlog.BoundLogger = structlog.get_logger() LOG_COUNT_EVERY = 100_000 class Command(BaseCommand): - def handle(self, *args: Any, **options: Any) -> None: + def add_arguments(self, parser: ArgumentParser) -> None: + parser.add_argument( + "--exclusive-start-key", + dest="exclusive_start_key", + type=str, + default="", + help="Exclusive start key in valid JSON", + ) + + def handle(self, *args: Any, exclusive_start_key: str, **options: Any) -> None: total_count = identity_wrapper.table.item_count - scanned_count = 0 - fixed_count = 0 + scanned_count = scanned_percentage = fixed_count = 0 + + log: structlog.BoundLogger = logger.bind(total_count=total_count) + + kwargs = {} + if exclusive_start_key: + kwargs["ExclusiveStartKey"] = json.loads(exclusive_start_key) - log: BoundLogger = get_logger(total_count=total_count) log.info("started") - for identity_document in identity_wrapper.query_get_all_items(): + for identity_document in identity_wrapper.scan_iter_all_items(**kwargs): should_write_identity_document = False if identity_traits_data := identity_document.get("identity_traits"): diff --git a/api/environments/dynamodb/wrappers/base.py b/api/environments/dynamodb/wrappers/base.py index 9358ee62d4bd..7f8d8471d007 100644 --- a/api/environments/dynamodb/wrappers/base.py +++ b/api/environments/dynamodb/wrappers/base.py @@ -5,10 +5,19 @@ import boto3 import boto3.dynamodb.types from botocore.config import Config +from sentry_sdk import set_context # TODO @kgustyr: Replace with OTel if typing.TYPE_CHECKING: from mypy_boto3_dynamodb.service_resource import Table + from mypy_boto3_dynamodb.type_defs import ( + QueryOutputTableTypeDef, + ScanOutputTableTypeDef, + TableAttributeValueTypeDef, + ) + DynamoDBOutput = QueryOutputTableTypeDef | ScanOutputTableTypeDef + + P = typing.ParamSpec("P") # Avoid `decimal.Rounded` when reading large numbers # See https://github.com/boto/boto3/issues/2500 @@ -40,14 +49,20 @@ def get_table(self) -> typing.Optional["Table"]: def is_enabled(self) -> bool: return self.table is not None - def query_get_all_items(self, **kwargs: dict) -> typing.Generator[dict, None, None]: - if kwargs: - response_getter = partial(self.table.query, **kwargs) - else: - response_getter = partial(self.table.scan) + def _iter_all_items( + self, + response_getter_method: "typing.Callable[[P], DynamoDBOutput]", + **kwargs: "P.kwargs", + ) -> typing.Generator[dict[str, "TableAttributeValueTypeDef"], None, None]: + response_getter = partial(response_getter_method, **kwargs) + set_context( + "dynamodb", + {"table_name": self.table_name, **kwargs}, + ) while True: query_response = response_getter() + for item in query_response["Items"]: yield item @@ -56,3 +71,19 @@ def query_get_all_items(self, **kwargs: dict) -> typing.Generator[dict, None, No break response_getter.keywords["ExclusiveStartKey"] = last_evaluated_key + set_context( + "dynamodb", + {"table_name": self.table_name, **response_getter.keywords}, + ) + + def scan_iter_all_items( + self, + **kwargs: typing.Any, + ) -> typing.Generator[dict[str, "TableAttributeValueTypeDef"], None, None]: + return self._iter_all_items(self.table.scan, **kwargs) + + def query_iter_all_items( + self, + **kwargs: typing.Any, + ) -> typing.Generator[dict[str, "TableAttributeValueTypeDef"], None, None]: + return self._iter_all_items(self.table.query, **kwargs) diff --git a/api/environments/dynamodb/wrappers/environment_wrapper.py b/api/environments/dynamodb/wrappers/environment_wrapper.py index dadf070b7ff7..ff3579b17d67 100644 --- a/api/environments/dynamodb/wrappers/environment_wrapper.py +++ b/api/environments/dynamodb/wrappers/environment_wrapper.py @@ -69,7 +69,7 @@ def get_identity_overrides_by_environment_id( ) -> typing.List[dict[str, Any]]: try: return list( - self.query_get_all_items( + self.query_iter_all_items( KeyConditionExpression=Key(ENVIRONMENTS_V2_PARTITION_KEY).eq( str(environment_id), ) @@ -122,7 +122,7 @@ def delete_environment(self, environment_id: int): "ProjectionExpression": "document_key", } with self.table.batch_writer() as writer: - for item in self.query_get_all_items(**query_kwargs): + for item in self.query_iter_all_items(**query_kwargs): writer.delete_item( Key={ ENVIRONMENTS_V2_PARTITION_KEY: environment_id, diff --git a/api/tests/integration/edge_api/identities/conftest.py b/api/tests/integration/edge_api/identities/conftest.py index 3781cec6d76d..391b233c939e 100644 --- a/api/tests/integration/edge_api/identities/conftest.py +++ b/api/tests/integration/edge_api/identities/conftest.py @@ -34,7 +34,7 @@ def identity_overrides_v2( edge_identity.save(admin_user) return [ item["document_key"] - for item in dynamodb_wrapper_v2.query_get_all_items( + for item in dynamodb_wrapper_v2.query_iter_all_items( KeyConditionExpression=Key("environment_id").eq( str(dynamo_enabled_environment) ), diff --git a/api/tests/integration/edge_api/identities/test_edge_identity_viewset.py b/api/tests/integration/edge_api/identities/test_edge_identity_viewset.py index b479c86ef9ce..69e886c1b93f 100644 --- a/api/tests/integration/edge_api/identities/test_edge_identity_viewset.py +++ b/api/tests/integration/edge_api/identities/test_edge_identity_viewset.py @@ -163,7 +163,7 @@ def test_delete_identity( KeyConditionExpression=Key("identity_uuid").eq(identity_uuid), )["Count"] assert not list( - dynamodb_wrapper_v2.query_get_all_items( + dynamodb_wrapper_v2.query_iter_all_items( KeyConditionExpression=Key("environment_id").eq( str(dynamo_enabled_environment) ) diff --git a/api/tests/unit/edge_api/test_unit_edge_api_commands.py b/api/tests/unit/edge_api/test_unit_edge_api_commands.py index 347eee434dcb..ee58d5120324 100644 --- a/api/tests/unit/edge_api/test_unit_edge_api_commands.py +++ b/api/tests/unit/edge_api/test_unit_edge_api_commands.py @@ -218,3 +218,25 @@ def test_ensure_identity_traits_blanks__logs_expected( "total_count": 11, }, ] + + +def test_ensure_identity_traits_blanks__exclusive_start_key__calls_expected( + flagsmith_identities_table: "Table", + mocker: "MockerFixture", +) -> None: + # Given + exclusive_start_key = '{"composite_key":"test_hello"}' + expected_kwargs = {"ExclusiveStartKey": {"composite_key": "test_hello"}} + + identity_wrapper_mock = mocker.patch( + "edge_api.management.commands.ensure_identity_traits_blanks.identity_wrapper" + ) + + # When + call_command( + "ensure_identity_traits_blanks", + exclusive_start_key=exclusive_start_key, + ) + + # Then + identity_wrapper_mock.scan_get_all_items.assert_called_once_with(**expected_kwargs)