diff --git a/holmes/core/supabase_dal.py b/holmes/core/supabase_dal.py index d771086a..a9e88051 100644 --- a/holmes/core/supabase_dal.py +++ b/holmes/core/supabase_dal.py @@ -2,12 +2,16 @@ import json import logging import os +import threading from typing import Dict, Optional, List +from uuid import uuid4 import yaml +from postgrest.types import ReturnMethod from supabase import create_client from supabase.lib.client_options import ClientOptions from pydantic import BaseModel +from cachetools import TTLCache from holmes.common.env_vars import (ROBUSTA_CONFIG_PATH, ROBUSTA_ACCOUNT_ID, STORE_URL, STORE_API_KEY, STORE_EMAIL, STORE_PASSWORD) @@ -19,6 +23,7 @@ ISSUES_TABLE = "Issues" EVIDENCE_TABLE = "Evidence" RUNBOOKS_TABLE = "HolmesRunbooks" +SESSION_TOKENS_TABLE = "AuthTokens" class RobustaConfig(BaseModel): sinks_config: List[Dict[str, Dict]] @@ -42,7 +47,10 @@ def __init__(self): logging.info(f"Initializing robusta store for account {self.account_id}") options = ClientOptions(postgrest_client_timeout=SUPABASE_TIMEOUT_SECONDS) self.client = create_client(self.url, self.api_key, options) - self.sign_in() + self.user_id = self.sign_in() + ttl = int(os.environ.get("SAAS_SESSION_TOKEN_TTL_SEC", "82800")) # 23 hours + self.token_cache = TTLCache(maxsize=1, ttl=ttl) + self.lock = threading.Lock() @staticmethod def __load_robusta_config() -> Optional[RobustaToken]: @@ -87,11 +95,12 @@ def __init_config(self) -> bool: # valid only if all store parameters are provided return all([self.account_id, self.url, self.api_key, self.email, self.password]) - def sign_in(self): + def sign_in(self) -> str: logging.info("Supabase DAL login") res = self.client.auth.sign_in_with_password({"email": self.email, "password": self.password}) self.client.auth.set_session(res.session.access_token, res.session.refresh_token) self.client.postgrest.auth(res.session.access_token) + return res.user.id def get_issue_data(self, issue_id: str) -> Optional[Dict]: # TODO this could be done in a single atomic SELECT, but there is no @@ -147,6 +156,27 @@ def get_resource_instructions(self, type: str, name: str) -> List[str]: return [] + def create_session_token(self) -> str: + token = str(uuid4()) + self.client.table(SESSION_TOKENS_TABLE).insert( + { + "account_id": self.account_id, + "user_id": self.user_id, + "token": token, + "type": "HOLMES", + }, returning=ReturnMethod.minimal # must use this, because the user cannot read this table + ).execute() + return token + + def get_ai_credentials(self) -> (str, str): + with self.lock: + session_token = self.token_cache.get("session_token") + if not session_token: + session_token = self.create_session_token() + self.token_cache["session_token"] = session_token + + return self.account_id, session_token + def get_workload_issues(self, resource: dict, since_hours: float) -> List[str]: if not self.enabled or not resource: return [] diff --git a/holmes/core/tool_calling_llm.py b/holmes/core/tool_calling_llm.py index 6cec1ce7..f03fd450 100644 --- a/holmes/core/tool_calling_llm.py +++ b/holmes/core/tool_calling_llm.py @@ -86,7 +86,7 @@ def check_llm(self, model, api_key): #if not litellm.supports_function_calling(model=model): # raise Exception(f"model {model} does not support function calling. You must use HolmesGPT with a model that supports function calling.") def get_context_window_size(self) -> int: - return litellm.model_cost[self.model]['max_input_tokens'] + return litellm.model_cost[self.model]['max_input_tokens'] def count_tokens_for_message(self, messages: list[dict]) -> int: return litellm.token_counter(model=self.model, diff --git a/poetry.lock b/poetry.lock index b9b8d53b..8e970156 100644 --- a/poetry.lock +++ b/poetry.lock @@ -226,6 +226,17 @@ urllib3 = [ [package.extras] crt = ["awscrt (==0.20.11)"] +[[package]] +name = "cachetools" +version = "5.5.0" +description = "Extensible memoizing collections and decorators" +optional = false +python-versions = ">=3.7" +files = [ + {file = "cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292"}, + {file = "cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a"}, +] + [[package]] name = "certifi" version = "2024.7.4" @@ -2680,4 +2691,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "202b5a7f3db3f4a664085e2dab801f3117c30755ac0af380561409f3d040b867" +content-hash = "6ed6b5adf1a5d796990b8b0d3eb0df0ab7cd113663f035997b0000b414aea16f" diff --git a/pyproject.toml b/pyproject.toml index fdc5bb27..ecfe8e00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ urllib3 = "^1.26.19" boto3 = "^1.34.145" setuptools = "^72.1.0" aiohttp = "^3.10.2" +cachetools = "^5.5.0" [build-system] requires = ["poetry-core"] diff --git a/server.py b/server.py index a7e77c3e..4effd853 100644 --- a/server.py +++ b/server.py @@ -6,7 +6,9 @@ print("added custom certificate") # DO NOT ADD ANY IMPORTS OR CODE ABOVE THIS LINE -# IMPORTING ABOVE MIGHT INITIALIZE AN HTTPS CLIENT THAT DOESN'T TRUST THE CUSTOM CERTIFICATEE +# IMPORTING ABOVE MIGHT INITIALIZE AN HTTPS CLIENT THAT DOESN'T TRUST THE CUSTOM CERTIFICATE + + import jinja2 import logging import uvicorn @@ -15,7 +17,7 @@ from typing import Dict, Callable from litellm.exceptions import AuthenticationError from fastapi import FastAPI, HTTPException -from pydantic import BaseModel +from pydantic import SecretStr from rich.console import Console from holmes.common.env_vars import ( @@ -40,7 +42,6 @@ ) from holmes.plugins.prompts import load_and_render_prompt from holmes.core.tool_calling_llm import ToolCallingLLM -import jinja2 def init_logging(): @@ -69,9 +70,16 @@ def init_logging(): config = Config.load_from_env() +def load_robusta_api_key(): + if os.environ.get("ROBUSTA_AI"): + account_id, token = dal.get_ai_credentials() + config.api_key = SecretStr(f"{account_id} {token}") + + @app.post("/api/investigate") def investigate_issues(investigate_request: InvestigateRequest): try: + load_robusta_api_key() context = dal.get_issue_data( investigate_request.context.get("robusta_issue_id") ) @@ -112,7 +120,7 @@ def investigate_issues(investigate_request: InvestigateRequest): @app.post("/api/workload_health_check") def workload_health_check(request: WorkloadHealthRequest): - + load_robusta_api_key() try: resource = request.resource workload_alerts: list[str] = [] @@ -149,6 +157,7 @@ def workload_health_check(request: WorkloadHealthRequest): def handle_issue_conversation( conversation_request: ConversationRequest, ai: ToolCallingLLM ): + load_robusta_api_key() context_window = ai.get_context_window_size() number_of_tools = len( conversation_request.context.investigation_result.tools @@ -240,6 +249,7 @@ def handle_issue_conversation( @app.post("/api/conversation") def converstation(conversation_request: ConversationRequest): try: + load_robusta_api_key() ai = config.create_toolcalling_llm(console, allowed_toolsets=ALLOWED_TOOLSETS) handler = conversation_type_handlers.get(conversation_request.conversation_type)