Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return run IDs from invoke/batch endpoints #148

Merged
merged 7 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions langserve/schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

try:
from pydantic.v1 import BaseModel
except ImportError:
Expand All @@ -21,3 +23,39 @@ class CustomUserType(BaseModel):
the server will keep the decoded type as a pydantic model instead
of converting it into a dict.
"""


class SharedResponseMetadata(BaseModel):
"""
Any response metadata should inherit from this class. Response metadata
represents non-output data that may be useful to some clients, but
ignorable to most. For example, the run_ids associated with each run
kicked off by the associated request.

SharedResponseMetadata is an abstraction to represent any metadata
representing a LangServe response shared across all outputs in said
response.
"""

pass


class SingletonResponseMetadata(SharedResponseMetadata):
"""
Represents response metadata used for just single input/output LangServe
responses.
"""

# Represents the parent run id for a given request
run_id: str


class BatchResponseMetadata(SharedResponseMetadata):
"""
Represents response metadata used for batches of input/output LangServe
responses.
"""

# Represents each parent run id for a given request, in
# the same order in which they were received
run_ids: List[str]
29 changes: 28 additions & 1 deletion langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@

from langserve.callbacks import AsyncEventAggregatorCallback, CallbackEventDict
from langserve.lzstring import LZString
from langserve.schema import CustomUserType
from langserve.schema import (
BatchResponseMetadata,
CustomUserType,
SingletonResponseMetadata,
)

try:
from pydantic.v1 import BaseModel, create_model
Expand Down Expand Up @@ -292,6 +296,23 @@ def _with_validation_error_translation() -> Generator[None, None, None]:
raise RequestValidationError(e.errors(), body=e.model)


def _get_base_run_id_as_str(
event_aggregator: AsyncEventAggregatorCallback,
) -> Optional[str]:
"""
Uses `event_aggregator` to determine the base run ID for a given run. Returns
the run_id as a string, or None if it does not exist.
"""
# The first run in the callback_events list corresponds to the
# overall trace for request
if event_aggregator.callback_events and event_aggregator.callback_events[0].get(
"run_id"
):
return str(event_aggregator.callback_events[0].get("run_id"))
else:
raise AssertionError("No run_id found for the given run")


# PUBLIC API


Expand Down Expand Up @@ -494,6 +515,9 @@ async def invoke(
# Callbacks are scrubbed and exceptions are converted to serializable format
# before returned in the response.
callback_events=callback_events,
metadata=SingletonResponseMetadata(
run_id=_get_base_run_id_as_str(event_aggregator)
),
)

@app.post(
Expand Down Expand Up @@ -585,6 +609,9 @@ async def batch(
return BatchResponse(
output=well_known_lc_serializer.dumpd(output),
callback_events=callback_events,
metadata=BatchResponseMetadata(
run_ids=[_get_base_run_id_as_str(agg) for agg in aggregators]
),
)

@app.post(
Expand Down
21 changes: 21 additions & 0 deletions langserve/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
RunInfo,
)

from langserve.schema import BatchResponseMetadata, SingletonResponseMetadata

try:
from pydantic.v1 import BaseModel, Field, create_model
except ImportError:
Expand Down Expand Up @@ -207,6 +209,16 @@ def create_invoke_response_model(
List[CallbackEvent],
Field(..., description="Callback events generated by the server side."),
),
metadata=(
SingletonResponseMetadata,
Field(
...,
description=(
"Metadata about the response that may be useful to "
"specific clients"
),
),
),
)
invoke_response_type.update_forward_refs()
return invoke_response_type
Expand Down Expand Up @@ -241,6 +253,15 @@ def create_batch_response_model(
),
),
),
metadata=(
BatchResponseMetadata,
Field(
...,
description=(
"Metadata about the response that may be useful to specific clients"
),
),
),
)
batch_response_type.update_forward_refs()
return batch_response_type
Expand Down
29 changes: 29 additions & 0 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test the server and client together."""
import asyncio
import json
import uuid
from asyncio import AbstractEventLoop
from contextlib import asynccontextmanager, contextmanager
from typing import Any, Dict, Iterator, List, Optional, Union
Expand Down Expand Up @@ -1438,3 +1439,31 @@ async def test_using_router() -> None:
)

app.include_router(router)


def _is_valid_uuid(uuid_as_str: str) -> bool:
try:
uuid.UUID(str(uuid_as_str))
return True
except ValueError:
return False


@pytest.mark.asyncio
async def test_invoke_returns_run_id(app: FastAPI) -> None:
"""Test the server directly via HTTP requests."""
async with get_async_test_client(app, raise_app_exceptions=True) as async_client:
response = await async_client.post("/invoke", json={"input": 1})
run_id = response.json()["metadata"]["run_id"]
assert _is_valid_uuid(run_id)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are kinda lame, since they don't actually match up the proper run ID. They just assert some run ID was returned.

Do you guys know how to mock out the generation of run_ids for a given runnable run? If I could do that, then my tests would be much better. Not sure how to reach into the underlying infra like that tho

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a fake tracer (FakeTracer)- check it it does whats needed, if not we could use mock patch (though that's brittle)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a quick look in the code, fake tracer won't help here -- i think current test is fine



@pytest.mark.asyncio
async def test_batch_returns_run_id(app: FastAPI) -> None:
"""Test the server directly via HTTP requests."""
async with get_async_test_client(app, raise_app_exceptions=True) as async_client:
response = await async_client.post("/batch", json={"inputs": [1, 2]})
run_ids = response.json()["metadata"]["run_ids"]
assert len(run_ids) == 2
for run_id in run_ids:
assert _is_valid_uuid(run_id)