Skip to content

Commit

Permalink
collection aggregation, removed redundant filter tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesfisher-geo committed Jul 9, 2024
1 parent 1dcc41a commit 38dd41c
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 227 deletions.
99 changes: 66 additions & 33 deletions stac_fastapi/core/stac_fastapi/core/extensions/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

import attr
import orjson
from fastapi import HTTPException, Request
from fastapi import HTTPException, Path, Request
from pygeofilter.backends.cql2_json import to_cql2
from pygeofilter.parsers.cql2_text import parse as parse_cql2_text
from stac_pydantic.shared import BBox
from typing_extensions import Annotated

from stac_fastapi.core.base_database_logic import BaseDatabaseLogic
from stac_fastapi.core.base_settings import ApiBaseSettings
Expand Down Expand Up @@ -39,7 +40,11 @@
class EsAggregationExtensionGetRequest(
AggregationExtensionGetRequest, FilterExtensionGetRequest
):
"""Add implementation specific query parameters to AggregationExtensionGetRequest for aggrgeation precision."""
"""Implementation specific query parameters for aggregation precision."""

collection_id: Optional[
Annotated[str, Path(description="Collection ID")]
] = attr.ib(default=None)

centroid_geohash_grid_frequency_precision: Optional[int] = attr.ib(default=None)
centroid_geohex_grid_frequency_precision: Optional[int] = attr.ib(default=None)
Expand All @@ -51,7 +56,7 @@ class EsAggregationExtensionGetRequest(
class EsAggregationExtensionPostRequest(
AggregationExtensionPostRequest, FilterExtensionPostRequest
):
"""Add implementation specific query parameters to AggregationExtensionPostRequest for aggrgeation precision."""
"""Implementation specific query parameters for aggregation precision."""

centroid_geohash_grid_frequency_precision: Optional[int] = None
centroid_geohex_grid_frequency_precision: Optional[int] = None
Expand Down Expand Up @@ -130,7 +135,6 @@ async def get_aggregations(self, collection_id: Optional[str] = None, **kwargs):
"""Get the available aggregations for a catalog or collection defined in the STAC JSON. If no aggregations, default aggregations are used."""
request: Request = kwargs["request"]
base_url = str(request.base_url)

links = [{"rel": "root", "type": "application/json", "href": base_url}]

if collection_id is not None:
Expand All @@ -149,13 +153,13 @@ async def get_aggregations(self, collection_id: Optional[str] = None, **kwargs):
},
]
)
if self.database.check_collection_exists(collection_id):
if await self.database.check_collection_exists(collection_id) is None:
collection = await self.database.find_collection(collection_id)
aggregations = collection.get(
"aggregations", self.DEFAULT_AGGREGATIONS.copy()
)
else:
raise IndexError("Collection does not exist")
raise IndexError(f"Collection {collection_id} does not exist")
else:
links.append(
{
Expand Down Expand Up @@ -257,14 +261,12 @@ def frequency_agg(self, es_aggs, name, data_type):

def metric_agg(self, es_aggs, name, data_type):
"""Format an aggregation for a metric aggregation."""
if name == "datetime_min" or name == "datetime_max":
value = datetime_to_str(
datetime.fromtimestamp(es_aggs.get(name, {}).get("value") / 1e3)
)
else:
value = es_aggs.get(name, {}).get("value_as_string") or es_aggs.get(
name, {}
).get("value")
value = es_aggs.get(name, {}).get("value_as_string") or es_aggs.get(
name, {}
).get("value")
# ES 7.x does not return datetimes with a 'value_as_string' field
if "datetime" in name and isinstance(value, float):
value = datetime_to_str(datetime.fromtimestamp(value / 1e3))
return Aggregation(
name=name,
data_type=data_type,
Expand Down Expand Up @@ -306,8 +308,10 @@ def format_datetime(dt):

async def aggregate(
self,
# collection_id: Optional[str] = None,
aggregate_request: Optional[EsAggregationExtensionPostRequest] = None,
collection_id: Optional[
Annotated[str, Path(description="Collection ID")]
] = None,
collections: Optional[List[str]] = [],
datetime: Optional[DateTimeType] = None,
intersects: Optional[str] = None,
Expand All @@ -326,9 +330,11 @@ async def aggregate(
"""Get aggregations from the database."""
request: Request = kwargs["request"]
base_url = str(request.base_url)
path = request.url.path
search = self.database.make_search()

if aggregate_request is None:

base_args = {
"collections": collections,
"ids": ids,
Expand All @@ -341,6 +347,9 @@ async def aggregate(
"geometry_geotile_grid_frequency_precision": geometry_geotile_grid_frequency_precision,
}

if collection_id:
collections = [str(collection_id)]

if intersects:
base_args["intersects"] = orjson.loads(unquote_plus(intersects))

Expand All @@ -359,18 +368,29 @@ async def aggregate(
filter_lang = match.group(1)
else:
filter_lang = "cql2-text"

if filter:
base_args["filter"] = self.get_filter(filter, filter_lang)

aggregate_request = EsAggregationExtensionPostRequest(**base_args)
else:
# Workaround for optional path param in POST requests
if "collections" in path:
collection_id = path.split("/")[2]

filter_lang = "cql2-json"
if aggregate_request.filter:
aggregate_request.filter = self.get_filter(
aggregate_request.filter, filter_lang
)

if collection_id:
if aggregate_request.collections:
raise HTTPException(
status_code=400,
detail="Cannot query multiple collections when executing '/collections/<collection_id>/aggregate'. Use '/aggregate' and the collections field instead",
)
else:
aggregate_request.collections = [collection_id]

if (
aggregate_request.aggregations is None
or aggregate_request.aggregations == []
Expand Down Expand Up @@ -495,8 +515,10 @@ async def aggregate(
aggs = []
if db_response:
result_aggs = db_response.get("aggregations", {})

for agg in supported_aggregations + self.GEO_POINT_AGGREGATIONS:
for agg in {
frozenset(item.items()): item
for item in supported_aggregations + self.GEO_POINT_AGGREGATIONS
}.values():
if agg["name"] in aggregate_request.aggregations:
if agg["name"].endswith("_frequency"):
aggs.append(
Expand All @@ -508,23 +530,34 @@ async def aggregate(
aggs.append(
self.metric_agg(result_aggs, agg["name"], agg["data_type"])
)

links = [
{
"rel": "self",
"type": "application/json",
"href": urljoin(base_url, "aggregate"),
},
{"rel": "root", "type": "application/json", "href": base_url},
]
# if collection_endpoint:
# links.append(
# {
# "rel": "collection",
# "type": "application/json",
# "href": collection_endpoint,
# }
# )

if collection_id:
collection_endpoint = urljoin(base_url, f"collections/{collection_id}")
links.extend(
[
{
"rel": "collection",
"type": "application/json",
"href": collection_endpoint,
},
{
"rel": "self",
"type": "application/json",
"href": urljoin(collection_endpoint, "aggregate"),
},
]
)
else:
links.append(
{
"rel": "self",
"type": "application/json",
"href": urljoin(base_url, "aggregate"),
}
)
results = AggregationCollection(
type="AggregationCollection", aggregations=aggs, links=links
)
Expand Down
Loading

0 comments on commit 38dd41c

Please sign in to comment.