diff --git a/CHANGELOG.md b/CHANGELOG.md index 1097805..2f54353 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) -and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). +and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0/). ## [Unreleased] @@ -15,7 +15,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Fixed - Removed bulk transactions extension from app.py -- Fixed pagination bug so pagination functions +- Fixed pagination issue with MongoDB. Fixes [#1](https://github.com/Healy-Hyperspatial/stac-fastapi-mongo/issues/1) ## [v3.0.0] @@ -27,5 +27,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Fixed -[Unreleased]: + +[Unreleased]: [v3.0.0]: diff --git a/stac_fastapi/mongo/database_logic.py b/stac_fastapi/mongo/database_logic.py index c7357d0..0444013 100644 --- a/stac_fastapi/mongo/database_logic.py +++ b/stac_fastapi/mongo/database_logic.py @@ -554,28 +554,37 @@ async def execute_search( query = {"$and": search.filters} if search and search.filters else {} print("Query: ", query) + if collection_ids: query["collection"] = {"$in": collection_ids} sort_criteria = sort if sort else [("id", 1)] # Default sort + try: if token: last_id = decode_token(token) - query["id"] = {"$gt": last_id} + skip_count = int(last_id) + else: + skip_count = 0 - cursor = collection.find(query).sort(sort_criteria).limit(limit + 1) + cursor = ( + collection.find(query) + .sort(sort_criteria) + .skip(skip_count) + .limit(limit + 1) + ) items = await cursor.to_list(length=limit + 1) next_token = None if len(items) > limit: - next_token = base64.urlsafe_b64encode( - str(items[-1]["id"]).encode() - ).decode() + next_skip = skip_count + limit + next_token = base64.urlsafe_b64encode(str(next_skip).encode()).decode() items = items[:-1] maybe_count = None if not token: maybe_count = await collection.count_documents(query) + return items, maybe_count, next_token except PyMongoError as e: print(f"Database operation failed: {e}") diff --git a/stac_fastapi/tests/resources/test_item.py b/stac_fastapi/tests/resources/test_item.py index 4fae107..17eab57 100644 --- a/stac_fastapi/tests/resources/test_item.py +++ b/stac_fastapi/tests/resources/test_item.py @@ -553,106 +553,121 @@ async def test_get_missing_item_collection(app_client): assert resp.status_code == 404 -@pytest.mark.skip( - reason="Pagination is not working in mongo, setting the limit doesn't limit the number of results. You can keep going to the next result." -) @pytest.mark.asyncio async def test_pagination_item_collection(app_client, ctx, txn_client): """Test item collection pagination links (paging extension)""" - ids = [ctx.item["id"]] + # Initialize a list to store the expected item IDs + expected_item_ids = [ctx.item["id"]] - # Ingest 5 items + # Ingest 5 items in addition to the default test-item for _ in range(5): ctx.item["id"] = str(uuid.uuid4()) await create_item(txn_client, item=ctx.item) - ids.append(ctx.item["id"]) + expected_item_ids.append(ctx.item["id"]) - # Paginate through all 6 items with a limit of 1 (expecting 7 requests) + # Paginate through all items with a limit of 1 (expecting 6 requests) page = await app_client.get( f"/collections/{ctx.item['collection']}/items", params={"limit": 1} ) - item_ids = [] - idx = 0 - for idx in range(100): + retrieved_item_ids = [] + request_count = 0 + for _ in range(100): + request_count += 1 page_data = page.json() next_link = list(filter(lambda link: link["rel"] == "next", page_data["links"])) if not next_link: - assert not page_data["features"] + # Ensure that the last page contains features + assert page_data["features"] break + # Assert that each page contains only one feature assert len(page_data["features"]) == 1 - item_ids.append(page_data["features"][0]["id"]) + retrieved_item_ids.append(page_data["features"][0]["id"]) + # Extract the next page URL href = next_link[0]["href"][len("http://test-server") :] page = await app_client.get(href) - assert idx == len(ids) + # Assert that the number of requests made is equal to the total number of items ingested + assert request_count == len(expected_item_ids) - # Confirm we have paginated through all items - assert not set(item_ids) - set(ids) + # Confirm we have paginated through all items by comparing the expected and retrieved item IDs + assert not set(retrieved_item_ids) - set(expected_item_ids) -@pytest.mark.skip(reason="fix pagination in mongo") @pytest.mark.asyncio async def test_pagination_post(app_client, ctx, txn_client): """Test POST pagination (paging extension)""" - ids = [ctx.item["id"]] + # Initialize a list to store the expected item IDs + expected_item_ids = [ctx.item["id"]] - # Ingest 5 items + # Ingest 5 items in addition to the default test-item for _ in range(5): ctx.item["id"] = str(uuid.uuid4()) - await create_item(txn_client, ctx.item) - ids.append(ctx.item["id"]) + await create_item(txn_client, ctx.item) + expected_item_ids.append(ctx.item["id"]) - # Paginate through all 5 items with a limit of 1 (expecting 5 requests) - request_body = {"ids": ids, "limit": 1} + # Prepare the initial request body with item IDs and a limit of 1 + request_body = {"ids": expected_item_ids, "limit": 1} + + # Perform the initial POST request to start pagination page = await app_client.post("/search", json=request_body) - idx = 0 - item_ids = [] + + retrieved_item_ids = [] + request_count = 0 for _ in range(100): - idx += 1 + request_count += 1 page_data = page.json() + + # Extract the next link from the page data next_link = list(filter(lambda link: link["rel"] == "next", page_data["links"])) + + # If there is no next link, exit the loop if not next_link: break - item_ids.append(page_data["features"][0]["id"]) + # Retrieve the ID of the first item on the current page and add it to the list + retrieved_item_ids.append(page_data["features"][0]["id"]) - # Merge request bodies + # Update the request body with the parameters from the next link request_body.update(next_link[0]["body"]) + + # Perform the next POST request using the updated request body page = await app_client.post("/search", json=request_body) - # Our limit is 1, so we expect len(ids) number of requests before we run out of pages - assert idx == len(ids) + 1 + # Our limit is 1, so we expect len(expected_item_ids) number of requests before we run out of pages + assert request_count == len(expected_item_ids) - # Confirm we have paginated through all items - assert not set(item_ids) - set(ids) + # Confirm we have paginated through all items by comparing the expected and retrieved item IDs + assert not set(retrieved_item_ids) - set(expected_item_ids) -@pytest.mark.skip(reason="fix pagination in mongo") @pytest.mark.asyncio async def test_pagination_token_idempotent(app_client, ctx, txn_client): """Test that pagination tokens are idempotent (paging extension)""" - ids = [ctx.item["id"]] + # Initialize a list to store the expected item IDs + expected_item_ids = [ctx.item["id"]] - # Ingest 5 items + # Ingest 5 items in addition to the default test-item for _ in range(5): ctx.item["id"] = str(uuid.uuid4()) - await create_item(txn_client, ctx.item) - ids.append(ctx.item["id"]) + await create_item(txn_client, ctx.item) + expected_item_ids.append(ctx.item["id"]) - page = await app_client.get("/search", params={"ids": ",".join(ids), "limit": 3}) + # Perform the initial GET request to start pagination with a limit of 3 + page = await app_client.get( + "/search", params={"ids": ",".join(expected_item_ids), "limit": 3} + ) page_data = page.json() next_link = list(filter(lambda link: link["rel"] == "next", page_data["links"])) + # Extract the pagination token from the next link + pagination_token = parse_qs(urlparse(next_link[0]["href"]).query) + # Confirm token is idempotent - resp1 = await app_client.get( - "/search", params=parse_qs(urlparse(next_link[0]["href"]).query) - ) - resp2 = await app_client.get( - "/search", params=parse_qs(urlparse(next_link[0]["href"]).query) - ) + resp1 = await app_client.get("/search", params=pagination_token) + resp2 = await app_client.get("/search", params=pagination_token) resp1_data = resp1.json() resp2_data = resp2.json()