Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add window_replacement module at prompt_maker node #559

Merged
merged 2 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autorag/nodes/promptmaker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .fstring import fstring
from .long_context_reorder import long_context_reorder
from .window_replacement import window_replacement
11 changes: 10 additions & 1 deletion autorag/nodes/promptmaker/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import functools
import logging
import os
from pathlib import Path
from typing import List, Union

import pandas as pd

from autorag.utils import result_to_dataframe
from autorag.utils import result_to_dataframe, fetch_contents

logger = logging.getLogger("AutoRAG")

Expand All @@ -31,6 +32,14 @@ def wrapper(
assert "retrieve_scores" in previous_result.columns, "previous_result must have retrieve_scores column."
retrieve_scores = previous_result["retrieve_scores"].tolist()
return func(prompt, query, retrieved_contents, retrieve_scores)
elif func.__name__ == 'window_replacement':
retrieved_ids = previous_result["retrieved_ids"].tolist()
# load corpus
data_dir = os.path.join(project_dir, "data")
corpus_data = pd.read_parquet(os.path.join(data_dir, "corpus.parquet"), engine='pyarrow')
# get metadata from corpus
retrieved_metadata = fetch_contents(corpus_data, retrieved_ids, column_name='metadata')
vkehfdl1 marked this conversation as resolved.
Show resolved Hide resolved
return func(prompt, query, retrieved_contents, retrieved_metadata)
else:
raise NotImplementedError(f"Module {func.__name__} is not implemented or not supported.")

Expand Down
46 changes: 46 additions & 0 deletions autorag/nodes/promptmaker/window_replacement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import logging
from typing import List, Dict

from autorag.nodes.promptmaker.base import prompt_maker_node

logger = logging.getLogger("AutoRAG")


@prompt_maker_node
def window_replacement(prompt: str, queries: List[str],
retrieved_contents: List[List[str]],
retrieved_metadata: List[List[Dict]]) -> List[str]:
"""
Replace retrieved_contents with window to create a Prompt
(only available for corpus chunked with Sentence window method)
You must type a prompt or prompt list at config yaml file like this:

.. Code:: yaml
nodes:
- node_type: prompt_maker
modules:
- module_type: window_replacement
prompt: [Answer this question: {query} \n\n {retrieved_contents},
Read the passages carefully and answer this question: {query} \n\n Passages: {retrieved_contents}]

:param prompt: A prompt string.
:param queries: List of query strings.
:param retrieved_contents: List of retrieved contents.
:param retrieved_metadata: List of retrieved metadata.
:return: Prompts that made by window_replacement.
"""

def window_replacement_row(_prompt: str, _query: str, _retrieved_contents, _retrieved_metadata: List[Dict]) -> str:
window_list = []
for content, metadata in zip(_retrieved_contents, _retrieved_metadata):
if 'window' in metadata:
window_list.append(metadata['window'])
else:
window_list.append(content)
logger.info("Only available for corpus chunked with Sentence window method."
"window_replacement will not proceed.")
contents_str = "\n\n".join(window_list)
return _prompt.format(query=_query, retrieved_contents=contents_str)

return list(map(lambda x: window_replacement_row(prompt, x[0], x[1], x[2]),
zip(queries, retrieved_contents, retrieved_metadata)))
1 change: 1 addition & 0 deletions autorag/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_support_modules(module_name: str) -> Callable:
# prompt_maker
'fstring': ('autorag.nodes.promptmaker', 'fstring'),
'long_context_reorder': ('autorag.nodes.promptmaker', 'long_context_reorder'),
'window_replacement': ('autorag.nodes.promptmaker', 'window_replacement'),
# generator
'llama_index_llm': ('autorag.nodes.generator', 'llama_index_llm'),
'vllm': ('autorag.nodes.generator', 'vllm'),
Expand Down
8 changes: 8 additions & 0 deletions docs/source/api_spec/autorag.nodes.promptmaker.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ autorag.nodes.promptmaker.run module
:undoc-members:
:show-inheritance:

autorag.nodes.promptmaker.window\_replacement module
----------------------------------------------------

.. automodule:: autorag.nodes.promptmaker.window_replacement
:members:
:undoc-members:
:show-inheritance:

Module contents
---------------

Expand Down
1 change: 1 addition & 0 deletions docs/source/nodes/prompt_maker/prompt_maker.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ maxdepth: 1
---
fstring.md
long_context_reorder.md
window_replacement.md
```
33 changes: 33 additions & 0 deletions docs/source/nodes/prompt_maker/window_replacement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
---
myst:
html_meta:
title: AutoRAG - Window Replacement
description: Learn about Window Replacement module in AutoRAG
keywords: AutoRAG,RAG,Advanced RAG,prompt
---

# Window Replacement

📌Only available for corpus chunked with `sentence window` method

The `window_replacement` module is prompt maker based on based
on [llama_index](https://docs.llamaindex.ai/en/stable/examples/node_postprocessor/MetadataReplacementDemo/).

Replace retrieved_contents with window to create a Prompt. This is most useful for large documents/indexes, as it helps
to retrieve more fine-grained details.

Make a prompt using `window_replacement` from a query and retrieved_contents.

## **Module Parameters**

**prompt**: This is the prompt that will be input to llm. Since it is created using an fstring, it must
contain `{query}` and `{retreived_contents}`.

## **Example config.yaml**

```yaml
modules:
- module_type: window_replacement
prompt: [ "Tell me something about the question: {query} \n\n {retrieved_contents}",
"Question: {query} \n Something to read: {retrieved_contents} \n What's your answer?" ]
```
3 changes: 3 additions & 0 deletions sample_config/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ node_lines:
- module_type: long_context_reorder
prompt: [ "Tell me something about the question: {query} \n\n {retrieved_contents}",
"Question: {query} \n Something to read: {retrieved_contents} \n What's your answer?" ]
- module_type: window_replacement
prompt: [ "Tell me something about the question: {query} \n\n {retrieved_contents}",
"Question: {query} \n Something to read: {retrieved_contents} \n What's your answer?" ]
- node_type: generator
strategy:
metrics:
Expand Down
17 changes: 16 additions & 1 deletion tests/autorag/nodes/promptmaker/test_prompt_maker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,23 @@
["Tokyo is the capital of Japan.", "Tokyo, the capital of Japan, is a huge metropolitan city."],
["Beijing is the capital of China.", "Beijing, the capital of China, is a huge metropolitan city."]]
retrieve_scores = [[0.9, 0.8], [0.9, 0.8]]
retrieved_ids = [["doc1", "doc2"], ["doc3", "doc4"]]
previous_result = pd.DataFrame({
"query": queries,
"retrieved_contents": retrieved_contents,
"retrieve_scores": retrieve_scores
"retrieve_scores": retrieve_scores,
"retrieved_ids": retrieved_ids
})

doc_id = ["doc1", "doc2", "doc3", "doc4", "doc5"]
contents = ["This is a test document 1.", "This is a test document 2.", "This is a test document 3.",
"This is a test document 4.", "This is a test document 5."]
metadata = [{'window': 'havertz arsenal doosan minji naeun gaeun lets go'} for _ in range(5)]
corpus_df = pd.DataFrame({"doc_id": doc_id, "contents": contents, "metadata": metadata})

retrieved_metadata = [
[{'window': 'havertz arsenal doosan minji naeun gaeun lets go'},
{'window': 'havertz arsenal doosan minji naeun gaeun lets go'}],
[{'window': 'havertz arsenal doosan minji naeun gaeun lets go'},
{'window': 'havertz arsenal doosan minji naeun gaeun lets go'}]
vkehfdl1 marked this conversation as resolved.
Show resolved Hide resolved
]
40 changes: 40 additions & 0 deletions tests/autorag/nodes/promptmaker/test_window_replacement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os.path
import tempfile

import pytest

from autorag.nodes.promptmaker import window_replacement
from tests.autorag.nodes.promptmaker.test_prompt_maker_base import (prompt, queries, retrieved_contents,
retrieved_metadata, previous_result, corpus_df)


@pytest.fixture
def pseudo_project_dir():
with tempfile.TemporaryDirectory() as project_dir:
data_dir = os.path.join(project_dir, 'data')
os.makedirs(data_dir)
corpus_df.to_parquet(os.path.join(data_dir, 'corpus.parquet'))
yield project_dir


def test_window_replacement():
window_replacement_original = window_replacement.__wrapped__
result_prompts = window_replacement_original(prompt, queries, retrieved_contents, retrieved_metadata)
assert len(result_prompts) == 2
assert isinstance(result_prompts, list)
assert result_prompts[
0] == "Answer this question: What is the capital of Japan? \n\n havertz arsenal doosan minji naeun gaeun lets go\n\nhavertz arsenal doosan minji naeun gaeun lets go"
assert result_prompts[
1] == "Answer this question: What is the capital of China? \n\n havertz arsenal doosan minji naeun gaeun lets go\n\nhavertz arsenal doosan minji naeun gaeun lets go"


def test_window_replacement_node(pseudo_project_dir):
result = window_replacement(project_dir=pseudo_project_dir,
previous_result=previous_result,
prompt=prompt)
assert len(result) == 2
assert result.columns == ["prompts"]
assert result['prompts'][
0] == "Answer this question: What is the capital of Japan? \n\n havertz arsenal doosan minji naeun gaeun lets go\n\nhavertz arsenal doosan minji naeun gaeun lets go"
assert result['prompts'][
1] == "Answer this question: What is the capital of China? \n\n havertz arsenal doosan minji naeun gaeun lets go\n\nhavertz arsenal doosan minji naeun gaeun lets go"