Skip to content

Commit

Permalink
Merge pull request #1990 from Giskard-AI/GSK-3609-Avoid-redundant-que…
Browse files Browse the repository at this point in the history
…stions-in-data-generation

GSK-3609  Avoid redundant questions in data generation
  • Loading branch information
henchaves authored Oct 31, 2024
2 parents 9fe01f7 + 6498980 commit 4f88d80
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 8 deletions.
11 changes: 11 additions & 0 deletions giskard/rag/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,17 @@ def get_failure_plot(self, question_evaluation: Sequence[dict] = None):
def get_random_document(self):
return self._rng.choice(self._documents)

def get_random_documents(self, n: int, with_replacement=False):
if with_replacement:
return list(self._rng.choice(self._documents, n, replace=True))

docs = list(self._rng.choice(self._documents, min(n, len(self._documents)), replace=False))

if len(docs) <= n:
docs.extend(self._rng.choice(self._documents, n - len(docs), replace=True))

return docs

def get_neighbors(self, seed_document: Document, n_neighbors: int = 4, similarity_threshold: float = 0.2):
seed_embedding = seed_document.embeddings

Expand Down
6 changes: 4 additions & 2 deletions giskard/rag/question_generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ class GenerateFromSingleQuestionMixin:
_question_type: str

def generate_questions(self, knowledge_base: KnowledgeBase, num_questions: int, *args, **kwargs) -> Iterator[Dict]:
for _ in range(num_questions):
docs = knowledge_base.get_random_documents(num_questions)

for doc in docs:
try:
yield self.generate_single_question(knowledge_base, *args, **kwargs)
yield self.generate_single_question(knowledge_base, *args, **kwargs, seed_document=doc)
except Exception as e: # @TODO: specify exceptions
logger.error(f"Encountered error in question generation: {e}. Skipping.")
logger.exception(e)
Expand Down
4 changes: 2 additions & 2 deletions giskard/rag/question_generators/double_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ class DoubleQuestionsGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestio
_question_type = "double"

def generate_single_question(
self, knowledge_base: KnowledgeBase, agent_description: str, language: str
self, knowledge_base: KnowledgeBase, agent_description: str, language: str, seed_document=None
) -> QuestionSample:
seed_document = knowledge_base.get_random_document()
seed_document = seed_document or knowledge_base.get_random_document()
context_documents = knowledge_base.get_neighbors(
seed_document, self._context_neighbors, self._context_similarity_threshold
)
Expand Down
4 changes: 2 additions & 2 deletions giskard/rag/question_generators/oos_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class OutOfScopeGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestionGene
_question_type = "out of scope"

def generate_single_question(
self, knowledge_base: KnowledgeBase, agent_description: str, language: str
self, knowledge_base: KnowledgeBase, agent_description: str, language: str, seed_document=None
) -> QuestionSample:
"""
Generate a question from a list of context documents.
Expand All @@ -87,7 +87,7 @@ def generate_single_question(
Tuple[dict, dict]
The generated question and the metadata of the question.
"""
seed_document = knowledge_base.get_random_document()
seed_document = seed_document or knowledge_base.get_random_document()

context_documents = knowledge_base.get_neighbors(
seed_document, self._context_neighbors, self._context_similarity_threshold
Expand Down
11 changes: 9 additions & 2 deletions giskard/rag/question_generators/simple_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ class SimpleQuestionsGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestio

_question_type = "simple"

def generate_single_question(self, knowledge_base: KnowledgeBase, agent_description: str, language: str) -> dict:
def generate_single_question(
self,
knowledge_base: KnowledgeBase,
agent_description: str,
language: str,
seed_document=None,
) -> dict:
"""
Generate a question from a list of context documents.
Expand All @@ -80,7 +86,8 @@ def generate_single_question(self, knowledge_base: KnowledgeBase, agent_descript
QuestionSample
The generated question and the metadata of the question.
"""
seed_document = knowledge_base.get_random_document()
seed_document = seed_document or knowledge_base.get_random_document()

context_documents = knowledge_base.get_neighbors(
seed_document, self._context_neighbors, self._context_similarity_threshold
)
Expand Down
26 changes: 26 additions & 0 deletions tests/rag/test_knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,32 @@
from giskard.rag.knowledge_base import KnowledgeBase


def test_knowledge_base_get_random_documents():
llm_client = Mock()
embeddings = Mock()
embeddings.embed.side_effect = [np.random.rand(5, 10), np.random.rand(3, 10)]

kb = KnowledgeBase.from_pandas(
df=pd.DataFrame({"text": ["This is a test string"] * 5}), llm_client=llm_client, embedding_model=embeddings
)

# Test when k is smaller than the number of documents
docs = kb.get_random_documents(3)
assert len(docs) == 3
# Check that all document IDs are unique
assert len(set(doc.id for doc in docs)) == len(docs)

# Test when k is equal to the number of documents
docs = kb.get_random_documents(5)
assert len(docs) == 5
assert all([doc == kb[doc.id] for doc in docs])

# Test when k is larger than the number of documents
docs = kb.get_random_documents(10)
assert len(docs) == 10
assert all([doc == kb[doc.id] for doc in docs])


def test_knowledge_base_creation_from_df():
dimension = 8
df = pd.DataFrame(["This is a test string"] * 5)
Expand Down
3 changes: 3 additions & 0 deletions tests/rag/test_question_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_simple_question_generation():
Document(dict(content="Milk is produced by cows, goats or sheep.")),
]
knowledge_base.get_random_document = Mock(return_value=documents[0])
knowledge_base.get_random_documents = Mock(return_value=documents)
knowledge_base.get_neighbors = Mock(return_value=documents)

question_generator = SimpleQuestionsGenerator(llm_client=llm_client)
Expand Down Expand Up @@ -212,6 +213,7 @@ def test_double_question_generation():
Document(dict(content="Milk is produced by cows, goats or sheep.")),
]
knowledge_base.get_random_document = Mock(return_value=documents[0])
knowledge_base.get_random_documents = Mock(return_value=documents)
knowledge_base.get_neighbors = Mock(return_value=documents)

question_generator = DoubleQuestionsGenerator(llm_client=llm_client)
Expand Down Expand Up @@ -304,6 +306,7 @@ def test_oos_question_generation():
dict(content="Paul Graham liked to buy a baguette every day at the local market."), doc_id="1"
)
)
knowledge_base.get_random_documents = Mock(return_value=documents)
knowledge_base.get_neighbors = Mock(return_value=documents)

question_generator = OutOfScopeGenerator(llm_client=llm_client)
Expand Down
2 changes: 2 additions & 0 deletions tests/rag/test_testset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def test_question_generation_fail(caplog):
knowledge_base.__getitem__ = lambda obj, idx: documents[0]
knowledge_base.topics = ["Cheese", "Ski"]

knowledge_base.get_random_documents = Mock(return_value=documents)

simple_gen = Mock()
simple_gen.generate_questions.return_value = [q1, q2]
failing_gen = SimpleQuestionsGenerator(llm_client=Mock())
Expand Down

0 comments on commit 4f88d80

Please sign in to comment.