From 5f58f93ea739985c2a61b1dd35729a4eff825a58 Mon Sep 17 00:00:00 2001 From: Shankar Ambady Date: Fri, 31 Jan 2025 10:43:53 -0500 Subject: [PATCH] Semantic Chunking of Content Files (#2005) * adding deps for semantic chunker * conforming to langchain embedding interface and adding ability to toggle semantic chunker * fixing tests * combine recursive and semantic chunker to stay within chunk size * fixing tests * updating defaults * adding test and fixes * doc updates --- main/settings.py | 30 ++++++++- poetry.lock | 97 +++++++++++++++++++++++++++-- pyproject.toml | 2 + vector_search/conftest.py | 32 ++++++++-- vector_search/encoders/base.py | 28 ++++++--- vector_search/encoders/fastembed.py | 4 +- vector_search/encoders/litellm.py | 4 +- vector_search/utils.py | 49 ++++++++++----- vector_search/utils_test.py | 54 ++++++++++++---- 9 files changed, 246 insertions(+), 54 deletions(-) diff --git a/main/settings.py b/main/settings.py index fe77a24cec..dce0a48e63 100644 --- a/main/settings.py +++ b/main/settings.py @@ -849,9 +849,35 @@ def get_all_config_keys(): AI_MAX_BUDGET = get_float(name="AI_MAX_BUDGET", default=0.05) AI_ANON_LIMIT_MULTIPLIER = get_float(name="AI_ANON_LIMIT_MULTIPLIER", default=10.0) CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE = get_int( - name="CONTENT_FILE_EMBEDDING_CHUNK_SIZE", default=None + name="CONTENT_FILE_EMBEDDING_CHUNK_SIZE", default=512 ) CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP = get_int( name="CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP", - default=200, # default that the tokenizer uses + default=0, # default that the tokenizer uses ) +CONTENT_FILE_EMBEDDING_SEMANTIC_CHUNKING_ENABLED = get_bool( + name="CONTENT_FILE_EMBEDDING_SEMANTIC_CHUNKING_ENABLED", default=False +) + +SEMANTIC_CHUNKING_CONFIG = { + "buffer_size": get_int( + # Number of sentences to combine. + name="SEMANTIC_CHUNKING_BUFFER_SIZE", + default=1, + ), + "breakpoint_threshold_type": get_string( + # 'percentile', 'standard_deviation', 'interquartile', or 'gradient' + name="SEMANTIC_CHUNKING_BREAKPOINT_THRESHOLD_TYPE", + default="percentile", + ), + "breakpoint_threshold_amount": get_float( + # value we use for breakpoint_threshold_type to filter outliers + name="SEMANTIC_CHUNKING_BREAKPOINT_THRESHOLD_AMOUNT", + default=None, + ), + "number_of_chunks": get_int( + # number of chunks to consider for merging + name="SEMANTIC_CHUNKING_NUMBER_OF_CHUNKS", + default=None, + ), +} diff --git a/poetry.lock b/poetry.lock index cbb0e40918..2f9656530c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2651,6 +2651,17 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "httpx-sse" +version = "0.4.0" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"}, + {file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"}, +] + [[package]] name = "huggingface-hub" version = "0.27.0" @@ -3119,26 +3130,82 @@ requests = ">=2,<3" SQLAlchemy = ">=1.4,<3" tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10" +[[package]] +name = "langchain-community" +version = "0.3.13" +description = "Community contributed LangChain integrations." +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "langchain_community-0.3.13-py3-none-any.whl", hash = "sha256:d6f623d59d44ea85b9cafac8ca6e9970c343bf892a9a3252b3f25e30df339be2"}, + {file = "langchain_community-0.3.13.tar.gz", hash = "sha256:8abe05b4ab160018dbd368f7b4dc79ef6a66178a461b47c4ce401d430337e0c2"}, +] + +[package.dependencies] +aiohttp = ">=3.8.3,<4.0.0" +dataclasses-json = ">=0.5.7,<0.7" +httpx-sse = ">=0.4.0,<0.5.0" +langchain = ">=0.3.13,<0.4.0" +langchain-core = ">=0.3.27,<0.4.0" +langsmith = ">=0.1.125,<0.3" +numpy = {version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""} +pydantic-settings = ">=2.4.0,<3.0.0" +PyYAML = ">=5.3" +requests = ">=2,<3" +SQLAlchemy = ">=1.4,<3" +tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10" + [[package]] name = "langchain-core" -version = "0.3.28" +version = "0.3.32" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_core-0.3.28-py3-none-any.whl", hash = "sha256:a02f81ca53a8eed757133797e5a602ca80c1324bbecb0c5d86ef7bd3d6625372"}, - {file = "langchain_core-0.3.28.tar.gz", hash = "sha256:407f7607e6b3c0ebfd6094da95d39b701e22e59966698ef126799782953e7f2c"}, + {file = "langchain_core-0.3.32-py3-none-any.whl", hash = "sha256:c050bd1e6dd556ae49073d338aca9dca08b7b55f4778ddce881a12224bc82a7e"}, + {file = "langchain_core-0.3.32.tar.gz", hash = "sha256:4eb85d8428585e67a1766e29c6aa2f246c6329d97cb486e8d6f564ab0bd94a4f"}, ] [package.dependencies] jsonpatch = ">=1.33,<2.0" -langsmith = ">=0.1.125,<0.3" +langsmith = ">=0.1.125,<0.4" packaging = ">=23.2,<25" pydantic = {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""} PyYAML = ">=5.3" tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0" typing-extensions = ">=4.7" +[[package]] +name = "langchain-experimental" +version = "0.3.4" +description = "Building applications with LLMs through composability" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "langchain_experimental-0.3.4-py3-none-any.whl", hash = "sha256:2e587306aea36b60fa5e5fc05dc7281bee9f60a806f0bf9d30916e0ee096af80"}, + {file = "langchain_experimental-0.3.4.tar.gz", hash = "sha256:937c4259ee4a639c618d19acf0e2c5c2898ef127050346edc5655259aa281a21"}, +] + +[package.dependencies] +langchain-community = ">=0.3.0,<0.4.0" +langchain-core = ">=0.3.28,<0.4.0" + +[[package]] +name = "langchain-openai" +version = "0.3.2" +description = "An integration package connecting OpenAI and LangChain" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "langchain_openai-0.3.2-py3-none-any.whl", hash = "sha256:8674183805e26d3ae3f78cc44f79fe0b2066f61e2de0e7e18be3b86f0d3b2759"}, + {file = "langchain_openai-0.3.2.tar.gz", hash = "sha256:c2c80ac0208eb7cefdef96f6353b00fa217979ffe83f0a21cc8666001df828c1"}, +] + +[package.dependencies] +langchain-core = ">=0.3.31,<0.4.0" +openai = ">=1.58.1,<2.0.0" +tiktoken = ">=0.7,<1" + [[package]] name = "langchain-text-splitters" version = "0.3.4" @@ -5298,6 +5365,26 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pydantic-settings" +version = "2.7.1" +description = "Settings management using Pydantic" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic_settings-2.7.1-py3-none-any.whl", hash = "sha256:590be9e6e24d06db33a4262829edef682500ef008565a969c73d39d5f8bfb3fd"}, + {file = "pydantic_settings-2.7.1.tar.gz", hash = "sha256:10c9caad35e64bfb3c2fbf70a078c0e25cc92499782e5200747f942a065dec93"}, +] + +[package.dependencies] +pydantic = ">=2.7.0" +python-dotenv = ">=0.21.0" + +[package.extras] +azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0)"] +toml = ["tomli (>=2.0.1)"] +yaml = ["pyyaml (>=6.0.1)"] + [[package]] name = "pygithub" version = "2.5.0" @@ -7744,4 +7831,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = "3.12.6" -content-hash = "7fdc6223d18db72747764790ef7ee561547c3d67bc44a4dc7e08801a4300ef6f" +content-hash = "2b428052f5b07e608004f893695f615c063a8b34a28c19ff359df8cbe955392f" diff --git a/pyproject.toml b/pyproject.toml index 4cccb53d4b..988e4a4283 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,8 @@ tiktoken = "^0.8.0" llama-index = "^0.12.6" llama-index-llms-openai = "^0.3.12" llama-index-agent-openai = "^0.4.1" +langchain-experimental = "^0.3.4" +langchain-openai = "^0.3.2" [tool.poetry.group.dev.dependencies] diff --git a/vector_search/conftest.py b/vector_search/conftest.py index 21fad20e3d..a95efb1552 100644 --- a/vector_search/conftest.py +++ b/vector_search/conftest.py @@ -14,10 +14,10 @@ class DummyEmbedEncoder(BaseEncoder): def __init__(self, model_name="dummy-embedding"): self.model_name = model_name - def encode(self, text: str) -> list: # noqa: ARG002 + def embed(self, text: str) -> list: # noqa: ARG002 return np.random.random((10, 1)) - def encode_batch(self, texts: list[str]) -> list[list[float]]: + def embed_documents(self, texts: list[str]) -> list[list[float]]: return np.random.random((10, len(texts))) @@ -32,13 +32,37 @@ def _use_test_qdrant_settings(settings, mocker): settings.QDRANT_HOST = "https://test" settings.QDRANT_BASE_COLLECTION_NAME = "test" settings.CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP = 0 + settings.CONTENT_FILE_EMBEDDING_SEMANTIC_CHUNKING_ENABLED = False mock_qdrant = mocker.patch("qdrant_client.QdrantClient") + mocker.patch("vector_search.utils.SemanticChunker") + mock_qdrant.scroll.return_value = [ [], None, ] - get_text_splitter_patch = mocker.patch("vector_search.utils._get_text_splitter") - get_text_splitter_patch.return_value = RecursiveCharacterTextSplitter() + get_text_splitter_patch = mocker.patch("vector_search.utils._chunk_documents") + get_text_splitter_patch.return_value = ( + RecursiveCharacterTextSplitter().create_documents( + texts=["test dociment"], + metadatas=[ + { + "run_title": "", + "platform": "", + "offered_by": "", + "run_readable_id": "", + "resource_readable_id": "", + "content_type": "", + "file_extension": "", + "content_feature_type": "", + "course_number": "", + "file_type": "", + "description": "", + "key": "", + "url": "", + } + ], + ) + ) mock_qdrant.count.return_value = CountResult(count=10) mocker.patch( "vector_search.utils.qdrant_client", diff --git a/vector_search/encoders/base.py b/vector_search/encoders/base.py index 4fe9b7269d..2ff4e9d844 100644 --- a/vector_search/encoders/base.py +++ b/vector_search/encoders/base.py @@ -1,9 +1,12 @@ from abc import ABC, abstractmethod +from langchain_core.embeddings import Embeddings -class BaseEncoder(ABC): + +class BaseEncoder(Embeddings, ABC): """ - Base encoder class + Base encoder class. + Conforms to the langchain Embeddings interface """ def model_short_name(self): @@ -17,21 +20,26 @@ def model_short_name(self): model_name = split_model_name[1] return model_name - def encode(self, text): + def embed(self, text): """ Embed a single text """ - return next(iter(self.encode_batch([text]))) + return next(iter(self.embed_documents([text]))) + + def dim(self): + """ + Return the dimension of the embeddings + """ + return len(self.embed("test")) @abstractmethod - def encode_batch(self, texts: list[str]) -> list[list[float]]: + def embed_documents(self, documents): """ - Embed multiple texts + Embed a list of documents """ - return [self.encode(text) for text in texts] - def dim(self): + def embed_query(self, query): """ - Return the dimension of the embeddings + Embed a query """ - return len(self.encode("test")) + return self.embed(query) diff --git a/vector_search/encoders/fastembed.py b/vector_search/encoders/fastembed.py index ccaf21e089..dabc65ac33 100644 --- a/vector_search/encoders/fastembed.py +++ b/vector_search/encoders/fastembed.py @@ -12,8 +12,8 @@ def __init__(self, model_name="BAAI/bge-small-en-v1.5"): self.model_name = model_name self.model = TextEmbedding(model_name=model_name, lazy_load=True) - def encode_batch(self, texts: list[str]) -> list[list[float]]: - return self.model.embed(texts) + def embed_documents(self, documents): + return list(self.model.embed(documents)) def dim(self): """ diff --git a/vector_search/encoders/litellm.py b/vector_search/encoders/litellm.py index bb0df1e701..5dd81d04a7 100644 --- a/vector_search/encoders/litellm.py +++ b/vector_search/encoders/litellm.py @@ -24,8 +24,8 @@ def __init__(self, model_name="text-embedding-3-small"): msg = f"Model {model_name} not found in tiktoken. defaulting to None" log.warning(msg) - def encode_batch(self, texts: list[str]) -> list[list[float]]: - return [result["embedding"] for result in self.get_embedding(texts)["data"]] + def embed_documents(self, documents): + return [result["embedding"] for result in self.get_embedding(documents)["data"]] def get_embedding(self, texts): if settings.LITELLM_CUSTOM_PROVIDER and settings.LITELLM_API_BASE: diff --git a/vector_search/utils.py b/vector_search/utils.py index 26d47bd024..054ae4e1d4 100644 --- a/vector_search/utils.py +++ b/vector_search/utils.py @@ -2,7 +2,8 @@ import uuid from django.conf import settings -from langchain.text_splitter import TokenTextSplitter +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_experimental.text_splitter import SemanticChunker from qdrant_client import QdrantClient, models from learning_resources.models import LearningResource @@ -173,25 +174,42 @@ def _process_resource_embeddings(serialized_resources): docs.append( f"{doc.get('title')} {doc.get('description')} {doc.get('full_description')}" ) - embeddings = encoder.encode_batch(docs) + embeddings = encoder.embed_documents(docs) return points_generator(ids, metadata, embeddings, vector_name) -def _get_text_splitter(encoder): - """ - Get the text splitter to use based on the encoder - """ +def _chunk_documents(encoder, texts, metadatas): + # chunk the documents. use semantic chunking if enabled chunk_params = { "chunk_overlap": settings.CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP, } - if hasattr(encoder, "token_encoding_name") and encoder.token_encoding_name: - chunk_params["encoding_name"] = encoder.token_encoding_name if settings.CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE: chunk_params["chunk_size"] = settings.CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE - - # leverage tiktoken to ensure we stay within token limits - return TokenTextSplitter(**chunk_params) + if settings.LITELLM_TOKEN_ENCODING_NAME: + """ + If we have a specific model encoding + we can use that to stay within token limits + """ + chunk_params["encoding_name"] = settings.LITELLM_TOKEN_ENCODING_NAME + recursive_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( + **chunk_params + ) + else: + recursive_splitter = RecursiveCharacterTextSplitter(**chunk_params) + + if settings.CONTENT_FILE_EMBEDDING_SEMANTIC_CHUNKING_ENABLED: + """ + If semantic chunking is enabled, + use the semantic chunker then recursive splitter + to stay within chunk size limits + """ + return recursive_splitter.split_documents( + SemanticChunker( + encoder, **settings.SEMANTIC_CHUNKING_CONFIG + ).create_documents(texts=texts, metadatas=metadatas) + ) + return recursive_splitter.create_documents(texts=texts, metadatas=metadatas) def _process_content_embeddings(serialized_content): @@ -202,13 +220,10 @@ def _process_content_embeddings(serialized_content): client = qdrant_client() encoder = dense_encoder() vector_name = encoder.model_short_name() - text_splitter = _get_text_splitter(encoder) for doc in serialized_content: if not doc.get("content"): continue - split_docs = text_splitter.create_documents( - texts=[doc.get("content")], metadatas=[doc] - ) + split_docs = _chunk_documents(encoder, [doc.get("content")], [doc]) split_texts = [d.page_content for d in split_docs if d.page_content] resource_vector_point_id = vector_point_id(doc["resource_readable_id"]) split_metadatas = [ @@ -245,7 +260,7 @@ def _process_content_embeddings(serialized_content): ) for md in split_metadatas ] - split_embeddings = list(encoder.encode_batch(split_texts)) + split_embeddings = list(encoder.embed_documents(split_texts)) if len(split_embeddings) > 0: resource_points.append( models.PointVectors( @@ -351,7 +366,7 @@ def vector_search( search_result = client.query_points( collection_name=search_collection, using=encoder.model_short_name(), - query=encoder.encode(query_string), + query=encoder.embed_query(query_string), query_filter=search_filter, limit=limit, offset=offset, diff --git a/vector_search/utils_test.py b/vector_search/utils_test.py index 97ccec8ede..fb7e83a3ec 100644 --- a/vector_search/utils_test.py +++ b/vector_search/utils_test.py @@ -12,7 +12,7 @@ ) from vector_search.encoders.utils import dense_encoder from vector_search.utils import ( - _get_text_splitter, + _chunk_documents, create_qdrand_collections, embed_learning_resources, filter_existing_qdrant_points, @@ -230,19 +230,49 @@ def test_qdrant_query_conditions(mocker): ) -def test_get_text_splitter(mocker): +def test_document_chunker(mocker): """ Test that the correct splitter is returned based on encoder """ settings.CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE = None + settings.CONTENT_FILE_EMBEDDING_SEMANTIC_CHUNKING_ENABLED = True + settings.LITELLM_TOKEN_ENCODING_NAME = None encoder = dense_encoder() encoder.token_encoding_name = None - mocked_splitter = mocker.patch("vector_search.utils.TokenTextSplitter") - _get_text_splitter(encoder) - assert "encoding_name" not in mocked_splitter.mock_calls[0].kwargs - encoder.token_encoding_name = "cl100k_base" # noqa: S105 - _get_text_splitter(encoder) - assert "encoding_name" in mocked_splitter.mock_calls[1].kwargs + mocked_splitter = mocker.patch("vector_search.utils.RecursiveCharacterTextSplitter") + mocked_chunker = mocker.patch("vector_search.utils.SemanticChunker") + _chunk_documents(encoder, ["this is a test document"], [{}]) + + mocked_chunker.assert_called() + mocked_splitter.assert_called() + + settings.CONTENT_FILE_EMBEDDING_SEMANTIC_CHUNKING_ENABLED = False + + mocked_splitter = mocker.patch("vector_search.utils.RecursiveCharacterTextSplitter") + mocked_chunker = mocker.patch("vector_search.utils.SemanticChunker") + + _chunk_documents(encoder, ["this is a test document"], [{}]) + mocked_chunker.assert_not_called() + mocked_splitter.assert_called() + + +def test_document_chunker_tiktoken(mocker): + """ + Test that we use tiktoken if a token encoding is specified + """ + settings.LITELLM_TOKEN_ENCODING_NAME = None + encoder = dense_encoder() + encoder.token_encoding_name = None + mocked_splitter = mocker.patch( + "vector_search.utils.RecursiveCharacterTextSplitter.from_tiktoken_encoder" + ) + + _chunk_documents(encoder, ["this is a test document"], [{}]) + mocked_splitter.assert_not_called() + + settings.LITELLM_TOKEN_ENCODING_NAME = "test" # noqa: S105 + _chunk_documents(encoder, ["this is a test document"], [{}]) + mocked_splitter.assert_called() def test_text_splitter_chunk_size_override(mocker): @@ -253,11 +283,11 @@ def test_text_splitter_chunk_size_override(mocker): settings.CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE = chunk_size settings.CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP = chunk_size / 10 encoder = dense_encoder() - mocked_splitter = mocker.patch("vector_search.utils.TokenTextSplitter") + mocked_splitter = mocker.patch("vector_search.utils.RecursiveCharacterTextSplitter") encoder.token_encoding_name = "cl100k_base" # noqa: S105 - _get_text_splitter(encoder) + _chunk_documents(encoder, ["this is a test document"], [{}]) assert mocked_splitter.mock_calls[0].kwargs["chunk_size"] == 100 - mocked_splitter = mocker.patch("vector_search.utils.TokenTextSplitter") + mocked_splitter = mocker.patch("vector_search.utils.RecursiveCharacterTextSplitter") settings.CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE = None - _get_text_splitter(encoder) + _chunk_documents(encoder, ["this is a test document"], [{}]) assert "chunk_size" not in mocked_splitter.mock_calls[0].kwargs