Skip to content

Commit

Permalink
Add system message to KnowledgeEngine and LLMClient
Browse files Browse the repository at this point in the history
  • Loading branch information
caufieldjh committed Aug 13, 2024
1 parent b40112b commit ffeed2a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
10 changes: 9 additions & 1 deletion src/ontogpt/clients/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class LLMClient:

temperature: float = 1.0

system_message: str = ""
"""System message to be provided to the LLM."""

def __post_init__(self):
# Get appropriate API key for the model source
# and other details if needed
Expand Down Expand Up @@ -65,14 +68,19 @@ def complete(self, prompt, show_prompt: bool = False, **kwargs) -> str:

response = None

these_messages = [{"content": prompt, "role": "user"}]

if self.system_message:
these_messages.insert(0, {"content": self.system_message, "role": "system"})

try:
# TODO: expose user prompt to CLI
response = completion(
api_key=self.api_key,
api_base=self.api_base,
api_version=self.api_version,
model=self.model,
messages=[{"content": prompt, "role": "user"}],
messages=these_messages,
temperature=self.temperature,
caching=True,
custom_llm_provider=self.custom_llm_provider,
Expand Down
4 changes: 4 additions & 0 deletions src/ontogpt/engines/knowledge_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ class KnowledgeEngine(ABC):
temperature: float = 1.0
"""Temperature for LLM completions - this is passed to the LLMClient."""

system_message: str = ""
"""System message to be provided to the LLM."""

def __post_init__(self):
if self.template_details:
(
Expand All @@ -167,6 +170,7 @@ def __post_init__(self):
api_version=self.api_version,
api_base=self.api_base,
custom_llm_provider=self.model_provider,
system_message=self.system_message,
)

# We retrieve encoding
Expand Down
1 change: 0 additions & 1 deletion src/ontogpt/prompts/qa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pathlib import Path

QA_PROMPT_DIR_PATH = Path(__file__).parent
GENERIC_QA_PROMPT = QA_PROMPT_DIR_PATH / "generic.jinja2"

0 comments on commit ffeed2a

Please sign in to comment.