-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnodes.py
29 lines (24 loc) · 1.18 KB
/
nodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from state import GraphState
from langchain_core.language_models import BaseChatModel
class Nodes:
def __init__(self, llm: BaseChatModel):
self.llm = llm
def call_llm1(self, state: GraphState) -> dict:
"""First LLM call to generate initial joke."""
msg = self.llm.invoke(f'Write a short joke about {state["topic"]}')
return {"joke": msg.content}
def call_llm2(self, state: GraphState) -> dict:
"""Second LLM call to generate story."""
msg = self.llm.invoke(f'Write a short storyy about: {state["topic"]}')
return {"story": msg.content}
def call_llm3(self, state: GraphState) -> dict:
"""Third LLM call to generate poem."""
msg = self.llm.invoke(f'Write a short poem about: {state["topic"]}')
return {"poem": msg.content}
def aggregator(self, state: GraphState) -> dict:
"""Aggregator function to combine outputs."""
combined = f"Here is a story, joke and poem about {state["topic"]}\n\n"
combined += f"Story:\n {state["story"]}\n\n"
combined += f"Joke:\n {state["joke"]}\n\n"
combined += f"Poem:\n {state["poem"]}\n\n"
return {"combined_output": combined}