-
-
Notifications
You must be signed in to change notification settings - Fork 269
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
Showing
22 changed files
with
533 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .prev_next_augmenter import prev_next_augmenter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import functools | ||
import itertools | ||
import logging | ||
import os | ||
from pathlib import Path | ||
from typing import List, Union, Tuple | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
|
||
from autorag import embedding_models | ||
from autorag.evaluate.metric.util import calculate_cosine_similarity | ||
from autorag.utils import result_to_dataframe, validate_qa_dataset, fetch_contents, sort_by_scores | ||
from autorag.utils.util import reconstruct_list, filter_dict_keys | ||
|
||
logger = logging.getLogger("AutoRAG") | ||
|
||
|
||
def passage_augmenter_node(func): | ||
@functools.wraps(func) | ||
@result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"]) | ||
def wrapper( | ||
project_dir: Union[str, Path], | ||
previous_result: pd.DataFrame, | ||
*args, **kwargs) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: | ||
validate_qa_dataset(previous_result) | ||
data_dir = os.path.join(project_dir, "data") | ||
|
||
# find queries columns | ||
assert "query" in previous_result.columns, "previous_result must have query column." | ||
queries = previous_result["query"].tolist() | ||
|
||
# find ids columns | ||
assert "retrieved_ids" in previous_result.columns, "previous_result must have retrieved_ids column." | ||
ids = previous_result["retrieved_ids"].tolist() | ||
|
||
corpus_df = pd.read_parquet(os.path.join(data_dir, "corpus.parquet")) | ||
|
||
if func.__name__ == 'prev_next_augmenter': | ||
slim_corpus_df = corpus_df[["doc_id", "metadata"]] | ||
slim_corpus_df['metadata'] = slim_corpus_df['metadata'].apply(filter_dict_keys, keys=['prev_id', 'next_id']) | ||
|
||
mode = kwargs.pop("mode", 'next') | ||
num_passages = kwargs.pop("num_passages", 1) | ||
|
||
# get augmented ids | ||
ids = func(ids_list=ids, corpus_df=slim_corpus_df, mode=mode, num_passages=num_passages) | ||
else: | ||
ids = func(ids_list=ids, *args, **kwargs) | ||
|
||
# fetch contents from corpus to use augmented ids | ||
contents = fetch_contents(corpus_df, ids) | ||
|
||
# set embedding model for getting scores | ||
embedding_model_str = kwargs.pop("embedding_model", 'openai') | ||
query_embeddings, contents_embeddings = embedding_query_content(queries, contents, embedding_model_str, | ||
batch=128) | ||
|
||
# get scores from calculated cosine similarity | ||
scores = [np.array([calculate_cosine_similarity(query_embedding, x) for x in content_embeddings]).tolist() | ||
for query_embedding, content_embeddings in zip(query_embeddings, contents_embeddings)] | ||
|
||
# sort by scores | ||
df = pd.DataFrame({ | ||
'contents': contents, | ||
'ids': ids, | ||
'scores': scores, | ||
}) | ||
df[['contents', 'ids', 'scores']] = df.apply(sort_by_scores, axis=1, result_type='expand') | ||
augmented_contents, augmented_ids, augmented_scores = \ | ||
df['contents'].tolist(), df['ids'].tolist(), df['scores'].tolist() | ||
|
||
return augmented_contents, augmented_ids, augmented_scores | ||
|
||
return wrapper | ||
|
||
|
||
def embedding_query_content(queries: List[str], contents_list: List[List[str]], | ||
embedding_model: str, batch: int = 128): | ||
embedding_model = embedding_models[embedding_model] | ||
|
||
# Embedding using batch | ||
embedding_model.embed_batch_size = batch | ||
query_embeddings = embedding_model.get_text_embedding_batch(queries) | ||
|
||
content_lengths = list(map(len, contents_list)) | ||
content_embeddings_flatten = embedding_model.get_text_embedding_batch(list( | ||
itertools.chain.from_iterable(contents_list))) | ||
content_embeddings = reconstruct_list(content_embeddings_flatten, content_lengths) | ||
|
||
del embedding_model | ||
if torch.cuda.is_available(): | ||
torch.cuda.empty_cache() | ||
return query_embeddings, content_embeddings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import List | ||
|
||
import pandas as pd | ||
|
||
from autorag.nodes.passageaugmenter.base import passage_augmenter_node | ||
|
||
|
||
@passage_augmenter_node | ||
def prev_next_augmenter(ids_list: List[List[str]], | ||
corpus_df: pd.DataFrame, | ||
num_passages: int = 1, | ||
mode: str = 'next' | ||
) -> List[List[str]]: | ||
""" | ||
Add passages before and/or after the retrieved passage. | ||
For more information, visit https://docs.llamaindex.ai/en/stable/examples/node_postprocessor/PrevNextPostprocessorDemo/. | ||
:param ids_list: The list of lists of ids retrieved | ||
:param corpus_df: The corpus dataframe | ||
:param num_passages: The number of passages to add before and after the retrieved passage | ||
Default is 1. | ||
:param mode: The mode of augmentation | ||
'prev': add passages before the retrieved passage | ||
'next': add passages after the retrieved passage | ||
'both': add passages before and after the retrieved passage | ||
Default is 'next'. | ||
:return: The list of lists of augmented ids | ||
""" | ||
if mode not in ['prev', 'next', 'both']: | ||
raise ValueError(f"mode must be 'prev', 'next', or 'both', but got {mode}") | ||
|
||
augmented_ids = [(lambda ids: prev_next_augmenter_pure(ids, corpus_df, mode, num_passages))(ids) for ids in | ||
ids_list] | ||
|
||
return augmented_ids | ||
|
||
|
||
def prev_next_augmenter_pure(ids: List[str], corpus_df: pd.DataFrame, mode: str, num_passages: int): | ||
def fetch_id_sequence(start_id, key): | ||
sequence = [] | ||
current_id = start_id | ||
for _ in range(num_passages): | ||
current_id = corpus_df.loc[corpus_df['doc_id'] == current_id]['metadata'].values[0].get(key) | ||
if current_id is None: | ||
break | ||
sequence.append(current_id) | ||
return sequence | ||
|
||
augmented_group = [] | ||
for id_ in ids: | ||
current_ids = [id_] | ||
if mode in ['prev', 'both']: | ||
current_ids = fetch_id_sequence(id_, 'prev_id')[::-1] + current_ids | ||
if mode in ['next', 'both']: | ||
current_ids += fetch_id_sequence(id_, 'next_id') | ||
augmented_group.extend(current_ids) | ||
return augmented_group |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import logging | ||
import os | ||
import pathlib | ||
from typing import List, Callable, Dict | ||
|
||
import pandas as pd | ||
|
||
from autorag.nodes.retrieval.run import evaluate_retrieval_node | ||
from autorag.strategy import measure_speed, filter_by_threshold, select_best_average | ||
|
||
logger = logging.getLogger("AutoRAG") | ||
|
||
|
||
def run_passage_augmenter_node(modules: List[Callable], | ||
module_params: List[Dict], | ||
previous_result: pd.DataFrame, | ||
node_line_dir: str, | ||
strategies: Dict, | ||
) -> pd.DataFrame: | ||
if not os.path.exists(node_line_dir): | ||
os.makedirs(node_line_dir) | ||
project_dir = pathlib.PurePath(node_line_dir).parent.parent | ||
retrieval_gt = pd.read_parquet(os.path.join(project_dir, "data", "qa.parquet"))['retrieval_gt'].tolist() | ||
|
||
results, execution_times = zip(*map(lambda task: measure_speed( | ||
task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params))) | ||
average_times = list(map(lambda x: x / len(results[0]), execution_times)) | ||
|
||
# run metrics before filtering | ||
if strategies.get('metrics') is None: | ||
raise ValueError("You must at least one metrics for passage_augmenter evaluation.") | ||
results = list(map(lambda x: evaluate_retrieval_node(x, retrieval_gt, strategies.get('metrics')), results)) | ||
|
||
# save results to folder | ||
save_dir = os.path.join(node_line_dir, "passage_augmenter") # node name | ||
if not os.path.exists(save_dir): | ||
os.makedirs(save_dir) | ||
filepaths = list(map(lambda x: os.path.join(save_dir, f'{x}.parquet'), range(len(modules)))) | ||
list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet | ||
filenames = list(map(lambda x: os.path.basename(x), filepaths)) | ||
|
||
summary_df = pd.DataFrame({ | ||
'filename': filenames, | ||
'module_name': list(map(lambda module: module.__name__, modules)), | ||
'module_params': module_params, | ||
'execution_time': average_times, | ||
**{f'passage_augmenter_{metric}': list(map(lambda result: result[metric].mean(), results)) for metric in | ||
strategies.get('metrics')}, | ||
}) | ||
|
||
# filter by strategies | ||
if strategies.get('speed_threshold') is not None: | ||
results, filenames = filter_by_threshold(results, average_times, strategies['speed_threshold'], filenames) | ||
selected_result, selected_filename = select_best_average(results, strategies.get('metrics'), filenames) | ||
# change metric name columns to passage_augmenter_metric_name | ||
selected_result = selected_result.rename(columns={ | ||
metric_name: f'passage_augmenter_{metric_name}' for metric_name in strategies['metrics']}) | ||
# drop retrieval result columns in previous_result | ||
previous_result = previous_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores']) | ||
best_result = pd.concat([previous_result, selected_result], axis=1) | ||
|
||
# add 'is_best' column to summary file | ||
summary_df['is_best'] = summary_df['filename'] == selected_filename | ||
|
||
# save files | ||
summary_df.to_csv(os.path.join(save_dir, "summary.csv"), index=False) | ||
best_result.to_parquet(os.path.join(save_dir, f'best_{os.path.splitext(selected_filename)[0]}.parquet'), | ||
index=False) | ||
return best_result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .preprocess import validate_qa_dataset, validate_corpus_dataset, cast_qa_dataset, cast_corpus_dataset | ||
from .util import fetch_contents, result_to_dataframe | ||
from .util import fetch_contents, result_to_dataframe, sort_by_scores |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
autorag.nodes.passageaugmenter package | ||
====================================== | ||
|
||
Submodules | ||
---------- | ||
|
||
autorag.nodes.passageaugmenter.base module | ||
------------------------------------------ | ||
|
||
.. automodule:: autorag.nodes.passageaugmenter.base | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
autorag.nodes.passageaugmenter.prev\_next\_augmenter module | ||
----------------------------------------------------------- | ||
|
||
.. automodule:: autorag.nodes.passageaugmenter.prev_next_augmenter | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
autorag.nodes.passageaugmenter.run module | ||
----------------------------------------- | ||
|
||
.. automodule:: autorag.nodes.passageaugmenter.run | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
Module contents | ||
--------------- | ||
|
||
.. automodule:: autorag.nodes.passageaugmenter | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
Oops, something went wrong.