From 697a162d586dd3cb0130535a1c1ddcb5a8a53db6 Mon Sep 17 00:00:00 2001 From: kimbwook Date: Fri, 5 Jul 2024 17:54:02 +0900 Subject: [PATCH 1/2] Add window_replacement module at prompt_maker node --- autorag/nodes/promptmaker/__init__.py | 1 + autorag/nodes/promptmaker/base.py | 11 ++++- .../nodes/promptmaker/window_replacement.py | 45 +++++++++++++++++++ autorag/support.py | 1 + .../api_spec/autorag.nodes.promptmaker.rst | 8 ++++ .../source/nodes/prompt_maker/prompt_maker.md | 1 + .../nodes/prompt_maker/window_replacement.md | 33 ++++++++++++++ sample_config/full.yaml | 3 ++ .../promptmaker/test_prompt_maker_base.py | 17 ++++++- .../promptmaker/test_window_replacement.py | 40 +++++++++++++++++ 10 files changed, 158 insertions(+), 2 deletions(-) create mode 100644 autorag/nodes/promptmaker/window_replacement.py create mode 100644 docs/source/nodes/prompt_maker/window_replacement.md create mode 100644 tests/autorag/nodes/promptmaker/test_window_replacement.py diff --git a/autorag/nodes/promptmaker/__init__.py b/autorag/nodes/promptmaker/__init__.py index 05337ff11..226292a5e 100644 --- a/autorag/nodes/promptmaker/__init__.py +++ b/autorag/nodes/promptmaker/__init__.py @@ -1,2 +1,3 @@ from .fstring import fstring from .long_context_reorder import long_context_reorder +from .window_replacement import window_replacement diff --git a/autorag/nodes/promptmaker/base.py b/autorag/nodes/promptmaker/base.py index 1002abb25..4737af246 100644 --- a/autorag/nodes/promptmaker/base.py +++ b/autorag/nodes/promptmaker/base.py @@ -1,11 +1,12 @@ import functools import logging +import os from pathlib import Path from typing import List, Union import pandas as pd -from autorag.utils import result_to_dataframe +from autorag.utils import result_to_dataframe, fetch_contents logger = logging.getLogger("AutoRAG") @@ -31,6 +32,14 @@ def wrapper( assert "retrieve_scores" in previous_result.columns, "previous_result must have retrieve_scores column." retrieve_scores = previous_result["retrieve_scores"].tolist() return func(prompt, query, retrieved_contents, retrieve_scores) + elif func.__name__ == 'window_replacement': + retrieved_ids = previous_result["retrieved_ids"].tolist() + # load corpus + data_dir = os.path.join(project_dir, "data") + corpus_data = pd.read_parquet(os.path.join(data_dir, "corpus.parquet"), engine='pyarrow') + # get metadata from corpus + retrieved_metadata = fetch_contents(corpus_data, retrieved_ids, column_name='metadata') + return func(prompt, query, retrieved_contents, retrieved_metadata) else: raise NotImplementedError(f"Module {func.__name__} is not implemented or not supported.") diff --git a/autorag/nodes/promptmaker/window_replacement.py b/autorag/nodes/promptmaker/window_replacement.py new file mode 100644 index 000000000..02efc0583 --- /dev/null +++ b/autorag/nodes/promptmaker/window_replacement.py @@ -0,0 +1,45 @@ +import logging +from typing import List, Dict + +from autorag.nodes.promptmaker.base import prompt_maker_node + +logger = logging.getLogger("AutoRAG") + + +@prompt_maker_node +def window_replacement(prompt: str, queries: List[str], + retrieved_contents: List[List[str]], + retrieved_metadata: List[List[Dict]]) -> List[str]: + """ + Replace retrieved_contents with window to create a Prompt + (only available for corpus chunked with Sentence window method) + You must type a prompt or prompt list at config yaml file like this: + + .. Code:: yaml + nodes: + - node_type: prompt_maker + modules: + - module_type: window_replacement + prompt: [Answer this question: {query} \n\n {retrieved_contents}, + Read the passages carefully and answer this question: {query} \n\n Passages: {retrieved_contents}] + + :param prompt: A prompt string. + :param queries: List of query strings. + :param retrieved_contents: List of retrieved contents. + :param retrieved_metadata: List of retrieved metadata. + :return: Prompts that made by window_replacement. + """ + + def window_replacement_row(_prompt: str, _query: str, _retrieved_contents, _retrieved_metadata: List[Dict]) -> str: + window_list = [] + for content, metadata in zip(_retrieved_contents, _retrieved_metadata): + if 'window' in metadata: + window_list.append(metadata['window']) + else: + window_list.append(content) + logger.info("If you use a summarizer, the reorder will not proceed.") + contents_str = "\n\n".join(window_list) + return _prompt.format(query=_query, retrieved_contents=contents_str) + + return list(map(lambda x: window_replacement_row(prompt, x[0], x[1], x[2]), + zip(queries, retrieved_contents, retrieved_metadata))) diff --git a/autorag/support.py b/autorag/support.py index 22089acdb..f9f33e79b 100644 --- a/autorag/support.py +++ b/autorag/support.py @@ -58,6 +58,7 @@ def get_support_modules(module_name: str) -> Callable: # prompt_maker 'fstring': ('autorag.nodes.promptmaker', 'fstring'), 'long_context_reorder': ('autorag.nodes.promptmaker', 'long_context_reorder'), + 'window_replacement': ('autorag.nodes.promptmaker', 'window_replacement'), # generator 'llama_index_llm': ('autorag.nodes.generator', 'llama_index_llm'), 'vllm': ('autorag.nodes.generator', 'vllm'), diff --git a/docs/source/api_spec/autorag.nodes.promptmaker.rst b/docs/source/api_spec/autorag.nodes.promptmaker.rst index dbaa8136c..feb44c807 100644 --- a/docs/source/api_spec/autorag.nodes.promptmaker.rst +++ b/docs/source/api_spec/autorag.nodes.promptmaker.rst @@ -36,6 +36,14 @@ autorag.nodes.promptmaker.run module :undoc-members: :show-inheritance: +autorag.nodes.promptmaker.window\_replacement module +---------------------------------------------------- + +.. automodule:: autorag.nodes.promptmaker.window_replacement + :members: + :undoc-members: + :show-inheritance: + Module contents --------------- diff --git a/docs/source/nodes/prompt_maker/prompt_maker.md b/docs/source/nodes/prompt_maker/prompt_maker.md index 76a593daf..dd7d6830d 100644 --- a/docs/source/nodes/prompt_maker/prompt_maker.md +++ b/docs/source/nodes/prompt_maker/prompt_maker.md @@ -68,4 +68,5 @@ maxdepth: 1 --- fstring.md long_context_reorder.md +window_replacement.md ``` \ No newline at end of file diff --git a/docs/source/nodes/prompt_maker/window_replacement.md b/docs/source/nodes/prompt_maker/window_replacement.md new file mode 100644 index 000000000..9908682f4 --- /dev/null +++ b/docs/source/nodes/prompt_maker/window_replacement.md @@ -0,0 +1,33 @@ +--- +myst: + html_meta: + title: AutoRAG - Window Replacement + description: Learn about Window Replacement module in AutoRAG + keywords: AutoRAG,RAG,Advanced RAG,prompt +--- + +# Window Replacement + +📌Only available for corpus chunked with `sentence window` method + +The `window_replacement` module is prompt maker based on based +on [llama_index](https://docs.llamaindex.ai/en/stable/examples/node_postprocessor/MetadataReplacementDemo/). + +Replace retrieved_contents with window to create a Prompt. This is most useful for large documents/indexes, as it helps +to retrieve more fine-grained details. + +Make a prompt using `window_replacement` from a query and retrieved_contents. + +## **Module Parameters** + +**prompt**: This is the prompt that will be input to llm. Since it is created using an fstring, it must +contain `{query}` and `{retreived_contents}`. + +## **Example config.yaml** + +```yaml +modules: + - module_type: window_replacement + prompt: [ "Tell me something about the question: {query} \n\n {retrieved_contents}", + "Question: {query} \n Something to read: {retrieved_contents} \n What's your answer?" ] +``` diff --git a/sample_config/full.yaml b/sample_config/full.yaml index 4e3795060..51a3744a2 100644 --- a/sample_config/full.yaml +++ b/sample_config/full.yaml @@ -140,6 +140,9 @@ node_lines: - module_type: long_context_reorder prompt: [ "Tell me something about the question: {query} \n\n {retrieved_contents}", "Question: {query} \n Something to read: {retrieved_contents} \n What's your answer?" ] + - module_type: window_replacement + prompt: [ "Tell me something about the question: {query} \n\n {retrieved_contents}", + "Question: {query} \n Something to read: {retrieved_contents} \n What's your answer?" ] - node_type: generator strategy: metrics: diff --git a/tests/autorag/nodes/promptmaker/test_prompt_maker_base.py b/tests/autorag/nodes/promptmaker/test_prompt_maker_base.py index ab6e466c3..d670b2945 100644 --- a/tests/autorag/nodes/promptmaker/test_prompt_maker_base.py +++ b/tests/autorag/nodes/promptmaker/test_prompt_maker_base.py @@ -6,8 +6,23 @@ ["Tokyo is the capital of Japan.", "Tokyo, the capital of Japan, is a huge metropolitan city."], ["Beijing is the capital of China.", "Beijing, the capital of China, is a huge metropolitan city."]] retrieve_scores = [[0.9, 0.8], [0.9, 0.8]] +retrieved_ids = [["doc1", "doc2"], ["doc3", "doc4"]] previous_result = pd.DataFrame({ "query": queries, "retrieved_contents": retrieved_contents, - "retrieve_scores": retrieve_scores + "retrieve_scores": retrieve_scores, + "retrieved_ids": retrieved_ids }) + +doc_id = ["doc1", "doc2", "doc3", "doc4", "doc5"] +contents = ["This is a test document 1.", "This is a test document 2.", "This is a test document 3.", + "This is a test document 4.", "This is a test document 5."] +metadata = [{'window': 'havertz arsenal doosan minji naeun gaeun lets go'} for _ in range(5)] +corpus_df = pd.DataFrame({"doc_id": doc_id, "contents": contents, "metadata": metadata}) + +retrieved_metadata = [ + [{'window': 'havertz arsenal doosan minji naeun gaeun lets go'}, + {'window': 'havertz arsenal doosan minji naeun gaeun lets go'}], + [{'window': 'havertz arsenal doosan minji naeun gaeun lets go'}, + {'window': 'havertz arsenal doosan minji naeun gaeun lets go'}] +] diff --git a/tests/autorag/nodes/promptmaker/test_window_replacement.py b/tests/autorag/nodes/promptmaker/test_window_replacement.py new file mode 100644 index 000000000..779c1896f --- /dev/null +++ b/tests/autorag/nodes/promptmaker/test_window_replacement.py @@ -0,0 +1,40 @@ +import os.path +import tempfile + +import pytest + +from autorag.nodes.promptmaker import window_replacement +from tests.autorag.nodes.promptmaker.test_prompt_maker_base import (prompt, queries, retrieved_contents, + retrieved_metadata, previous_result, corpus_df) + + +@pytest.fixture +def pseudo_project_dir(): + with tempfile.TemporaryDirectory() as project_dir: + data_dir = os.path.join(project_dir, 'data') + os.makedirs(data_dir) + corpus_df.to_parquet(os.path.join(data_dir, 'corpus.parquet')) + yield project_dir + + +def test_window_replacement(): + window_replacement_original = window_replacement.__wrapped__ + result_prompts = window_replacement_original(prompt, queries, retrieved_contents, retrieved_metadata) + assert len(result_prompts) == 2 + assert isinstance(result_prompts, list) + assert result_prompts[ + 0] == "Answer this question: What is the capital of Japan? \n\n havertz arsenal doosan minji naeun gaeun lets go\n\nhavertz arsenal doosan minji naeun gaeun lets go" + assert result_prompts[ + 1] == "Answer this question: What is the capital of China? \n\n havertz arsenal doosan minji naeun gaeun lets go\n\nhavertz arsenal doosan minji naeun gaeun lets go" + + +def test_window_replacement_node(pseudo_project_dir): + result = window_replacement(project_dir=pseudo_project_dir, + previous_result=previous_result, + prompt=prompt) + assert len(result) == 2 + assert result.columns == ["prompts"] + assert result['prompts'][ + 0] == "Answer this question: What is the capital of Japan? \n\n havertz arsenal doosan minji naeun gaeun lets go\n\nhavertz arsenal doosan minji naeun gaeun lets go" + assert result['prompts'][ + 1] == "Answer this question: What is the capital of China? \n\n havertz arsenal doosan minji naeun gaeun lets go\n\nhavertz arsenal doosan minji naeun gaeun lets go" From afda497ba9ae7dcb880fd85925a0369c74ab69fa Mon Sep 17 00:00:00 2001 From: kimbwook Date: Fri, 5 Jul 2024 17:55:43 +0900 Subject: [PATCH 2/2] change logger ment --- autorag/nodes/promptmaker/window_replacement.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/autorag/nodes/promptmaker/window_replacement.py b/autorag/nodes/promptmaker/window_replacement.py index 02efc0583..f5def71d3 100644 --- a/autorag/nodes/promptmaker/window_replacement.py +++ b/autorag/nodes/promptmaker/window_replacement.py @@ -37,7 +37,8 @@ def window_replacement_row(_prompt: str, _query: str, _retrieved_contents, _retr window_list.append(metadata['window']) else: window_list.append(content) - logger.info("If you use a summarizer, the reorder will not proceed.") + logger.info("Only available for corpus chunked with Sentence window method." + "window_replacement will not proceed.") contents_str = "\n\n".join(window_list) return _prompt.format(query=_query, retrieved_contents=contents_str)