Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/augmentations #1091

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 26 additions & 60 deletions langtest/augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
import yaml

from langtest.augmentation.utils import AzureOpenAIConfig, OpenAIConfig
from langtest.datahandler.datasource import DataFactory
from langtest.transform import TestFactory
from langtest.transform.utils import create_terminology
Expand Down Expand Up @@ -324,6 +325,7 @@ def __init__(
generate_templates=False,
show_templates=False,
num_extra_templates=10,
model_config: Union[OpenAIConfig, AzureOpenAIConfig] = None,
) -> None:
"""This constructor for the TemplaticAugment class.

Expand All @@ -341,21 +343,23 @@ def __init__(
given_template = self.__templates[:]
for template in given_template:
generated_templates: List[str] = self.__generate_templates(
template, num_extra_templates
template, num_extra_templates, model_config
)

while len(generated_templates) < num_extra_templates:
temp_templates = self.__generate_templates(
template, num_extra_templates
template,
num_extra_templates,
model_config,
)
generated_templates.extend(temp_templates)

if generated_templates:
# Extend the existing templates list

self.__templates.extend(generated_templates[:num_extra_templates])
except Exception as e:
raise Errors.E095(msg=e)
except Exception as e_msg:
raise Errors.E095(e=e_msg)

if show_templates:
[print(template) for template in self.__templates]
Expand Down Expand Up @@ -596,63 +600,25 @@ def add_spaces_around_punctuation(text: str):

return text

def __generate_templates(self, template, num_extra_templates) -> List[str]:
def __generate_templates(
self,
template: str,
num_extra_templates: int,
model_config: Union[OpenAIConfig, AzureOpenAIConfig] = None,
) -> List[str]:
if try_import_lib("openai"):
import openai
from pydantic import BaseModel, validator

client = openai.OpenAI()

class Templates(BaseModel):
templates: List[str]

def __post_init__(self):
self.templates = [i.strip('"') for i in self.templates]

@validator("templates", each_item=True, allow_reuse=True)
def check_templates(cls, v: str):
if not v:
raise ValueError("No templates generated.")
return v.strip('"')

def remove_invalid_templates(self, original_template):
# extract variable names using regex
regexs = r"{([^{}]*)}"
original_vars = re.findall(regexs, original_template)
original_vars = set([var.strip() for var in original_vars])

# remove invalid templates
valid_templates = []
for template in self.templates:
template_vars: List[str] = re.findall(regexs, template)
template_vars = set([var.strip() for var in template_vars])
if template_vars == original_vars:
valid_templates.append(template)
self.templates = valid_templates

prompt = (
f"Based on the provided template, create {num_extra_templates} new and unique templates that are "
"variations on this theme. Present these as a list, with each template as a quoted string. The list should "
"contain only the templates, without any additional text or explanation. Ensure that the structure of "
"these variables remains consistent in each generated template. Note: don't add any extra variables and ignore typo errors.\n\n"
"Template:\n"
f"{template}\n"
)
response = client.beta.chat.completions.parse(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": f"Action: Generate up to {num_extra_templates} templates and ensure that the structure of the variables within the templates remains unchanged and don't add any extra variables.",
},
{"role": "user", "content": prompt},
],
max_tokens=500,
temperature=0,
response_format=Templates,
from langtest.augmentation.utils import (
generate_templates_azoi, # azoi means Azure OpenAI
generate_templates_openai,
)

generated_response = response.choices[0].message.parsed
generated_response.remove_invalid_templates(template)
params = model_config.copy() if model_config else {}

if model_config and model_config.get("provider") == "openai":
return generate_templates_openai(template, num_extra_templates, params)

return generated_response.templates[:num_extra_templates]
elif model_config and model_config.get("provider") == "azure":
return generate_templates_azoi(template, num_extra_templates, params)

else:
return generate_templates_openai(template, num_extra_templates)
174 changes: 174 additions & 0 deletions langtest/augmentation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import re
from typing import List, TypedDict, Union
import os

from pydantic import BaseModel, validator
from langtest.logger import logger


class OpenAIConfig(TypedDict):
"""OpenAI Configuration for API Key and Provider."""

api_key: str = os.environ.get("OPENAI_API_KEY")
base_url: Union[str, None] = None
organization: Union[str, None] = (None,)
project: Union[str, None] = (None,)
provider: str = "openai"


class AzureOpenAIConfig(TypedDict):
"""Azure OpenAI Configuration for API Key and Provider."""

from openai.lib.azure import AzureADTokenProvider

azure_endpoint: str
api_version: str
api_key: str
provider: str
azure_deployment: Union[str, None] = None
azure_ad_token: Union[str, None] = (None,)
azure_ad_token_provider: Union[AzureADTokenProvider, None] = (None,)
organization: Union[str, None] = (None,)


class Templates(BaseModel):
"""Model to validate generated templates."""

templates: List[str]

def __post_init__(self):
"""Post init method to remove quotes from templates."""
self.templates = [i.strip('"') for i in self.templates]
logger.info(f"Generated templates: {self.templates}")

@validator("templates", each_item=True, allow_reuse=True)
def check_templates(cls, v: str):
"""Validator to check if templates are generated."""
if not v:
raise ValueError("No templates generated.")
return v.strip('"')

def remove_invalid_templates(self, original_template):
"""Remove invalid templates based on the original template."""
# extract variable names using regex
regexs = r"{([^{}]*)}"
original_vars = re.findall(regexs, original_template)
original_vars = set([var.strip() for var in original_vars])

# remove invalid templates
valid_templates = []
for template in self.templates:
template_vars: List[str] = re.findall(regexs, template)
template_vars = set([var.strip() for var in template_vars])
if template_vars == original_vars:
valid_templates.append(template)
logger.info(f"Valid template: {template}")
else:
logger.warning(
f"Invalid Variables in template: {template} - {template_vars}"
)

self.templates = valid_templates
logger.info(f"Valid templates: {self.templates}")


def generate_templates_azoi(
template: str, num_extra_templates: int, model_config: AzureOpenAIConfig
):
"""Generate new templates based on the provided template using Azure OpenAI API."""
import openai

if "provider" in model_config:
del model_config["provider"]

client = openai.AzureOpenAI(**model_config)

prompt = (
"Based on the provided template, create {num_extra_templates} new and unique templates that are "
"variations on this theme. Present these as a list, with each template as a quoted string. The list should "
"contain only the templates, without any additional text or explanation. Ensure that the structure of "
"these variables remains consistent in each generated template. Note: don't add any extra variables and ignore typo errors.\n\n"
"Template:\n"
"{template}\n"
)

response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": f"Generate up to {num_extra_templates} templates based on the provided template.\n\n JSON Output Schema: {Templates.schema()}\n",
},
{
"role": "user",
"content": prompt.format(
template="The {ORG} company is located in {LOC}",
num_extra_templates=2,
),
},
{
"role": "assistant",
"content": '["The {ORG} corporation is based out of {LOC}",\n "The {ORG} organization operates in {LOC}"]',
},
{
"role": "user",
"content": prompt.format(
template=template, num_extra_templates=num_extra_templates
),
},
],
temperature=0.1,
max_tokens=1000,
)

import json

try:
clean_response = response.choices[0].message.content.replace("'", '"')
gen_templates = Templates(templates=json.loads(clean_response))
gen_templates.remove_invalid_templates(template)

return gen_templates.templates[:num_extra_templates]

except json.JSONDecodeError as e:
logger.error(f"Error decoding response: {e}")
raise ValueError(f"Error decoding response: {e}")


def generate_templates_openai(
template: str, num_extra_templates: int, model_config: OpenAIConfig = OpenAIConfig()
):
"""Generate new templates based on the provided template using OpenAI API."""
import openai

if "provider" in model_config:
del model_config["provider"]

client = openai.OpenAI(**model_config)

prompt = (
f"Based on the provided template, create {num_extra_templates} new and unique templates that are "
"variations on this theme. Present these as a list, with each template as a quoted string. The list should "
"contain only the templates, without any additional text or explanation. Ensure that the structure of "
"these variables remains consistent in each generated template. Note: don't add any extra variables and ignore typo errors.\n\n"
"Template:\n"
f"{template}\n"
)
response = client.beta.chat.completions.parse(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": f"Action: Generate up to {num_extra_templates} templates and ensure that the structure of the variables within the templates remains unchanged and don't add any extra variables.",
},
{"role": "user", "content": prompt},
],
max_tokens=100,
temperature=0.1,
response_format=Templates,
)

generated_response = response.choices[0].message.parsed
generated_response.remove_invalid_templates(template)

return generated_response.templates[:num_extra_templates]
2 changes: 1 addition & 1 deletion langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import csv
import importlib
import logging
import os
import random
import re
Expand All @@ -11,6 +10,7 @@
import jsonlines
import pandas as pd
from langtest.tasks.task import TaskManager
from langtest.logger import logger as logging

from .format import Formatter
from langtest.utils.custom_types import (
Expand Down
Loading