diff --git a/holmes/core/investigation.py b/holmes/core/investigation.py index c7d7879..3b1407f 100644 --- a/holmes/core/investigation.py +++ b/holmes/core/investigation.py @@ -1,8 +1,7 @@ -from typing import Optional -from rich.console import Console from holmes.common.env_vars import HOLMES_POST_PROCESSING_PROMPT from holmes.config import Config +from holmes.core.investigation_structured_output import process_response_into_sections from holmes.core.issue import Issue from holmes.core.models import InvestigateRequest, InvestigationResult from holmes.core.supabase_dal import SupabaseDal @@ -36,13 +35,15 @@ def investigate_issues(investigate_request: InvestigateRequest, dal: SupabaseDal issue, prompt=investigate_request.prompt_template, post_processing_prompt=HOLMES_POST_PROCESSING_PROMPT, - sections=investigate_request.sections, instructions=resource_instructions, global_instructions=global_instructions ) + + (text_response, sections) = process_response_into_sections(investigation.result) + return InvestigationResult( - analysis=investigation.result, - sections=investigation.sections, + analysis=text_response, + sections=sections, tool_calls=investigation.tool_calls or [], instructions=investigation.instructions, ) diff --git a/holmes/core/investigation_structured_output.py b/holmes/core/investigation_structured_output.py index e47d28e..abdf6e8 100644 --- a/holmes/core/investigation_structured_output.py +++ b/holmes/core/investigation_structured_output.py @@ -1,6 +1,15 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional, Tuple, Union +import json -DEFAULT_SECTIONS = { +from pydantic import RootModel + +InputSectionsDataType = Dict[str, str] + +OutputSectionsDataType = Optional[Dict[str, Union[str, None]]] + +SectionsData = RootModel[OutputSectionsDataType] + +DEFAULT_SECTIONS:InputSectionsDataType = { "Alert Explanation": "1-2 sentences explaining the alert itself - note don't say \"The alert indicates a warning event related to a Kubernetes pod doing blah\" rather just say \"The pod XYZ did blah\" because that is what the user actually cares about", "Investigation": "What you checked and found", "Conclusions and Possible Root causes": "What conclusions can you reach based on the data you found? what are possible root causes (if you have enough conviction to say) or what uncertainty remains. Don't say root cause but 'possible root causes'. Be clear to distinguish between what you know for certain and what is a possible explanation", @@ -10,7 +19,7 @@ "External links": "Provide links to external sources. Where to look when investigating this issue. For example provide links to relevant runbooks, etc. Add a short sentence describing each link." } -def get_output_format_for_investigation(sections: Dict[str, str]) -> Dict[str, Any]: +def get_output_format_for_investigation(sections: InputSectionsDataType) -> Dict[str, Any]: properties = {} required_fields = [] @@ -34,12 +43,32 @@ def get_output_format_for_investigation(sections: Dict[str, str]) -> Dict[str, A return output_format -def combine_sections(sections: Any) -> str: - if isinstance(sections, dict): - content = '' - for section_title, section_content in sections.items(): - if section_content: - # content = content + f'\n# {" ".join(section_title.split("_")).title()}\n{section_content}' - content = content + f'\n# {section_title}\n{section_content}\n' - return content - return f"{sections}" +def combine_sections(sections: Dict) -> str: + content = '' + for section_title, section_content in sections.items(): + if section_content: + content = content + f'\n# {section_title}\n{section_content}\n' + return content + + +def process_response_into_sections(response: Any) -> Tuple[str, OutputSectionsDataType]: + if isinstance(response, dict): + # No matter if the result is already structured, we want to go through the code below to validate the JSON + response = json.dumps(response) + + if not isinstance(response, str): + # if it's not a string, we make it so as it'll be parsed later + response = str(response) + + + try: + parsed_json = json.loads(response) + # TODO: force dict values into a string would make this more resilient as SectionsData only accept none/str as values + sections = SectionsData(root=parsed_json).root + if sections: + combined = combine_sections(sections) + return (combined, sections) + except Exception: + pass + + return (response, None) diff --git a/holmes/core/models.py b/holmes/core/models.py index 71683cf..1a8ed0e 100644 --- a/holmes/core/models.py +++ b/holmes/core/models.py @@ -1,3 +1,4 @@ +from holmes.core.investigation_structured_output import InputSectionsDataType from holmes.core.tool_calling_llm import ToolCallResult from typing import Optional, List, Dict, Any, Union from pydantic import BaseModel, model_validator @@ -21,7 +22,7 @@ class InvestigateRequest(BaseModel): include_tool_calls: bool = False include_tool_call_results: bool = False prompt_template: str = "builtin://generic_investigation.jinja2" - sections: Optional[Dict[str, str]] = None + sections: Optional[InputSectionsDataType] = None # TODO in the future # response_handler: ... diff --git a/holmes/core/tool_calling_llm.py b/holmes/core/tool_calling_llm.py index b0c1a44..2183409 100644 --- a/holmes/core/tool_calling_llm.py +++ b/holmes/core/tool_calling_llm.py @@ -3,7 +3,7 @@ import logging import textwrap from typing import List, Optional, Dict, Type, Union -from holmes.core.investigation_structured_output import DEFAULT_SECTIONS, get_output_format_for_investigation, combine_sections +from holmes.core.investigation_structured_output import DEFAULT_SECTIONS, InputSectionsDataType, get_output_format_for_investigation from holmes.core.performance_timing import PerformanceTiming from holmes.utils.tags import format_tags_in_string, parse_messages_tags from holmes.plugins.prompts import load_and_render_prompt @@ -27,14 +27,11 @@ class ToolCallResult(BaseModel): description: str result: str - class LLMResult(BaseModel): tool_calls: Optional[List[ToolCallResult]] = None - sections: Optional[Dict[str, Union[str, None]]] = None result: Optional[str] = None unprocessed_result: Optional[str] = None instructions: List[str] = [] - # TODO: clean up these two prompt: Optional[str] = None messages: Optional[List[dict]] = None @@ -159,22 +156,12 @@ def call( tools_to_call = getattr(response_message, "tool_calls", None) text_response = response_message.content - sections:Optional[Dict[str, str]] = None - if isinstance(text_response, str): - try: - parsed_json = json.loads(text_response) - text_response = parsed_json - except json.JSONDecodeError: - pass - if not isinstance(text_response, str): - sections = text_response - text_response = combine_sections(sections) if not tools_to_call: # For chatty models post process and summarize the result # this only works for calls where user prompt is explicitly passed through if post_process_prompt and user_prompt: - logging.info(f"Running post processing on investigation.") + logging.info("Running post processing on investigation.") raw_response = text_response post_processed_response = self._post_processing_call( prompt=user_prompt, @@ -185,7 +172,6 @@ def call( perf_timing.end() return LLMResult( result=post_processed_response, - sections=sections, unprocessed_result=raw_response, tool_calls=tool_calls, prompt=json.dumps(messages, indent=2), @@ -195,7 +181,6 @@ def call( perf_timing.end() return LLMResult( result=text_response, - sections=sections, tool_calls=tool_calls, prompt=json.dumps(messages, indent=2), messages=messages, @@ -231,7 +216,6 @@ def _invoke_tool( logging.warning( f"Failed to parse arguments for tool: {tool_name}. args: {tool_to_call.function.arguments}" ) - tool_call_id = tool_to_call.id tool = self.tool_executor.get_tool_by_name(tool_name) @@ -358,7 +342,7 @@ def investigate( console: Optional[Console] = None, global_instructions: Optional[Instructions] = None, post_processing_prompt: Optional[str] = None, - sections: Optional[Dict[str, str]] = None + sections: Optional[InputSectionsDataType] = None ) -> LLMResult: runbooks = self.runbook_manager.get_instructions_for_issue(issue) diff --git a/server.py b/server.py index 8df2261..15e8d46 100644 --- a/server.py +++ b/server.py @@ -148,7 +148,6 @@ def workload_health_check(request: WorkloadHealthRequest): system_prompt = load_and_render_prompt(request.prompt_template, context={'alerts': workload_alerts}) - ai = config.create_toolcalling_llm(dal=dal) structured_output = {"type": "json_object"}