Skip to content

Commit

Permalink
Use pop params to reduce error at similarity percentile cutoff
Browse files Browse the repository at this point in the history
  • Loading branch information
vkehfdl1 committed Dec 12, 2024
1 parent c14517f commit bd91531
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
11 changes: 5 additions & 6 deletions autorag/nodes/passagefilter/similarity_percentile_cutoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
embedding_query_content,
)
from autorag.utils import result_to_dataframe
from autorag.utils.util import empty_cuda_cache
from autorag.utils.util import empty_cuda_cache, pop_params


class SimilarityPercentileCutoff(BasePassageFilter):
Expand All @@ -24,7 +24,7 @@ def __init__(self, project_dir: Union[str, Path], *args, **kwargs):
Default is "openai" which is OpenAI text-embedding-ada-002 embedding model.
"""
super().__init__(project_dir, *args, **kwargs)
embedding_model_str = kwargs.get("embedding_model", "openai")
embedding_model_str = kwargs.pop("embedding_model", "openai")
self.embedding_model = embedding_models[embedding_model_str]()

def __del__(self):
Expand All @@ -34,11 +34,10 @@ def __del__(self):
empty_cuda_cache()

@result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"])
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
def pure(self, previous_result: pd.DataFrame, **kwargs):
queries, contents, scores, ids = self.cast_to_run(previous_result)
if "embedding_model" in kwargs.keys():
del kwargs["embedding_model"]
return self._pure(queries, contents, scores, ids, *args, **kwargs)
kwargs = pop_params(self._pure, kwargs)
return self._pure(queries, contents, scores, ids, **kwargs)

def _pure(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def test_similarity_percentile_cutoff(similarity_percentile_cutoff_instance):
)
def test_similarity_percentile_cutoff_node():
result_df = SimilarityPercentileCutoff.run_evaluator(
project_dir=project_dir, previous_result=previous_result, percentile=0.9
project_dir=project_dir,
previous_result=previous_result,
percentile=0.9,
embedding_model="openai_embed_3_large",
)
base_passage_filter_node_test(result_df)

0 comments on commit bd91531

Please sign in to comment.