Skip to content

Commit

Permalink
refactor: unifed old retrieve API (#573)
Browse files Browse the repository at this point in the history
Unify `/admin/retrieve/documents` and `/admin/embedding_retrieve` API
impl
  • Loading branch information
Mini256 authored Jan 8, 2025
1 parent 3903ad9 commit a0cf8ac
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 202 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
.idea
.vscode
.env
.ruff_cache

redis-data
data
Expand Down
3 changes: 2 additions & 1 deletion backend/app/api/admin_routes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ class ChatEngineDescriptor(BaseModel):
is_default: bool


class RetrieveRequest(BaseModel):
class ChatEngineBasedRetrieveRequest(BaseModel):
query: str
chat_engine: Optional[str] = "default"
top_k: Optional[int] = 5
similarity_top_k: Optional[int] = None
oversampling_factor: Optional[int] = 5
enable_kg_enchance_query_refine: Optional[bool] = True
29 changes: 18 additions & 11 deletions backend/app/api/admin_routes/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from fastapi import APIRouter
from app.models import Document
from app.api.admin_routes.models import RetrieveRequest
from app.api.admin_routes.models import ChatEngineBasedRetrieveRequest
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.rag.retrieve import RetrieveService
from llama_index.core.schema import NodeWithScore
from app.rag.retrieve import retrieve_service

from app.exceptions import InternalServerError, KBNotFound

Expand All @@ -23,14 +23,17 @@ async def retrieve_documents(
top_k: Optional[int] = 5,
similarity_top_k: Optional[int] = None,
oversampling_factor: Optional[int] = 5,
enable_kg_enchance_query_refine: Optional[bool] = True,
) -> List[Document]:
try:
retrieve_service = RetrieveService(session, chat_engine)
return retrieve_service.retrieve(
question,
return retrieve_service.chat_engine_retrieve_documents(
session,
question=question,
top_k=top_k,
chat_engine_name=chat_engine,
similarity_top_k=similarity_top_k,
oversampling_factor=oversampling_factor,
enable_kg_enchance_query_refine=enable_kg_enchance_query_refine,
)
except KBNotFound as e:
raise e
Expand All @@ -48,14 +51,17 @@ async def embedding_retrieve(
top_k: Optional[int] = 5,
similarity_top_k: Optional[int] = None,
oversampling_factor: Optional[int] = 5,
enable_kg_enchance_query_refine: Optional[bool] = True,
) -> List[NodeWithScore]:
try:
retrieve_service = RetrieveService(session, chat_engine)
return retrieve_service.embedding_retrieve(
question,
return retrieve_service.chat_engine_retrieve_chunks(
session,
question=question,
top_k=top_k,
chat_engine_name=chat_engine,
similarity_top_k=similarity_top_k,
oversampling_factor=oversampling_factor,
enable_kg_enchance_query_refine=enable_kg_enchance_query_refine,
)
except KBNotFound as e:
raise e
Expand All @@ -68,15 +74,16 @@ async def embedding_retrieve(
async def embedding_search(
session: SessionDep,
user: CurrentSuperuserDep,
request: RetrieveRequest,
request: ChatEngineBasedRetrieveRequest,
) -> List[NodeWithScore]:
try:
retrieve_service = RetrieveService(session, request.chat_engine)
return retrieve_service.embedding_retrieve(
return retrieve_service.chat_engine_retrieve_chunks(
session,
request.query,
top_k=request.top_k,
similarity_top_k=request.similarity_top_k,
oversampling_factor=request.oversampling_factor,
enable_kg_enchance_query_refine=request.enable_kg_enchance_query_refine,
)
except KBNotFound as e:
raise e
Expand Down
4 changes: 2 additions & 2 deletions backend/app/rag/chat_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import logging
from typing import Optional
from typing import Dict, Optional

import dspy
from llama_index.llms.bedrock.utils import BEDROCK_FOUNDATION_LLMS
Expand Down Expand Up @@ -82,7 +82,7 @@ class KnowledgeGraphOption(BaseModel):
include_meta: bool = True
with_degree: bool = False
using_intent_search: bool = True
relationship_meta_filters: Optional[dict] = None
relationship_meta_filters: Optional[Dict] = None


class ExternalChatEngine(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def retrieve_with_weight(
with_degree: bool = False,
with_chunks: bool = True,
# experimental feature to filter relationships based on meta, can be removed in the future
relationship_meta_filters: Dict = {},
relationship_meta_filters: dict = {},
session: Optional[Session] = None,
) -> Tuple[list, list, list]:
if not embedding:
Expand Down
229 changes: 45 additions & 184 deletions backend/app/rag/retrieve.py
Original file line number Diff line number Diff line change
@@ -1,214 +1,75 @@
import logging
from typing import List, Optional, Type

from llama_index.core import VectorStoreIndex
from typing import List, Optional
from llama_index.core.schema import NodeWithScore
from sqlmodel import Session, select
from sqlmodel import Session

from app.models import (
Document as DBDocument,
Chunk as DBChunk,
Entity as DBEntity,
Relationship as DBRelationship,
)
from app.models.chunk import get_kb_chunk_model
from app.models.entity import get_kb_entity_model
from app.rag.chat import get_prompt_by_jinja2_template
from app.rag.chat_config import ChatEngineConfig, get_default_embed_model
from app.models.relationship import get_kb_relationship_model
from app.models.patch.sql_model import SQLModel
from app.rag.knowledge_base.config import get_kb_embed_model
from app.rag.knowledge_graph import KnowledgeGraphIndex
from app.rag.knowledge_graph.graph_store import TiDBGraphStore
from app.rag.vector_store.tidb_vector_store import TiDBVectorStore
from app.rag.retrievers.LegacyChatEngineRetriever import (
ChatEngineBasedRetriever,
)
from app.repositories.chunk import ChunkRepo
from app.repositories.knowledge_base import knowledge_base_repo
from app.models.chunk import get_kb_chunk_model
from app.rag.chat_config import ChatEngineConfig


logger = logging.getLogger(__name__)


class RetrieveService:
_chunk_model: Type[SQLModel] = DBChunk
_entity_model: Type[SQLModel] = DBEntity
_relationship_model: Type[SQLModel] = DBRelationship

def __init__(
class ChatEngineBasedRetrieveService:
def chat_engine_retrieve_documents(
self,
db_session: Session,
engine_name: str = "default",
) -> None:
self.db_session = db_session
self.engine_name = engine_name

self.chat_engine_config = ChatEngineConfig.load_from_db(db_session, engine_name)
self.db_chat_engine = self.chat_engine_config.get_db_chat_engine()

if self.chat_engine_config.knowledge_base:
# TODO: Support multiple knowledge base retrieve.
linked_knowledge_base = (
self.chat_engine_config.knowledge_base.linked_knowledge_base
)
kb = knowledge_base_repo.must_get(db_session, linked_knowledge_base.id)
self._chunk_model = get_kb_chunk_model(kb)
self._entity_model = get_kb_entity_model(kb)
self._relationship_model = get_kb_relationship_model(kb)
self._embed_model = get_kb_embed_model(self.db_session, kb)
else:
self._embed_model = get_default_embed_model(self.db_session)

def retrieve(
self,
question: str,
top_k: int = 10,
top_k: int = 5,
chat_engine_name: str = "default",
similarity_top_k: Optional[int] = None,
oversampling_factor: int = 5,
oversampling_factor: Optional[int] = None,
enable_kg_enchance_query_refine: bool = False,
) -> List[DBDocument]:
"""
Retrieve the related documents based on the user question.
Args:
question: The question from client.
top_k: The number of top-k CHUNKS, and retrieve its related documents.
Which means it might return less than top_k documents.
Returns:
A list of related documents.
"""
_fast_llm = self.chat_engine_config.get_fast_llama_llm(self.db_session)
_fast_dspy_lm = self.chat_engine_config.get_fast_dspy_lm(self.db_session)

# 1. Retrieve entities, relations, and chunks from the knowledge graph
kg_config = self.chat_engine_config.knowledge_graph
if kg_config.enabled:
graph_store = TiDBGraphStore(
dspy_lm=_fast_dspy_lm,
session=self.db_session,
embed_model=self._embed_model,
entity_db_model=self._entity_model,
relationship_db_model=self._relationship_model,
)
graph_index: KnowledgeGraphIndex = KnowledgeGraphIndex.from_existing(
dspy_lm=_fast_dspy_lm,
kg_store=graph_store,
)

if kg_config.using_intent_search:
sub_queries = graph_index.intent_analyze(question)
result = graph_index.graph_semantic_search(
sub_queries, include_meta=True
)

graph_knowledges = get_prompt_by_jinja2_template(
self.chat_engine_config.llm.intent_graph_knowledge,
sub_queries=result["queries"],
)
graph_knowledges_context = graph_knowledges.template
else:
entities, relations, chunks = graph_index.retrieve_with_weight(
question,
[],
depth=kg_config.depth,
include_meta=kg_config.include_meta,
with_degree=kg_config.with_degree,
with_chunks=False,
)
graph_knowledges = get_prompt_by_jinja2_template(
self.chat_engine_config.llm.normal_graph_knowledge,
entities=entities,
relationships=relations,
)
graph_knowledges_context = graph_knowledges.template
else:
entities, relations, chunks = [], [], []
graph_knowledges_context = ""

# 2. Refine the user question using graph information and chat history
refined_question = _fast_llm.predict(
get_prompt_by_jinja2_template(
self.chat_engine_config.llm.condense_question_prompt,
graph_knowledges=graph_knowledges_context,
question=question,
),
)

# 3. Retrieve the related chunks from the vector store
# 4. Rerank after the retrieval

# Vector Index
vector_store = TiDBVectorStore(
session=self.db_session,
chunk_db_model=self._chunk_model,
chat_engine_config = ChatEngineConfig.load_from_db(db_session, chat_engine_name)
if not chat_engine_config.knowledge_base:
raise Exception("Chat engine does not configured with knowledge base")

nodes = self.chat_engine_retrieve_chunks(
db_session,
question=question,
top_k=top_k,
chat_engine_name=chat_engine_name,
similarity_top_k=similarity_top_k,
oversampling_factor=oversampling_factor,
)
vector_index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=self._embed_model,
)

# Node postprocessors
metadata_filter = self.chat_engine_config.get_metadata_filter()
reranker = self.chat_engine_config.get_reranker(self.db_session, top_n=top_k)
if reranker:
node_postprocessors = [metadata_filter, reranker]
else:
node_postprocessors = [metadata_filter]

# Retriever Engine
retrieve_engine = vector_index.as_retriever(
similarity_top_k=similarity_top_k or top_k,
filters=metadata_filter.filters,
node_postprocessors=node_postprocessors,
enable_kg_enchance_query_refine=enable_kg_enchance_query_refine,
)

node_list: List[NodeWithScore] = retrieve_engine.retrieve(refined_question)
source_documents = self._get_source_documents(node_list)
linked_knowledge_base = chat_engine_config.knowledge_base.linked_knowledge_base
kb = knowledge_base_repo.must_get(db_session, linked_knowledge_base.id)
chunk_model = get_kb_chunk_model(kb)
chunk_repo = ChunkRepo(chunk_model)
chunk_ids = [node.node.node_id for node in nodes]

return source_documents
return chunk_repo.get_documents_by_chunk_ids(db_session, chunk_ids)

def embedding_retrieve(
def chat_engine_retrieve_chunks(
self,
db_session: Session,
question: str,
top_k: int = 10,
top_k: int = 5,
chat_engine_name: str = "default",
similarity_top_k: Optional[int] = None,
oversampling_factor: int = 5,
oversampling_factor: Optional[int] = None,
enable_kg_enchance_query_refine: bool = False,
) -> List[NodeWithScore]:
# Vector Index
vector_store = TiDBVectorStore(
session=self.db_session,
chunk_db_model=self._chunk_model,
retriever = ChatEngineBasedRetriever(
db_session=db_session,
engine_name=chat_engine_name,
top_k=top_k,
similarity_top_k=similarity_top_k,
oversampling_factor=oversampling_factor,
enable_kg_enchance_query_refine=enable_kg_enchance_query_refine,
)
vector_index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=self._embed_model,
)
return retriever.retrieve(question)

# Node postprocessors
metadata_filter = self.chat_engine_config.get_metadata_filter()
reranker = self.chat_engine_config.get_reranker(self.db_session, top_n=top_k)
if reranker:
node_postprocessors = [metadata_filter, reranker]
else:
node_postprocessors = [metadata_filter]

# Retriever Engine
retrieve_engine = vector_index.as_retriever(
node_postprocessors=node_postprocessors,
similarity_top_k=similarity_top_k or top_k,
)

node_list: List[NodeWithScore] = retrieve_engine.retrieve(question)
return node_list

def _get_source_documents(self, node_list: List[NodeWithScore]) -> List[DBDocument]:
source_nodes_ids = [s_n.node_id for s_n in node_list]
stmt = select(DBDocument).where(
DBDocument.id.in_(
select(
self._chunk_model.document_id,
).where(
self._chunk_model.id.in_(source_nodes_ids),
)
),
)

return list(self.db_session.exec(stmt).all())
retrieve_service = ChatEngineBasedRetrieveService()
Loading

0 comments on commit a0cf8ac

Please sign in to comment.