Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: draft summarization nodes #3540

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/quivr_core/brain/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,8 @@ async def ask_streaming(
if rag_pipeline is None:
rag_pipeline = QuivrQARAGLangGraph

logger.info(f"LLLL Using vector db : {self.vector_db}")

rag_instance = rag_pipeline(
retrieval_config=retrieval_config, llm=llm, vector_store=self.vector_db
)
Expand Down
15 changes: 9 additions & 6 deletions core/quivr_core/llm_tools/llm_tools.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
from typing import Dict, Any, Type, Union
from typing import Any, Dict, Type, Union

from quivr_core.llm_tools.entity import ToolWrapper

from quivr_core.llm_tools.web_search_tools import (
WebSearchTools,
)

from quivr_core.llm_tools.other_tools import (
OtherTools,
)
from quivr_core.llm_tools.summarization_tool import (
SummarizationTools,
)
from quivr_core.llm_tools.web_search_tools import (
WebSearchTools,
)

TOOLS_CATEGORIES = {
WebSearchTools.name: WebSearchTools,
SummarizationTools.name: SummarizationTools,
OtherTools.name: OtherTools,
}

# Register all ToolsList enums
TOOLS_LISTS = {
**{tool.value: tool for tool in WebSearchTools.tools},
**{tool.value: tool for tool in SummarizationTools.tools},
**{tool.value: tool for tool in OtherTools.tools},
}

Expand Down
186 changes: 186 additions & 0 deletions core/quivr_core/llm_tools/summarization_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import operator
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, TypedDict

from langchain.chains.combine_documents.reduce import (
acollapse_docs,
split_list_of_docs,
)
from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from langgraph.constants import Send
from langgraph.graph import END, START, StateGraph

from quivr_core.llm_tools.entity import ToolRegistry, ToolsCategory, ToolWrapper

# This code is Widely inspired by https://python.langchain.com/docs/tutorials/summarization/

default_map_prompt = "Write a concise summary of the following:\\n\\n{context}"

default_reduce_prompt = """
The following is a set of summaries:
{docs}
Take these and distill it into a final, consolidated summary
of the main themes.
"""


class SummaryToolsList(str, Enum):
MAPREDUCE = "map_reduce"


# This will be the overall state of the main graph.
# It will contain the input document contents, corresponding
# summaries, and a final summary.
class OverallState(TypedDict):
# Notice here we use the operator.add
# This is because we want combine all the summaries we generate
# from individual nodes back into one list - this is essentially
# the "reduce" part
contents: List[str]
summaries: Annotated[list, operator.add]
collapsed_summaries: List[Document]
final_summary: str


# This will be the state of the node that we will "map" all
# documents to in order to generate summaries
class SummaryState(TypedDict):
content: str


class MPSummarizeTool:
def __init__(
self,
llm: BaseChatModel,
token_max: int = 8000,
map_prompt: str = default_map_prompt,
reduce_prompt: str = default_reduce_prompt,
):
self.map_prompt = ChatPromptTemplate.from_messages([("system", map_prompt)])
self.reduce_prompt = ChatPromptTemplate([("human", reduce_prompt)])
self.llm = llm
self.token_max = token_max

def length_function(self, documents: List[Document]) -> int:
"""Get number of tokens for input contents."""
return sum(self.llm.get_num_tokens(doc.page_content) for doc in documents)

# Here we generate a summary, given a document
async def generate_summary(self, state: SummaryState):
prompt = self.map_prompt.invoke(state["content"])
response = await self.llm.ainvoke(prompt)
return {"summaries": [response.content]}

# Here we define the logic to map out over the documents
# We will use this an edge in the graph
def map_summaries(self, state: OverallState):
# We will return a list of `Send` objects
# Each `Send` object consists of the name of a node in the graph
# as well as the state to send to that node
return [
Send("generate_summary", {"content": content})
for content in state["contents"]
]

def collect_summaries(self, state: OverallState):
return {
"collapsed_summaries": [Document(summary) for summary in state["summaries"]]
}

async def _reduce(self, input: list) -> str:
prompt = self.reduce_prompt.invoke(input)
response = await self.llm.ainvoke(prompt)
return response.content

# Add node to collapse summaries
async def collapse_summaries(self, state: OverallState):
doc_lists = split_list_of_docs(
state["collapsed_summaries"], self.length_function, self.token_max
)
results = []
for doc_list in doc_lists:
results.append(await acollapse_docs(doc_list, self._reduce))

return {"collapsed_summaries": results}

# This represents a conditional edge in the graph that determines
# if we should collapse the summaries or not
def should_collapse(
self,
state: OverallState,
) -> Literal["collapse_summaries", "generate_final_summary"]:
num_tokens = self.length_function(state["collapsed_summaries"])
if num_tokens > self.token_max:
return "collapse_summaries"
else:
return "generate_final_summary"

# Here we will generate the final summary
async def generate_final_summary(self, state: OverallState):
response = await self._reduce(state["collapsed_summaries"])
return {"final_summary": response}

def build(self):
summary_graph = StateGraph(OverallState)

summary_graph.add_node(
"generate_summary", self.generate_summary
) # same as before
summary_graph.add_node("collect_summaries", self.collect_summaries)
summary_graph.add_node("collapse_summaries", self.collapse_summaries)
summary_graph.add_node("generate_final_summary", self.generate_final_summary)

# Edges:
summary_graph.add_conditional_edges(
START, self.map_summaries, ["generate_summary"]
)
summary_graph.add_edge("generate_summary", "collect_summaries")
summary_graph.add_conditional_edges("collect_summaries", self.should_collapse)
summary_graph.add_conditional_edges("collapse_summaries", self.should_collapse)
summary_graph.add_edge("generate_final_summary", END)

return summary_graph.compile()


def create_summary_tool(config: Dict[str, Any]):
summary_tool = MPSummarizeTool(
config.get("map_prompt", None),
config.get("reduce_prompt", None),
config.get("llm", None),
config.get("token_max", None),
)

def format_input(task: str) -> Dict[str, Any]:
return {"query": task}

def format_output(response: Any) -> List[Document]:
return [
Document(
page_content=d["content"],
)
for d in response
]

return ToolWrapper(summary_tool, format_input, format_output)


# Initialize the registry and register tools
summarization_tool_registry = ToolRegistry()
summarization_tool_registry.register_tool(
SummaryToolsList.MAPREDUCE, create_summary_tool
)


def create_summarization_tool(tool_name: str, config: Dict[str, Any]) -> ToolWrapper:
return summarization_tool_registry.create_tool(tool_name, config)


SummarizationTools = ToolsCategory(
name="Summarization",
description="Tools for summarizing documents",
tools=[SummaryToolsList.MAPREDUCE],
default_tool=SummaryToolsList.MAPREDUCE,
create_tool=create_summarization_tool,
)
20 changes: 11 additions & 9 deletions core/quivr_core/rag/quivr_rag_langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.vectorstores import VectorStore
from langfuse.callback import CallbackHandler
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.types import Send
from pydantic import BaseModel, Field

from langfuse.callback import CallbackHandler

from quivr_core.llm import LLMEndpoint
from quivr_core.llm_tools.llm_tools import LLMToolFactory
from quivr_core.rag.entities.chat import ChatHistory
Expand Down Expand Up @@ -502,7 +501,7 @@ async def rewrite(self, state: AgentState) -> AgentState:
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

# Replace each question with its condensed version
for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
tasks.set_definition(task_id, response.content)

return {**state, "tasks": tasks}
Expand Down Expand Up @@ -558,7 +557,7 @@ async def tool_routing(self, state: AgentState):
)
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
tasks.set_completion(task_id, response.is_task_completable)
if not response.is_task_completable and response.tool:
tasks.set_tool(task_id, response.tool)
Expand Down Expand Up @@ -599,7 +598,7 @@ async def run_tool(self, state: AgentState) -> AgentState:
)
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
_docs = tool_wrapper.format_output(response)
_docs = self.filter_chunks_by_relevance(_docs)
tasks.set_docs(task_id, _docs)
Expand Down Expand Up @@ -634,7 +633,6 @@ async def retrieve(self, state: AgentState) -> AgentState:
compression_retriever = ContextualCompressionRetriever(
base_compressor=reranker, base_retriever=base_retriever
)

# Prepare the async tasks for all questions
async_jobs = []
for task_id in tasks.ids:
Expand All @@ -652,7 +650,7 @@ async def retrieve(self, state: AgentState) -> AgentState:
task_ids = [task[1] for task in async_jobs] if async_jobs else []

# Process responses and associate docs with tasks
for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
_docs = self.filter_chunks_by_relevance(response)
tasks.set_docs(task_id, _docs) # Associate docs with the specific task

Expand Down Expand Up @@ -715,7 +713,7 @@ async def dynamic_retrieve(self, state: AgentState) -> AgentState:
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

_n = []
for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
_docs = self.filter_chunks_by_relevance(response)
_n.append(len(_docs))
tasks.set_docs(task_id, _docs)
Expand All @@ -737,6 +735,10 @@ async def dynamic_retrieve(self, state: AgentState) -> AgentState:

return {**state, "tasks": tasks}

def retrieve_all_chunks_from_file(self, file_id: UUID) -> List[Document]:
retriever = self.get_retriever()
return retriever.get_by_ids(ids=[file_id])

def _sort_docs_by_relevance(self, docs: List[Document]) -> List[Document]:
return sorted(
docs,
Expand Down Expand Up @@ -1012,7 +1014,7 @@ def invoke_structured_output(
)
return structured_llm.invoke(prompt)
except openai.BadRequestError:
structured_llm = self.llm_endpoint._llm.with_structured_output(output_class)
structured_llm = self.llm_endpoint._llm.with_stuctured_output(output_class)
return structured_llm.invoke(prompt)

def _build_rag_prompt_inputs(
Expand Down
40 changes: 40 additions & 0 deletions core/tests/summarization_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
ingestion_config:
parser_config:
megaparse_config:
strategy: "fast"
pdf_parser: "unstructured"
splitter_config:
chunk_size: 400
chunk_overlap: 100


retrieval_config:
workflow_config:
name: "Summarizer"
available_tools:
- "Summarization"
nodes:
- name: "START"
edges: ["retrieve_all_chunks_from_file"]

- name: "retrieve_all_chunks_from_file"
edges: ["tool"]

- name: "run_tool"
edges: ["END"]

llm_config:
# The LLM supplier to use
supplier: "openai"

# The model to use for the LLM for the given supplier
model: "gpt-3.5-turbo-0125"

max_context_tokens: 2000

# Maximum number of tokens to pass to the LLM
# as a context to generate the answer
max_output_tokens: 2000

temperature: 0.7
streaming: true
2 changes: 1 addition & 1 deletion core/tests/test_quivr_rag.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from uuid import uuid4

import pytest
from quivr_core.llm import LLMEndpoint
from quivr_core.rag.entities.chat import ChatHistory
from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig
from quivr_core.llm import LLMEndpoint
from quivr_core.rag.entities.models import ParsedRAGChunkResponse, RAGResponseMetadata
from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph

Expand Down
Loading