From 33ddeb7f28ac47847d31a60e6096493f39a968ad Mon Sep 17 00:00:00 2001 From: jonhealy1 Date: Wed, 31 Jan 2024 23:19:39 +0800 Subject: [PATCH] remove core.py from es folder --- .../stac_fastapi/elasticsearch/core.py | 851 ------------------ .../elasticsearch/models/links.py | 138 --- .../tests/resources/test_item.py | 10 +- 3 files changed, 8 insertions(+), 991 deletions(-) delete mode 100644 stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py delete mode 100644 stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/models/links.py diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py deleted file mode 100644 index 12cc6b2c..00000000 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py +++ /dev/null @@ -1,851 +0,0 @@ -"""Item crud client.""" -import logging -import re -from base64 import urlsafe_b64encode -from datetime import datetime as datetime_type -from datetime import timezone -from typing import Any, Dict, List, Optional, Set, Type, Union -from urllib.parse import unquote_plus, urljoin - -import attr -import orjson -import stac_pydantic -from fastapi import HTTPException, Request -from overrides import overrides -from pydantic import ValidationError -from pygeofilter.backends.cql2_json import to_cql2 -from pygeofilter.parsers.cql2_text import parse as parse_cql2_text -from stac_pydantic.links import Relations -from stac_pydantic.shared import MimeTypes - -from stac_fastapi.elasticsearch import serializers -from stac_fastapi.elasticsearch.config import ElasticsearchSettings -from stac_fastapi.elasticsearch.database_logic import DatabaseLogic -from stac_fastapi.elasticsearch.models.links import PagingLinks -from stac_fastapi.elasticsearch.serializers import CollectionSerializer, ItemSerializer -from stac_fastapi.elasticsearch.session import Session -from stac_fastapi.extensions.third_party.bulk_transactions import ( - BaseBulkTransactionsClient, - BulkTransactionMethod, - Items, -) -from stac_fastapi.types import stac as stac_types -from stac_fastapi.types.config import Settings -from stac_fastapi.types.core import ( - AsyncBaseCoreClient, - AsyncBaseFiltersClient, - AsyncBaseTransactionsClient, -) -from stac_fastapi.types.links import CollectionLinks -from stac_fastapi.types.search import BaseSearchPostRequest -from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection - -logger = logging.getLogger(__name__) - -NumType = Union[float, int] - - -@attr.s -class CoreClient(AsyncBaseCoreClient): - """Client for core endpoints defined by the STAC specification. - - This class is a implementation of `AsyncBaseCoreClient` that implements the core endpoints - defined by the STAC specification. It uses the `DatabaseLogic` class to interact with the - database, and `ItemSerializer` and `CollectionSerializer` to convert between STAC objects and - database records. - - Attributes: - session (Session): A requests session instance to be used for all HTTP requests. - item_serializer (Type[serializers.ItemSerializer]): A serializer class to be used to convert - between STAC items and database records. - collection_serializer (Type[serializers.CollectionSerializer]): A serializer class to be - used to convert between STAC collections and database records. - database (DatabaseLogic): An instance of the `DatabaseLogic` class that is used to interact - with the database. - """ - - session: Session = attr.ib(default=attr.Factory(Session.create_from_env)) - item_serializer: Type[serializers.ItemSerializer] = attr.ib( - default=serializers.ItemSerializer - ) - collection_serializer: Type[serializers.CollectionSerializer] = attr.ib( - default=serializers.CollectionSerializer - ) - database = DatabaseLogic() - - @overrides - async def all_collections(self, **kwargs) -> Collections: - """Read all collections from the database. - - Returns: - Collections: A `Collections` object containing all the collections in the database and - links to various resources. - - Raises: - Exception: If any error occurs while reading the collections from the database. - """ - request: Request = kwargs["request"] - base_url = str(kwargs["request"].base_url) - - limit = ( - int(request.query_params["limit"]) - if "limit" in request.query_params - else 10 - ) - token = ( - request.query_params["token"] if "token" in request.query_params else None - ) - - hits = await self.database.get_all_collections(limit=limit, token=token) - - next_search_after = None - next_link = None - if len(hits) == limit: - last_hit = hits[-1] - next_search_after = last_hit["sort"] - next_token = urlsafe_b64encode( - ",".join(map(str, next_search_after)).encode() - ).decode() - paging_links = PagingLinks(next=next_token, request=request) - next_link = paging_links.link_next() - - links = [ - { - "rel": Relations.root.value, - "type": MimeTypes.json, - "href": base_url, - }, - { - "rel": Relations.parent.value, - "type": MimeTypes.json, - "href": base_url, - }, - { - "rel": Relations.self.value, - "type": MimeTypes.json, - "href": urljoin(base_url, "collections"), - }, - ] - - if next_link: - links.append(next_link) - - return Collections( - collections=[ - self.collection_serializer.db_to_stac(c["_source"], base_url=base_url) - for c in hits - ], - links=links, - ) - - @overrides - async def get_collection(self, collection_id: str, **kwargs) -> Collection: - """Get a collection from the database by its id. - - Args: - collection_id (str): The id of the collection to retrieve. - kwargs: Additional keyword arguments passed to the API call. - - Returns: - Collection: A `Collection` object representing the requested collection. - - Raises: - NotFoundError: If the collection with the given id cannot be found in the database. - """ - base_url = str(kwargs["request"].base_url) - collection = await self.database.find_collection(collection_id=collection_id) - return self.collection_serializer.db_to_stac(collection, base_url) - - @overrides - async def item_collection( - self, - collection_id: str, - bbox: Optional[List[NumType]] = None, - datetime: Union[str, datetime_type, None] = None, - limit: int = 10, - token: str = None, - **kwargs, - ) -> ItemCollection: - """Read items from a specific collection in the database. - - Args: - collection_id (str): The identifier of the collection to read items from. - bbox (Optional[List[NumType]]): The bounding box to filter items by. - datetime (Union[str, datetime_type, None]): The datetime range to filter items by. - limit (int): The maximum number of items to return. The default value is 10. - token (str): A token used for pagination. - request (Request): The incoming request. - - Returns: - ItemCollection: An `ItemCollection` object containing the items from the specified collection that meet - the filter criteria and links to various resources. - - Raises: - HTTPException: If the specified collection is not found. - Exception: If any error occurs while reading the items from the database. - """ - request: Request = kwargs["request"] - base_url = str(request.base_url) - - collection = await self.get_collection( - collection_id=collection_id, request=request - ) - collection_id = collection.get("id") - if collection_id is None: - raise HTTPException(status_code=404, detail="Collection not found") - - search = self.database.make_search() - search = self.database.apply_collections_filter( - search=search, collection_ids=[collection_id] - ) - - if datetime: - datetime_search = self._return_date(datetime) - search = self.database.apply_datetime_filter( - search=search, datetime_search=datetime_search - ) - - if bbox: - bbox = [float(x) for x in bbox] - if len(bbox) == 6: - bbox = [bbox[0], bbox[1], bbox[3], bbox[4]] - - search = self.database.apply_bbox_filter(search=search, bbox=bbox) - - items, maybe_count, next_token = await self.database.execute_search( - search=search, - limit=limit, - sort=None, - token=token, # type: ignore - collection_ids=[collection_id], - ) - - items = [ - self.item_serializer.db_to_stac(item, base_url=base_url) for item in items - ] - - context_obj = None - if self.extension_is_enabled("ContextExtension"): - context_obj = { - "returned": len(items), - "limit": limit, - } - if maybe_count is not None: - context_obj["matched"] = maybe_count - - links = [] - if next_token: - links = await PagingLinks(request=request, next=next_token).get_links() - - return ItemCollection( - type="FeatureCollection", - features=items, - links=links, - context=context_obj, - ) - - @overrides - async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: - """Get an item from the database based on its id and collection id. - - Args: - collection_id (str): The ID of the collection the item belongs to. - item_id (str): The ID of the item to be retrieved. - - Returns: - Item: An `Item` object representing the requested item. - - Raises: - Exception: If any error occurs while getting the item from the database. - NotFoundError: If the item does not exist in the specified collection. - """ - base_url = str(kwargs["request"].base_url) - item = await self.database.get_one_item( - item_id=item_id, collection_id=collection_id - ) - return self.item_serializer.db_to_stac(item, base_url) - - @staticmethod - def _return_date(interval_str): - """ - Convert a date interval string into a dictionary for filtering search results. - - The date interval string should be formatted as either a single date or a range of dates separated - by "/". The date format should be ISO-8601 (YYYY-MM-DDTHH:MM:SSZ). If the interval string is a - single date, it will be converted to a dictionary with a single "eq" key whose value is the date in - the ISO-8601 format. If the interval string is a range of dates, it will be converted to a - dictionary with "gte" (greater than or equal to) and "lte" (less than or equal to) keys. If the - interval string is a range of dates with ".." instead of "/", the start and end dates will be - assigned default values to encompass the entire possible date range. - - Args: - interval_str (str): The date interval string to be converted. - - Returns: - dict: A dictionary representing the date interval for use in filtering search results. - """ - intervals = interval_str.split("/") - if len(intervals) == 1: - datetime = f"{intervals[0][0:19]}Z" - return {"eq": datetime} - else: - start_date = intervals[0] - end_date = intervals[1] - if ".." not in intervals: - start_date = f"{start_date[0:19]}Z" - end_date = f"{end_date[0:19]}Z" - elif start_date != "..": - start_date = f"{start_date[0:19]}Z" - end_date = "2200-12-01T12:31:12Z" - elif end_date != "..": - start_date = "1900-10-01T00:00:00Z" - end_date = f"{end_date[0:19]}Z" - else: - start_date = "1900-10-01T00:00:00Z" - end_date = "2200-12-01T12:31:12Z" - - return {"lte": end_date, "gte": start_date} - - async def get_search( - self, - request: Request, - collections: Optional[List[str]] = None, - ids: Optional[List[str]] = None, - bbox: Optional[List[NumType]] = None, - datetime: Optional[Union[str, datetime_type]] = None, - limit: Optional[int] = 10, - query: Optional[str] = None, - token: Optional[str] = None, - fields: Optional[List[str]] = None, - sortby: Optional[str] = None, - intersects: Optional[str] = None, - filter: Optional[str] = None, - filter_lang: Optional[str] = None, - **kwargs, - ) -> ItemCollection: - """Get search results from the database. - - Args: - collections (Optional[List[str]]): List of collection IDs to search in. - ids (Optional[List[str]]): List of item IDs to search for. - bbox (Optional[List[NumType]]): Bounding box to search in. - datetime (Optional[Union[str, datetime_type]]): Filter items based on the datetime field. - limit (Optional[int]): Maximum number of results to return. - query (Optional[str]): Query string to filter the results. - token (Optional[str]): Access token to use when searching the catalog. - fields (Optional[List[str]]): Fields to include or exclude from the results. - sortby (Optional[str]): Sorting options for the results. - intersects (Optional[str]): GeoJSON geometry to search in. - kwargs: Additional parameters to be passed to the API. - - Returns: - ItemCollection: Collection of `Item` objects representing the search results. - - Raises: - HTTPException: If any error occurs while searching the catalog. - """ - base_args = { - "collections": collections, - "ids": ids, - "bbox": bbox, - "limit": limit, - "token": token, - "query": orjson.loads(query) if query else query, - } - - # this is borrowed from stac-fastapi-pgstac - # Kludgy fix because using factory does not allow alias for filter-lan - query_params = str(request.query_params) - if filter_lang is None: - match = re.search(r"filter-lang=([a-z0-9-]+)", query_params, re.IGNORECASE) - if match: - filter_lang = match.group(1) - - if datetime: - base_args["datetime"] = datetime - - if intersects: - base_args["intersects"] = orjson.loads(unquote_plus(intersects)) - - if sortby: - sort_param = [] - for sort in sortby: - sort_param.append( - { - "field": sort[1:], - "direction": "desc" if sort[0] == "-" else "asc", - } - ) - print(sort_param) - base_args["sortby"] = sort_param - - if filter: - if filter_lang == "cql2-json": - base_args["filter-lang"] = "cql2-json" - base_args["filter"] = orjson.loads(unquote_plus(filter)) - else: - base_args["filter-lang"] = "cql2-json" - base_args["filter"] = orjson.loads(to_cql2(parse_cql2_text(filter))) - - if fields: - includes = set() - excludes = set() - for field in fields: - if field[0] == "-": - excludes.add(field[1:]) - elif field[0] == "+": - includes.add(field[1:]) - else: - includes.add(field) - base_args["fields"] = {"include": includes, "exclude": excludes} - - # Do the request - try: - search_request = self.post_request_model(**base_args) - except ValidationError: - raise HTTPException(status_code=400, detail="Invalid parameters provided") - resp = await self.post_search(search_request=search_request, request=request) - - return resp - - async def post_search( - self, search_request: BaseSearchPostRequest, request: Request - ) -> ItemCollection: - """ - Perform a POST search on the catalog. - - Args: - search_request (BaseSearchPostRequest): Request object that includes the parameters for the search. - kwargs: Keyword arguments passed to the function. - - Returns: - ItemCollection: A collection of items matching the search criteria. - - Raises: - HTTPException: If there is an error with the cql2_json filter. - """ - base_url = str(request.base_url) - - search = self.database.make_search() - - if search_request.ids: - search = self.database.apply_ids_filter( - search=search, item_ids=search_request.ids - ) - - if search_request.collections: - search = self.database.apply_collections_filter( - search=search, collection_ids=search_request.collections - ) - - if search_request.datetime: - datetime_search = self._return_date(search_request.datetime) - search = self.database.apply_datetime_filter( - search=search, datetime_search=datetime_search - ) - - if search_request.bbox: - bbox = search_request.bbox - if len(bbox) == 6: - bbox = [bbox[0], bbox[1], bbox[3], bbox[4]] - - search = self.database.apply_bbox_filter(search=search, bbox=bbox) - - if search_request.intersects: - search = self.database.apply_intersects_filter( - search=search, intersects=search_request.intersects - ) - - if search_request.query: - for (field_name, expr) in search_request.query.items(): - field = "properties__" + field_name - for (op, value) in expr.items(): - search = self.database.apply_stacql_filter( - search=search, op=op, field=field, value=value - ) - - # only cql2_json is supported here - if hasattr(search_request, "filter"): - cql2_filter = getattr(search_request, "filter", None) - try: - search = self.database.apply_cql2_filter(search, cql2_filter) - except Exception as e: - raise HTTPException( - status_code=400, detail=f"Error with cql2_json filter: {e}" - ) - - sort = None - if search_request.sortby: - sort = self.database.populate_sort(search_request.sortby) - - limit = 10 - if search_request.limit: - limit = search_request.limit - - items, maybe_count, next_token = await self.database.execute_search( - search=search, - limit=limit, - token=search_request.token, # type: ignore - sort=sort, - collection_ids=search_request.collections, - ) - - items = [ - self.item_serializer.db_to_stac(item, base_url=base_url) for item in items - ] - - if self.extension_is_enabled("FieldsExtension"): - if search_request.query is not None: - query_include: Set[str] = set( - [ - k if k in Settings.get().indexed_fields else f"properties.{k}" - for k in search_request.query.keys() - ] - ) - if not search_request.fields.include: - search_request.fields.include = query_include - else: - search_request.fields.include.union(query_include) - - filter_kwargs = search_request.fields.filter_fields - - items = [ - orjson.loads( - stac_pydantic.Item(**feat).json(**filter_kwargs, exclude_unset=True) - ) - for feat in items - ] - - context_obj = None - if self.extension_is_enabled("ContextExtension"): - context_obj = { - "returned": len(items), - "limit": limit, - } - if maybe_count is not None: - context_obj["matched"] = maybe_count - - links = [] - if next_token: - links = await PagingLinks(request=request, next=next_token).get_links() - - return ItemCollection( - type="FeatureCollection", - features=items, - links=links, - context=context_obj, - ) - - -@attr.s -class TransactionsClient(AsyncBaseTransactionsClient): - """Transactions extension specific CRUD operations.""" - - session: Session = attr.ib(default=attr.Factory(Session.create_from_env)) - database = DatabaseLogic() - - @overrides - async def create_item( - self, collection_id: str, item: stac_types.Item, **kwargs - ) -> stac_types.Item: - """Create an item in the collection. - - Args: - collection_id (str): The id of the collection to add the item to. - item (stac_types.Item): The item to be added to the collection. - kwargs: Additional keyword arguments. - - Returns: - stac_types.Item: The created item. - - Raises: - NotFound: If the specified collection is not found in the database. - ConflictError: If the item in the specified collection already exists. - - """ - base_url = str(kwargs["request"].base_url) - - # If a feature collection is posted - if item["type"] == "FeatureCollection": - bulk_client = BulkTransactionsClient() - processed_items = [ - bulk_client.preprocess_item(item, base_url, BulkTransactionMethod.INSERT) for item in item["features"] # type: ignore - ] - - await self.database.bulk_async( - collection_id, processed_items, refresh=kwargs.get("refresh", False) - ) - - return None # type: ignore - else: - item = await self.database.prep_create_item(item=item, base_url=base_url) - await self.database.create_item(item, refresh=kwargs.get("refresh", False)) - return item - - @overrides - async def update_item( - self, collection_id: str, item_id: str, item: stac_types.Item, **kwargs - ) -> stac_types.Item: - """Update an item in the collection. - - Args: - collection_id (str): The ID of the collection the item belongs to. - item_id (str): The ID of the item to be updated. - item (stac_types.Item): The new item data. - kwargs: Other optional arguments, including the request object. - - Returns: - stac_types.Item: The updated item object. - - Raises: - NotFound: If the specified collection is not found in the database. - - """ - base_url = str(kwargs["request"].base_url) - now = datetime_type.now(timezone.utc).isoformat().replace("+00:00", "Z") - item["properties"]["updated"] = now - - await self.database.check_collection_exists(collection_id) - await self.delete_item(item_id=item_id, collection_id=collection_id) - await self.create_item(collection_id=collection_id, item=item, **kwargs) - - return ItemSerializer.db_to_stac(item, base_url) - - @overrides - async def delete_item( - self, item_id: str, collection_id: str, **kwargs - ) -> stac_types.Item: - """Delete an item from a collection. - - Args: - item_id (str): The identifier of the item to delete. - collection_id (str): The identifier of the collection that contains the item. - - Returns: - Optional[stac_types.Item]: The deleted item, or `None` if the item was successfully deleted. - """ - await self.database.delete_item(item_id=item_id, collection_id=collection_id) - return None # type: ignore - - @overrides - async def create_collection( - self, collection: stac_types.Collection, **kwargs - ) -> stac_types.Collection: - """Create a new collection in the database. - - Args: - collection (stac_types.Collection): The collection to be created. - kwargs: Additional keyword arguments. - - Returns: - stac_types.Collection: The created collection object. - - Raises: - ConflictError: If the collection already exists. - """ - base_url = str(kwargs["request"].base_url) - collection_links = CollectionLinks( - collection_id=collection["id"], base_url=base_url - ).create_links() - collection["links"] = collection_links - await self.database.create_collection(collection=collection) - - return CollectionSerializer.db_to_stac(collection, base_url) - - @overrides - async def update_collection( - self, collection: stac_types.Collection, **kwargs - ) -> stac_types.Collection: - """ - Update a collection. - - This method updates an existing collection in the database by first finding - the collection by its id, then deleting the old version, and finally creating - a new version of the updated collection. The updated collection is then returned. - - Args: - collection: A STAC collection that needs to be updated. - kwargs: Additional keyword arguments. - - Returns: - A STAC collection that has been updated in the database. - - """ - base_url = str(kwargs["request"].base_url) - - await self.database.find_collection(collection_id=collection["id"]) - await self.delete_collection(collection["id"]) - await self.create_collection(collection, **kwargs) - - return CollectionSerializer.db_to_stac(collection, base_url) - - @overrides - async def delete_collection( - self, collection_id: str, **kwargs - ) -> stac_types.Collection: - """ - Delete a collection. - - This method deletes an existing collection in the database. - - Args: - collection_id (str): The identifier of the collection that contains the item. - kwargs: Additional keyword arguments. - - Returns: - None. - - Raises: - NotFoundError: If the collection doesn't exist. - """ - await self.database.delete_collection(collection_id=collection_id) - return None # type: ignore - - -@attr.s -class BulkTransactionsClient(BaseBulkTransactionsClient): - """A client for posting bulk transactions to a Postgres database. - - Attributes: - session: An instance of `Session` to use for database connection. - database: An instance of `DatabaseLogic` to perform database operations. - """ - - session: Session = attr.ib(default=attr.Factory(Session.create_from_env)) - database = DatabaseLogic() - - def __attrs_post_init__(self): - """Create es engine.""" - settings = ElasticsearchSettings() - self.client = settings.create_client - - def preprocess_item( - self, item: stac_types.Item, base_url, method: BulkTransactionMethod - ) -> stac_types.Item: - """Preprocess an item to match the data model. - - Args: - item: The item to preprocess. - base_url: The base URL of the request. - method: The bulk transaction method. - - Returns: - The preprocessed item. - """ - exist_ok = method == BulkTransactionMethod.UPSERT - return self.database.sync_prep_create_item( - item=item, base_url=base_url, exist_ok=exist_ok - ) - - @overrides - def bulk_item_insert( - self, items: Items, chunk_size: Optional[int] = None, **kwargs - ) -> str: - """Perform a bulk insertion of items into the database using Elasticsearch. - - Args: - items: The items to insert. - chunk_size: The size of each chunk for bulk processing. - **kwargs: Additional keyword arguments, such as `request` and `refresh`. - - Returns: - A string indicating the number of items successfully added. - """ - request = kwargs.get("request") - if request: - base_url = str(request.base_url) - else: - base_url = "" - - processed_items = [ - self.preprocess_item(item, base_url, items.method) - for item in items.items.values() - ] - - # not a great way to get the collection_id-- should be part of the method signature - collection_id = processed_items[0]["collection"] - - self.database.bulk_sync( - collection_id, processed_items, refresh=kwargs.get("refresh", False) - ) - - return f"Successfully added {len(processed_items)} Items." - - -@attr.s -class EsAsyncBaseFiltersClient(AsyncBaseFiltersClient): - """Defines a pattern for implementing the STAC filter extension.""" - - # todo: use the ES _mapping endpoint to dynamically find what fields exist - async def get_queryables( - self, collection_id: Optional[str] = None, **kwargs - ) -> Dict[str, Any]: - """Get the queryables available for the given collection_id. - - If collection_id is None, returns the intersection of all - queryables over all collections. - - This base implementation returns a blank queryable schema. This is not allowed - under OGC CQL but it is allowed by the STAC API Filter Extension - - https://github.com/radiantearth/stac-api-spec/tree/master/fragments/filter#queryables - - Args: - collection_id (str, optional): The id of the collection to get queryables for. - **kwargs: additional keyword arguments - - Returns: - Dict[str, Any]: A dictionary containing the queryables for the given collection. - """ - return { - "$schema": "https://json-schema.org/draft/2019-09/schema", - "$id": "https://stac-api.example.com/queryables", - "type": "object", - "title": "Queryables for Example STAC API", - "description": "Queryable names for the example STAC API Item Search filter.", - "properties": { - "id": { - "description": "ID", - "$ref": "https://schemas.stacspec.org/v1.0.0/item-spec/json-schema/item.json#/definitions/core/allOf/2/properties/id", - }, - "collection": { - "description": "Collection", - "$ref": "https://schemas.stacspec.org/v1.0.0/item-spec/json-schema/item.json#/definitions/core/allOf/2/then/properties/collection", - }, - "geometry": { - "description": "Geometry", - "$ref": "https://schemas.stacspec.org/v1.0.0/item-spec/json-schema/item.json#/definitions/core/allOf/1/oneOf/0/properties/geometry", - }, - "datetime": { - "description": "Acquisition Timestamp", - "$ref": "https://schemas.stacspec.org/v1.0.0/item-spec/json-schema/datetime.json#/properties/datetime", - }, - "created": { - "description": "Creation Timestamp", - "$ref": "https://schemas.stacspec.org/v1.0.0/item-spec/json-schema/datetime.json#/properties/created", - }, - "updated": { - "description": "Creation Timestamp", - "$ref": "https://schemas.stacspec.org/v1.0.0/item-spec/json-schema/datetime.json#/properties/updated", - }, - "cloud_cover": { - "description": "Cloud Cover", - "$ref": "https://stac-extensions.github.io/eo/v1.0.0/schema.json#/definitions/fields/properties/eo:cloud_cover", - }, - "cloud_shadow_percentage": { - "description": "Cloud Shadow Percentage", - "title": "Cloud Shadow Percentage", - "type": "number", - "minimum": 0, - "maximum": 100, - }, - "nodata_pixel_percentage": { - "description": "No Data Pixel Percentage", - "title": "No Data Pixel Percentage", - "type": "number", - "minimum": 0, - "maximum": 100, - }, - }, - "additionalProperties": True, - } diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/models/links.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/models/links.py deleted file mode 100644 index 3941a149..00000000 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/models/links.py +++ /dev/null @@ -1,138 +0,0 @@ -"""link helpers.""" - -from typing import Any, Dict, List, Optional -from urllib.parse import ParseResult, parse_qs, unquote, urlencode, urljoin, urlparse - -import attr -from stac_pydantic.links import Relations -from stac_pydantic.shared import MimeTypes -from starlette.requests import Request - -# Copied from pgstac links - -# These can be inferred from the item/collection, so they aren't included in the database -# Instead they are dynamically generated when querying the database using the classes defined below -INFERRED_LINK_RELS = ["self", "item", "parent", "collection", "root"] - - -def merge_params(url: str, newparams: Dict) -> str: - """Merge url parameters.""" - u = urlparse(url) - params = parse_qs(u.query) - params.update(newparams) - param_string = unquote(urlencode(params, True)) - - href = ParseResult( - scheme=u.scheme, - netloc=u.netloc, - path=u.path, - params=u.params, - query=param_string, - fragment=u.fragment, - ).geturl() - return href - - -@attr.s -class BaseLinks: - """Create inferred links common to collections and items.""" - - request: Request = attr.ib() - - @property - def base_url(self): - """Get the base url.""" - return str(self.request.base_url) - - @property - def url(self): - """Get the current request url.""" - return str(self.request.url) - - def resolve(self, url): - """Resolve url to the current request url.""" - return urljoin(str(self.base_url), str(url)) - - def link_self(self) -> Dict: - """Return the self link.""" - return dict(rel=Relations.self.value, type=MimeTypes.json.value, href=self.url) - - def link_root(self) -> Dict: - """Return the catalog root.""" - return dict( - rel=Relations.root.value, type=MimeTypes.json.value, href=self.base_url - ) - - def create_links(self) -> List[Dict[str, Any]]: - """Return all inferred links.""" - links = [] - for name in dir(self): - if name.startswith("link_") and callable(getattr(self, name)): - link = getattr(self, name)() - if link is not None: - links.append(link) - return links - - async def get_links( - self, extra_links: Optional[List[Dict[str, Any]]] = None - ) -> List[Dict[str, Any]]: - """ - Generate all the links. - - Get the links object for a stac resource by iterating through - available methods on this class that start with link_. - """ - # TODO: Pass request.json() into function so this doesn't need to be coroutine - if self.request.method == "POST": - self.request.postbody = await self.request.json() - # join passed in links with generated links - # and update relative paths - links = self.create_links() - - if extra_links: - # For extra links passed in, - # add links modified with a resolved href. - # Drop any links that are dynamically - # determined by the server (e.g. self, parent, etc.) - # Resolving the href allows for relative paths - # to be stored in pgstac and for the hrefs in the - # links of response STAC objects to be resolved - # to the request url. - links += [ - {**link, "href": self.resolve(link["href"])} - for link in extra_links - if link["rel"] not in INFERRED_LINK_RELS - ] - - return links - - -@attr.s -class PagingLinks(BaseLinks): - """Create links for paging.""" - - next: Optional[str] = attr.ib(kw_only=True, default=None) - - def link_next(self) -> Optional[Dict[str, Any]]: - """Create link for next page.""" - if self.next is not None: - method = self.request.method - if method == "GET": - href = merge_params(self.url, {"token": self.next}) - link = dict( - rel=Relations.next.value, - type=MimeTypes.json.value, - method=method, - href=href, - ) - return link - if method == "POST": - return { - "rel": Relations.next, - "type": MimeTypes.json, - "method": method, - "href": f"{self.request.url}", - "body": {**self.request.postbody, "token": self.next}, - } - - return None diff --git a/stac_fastapi/elasticsearch/tests/resources/test_item.py b/stac_fastapi/elasticsearch/tests/resources/test_item.py index 5b382873..c63be048 100644 --- a/stac_fastapi/elasticsearch/tests/resources/test_item.py +++ b/stac_fastapi/elasticsearch/tests/resources/test_item.py @@ -12,7 +12,8 @@ from geojson_pydantic.geometries import Polygon from pystac.utils import datetime_to_str -from stac_fastapi.elasticsearch.core import CoreClient +from stac_fastapi.core.core import CoreClient +from stac_fastapi.elasticsearch.database_logic import DatabaseLogic from stac_fastapi.elasticsearch.datetime_utils import now_to_rfc3339_str from stac_fastapi.types.core import LandingPageMixin @@ -23,6 +24,9 @@ def rfc3339_str_to_datetime(s: str) -> datetime: return ciso8601.parse_rfc3339(s) +database_logic = DatabaseLogic() + + @pytest.mark.asyncio async def test_create_and_delete_item(app_client, ctx, txn_client): """Test creation and deletion of a single item (transactions extension)""" @@ -773,7 +777,9 @@ async def test_conformance_classes_configurable(): # Update environment to avoid key error on client instantiation os.environ["READER_CONN_STRING"] = "testing" os.environ["WRITER_CONN_STRING"] = "testing" - client = CoreClient(base_conformance_classes=["this is a test"]) + client = CoreClient( + database=database_logic, base_conformance_classes=["this is a test"] + ) assert client.conformance_classes()[0] == "this is a test"