Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rate limit implementation #303

Merged
merged 10 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
### Added

- Added `datetime_frequency_interval` parameter for `datetime_frequency` aggregation. [#294](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/294)
- Added rate limiting functionality with configurable limits using environment variable `STAC_FASTAPI_RATE_LIMIT`, example: `500/minute`. [#303](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/303)

### Changed

Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,4 +383,8 @@ Available aggregations are:
- geometry_geohash_grid_frequency ([geohash grid](https://opensearch.org/docs/latest/aggregations/bucket/geohash-grid/) on Item.geometry)
- geometry_geotile_grid_frequency ([geotile grid](https://opensearch.org/docs/latest/aggregations/bucket/geotile-grid/) on Item.geometry)

Support for additional fields and new aggregations can be added in the associated `database_logic.py` file.
Support for additional fields and new aggregations can be added in the associated `database_logic.py` file.

## Rate Limiting

Rate limiting is an optional security feature that controls API request frequency on a remote address basis. It's enabled by setting the `STAC_FASTAPI_RATE_LIMIT` environment variable, e.g., `500/minute`. This limits each client to 500 requests per minute, helping prevent abuse and maintain API stability. Implementation examples are available in the [examples/rate_limit](examples/rate_limit) directory.
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ services:
- ES_USE_SSL=false
- ES_VERIFY_CERTS=false
- BACKEND=opensearch
- STAC_FASTAPI_RATE_LIMIT=200/minute
ports:
- "8082:8082"
volumes:
Expand Down
94 changes: 94 additions & 0 deletions examples/rate_limit/docker-compose.rate_limit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
version: '3.9'

services:
app-elasticsearch:
container_name: stac-fastapi-es
image: stac-utils/stac-fastapi-es
restart: always
build:
context: .
dockerfile: dockerfiles/Dockerfile.dev.es
environment:
- STAC_FASTAPI_TITLE=stac-fastapi-elasticsearch
- STAC_FASTAPI_DESCRIPTION=A STAC FastAPI with an Elasticsearch backend
- STAC_FASTAPI_VERSION=2.1
- APP_HOST=0.0.0.0
- APP_PORT=8080
- RELOAD=true
- ENVIRONMENT=local
- WEB_CONCURRENCY=10
- ES_HOST=elasticsearch
- ES_PORT=9200
- ES_USE_SSL=false
- ES_VERIFY_CERTS=false
- BACKEND=elasticsearch
- STAC_FASTAPI_RATE_LIMIT=500/minute
ports:
- "8080:8080"
volumes:
- ./stac_fastapi:/app/stac_fastapi
- ./scripts:/app/scripts
- ./esdata:/usr/share/elasticsearch/data
depends_on:
- elasticsearch
command:
bash -c "./scripts/wait-for-it-es.sh es-container:9200 && python -m stac_fastapi.elasticsearch.app"

app-opensearch:
container_name: stac-fastapi-os
image: stac-utils/stac-fastapi-os
restart: always
build:
context: .
dockerfile: dockerfiles/Dockerfile.dev.os
environment:
- STAC_FASTAPI_TITLE=stac-fastapi-opensearch
- STAC_FASTAPI_DESCRIPTION=A STAC FastAPI with an Opensearch backend
- STAC_FASTAPI_VERSION=3.0.0a2
- APP_HOST=0.0.0.0
- APP_PORT=8082
- RELOAD=true
- ENVIRONMENT=local
- WEB_CONCURRENCY=10
- ES_HOST=opensearch
- ES_PORT=9202
- ES_USE_SSL=false
- ES_VERIFY_CERTS=false
- BACKEND=opensearch
- STAC_FASTAPI_RATE_LIMIT=200/minute
ports:
- "8082:8082"
volumes:
- ./stac_fastapi:/app/stac_fastapi
- ./scripts:/app/scripts
- ./osdata:/usr/share/opensearch/data
depends_on:
- opensearch
command:
bash -c "./scripts/wait-for-it-es.sh os-container:9202 && python -m stac_fastapi.opensearch.app"

elasticsearch:
container_name: es-container
image: docker.elastic.co/elasticsearch/elasticsearch:${ELASTICSEARCH_VERSION:-8.11.0}
hostname: elasticsearch
environment:
ES_JAVA_OPTS: -Xms512m -Xmx1g
volumes:
- ./elasticsearch/config/elasticsearch.yml:/usr/share/elasticsearch/config/elasticsearch.yml
- ./elasticsearch/snapshots:/usr/share/elasticsearch/snapshots
ports:
- "9200:9200"

opensearch:
container_name: os-container
image: opensearchproject/opensearch:${OPENSEARCH_VERSION:-2.11.1}
hostname: opensearch
environment:
- discovery.type=single-node
- plugins.security.disabled=true
- OPENSEARCH_JAVA_OPTS=-Xms512m -Xmx512m
volumes:
- ./opensearch/config/opensearch.yml:/usr/share/opensearch/config/opensearch.yml
- ./opensearch/snapshots:/usr/share/opensearch/snapshots
ports:
- "9202:9202"
1 change: 1 addition & 0 deletions stac_fastapi/core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"pygeofilter==0.2.1",
"typing_extensions==4.8.0",
"jsonschema",
"slowapi==0.1.9",
]

setup(
Expand Down
44 changes: 44 additions & 0 deletions stac_fastapi/core/stac_fastapi/core/rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Rate limiting middleware."""

import logging
import os
from typing import Optional

from fastapi import FastAPI, Request
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from slowapi.util import get_remote_address

logger = logging.getLogger(__name__)


def get_limiter(key_func=get_remote_address):
"""Create and return a Limiter instance for rate limiting."""
return Limiter(key_func=key_func)


def setup_rate_limit(
app: FastAPI, rate_limit: Optional[str] = None, key_func=get_remote_address
):
"""Set up rate limiting middleware."""
RATE_LIMIT = rate_limit or os.getenv("STAC_FASTAPI_RATE_LIMIT")

if not RATE_LIMIT:
logger.info("Rate limiting is disabled")
return

logger.info(f"Setting up rate limit with RATE_LIMIT={RATE_LIMIT}")

limiter = get_limiter(key_func)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)

@app.middleware("http")
@limiter.limit(RATE_LIMIT)
async def rate_limit_middleware(request: Request, call_next):
response = await call_next(request)
return response

logger.info("Rate limit setup complete")
4 changes: 4 additions & 0 deletions stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
EsAsyncAggregationClient,
)
from stac_fastapi.core.extensions.fields import FieldsExtension
from stac_fastapi.core.rate_limit import setup_rate_limit
from stac_fastapi.core.route_dependencies import get_route_dependencies
from stac_fastapi.core.session import Session
from stac_fastapi.elasticsearch.config import ElasticsearchSettings
Expand Down Expand Up @@ -97,6 +98,9 @@
app = api.app
app.root_path = os.getenv("STAC_FASTAPI_ROOT_PATH", "")

# Add rate limit
setup_rate_limit(app, rate_limit=os.getenv("STAC_FASTAPI_RATE_LIMIT"))


@app.on_event("startup")
async def _startup_event() -> None:
Expand Down
4 changes: 4 additions & 0 deletions stac_fastapi/opensearch/stac_fastapi/opensearch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
EsAsyncAggregationClient,
)
from stac_fastapi.core.extensions.fields import FieldsExtension
from stac_fastapi.core.rate_limit import setup_rate_limit
from stac_fastapi.core.route_dependencies import get_route_dependencies
from stac_fastapi.core.session import Session
from stac_fastapi.extensions.core import (
Expand Down Expand Up @@ -97,6 +98,9 @@
app = api.app
app.root_path = os.getenv("STAC_FASTAPI_ROOT_PATH", "")

# Add rate limit
setup_rate_limit(app, rate_limit=os.getenv("STAC_FASTAPI_RATE_LIMIT"))


@app.on_event("startup")
async def _startup_event() -> None:
Expand Down
4 changes: 2 additions & 2 deletions stac_fastapi/tests/basic_auth/test_basic_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def test_get_search_not_authenticated(app_client_basic_auth, ctx):

@pytest.mark.asyncio
async def test_post_search_authenticated(app_client_basic_auth, ctx):
"""Test protected endpoint [POST /search] with reader auhtentication"""
"""Test protected endpoint [POST /search] with reader authentication"""
if not os.getenv("BASIC_AUTH"):
pytest.skip()
params = {"id": ctx.item["id"]}
Expand All @@ -34,7 +34,7 @@ async def test_post_search_authenticated(app_client_basic_auth, ctx):
async def test_delete_resource_anonymous(
app_client_basic_auth,
):
"""Test protected endpoint [DELETE /collections/{collection_id}] without auhtentication"""
"""Test protected endpoint [DELETE /collections/{collection_id}] without authentication"""
if not os.getenv("BASIC_AUTH"):
pytest.skip()

Expand Down
60 changes: 60 additions & 0 deletions stac_fastapi/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
EsAggregationExtensionPostRequest,
EsAsyncAggregationClient,
)
from stac_fastapi.core.rate_limit import setup_rate_limit
from stac_fastapi.core.route_dependencies import get_route_dependencies

if os.getenv("BACKEND", "elasticsearch").lower() == "opensearch":
Expand Down Expand Up @@ -246,6 +247,65 @@ async def app_client(app):
yield c


@pytest_asyncio.fixture(scope="session")
async def app_rate_limit():
settings = AsyncSettings()

aggregation_extension = AggregationExtension(
client=EsAsyncAggregationClient(
database=database, session=None, settings=settings
)
)
aggregation_extension.POST = EsAggregationExtensionPostRequest
aggregation_extension.GET = EsAggregationExtensionGetRequest

search_extensions = [
TransactionExtension(
client=TransactionsClient(
database=database, session=None, settings=settings
),
settings=settings,
),
SortExtension(),
FieldsExtension(),
QueryExtension(),
TokenPaginationExtension(),
FilterExtension(),
FreeTextExtension(),
]

extensions = [aggregation_extension] + search_extensions

post_request_model = create_post_request_model(search_extensions)

app = StacApi(
settings=settings,
client=CoreClient(
database=database,
session=None,
extensions=extensions,
post_request_model=post_request_model,
),
extensions=extensions,
search_get_request_model=create_get_request_model(search_extensions),
search_post_request_model=post_request_model,
).app

# Set up rate limit
setup_rate_limit(app, rate_limit="2/minute")

return app


@pytest_asyncio.fixture(scope="session")
async def app_client_rate_limit(app_rate_limit):
await create_index_templates()
await create_collection_index()

async with AsyncClient(app=app_rate_limit, base_url="http://test-server") as c:
yield c


@pytest_asyncio.fixture(scope="session")
async def app_basic_auth():

Expand Down
38 changes: 38 additions & 0 deletions stac_fastapi/tests/rate_limit/test_rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

import pytest
from httpx import AsyncClient
from slowapi.errors import RateLimitExceeded

logger = logging.getLogger(__name__)


@pytest.mark.asyncio
async def test_rate_limit(app_client_rate_limit: AsyncClient, ctx):
expected_status_codes = [200, 200, 429, 429, 429]

for i, expected_status_code in enumerate(expected_status_codes):
try:
response = await app_client_rate_limit.get("/collections")
status_code = response.status_code
except RateLimitExceeded:
status_code = 429

logger.info(f"Request {i+1}: Status code {status_code}")
assert (
status_code == expected_status_code
), f"Expected status code {expected_status_code}, but got {status_code}"


@pytest.mark.asyncio
async def test_rate_limit_no_limit(app_client: AsyncClient, ctx):
expected_status_codes = [200, 200, 200, 200, 200]

for i, expected_status_code in enumerate(expected_status_codes):
response = await app_client.get("/collections")
status_code = response.status_code

logger.info(f"Request {i+1}: Status code {status_code}")
assert (
status_code == expected_status_code
), f"Expected status code {expected_status_code}, but got {status_code}"