diff --git a/langtest/augmentation/augmenter.py b/langtest/augmentation/augmenter.py index f587adc27..dd65efe78 100644 --- a/langtest/augmentation/augmenter.py +++ b/langtest/augmentation/augmenter.py @@ -8,10 +8,15 @@ from langtest.transform import TestFactory from langtest.tasks.task import TaskManager from langtest.utils.custom_types.sample import Sample +from langtest.logger import logger class DataAugmenter: - def __init__(self, task: Union[str, TaskManager], config: Union[str, dict]) -> None: + def __init__( + self, + task: Union[str, TaskManager], + config: Union[str, dict], + ) -> None: """ Initialize the DataAugmenter. @@ -241,11 +246,20 @@ def prepare_hash_map( return hashmap - def save(self, file_path: str): + def save(self, file_path: str, for_gen_ai=False) -> None: """ Save the augmented data. """ - self.__datafactory.export(data=self.__augmented_data, output_path=file_path) + try: + # .json file allow only for_gen_ai boolean is true and task is ner + # then file_path should be .json + if not (for_gen_ai) and self.__task.task_name == "ner": + if file_path.endswith(".json"): + raise ValueError("File path shouldn't be .json file") + + self.__datafactory.export(data=self.__augmented_data, output_path=file_path) + except Exception as e: + logger.error(f"Error in saving the augmented data: {e}") def __or__(self, other: Iterable): results = self.augment(other) diff --git a/langtest/datahandler/datasource.py b/langtest/datahandler/datasource.py index 1d89303ae..5e35fc97b 100644 --- a/langtest/datahandler/datasource.py +++ b/langtest/datahandler/datasource.py @@ -26,6 +26,7 @@ from ..errors import Warnings, Errors import glob from pkg_resources import resource_filename +from langtest.logger import logger COLUMN_MAPPER = { "text-classification": { @@ -551,14 +552,49 @@ def export_data(self, data: List[NERSample], output_path: str): output_path (str): path to save the data to """ - otext = "" - temp_id = None - for i in data: - text, temp_id = Formatter.process(i, output_format="conll", temp_id=temp_id) - otext += text + "\n" - - with open(output_path, "wb") as fwriter: - fwriter.write(bytes(otext, encoding="utf-8")) + if output_path.endswith(".conll"): + otext = "" + temp_id = None + for i in data: + text, temp_id = Formatter.process( + i, output_format="conll", temp_id=temp_id + ) + otext += text + "\n" + + with open(output_path, "wb") as fwriter: + fwriter.write(bytes(otext, encoding="utf-8")) + + elif output_path.endswith(".json"): + import json + from .utils import process_document + + logger.warn("Only for Gen AI Lab use") + logger.info("Converting NER sample to JSON format") + + otext_list = [] + temp_id = None + for i in data: + otext, temp_id = Formatter.process( + i, output_format="json", temp_id=temp_id + ) + processed_text = process_document(otext) + # add test info + tem_dict = processed_text["data"] + tem_dict["test_type"] = i.test_type or "null" + tem_dict["category"] = i.category or "null" + + processed_text["data"] = tem_dict + otext_list.append(processed_text) + + # otext += text + "\n" + # if temp_id2 != temp_id: + # processed_text = process_document(otext) + # otext_list.append(processed_text) + # otext = "" + # temp_id = temp_id2 + + with open(output_path, "w") as fwriter: + json.dump(otext_list, fwriter) def __token_validation(self, tokens: str) -> (bool, List[List[str]]): # type: ignore """Validates the tokens in a sentence. diff --git a/langtest/datahandler/format.py b/langtest/datahandler/format.py index 621fe34e0..808c0ade2 100644 --- a/langtest/datahandler/format.py +++ b/langtest/datahandler/format.py @@ -235,6 +235,13 @@ def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, st return text, temp_id + @staticmethod + def to_json(sample: NERSample, temp_id: int = None) -> dict: + """Converts a NERSample to a JSON string.""" + + text, temp_id = NEROutputFormatter.to_conll(sample, temp_id) + return text, temp_id + class QAFormatter(BaseFormatter): def to_jsonl(sample: QASample, *args, **kwargs): diff --git a/langtest/datahandler/utils.py b/langtest/datahandler/utils.py new file mode 100644 index 000000000..87a90fd07 --- /dev/null +++ b/langtest/datahandler/utils.py @@ -0,0 +1,116 @@ +from datetime import datetime + + +def get_results(tokens, labels, text): + current_entity = None + current_span = [] + results = [] + char_pos = 0 # Tracks the character position in the text + + for i, (token, label) in enumerate(zip(tokens, labels)): + token_start = char_pos + token_end = token_start + len(token) + if label.startswith("B-"): + if current_entity: + results.append( + { + "value": { + "start": current_span[0], + "end": current_span[-1], + "text": text[current_span[0] : current_span[-1]], + "labels": [current_entity], + "confidence": 1, + }, + "from_name": "label", + "to_name": "text", + "type": "labels", + } + ) + current_entity = label[2:] + current_span = [token_start, token_end] + elif label.startswith("I-") and current_entity: + current_span[-1] = token_end + elif label == "O" and current_entity: + results.append( + { + "value": { + "start": current_span[0], + "end": current_span[-1], + "text": text[current_span[0] : current_span[-1]], + "labels": [current_entity], + "confidence": 1, + }, + "from_name": "label", + "to_name": "text", + "type": "labels", + } + ) + current_entity = None + current_span = [] + + # Move to the next character position (account for the space between tokens) + char_pos = ( + token_end + 1 + if i + 1 < len(tokens) and tokens[i + 1] not in [".", ",", "!", "?"] + else token_end + ) + + if current_entity: + results.append( + { + "value": { + "start": current_span[0], + "end": current_span[-1], + "text": text[current_span[0] : current_span[-1]], + "labels": [current_entity], + "confidence": 1, + }, + "from_name": "label", + "to_name": "text", + "type": "labels", + } + ) + return results + + +def process_document(doc): + tokens = [] + labels = [] + + # replace the -DOCSTART- tag with a newline + doc = doc.replace("-DOCSTART-", "") + + for line in doc.strip().split("\n"): + if line.strip(): + parts = line.strip().split() + if len(parts) == 4: + token, _, _, label = parts + tokens.append(token) + labels.append(label) + + text = "" + for _, token in enumerate(tokens): + if token in {".", ",", "!", "?"}: + text = text.rstrip() + token + " " + else: + text += token + " " + + text = text.rstrip() + + results = get_results(tokens, labels, text) + now = datetime.utcnow() + current_date = now.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + json_output = { + "created_ago": current_date, + "result": results, + "honeypot": True, + "lead_time": 10, + "confidence_range": [0, 1], + "submitted_at": current_date, + "updated_at": current_date, + "predictions": [], + "created_at": current_date, + "data": {"text": text}, + } + + return json_output