Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removes as_chain #167

Merged
merged 4 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion django_ai_assistant/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

DEFAULTS = {
"INIT_API_FN": "django_ai_assistant.api.views.init_api",
"USE_LANGGRAPH": False,
"CAN_CREATE_THREAD_FN": "django_ai_assistant.permissions.allow_all",
"CAN_VIEW_THREAD_FN": "django_ai_assistant.permissions.owns_thread",
"CAN_UPDATE_THREAD_FN": "django_ai_assistant.permissions.owns_thread",
Expand Down
194 changes: 19 additions & 175 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,11 @@
import re
from typing import Annotated, Any, ClassVar, Sequence, TypedDict, cast

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad.tools import (
format_to_tool_messages,
)
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
)
from langchain.tools import StructuredTool
from langchain_core.chat_history import (
BaseChatMessageHistory,
InMemoryChatMessageHistory,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
Expand All @@ -38,19 +29,15 @@
RetrieverOutput,
)
from langchain_core.runnables import (
ConfigurableFieldSpec,
Runnable,
RunnableBranch,
RunnablePassthrough,
)
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode

from django_ai_assistant.conf import app_settings
from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import (
AIAssistantMisconfiguredError,
Expand Down Expand Up @@ -90,8 +77,6 @@ class AIAssistant(abc.ABC): # noqa: F821
"""Whether the assistant uses RAG (Retrieval-Augmented Generation) or not.\n
Defaults to `False`.
When True, the assistant will use a retriever to get documents to provide as context to the LLM.
For this to work, the `instructions` should contain a placeholder for the context,
which is `{context}` by default.
Additionally, the assistant class should implement the `get_retriever` method to return
the retriever to use."""
_user: Any | None
Expand Down Expand Up @@ -266,58 +251,6 @@ def get_model_kwargs(self) -> dict[str, Any]:
"""
return {}

def get_prompt_template(self) -> ChatPromptTemplate:
"""Get the `ChatPromptTemplate` for the Langchain chain to use.\n
The system prompt comes from the `get_instructions` method.\n
The template includes placeholders for the instructions, chat `{history}`, user `{input}`,
and `{agent_scratchpad}`, all which are necessary for the chain to work properly.\n
The chat history is filled by the chain using the message history from `get_message_history`.\n
If the assistant uses RAG, the instructions should contain a placeholder
for the context, which is `{context}` by default, defined by the `get_context_placeholder` method.

Returns:
ChatPromptTemplate: The chat prompt template for the Langchain chain.
"""
instructions = self.get_instructions()
context_placeholder = self.get_context_placeholder()
if self.has_rag and f"{context_placeholder}" not in instructions:
raise AIAssistantMisconfiguredError(
f"{self.__class__.__name__} has_rag=True"
f"but does not have a {{{context_placeholder}}} placeholder in instructions."
)

return ChatPromptTemplate.from_messages(
[
("system", instructions),
MessagesPlaceholder(variable_name="history"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)

@with_cast_id
def get_message_history(self, thread_id: Any | None) -> BaseChatMessageHistory:
"""Get the chat message history instance for the given `thread_id`.\n
The Langchain chain uses the return of this method to get the thread messages
for the assistant, filling the `history` placeholder in the `get_prompt_template`.\n

Args:
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.

Returns:
BaseChatMessageHistory: The chat message history instance for the given `thread_id`.
"""

# DjangoChatMessageHistory must be here because Django may not be loaded yet elsewhere:
from django_ai_assistant.langchain.chat_message_histories import (
DjangoChatMessageHistory,
)

if thread_id is None:
return InMemoryChatMessageHistory()
return DjangoChatMessageHistory(thread_id)

def get_llm(self) -> BaseChatModel:
"""Get the Langchain LLM instance for the assistant.
By default, this uses the OpenAI implementation.\n
Expand Down Expand Up @@ -368,15 +301,6 @@ def get_document_prompt(self) -> PromptTemplate:
"""
return DEFAULT_DOCUMENT_PROMPT

def get_context_placeholder(self) -> str:
"""Get the RAG context placeholder to use in the prompt when `has_rag=True`.\n
Defaults to `"context"`. Override this method to use a different placeholder.

Returns:
str: the RAG context placeholder to use in the prompt.
"""
return "context"

def get_retriever(self) -> BaseRetriever:
"""Get the RAG retriever to use for fetching documents.\n
Must be implemented by subclasses when `has_rag=True`.\n
Expand Down Expand Up @@ -464,15 +388,23 @@ def as_graph(self, thread_id: Any | None = None) -> Runnable[dict, dict]:
Returns:
the compiled graph
"""
# DjangoChatMessageHistory must be here because Django may not be loaded yet elsewhere.
# DjangoChatMessageHistory was used in the context of langchain, now that we are using
# langgraph this can be further simplified by just porting the add_messages logic.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this tracked? So we don't need the DjangoChatMessageHistory wrapper anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it seems like it could be replaced by one simple method that retrieves the message history and another that adds new ones. I created an issue for this #169

from django_ai_assistant.langchain.chat_message_histories import (
DjangoChatMessageHistory,
)

message_history = DjangoChatMessageHistory(thread_id) if thread_id else None

llm = self.get_llm()
tools = self.get_tools()
llm_with_tools = llm.bind_tools(tools) if tools else llm
message_history = self.get_message_history(thread_id)

def custom_add_messages(left: list[BaseMessage], right: list[BaseMessage]):
result = add_messages(left, right)

if thread_id:
if message_history:
messages_to_store = [
m
for m in result
Expand Down Expand Up @@ -513,7 +445,7 @@ def retriever(state: AgentState):
}

def history(state: AgentState):
history = message_history.messages if thread_id else []
history = message_history.messages if message_history else []
return {"messages": [*history, HumanMessage(content=state["input"])]}

def agent(state: AgentState):
Expand Down Expand Up @@ -558,113 +490,25 @@ def record_response(state: AgentState):

return workflow.compile()

@with_cast_id
def as_chain(self, thread_id: Any | None) -> Runnable[dict, dict]:
"""Create the Langchain chain for the assistant.\n
This chain is an agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n
`as_chain` uses many other methods to create the chain.\n
Prefer to override the other methods to customize the chain for the assistant.
Only override this method if you need to customize the chain at a lower level.

The chain input is a dictionary with the key `"input"` containing the user message.\n
The chain output is a dictionary with the key `"output"` containing the assistant response,
along with the key `"history"` containing the previous chat history.

Args:
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.

Returns:
Runnable[dict, dict]: The Langchain chain for the assistant.
"""
# Based on:
# - https://python.langchain.com/v0.2/docs/how_to/qa_chat_history_how_to/
# - https://python.langchain.com/v0.2/docs/how_to/migrate_agent/
# TODO: use langgraph instead?
llm = self.get_llm()
tools = self.get_tools()
prompt = self.get_prompt_template()
tools = cast(Sequence[BaseTool], tools)
if tools:
llm_with_tools = llm.bind_tools(tools)
else:
llm_with_tools = llm
chain = (
# based on create_tool_calling_agent:
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_to_tool_messages(x["intermediate_steps"])
).with_config(run_name="format_to_tool_messages")
)

if self.has_rag:
# based on create_retrieval_chain:
retriever = self.get_history_aware_retriever()
chain = chain | RunnablePassthrough.assign(
docs=retriever.with_config(run_name="retrieve_documents"),
)

# based on create_stuff_documents_chain:
document_separator = self.get_document_separator()
document_prompt = self.get_document_prompt()
context_placeholder = self.get_context_placeholder()
chain = chain | RunnablePassthrough.assign(
**{
context_placeholder: lambda x: document_separator.join(
format_document(doc, document_prompt) for doc in x["docs"]
)
}
).with_config(run_name="format_input_docs")

chain = chain | prompt | llm_with_tools | ToolsAgentOutputParser()

agent_executor = AgentExecutor(
agent=chain, # pyright: ignore[reportArgumentType]
tools=tools,
)
agent_with_chat_history = RunnableWithMessageHistory(
agent_executor, # pyright: ignore[reportArgumentType]
get_session_history=self.get_message_history,
input_messages_key="input",
history_messages_key="history",
history_factory_config=[
ConfigurableFieldSpec(
id="thread_id", # must match get_message_history kwarg
annotation=int,
name="Thread ID",
description="Unique identifier for the chat thread / conversation / session.",
default=None,
is_shared=True,
),
],
).with_config(
{"configurable": {"thread_id": thread_id}},
run_name="agent_with_chat_history",
)

return agent_with_chat_history

@with_cast_id
def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
"""Invoke the assistant Langchain chain with the given arguments and keyword arguments.\n
"""Invoke the assistant Langchain graph with the given arguments and keyword arguments.\n
This is the lower-level method to run the assistant.\n
The chain is created by the `as_chain` method.\n
The graph is created by the `as_graph` method.\n

Args:
*args: Positional arguments to pass to the chain.
*args: Positional arguments to pass to the graph.
Make sure to include a `dict` like `{"input": "user message"}`.
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
**kwargs: Keyword arguments to pass to the chain.
**kwargs: Keyword arguments to pass to the graph.

Returns:
dict: The output of the assistant chain,
dict: The output of the assistant graph,
structured like `{"output": "assistant response", "history": ...}`.
"""
if app_settings.USE_LANGGRAPH:
chain = self.as_graph(thread_id)
else:
chain = self.as_chain(thread_id)
return chain.invoke(*args, **kwargs)
graph = self.as_graph(thread_id)
return graph.invoke(*args, **kwargs)

@with_cast_id
def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> str:
Expand All @@ -675,7 +519,7 @@ def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> str:
message (str): The user message to pass to the assistant.
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
**kwargs: Additional keyword arguments to pass to the chain.
**kwargs: Additional keyword arguments to pass to the graph.

Returns:
str: The assistant response to the user message.
Expand Down
2 changes: 1 addition & 1 deletion django_ai_assistant/helpers/use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def create_message(
content (Any): Message content, usually a string
request (HttpRequest | None): Current request, if any
Returns:
dict: The output of the assistant chain,
dict: The output of the assistant,
structured like `{"output": "assistant response", "history": ...}`
Raises:
AIUserNotAllowedError: If user is not allowed to create messages in the thread
Expand Down
Loading