Skip to content

Commit

Permalink
Merge pull request #309 from romartin/AAP-38368-2
Browse files Browse the repository at this point in the history
AAP-38368: Fixing sample prompt tokens calculations, which may cause further available token calculation issues
  • Loading branch information
tisnik authored Jan 23, 2025
2 parents 505ad72 + 4891926 commit e794248
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 10 deletions.
5 changes: 5 additions & 0 deletions ols/src/prompts/prompt_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from ols.customize import prompts


def restructure_rag_context(text: str, model: str) -> str:
"""Restructure rag text by appending special characters.."""
return restructure_rag_context_post(restructure_rag_context_pre(text, model), model)


def restructure_rag_context_pre(text: str, model: str) -> str:
"""Restructure rag text - pre truncation."""
if ModelFamily.GRANITE in model:
Expand Down
13 changes: 11 additions & 2 deletions ols/src/query_helpers/docs_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from ols.app.models.models import RagChunk, SummarizerResponse
from ols.constants import RAG_CONTENT_LIMIT, GenericLLMParameters
from ols.customize import prompts, reranker
from ols.src.prompts.prompt_generator import GeneratePrompt
from ols.src.prompts.prompt_generator import (
GeneratePrompt,
restructure_history,
restructure_rag_context,
)
from ols.src.query_helpers.query_helper import QueryHelper
from ols.utils.token_handler import TokenHandler

Expand Down Expand Up @@ -80,7 +84,12 @@ def _prepare_prompt(
# Use sample text for context/history to get complete prompt
# instruction. This is used to calculate available tokens.
temp_prompt, temp_prompt_input = GeneratePrompt(
query, ["sample"], ["ai: sample"], self.system_prompt
# Sample prompt's context/history must be re-structured for the given model,
# to ensure the further right available token calculation.
query,
[restructure_rag_context("sample", self.model)],
[restructure_history("ai: sample", self.model)],
self._system_prompt,
).generate_prompt(self.model)

available_tokens = token_handler.calculate_and_check_available_tokens(
Expand Down
2 changes: 1 addition & 1 deletion ols/utils/token_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def calculate_and_check_available_tokens(
context_window_size - max_tokens_for_response - prompt_token_count
)

if available_tokens <= 0:
if available_tokens < 0:
limit = context_window_size - max_tokens_for_response
raise PromptTooLongError(
f"Prompt length {prompt_token_count} exceeds LLM "
Expand Down
8 changes: 2 additions & 6 deletions tests/unit/prompts/test_prompt_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from ols.src.prompts.prompt_generator import (
GeneratePrompt,
restructure_history,
restructure_rag_context_post,
restructure_rag_context_pre,
restructure_rag_context,
)

model = [GRANITE_13B_CHAT_V2, GPT35_TURBO]
Expand All @@ -33,10 +32,7 @@

def _restructure_prompt_input(rag_context, conversation_history, model):
"""Restructure prompt input."""
rag_formatted = [
restructure_rag_context_post(restructure_rag_context_pre(text, model), model)
for text in rag_context
]
rag_formatted = [restructure_rag_context(text, model) for text in rag_context]
history_formatted = [
restructure_history(history, model) for history in conversation_history
]
Expand Down
26 changes: 25 additions & 1 deletion tests/unit/query_helpers/test_docs_summarizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Unit tests for DocsSummarizer class."""

import logging
from unittest.mock import ANY, patch
from unittest.mock import ANY, call, patch

import pytest

Expand Down Expand Up @@ -121,6 +121,30 @@ def test_summarize_truncation():
assert summary.history_truncated


@patch("ols.utils.token_handler.RAG_SIMILARITY_CUTOFF", 0.4)
@patch("ols.src.query_helpers.docs_summarizer.LLMChain", new=mock_llm_chain(None))
def test_prepare_prompt_context():
"""Basic test for DocsSummarizer to check re-structuring of context for the 'temp' prompt."""
summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None))
question = "What's the ultimate question with answer 42?"
history = ["human: What is Kubernetes?"]
rag_index = MockLlamaIndex()

with patch(
"ols.src.query_helpers.docs_summarizer.restructure_rag_context",
return_value="patched_history",
) as restructure_rag_context:
summarizer.create_response(question, rag_index, history)
restructure_rag_context.assert_has_calls([call("sample", ANY)])

with patch(
"ols.src.query_helpers.docs_summarizer.restructure_history",
return_value="patched_history",
) as restructure_history:
summarizer.create_response(question, rag_index, history)
restructure_history.assert_has_calls([call("ai: sample", ANY)])


@patch("ols.src.query_helpers.docs_summarizer.LLMChain", new=mock_llm_chain(None))
def test_summarize_no_reference_content():
"""Basic test for DocsSummarizer using mocked index and query engine."""
Expand Down

0 comments on commit e794248

Please sign in to comment.