Skip to content

Commit

Permalink
core[patch]:Fix Incorrect listeners parameters for Runnable.with_list…
Browse files Browse the repository at this point in the history
  • Loading branch information
liugddx authored May 13, 2024
1 parent b0f5a47 commit a156aac
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 6 deletions.
14 changes: 8 additions & 6 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4266,9 +4266,10 @@ def _invoke(
config: RunnableConfig,
**kwargs: Any,
) -> List[Output]:
return self.bound.batch(
inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs
)
configs = [
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
]
return self.bound.batch(inputs, configs, **kwargs)

def invoke(
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
Expand All @@ -4282,9 +4283,10 @@ async def _ainvoke(
config: RunnableConfig,
**kwargs: Any,
) -> List[Output]:
return await self.bound.abatch(
inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs
)
configs = [
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
]
return await self.bound.abatch(inputs, configs, **kwargs)

async def ainvoke(
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
Expand Down
61 changes: 61 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5508,3 +5508,64 @@ async def chunk_iterator() -> AsyncIterator[Dict[str, str]]:

chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
assert chunks == [{"foo": "a"}, {"foo": "n"}]


def test_listeners() -> None:
from langchain_core.runnables import RunnableLambda
from langchain_core.tracers.schemas import Run

def fake_chain(inputs: dict) -> dict:
return {**inputs, "key": "extra"}

shared_state = {}
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}

def on_start(run: Run) -> None:
shared_state[run.id] = {"inputs": run.inputs}

def on_end(run: Run) -> None:
shared_state[run.id]["outputs"] = run.inputs

chain = (
RunnableLambda(fake_chain)
.with_listeners(on_end=on_end, on_start=on_start)
.map()
)

data = [{"name": "one"}, {"name": "two"}]
chain.invoke(data, config={"max_concurrency": 1})
assert len(shared_state) == 2
assert value1 in shared_state.values(), "Value not found in the dictionary."
assert value2 in shared_state.values(), "Value not found in the dictionary."


async def test_listeners_async() -> None:
from langchain_core.runnables import RunnableLambda
from langchain_core.tracers.schemas import Run

def fake_chain(inputs: dict) -> dict:
return {**inputs, "key": "extra"}

shared_state = {}
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}

def on_start(run: Run) -> None:
shared_state[run.id] = {"inputs": run.inputs}

def on_end(run: Run) -> None:
shared_state[run.id]["outputs"] = run.inputs

chain: Runnable = (
RunnableLambda(fake_chain)
.with_listeners(on_end=on_end, on_start=on_start)
.map()
)

data = [{"name": "one"}, {"name": "two"}]
await chain.ainvoke(data, config={"max_concurrency": 1})

assert len(shared_state) == 2
assert value1 in shared_state.values(), "Value not found in the dictionary."
assert value2 in shared_state.values(), "Value not found in the dictionary."

0 comments on commit a156aac

Please sign in to comment.