diff --git a/CHANGELOG.md b/CHANGELOG.md index 823f3c9e..1e4cbadf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] +### Added + - Added support for FreeTextExtension. [#227](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/227) + ### Changed - Support escaped backslashes in CQL2 `LIKE` queries, and reject invalid (or incomplete) escape sequences. [#286](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/286) diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index df20c103..56afcbc8 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -1,4 +1,5 @@ """Core client.""" + import logging from datetime import datetime as datetime_type from datetime import timezone @@ -456,6 +457,7 @@ async def get_search( token: Optional[str] = None, fields: Optional[List[str]] = None, sortby: Optional[str] = None, + q: Optional[List[str]] = None, intersects: Optional[str] = None, filter: Optional[str] = None, filter_lang: Optional[str] = None, @@ -473,6 +475,7 @@ async def get_search( token (Optional[str]): Access token to use when searching the catalog. fields (Optional[List[str]]): Fields to include or exclude from the results. sortby (Optional[str]): Sorting options for the results. + q (Optional[List[str]]): Free text query to filter the results. intersects (Optional[str]): GeoJSON geometry to search in. kwargs: Additional parameters to be passed to the API. @@ -489,6 +492,7 @@ async def get_search( "limit": limit, "token": token, "query": orjson.loads(query) if query else query, + "q": q, } if datetime: @@ -599,6 +603,15 @@ async def post_search( status_code=400, detail=f"Error with cql2_json filter: {e}" ) + if hasattr(search_request, "q"): + free_text_queries = getattr(search_request, "q", None) + try: + search = self.database.apply_free_text_filter(search, free_text_queries) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Error with free text query: {e}" + ) + sort = None if search_request.sortby: sort = self.database.populate_sort(search_request.sortby) diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py index 5c1a83f4..6b26c2ac 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py @@ -28,6 +28,7 @@ from stac_fastapi.extensions.core import ( AggregationExtension, FilterExtension, + FreeTextExtension, SortExtension, TokenPaginationExtension, TransactionExtension, @@ -71,6 +72,7 @@ SortExtension(), TokenPaginationExtension(), filter_extension, + FreeTextExtension(), ] extensions = [aggregation_extension] + search_extensions diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 348b8784..7aa887b5 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -1,4 +1,5 @@ """Database logic.""" + import asyncio import logging import os @@ -509,6 +510,17 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float): return search + @staticmethod + def apply_free_text_filter(search: Search, free_text_queries: Optional[List[str]]): + """Database logic to perform query for search endpoint.""" + if free_text_queries is not None: + free_text_query_string = '" OR properties.\\*:"'.join(free_text_queries) + search = search.query( + "query_string", query=f'properties.\\*:"{free_text_query_string}"' + ) + + return search + @staticmethod def apply_cql2_filter(search: Search, _filter: Optional[Dict[str, Any]]): """ diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py index 186a85ab..2a764518 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py @@ -22,6 +22,7 @@ from stac_fastapi.extensions.core import ( AggregationExtension, FilterExtension, + FreeTextExtension, SortExtension, TokenPaginationExtension, TransactionExtension, @@ -71,6 +72,7 @@ SortExtension(), TokenPaginationExtension(), filter_extension, + FreeTextExtension(), ] extensions = [aggregation_extension] + search_extensions diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index 1a0e7c85..014ea57b 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -1,4 +1,5 @@ """Database logic.""" + import asyncio import logging import os @@ -426,6 +427,17 @@ def apply_collections_filter(search: Search, collection_ids: List[str]): """Database logic to search a list of STAC collection ids.""" return search.filter("terms", collection=collection_ids) + @staticmethod + def apply_free_text_filter(search: Search, free_text_queries: Optional[List[str]]): + """Database logic to perform query for search endpoint.""" + if free_text_queries is not None: + free_text_query_string = '" OR properties.\\*:"'.join(free_text_queries) + search = search.query( + "query_string", query=f'properties.\\*:"{free_text_query_string}"' + ) + + return search + @staticmethod def apply_datetime_filter(search: Search, datetime_search): """Apply a filter to search based on datetime field. diff --git a/stac_fastapi/tests/conftest.py b/stac_fastapi/tests/conftest.py index dde9375e..ca2d8436 100644 --- a/stac_fastapi/tests/conftest.py +++ b/stac_fastapi/tests/conftest.py @@ -49,6 +49,7 @@ AggregationExtension, FieldsExtension, FilterExtension, + FreeTextExtension, SortExtension, TokenPaginationExtension, TransactionExtension, @@ -215,6 +216,7 @@ async def app(): QueryExtension(), TokenPaginationExtension(), FilterExtension(), + FreeTextExtension(), ] extensions = [aggregation_extension] + search_extensions @@ -301,6 +303,7 @@ async def app_basic_auth(): QueryExtension(), TokenPaginationExtension(), FilterExtension(), + FreeTextExtension(), ] extensions = [aggregation_extension] + search_extensions @@ -380,6 +383,7 @@ async def route_dependencies_app(): QueryExtension(), TokenPaginationExtension(), FilterExtension(), + FreeTextExtension(), ] post_request_model = create_post_request_model(extensions) diff --git a/stac_fastapi/tests/resources/test_item.py b/stac_fastapi/tests/resources/test_item.py index ce8a90fc..904adbbf 100644 --- a/stac_fastapi/tests/resources/test_item.py +++ b/stac_fastapi/tests/resources/test_item.py @@ -533,6 +533,48 @@ async def test_item_search_properties_field(app_client): assert len(resp_json["features"]) == 0 +@pytest.mark.asyncio +async def test_item_search_free_text_extension(app_client, txn_client, ctx): + """Test POST search indexed field with q parameter (free-text)""" + first_item = ctx.item + + second_item = dict(first_item) + second_item["id"] = "second-item" + second_item["properties"]["ft_field1"] = "hello" + + await create_item(txn_client, second_item) + + params = {"q": ["hello"]} + resp = await app_client.post("/search", json=params) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json["features"]) == 1 + + +@pytest.mark.asyncio +async def test_item_search_free_text_extension_or_query(app_client, txn_client, ctx): + """Test POST search indexed field with q parameter with multiple terms (free-text)""" + first_item = ctx.item + + second_item = dict(first_item) + second_item["id"] = "second-item" + second_item["properties"]["ft_field1"] = "hello" + second_item["properties"]["ft_field2"] = "world" + + await create_item(txn_client, second_item) + + third_item = dict(first_item) + third_item["id"] = "third-item" + third_item["properties"]["ft_field1"] = "world" + await create_item(txn_client, third_item) + + params = {"q": ["hello", "world"]} + resp = await app_client.post("/search", json=params) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json["features"]) == 2 + + @pytest.mark.asyncio async def test_item_search_get_query_extension(app_client, ctx): """Test GET search with JSONB query (query extension)"""