From c512bdb008254badea612d9924b465c2b8e4c16a Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 17 Oct 2023 15:25:13 +0100 Subject: [PATCH] Add endpoints scoped to a config hash encoded in url --- Makefile | 3 + langserve/lzstring.py | 433 +++++++++++++++++++++++++ langserve/server.py | 74 ++++- poetry.lock | 90 +++-- pyproject.toml | 1 + tests/unit_tests/test_server_client.py | 72 +++- tests/unit_tests/test_validation.py | 4 +- 7 files changed, 631 insertions(+), 46 deletions(-) create mode 100644 langserve/lzstring.py diff --git a/Makefile b/Makefile index 51de14d8..a9fdd330 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,9 @@ TEST_FILE ?= tests/unit_tests/ test: poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE) +test_watch: + poetry run ptw . -- $(TEST_FILE) + ###################### # LINTING AND FORMATTING ###################### diff --git a/langserve/lzstring.py b/langserve/lzstring.py new file mode 100644 index 00000000..a6c4dd58 --- /dev/null +++ b/langserve/lzstring.py @@ -0,0 +1,433 @@ +""" +Copyright © 2017 Marcel Dancak +This work is free. You can redistribute it and/or modify it under the +terms of the Do What The Fuck You Want To Public License, Version 2, +as published by Sam Hocevar. See the COPYING file for more details. + +Adapted from https://github.com/marcel-dancak/lz-string-python/blob/master/lzstring.py +""" + +import math + + +keyStrBase64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=" +keyStrUriSafe = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+-$" +baseReverseDic = {} + + +class Object(object): + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +def getBaseValue(alphabet, character): + if alphabet not in baseReverseDic: + baseReverseDic[alphabet] = {} + for index, i in enumerate(alphabet): + baseReverseDic[alphabet][i] = index + return baseReverseDic[alphabet][character] + + +def _compress(uncompressed, bitsPerChar, getCharFromInt): + if uncompressed is None: + return "" + + context_dictionary = {} + context_dictionaryToCreate = {} + context_c = "" + context_wc = "" + context_w = "" + context_enlargeIn = 2 # Compensate for the first entry which should not count + context_dictSize = 3 + context_numBits = 2 + context_data = [] + context_data_val = 0 + context_data_position = 0 + + for ii in range(len(uncompressed)): + context_c = uncompressed[ii] + if context_c not in context_dictionary: + context_dictionary[context_c] = context_dictSize + context_dictSize += 1 + context_dictionaryToCreate[context_c] = True + + context_wc = context_w + context_c + if context_wc in context_dictionary: + context_w = context_wc + else: + if context_w in context_dictionaryToCreate: + if ord(context_w[0]) < 256: + for i in range(context_numBits): + context_data_val = context_data_val << 1 + if context_data_position == bitsPerChar - 1: + context_data_position = 0 + context_data.append(getCharFromInt(context_data_val)) + context_data_val = 0 + else: + context_data_position += 1 + value = ord(context_w[0]) + for i in range(8): + context_data_val = (context_data_val << 1) | (value & 1) + if context_data_position == bitsPerChar - 1: + context_data_position = 0 + context_data.append(getCharFromInt(context_data_val)) + context_data_val = 0 + else: + context_data_position += 1 + value = value >> 1 + + else: + value = 1 + for i in range(context_numBits): + context_data_val = (context_data_val << 1) | value + if context_data_position == bitsPerChar - 1: + context_data_position = 0 + context_data.append(getCharFromInt(context_data_val)) + context_data_val = 0 + else: + context_data_position += 1 + value = 0 + value = ord(context_w[0]) + for i in range(16): + context_data_val = (context_data_val << 1) | (value & 1) + if context_data_position == bitsPerChar - 1: + context_data_position = 0 + context_data.append(getCharFromInt(context_data_val)) + context_data_val = 0 + else: + context_data_position += 1 + value = value >> 1 + context_enlargeIn -= 1 + if context_enlargeIn == 0: + context_enlargeIn = math.pow(2, context_numBits) + context_numBits += 1 + del context_dictionaryToCreate[context_w] + else: + value = context_dictionary[context_w] + for i in range(context_numBits): + context_data_val = (context_data_val << 1) | (value & 1) + if context_data_position == bitsPerChar - 1: + context_data_position = 0 + context_data.append(getCharFromInt(context_data_val)) + context_data_val = 0 + else: + context_data_position += 1 + value = value >> 1 + + context_enlargeIn -= 1 + if context_enlargeIn == 0: + context_enlargeIn = math.pow(2, context_numBits) + context_numBits += 1 + + # Add wc to the dictionary. + context_dictionary[context_wc] = context_dictSize + context_dictSize += 1 + context_w = str(context_c) + + # Output the code for w. + if context_w != "": + if context_w in context_dictionaryToCreate: + if ord(context_w[0]) < 256: + for i in range(context_numBits): + context_data_val = context_data_val << 1 + if context_data_position == bitsPerChar - 1: + context_data_position = 0 + context_data.append(getCharFromInt(context_data_val)) + context_data_val = 0 + else: + context_data_position += 1 + value = ord(context_w[0]) + for i in range(8): + context_data_val = (context_data_val << 1) | (value & 1) + if context_data_position == bitsPerChar - 1: + context_data_position = 0 + context_data.append(getCharFromInt(context_data_val)) + context_data_val = 0 + else: + context_data_position += 1 + value = value >> 1 + else: + value = 1 + for i in range(context_numBits): + context_data_val = (context_data_val << 1) | value + if context_data_position == bitsPerChar - 1: + context_data_position = 0 + context_data.append(getCharFromInt(context_data_val)) + context_data_val = 0 + else: + context_data_position += 1 + value = 0 + value = ord(context_w[0]) + for i in range(16): + context_data_val = (context_data_val << 1) | (value & 1) + if context_data_position == bitsPerChar - 1: + context_data_position = 0 + context_data.append(getCharFromInt(context_data_val)) + context_data_val = 0 + else: + context_data_position += 1 + value = value >> 1 + context_enlargeIn -= 1 + if context_enlargeIn == 0: + context_enlargeIn = math.pow(2, context_numBits) + context_numBits += 1 + del context_dictionaryToCreate[context_w] + else: + value = context_dictionary[context_w] + for i in range(context_numBits): + context_data_val = (context_data_val << 1) | (value & 1) + if context_data_position == bitsPerChar - 1: + context_data_position = 0 + context_data.append(getCharFromInt(context_data_val)) + context_data_val = 0 + else: + context_data_position += 1 + value = value >> 1 + + context_enlargeIn -= 1 + if context_enlargeIn == 0: + context_enlargeIn = math.pow(2, context_numBits) + context_numBits += 1 + + # Mark the end of the stream + value = 2 + for i in range(context_numBits): + context_data_val = (context_data_val << 1) | (value & 1) + if context_data_position == bitsPerChar - 1: + context_data_position = 0 + context_data.append(getCharFromInt(context_data_val)) + context_data_val = 0 + else: + context_data_position += 1 + value = value >> 1 + + # Flush the last char + while True: + context_data_val = context_data_val << 1 + if context_data_position == bitsPerChar - 1: + context_data.append(getCharFromInt(context_data_val)) + break + else: + context_data_position += 1 + + return "".join(context_data) + + +def _decompress(length, resetValue, getNextValue): + dictionary = {} + enlargeIn = 4 + dictSize = 4 + numBits = 3 + entry = "" + result = [] + + data = Object(val=getNextValue(0), position=resetValue, index=1) + + for i in range(3): + dictionary[i] = i + + bits = 0 + maxpower = math.pow(2, 2) + power = 1 + + while power != maxpower: + resb = data.val & data.position + data.position >>= 1 + if data.position == 0: + data.position = resetValue + data.val = getNextValue(data.index) + data.index += 1 + + bits |= power if resb > 0 else 0 + power <<= 1 + + next = bits + if next == 0: + bits = 0 + maxpower = math.pow(2, 8) + power = 1 + while power != maxpower: + resb = data.val & data.position + data.position >>= 1 + if data.position == 0: + data.position = resetValue + data.val = getNextValue(data.index) + data.index += 1 + bits |= power if resb > 0 else 0 + power <<= 1 + c = chr(bits) + elif next == 1: + bits = 0 + maxpower = math.pow(2, 16) + power = 1 + while power != maxpower: + resb = data.val & data.position + data.position >>= 1 + if data.position == 0: + data.position = resetValue + data.val = getNextValue(data.index) + data.index += 1 + bits |= power if resb > 0 else 0 + power <<= 1 + c = chr(bits) + elif next == 2: + return "" + + # print(bits) + dictionary[3] = c + w = c + result.append(c) + counter = 0 + while True: + counter += 1 + if data.index > length: + return "" + + bits = 0 + maxpower = math.pow(2, numBits) + power = 1 + while power != maxpower: + resb = data.val & data.position + data.position >>= 1 + if data.position == 0: + data.position = resetValue + data.val = getNextValue(data.index) + data.index += 1 + bits |= power if resb > 0 else 0 + power <<= 1 + + c = bits + if c == 0: + bits = 0 + maxpower = math.pow(2, 8) + power = 1 + while power != maxpower: + resb = data.val & data.position + data.position >>= 1 + if data.position == 0: + data.position = resetValue + data.val = getNextValue(data.index) + data.index += 1 + bits |= power if resb > 0 else 0 + power <<= 1 + + dictionary[dictSize] = chr(bits) + dictSize += 1 + c = dictSize - 1 + enlargeIn -= 1 + elif c == 1: + bits = 0 + maxpower = math.pow(2, 16) + power = 1 + while power != maxpower: + resb = data.val & data.position + data.position >>= 1 + if data.position == 0: + data.position = resetValue + data.val = getNextValue(data.index) + data.index += 1 + bits |= power if resb > 0 else 0 + power <<= 1 + dictionary[dictSize] = chr(bits) + dictSize += 1 + c = dictSize - 1 + enlargeIn -= 1 + elif c == 2: + return "".join(result) + + if enlargeIn == 0: + enlargeIn = math.pow(2, numBits) + numBits += 1 + + if c in dictionary: + entry = dictionary[c] + else: + if c == dictSize: + entry = w + w[0] + else: + return None + result.append(entry) + + # Add w+entry[0] to the dictionary. + dictionary[dictSize] = w + entry[0] + dictSize += 1 + enlargeIn -= 1 + + w = entry + if enlargeIn == 0: + enlargeIn = math.pow(2, numBits) + numBits += 1 + + +class LZString: + @staticmethod + def compress(uncompressed): + return _compress(uncompressed, 16, chr) + + @staticmethod + def compressToUTF16(uncompressed): + if uncompressed is None: + return "" + return _compress(uncompressed, 15, lambda a: chr(a + 32)) + " " + + @staticmethod + def compressToBase64(uncompressed): + if uncompressed is None: + return "" + res = _compress(uncompressed, 6, lambda a: keyStrBase64[a]) + # To produce valid Base64 + end = len(res) % 4 + print(end) + if end > 0: + res += "=" * (4 - end) + return res + + @staticmethod + def compressToEncodedURIComponent(uncompressed): + if uncompressed is None: + return "" + return _compress(uncompressed, 6, lambda a: keyStrUriSafe[a]) + + @staticmethod + def decompress(compressed): + if compressed is None: + return "" + if compressed == "": + return None + return _decompress(len(compressed), 32768, lambda index: ord(compressed[index])) + + @staticmethod + def decompressFromUTF16(compressed): + if compressed is None: + return "" + if compressed == "": + return None + return _decompress( + len(compressed), 16384, lambda index: ord(compressed[index]) - 32 + ) + + @staticmethod + def decompressFromBase64(compressed): + if compressed is None: + return "" + if compressed == "": + return None + return _decompress( + len(compressed), + 32, + lambda index: getBaseValue(keyStrBase64, compressed[index]), + ) + + @staticmethod + def decompressFromEncodedURIComponent(compressed): + if compressed is None: + return "" + if compressed == "": + return None + compressed = compressed.replace(" ", "+") + return _decompress( + len(compressed), + 32, + lambda index: getBaseValue(keyStrUriSafe, compressed[index]), + ) diff --git a/langserve/server.py b/langserve/server.py index 5e44ba8d..f1ba0814 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -5,6 +5,7 @@ The main entry point is the `add_routes` function which adds the routes to an existing FastAPI app or APIRouter. """ +import json from inspect import isclass from typing import ( Any, @@ -17,13 +18,15 @@ Union, ) -from fastapi import Request +from fastapi import HTTPException, Request from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch from langchain.load.serializable import Serializable from langchain.schema.runnable import Runnable +from langchain.schema.runnable.config import merge_configs from typing_extensions import Annotated from langserve.version import __version__ +from langserve.lzstring import LZString try: from pydantic.v1 import BaseModel, create_model @@ -47,10 +50,35 @@ APIRouter = FastAPI = Any -def _unpack_config(d: Union[BaseModel, Mapping], keys: Sequence[str]) -> Dict[str, Any]: - """Project the given keys from the given dict.""" - _d = d.dict() if isinstance(d, BaseModel) else d - return {k: _d[k] for k in keys if k in _d} +def _config_from_hash(config_hash: str) -> Dict[str, Any]: + try: + if not config_hash: + return {} + + uncompressed = LZString.decompressFromEncodedURIComponent(config_hash) + parsed = json.loads(uncompressed) + if isinstance(parsed, dict): + return parsed + else: + raise HTTPException(400, "Invalid config hash") + except Exception: + raise HTTPException(400, "Invalid config hash") + + +def _unpack_config( + *configs: Union[BaseModel, Mapping, str], keys: Sequence[str] +) -> Dict[str, Any]: + """Merge configs, and project the given keys from the merged dict.""" + config_dicts = [] + for config in configs: + if isinstance(config, str): + config_dicts.append(_config_from_hash(config)) + elif isinstance(config, BaseModel): + config_dicts.append(config.dict()) + else: + config_dicts.append(config) + config = merge_configs(*config_dicts) + return {k: config[k] for k in keys if k in config} def _unpack_input(validated_model: BaseModel) -> Any: @@ -217,18 +245,17 @@ def add_routes( InvokeResponse = create_invoke_response_model(model_namespace, output_type_) BatchResponse = create_batch_response_model(model_namespace, output_type_) - @app.post( - f"{namespace}/invoke", - response_model=InvokeResponse, - ) + @app.post(namespace + "/h{config_hash}/invoke", response_model=InvokeResponse) + @app.post(f"{namespace}/invoke", response_model=InvokeResponse) async def invoke( invoke_request: Annotated[InvokeRequest, InvokeRequest], request: Request, + config_hash: str = "", ) -> InvokeResponse: """Invoke the runnable with the given input and config.""" # Request is first validated using InvokeRequest which takes into account # config_keys as well as input_type. - config = _unpack_config(invoke_request.config, config_keys) + config = _unpack_config(config_hash, invoke_request.config, keys=config_keys) _add_tracing_info_to_metadata(config, request) output = await runnable.ainvoke( _unpack_input(invoke_request.input), config=config @@ -236,32 +263,36 @@ async def invoke( return InvokeResponse(output=simple_dumpd(output)) - # + @app.post(namespace + "/h{config_hash}/batch", response_model=BatchResponse) @app.post(f"{namespace}/batch", response_model=BatchResponse) async def batch( batch_request: Annotated[BatchRequest, BatchRequest], request: Request, + config_hash: str = "", ) -> BatchResponse: """Invoke the runnable with the given inputs and config.""" if isinstance(batch_request.config, list): config = [ - _unpack_config(config, config_keys) for config in batch_request.config + _unpack_config(config, keys=config_keys) + for config in batch_request.config ] for c in config: _add_tracing_info_to_metadata(c, request) else: - config = _unpack_config(batch_request.config, config_keys) + config = _unpack_config(config_hash, batch_request.config, keys=config_keys) _add_tracing_info_to_metadata(config, request) inputs = [_unpack_input(input_) for input_ in batch_request.inputs] output = await runnable.abatch(inputs, config=config) return BatchResponse(output=simple_dumpd(output)) + @app.post(namespace + "/h{config_hash}/stream") @app.post(f"{namespace}/stream") async def stream( stream_request: Annotated[StreamRequest, StreamRequest], request: Request, + config_hash: str = "", ) -> EventSourceResponse: """Invoke the runnable stream the output. @@ -298,7 +329,7 @@ async def stream( # config_keys as well as input_type. # After validation, the input is loaded using LangChain's load function. input_ = _unpack_input(stream_request.input) - config = _unpack_config(stream_request.config, config_keys) + config = _unpack_config(config_hash, stream_request.config, keys=config_keys) _add_tracing_info_to_metadata(config, request) async def _stream() -> AsyncIterator[dict]: @@ -312,10 +343,12 @@ async def _stream() -> AsyncIterator[dict]: return EventSourceResponse(_stream()) + @app.post(namespace + "/h{config_hash}/stream_log") @app.post(f"{namespace}/stream_log") async def stream_log( stream_log_request: Annotated[StreamLogRequest, StreamLogRequest], request: Request, + config_hash: str = "", ) -> EventSourceResponse: """Invoke the runnable stream_log the output. @@ -353,7 +386,9 @@ async def stream_log( # config_keys as well as input_type. # After validation, the input is loaded using LangChain's load function. input_ = _unpack_input(stream_log_request.input) - config = _unpack_config(stream_log_request.config, config_keys) + config = _unpack_config( + config_hash, stream_log_request.config, keys=config_keys + ) _add_tracing_info_to_metadata(config, request) async def _stream_log() -> AsyncIterator[dict]: @@ -397,17 +432,20 @@ async def _stream_log() -> AsyncIterator[dict]: return EventSourceResponse(_stream_log()) + @app.get(namespace + "/h{config_hash}/input_schema") @app.get(f"{namespace}/input_schema") - async def input_schema() -> Any: + async def input_schema(config_hash: str = "") -> Any: """Return the input schema of the runnable.""" return runnable.input_schema.schema() + @app.get(namespace + "/h{config_hash}/output_schema") @app.get(f"{namespace}/output_schema") - async def output_schema() -> Any: + async def output_schema(config_hash: str = "") -> Any: """Return the output schema of the runnable.""" return runnable.output_schema.schema() + @app.get(namespace + "/h{config_hash}/config_schema") @app.get(f"{namespace}/config_schema") - async def config_schema() -> Any: + async def config_schema(config_hash: str = "") -> Any: """Return the config schema of the runnable.""" return runnable.config_schema(include=config_keys).schema() diff --git a/poetry.lock b/poetry.lock index 7ab0c6db..fcad5958 100644 --- a/poetry.lock +++ b/poetry.lock @@ -780,6 +780,16 @@ files = [ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, ] +[[package]] +name = "docopt" +version = "0.6.2" +description = "Pythonic argument parser, that will make you smile" +optional = false +python-versions = "*" +files = [ + {file = "docopt-0.6.2.tar.gz", hash = "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491"}, +] + [[package]] name = "entrypoints" version = "0.4" @@ -960,7 +970,7 @@ files = [ {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0b72b802496cccbd9b31acea72b6f87e7771ccfd7f7927437d592e5c92ed703c"}, {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:527cd90ba3d8d7ae7dceb06fda619895768a46a1b4e423bdb24c1969823b8362"}, {file = "greenlet-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:37f60b3a42d8b5499be910d1267b24355c495064f271cfe74bf28b17b099133c"}, - {file = "greenlet-3.0.0-cp311-universal2-macosx_10_9_universal2.whl", hash = "sha256:c3692ecf3fe754c8c0f2c95ff19626584459eab110eaab66413b1e7425cd84e9"}, + {file = "greenlet-3.0.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1482fba7fbed96ea7842b5a7fc11d61727e8be75a077e603e8ab49d24e234383"}, {file = "greenlet-3.0.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:be557119bf467d37a8099d91fbf11b2de5eb1fd5fc5b91598407574848dc910f"}, {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73b2f1922a39d5d59cc0e597987300df3396b148a9bd10b76a058a2f2772fc04"}, {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1e22c22f7826096ad503e9bb681b05b8c1f5a8138469b255eb91f26a76634f2"}, @@ -970,7 +980,6 @@ files = [ {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:952256c2bc5b4ee8df8dfc54fc4de330970bf5d79253c863fb5e6761f00dda35"}, {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:269d06fa0f9624455ce08ae0179430eea61085e3cf6457f05982b37fd2cefe17"}, {file = "greenlet-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9adbd8ecf097e34ada8efde9b6fec4dd2a903b1e98037adf72d12993a1c80b51"}, - {file = "greenlet-3.0.0-cp312-universal2-macosx_10_9_universal2.whl", hash = "sha256:553d6fb2324e7f4f0899e5ad2c427a4579ed4873f42124beba763f16032959af"}, {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6b5ce7f40f0e2f8b88c28e6691ca6806814157ff05e794cdd161be928550f4c"}, {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf94aa539e97a8411b5ea52fc6ccd8371be9550c4041011a091eb8b3ca1d810"}, {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80dcd3c938cbcac986c5c92779db8e8ce51a89a849c135172c88ecbdc8c056b7"}, @@ -1728,16 +1737,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -2527,6 +2526,22 @@ files = [ [package.dependencies] pytest = ">=3.6.3" +[[package]] +name = "pytest-watch" +version = "4.2.0" +description = "Local continuous test runner with pytest and watchdog." +optional = false +python-versions = "*" +files = [ + {file = "pytest-watch-4.2.0.tar.gz", hash = "sha256:06136f03d5b361718b8d0d234042f7b2f203910d8568f63df2f866b547b3d4b9"}, +] + +[package.dependencies] +colorama = ">=0.3.3" +docopt = ">=0.4.0" +pytest = ">=2.6.4" +watchdog = ">=0.6.0" + [[package]] name = "python-dateutil" version = "2.8.2" @@ -2626,7 +2641,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2634,15 +2648,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2659,7 +2666,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2667,7 +2673,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -3400,6 +3405,45 @@ dev = ["Cython (>=0.29.32,<0.30.0)", "Sphinx (>=4.1.2,<4.2.0)", "aiohttp", "flak docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] test = ["Cython (>=0.29.32,<0.30.0)", "aiohttp", "flake8 (>=3.9.2,<3.10.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=22.0.0,<22.1.0)", "pycodestyle (>=2.7.0,<2.8.0)"] +[[package]] +name = "watchdog" +version = "3.0.0" +description = "Filesystem events monitoring" +optional = false +python-versions = ">=3.7" +files = [ + {file = "watchdog-3.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:336adfc6f5cc4e037d52db31194f7581ff744b67382eb6021c868322e32eef41"}, + {file = "watchdog-3.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a70a8dcde91be523c35b2bf96196edc5730edb347e374c7de7cd20c43ed95397"}, + {file = "watchdog-3.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:adfdeab2da79ea2f76f87eb42a3ab1966a5313e5a69a0213a3cc06ef692b0e96"}, + {file = "watchdog-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2b57a1e730af3156d13b7fdddfc23dea6487fceca29fc75c5a868beed29177ae"}, + {file = "watchdog-3.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7ade88d0d778b1b222adebcc0927428f883db07017618a5e684fd03b83342bd9"}, + {file = "watchdog-3.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7e447d172af52ad204d19982739aa2346245cc5ba6f579d16dac4bfec226d2e7"}, + {file = "watchdog-3.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9fac43a7466eb73e64a9940ac9ed6369baa39b3bf221ae23493a9ec4d0022674"}, + {file = "watchdog-3.0.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8ae9cda41fa114e28faf86cb137d751a17ffd0316d1c34ccf2235e8a84365c7f"}, + {file = "watchdog-3.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:25f70b4aa53bd743729c7475d7ec41093a580528b100e9a8c5b5efe8899592fc"}, + {file = "watchdog-3.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4f94069eb16657d2c6faada4624c39464f65c05606af50bb7902e036e3219be3"}, + {file = "watchdog-3.0.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7c5f84b5194c24dd573fa6472685b2a27cc5a17fe5f7b6fd40345378ca6812e3"}, + {file = "watchdog-3.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3aa7f6a12e831ddfe78cdd4f8996af9cf334fd6346531b16cec61c3b3c0d8da0"}, + {file = "watchdog-3.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:233b5817932685d39a7896b1090353fc8efc1ef99c9c054e46c8002561252fb8"}, + {file = "watchdog-3.0.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:13bbbb462ee42ec3c5723e1205be8ced776f05b100e4737518c67c8325cf6100"}, + {file = "watchdog-3.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8f3ceecd20d71067c7fd4c9e832d4e22584318983cabc013dbf3f70ea95de346"}, + {file = "watchdog-3.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c9d8c8ec7efb887333cf71e328e39cffbf771d8f8f95d308ea4125bf5f90ba64"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0e06ab8858a76e1219e68c7573dfeba9dd1c0219476c5a44d5333b01d7e1743a"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:d00e6be486affb5781468457b21a6cbe848c33ef43f9ea4a73b4882e5f188a44"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:c07253088265c363d1ddf4b3cdb808d59a0468ecd017770ed716991620b8f77a"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:5113334cf8cf0ac8cd45e1f8309a603291b614191c9add34d33075727a967709"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:51f90f73b4697bac9c9a78394c3acbbd331ccd3655c11be1a15ae6fe289a8c83"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:ba07e92756c97e3aca0912b5cbc4e5ad802f4557212788e72a72a47ff376950d"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d429c2430c93b7903914e4db9a966c7f2b068dd2ebdd2fa9b9ce094c7d459f33"}, + {file = "watchdog-3.0.0-py3-none-win32.whl", hash = "sha256:3ed7c71a9dccfe838c2f0b6314ed0d9b22e77d268c67e015450a29036a81f60f"}, + {file = "watchdog-3.0.0-py3-none-win_amd64.whl", hash = "sha256:4c9956d27be0bb08fc5f30d9d0179a855436e655f046d288e2bcc11adfae893c"}, + {file = "watchdog-3.0.0-py3-none-win_ia64.whl", hash = "sha256:5d9f3a10e02d7371cd929b5d8f11e87d4bad890212ed3901f9b4d68767bee759"}, + {file = "watchdog-3.0.0.tar.gz", hash = "sha256:4d98a320595da7a7c5a18fc48cb633c2e73cda78f93cac2ef42d42bf609a33f9"}, +] + +[package.extras] +watchmedo = ["PyYAML (>=3.10)"] + [[package]] name = "watchfiles" version = "0.20.0" @@ -3778,4 +3822,4 @@ server = ["fastapi", "sse-starlette"] [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "5c10b2c8e0c87eb71bc9c9d56a009fb1c861cad22833c41cae33fee09cb860a9" +content-hash = "75348fb1a36b0e85c14030ea99bf75ac53fe9438178ffb589c353b262c11eef8" diff --git a/pyproject.toml b/pyproject.toml index 0c4c2628..212cbf77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ pytest-cov = "^4.0.0" pytest-asyncio = "^0.21.1" pytest-mock = "^3.11.1" pytest-socket = "^0.6.0" +pytest-watch = "^4.2.0" [tool.poetry.group.examples.dependencies] openai = "^0.28.0" diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index ef72ecdb..30a6da8d 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -1,8 +1,9 @@ """Test the server and client together.""" import asyncio +import json from asyncio import AbstractEventLoop from contextlib import asynccontextmanager -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union import httpx import pytest @@ -13,12 +14,13 @@ from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch from langchain.prompts import PromptTemplate from langchain.schema.messages import HumanMessage, SystemMessage -from langchain.schema.runnable import RunnablePassthrough +from langchain.schema.runnable import RunnablePassthrough, RunnableConfig from langchain.schema.runnable.base import RunnableLambda from langchain.schema.runnable.utils import ConfigurableField from pytest_mock import MockerFixture from langserve.client import RemoteRunnable +from langserve.lzstring import LZString from langserve.server import add_routes from tests.unit_tests.utils import FakeListLLM @@ -49,7 +51,27 @@ async def add_one_or_passthrough( runnable_lambda = RunnableLambda(func=add_one_or_passthrough) app = FastAPI() try: - add_routes(app, runnable_lambda) + add_routes(app, runnable_lambda, config_keys=["tags"]) + yield app + finally: + del app + + +@pytest.fixture() +def app_for_config(event_loop: AbstractEventLoop) -> FastAPI: + """A simple server that wraps a Runnable and exposes it as an API.""" + + async def return_config( + _: int, + config: RunnableConfig, + ) -> Dict[str, Any]: + """Add one to int or passthrough.""" + return {k: config[k] for k in config if k in ("tags", "configurable")} + + runnable_lambda = RunnableLambda(func=return_config) + app = FastAPI() + try: + add_routes(app, runnable_lambda, config_keys=["tags", "metadata"]) yield app finally: del app @@ -142,6 +164,50 @@ async def test_server_async(app: FastAPI) -> None: assert response.text == "event: data\r\ndata: 2\r\n\r\nevent: end\r\n\r\n" +@pytest.mark.asyncio +async def test_server_bound_async(app_for_config: FastAPI) -> None: + """Test the server directly via HTTP requests.""" + async_client = AsyncClient(app=app_for_config, base_url="http://localhost:9999") + config_hash = LZString.compressToEncodedURIComponent(json.dumps({"tags": ["test"]})) + + # Test invoke + response = await async_client.post( + f"/h{config_hash}/invoke", + json={"input": 1, "config": {"tags": ["another-one"]}}, + ) + assert response.status_code == 200 + assert response.json() == { + "output": { + "tags": ["test", "another-one"], + } + } + + # Test batch + response = await async_client.post( + f"/h{config_hash}/batch", + json={"inputs": [1], "config": {"tags": ["another-one"]}}, + ) + assert response.status_code == 200 + assert response.json() == { + "output": [ + { + "tags": ["test", "another-one"], + } + ] + } + + # Test stream + response = await async_client.post( + f"/h{config_hash}/stream", + json={"input": 1, "config": {"tags": ["another-one"]}}, + ) + assert response.status_code == 200 + assert ( + response.text + == """event: data\r\ndata: {"tags": ["test", "another-one"]}\r\n\r\nevent: end\r\n\r\n""" + ) + + def test_invoke(client: RemoteRunnable) -> None: """Test sync invoke.""" assert client.invoke(1) == 2 diff --git a/tests/unit_tests/test_validation.py b/tests/unit_tests/test_validation.py index 26208d49..f22588a7 100644 --- a/tests/unit_tests/test_validation.py +++ b/tests/unit_tests/test_validation.py @@ -155,7 +155,7 @@ def test_invoke_request_with_runnables() -> None: Model( input={"name": "bob"}, ).config, - [], + keys=[], ) == {} ) @@ -177,6 +177,6 @@ def test_invoke_request_with_runnables() -> None: "template": "goodbye {name}", } - assert _unpack_config(request.config, ["configurable"]) == { + assert _unpack_config(request.config, keys=["configurable"]) == { "configurable": {"template": "goodbye {name}"}, }