Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch/2.3.1 #1103

Merged
merged 14 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
pull_request:
branches:
- "release/*"
- "patch/*"
- "main"

jobs:
Expand Down
28 changes: 15 additions & 13 deletions langtest/augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from langtest.utils.custom_types.predictions import NERPrediction, SequenceLabel
from langtest.utils.custom_types.sample import NERSample
from langtest.tasks import TaskManager
from ..utils.lib_manager import try_import_lib
from ..errors import Errors


Expand Down Expand Up @@ -358,6 +357,9 @@ def __init__(
# Extend the existing templates list

self.__templates.extend(generated_templates[:num_extra_templates])
except ModuleNotFoundError:
raise ImportError(Errors.E097())

except Exception as e_msg:
raise Errors.E095(e=e_msg)

Expand Down Expand Up @@ -606,19 +608,19 @@ def __generate_templates(
num_extra_templates: int,
model_config: Union[OpenAIConfig, AzureOpenAIConfig] = None,
) -> List[str]:
if try_import_lib("openai"):
from langtest.augmentation.utils import (
generate_templates_azoi, # azoi means Azure OpenAI
generate_templates_openai,
)
"""This method is used to generate extra templates from a given template."""
from langtest.augmentation.utils import (
generate_templates_azoi, # azoi means Azure OpenAI
generate_templates_openai,
)

params = model_config.copy() if model_config else {}
params = model_config.copy() if model_config else {}

if model_config and model_config.get("provider") == "openai":
return generate_templates_openai(template, num_extra_templates, params)
if model_config and model_config.get("provider") == "openai":
return generate_templates_openai(template, num_extra_templates, params)

elif model_config and model_config.get("provider") == "azure":
return generate_templates_azoi(template, num_extra_templates, params)
elif model_config and model_config.get("provider") == "azure":
return generate_templates_azoi(template, num_extra_templates, params)

else:
return generate_templates_openai(template, num_extra_templates)
else:
return generate_templates_openai(template, num_extra_templates)
6 changes: 3 additions & 3 deletions langtest/augmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@ class OpenAIConfig(TypedDict):
class AzureOpenAIConfig(TypedDict):
"""Azure OpenAI Configuration for API Key and Provider."""

from openai.lib.azure import AzureADTokenProvider

azure_endpoint: str
api_version: str
api_key: str
provider: str
azure_deployment: Union[str, None] = None
azure_ad_token: Union[str, None] = (None,)
azure_ad_token_provider: Union[AzureADTokenProvider, None] = (None,)
azure_ad_token_provider = (None,)
organization: Union[str, None] = (None,)


Expand Down Expand Up @@ -76,6 +74,7 @@ def generate_templates_azoi(
template: str, num_extra_templates: int, model_config: AzureOpenAIConfig
):
"""Generate new templates based on the provided template using Azure OpenAI API."""

import openai

if "provider" in model_config:
Expand Down Expand Up @@ -139,6 +138,7 @@ def generate_templates_openai(
template: str, num_extra_templates: int, model_config: OpenAIConfig = OpenAIConfig()
):
"""Generate new templates based on the provided template using OpenAI API."""

import openai

if "provider" in model_config:
Expand Down
128 changes: 92 additions & 36 deletions langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class DataFactory:
data_sources: Dict[str, BaseDataset] = BaseDataset.data_sources
CURATED_BIAS_DATASETS = ["BoolQ", "XSum"]

def __init__(self, file_path: dict, task: TaskManager, **kwargs) -> None:
def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) -> None:
"""Initializes DataFactory object.

Args:
Expand Down Expand Up @@ -232,6 +232,9 @@ def __init__(self, file_path: dict, task: TaskManager, **kwargs) -> None:
self.init_cls: BaseDataset = None
self.kwargs = kwargs

if self.task == "ner" and "doc_wise" in self._custom_label:
self.kwargs.update({"doc_wise": self._custom_label.get("doc_wise", False)})

def load_raw(self):
"""Loads the data into a raw format"""
self.init_cls = self.data_sources[self.file_ext.replace(".", "")](
Expand All @@ -257,7 +260,9 @@ def load(self) -> List[Sample]:
return DataFactory.load_curated_bias(self._file_path)
else:
self.init_cls = self.data_sources[self.file_ext.replace(".", "")](
self._file_path, task=self.task, **self.kwargs
self._file_path,
task=self.task,
**self.kwargs,
)

loaded_data = self.init_cls.load_data()
Expand Down Expand Up @@ -425,7 +430,9 @@ class ConllDataset(BaseDataset):

COLUMN_NAMES = {task: COLUMN_MAPPER[task] for task in supported_tasks}

def __init__(self, file_path: str, task: TaskManager) -> None:
def __init__(
self, file_path: Union[str, Dict[str, str]], task: TaskManager, **kwargs
) -> None:
"""Initializes ConllDataset object.

Args:
Expand All @@ -434,7 +441,7 @@ def __init__(self, file_path: str, task: TaskManager) -> None:
"""
super().__init__()
self._file_path = file_path

self.doc_wise = kwargs.get("doc_wise") if "doc_wise" in kwargs else False
self.task = task

def load_raw_data(self) -> List[Dict]:
Expand Down Expand Up @@ -495,42 +502,42 @@ def load_data(self) -> List[NERSample]:
]
for d_id, doc in enumerate(docs):
# file content to sentence split
sentences = re.split(r"\n\n|\n\s+\n", doc.strip())

if sentences == [""]:
continue

for sent in sentences:
# sentence string to token level split
tokens = sent.strip().split("\n")

# get annotations from token level split
valid_tokens, token_list = self.__token_validation(tokens)

if not valid_tokens:
logging.warning(Warnings.W004(sent=sent))
continue

# get token and labels from the split
if self.doc_wise:
tokens = doc.strip().split("\n")
ner_labels = []
cursor = 0
for split in token_list:
ner_labels.append(
NERPrediction.from_span(
entity=split[-1],
word=split[0],

for token in tokens:
token_list = token.split()

if len(token_list) == 0:
pred = NERPrediction.from_span(
entity="",
word="\n",
start=cursor,
end=cursor + len(split[0]),
doc_id=d_id,
doc_name=(
docs_strings[d_id] if len(docs_strings) > 0 else ""
),
pos_tag=split[1],
chunk_tag=split[2],
end=cursor,
pos_tag="",
chunk_tag="",
)
)
# +1 to account for the white space
cursor += len(split[0]) + 1
ner_labels.append(pred)
else:
ner_labels.append(
NERPrediction.from_span(
entity=token_list[-1],
word=token_list[0],
start=cursor,
end=cursor + len(token_list[0]),
doc_id=d_id,
doc_name=(
docs_strings[d_id]
if len(docs_strings) > 0
else ""
),
pos_tag=token_list[1],
chunk_tag=token_list[2],
)
)
cursor += len(token_list[0]) + 1

original = " ".join([label.span.word for label in ner_labels])

Expand All @@ -540,6 +547,55 @@ def load_data(self) -> List[NERSample]:
expected_results=NEROutput(predictions=ner_labels),
)
)

else:
sentences = re.split(r"\n\n|\n\s+\n", doc.strip())

if sentences == [""]:
continue

for sent in sentences:
# sentence string to token level split
tokens = sent.strip().split("\n")

# get annotations from token level split
valid_tokens, token_list = self.__token_validation(tokens)

if not valid_tokens:
logging.warning(Warnings.W004(sent=sent))
continue

# get token and labels from the split
ner_labels = []
cursor = 0
for split in token_list:
ner_labels.append(
NERPrediction.from_span(
entity=split[-1],
word=split[0],
start=cursor,
end=cursor + len(split[0]),
doc_id=d_id,
doc_name=(
docs_strings[d_id]
if len(docs_strings) > 0
else ""
),
pos_tag=split[1],
chunk_tag=split[2],
)
)
# +1 to account for the white space
cursor += len(split[0]) + 1

original = " ".join([label.span.word for label in ner_labels])

data.append(
self.task.get_sample_class(
original=original,
expected_results=NEROutput(predictions=ner_labels),
)
)
self.dataset_size = len(data)
return data

Expand Down
68 changes: 39 additions & 29 deletions langtest/datahandler/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,43 +195,53 @@ def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, st
test_case = sample.test_case
original = sample.original
if test_case:
test_case_items = test_case.split()
norm_test_case_items = test_case.lower().split()
norm_original_items = original.lower().split()
test_case_items = test_case.split(" ")
norm_test_case_items = test_case.lower().split(" ")
norm_original_items = original.lower().split(" ")
temp_len = 0
for jdx, item in enumerate(norm_test_case_items):
try:
if item in norm_original_items and jdx >= norm_original_items.index(
item
):
oitem_index = norm_original_items.index(item)
j = sample.expected_results.predictions[oitem_index + temp_len]
if temp_id != j.doc_id and jdx == 0:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
norm_original_items.pop(oitem_index)
temp_len += 1
else:
o_item = sample.expected_results.predictions[jdx].span.word
letters_count = len(set(item) - set(o_item))
if test_case_items[jdx] == "\n":
text += "\n" # add a newline character after each sentence
else:
try:
if (
len(norm_test_case_items) == len(original.lower().split())
or letters_count < 2
item in norm_original_items
and jdx >= norm_original_items.index(item)
):
tl = sample.expected_results.predictions[jdx]
text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n"
oitem_index = norm_original_items.index(item)
j = sample.expected_results.predictions[
oitem_index + temp_len
]
if temp_id != j.doc_id and jdx == 0:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
norm_original_items.pop(oitem_index)
temp_len += 1
else:
text += f"{test_case_items[jdx]} -X- -X- O\n"
except IndexError:
text += f"{test_case_items[jdx]} -X- -X- O\n"
o_item = sample.expected_results.predictions[jdx].span.word
letters_count = len(set(item) - set(o_item))
if (
len(norm_test_case_items)
== len(original.lower().split(" "))
or letters_count < 2
):
tl = sample.expected_results.predictions[jdx]
text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n"
else:
text += f"{test_case_items[jdx]} -X- -X- O\n"
except IndexError:
text += f"{test_case_items[jdx]} -X- -X- O\n"

else:
for j in sample.expected_results.predictions:
if temp_id != j.doc_id:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
if j.span.word == "\n":
text += "\n"
else:
if temp_id != j.doc_id:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n"

return text, temp_id

Expand Down
1 change: 1 addition & 0 deletions langtest/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ class Errors(metaclass=ErrorsWithCodes):
E094 = ("Unsupported category: '{category}'. Supported categories: {supported_category}")
E095 = ("Failed to make API request: {e}")
E096 = ("Failed to generate the templates in Augmentation: {msg}")
E097 = ("Failed to load openai. Please install it using `pip install openai`")


class ColumnNameError(Exception):
Expand Down
Loading
Loading