Skip to content

Commit

Permalink
base_prompt: Stop supporting deprecated variables notation
Browse files Browse the repository at this point in the history
  • Loading branch information
irgolic committed Nov 21, 2023
1 parent 58d2316 commit a6fbb35
Showing 1 changed file with 7 additions and 25 deletions.
32 changes: 7 additions & 25 deletions guardrails/prompt/base_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import regex

from guardrails.namespace_template import NamespaceTemplate
from guardrails.utils.constants import constants
from guardrails.utils.parsing_utils import get_template_variables

Expand All @@ -17,8 +16,13 @@ class BasePrompt:
def __init__(self, source: str, output_schema: Optional[str] = None):
self.format_instructions_start = self.get_format_instructions_idx(source)

# Substitute constants in the prompt.
source = self.substitute_constants(source)
# Warn if the prompt uses the old constant schema.
if self.uses_old_constant_schema(source):
warnings.warn(
"You may be using an an unsupported convention for specifying "
"guardrails variables. Follow the new namespaced convention "
"documented here: https://docs.guardrailsai.com/0-2-migration/"
)

# If an output schema is provided, substitute it in the prompt.
if output_schema:
Expand All @@ -44,28 +48,6 @@ def variable_names(self):
def format_instructions(self):
return self.source[self.format_instructions_start :]

def substitute_constants(self, text):
"""Substitute constants in the prompt."""
# Substitute constants by reading the constants file.
# Regex to extract all occurrences of ${gr.<constant_name>}
if self.uses_old_constant_schema(text):
warnings.warn(
"It appears that you are using an old schema for gaurdrails variables, "
"follow the new namespaced convention "
"documented here: https://docs.guardrailsai.com/0-2-migration/"
)

matches = re.findall(r"\${gr\.(\w+)}", text)

# Substitute all occurrences of ${gr.<constant_name>}
# with the value of the constant.
for match in matches:
template = NamespaceTemplate(text)
mapping = {f"gr.{match}": constants[match]}
text = template.safe_substitute(**mapping)

return text

def uses_old_constant_schema(self, text) -> bool:
matches = re.findall(r"@(\w+)", text)
if len(matches) == 0:
Expand Down

0 comments on commit a6fbb35

Please sign in to comment.