Skip to content

Commit

Permalink
Add free-text search extension (#227)
Browse files Browse the repository at this point in the history
**Description:**

Adding the free-text search extension. Related to
stac-utils/stac-fastapi#655

**PR Checklist:**

- [x] Code is formatted and linted (run `pre-commit run --all-files`)
- [x] Tests pass (run `make test`)
- [x] Documentation has been updated to reflect changes, if applicable
- [x] Changes are added to the changelog
  • Loading branch information
rhysrevans3 authored Aug 31, 2024
1 parent ace0c7a commit ae5588e
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]

### Added
- Added support for FreeTextExtension. [#227](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/227)

### Changed
- Support escaped backslashes in CQL2 `LIKE` queries, and reject invalid (or incomplete) escape sequences. [#286](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/286)

Expand Down
13 changes: 13 additions & 0 deletions stac_fastapi/core/stac_fastapi/core/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Core client."""

import logging
from datetime import datetime as datetime_type
from datetime import timezone
Expand Down Expand Up @@ -456,6 +457,7 @@ async def get_search(
token: Optional[str] = None,
fields: Optional[List[str]] = None,
sortby: Optional[str] = None,
q: Optional[List[str]] = None,
intersects: Optional[str] = None,
filter: Optional[str] = None,
filter_lang: Optional[str] = None,
Expand All @@ -473,6 +475,7 @@ async def get_search(
token (Optional[str]): Access token to use when searching the catalog.
fields (Optional[List[str]]): Fields to include or exclude from the results.
sortby (Optional[str]): Sorting options for the results.
q (Optional[List[str]]): Free text query to filter the results.
intersects (Optional[str]): GeoJSON geometry to search in.
kwargs: Additional parameters to be passed to the API.
Expand All @@ -489,6 +492,7 @@ async def get_search(
"limit": limit,
"token": token,
"query": orjson.loads(query) if query else query,
"q": q,
}

if datetime:
Expand Down Expand Up @@ -599,6 +603,15 @@ async def post_search(
status_code=400, detail=f"Error with cql2_json filter: {e}"
)

if hasattr(search_request, "q"):
free_text_queries = getattr(search_request, "q", None)
try:
search = self.database.apply_free_text_filter(search, free_text_queries)
except Exception as e:
raise HTTPException(
status_code=400, detail=f"Error with free text query: {e}"
)

sort = None
if search_request.sortby:
sort = self.database.populate_sort(search_request.sortby)
Expand Down
2 changes: 2 additions & 0 deletions stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from stac_fastapi.extensions.core import (
AggregationExtension,
FilterExtension,
FreeTextExtension,
SortExtension,
TokenPaginationExtension,
TransactionExtension,
Expand Down Expand Up @@ -71,6 +72,7 @@
SortExtension(),
TokenPaginationExtension(),
filter_extension,
FreeTextExtension(),
]

extensions = [aggregation_extension] + search_extensions
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Database logic."""

import asyncio
import logging
import os
Expand Down Expand Up @@ -509,6 +510,17 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float):

return search

@staticmethod
def apply_free_text_filter(search: Search, free_text_queries: Optional[List[str]]):
"""Database logic to perform query for search endpoint."""
if free_text_queries is not None:
free_text_query_string = '" OR properties.\\*:"'.join(free_text_queries)
search = search.query(
"query_string", query=f'properties.\\*:"{free_text_query_string}"'
)

return search

@staticmethod
def apply_cql2_filter(search: Search, _filter: Optional[Dict[str, Any]]):
"""
Expand Down
2 changes: 2 additions & 0 deletions stac_fastapi/opensearch/stac_fastapi/opensearch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from stac_fastapi.extensions.core import (
AggregationExtension,
FilterExtension,
FreeTextExtension,
SortExtension,
TokenPaginationExtension,
TransactionExtension,
Expand Down Expand Up @@ -71,6 +72,7 @@
SortExtension(),
TokenPaginationExtension(),
filter_extension,
FreeTextExtension(),
]

extensions = [aggregation_extension] + search_extensions
Expand Down
12 changes: 12 additions & 0 deletions stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Database logic."""

import asyncio
import logging
import os
Expand Down Expand Up @@ -426,6 +427,17 @@ def apply_collections_filter(search: Search, collection_ids: List[str]):
"""Database logic to search a list of STAC collection ids."""
return search.filter("terms", collection=collection_ids)

@staticmethod
def apply_free_text_filter(search: Search, free_text_queries: Optional[List[str]]):
"""Database logic to perform query for search endpoint."""
if free_text_queries is not None:
free_text_query_string = '" OR properties.\\*:"'.join(free_text_queries)
search = search.query(
"query_string", query=f'properties.\\*:"{free_text_query_string}"'
)

return search

@staticmethod
def apply_datetime_filter(search: Search, datetime_search):
"""Apply a filter to search based on datetime field.
Expand Down
4 changes: 4 additions & 0 deletions stac_fastapi/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
AggregationExtension,
FieldsExtension,
FilterExtension,
FreeTextExtension,
SortExtension,
TokenPaginationExtension,
TransactionExtension,
Expand Down Expand Up @@ -215,6 +216,7 @@ async def app():
QueryExtension(),
TokenPaginationExtension(),
FilterExtension(),
FreeTextExtension(),
]

extensions = [aggregation_extension] + search_extensions
Expand Down Expand Up @@ -301,6 +303,7 @@ async def app_basic_auth():
QueryExtension(),
TokenPaginationExtension(),
FilterExtension(),
FreeTextExtension(),
]

extensions = [aggregation_extension] + search_extensions
Expand Down Expand Up @@ -380,6 +383,7 @@ async def route_dependencies_app():
QueryExtension(),
TokenPaginationExtension(),
FilterExtension(),
FreeTextExtension(),
]

post_request_model = create_post_request_model(extensions)
Expand Down
42 changes: 42 additions & 0 deletions stac_fastapi/tests/resources/test_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,48 @@ async def test_item_search_properties_field(app_client):
assert len(resp_json["features"]) == 0


@pytest.mark.asyncio
async def test_item_search_free_text_extension(app_client, txn_client, ctx):
"""Test POST search indexed field with q parameter (free-text)"""
first_item = ctx.item

second_item = dict(first_item)
second_item["id"] = "second-item"
second_item["properties"]["ft_field1"] = "hello"

await create_item(txn_client, second_item)

params = {"q": ["hello"]}
resp = await app_client.post("/search", json=params)
assert resp.status_code == 200
resp_json = resp.json()
assert len(resp_json["features"]) == 1


@pytest.mark.asyncio
async def test_item_search_free_text_extension_or_query(app_client, txn_client, ctx):
"""Test POST search indexed field with q parameter with multiple terms (free-text)"""
first_item = ctx.item

second_item = dict(first_item)
second_item["id"] = "second-item"
second_item["properties"]["ft_field1"] = "hello"
second_item["properties"]["ft_field2"] = "world"

await create_item(txn_client, second_item)

third_item = dict(first_item)
third_item["id"] = "third-item"
third_item["properties"]["ft_field1"] = "world"
await create_item(txn_client, third_item)

params = {"q": ["hello", "world"]}
resp = await app_client.post("/search", json=params)
assert resp.status_code == 200
resp_json = resp.json()
assert len(resp_json["features"]) == 2


@pytest.mark.asyncio
async def test_item_search_get_query_extension(app_client, ctx):
"""Test GET search with JSONB query (query extension)"""
Expand Down

0 comments on commit ae5588e

Please sign in to comment.