diff --git a/tests/test_rerank.py b/tests/test_rerank.py index 9d37e92..19ced38 100644 --- a/tests/test_rerank.py +++ b/tests/test_rerank.py @@ -3,11 +3,17 @@ import pytest from rerankers.models.flashrank_ranker import FlashRankRanker from rerankers.models.ranker import BaseRanker +from scipy.stats import kendalltau from raglite import RAGLiteConfig, hybrid_search, rerank_chunks, retrieve_chunks from raglite._database import Chunk +def kendall_tau(a: list[Chunk], b: list[Chunk]) -> float: + """Measure the Kendall rank correlation coefficient between two lists.""" + return kendalltau(range(len(a)), [a.index(c) for c in b])[0] + + @pytest.fixture( params=[ pytest.param(None, id="no_reranker"), @@ -47,9 +53,9 @@ def test_reranker( assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True)) # Rerank the chunks given an inverted chunk order. reranked_chunks = rerank_chunks(query, chunks[::-1], config=raglite_test_config) - if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder: - assert reranked_chunks[0] in chunks[:3] + if reranker: + assert kendall_tau(chunks, reranked_chunks) >= kendall_tau(chunks[::-1], reranked_chunks) # Test that we can also rerank given the chunk_ids only. reranked_chunks = rerank_chunks(query, chunk_ids[::-1], config=raglite_test_config) - if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder: - assert reranked_chunks[0] in chunks[:3] + if reranker: + assert kendall_tau(chunks, reranked_chunks) >= kendall_tau(chunks[::-1], reranked_chunks)