diff --git a/sweepai/agents/query_filter_agent.py b/sweepai/agents/query_filter_agent.py new file mode 100644 index 0000000000..55700f6b7a --- /dev/null +++ b/sweepai/agents/query_filter_agent.py @@ -0,0 +1,22 @@ +from sweepai.core.prompts import doc_query_rewriter_prompt +from sweepai.utils.openai_proxy import OpenAIProxy +from sweepai.watch import logger + + +class QueryFilterAgent: + def __init__(self): + self.openai_proxy = OpenAIProxy() + + def filter_query(self, search_query: str) -> str: + try: + prompt = doc_query_rewriter_prompt.format(issue=search_query) + filtered_query = self.openai_proxy.call_openai( + model="gpt-3.5-turbo", + messages=[{"role": "system", "content": prompt}], + max_tokens=60, + temperature=0.0 + ) + return filtered_query + except Exception as e: + logger.error(f"Error filtering query: {e}") + raise e diff --git a/sweepai/utils/ticket_utils.py b/sweepai/utils/ticket_utils.py index 48ec020c59..17ce48284d 100644 --- a/sweepai/utils/ticket_utils.py +++ b/sweepai/utils/ticket_utils.py @@ -15,6 +15,7 @@ from sweepai.utils.event_logger import posthog from sweepai.utils.github_utils import ClonedRepo from sweepai.utils.progress import TicketProgress +from sweepai.agents.query_filter_agent import QueryFilterAgent @file_cache() def get_top_k_snippets( @@ -109,6 +110,7 @@ def fetch_relevant_files( logger.info("Fetching relevant files...") try: search_query = (title + summary + replies_text).strip("\n") + search_query = QueryFilterAgent().filter_query(search_query) replies_text = f"\n{replies_text}" if replies_text else "" formatted_query = (f"{title.strip()}\n{summary.strip()}" + replies_text).strip( "\n"