Skip to content

Commit

Permalink
Paginated search queries now don't return a token on the last page (#243
Browse files Browse the repository at this point in the history
)

**Related Issue(s):**

-
#242

**Merge dependencie(s):**

-
#241

**Description:**
- Paginated search queries now don't return a token on the last page.
- Made some fixes to the respective tests. In particular
`test_pagination_token_idempotent` had and indentation issue
- Improved `execute_search` to make use of
`es_response["hits"]["total"]["value"]`

**PR Checklist:**

- [x] Code is formatted and linted (run `pre-commit run --all-files`)
- [x] Tests pass (run `make test`)
- [x] Documentation has been updated to reflect changes, if applicable
- [x] Changes are added to the changelog

---------

Co-authored-by: Jonathan Healy <[email protected]>
  • Loading branch information
pedro-cf and jonhealy1 authored May 8, 2024
1 parent c5c96c9 commit 55dd87e
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

### Fixed

- Fixed issue where paginated search queries would return a `next_token` on the last page [#243](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/243)
- Fixed issue where searches return an empty `links` array [#241](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/241)

## [v2.4.0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import attr
from elasticsearch_dsl import Q, Search

import stac_fastapi.types.search
from elasticsearch import exceptions, helpers # type: ignore
from stac_fastapi.core.extensions import filter
from stac_fastapi.core.serializers import CollectionSerializer, ItemSerializer
Expand Down Expand Up @@ -552,21 +553,26 @@ async def execute_search(
NotFoundError: If the collections specified in `collection_ids` do not exist.
"""
search_after = None

if token:
search_after = urlsafe_b64decode(token.encode()).decode().split(",")

query = search.query.to_dict() if search.query else None

index_param = indices(collection_ids)

max_result_window = stac_fastapi.types.search.Limit.le

size_limit = min(limit + 1, max_result_window)

search_task = asyncio.create_task(
self.client.search(
index=index_param,
ignore_unavailable=ignore_unavailable,
query=query,
sort=sort or DEFAULT_SORT,
search_after=search_after,
size=limit,
size=size_limit,
)
)

Expand All @@ -584,24 +590,27 @@ async def execute_search(
raise NotFoundError(f"Collections '{collection_ids}' do not exist")

hits = es_response["hits"]["hits"]
items = (hit["_source"] for hit in hits)
items = (hit["_source"] for hit in hits[:limit])

next_token = None
if hits and (sort_array := hits[-1].get("sort")):
next_token = urlsafe_b64encode(
",".join([str(x) for x in sort_array]).encode()
).decode()

# (1) count should not block returning results, so don't wait for it to be done
# (2) don't cancel the task so that it will populate the ES cache for subsequent counts
maybe_count = None
if len(hits) > limit and limit < max_result_window:
if hits and (sort_array := hits[limit - 1].get("sort")):
next_token = urlsafe_b64encode(
",".join([str(x) for x in sort_array]).encode()
).decode()

matched = (
es_response["hits"]["total"]["value"]
if es_response["hits"]["total"]["relation"] == "eq"
else None
)
if count_task.done():
try:
maybe_count = count_task.result().get("count")
matched = count_task.result().get("count")
except Exception as e:
logger.error(f"Count task failed: {e}")

return items, maybe_count, next_token
return items, matched, next_token

""" TRANSACTION LOGIC """

Expand Down
37 changes: 25 additions & 12 deletions stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from opensearchpy.helpers.query import Q
from opensearchpy.helpers.search import Search

import stac_fastapi.types.search
from stac_fastapi.core import serializers
from stac_fastapi.core.extensions import filter
from stac_fastapi.core.utilities import bbox2polygon
Expand Down Expand Up @@ -582,19 +583,28 @@ async def execute_search(
query = search.query.to_dict() if search.query else None
if query:
search_body["query"] = query

search_after = None

if token:
search_after = urlsafe_b64decode(token.encode()).decode().split(",")
if search_after:
search_body["search_after"] = search_after

search_body["sort"] = sort if sort else DEFAULT_SORT

index_param = indices(collection_ids)

max_result_window = stac_fastapi.types.search.Limit.le

size_limit = min(limit + 1, max_result_window)

search_task = asyncio.create_task(
self.client.search(
index=index_param,
ignore_unavailable=ignore_unavailable,
body=search_body,
size=limit,
size=size_limit,
)
)

Expand All @@ -612,24 +622,27 @@ async def execute_search(
raise NotFoundError(f"Collections '{collection_ids}' do not exist")

hits = es_response["hits"]["hits"]
items = (hit["_source"] for hit in hits)
items = (hit["_source"] for hit in hits[:limit])

next_token = None
if hits and (sort_array := hits[-1].get("sort")):
next_token = urlsafe_b64encode(
",".join([str(x) for x in sort_array]).encode()
).decode()

# (1) count should not block returning results, so don't wait for it to be done
# (2) don't cancel the task so that it will populate the ES cache for subsequent counts
maybe_count = None
if len(hits) > limit and limit < max_result_window:
if hits and (sort_array := hits[limit - 1].get("sort")):
next_token = urlsafe_b64encode(
",".join([str(x) for x in sort_array]).encode()
).decode()

matched = (
es_response["hits"]["total"]["value"]
if es_response["hits"]["total"]["relation"] == "eq"
else None
)
if count_task.done():
try:
maybe_count = count_task.result().get("count")
matched = count_task.result().get("count")
except Exception as e:
logger.error(f"Count task failed: {e}")

return items, maybe_count, next_token
return items, matched, next_token

""" TRANSACTION LOGIC """

Expand Down
24 changes: 9 additions & 15 deletions stac_fastapi/tests/resources/test_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,12 +492,9 @@ async def test_item_search_temporal_window_timezone_get(app_client, ctx):
"datetime": f"{datetime_to_str(item_date_before)}/{datetime_to_str(item_date_after)}",
}
resp = await app_client.get("/search", params=params)
resp_json = resp.json()
next_link = next(link for link in resp_json["links"] if link["rel"] == "next")[
"href"
]
resp = await app_client.get(next_link)
assert resp.status_code == 200
resp_json = resp.json()
assert resp_json["features"][0]["id"] == test_item["id"]


@pytest.mark.asyncio
Expand Down Expand Up @@ -632,18 +629,17 @@ async def test_pagination_item_collection(app_client, ctx, txn_client):
await create_item(txn_client, item=ctx.item)
ids.append(ctx.item["id"])

# Paginate through all 6 items with a limit of 1 (expecting 7 requests)
# Paginate through all 6 items with a limit of 1 (expecting 6 requests)
page = await app_client.get(
f"/collections/{ctx.item['collection']}/items", params={"limit": 1}
)

item_ids = []
idx = 0
for idx in range(100):
for idx in range(1, 100):
page_data = page.json()
next_link = list(filter(lambda link: link["rel"] == "next", page_data["links"]))
if not next_link:
assert not page_data["features"]
assert idx == 6
break

assert len(page_data["features"]) == 1
Expand Down Expand Up @@ -672,10 +668,8 @@ async def test_pagination_post(app_client, ctx, txn_client):
# Paginate through all 5 items with a limit of 1 (expecting 5 requests)
request_body = {"ids": ids, "limit": 1}
page = await app_client.post("/search", json=request_body)
idx = 0
item_ids = []
for _ in range(100):
idx += 1
for idx in range(1, 100):
page_data = page.json()
next_link = list(filter(lambda link: link["rel"] == "next", page_data["links"]))
if not next_link:
Expand All @@ -688,7 +682,7 @@ async def test_pagination_post(app_client, ctx, txn_client):
page = await app_client.post("/search", json=request_body)

# Our limit is 1, so we expect len(ids) number of requests before we run out of pages
assert idx == len(ids) + 1
assert idx == len(ids)

# Confirm we have paginated through all items
assert not set(item_ids) - set(ids)
Expand All @@ -702,8 +696,8 @@ async def test_pagination_token_idempotent(app_client, ctx, txn_client):
# Ingest 5 items
for _ in range(5):
ctx.item["id"] = str(uuid.uuid4())
await create_item(txn_client, ctx.item)
ids.append(ctx.item["id"])
await create_item(txn_client, ctx.item)
ids.append(ctx.item["id"])

page = await app_client.get("/search", params={"ids": ",".join(ids), "limit": 3})
page_data = page.json()
Expand Down

0 comments on commit 55dd87e

Please sign in to comment.