diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 2c39ba53..678d8e8f 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -1398,6 +1398,7 @@ export class Pregel< ); } }, + signal: config.signal, }); } if (loop.status === "out_of_steps") { diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index 37d5360a..d7081185 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -9429,6 +9429,44 @@ graph TD; const thirdState = await graph.getState(config); expect(thirdState.tasks).toHaveLength(0); }); + + it("should cancel when signal is aborted", async () => { + let oneCount = 0; + let twoCount = 0; + const graph = new StateGraph(MessagesAnnotation) + .addNode("one", async () => { + oneCount += 1; + await new Promise((resolve) => setTimeout(resolve, 100)); + return {}; + }) + .addNode("two", () => { + twoCount += 1; + throw new Error("Should not be called!"); + }) + .addEdge(START, "one") + .addEdge("one", "two") + .addEdge("two", END) + .compile(); + + const abortController = new AbortController(); + const config = { + signal: abortController.signal, + }; + + setTimeout(() => abortController.abort(), 10); + + await expect( + async () => + await graph.invoke( + { + messages: [], + }, + config + ) + ).rejects.toThrow("Aborted"); + expect(oneCount).toEqual(1); + expect(twoCount).toEqual(0); + }); } runPregelTests(() => new MemorySaverAssertImmutable());