diff --git a/CHANGES.md b/CHANGES.md index b0ab199..c2b4ee4 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,7 +5,9 @@ * add `/collections/{collection_id}/items/{item_id}/assets/{asset_id}` optional endpoints (`TITILER_PGSTAC_API_ENABLE_ASSETS_ENDPOINTS=TRUE|FALSE`) * add `/external` optional endpoints (`TITILER_PGSTAC_API_ENABLE_EXTERNAL_DATASET_ENDPOINTS=TRUE|FALSE`) * add `cachecontrol_exclude_paths` attribute in `ApiSettings` to let users decide if some path should not have cache-control headers (defaults is to exclude `/list`) -* Add PgstacSettings such that the user can provide their own default settings for PgSTAC search +* add `PgstacSettings` such that the user can provide their own default settings for PgSTAC search +* add check for pgstac `read-only` mode and raise `ReadOnlyPgSTACError` error when trying to write to the pgstac instance +* add `/pgstac` endpoint in the application (when `TITILER_PGSTAC_API_DEBUG=TRUE`) ## 1.3.1 (2024-08-01) diff --git a/tests/conftest.py b/tests/conftest.py index c177e48..4b58541 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,41 +59,53 @@ def database(postgresql_proc): password="a2Vw:yk=)CdSis[fek]tW=/o", ) as jan: connection = f"postgresql://{jan.user}:{quote(jan.password)}@{jan.host}:{jan.port}/{jan.dbname}" - # make sure the DB is set to use UTC with psycopg.connect(connection) as conn: with conn.cursor() as cur: cur.execute(f"ALTER DATABASE {jan.dbname} SET TIMEZONE='UTC';") - print("Running to PgSTAC migration...") - with PgstacDB(dsn=connection) as db: - migrator = Migrate(db) - version = migrator.run_migration() - assert version - print(f"PgSTAC version: {version}") - - print("Load items and collection into PgSTAC") - loader = Loader(db=db) - loader.load_collections(collection) - loader.load_collections(collection_maxar) - loader.load_items(items) - - # Make sure we have 1 collection and 163 items in pgstac - with psycopg.connect(connection) as conn: - with conn.cursor() as cur: - cur.execute("SELECT COUNT(*) FROM pgstac.collections") - val = cur.fetchone()[0] - assert val == 2 + yield jan - cur.execute("SELECT COUNT(*) FROM pgstac.items") - val = cur.fetchone()[0] - assert val == 163 - yield jan +@pytest.fixture(scope="session") +def pgstac(database): + """Create PgSTAC fixture.""" + connection = f"postgresql://{database.user}:{quote(database.password)}@{database.host}:{database.port}/{database.dbname}" + + # Clear PgSTAC + with psycopg.connect(connection) as conn: + with conn.cursor() as cur: + cur.execute("DROP SCHEMA IF EXISTS pgstac CASCADE;") + + print("Running to PgSTAC migration...") + with PgstacDB(dsn=connection) as db: + migrator = Migrate(db) + version = migrator.run_migration() + assert version + print(f"PgSTAC version: {version}") + + print("Load items and collection into PgSTAC") + loader = Loader(db=db) + loader.load_collections(collection) + loader.load_collections(collection_maxar) + loader.load_items(items) + + # Make sure we have 1 collection and 163 items in pgstac + with psycopg.connect(connection) as conn: + with conn.cursor() as cur: + cur.execute("SELECT COUNT(*) FROM pgstac.collections") + val = cur.fetchone()[0] + assert val == 2 + + cur.execute("SELECT COUNT(*) FROM pgstac.items") + val = cur.fetchone()[0] + assert val == 163 + + yield connection @pytest.fixture(autouse=True) -def app(database, monkeypatch): +def app(pgstac, monkeypatch): """Create app with connection to the pytest database.""" monkeypatch.setenv("AWS_ACCESS_KEY_ID", "jqt") monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "rde") @@ -104,10 +116,7 @@ def app(database, monkeypatch): monkeypatch.setenv("TITILER_PGSTAC_API_DEBUG", "TRUE") monkeypatch.setenv("TITILER_PGSTAC_API_ENABLE_ASSETS_ENDPOINTS", "TRUE") monkeypatch.setenv("TITILER_PGSTAC_API_ENABLE_EXTERNAL_DATASET_ENDPOINTS", "TRUE") - monkeypatch.setenv( - "DATABASE_URL", - f"postgresql://{database.user}:{quote(database.password)}@{database.host}:{database.port}/{database.dbname}", - ) + monkeypatch.setenv("DATABASE_URL", pgstac) from titiler.pgstac.main import app diff --git a/tests/test_items.py b/tests/test_items.py index fae82e3..501be74 100644 --- a/tests/test_items.py +++ b/tests/test_items.py @@ -4,13 +4,13 @@ import pystac -from titiler.pgstac.dependencies import get_stac_item - from .conftest import mock_rasterio_open def test_get_stac_item(app): """test get_stac_item.""" + from titiler.pgstac.dependencies import get_stac_item + item = get_stac_item( app.app.state.dbpool, "noaa-emergency-response", "20200307aC0853900w361030" ) diff --git a/tests/test_readonly.py b/tests/test_readonly.py new file mode 100644 index 0000000..36e85dd --- /dev/null +++ b/tests/test_readonly.py @@ -0,0 +1,201 @@ +"""test read-only pgstac instance.""" + +import os +from contextlib import asynccontextmanager +from urllib.parse import quote_plus as quote + +import psycopg +import pytest +from fastapi import FastAPI +from psycopg.rows import class_row, dict_row +from pypgstac.db import PgstacDB +from pypgstac.load import Loader +from pypgstac.migrate import Migrate +from starlette.requests import Request +from starlette.testclient import TestClient + +from titiler.pgstac.errors import ReadOnlyPgSTACError +from titiler.pgstac.model import Metadata, PgSTACSearch, Search + +DATA_DIR = os.path.join(os.path.dirname(__file__), "fixtures") +collection = os.path.join(DATA_DIR, "noaa-emergency-response.json") +items = os.path.join(DATA_DIR, "noaa-eri-nashville2020.json") + + +@pytest.fixture( + params=[ + True, + False, + ], + scope="session", +) +def pgstac_ro(request, database): + """Create PgSTAC fixture.""" + read_only = request.param + + connection = f"postgresql://{database.user}:{quote(database.password)}@{database.host}:{database.port}/{database.dbname}" + with psycopg.connect(connection) as conn: + with conn.cursor() as cur: + cur.execute("DROP SCHEMA IF EXISTS pgstac CASCADE;") + + with PgstacDB(dsn=connection) as db: + migrator = Migrate(db) + version = migrator.run_migration() + assert version + + loader = Loader(db=db) + loader.load_collections(collection) + loader.load_items(items) + + # register one search + with psycopg.connect( + connection, + options="-c search_path=pgstac,public -c application_name=pgstac", + ) as conn: + with conn.cursor(row_factory=dict_row) as cursor: + search = PgSTACSearch(collections=["noaa-emergency-response"]) + metadata = Metadata(name="noaa-emergency-response") + cursor.row_factory = class_row(Search) + cursor.execute( + "SELECT * FROM search_query(%s, _metadata => %s);", + ( + search.model_dump_json(by_alias=True, exclude_none=True), + metadata.model_dump_json(exclude_none=True), + ), + ) + cursor.fetchone() + + if read_only: + with psycopg.connect(connection) as conn: + with conn.cursor() as cur: + cur.execute( + "UPDATE pgstac.pgstac_settings SET value = true WHERE name = 'readonly';" + ) + + yield connection, read_only + + +@pytest.fixture(autouse=True) +def app_ro(pgstac_ro, monkeypatch): + """create app fixture.""" + monkeypatch.setenv("TITILER_PGSTAC_CACHE_DISABLE", "TRUE") + + dsn, ro = pgstac_ro + + from titiler.pgstac.db import close_db_connection, connect_to_db + from titiler.pgstac.dependencies import CollectionIdParams, SearchIdParams + from titiler.pgstac.extensions import searchInfoExtension + from titiler.pgstac.factory import ( + MosaicTilerFactory, + add_search_list_route, + add_search_register_route, + ) + from titiler.pgstac.settings import PostgresSettings + + postgres_settings = PostgresSettings(database_url=dsn) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """FastAPI Lifespan.""" + # Create Connection Pool + await connect_to_db(app, settings=postgres_settings) + yield + # Close the Connection Pool + await close_db_connection(app) + + app = FastAPI(lifespan=lifespan) + app.state.readonly = ro + + @app.get("/pgstac") + def pgstac_info(request: Request): + """Retrieve PgSTAC Info.""" + with request.app.state.dbpool.connection() as conn: + with conn.cursor(row_factory=dict_row) as cursor: + cursor.execute("SELECT pgstac.readonly()") + pgstac_readonly = cursor.fetchone()["readonly"] + + return { + "pgstac_readonly": pgstac_readonly, + } + + searches = MosaicTilerFactory( + path_dependency=SearchIdParams, + router_prefix="/searches/{search_id}", + extensions=[ + searchInfoExtension(), + ], + ) + app.include_router(searches.router, prefix="/searches/{search_id}") + add_search_register_route( + app, + prefix="/searches", + tile_dependencies=[ + searches.layer_dependency, + searches.dataset_dependency, + searches.pixel_selection_dependency, + searches.process_dependency, + searches.rescale_dependency, + searches.colormap_dependency, + searches.render_dependency, + searches.pgstac_dependency, + searches.reader_dependency, + searches.backend_dependency, + ], + ) + add_search_list_route(app, prefix="/searches", tags=["STAC Search"]) + + collection = MosaicTilerFactory( + path_dependency=CollectionIdParams, + router_prefix="/collections/{collection_id}", + extensions=[ + searchInfoExtension(), + ], + ) + app.include_router(collection.router, prefix="/collections/{collection_id}") + + with TestClient(app) as app: + yield app, ro + + +def test_pgstac_config(app_ro): + """should return pgstac read-only info.""" + client, ro = app_ro + + response = client.get("/pgstac") + assert response.status_code == 200 + assert response.json()["pgstac_readonly"] == ro + + +def test_searches_ro(app_ro): + """Register Search should only work for non-read-only pgstac.""" + client, ro = app_ro + + response = client.get("/searches/list", params={"limit": 1}) + assert response.status_code == 200 + resp = response.json() + assert resp["context"]["matched"] == 1 + search_id = resp["searches"][0]["search"]["hash"] + + response = client.get(f"/searches/{search_id}/info") + assert response.status_code == 200 + + if ro: + with pytest.raises(ReadOnlyPgSTACError): + client.post("/searches/register", json={"collections": ["collection"]}) + else: + response = client.post( + "/searches/register", json={"collections": ["collection"]} + ) + assert response.status_code == 200 + + +def test_collections_ro(app_ro): + """collections should only work for non-read-only pgstac.""" + client, ro = app_ro + + if ro: + with pytest.raises(ReadOnlyPgSTACError): + client.get("/collections/noaa-emergency-response/info") + else: + response = client.get("/collections/noaa-emergency-response/info") + assert response.status_code == 200 diff --git a/titiler/pgstac/dependencies.py b/titiler/pgstac/dependencies.py index d91f224..45921b1 100644 --- a/titiler/pgstac/dependencies.py +++ b/titiler/pgstac/dependencies.py @@ -18,6 +18,7 @@ from titiler.core.dependencies import DefaultDependency from titiler.pgstac import model +from titiler.pgstac.errors import ReadOnlyPgSTACError from titiler.pgstac.settings import CacheSettings, RetrySettings from titiler.pgstac.utils import retry @@ -100,6 +101,14 @@ def get_collection_id(pool: ConnectionPool, collection_id: str) -> str: # noqa: metadata.defaults = renders + # TODO: adapt Mosaic Backend to accept Search object directly + # we technically don't need to register the search request for /collections + cursor.execute("SELECT pgstac.readonly()") + if cursor.fetchone()["readonly"]: + raise ReadOnlyPgSTACError( + "PgSTAC instance is set to `read-only`, cannot register search query." + ) + cursor.row_factory = class_row(model.Search) cursor.execute( "SELECT * FROM search_query(%s, _metadata => %s);", diff --git a/titiler/pgstac/errors.py b/titiler/pgstac/errors.py new file mode 100644 index 0000000..71ac769 --- /dev/null +++ b/titiler/pgstac/errors.py @@ -0,0 +1,12 @@ +"""titiler.pgstac errors.""" + +from starlette import status + +from titiler.core.errors import TilerError + + +class ReadOnlyPgSTACError(TilerError): + """Cannot Write to PgSTAC Database.""" + + +PGSTAC_STATUS_CODES = {ReadOnlyPgSTACError: status.HTTP_500_INTERNAL_SERVER_ERROR} diff --git a/titiler/pgstac/factory.py b/titiler/pgstac/factory.py index 4d88bf4..d2c6a80 100644 --- a/titiler/pgstac/factory.py +++ b/titiler/pgstac/factory.py @@ -25,7 +25,7 @@ from fastapi.dependencies.utils import get_dependant, request_params_to_args from geojson_pydantic import Feature, FeatureCollection from psycopg import sql -from psycopg.rows import class_row +from psycopg.rows import class_row, dict_row from pydantic import conint from rio_tiler.constants import MAX_THREADS, WGS84_CRS from rio_tiler.mosaic.methods.base import MosaicMethodBase @@ -62,6 +62,7 @@ SearchParams, TmsTileParams, ) +from titiler.pgstac.errors import ReadOnlyPgSTACError from titiler.pgstac.mosaic import PGSTACBackend MOSAIC_THREADS = int(os.getenv("MOSAIC_CONCURRENCY", MAX_THREADS)) @@ -1001,7 +1002,14 @@ def register_search(request: Request, search_query=Depends(search_dependency)): search, metadata = search_query with request.app.state.dbpool.connection() as conn: - with conn.cursor(row_factory=class_row(model.Search)) as cursor: + with conn.cursor(row_factory=dict_row) as cursor: + cursor.execute("SELECT pgstac.readonly()") + if cursor.fetchone()["readonly"]: + raise ReadOnlyPgSTACError( + "PgSTAC instance is set to `read-only`, cannot register search query." + ) + + cursor.row_factory = class_row(model.Search) cursor.execute( "SELECT * FROM search_query(%s, _metadata => %s);", ( diff --git a/titiler/pgstac/main.py b/titiler/pgstac/main.py index 84858ce..05e2fef 100644 --- a/titiler/pgstac/main.py +++ b/titiler/pgstac/main.py @@ -38,6 +38,7 @@ ItemIdParams, SearchIdParams, ) +from titiler.pgstac.errors import PGSTAC_STATUS_CODES from titiler.pgstac.extensions import searchInfoExtension from titiler.pgstac.factory import ( MosaicTilerFactory, @@ -95,6 +96,7 @@ async def lifespan(app: FastAPI): add_exception_handlers(app, DEFAULT_STATUS_CODES) add_exception_handlers(app, MOSAIC_STATUS_CODES) +add_exception_handlers(app, PGSTAC_STATUS_CODES) # Set all CORS enabled origins @@ -142,6 +144,22 @@ async def get_collection(request: Request, collection_id: str = Path()): r = cursor.fetchone() return r.get("get_collection") or {} + @app.get("/pgstac", include_in_schema=False, tags=["DEBUG"]) + def pgstac_info(request: Request) -> Dict: + """Retrieve PgSTAC Info.""" + with request.app.state.dbpool.connection() as conn: + with conn.cursor(row_factory=dict_row) as cursor: + cursor.execute("SELECT pgstac.readonly()") + pgstac_readonly = cursor.fetchone()["readonly"] + + cursor.execute("SELECT pgstac.get_version();") + pgstac_version = cursor.fetchone()["get_version"] + + return { + "pgstac_version": pgstac_version, + "pgstac_readonly": pgstac_readonly, + } + ############################################################################### # STAC Search Endpoints