Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Use Blockbuster to detect blocking calls in asyncio during tests #29043

Merged
merged 7 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions libs/core/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ grandalf = "^0.8"
responses = "^0.25.0"
pytest-socket = "^0.7.0"
pytest-xdist = "^3.6.1"
blockbuster = "~1.5.11"
[[tool.poetry.group.test.dependencies.numpy]]
version = "^1.24.0"
python = "<3.12"
Expand Down
29 changes: 28 additions & 1 deletion libs/core/tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,41 @@
"""Configuration for unit tests."""

from collections.abc import Sequence
from collections.abc import Iterator, Sequence
from importlib import util
from uuid import UUID

import pytest
from blockbuster import BlockBuster, blockbuster_ctx
from pytest import Config, Function, Parser
from pytest_mock import MockerFixture


@pytest.fixture(autouse=True)
def blockbuster() -> Iterator[BlockBuster]:
with blockbuster_ctx("langchain_core") as bb:
for func in ["os.stat", "os.path.abspath"]:
(
bb.functions[func]
.can_block_in("langchain_core/_api/internal.py", "is_caller_internal")
.can_block_in("langchain_core/runnables/base.py", "__repr__")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RunnableLambda's __repr__ calls get_lambda_source which is blocking. It should probably be cached.

.can_block_in(
"langchain_core/beta/runnables/context.py", "aconfig_with_context"
Copy link
Collaborator Author

@cbornet cbornet Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be dealt with in another PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

)
)

for func in ["os.stat", "io.TextIOWrapper.read"]:
bb.functions[func].can_block_in(
"langsmith/client.py", "_default_retry_config"
Copy link
Collaborator Author

@cbornet cbornet Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably something to be fixed in LangSmith.

)

for bb_function in bb.functions.values():
bb_function.can_block_in(
"freezegun/api.py", "_get_cached_module_attributes"
)

yield bb


def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
Expand Down
7 changes: 6 additions & 1 deletion libs/core/tests/unit_tests/fake/test_fake_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,12 @@ async def on_llm_new_token(
model = GenericFakeChatModel(messages=infinite_cycle)
tokens: list[str] = []
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
results = [
chunk
async for chunk in model.astream(
"meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}
)
]
assert results == [
_any_id_ai_message_chunk(content="hello"),
_any_id_ai_message_chunk(content=" "),
Expand Down
59 changes: 37 additions & 22 deletions libs/core/tests/unit_tests/language_models/chat_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.tracers import LogStreamCallbackHandler
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.context import collect_runs
from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler
Expand Down Expand Up @@ -303,39 +304,48 @@ def _stream(


@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming(
def test_disable_streaming(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = StreamingModel(disable_streaming=disable_streaming)
assert model.invoke([]).content == "invoke"
assert (await model.ainvoke([])).content == "invoke"

expected = "invoke" if disable_streaming is True else "stream"
assert next(model.stream([])).content == expected
async for c in model.astream([]):
assert c.content == expected
break
assert (
model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content
== expected
)

expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
[], config={"callbacks": [LogStreamCallbackHandler()]}, tools=[{}]
).content
== expected
)


@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming_async(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = StreamingModel(disable_streaming=disable_streaming)
assert (await model.ainvoke([])).content == "invoke"

expected = "invoke" if disable_streaming is True else "stream"
async for c in model.astream([]):
assert c.content == expected
break
assert (
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
).content == expected

expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
async for c in model.astream([], tools=[{}]):
assert c.content == expected
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
).content
== expected
)
assert (
await model.ainvoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
Expand All @@ -344,26 +354,31 @@ async def test_disable_streaming(


@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming_no_streaming_model(
def test_disable_streaming_no_streaming_model(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = NoStreamingModel(disable_streaming=disable_streaming)
assert model.invoke([]).content == "invoke"
assert (await model.ainvoke([])).content == "invoke"
assert next(model.stream([])).content == "invoke"
async for c in model.astream([]):
assert c.content == "invoke"
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
).content
model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content
== "invoke"
)
assert next(model.stream([], tools=[{}])).content == "invoke"


@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming_no_streaming_model_async(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = NoStreamingModel(disable_streaming=disable_streaming)
assert (await model.ainvoke([])).content == "invoke"
async for c in model.astream([]):
assert c.content == "invoke"
break
assert (
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
).content == "invoke"
assert next(model.stream([], tools=[{}])).content == "invoke"
async for c in model.astream([], tools=[{}]):
assert c.content == "invoke"
break
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import time
from typing import Optional as Optional

import pytest
from blockbuster import BlockBuster

from langchain_core.caches import InMemoryCache
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.rate_limiters import InMemoryRateLimiter


@pytest.fixture(autouse=True)
def deactivate_blockbuster(blockbuster: BlockBuster) -> None:
# Deactivate BlockBuster to not disturb the rate limiter timings
blockbuster.deactivate()


def test_rate_limit_invoke() -> None:
"""Add rate limiter."""
model = GenericFakeChatModel(
Expand Down
67 changes: 30 additions & 37 deletions libs/core/tests/unit_tests/prompts/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import base64
import tempfile
import warnings
from pathlib import Path
from typing import Any, Union, cast
Expand Down Expand Up @@ -727,44 +725,39 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None:
async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None:
"""Verify that we cannot pass `path` for an image as a variable."""
in_mem = "base64mem"
in_file_data = "base64file01"

with tempfile.NamedTemporaryFile(delete=True, suffix=".jpg") as temp_file:
temp_file.write(base64.b64decode(in_file_data))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's useless to use a real file here.

temp_file.flush()

template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
(
"human",
[
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": "data:image/jpeg;base64,{in_mem}",
},
{
"type": "image_url",
"image_url": {"path": "{file_path}"},
},
],
),
]
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
(
"human",
[
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": "data:image/jpeg;base64,{in_mem}",
},
{
"type": "image_url",
"image_url": {"path": "{file_path}"},
},
],
),
]
)
with pytest.raises(ValueError):
template.format_messages(
name="R2D2",
in_mem=in_mem,
file_path="some/path",
)
with pytest.raises(ValueError):
template.format_messages(
name="R2D2",
in_mem=in_mem,
file_path=temp_file.name,
)

with pytest.raises(ValueError):
await template.aformat_messages(
name="R2D2",
in_mem=in_mem,
file_path=temp_file.name,
)
with pytest.raises(ValueError):
await template.aformat_messages(
name="R2D2",
in_mem=in_mem,
file_path="some/path",
)


def test_messages_placeholder() -> None:
Expand Down
17 changes: 12 additions & 5 deletions libs/core/tests/unit_tests/runnables/test_context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, Callable, NamedTuple, Union

import pytest
Expand Down Expand Up @@ -330,19 +331,26 @@ def seq_naive_rag_scoped() -> Runnable:


@pytest.mark.parametrize("runnable, cases", test_cases)
async def test_context_runnables(
def test_context_runnables(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert runnable.invoke(cases[0].input) == cases[0].output
assert await runnable.ainvoke(cases[1].input) == cases[1].output
assert runnable.batch([case.input for case in cases]) == [
case.output for case in cases
]
assert add(runnable.stream(cases[0].input)) == cases[0].output


@pytest.mark.parametrize("runnable, cases", test_cases)
async def test_context_runnables_async(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert await runnable.ainvoke(cases[1].input) == cases[1].output
assert await runnable.abatch([case.input for case in cases]) == [
case.output for case in cases
]
assert add(runnable.stream(cases[0].input)) == cases[0].output
assert await aadd(runnable.astream(cases[1].input)) == cases[1].output


Expand Down Expand Up @@ -390,8 +398,7 @@ async def test_runnable_seq_streaming_chunks() -> None:
"prompt": Context.getter("prompt"),
}
)

chunks = list(chain.stream({"foo": "foo", "bar": "bar"}))
chunks = await asyncio.to_thread(list, chain.stream({"foo": "foo", "bar": "bar"}))
achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})]
for c in chunks:
assert c in achunks
Expand Down
Loading
Loading