diff --git a/tests/rag/test_knowledge_base.py b/tests/rag/test_knowledge_base.py index 23b221a057..25a38fe2b4 100644 --- a/tests/rag/test_knowledge_base.py +++ b/tests/rag/test_knowledge_base.py @@ -20,17 +20,18 @@ def test_knowledge_base_get_random_documents(): # Test when k is smaller than the number of documents docs = kb.get_random_documents(3) assert len(docs) == 3 - assert all([doc in kb._documents for doc in docs]) + # Check that all document IDs are unique + assert len(set(doc.id for doc in docs)) == len(docs) # Test when k is equal to the number of documents docs = kb.get_random_documents(5) assert len(docs) == 5 - assert all([doc in kb._documents for doc in docs]) + assert all([doc == kb[doc.id] for doc in docs]) # Test when k is larger than the number of documents docs = kb.get_random_documents(10) assert len(docs) == 10 - assert all([doc in kb._documents for doc in docs]) + assert all([doc == kb[doc.id] for doc in docs]) def test_knowledge_base_creation_from_df():