Skip to content

Commit

Permalink
Semantic Chunking of Content Files (#2005)
Browse files Browse the repository at this point in the history
* adding deps for semantic chunker

* conforming to langchain embedding interface and adding ability to toggle semantic chunker

* fixing tests

* combine recursive and semantic chunker to stay within chunk size

* fixing tests

* updating defaults

* adding test and fixes

* doc updates
  • Loading branch information
shanbady authored Jan 31, 2025
1 parent c5c98a8 commit 5f58f93
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 54 deletions.
30 changes: 28 additions & 2 deletions main/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,9 +849,35 @@ def get_all_config_keys():
AI_MAX_BUDGET = get_float(name="AI_MAX_BUDGET", default=0.05)
AI_ANON_LIMIT_MULTIPLIER = get_float(name="AI_ANON_LIMIT_MULTIPLIER", default=10.0)
CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE = get_int(
name="CONTENT_FILE_EMBEDDING_CHUNK_SIZE", default=None
name="CONTENT_FILE_EMBEDDING_CHUNK_SIZE", default=512
)
CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP = get_int(
name="CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP",
default=200, # default that the tokenizer uses
default=0, # default that the tokenizer uses
)
CONTENT_FILE_EMBEDDING_SEMANTIC_CHUNKING_ENABLED = get_bool(
name="CONTENT_FILE_EMBEDDING_SEMANTIC_CHUNKING_ENABLED", default=False
)

SEMANTIC_CHUNKING_CONFIG = {
"buffer_size": get_int(
# Number of sentences to combine.
name="SEMANTIC_CHUNKING_BUFFER_SIZE",
default=1,
),
"breakpoint_threshold_type": get_string(
# 'percentile', 'standard_deviation', 'interquartile', or 'gradient'
name="SEMANTIC_CHUNKING_BREAKPOINT_THRESHOLD_TYPE",
default="percentile",
),
"breakpoint_threshold_amount": get_float(
# value we use for breakpoint_threshold_type to filter outliers
name="SEMANTIC_CHUNKING_BREAKPOINT_THRESHOLD_AMOUNT",
default=None,
),
"number_of_chunks": get_int(
# number of chunks to consider for merging
name="SEMANTIC_CHUNKING_NUMBER_OF_CHUNKS",
default=None,
),
}
97 changes: 92 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ tiktoken = "^0.8.0"
llama-index = "^0.12.6"
llama-index-llms-openai = "^0.3.12"
llama-index-agent-openai = "^0.4.1"
langchain-experimental = "^0.3.4"
langchain-openai = "^0.3.2"


[tool.poetry.group.dev.dependencies]
Expand Down
32 changes: 28 additions & 4 deletions vector_search/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ class DummyEmbedEncoder(BaseEncoder):
def __init__(self, model_name="dummy-embedding"):
self.model_name = model_name

def encode(self, text: str) -> list: # noqa: ARG002
def embed(self, text: str) -> list: # noqa: ARG002
return np.random.random((10, 1))

def encode_batch(self, texts: list[str]) -> list[list[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return np.random.random((10, len(texts)))


Expand All @@ -32,13 +32,37 @@ def _use_test_qdrant_settings(settings, mocker):
settings.QDRANT_HOST = "https://test"
settings.QDRANT_BASE_COLLECTION_NAME = "test"
settings.CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP = 0
settings.CONTENT_FILE_EMBEDDING_SEMANTIC_CHUNKING_ENABLED = False
mock_qdrant = mocker.patch("qdrant_client.QdrantClient")
mocker.patch("vector_search.utils.SemanticChunker")

mock_qdrant.scroll.return_value = [
[],
None,
]
get_text_splitter_patch = mocker.patch("vector_search.utils._get_text_splitter")
get_text_splitter_patch.return_value = RecursiveCharacterTextSplitter()
get_text_splitter_patch = mocker.patch("vector_search.utils._chunk_documents")
get_text_splitter_patch.return_value = (
RecursiveCharacterTextSplitter().create_documents(
texts=["test dociment"],
metadatas=[
{
"run_title": "",
"platform": "",
"offered_by": "",
"run_readable_id": "",
"resource_readable_id": "",
"content_type": "",
"file_extension": "",
"content_feature_type": "",
"course_number": "",
"file_type": "",
"description": "",
"key": "",
"url": "",
}
],
)
)
mock_qdrant.count.return_value = CountResult(count=10)
mocker.patch(
"vector_search.utils.qdrant_client",
Expand Down
28 changes: 18 additions & 10 deletions vector_search/encoders/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from abc import ABC, abstractmethod

from langchain_core.embeddings import Embeddings

class BaseEncoder(ABC):

class BaseEncoder(Embeddings, ABC):
"""
Base encoder class
Base encoder class.
Conforms to the langchain Embeddings interface
"""

def model_short_name(self):
Expand All @@ -17,21 +20,26 @@ def model_short_name(self):
model_name = split_model_name[1]
return model_name

def encode(self, text):
def embed(self, text):
"""
Embed a single text
"""
return next(iter(self.encode_batch([text])))
return next(iter(self.embed_documents([text])))

def dim(self):
"""
Return the dimension of the embeddings
"""
return len(self.embed("test"))

@abstractmethod
def encode_batch(self, texts: list[str]) -> list[list[float]]:
def embed_documents(self, documents):
"""
Embed multiple texts
Embed a list of documents
"""
return [self.encode(text) for text in texts]

def dim(self):
def embed_query(self, query):
"""
Return the dimension of the embeddings
Embed a query
"""
return len(self.encode("test"))
return self.embed(query)
4 changes: 2 additions & 2 deletions vector_search/encoders/fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def __init__(self, model_name="BAAI/bge-small-en-v1.5"):
self.model_name = model_name
self.model = TextEmbedding(model_name=model_name, lazy_load=True)

def encode_batch(self, texts: list[str]) -> list[list[float]]:
return self.model.embed(texts)
def embed_documents(self, documents):
return list(self.model.embed(documents))

def dim(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions vector_search/encoders/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __init__(self, model_name="text-embedding-3-small"):
msg = f"Model {model_name} not found in tiktoken. defaulting to None"
log.warning(msg)

def encode_batch(self, texts: list[str]) -> list[list[float]]:
return [result["embedding"] for result in self.get_embedding(texts)["data"]]
def embed_documents(self, documents):
return [result["embedding"] for result in self.get_embedding(documents)["data"]]

def get_embedding(self, texts):
if settings.LITELLM_CUSTOM_PROVIDER and settings.LITELLM_API_BASE:
Expand Down
Loading

0 comments on commit 5f58f93

Please sign in to comment.