diff --git a/CHANGELOG.md b/CHANGELOG.md index a88e32a1..7931b027 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] +### Added + - Queryables landing page and collection links when the Filter Extension is enabled [#267](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/267) + ### Changed - Updated stac-fastapi libraries to v3.0.0a1 [#265](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/265) diff --git a/data_loader.py b/data_loader.py index 1cccffd5..7d157e40 100644 --- a/data_loader.py +++ b/data_loader.py @@ -22,12 +22,17 @@ def load_collection(base_url, collection_id, data_dir): collection["id"] = collection_id try: resp = requests.post(f"{base_url}/collections", json=collection) - if resp.status_code == 200: + if resp.status_code == 200 or resp.status_code == 201: click.echo(f"Status code: {resp.status_code}") click.echo(f"Added collection: {collection['id']}") elif resp.status_code == 409: click.echo(f"Status code: {resp.status_code}") click.echo(f"Collection: {collection['id']} already exists") + else: + click.echo(f"Status code: {resp.status_code}") + click.echo( + f"Error writing {collection['id']} collection. Message: {resp.text}" + ) except requests.ConnectionError: click.secho("Failed to connect", fg="red", err=True) diff --git a/sample_data/collection.json b/sample_data/collection.json index dd68234d..bafd3ea2 100644 --- a/sample_data/collection.json +++ b/sample_data/collection.json @@ -1,6 +1,7 @@ { "id":"sentinel-s2-l2a-cogs-test", "stac_version":"1.0.0", + "type": "Collection", "description":"Sentinel-2a and Sentinel-2b imagery, processed to Level 2A (Surface Reflectance) and converted to Cloud-Optimized GeoTIFFs", "links":[ {"rel":"self","href":"https://earth-search.aws.element84.com/v0/collections/sentinel-s2-l2a-cogs"}, diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index 5469bf10..984b2e0a 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -153,6 +153,19 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: conformance_classes=self.conformance_classes(), extension_schemas=[], ) + + if self.extension_is_enabled("FilterExtension"): + landing_page["links"].append( + { + # TODO: replace this with Relations.queryables.value, + "rel": "queryables", + # TODO: replace this with MimeTypes.jsonschema, + "type": "application/schema+json", + "title": "Queryables", + "href": urljoin(base_url, "queryables"), + } + ) + collections = await self.all_collections(request=kwargs["request"]) for collection in collections["collections"]: landing_page["links"].append( @@ -205,7 +218,7 @@ async def all_collections(self, **kwargs) -> stac_types.Collections: token = request.query_params.get("token") collections, next_token = await self.database.get_all_collections( - token=token, limit=limit, base_url=base_url + token=token, limit=limit, request=request ) links = [ @@ -239,10 +252,12 @@ async def get_collection( Raises: NotFoundError: If the collection with the given id cannot be found in the database. """ - base_url = str(kwargs["request"].base_url) + request = kwargs["request"] collection = await self.database.find_collection(collection_id=collection_id) return self.collection_serializer.db_to_stac( - collection=collection, base_url=base_url + collection=collection, + request=request, + extensions=[type(ext).__name__ for ext in self.extensions], ) async def item_collection( @@ -748,12 +763,14 @@ async def create_collection( ConflictError: If the collection already exists. """ collection = collection.model_dump(mode="json") - base_url = str(kwargs["request"].base_url) - collection = self.database.collection_serializer.stac_to_db( - collection, base_url - ) + request = kwargs["request"] + collection = self.database.collection_serializer.stac_to_db(collection, request) await self.database.create_collection(collection=collection) - return CollectionSerializer.db_to_stac(collection, base_url) + return CollectionSerializer.db_to_stac( + collection, + request, + extensions=[type(ext).__name__ for ext in self.database.extensions], + ) @overrides async def update_collection( @@ -780,16 +797,18 @@ async def update_collection( """ collection = collection.model_dump(mode="json") - base_url = str(kwargs["request"].base_url) + request = kwargs["request"] - collection = self.database.collection_serializer.stac_to_db( - collection, base_url - ) + collection = self.database.collection_serializer.stac_to_db(collection, request) await self.database.update_collection( collection_id=collection_id, collection=collection ) - return CollectionSerializer.db_to_stac(collection, base_url) + return CollectionSerializer.db_to_stac( + collection, + request, + extensions=[type(ext).__name__ for ext in self.database.extensions], + ) @overrides async def delete_collection( diff --git a/stac_fastapi/core/stac_fastapi/core/models/links.py b/stac_fastapi/core/stac_fastapi/core/models/links.py index 725dc5c0..7a12b1c4 100644 --- a/stac_fastapi/core/stac_fastapi/core/models/links.py +++ b/stac_fastapi/core/stac_fastapi/core/models/links.py @@ -107,6 +107,39 @@ async def get_links( return links +@attr.s +class CollectionLinks(BaseLinks): + """Create inferred links specific to collections.""" + + collection_id: str = attr.ib() + extensions: List[str] = attr.ib(default=attr.Factory(list)) + + def link_parent(self) -> Dict[str, Any]: + """Create the `parent` link.""" + return dict(rel=Relations.parent, type=MimeTypes.json.value, href=self.base_url) + + def link_items(self) -> Dict[str, Any]: + """Create the `items` link.""" + return dict( + rel="items", + type=MimeTypes.geojson.value, + href=urljoin(self.base_url, f"collections/{self.collection_id}/items"), + ) + + def link_queryables(self) -> Dict[str, Any]: + """Create the `queryables` link.""" + if "FilterExtension" in self.extensions: + return dict( + rel="queryables", + type=MimeTypes.json.value, + href=urljoin( + self.base_url, f"collections/{self.collection_id}/queryables" + ), + ) + else: + return None + + @attr.s class PagingLinks(BaseLinks): """Create links for paging.""" diff --git a/stac_fastapi/core/stac_fastapi/core/serializers.py b/stac_fastapi/core/stac_fastapi/core/serializers.py index ba588025..9b0d36d4 100644 --- a/stac_fastapi/core/stac_fastapi/core/serializers.py +++ b/stac_fastapi/core/stac_fastapi/core/serializers.py @@ -1,13 +1,15 @@ """Serializers.""" import abc from copy import deepcopy -from typing import Any +from typing import Any, List, Optional import attr +from starlette.requests import Request from stac_fastapi.core.datetime_utils import now_to_rfc3339_str +from stac_fastapi.core.models.links import CollectionLinks from stac_fastapi.types import stac as stac_types -from stac_fastapi.types.links import CollectionLinks, ItemLinks, resolve_links +from stac_fastapi.types.links import ItemLinks, resolve_links @attr.s @@ -109,29 +111,34 @@ class CollectionSerializer(Serializer): @classmethod def stac_to_db( - cls, collection: stac_types.Collection, base_url: str + cls, collection: stac_types.Collection, request: Request ) -> stac_types.Collection: """ Transform STAC Collection to database-ready STAC collection. Args: stac_data: the STAC Collection object to be transformed - base_url: the base URL for the STAC API + starlette.requests.Request: the API request Returns: stac_types.Collection: The database-ready STAC Collection object. """ collection = deepcopy(collection) - collection["links"] = resolve_links(collection.get("links", []), base_url) + collection["links"] = resolve_links( + collection.get("links", []), str(request.base_url) + ) return collection @classmethod - def db_to_stac(cls, collection: dict, base_url: str) -> stac_types.Collection: + def db_to_stac( + cls, collection: dict, request: Request, extensions: Optional[List[str]] = [] + ) -> stac_types.Collection: """Transform database model to STAC collection. Args: collection (dict): The collection data in dictionary form, extracted from the database. - base_url (str): The base URL for the collection. + starlette.requests.Request: the API request + extensions: A list of the extension class names (`ext.__name__`) or all enabled STAC API extensions. Returns: stac_types.Collection: The STAC collection object. @@ -157,13 +164,13 @@ def db_to_stac(cls, collection: dict, base_url: str) -> stac_types.Collection: # Create the collection links using CollectionLinks collection_links = CollectionLinks( - collection_id=collection_id, base_url=base_url + collection_id=collection_id, request=request, extensions=extensions ).create_links() # Add any additional links from the collection dictionary original_links = collection.get("links") if original_links: - collection_links += resolve_links(original_links, base_url) + collection_links += resolve_links(original_links, str(request.base_url)) collection["links"] = collection_links # Return the stac_types.Collection object diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py index c0d4aaea..6a5ee006 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py @@ -59,6 +59,8 @@ filter_extension, ] +database_logic.extensions = [type(ext).__name__ for ext in extensions] + post_request_model = create_post_request_model(extensions) api = StacApi( diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index ddb6648b..a4b40325 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -7,6 +7,7 @@ import attr from elasticsearch_dsl import Q, Search +from starlette.requests import Request from elasticsearch import exceptions, helpers # type: ignore from stac_fastapi.core.extensions import filter @@ -312,10 +313,12 @@ class DatabaseLogic: default=CollectionSerializer ) + extensions: List[str] = attr.ib(default=attr.Factory(list)) + """CORE LOGIC""" async def get_all_collections( - self, token: Optional[str], limit: int, base_url: str + self, token: Optional[str], limit: int, request: Request ) -> Tuple[List[Dict[str, Any]], Optional[str]]: """Retrieve a list of all collections from Elasticsearch, supporting pagination. @@ -342,7 +345,7 @@ async def get_all_collections( hits = response["hits"]["hits"] collections = [ self.collection_serializer.db_to_stac( - collection=hit["_source"], base_url=base_url + collection=hit["_source"], request=request, extensions=self.extensions ) for hit in hits ] diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py index 4cd38c20..d06b0f29 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py @@ -59,6 +59,8 @@ filter_extension, ] +database_logic.extensions = [type(ext).__name__ for ext in extensions] + post_request_model = create_post_request_model(extensions) api = StacApi( diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index 5a320d8f..841d5e27 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -10,6 +10,7 @@ from opensearchpy.exceptions import TransportError from opensearchpy.helpers.query import Q from opensearchpy.helpers.search import Search +from starlette.requests import Request from stac_fastapi.core import serializers from stac_fastapi.core.extensions import filter @@ -333,10 +334,12 @@ class DatabaseLogic: default=serializers.CollectionSerializer ) + extensions: List[str] = attr.ib(default=attr.Factory(list)) + """CORE LOGIC""" async def get_all_collections( - self, token: Optional[str], limit: int, base_url: str + self, token: Optional[str], limit: int, request: Request ) -> Tuple[List[Dict[str, Any]], Optional[str]]: """ Retrieve a list of all collections from Opensearch, supporting pagination. @@ -366,7 +369,7 @@ async def get_all_collections( hits = response["hits"]["hits"] collections = [ self.collection_serializer.db_to_stac( - collection=hit["_source"], base_url=base_url + collection=hit["_source"], request=request, extensions=self.extensions ) for hit in hits ] diff --git a/stac_fastapi/tests/conftest.py b/stac_fastapi/tests/conftest.py index 21380494..619a257c 100644 --- a/stac_fastapi/tests/conftest.py +++ b/stac_fastapi/tests/conftest.py @@ -58,6 +58,7 @@ def __init__(self, item, collection): class MockRequest: base_url = "http://test-server" + url = "http://test-server/test" query_params = {} def __init__( diff --git a/stac_fastapi/tests/extensions/test_filter.py b/stac_fastapi/tests/extensions/test_filter.py index edff5c1a..8f4fa5ee 100644 --- a/stac_fastapi/tests/extensions/test_filter.py +++ b/stac_fastapi/tests/extensions/test_filter.py @@ -8,6 +8,35 @@ THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +@pytest.mark.asyncio +async def test_filter_extension_landing_page_link(app_client, ctx): + resp = await app_client.get("/") + assert resp.status_code == 200 + + resp_json = resp.json() + keys = [link["rel"] for link in resp_json["links"]] + + assert "queryables" in keys + + +@pytest.mark.asyncio +async def test_filter_extension_collection_link(app_client, load_test_data): + """Test creation and deletion of a collection""" + test_collection = load_test_data("test_collection.json") + test_collection["id"] = "test" + + resp = await app_client.post("/collections", json=test_collection) + assert resp.status_code == 201 + + resp = await app_client.get(f"/collections/{test_collection['id']}") + resp_json = resp.json() + keys = [link["rel"] for link in resp_json["links"]] + assert "queryables" in keys + + resp = await app_client.delete(f"/collections/{test_collection['id']}") + assert resp.status_code == 204 + + @pytest.mark.asyncio async def test_search_filters_post(app_client, ctx):