diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c3f4a61..b552a17f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Changed - Elasticsearch drivers from 7.17.9 to 8.11.0 [#169](https://github.com/stac-utils/stac-fastapi-elasticsearch/pull/169) +- Collection update endpoint no longer delete all sub items [#177](https://github.com/stac-utils/stac-fastapi-elasticsearch/pull/177) ### Fixed diff --git a/docker-compose.yml b/docker-compose.yml index db3352fb..03698654 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -14,7 +14,7 @@ services: - RELOAD=true - ENVIRONMENT=local - WEB_CONCURRENCY=10 - - ES_HOST=172.17.0.1 + - ES_HOST=elasticsearch - ES_PORT=9200 - ES_USE_SSL=false - ES_VERIFY_CERTS=false @@ -32,6 +32,7 @@ services: elasticsearch: container_name: es-container image: docker.elastic.co/elasticsearch/elasticsearch:${ELASTICSEARCH_VERSION:-8.11.0} + hostname: elasticsearch environment: ES_JAVA_OPTS: -Xms512m -Xmx1g volumes: diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/config.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/config.py index 8634d3b9..10cf95e9 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/config.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/config.py @@ -39,6 +39,15 @@ def _es_config() -> Dict[str, Any]: if (u := os.getenv("ES_USER")) and (p := os.getenv("ES_PASS")): config["http_auth"] = (u, p) + if api_key := os.getenv("ES_API_KEY"): + if isinstance(config["headers"], dict): + headers = {**config["headers"], "x-api-key": api_key} + + else: + config["headers"] = {"x-api-key": api_key} + + config["headers"] = headers + return config diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py index 12cc6b2c..a78d4707 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py @@ -457,9 +457,9 @@ async def post_search( ) if search_request.query: - for (field_name, expr) in search_request.query.items(): + for field_name, expr in search_request.query.items(): field = "properties__" + field_name - for (op, value) in expr.items(): + for op, value in expr.items(): search = self.database.apply_stacql_filter( search=search, op=op, field=field, value=value ) @@ -660,8 +660,11 @@ async def update_collection( Update a collection. This method updates an existing collection in the database by first finding - the collection by its id, then deleting the old version, and finally creating - a new version of the updated collection. The updated collection is then returned. + the collection by the id given in the keyword argument `collection_id`. + If no `collection_id` is given the id of the given collection object is used. + If the object and keyword collection ids don't match the sub items + collection id is updated else the items are left unchanged. + The updated collection is then returned. Args: collection: A STAC collection that needs to be updated. @@ -673,9 +676,18 @@ async def update_collection( """ base_url = str(kwargs["request"].base_url) - await self.database.find_collection(collection_id=collection["id"]) - await self.delete_collection(collection["id"]) - await self.create_collection(collection, **kwargs) + collection_id = kwargs["request"].query_params.get( + "collection_id", collection["id"] + ) + + collection_links = CollectionLinks( + collection_id=collection["id"], base_url=base_url + ).create_links() + collection["links"] = collection_links + + await self.database.update_collection( + collection_id=collection_id, collection=collection + ) return CollectionSerializer.db_to_stac(collection, base_url) diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 336c8d07..8b8911d1 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -763,6 +763,53 @@ async def find_collection(self, collection_id: str) -> Collection: return collection["_source"] + async def update_collection( + self, collection_id: str, collection: Collection, refresh: bool = False + ): + """Update a collection from the database. + + Args: + self: The instance of the object calling this function. + collection_id (str): The ID of the collection to be updated. + collection (Collection): The Collection object to be used for the update. + + Raises: + NotFoundError: If the collection with the given `collection_id` is not + found in the database. + + Notes: + This function updates the collection in the database using the specified + `collection_id` and with the collection specified in the `Collection` object. + If the collection is not found, a `NotFoundError` is raised. + """ + await self.find_collection(collection_id=collection_id) + + if collection_id != collection["id"]: + await self.create_collection(collection, refresh=refresh) + + await self.client.reindex( + body={ + "dest": {"index": f"{ITEMS_INDEX_PREFIX}{collection['id']}"}, + "source": {"index": f"{ITEMS_INDEX_PREFIX}{collection_id}"}, + "script": { + "lang": "painless", + "source": f"""ctx._id = ctx._id.replace('{collection_id}', '{collection["id"]}'); ctx._source.collection = '{collection["id"]}' ;""", + }, + }, + wait_for_completion=True, + refresh=refresh, + ) + + await self.delete_collection(collection_id) + + else: + await self.client.index( + index=COLLECTIONS_INDEX, + id=collection_id, + document=collection, + refresh=refresh, + ) + async def delete_collection(self, collection_id: str, refresh: bool = False): """Delete a collection from the database. diff --git a/stac_fastapi/elasticsearch/tests/clients/test_elasticsearch.py b/stac_fastapi/elasticsearch/tests/clients/test_elasticsearch.py index 3da8f86d..41fcf26d 100644 --- a/stac_fastapi/elasticsearch/tests/clients/test_elasticsearch.py +++ b/stac_fastapi/elasticsearch/tests/clients/test_elasticsearch.py @@ -40,16 +40,90 @@ async def test_update_collection( txn_client, load_test_data: Callable, ): - data = load_test_data("test_collection.json") + collection_data = load_test_data("test_collection.json") + item_data = load_test_data("test_item.json") - await txn_client.create_collection(data, request=MockRequest) - data["keywords"].append("new keyword") - await txn_client.update_collection(data, request=MockRequest) + await txn_client.create_collection(collection_data, request=MockRequest) + await txn_client.create_item( + collection_id=collection_data["id"], + item=item_data, + request=MockRequest, + refresh=True, + ) - coll = await core_client.get_collection(data["id"], request=MockRequest) + collection_data["keywords"].append("new keyword") + await txn_client.update_collection(collection_data, request=MockRequest) + + coll = await core_client.get_collection(collection_data["id"], request=MockRequest) assert "new keyword" in coll["keywords"] - await txn_client.delete_collection(data["id"]) + item = await core_client.get_item( + item_id=item_data["id"], + collection_id=collection_data["id"], + request=MockRequest, + ) + assert item["id"] == item_data["id"] + assert item["collection"] == item_data["collection"] + + await txn_client.delete_collection(collection_data["id"]) + + +@pytest.mark.asyncio +async def test_update_collection_id( + core_client, + txn_client, + load_test_data: Callable, +): + collection_data = load_test_data("test_collection.json") + item_data = load_test_data("test_item.json") + new_collection_id = "new-test-collection" + + await txn_client.create_collection(collection_data, request=MockRequest) + await txn_client.create_item( + collection_id=collection_data["id"], + item=item_data, + request=MockRequest, + refresh=True, + ) + + old_collection_id = collection_data["id"] + collection_data["id"] = new_collection_id + + await txn_client.update_collection( + collection=collection_data, + request=MockRequest( + query_params={ + "collection_id": old_collection_id, + "limit": "10", + } + ), + refresh=True, + ) + + with pytest.raises(NotFoundError): + await core_client.get_collection(old_collection_id, request=MockRequest) + + coll = await core_client.get_collection(collection_data["id"], request=MockRequest) + assert coll["id"] == new_collection_id + + with pytest.raises(NotFoundError): + await core_client.get_item( + item_id=item_data["id"], + collection_id=old_collection_id, + request=MockRequest, + ) + + item = await core_client.get_item( + item_id=item_data["id"], + collection_id=collection_data["id"], + request=MockRequest, + refresh=True, + ) + + assert item["id"] == item_data["id"] + assert item["collection"] == new_collection_id + + await txn_client.delete_collection(collection_data["id"]) @pytest.mark.asyncio diff --git a/stac_fastapi/elasticsearch/tests/conftest.py b/stac_fastapi/elasticsearch/tests/conftest.py index fa093af2..13956329 100644 --- a/stac_fastapi/elasticsearch/tests/conftest.py +++ b/stac_fastapi/elasticsearch/tests/conftest.py @@ -39,6 +39,7 @@ def __init__(self, item, collection): class MockRequest: base_url = "http://test-server" + query_params = {} def __init__( self, @@ -50,7 +51,7 @@ def __init__( self.method = method self.url = url self.app = app - self.query_params = query_params or {} + self.query_params = query_params class TestSettings(AsyncElasticsearchSettings):