diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 3484e8bba3c69..def410ba69396 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -4266,9 +4266,10 @@ def _invoke( config: RunnableConfig, **kwargs: Any, ) -> List[Output]: - return self.bound.batch( - inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs - ) + configs = [ + patch_config(config, callbacks=run_manager.get_child()) for _ in inputs + ] + return self.bound.batch(inputs, configs, **kwargs) def invoke( self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any @@ -4282,9 +4283,10 @@ async def _ainvoke( config: RunnableConfig, **kwargs: Any, ) -> List[Output]: - return await self.bound.abatch( - inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs - ) + configs = [ + patch_config(config, callbacks=run_manager.get_child()) for _ in inputs + ] + return await self.bound.abatch(inputs, configs, **kwargs) async def ainvoke( self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index dfa75418c42f0..37e6bff2e3fe4 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -5508,3 +5508,64 @@ async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] assert chunks == [{"foo": "a"}, {"foo": "n"}] + + +def test_listeners() -> None: + from langchain_core.runnables import RunnableLambda + from langchain_core.tracers.schemas import Run + + def fake_chain(inputs: dict) -> dict: + return {**inputs, "key": "extra"} + + shared_state = {} + value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}} + value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}} + + def on_start(run: Run) -> None: + shared_state[run.id] = {"inputs": run.inputs} + + def on_end(run: Run) -> None: + shared_state[run.id]["outputs"] = run.inputs + + chain = ( + RunnableLambda(fake_chain) + .with_listeners(on_end=on_end, on_start=on_start) + .map() + ) + + data = [{"name": "one"}, {"name": "two"}] + chain.invoke(data, config={"max_concurrency": 1}) + assert len(shared_state) == 2 + assert value1 in shared_state.values(), "Value not found in the dictionary." + assert value2 in shared_state.values(), "Value not found in the dictionary." + + +async def test_listeners_async() -> None: + from langchain_core.runnables import RunnableLambda + from langchain_core.tracers.schemas import Run + + def fake_chain(inputs: dict) -> dict: + return {**inputs, "key": "extra"} + + shared_state = {} + value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}} + value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}} + + def on_start(run: Run) -> None: + shared_state[run.id] = {"inputs": run.inputs} + + def on_end(run: Run) -> None: + shared_state[run.id]["outputs"] = run.inputs + + chain: Runnable = ( + RunnableLambda(fake_chain) + .with_listeners(on_end=on_end, on_start=on_start) + .map() + ) + + data = [{"name": "one"}, {"name": "two"}] + await chain.ainvoke(data, config={"max_concurrency": 1}) + + assert len(shared_state) == 2 + assert value1 in shared_state.values(), "Value not found in the dictionary." + assert value2 in shared_state.values(), "Value not found in the dictionary."