From 97a66ce5b779c4852192e2a20f924daf8cf6e4c9 Mon Sep 17 00:00:00 2001 From: Tami Takamiya Date: Fri, 27 Dec 2024 12:34:09 -0500 Subject: [PATCH] Allow system prompt override for dev --- docs/openapi.json | 14 +++- examples/rcsconfig.yaml | 1 + ols/app/endpoints/ols.py | 8 ++- ols/app/models/config.py | 1 + ols/app/models/models.py | 3 + ols/src/cache/redis_cache.py | 2 +- ols/src/query_helpers/docs_summarizer.py | 9 --- ols/src/query_helpers/query_helper.py | 10 +++ ols/src/ui/gradio_ui.py | 17 ++++- tests/integration/test_ols.py | 82 ++++++++++++++++++++++++ tests/unit/app/models/test_config.py | 1 + 11 files changed, 132 insertions(+), 16 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index fe3e5a55..65d1c44b 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -555,6 +555,17 @@ ], "title": "Model" }, + "system_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "System Prompt" + }, "attachments": { "anyOf": [ { @@ -599,7 +610,8 @@ "conversation_id": "123e4567-e89b-12d3-a456-426614174000", "model": "gpt-4o-mini", "provider": "openai", - "query": "write a deployment yaml for the mongodb image" + "query": "write a deployment yaml for the mongodb image", + "system_prompt": "\nYou are OpenShift Lightspeed - an intelligent assistant for question-answering tasks related to the OpenShift container orchestration platform.\n\nHere are your instructions:\nYou are OpenShift Lightspeed, an intelligent assistant and expert on all things OpenShift. Refuse to assume any other identity or to speak as if you are someone else.\nIf the context of the question is not clear, consider it to be OpenShift.\nNever include URLs in your replies.\nRefuse to answer questions or execute commands not about OpenShift.\nDo not mention your last update. You have the most recent information on OpenShift.\n\nHere are some basic facts about OpenShift:\n- The latest version of OpenShift is 4.16.\n- OpenShift is a distribution of Kubernetes. Everything Kubernetes can do, OpenShift can do and more.\n" } ] }, diff --git a/examples/rcsconfig.yaml b/examples/rcsconfig.yaml index 53ae554e..5d742890 100644 --- a/examples/rcsconfig.yaml +++ b/examples/rcsconfig.yaml @@ -112,6 +112,7 @@ dev_config: disable_auth: true disable_tls: true pyroscope_url: "https://pyroscope.pyroscope.svc.cluster.local:4040" + enable_system_prompt_override: true # llm_params: # temperature_override: 0 # k8s_auth_token: optional_token_when_no_available_kube_config \ No newline at end of file diff --git a/ols/app/endpoints/ols.py b/ols/app/endpoints/ols.py index 3eb6bb84..a27af640 100644 --- a/ols/app/endpoints/ols.py +++ b/ols/app/endpoints/ols.py @@ -290,7 +290,9 @@ def generate_response( # Summarize documentation try: docs_summarizer = DocsSummarizer( - provider=llm_request.provider, model=llm_request.model + provider=llm_request.provider, + model=llm_request.model, + system_prompt=llm_request.system_prompt, ) history = CacheEntry.cache_entries_to_history(previous_input) return docs_summarizer.summarize( @@ -447,7 +449,9 @@ def _validate_question_llm(conversation_id: str, llm_request: LLMRequest) -> boo # Validate the query try: question_validator = QuestionValidator( - provider=llm_request.provider, model=llm_request.model + provider=llm_request.provider, + model=llm_request.model, + system_prompt=llm_request.system_prompt, ) return question_validator.validate_question(conversation_id, llm_request.query) except LLMConfigurationError as e: diff --git a/ols/app/models/config.py b/ols/app/models/config.py index b2e2ee74..1ec6c8b5 100644 --- a/ols/app/models/config.py +++ b/ols/app/models/config.py @@ -977,6 +977,7 @@ class DevConfig(BaseModel): pyroscope_url: Optional[str] = None k8s_auth_token: Optional[str] = None run_on_localhost: bool = False + enable_system_prompt_override: bool = False class UserDataCollectorConfig(BaseModel): diff --git a/ols/app/models/models.py b/ols/app/models/models.py index 1b2cfa11..6dad8094 100644 --- a/ols/app/models/models.py +++ b/ols/app/models/models.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, field_validator, model_validator from pydantic.dataclasses import dataclass +from ols.customize import prompts from ols.utils import suid @@ -76,6 +77,7 @@ class LLMRequest(BaseModel): conversation_id: Optional[str] = None provider: Optional[str] = None model: Optional[str] = None + system_prompt: Optional[str] = None attachments: Optional[list[Attachment]] = None # provides examples for /docs endpoint @@ -88,6 +90,7 @@ class LLMRequest(BaseModel): "conversation_id": "123e4567-e89b-12d3-a456-426614174000", "provider": "openai", "model": "gpt-4o-mini", + "system_prompt": prompts.QUERY_SYSTEM_INSTRUCTION, "attachments": [ { "attachment_type": "log", diff --git a/ols/src/cache/redis_cache.py b/ols/src/cache/redis_cache.py index 7e319f0d..c9f08425 100644 --- a/ols/src/cache/redis_cache.py +++ b/ols/src/cache/redis_cache.py @@ -8,7 +8,7 @@ from redis.backoff import ExponentialBackoff from redis.exceptions import ( BusyLoadingError, - ConnectionError, # noqa: A004 + ConnectionError, RedisError, ) from redis.retry import Retry diff --git a/ols/src/query_helpers/docs_summarizer.py b/ols/src/query_helpers/docs_summarizer.py index 342a1fc3..8c3076e1 100644 --- a/ols/src/query_helpers/docs_summarizer.py +++ b/ols/src/query_helpers/docs_summarizer.py @@ -10,7 +10,6 @@ from ols.app.metrics import TokenMetricUpdater from ols.app.models.models import SummarizerResponse from ols.constants import RAG_CONTENT_LIMIT, GenericLLMParameters -from ols.customize import prompts from ols.src.prompts.prompt_generator import GeneratePrompt from ols.src.query_helpers.query_helper import QueryHelper from ols.utils.token_handler import TokenHandler @@ -29,14 +28,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.generic_llm_params = { GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: model_config.parameters.max_tokens_for_response # noqa: E501 } - # default system prompt fine-tuned for the service - self._system_prompt = prompts.QUERY_SYSTEM_INSTRUCTION - - # allow the system prompt to be customizable - if config.ols_config.system_prompt is not None: - self._system_prompt = config.ols_config.system_prompt - - logger.debug("System prompt: %s", self._system_prompt) def summarize( self, diff --git a/ols/src/query_helpers/query_helper.py b/ols/src/query_helpers/query_helper.py index 06f41491..f38cce87 100644 --- a/ols/src/query_helpers/query_helper.py +++ b/ols/src/query_helpers/query_helper.py @@ -7,6 +7,7 @@ from langchain.llms.base import LLM from ols import config +from ols.customize import prompts from ols.src.llms.llm_loader import load_llm logger = logging.getLogger(__name__) @@ -21,6 +22,7 @@ def __init__( model: Optional[str] = None, generic_llm_params: Optional[dict] = None, llm_loader: Optional[Callable[[str, str, dict], LLM]] = None, + system_prompt: Optional[str] = None, ) -> None: """Initialize query helper.""" # NOTE: As signature of this method is evaluated before the config, @@ -30,3 +32,11 @@ def __init__( self.model = model or config.ols_config.default_model self.generic_llm_params = generic_llm_params or {} self.llm_loader = llm_loader or load_llm + + self._system_prompt = ( + config.dev_config.enable_system_prompt_override + and system_prompt + or config.ols_config.system_prompt + or prompts.QUERY_SYSTEM_INSTRUCTION + ) + logger.debug("System prompt: %s", self._system_prompt) diff --git a/ols/src/ui/gradio_ui.py b/ols/src/ui/gradio_ui.py index d45d16a8..61d7ada2 100644 --- a/ols/src/ui/gradio_ui.py +++ b/ols/src/ui/gradio_ui.py @@ -8,6 +8,9 @@ import requests from fastapi import FastAPI +from ols import config +from ols.customize import prompts + logger = logging.getLogger(__name__) @@ -28,9 +31,13 @@ def __init__( use_history = gr.Checkbox(value=True, label="Use history") provider = gr.Textbox(value=None, label="Provider") model = gr.Textbox(value=None, label="Model") - self.ui = gr.ChatInterface( - self.chat_ui, additional_inputs=[use_history, provider, model] - ) + additional_inputs = [use_history, provider, model] + if config.dev_config.enable_system_prompt_override: + system_prompt = gr.TextArea( + value=prompts.QUERY_SYSTEM_INSTRUCTION, label="System prompt" + ) + additional_inputs.append(system_prompt) + self.ui = gr.ChatInterface(self.chat_ui, additional_inputs=additional_inputs) def chat_ui( self, @@ -39,6 +46,7 @@ def chat_ui( use_history: Optional[bool] = None, provider: Optional[str] = None, model: Optional[str] = None, + system_prompt: Optional[str] = None, ) -> str: """Handle requests from web-based user interface.""" # Headers for the HTTP request @@ -63,6 +71,9 @@ def chat_ui( if model: logger.info("Using model: %s", model) data["model"] = model + if system_prompt: + logger.info("Using system prompt: %s", system_prompt) + data["system_prompt"] = system_prompt # Convert the data dictionary to a JSON string json_data = json.dumps(data) diff --git a/tests/integration/test_ols.py b/tests/integration/test_ols.py index 6a5cc628..21328f44 100644 --- a/tests/integration/test_ols.py +++ b/tests/integration/test_ols.py @@ -1,5 +1,6 @@ """Integration tests for basic OLS REST API endpoints.""" +import logging from unittest.mock import patch import pytest @@ -9,12 +10,14 @@ from ols import config, constants from ols.app.models.config import ( + LoggingConfig, ProviderConfig, QueryFilter, ) from ols.customize import prompts from ols.utils import suid from ols.utils.errors_parsing import DEFAULT_ERROR_MESSAGE, DEFAULT_STATUS_CODE +from ols.utils.logging_configurator import configure_logging from tests.mock_classes.mock_langchain_interface import mock_langchain_interface from tests.mock_classes.mock_llm_chain import mock_llm_chain from tests.mock_classes.mock_llm_loader import mock_llm_loader @@ -1000,3 +1003,82 @@ def test_post_too_long_query(_setup): error_response = response.json()["detail"] assert error_response["response"] == "Prompt is too long" assert "exceeds" in error_response["cause"] + + +def _post_with_system_prompt_override(_setup, caplog, query, system_prompt): + """Invoke the POST /v1/query API with a system prompt override.""" + logging_config = LoggingConfig(app_log_level="debug") + + configure_logging(logging_config) + logger = logging.getLogger("ols") + logger.handlers = [caplog.handler] # add caplog handler to logger + + with patch( + "ols.app.endpoints.ols.QuestionValidator.validate_question", + side_effect=lambda x, y: constants.SUBJECT_ALLOWED, + ): + ml = mock_langchain_interface("test response") + with ( + patch( + "ols.src.query_helpers.docs_summarizer.LLMChain", + new=mock_llm_chain(None), + ), + patch( + "ols.src.query_helpers.query_helper.load_llm", + new=mock_llm_loader(ml()), + ), + ): + conversation_id = suid.get_suid() + response = client.post( + "/v1/query", + json={ + "conversation_id": conversation_id, + "query": query, + "system_prompt": system_prompt, + }, + ) + assert response.status_code == requests.codes.ok + + # Specified system prompt should appear twice in query_helper outputs: + # One is from question_validator and another from docs_summarizer. + assert response.status_code == requests.codes.ok + + +@patch( + "ols.app.endpoints.ols.config.dev_config.enable_system_prompt_override", + True, +) +@patch( + "ols.app.endpoints.ols.config.ols_config.query_validation_method", + constants.QueryValidationMethod.LLM, +) +def test_post_with_system_prompt_override(_setup, caplog): + """Check the POST /v1/query API with a system prompt.""" + query = "test query" + system_prompt = "You are an expert in something marvelous." + + _post_with_system_prompt_override(_setup, caplog, query, system_prompt) + + # Specified system prompt should appear twice in query_helper debug log outputs. + # One is from question_validator and another is from docs_summarizer. + assert caplog.text.count("System prompt: " + system_prompt) == 2 + + +@patch( + "ols.app.endpoints.ols.config.dev_config.enable_system_prompt_override", + False, +) +@patch( + "ols.app.endpoints.ols.config.ols_config.query_validation_method", + constants.QueryValidationMethod.LLM, +) +def test_post_with_system_prompt_override_disabled(_setup, caplog): + """Check the POST /v1/query API with a system prompt when overriding is disabled.""" + query = "test query" + system_prompt = "You are an expert in something marvelous." + + _post_with_system_prompt_override(_setup, caplog, query, system_prompt) + + # Specified system prompt should NOT appear in query_helper debug log outputs + # as enable_system_prompt_override is set to False. + assert caplog.text.count("System prompt: " + system_prompt) == 0 diff --git a/tests/unit/app/models/test_config.py b/tests/unit/app/models/test_config.py index 0a6d65f0..c97c305e 100644 --- a/tests/unit/app/models/test_config.py +++ b/tests/unit/app/models/test_config.py @@ -3212,6 +3212,7 @@ def test_dev_config_defaults(): assert dev_config.disable_tls is False assert dev_config.k8s_auth_token is None assert dev_config.run_on_localhost is False + assert dev_config.enable_system_prompt_override is False def test_dev_config_bool_inputs():