Skip to content

Commit

Permalink
Use Blockbuster to detect blocking calls in asyncio during tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jan 6, 2025
1 parent edbe7d5 commit 7ce4b18
Show file tree
Hide file tree
Showing 16 changed files with 343 additions and 165 deletions.
6 changes: 4 additions & 2 deletions libs/core/langchain_core/beta/runnables/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _config_with_context(
return patch_config(config, configurable=context_funcs)


def aconfig_with_context(
async def aconfig_with_context(
config: RunnableConfig,
steps: list[Runnable],
) -> RunnableConfig:
Expand All @@ -134,7 +134,9 @@ def aconfig_with_context(
Returns:
The patched runnable config.
"""
return _config_with_context(config, steps, _asetter, _agetter, asyncio.Event)
return await asyncio.to_thread(
_config_with_context, config, steps, _asetter, _agetter, asyncio.Event
)


def config_with_context(
Expand Down
6 changes: 3 additions & 3 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3037,7 +3037,7 @@ async def ainvoke(
from langchain_core.beta.runnables.context import aconfig_with_context

# setup callbacks and context
config = aconfig_with_context(ensure_config(config), self.steps)
config = await aconfig_with_context(ensure_config(config), self.steps)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
Expand Down Expand Up @@ -3214,7 +3214,7 @@ async def abatch(

# setup callbacks and context
configs = [
aconfig_with_context(c, self.steps)
await aconfig_with_context(c, self.steps)
for c in get_config_list(config, len(inputs))
]
callback_managers = [
Expand Down Expand Up @@ -3364,7 +3364,7 @@ async def _atransform(
from langchain_core.beta.runnables.context import aconfig_with_context

steps = [self.first] + self.middle + [self.last]
config = aconfig_with_context(config, self.steps)
config = await aconfig_with_context(config, self.steps)

# stream the last steps
# transform the input stream of each step with the next
Expand Down
28 changes: 26 additions & 2 deletions libs/core/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
responses = "^0.25.0"
pytest-socket = "^0.7.0"
blockbuster = "~1.5.8"
[[tool.poetry.group.test.dependencies.numpy]]
version = "^1.24.0"
python = "<3.12"
Expand Down
17 changes: 17 additions & 0 deletions libs/core/tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,27 @@
from uuid import UUID

import pytest
from blockbuster import blockbuster_ctx
from pytest import Config, Function, Parser
from pytest_mock import MockerFixture


@pytest.fixture(autouse=True)
def blockbuster(request):
with blockbuster_ctx() as bb:
for func in ["os.stat", "os.path.abspath"]:
bb.functions[func].can_block_in(
"langchain_core/_api/internal.py", "is_caller_internal"
)

for func in ["os.stat", "io.TextIOWrapper.read"]:
bb.functions[func].can_block_in(
"langsmith/client.py", "_default_retry_config"
)

yield bb


def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
Expand Down
7 changes: 6 additions & 1 deletion libs/core/tests/unit_tests/fake/test_fake_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,12 @@ async def on_llm_new_token(
model = GenericFakeChatModel(messages=infinite_cycle)
tokens: list[str] = []
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
results = [
chunk
async for chunk in model.astream(
"meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}
)
]
assert results == [
_any_id_ai_message_chunk(content="hello"),
_any_id_ai_message_chunk(content=" "),
Expand Down
59 changes: 37 additions & 22 deletions libs/core/tests/unit_tests/language_models/chat_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.tracers import LogStreamCallbackHandler
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.context import collect_runs
from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler
Expand Down Expand Up @@ -304,39 +305,48 @@ def _stream(


@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming(
def test_disable_streaming(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = StreamingModel(disable_streaming=disable_streaming)
assert model.invoke([]).content == "invoke"
assert (await model.ainvoke([])).content == "invoke"

expected = "invoke" if disable_streaming is True else "stream"
assert next(model.stream([])).content == expected
async for c in model.astream([]):
assert c.content == expected
break
assert (
model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content
== expected
)

expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
[], config={"callbacks": [LogStreamCallbackHandler()]}, tools=[{}]
).content
== expected
)


@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming_async(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = StreamingModel(disable_streaming=disable_streaming)
assert (await model.ainvoke([])).content == "invoke"

expected = "invoke" if disable_streaming is True else "stream"
async for c in model.astream([]):
assert c.content == expected
break
assert (
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
).content == expected

expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
async for c in model.astream([], tools=[{}]):
assert c.content == expected
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
).content
== expected
)
assert (
await model.ainvoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
Expand All @@ -345,26 +355,31 @@ async def test_disable_streaming(


@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming_no_streaming_model(
def test_disable_streaming_no_streaming_model(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = NoStreamingModel(disable_streaming=disable_streaming)
assert model.invoke([]).content == "invoke"
assert (await model.ainvoke([])).content == "invoke"
assert next(model.stream([])).content == "invoke"
async for c in model.astream([]):
assert c.content == "invoke"
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
).content
model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content
== "invoke"
)
assert next(model.stream([], tools=[{}])).content == "invoke"


@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming_no_streaming_model_async(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = NoStreamingModel(disable_streaming=disable_streaming)
assert (await model.ainvoke([])).content == "invoke"
async for c in model.astream([]):
assert c.content == "invoke"
break
assert (
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
).content == "invoke"
assert next(model.stream([], tools=[{}])).content == "invoke"
async for c in model.astream([], tools=[{}]):
assert c.content == "invoke"
break
67 changes: 30 additions & 37 deletions libs/core/tests/unit_tests/prompts/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import base64
import tempfile
import warnings
from pathlib import Path
from typing import Any, Union, cast
Expand Down Expand Up @@ -722,44 +720,39 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None:
async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None:
"""Verify that we cannot pass `path` for an image as a variable."""
in_mem = "base64mem"
in_file_data = "base64file01"

with tempfile.NamedTemporaryFile(delete=True, suffix=".jpg") as temp_file:
temp_file.write(base64.b64decode(in_file_data))
temp_file.flush()

template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
(
"human",
[
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": "data:image/jpeg;base64,{in_mem}",
},
{
"type": "image_url",
"image_url": {"path": "{file_path}"},
},
],
),
]
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
(
"human",
[
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": "data:image/jpeg;base64,{in_mem}",
},
{
"type": "image_url",
"image_url": {"path": "{file_path}"},
},
],
),
]
)
with pytest.raises(ValueError):
template.format_messages(
name="R2D2",
in_mem=in_mem,
file_path="some/path",
)
with pytest.raises(ValueError):
template.format_messages(
name="R2D2",
in_mem=in_mem,
file_path=temp_file.name,
)

with pytest.raises(ValueError):
await template.aformat_messages(
name="R2D2",
in_mem=in_mem,
file_path=temp_file.name,
)
with pytest.raises(ValueError):
await template.aformat_messages(
name="R2D2",
in_mem=in_mem,
file_path="some/path",
)


def test_messages_placeholder() -> None:
Expand Down
17 changes: 12 additions & 5 deletions libs/core/tests/unit_tests/runnables/test_context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, Callable, NamedTuple, Union

import pytest
Expand Down Expand Up @@ -330,19 +331,26 @@ def seq_naive_rag_scoped() -> Runnable:


@pytest.mark.parametrize("runnable, cases", test_cases)
async def test_context_runnables(
def test_context_runnables(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert runnable.invoke(cases[0].input) == cases[0].output
assert await runnable.ainvoke(cases[1].input) == cases[1].output
assert runnable.batch([case.input for case in cases]) == [
case.output for case in cases
]
assert add(runnable.stream(cases[0].input)) == cases[0].output


@pytest.mark.parametrize("runnable, cases", test_cases)
async def test_context_runnables_async(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert await runnable.ainvoke(cases[1].input) == cases[1].output
assert await runnable.abatch([case.input for case in cases]) == [
case.output for case in cases
]
assert add(runnable.stream(cases[0].input)) == cases[0].output
assert await aadd(runnable.astream(cases[1].input)) == cases[1].output


Expand Down Expand Up @@ -390,8 +398,7 @@ async def test_runnable_seq_streaming_chunks() -> None:
"prompt": Context.getter("prompt"),
}
)

chunks = list(chain.stream({"foo": "foo", "bar": "bar"}))
chunks = await asyncio.to_thread(list, chain.stream({"foo": "foo", "bar": "bar"}))
achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})]
for c in chunks:
assert c in achunks
Expand Down
Loading

0 comments on commit 7ce4b18

Please sign in to comment.