From e0bd94f3cf6dfdbaaf02ae32d331e26e7c58ab3f Mon Sep 17 00:00:00 2001 From: rhysrevans3 Date: Thu, 29 Aug 2024 14:46:09 +0100 Subject: [PATCH] Adding patch endpoints to transactions extension to elasticsearch. --- .../stac_fastapi/core/base_database_logic.py | 48 +++- stac_fastapi/core/stac_fastapi/core/core.py | 106 ++++++++ .../core/stac_fastapi/core/utilities.py | 63 +++++ .../elasticsearch/database_logic.py | 226 +++++++++++++++++- 4 files changed, 440 insertions(+), 3 deletions(-) diff --git a/stac_fastapi/core/stac_fastapi/core/base_database_logic.py b/stac_fastapi/core/stac_fastapi/core/base_database_logic.py index 0043cfb8..50d2062c 100644 --- a/stac_fastapi/core/stac_fastapi/core/base_database_logic.py +++ b/stac_fastapi/core/stac_fastapi/core/base_database_logic.py @@ -1,7 +1,7 @@ """Base database logic.""" import abc -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, List, Optional class BaseDatabaseLogic(abc.ABC): @@ -29,6 +29,30 @@ async def create_item(self, item: Dict, refresh: bool = False) -> None: """Create an item in the database.""" pass + @abc.abstractmethod + async def merge_patch_item( + self, + collection_id: str, + item_id: str, + item: Dict, + base_url: str, + refresh: bool = True, + ) -> Dict: + """Patch a item in the database follows RF7396.""" + pass + + @abc.abstractmethod + async def json_patch_item( + self, + collection_id: str, + item_id: str, + operations: List, + base_url: str, + refresh: bool = True, + ) -> Dict: + """Patch a item in the database follows RF6902.""" + pass + @abc.abstractmethod async def delete_item( self, item_id: str, collection_id: str, refresh: bool = False @@ -41,6 +65,28 @@ async def create_collection(self, collection: Dict, refresh: bool = False) -> No """Create a collection in the database.""" pass + @abc.abstractmethod + async def merge_patch_collection( + self, + collection_id: str, + collection: Dict, + base_url: str, + refresh: bool = True, + ) -> Dict: + """Patch a collection in the database follows RF7396.""" + pass + + @abc.abstractmethod + async def json_patch_collection( + self, + collection_id: str, + operations: List, + base_url: str, + refresh: bool = True, + ) -> Dict: + """Patch a collection in the database follows RF6902.""" + pass + @abc.abstractmethod async def find_collection(self, collection_id: str) -> Dict: """Find a collection in the database.""" diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index 57f7c816..b390356d 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -1,4 +1,5 @@ """Core client.""" + import logging import re from datetime import datetime as datetime_type @@ -708,6 +709,58 @@ async def update_item( return ItemSerializer.db_to_stac(item, base_url) + @overrides + async def merge_patch_item( + self, collection_id: str, item_id: str, item: stac_types.PartialItem, **kwargs + ) -> Optional[stac_types.Item]: + """Patch an item in the collection following RF7396.. + + Args: + collection_id (str): The ID of the collection the item belongs to. + item_id (str): The ID of the item to be updated. + item (stac_types.PartialItem): The partial item data. + kwargs: Other optional arguments, including the request object. + + Returns: + stac_types.Item: The patched item object. + + """ + item = await self.database.merge_patch_item( + collection_id=collection_id, + item_id=item_id, + item=item, + base_url=str(kwargs["request"].base_url), + ) + return ItemSerializer.db_to_stac(item, base_url=str(kwargs["request"].base_url)) + + @overrides + async def json_patch_item( + self, + collection_id: str, + item_id: str, + operations: List[stac_types.PatchOperation], + **kwargs, + ) -> Optional[stac_types.Item]: + """Patch an item in the collection following RF6902. + + Args: + collection_id (str): The ID of the collection the item belongs to. + item_id (str): The ID of the item to be updated. + operations (List): List of operations to run on item. + kwargs: Other optional arguments, including the request object. + + Returns: + stac_types.Item: The patched item object. + + """ + item = await self.database.json_patch_item( + collection_id=collection_id, + item_id=item_id, + base_url=str(kwargs["request"].base_url), + operations=operations, + ) + return ItemSerializer.db_to_stac(item, base_url=str(kwargs["request"].base_url)) + @overrides async def delete_item( self, item_id: str, collection_id: str, **kwargs @@ -788,6 +841,59 @@ async def update_collection( extensions=[type(ext).__name__ for ext in self.database.extensions], ) + @overrides + async def merge_patch_collection( + self, collection_id: str, collection: stac_types.PartialCollection, **kwargs + ) -> Optional[stac_types.Collection]: + """Patch a collection following RF7396.. + + Args: + collection_id (str): The ID of the collection to patch. + collection (stac_types.Collection): The partial collection data. + kwargs: Other optional arguments, including the request object. + + Returns: + stac_types.Collection: The patched collection object. + + """ + collection = await self.database.merge_patch_collection( + collection_id=collection_id, + base_url=str(kwargs["request"].base_url), + collection=collection, + ) + + return CollectionSerializer.db_to_stac( + collection, + kwargs["request"], + extensions=[type(ext).__name__ for ext in self.database.extensions], + ) + + @overrides + async def json_patch_collection( + self, collection_id: str, operations: List[stac_types.PatchOperation], **kwargs + ) -> Optional[stac_types.Collection]: + """Patch a collection following RF6902. + + Args: + collection_id (str): The ID of the collection to patch. + operations (List): List of operations to run on collection. + kwargs: Other optional arguments, including the request object. + + Returns: + stac_types.Collection: The patched collection object. + + """ + collection = await self.database.json_patch_collection( + collection_id=collection_id, + operations=operations, + base_url=str(kwargs["request"].base_url), + ) + return CollectionSerializer.db_to_stac( + collection, + kwargs["request"], + extensions=[type(ext).__name__ for ext in self.database.extensions], + ) + @overrides async def delete_collection( self, collection_id: str, **kwargs diff --git a/stac_fastapi/core/stac_fastapi/core/utilities.py b/stac_fastapi/core/stac_fastapi/core/utilities.py index d8c69529..96880d68 100644 --- a/stac_fastapi/core/stac_fastapi/core/utilities.py +++ b/stac_fastapi/core/stac_fastapi/core/utilities.py @@ -3,6 +3,8 @@ This module contains functions for transforming geospatial coordinates, such as converting bounding boxes to polygon representations. """ + +import json from typing import Any, Dict, List, Optional, Set, Union from stac_fastapi.types.stac import Item @@ -133,3 +135,64 @@ def dict_deep_update(merge_to: Dict[str, Any], merge_from: Dict[str, Any]) -> No dict_deep_update(merge_to[k], merge_from[k]) else: merge_to[k] = v + + +def merge_to_operations(data: Dict) -> List: + """Convert merge operation to list of RF6902 operations. + + Args: + data: dictionary to convert. + + Returns: + List: list of RF6902 operations. + """ + operations = [] + + for key, value in data.copy().items(): + + if value is None: + operations.append({"op": "remove", "path": key}) + continue + + elif isinstance(value, dict): + nested_operations = merge_to_operations(value) + + for nested_operation in nested_operations: + nested_operation["path"] = f"{key}.{nested_operation['path']}" + operations.append(nested_operation) + + else: + operations.append({"op": "add", "path": key, "value": value}) + + return operations + + +def operations_to_script(operations: List) -> Dict: + """Convert list of operation to painless script. + + Args: + operations: List of RF6902 operations. + + Returns: + Dict: elasticsearch update script. + """ + source = "" + for operation in operations: + if operation["op"] in ["copy", "move"]: + source += ( + f"ctx._source.{operation['path']} = ctx._source.{operation['from']};" + ) + + if operation["op"] in ["remove", "move"]: + nest, partition, key = operation["path"].rpartition(".") + source += f"ctx._source.{nest + partition}remove('{key}');" + + if operation["op"] in ["add", "replace"]: + source += ( + f"ctx._source.{operation['path']} = {json.dumps(operation['value'])};" + ) + + return { + "source": source, + "lang": "painless", + } diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index a4b40325..6de5c424 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -1,4 +1,5 @@ """Database logic.""" + import asyncio import logging import os @@ -12,13 +13,24 @@ from elasticsearch import exceptions, helpers # type: ignore from stac_fastapi.core.extensions import filter from stac_fastapi.core.serializers import CollectionSerializer, ItemSerializer -from stac_fastapi.core.utilities import MAX_LIMIT, bbox2polygon +from stac_fastapi.core.utilities import ( + MAX_LIMIT, + bbox2polygon, + merge_to_operations, + operations_to_script, +) from stac_fastapi.elasticsearch.config import AsyncElasticsearchSettings from stac_fastapi.elasticsearch.config import ( ElasticsearchSettings as SyncElasticsearchSettings, ) from stac_fastapi.types.errors import ConflictError, NotFoundError -from stac_fastapi.types.stac import Collection, Item +from stac_fastapi.types.stac import ( + Collection, + Item, + PartialCollection, + PartialItem, + PatchOperation, +) logger = logging.getLogger(__name__) @@ -735,6 +747,129 @@ async def create_item(self, item: Item, refresh: bool = False): f"Item {item_id} in collection {collection_id} already exists" ) + async def merge_patch_item( + self, + collection_id: str, + item_id: str, + item: PartialItem, + base_url: str, + refresh: bool = True, + ) -> Item: + """Database logic for merge patching an item following RF7396. + + Args: + collection_id(str): Collection that item belongs to. + item_id(str): Id of item to be patched. + item (PartialItem): The partial item to be updated. + base_url: (str): The base URL used for constructing URLs for the item. + refresh (bool, optional): Refresh the index after performing the operation. Defaults to True. + + Raises: + ConflictError: If the item already exists in the database. + + Returns: + None + """ + operations = merge_to_operations(item) + + return await self.json_patch_item( + collection_id=collection_id, + item_id=item_id, + operations=operations, + base_url=base_url, + refresh=refresh, + ) + + async def json_patch_item( + self, + collection_id: str, + item_id: str, + operations: List[PatchOperation], + base_url: str, + refresh: bool = True, + ) -> Item: + """Database logic for json patching an item following RF6902. + + Args: + collection_id(str): Collection that item belongs to. + item_id(str): Id of item to be patched. + operations (list): List of operations to run. + base_url (str): The base URL used for constructing URLs for the item. + refresh (bool, optional): Refresh the index after performing the operation. Defaults to True. + + Raises: + ConflictError: If the item already exists in the database. + + Returns: + None + """ + new_item_id = None + new_collection_id = None + script_operations = [] + + for operation in operations: + if operation["op"] in ["add", "replace"]: + if ( + operation["path"] == "collection" + and collection_id != operation["value"] + ): + await self.check_collection_exists(collection_id=operation["value"]) + new_collection_id = operation["value"] + + if operation["path"] == "id" and item_id != operation["value"]: + new_item_id = operation["value"] + + else: + script_operations.append(operation) + + script = operations_to_script(script_operations) + + if not new_collection_id and not new_item_id: + await self.client.update( + index=index_by_collection_id(collection_id), + id=mk_item_id(item_id, collection_id), + script=script, + refresh=refresh, + ) + + if new_collection_id: + await self.client.reindex( + body={ + "dest": {"index": f"{ITEMS_INDEX_PREFIX}{operation['value']}"}, + "source": { + "index": f"{ITEMS_INDEX_PREFIX}{collection_id}", + "query": {"term": {"id": {"value": item_id}}}, + }, + "script": { + "lang": "painless", + "source": ( + f"""ctx._id = ctx._id.replace('{collection_id}', '{operation["value"]}');""" + f"""ctx._source.collection = '{operation["value"]}';""" + + script + ), + }, + }, + wait_for_completion=True, + refresh=False, + ) + + item = await self.get_one_item(collection_id, item_id) + + if new_item_id: + item["id"] = new_item_id + item = await self.prep_create_item(item=item, base_url=base_url) + await self.create_item(item=item, refresh=False) + + if new_item_id or new_collection_id: + + await self.delete_item( + item_id=item_id, + collection_id=collection_id, + refresh=refresh, + ) + + return item + async def delete_item( self, item_id: str, collection_id: str, refresh: bool = False ): @@ -859,6 +994,93 @@ async def update_collection( refresh=refresh, ) + async def merge_patch_collection( + self, + collection_id: str, + collection: PartialCollection, + base_url: str, + refresh: bool = True, + ) -> Item: + """Database logic for merge patching a collection following RF7396. + + Args: + collection_id(str): Collection that item belongs to. + item_id(str): Id of item to be patched. + item (PartialItem): The partial item to be updated. + base_url: (str): The base URL used for constructing URLs for the item. + refresh (bool, optional): Refresh the index after performing the operation. Defaults to True. + + Raises: + ConflictError: If the item already exists in the database. + + Returns: + None + """ + operations = merge_to_operations(collection) + + return await self.json_patch_collection( + collection_id=collection_id, + operations=operations, + base_url=base_url, + refresh=refresh, + ) + + async def json_patch_collection( + self, + collection_id: str, + operations: List[PatchOperation], + base_url: str, + refresh: bool = True, + ) -> Item: + """Database logic for json patching an item following RF6902. + + Args: + collection_id(str): Collection that item belongs to. + item_id(str): Id of item to be patched. + operations (list): List of operations to run. + base_url (str): The base URL used for constructing URLs for the item. + refresh (bool, optional): Refresh the index after performing the operation. Defaults to True. + + Raises: + ConflictError: If the item already exists in the database. + + Returns: + None + """ + new_collection_id = None + script_operations = [] + + for operation in operations: + if ( + operation["op"] in ["add", "replace"] + and operation["path"] == "collection" + and collection_id != operation["value"] + ): + new_collection_id = operation["value"] + + else: + script_operations.append(operation) + + script = operations_to_script(script_operations) + + if not new_collection_id: + await self.client.update( + index=COLLECTIONS_INDEX, + id=collection_id, + script=script, + refresh=refresh, + ) + + collection = await self.find_collection(collection_id) + + if new_collection_id: + collection["id"] = new_collection_id + await self.update_collection( + collection_id=collection_id, collection=collection, refresh=False + ) + + return collection + async def delete_collection(self, collection_id: str, refresh: bool = False): """Delete a collection from the database.