Skip to content

Commit

Permalink
Allow system prompt override for dev
Browse files Browse the repository at this point in the history
  • Loading branch information
TamiTakamiya committed Dec 28, 2024
1 parent 45f79e3 commit 97a66ce
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 16 deletions.
14 changes: 13 additions & 1 deletion docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,17 @@
],
"title": "Model"
},
"system_prompt": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "System Prompt"
},
"attachments": {
"anyOf": [
{
Expand Down Expand Up @@ -599,7 +610,8 @@
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"model": "gpt-4o-mini",
"provider": "openai",
"query": "write a deployment yaml for the mongodb image"
"query": "write a deployment yaml for the mongodb image",
"system_prompt": "\nYou are OpenShift Lightspeed - an intelligent assistant for question-answering tasks related to the OpenShift container orchestration platform.\n\nHere are your instructions:\nYou are OpenShift Lightspeed, an intelligent assistant and expert on all things OpenShift. Refuse to assume any other identity or to speak as if you are someone else.\nIf the context of the question is not clear, consider it to be OpenShift.\nNever include URLs in your replies.\nRefuse to answer questions or execute commands not about OpenShift.\nDo not mention your last update. You have the most recent information on OpenShift.\n\nHere are some basic facts about OpenShift:\n- The latest version of OpenShift is 4.16.\n- OpenShift is a distribution of Kubernetes. Everything Kubernetes can do, OpenShift can do and more.\n"
}
]
},
Expand Down
1 change: 1 addition & 0 deletions examples/rcsconfig.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ dev_config:
disable_auth: true
disable_tls: true
pyroscope_url: "https://pyroscope.pyroscope.svc.cluster.local:4040"
enable_system_prompt_override: true
# llm_params:
# temperature_override: 0
# k8s_auth_token: optional_token_when_no_available_kube_config
8 changes: 6 additions & 2 deletions ols/app/endpoints/ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,9 @@ def generate_response(
# Summarize documentation
try:
docs_summarizer = DocsSummarizer(
provider=llm_request.provider, model=llm_request.model
provider=llm_request.provider,
model=llm_request.model,
system_prompt=llm_request.system_prompt,
)
history = CacheEntry.cache_entries_to_history(previous_input)
return docs_summarizer.summarize(
Expand Down Expand Up @@ -447,7 +449,9 @@ def _validate_question_llm(conversation_id: str, llm_request: LLMRequest) -> boo
# Validate the query
try:
question_validator = QuestionValidator(
provider=llm_request.provider, model=llm_request.model
provider=llm_request.provider,
model=llm_request.model,
system_prompt=llm_request.system_prompt,
)
return question_validator.validate_question(conversation_id, llm_request.query)
except LLMConfigurationError as e:
Expand Down
1 change: 1 addition & 0 deletions ols/app/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,7 @@ class DevConfig(BaseModel):
pyroscope_url: Optional[str] = None
k8s_auth_token: Optional[str] = None
run_on_localhost: bool = False
enable_system_prompt_override: bool = False


class UserDataCollectorConfig(BaseModel):
Expand Down
3 changes: 3 additions & 0 deletions ols/app/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel, field_validator, model_validator
from pydantic.dataclasses import dataclass

from ols.customize import prompts
from ols.utils import suid


Expand Down Expand Up @@ -76,6 +77,7 @@ class LLMRequest(BaseModel):
conversation_id: Optional[str] = None
provider: Optional[str] = None
model: Optional[str] = None
system_prompt: Optional[str] = None
attachments: Optional[list[Attachment]] = None

# provides examples for /docs endpoint
Expand All @@ -88,6 +90,7 @@ class LLMRequest(BaseModel):
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"provider": "openai",
"model": "gpt-4o-mini",
"system_prompt": prompts.QUERY_SYSTEM_INSTRUCTION,
"attachments": [
{
"attachment_type": "log",
Expand Down
2 changes: 1 addition & 1 deletion ols/src/cache/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from redis.backoff import ExponentialBackoff
from redis.exceptions import (
BusyLoadingError,
ConnectionError, # noqa: A004
ConnectionError,
RedisError,
)
from redis.retry import Retry
Expand Down
9 changes: 0 additions & 9 deletions ols/src/query_helpers/docs_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ols.app.metrics import TokenMetricUpdater
from ols.app.models.models import SummarizerResponse
from ols.constants import RAG_CONTENT_LIMIT, GenericLLMParameters
from ols.customize import prompts
from ols.src.prompts.prompt_generator import GeneratePrompt
from ols.src.query_helpers.query_helper import QueryHelper
from ols.utils.token_handler import TokenHandler
Expand All @@ -29,14 +28,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.generic_llm_params = {
GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: model_config.parameters.max_tokens_for_response # noqa: E501
}
# default system prompt fine-tuned for the service
self._system_prompt = prompts.QUERY_SYSTEM_INSTRUCTION

# allow the system prompt to be customizable
if config.ols_config.system_prompt is not None:
self._system_prompt = config.ols_config.system_prompt

logger.debug("System prompt: %s", self._system_prompt)

def summarize(
self,
Expand Down
10 changes: 10 additions & 0 deletions ols/src/query_helpers/query_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain.llms.base import LLM

from ols import config
from ols.customize import prompts
from ols.src.llms.llm_loader import load_llm

logger = logging.getLogger(__name__)
Expand All @@ -21,6 +22,7 @@ def __init__(
model: Optional[str] = None,
generic_llm_params: Optional[dict] = None,
llm_loader: Optional[Callable[[str, str, dict], LLM]] = None,
system_prompt: Optional[str] = None,
) -> None:
"""Initialize query helper."""
# NOTE: As signature of this method is evaluated before the config,
Expand All @@ -30,3 +32,11 @@ def __init__(
self.model = model or config.ols_config.default_model
self.generic_llm_params = generic_llm_params or {}
self.llm_loader = llm_loader or load_llm

self._system_prompt = (
config.dev_config.enable_system_prompt_override
and system_prompt
or config.ols_config.system_prompt
or prompts.QUERY_SYSTEM_INSTRUCTION
)
logger.debug("System prompt: %s", self._system_prompt)
17 changes: 14 additions & 3 deletions ols/src/ui/gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import requests
from fastapi import FastAPI

from ols import config
from ols.customize import prompts

logger = logging.getLogger(__name__)


Expand All @@ -28,9 +31,13 @@ def __init__(
use_history = gr.Checkbox(value=True, label="Use history")
provider = gr.Textbox(value=None, label="Provider")
model = gr.Textbox(value=None, label="Model")
self.ui = gr.ChatInterface(
self.chat_ui, additional_inputs=[use_history, provider, model]
)
additional_inputs = [use_history, provider, model]
if config.dev_config.enable_system_prompt_override:
system_prompt = gr.TextArea(
value=prompts.QUERY_SYSTEM_INSTRUCTION, label="System prompt"
)
additional_inputs.append(system_prompt)
self.ui = gr.ChatInterface(self.chat_ui, additional_inputs=additional_inputs)

def chat_ui(
self,
Expand All @@ -39,6 +46,7 @@ def chat_ui(
use_history: Optional[bool] = None,
provider: Optional[str] = None,
model: Optional[str] = None,
system_prompt: Optional[str] = None,
) -> str:
"""Handle requests from web-based user interface."""
# Headers for the HTTP request
Expand All @@ -63,6 +71,9 @@ def chat_ui(
if model:
logger.info("Using model: %s", model)
data["model"] = model
if system_prompt:
logger.info("Using system prompt: %s", system_prompt)
data["system_prompt"] = system_prompt

# Convert the data dictionary to a JSON string
json_data = json.dumps(data)
Expand Down
82 changes: 82 additions & 0 deletions tests/integration/test_ols.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Integration tests for basic OLS REST API endpoints."""

import logging
from unittest.mock import patch

import pytest
Expand All @@ -9,12 +10,14 @@

from ols import config, constants
from ols.app.models.config import (
LoggingConfig,
ProviderConfig,
QueryFilter,
)
from ols.customize import prompts
from ols.utils import suid
from ols.utils.errors_parsing import DEFAULT_ERROR_MESSAGE, DEFAULT_STATUS_CODE
from ols.utils.logging_configurator import configure_logging
from tests.mock_classes.mock_langchain_interface import mock_langchain_interface
from tests.mock_classes.mock_llm_chain import mock_llm_chain
from tests.mock_classes.mock_llm_loader import mock_llm_loader
Expand Down Expand Up @@ -1000,3 +1003,82 @@ def test_post_too_long_query(_setup):
error_response = response.json()["detail"]
assert error_response["response"] == "Prompt is too long"
assert "exceeds" in error_response["cause"]


def _post_with_system_prompt_override(_setup, caplog, query, system_prompt):
"""Invoke the POST /v1/query API with a system prompt override."""
logging_config = LoggingConfig(app_log_level="debug")

configure_logging(logging_config)
logger = logging.getLogger("ols")
logger.handlers = [caplog.handler] # add caplog handler to logger

with patch(
"ols.app.endpoints.ols.QuestionValidator.validate_question",
side_effect=lambda x, y: constants.SUBJECT_ALLOWED,
):
ml = mock_langchain_interface("test response")
with (
patch(
"ols.src.query_helpers.docs_summarizer.LLMChain",
new=mock_llm_chain(None),
),
patch(
"ols.src.query_helpers.query_helper.load_llm",
new=mock_llm_loader(ml()),
),
):
conversation_id = suid.get_suid()
response = client.post(
"/v1/query",
json={
"conversation_id": conversation_id,
"query": query,
"system_prompt": system_prompt,
},
)
assert response.status_code == requests.codes.ok

# Specified system prompt should appear twice in query_helper outputs:
# One is from question_validator and another from docs_summarizer.
assert response.status_code == requests.codes.ok


@patch(
"ols.app.endpoints.ols.config.dev_config.enable_system_prompt_override",
True,
)
@patch(
"ols.app.endpoints.ols.config.ols_config.query_validation_method",
constants.QueryValidationMethod.LLM,
)
def test_post_with_system_prompt_override(_setup, caplog):
"""Check the POST /v1/query API with a system prompt."""
query = "test query"
system_prompt = "You are an expert in something marvelous."

_post_with_system_prompt_override(_setup, caplog, query, system_prompt)

# Specified system prompt should appear twice in query_helper debug log outputs.
# One is from question_validator and another is from docs_summarizer.
assert caplog.text.count("System prompt: " + system_prompt) == 2


@patch(
"ols.app.endpoints.ols.config.dev_config.enable_system_prompt_override",
False,
)
@patch(
"ols.app.endpoints.ols.config.ols_config.query_validation_method",
constants.QueryValidationMethod.LLM,
)
def test_post_with_system_prompt_override_disabled(_setup, caplog):
"""Check the POST /v1/query API with a system prompt when overriding is disabled."""
query = "test query"
system_prompt = "You are an expert in something marvelous."

_post_with_system_prompt_override(_setup, caplog, query, system_prompt)

# Specified system prompt should NOT appear in query_helper debug log outputs
# as enable_system_prompt_override is set to False.
assert caplog.text.count("System prompt: " + system_prompt) == 0
1 change: 1 addition & 0 deletions tests/unit/app/models/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3212,6 +3212,7 @@ def test_dev_config_defaults():
assert dev_config.disable_tls is False
assert dev_config.k8s_auth_token is None
assert dev_config.run_on_localhost is False
assert dev_config.enable_system_prompt_override is False


def test_dev_config_bool_inputs():
Expand Down

0 comments on commit 97a66ce

Please sign in to comment.