Skip to content

Commit

Permalink
Update for conversational distillation (#3221)
Browse files Browse the repository at this point in the history
* Update for conversational distillation

* fixed linting errors

* Removed trailing white spaces

* Added the placeholder for data_generation_task_type

* Added doc strings

* Updated enum and added doc strings

* removed whitespaces

* Updated based on the comments.

* Fixed linting issues

* Updated to python3.8 and some minor bug fixes

* Updated the description of the parameter data_generation_task_type
revert changes for cot

* updated description for the data_generation_task_type.

* remvoed the parameter required for data_generation_task_type

* Lint fixes

* Lint fixes

* Updated as per comments.

* Updated the version of the pipeliene and datagen component
  • Loading branch information
babu-namburi authored Aug 6, 2024
1 parent 1e3f36f commit 5f06b01
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_generate_data
version: 0.0.2
version: 0.0.3
type: command

is_deterministic: True
Expand Down Expand Up @@ -85,6 +85,18 @@ inputs:
type: string
default: "true"
description: Enable Chain of thought for data generation

data_generation_task_type:
type: string
enum:
- NLI
- CONVERSATION
- NLU_QA
description: >
Data generation task type. Supported values are:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
outputs:
generated_train_file_path:
Expand Down Expand Up @@ -113,5 +125,6 @@ command: >-
--request_batch_size ${{inputs.request_batch_size}}
--min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}}
--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}
--data_generation_task_type ${{inputs.data_generation_task_type}}
--generated_train_file_path ${{outputs.generated_train_file_path}}
--generated_validation_file_path ${{outputs.generated_validation_file_path}}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Data generatior constants."""

import re

from enum import EnumMeta, Enum
# COMPONENT META
COMPONENT_NAME = "oss_distillation_generate_data"

Expand Down Expand Up @@ -85,3 +85,23 @@ class InferenceMode:
HFTV2_TEXT_GENERATION = "hftv2_text_generation"
VLLM_CHAT_COMPLETION = "vllm_chat_completion"
VLLM_TEXT_GENERATION = "vllm_text_generation"


class MetaEnum(EnumMeta):
"""Metaclass for Enum classes. to use the in operator to check if a value is in the Enum."""

def __contains__(cls, item):
"""Check if the item is in the Enum."""
try:
cls(item)
except ValueError:
return False
return True


class DataGenerationTaskType(str, Enum, metaclass=MetaEnum):
"""Enum for data generation task types."""

NLI = "NLI"
CONVERSATION = "CONVERSATION"
NLU_QUESTION_ANSWERING = "NLU_QA"
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
STOP_TOKEN,
SUPPORTED_FILE_FORMATS,
VLLM_CHAT_SCORE_PATH,
DataGenerationTaskType
)

from common.utils import (
Expand Down Expand Up @@ -180,6 +181,18 @@ def get_parser():
help="This enables Chain of Thought"
)

parser.add_argument(
"--data_generation_task_type",
type=str,
required=True,
help="""Data generation task type. Supported values are:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
""",
choices=[v.value for v in DataGenerationTaskType]
)

return parser


Expand Down Expand Up @@ -231,7 +244,8 @@ def generate_synthetic_data(
generated_train_file_path: Path,
generated_validation_file_path: Path,
train_file_path: Path,
validation_file_path: Path = None,
data_generation_task_type: str,
validation_file_path: Path = None
):
"""Generate and save synthentic data under output_dataset.
Expand All @@ -246,7 +260,6 @@ def generate_synthetic_data(
train_file_path (Path): Train JSONL file path
validation_file_path (Path, optional): Validation JSONL file path. Defaults to None.
"""

def process_request(idx: str, enable_cot: bool, data: dict, url: str, endpoint_key: str) -> dict:
"""Process a single request.
Expand Down Expand Up @@ -295,11 +308,172 @@ def process_request(idx: str, enable_cot: bool, data: dict, url: str, endpoint_k
"exception": e,
}

def process_conversational_request(idx: str, data: dict, url: str, endpoint_key: str):
"""Process a single conversational request.
Args:
idx (str): Row index in Input data.
data (dict): Payload dict
url (str): Endpoint URL
endpoint_key (str): key to authenticate endpoint request
Returns:
dict: result dictionary
"""
try:
logger.info(f"request_data: {repr(data)}")
# Basic validation for the input data
messages = data.pop("messages", [])
if not messages: # empty messages
return {
"idx": idx,
"status_code": None,
"messages": [],
"exception": "Empty messages"
}
first_message = messages[0]
if first_message['role'] != 'system':
logger.warning(f"First message should be system, but got {first_message['role']}")
return {"idx": idx,
"status_code": None,
"messages": [],
"exception": ("Incorrect format.\n"
f"First message should be system, but got {first_message['role']}"),
}
for message in messages[1:]:
role = message['role']
if role not in ('assistant', 'user'):
logger.warning(f"role should be system or user, but got {role}")
return {"idx": idx,
"status_code": None,
"messages": [],
"exception": f"Incorrect format.\nRole should be assistant or user, but got {role}"
}

synthetic_responses = []
for message in messages:
role = message['role']
if role in ('system', 'user'):
synthetic_responses.append(message)
else:
data_with_inference_parameters = {"messages": synthetic_responses}
for key, value in data.items():
data_with_inference_parameters[key] = value
# replace the assistant content from the model
response: Response = _invoke_endpoint(url=url, key=endpoint_key,
data=data_with_inference_parameters)
if response.status_code != 200:
break
logger.info(f"response_text: {response.text}")
response_data = response.json()

logger.info(f"JSON response: {response_data}")
prediction_result = (
None if response.status_code != 200
# response content should be structured as below for a successful vllm response
else response_data['choices'][0]["message"]["content"].strip()
)
synthetic_responses.append({'role': 'assistant', 'content': prediction_result})
return {
"idx": idx,
"status_code": response.status_code,
"messages": synthetic_responses,
"exception": (f"Not able to generate synthetic response for all turns for idx: {idx}"
if response.status_code != 200
else
None),
}
except Exception as e:
logger.error(f"idx: {idx}. exception: {e}")
return {
"idx": idx,
"status_code": None,
"messages": [],
"exception": e,
}

def replace_cot_system_message(messages: List[dict]) -> List[dict]:
# Replace the system message without changing the original messages list
cot_system_message = {'role': 'system', 'content': COT_SYSTEM_PROMPT}
return [(cot_system_message if message['role'] == 'system' else message) for message in messages]

def batch_process_conversation_data(input_file_path: Path, output_file_path: Path, batch_size: int) -> None:
"""Batch process data and do a bulk request to teacher model endpoint.
Args:
input_file_path (Path): Input data file path
output_file_path (Path): Path to output directory
batch_size (int): Input batch size for processing rows in train and validation dataset
Raises:
Exception: if success ratio is less than min_endpoint_success_ratio
"""
train_df = pd.read_json(input_file_path, lines=True, chunksize=batch_size)
total_rows = 0
error_count = 0
output_data = []
error_map = {}
ERROR = "error"

for batch in train_df:
total_rows += len(batch)
futures = []

with ThreadPoolExecutor() as executor:
for idx, row in batch.iterrows():
messages = row.iloc[0]
request_data = {
"messages": messages,
**inference_params,
}
futures.append(
executor.submit(
process_conversational_request,
idx,
request_data,
teacher_model_endpoint_url,
teacher_model_endpoint_key
)
)

# wait for results to complete
future_results = {
result["idx"]: result
for result in [future.result() for future in as_completed(futures)]
}

idx = 0
for idx, row in batch.iterrows():
future_result = future_results.get(idx)
logger.info(future_result)
if future_result is None:
logger.error(f"row {idx} not found in future_results")
error_map[ERROR] = error_map.get(ERROR, 0) + 1
elif future_result['exception']:
logger.error(f"row {idx} failed with exception: {future_result['exception']}")
error_map[ERROR] = error_map.get(ERROR, 0) + 1
elif future_result['status_code'] != 200:
logger.warning(f"row {idx} request status_code: {future_result['status_code']} != 200")
error_map[future_result['status_code']] = error_map.get(future_result['status_code'], 0) + 1
else:
output_data.append({"messages": future_result['messages']})
Path(output_file_path.parent).mkdir(exist_ok=True, parents=True)
with open(output_file_path, 'w') as f:
for entry in output_data:
f.write(json.dumps(entry) + '\n')

if error_map:
logger.info("Error summary. With key denoting non-200 status code or some other error.")
for k, v in error_map.items():
error_count += v
logger.warning(f"{k} => {v}")

success_ratio = float(total_rows - error_count) / total_rows
logger.info(f"Success rate was {success_ratio} for {input_file_path}")
if success_ratio < min_endpoint_success_ratio:
msg = f"Success ratio for dataset {input_file_path}: {success_ratio} < {min_endpoint_success_ratio}."
raise Exception(msg)

def batch_process_data(input_file_path: Path, output_file_path: Path, batch_size: int) -> None:
"""Batch process data and do a bulk request to teacher model endpoint.
Expand Down Expand Up @@ -386,14 +560,21 @@ def batch_process_data(input_file_path: Path, output_file_path: Path, batch_size
if success_ratio < min_endpoint_success_ratio:
msg = f"Success ratio for dataset {input_file_path}: {success_ratio} < {min_endpoint_success_ratio}."
raise Exception(msg)

logger.info("Processing train file")
batch_process_data(train_file_path, generated_train_file_path, request_batch_size)

if data_generation_task_type == DataGenerationTaskType.CONVERSATION:
batch_process_conversation_data(train_file_path, generated_train_file_path, request_batch_size)
else:
batch_process_data(train_file_path, generated_train_file_path, request_batch_size)

logger.info("Data generated and saved for train file")

if validation_file_path:
logger.info("Processing validation file")
batch_process_data(validation_file_path, generated_validation_file_path, request_batch_size)
if data_generation_task_type == DataGenerationTaskType.CONVERSATION:
batch_process_conversation_data(validation_file_path, generated_validation_file_path, request_batch_size)
else:
batch_process_data(validation_file_path, generated_validation_file_path, request_batch_size)
logger.info("Data generated and saved for validation file")


Expand All @@ -417,6 +598,7 @@ def data_import(args: Namespace):
request_batch_size = args.request_batch_size
min_endpoint_success_ratio = args.min_endpoint_success_ratio
enable_cot_str = args.enable_chain_of_thought
data_generation_task_type = args.data_generation_task_type

# validate file formats
_validate_file_paths_with_supported_formats([args.train_file_path, args.validation_file_path])
Expand Down Expand Up @@ -485,6 +667,7 @@ def data_import(args: Namespace):
generated_train_file_path=generated_train_file_path,
generated_validation_file_path=generated_validation_file_path,
train_file_path=train_file_path,
data_generation_task_type=data_generation_task_type,
validation_file_path=validation_file_path,
)

Expand Down
17 changes: 15 additions & 2 deletions assets/training/distillation/components/pipeline/spec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
name: oss_distillation_pipeline
version: 0.0.2
version: 0.0.3
type: pipeline


Expand Down Expand Up @@ -126,6 +126,18 @@ inputs:
default: "false"
description: Enable Chain of thought for data generation

data_generation_task_type:
type: string
enum:
- NLI
- CONVERSATION
- NLU_QA
description: >
Data generation task type. Supported values are:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
## OSS Finetune Input Parameters
number_of_gpu_to_use_finetuning:
type: integer
Expand Down Expand Up @@ -193,7 +205,7 @@ outputs:
jobs:
oss_distillation_generate_data:
type: command
component: azureml:oss_distillation_generate_data:0.0.2
component: azureml:oss_distillation_generate_data:0.0.3
compute: '${{parent.inputs.compute_data_generation}}'
resources:
instance_type: '${{parent.inputs.instance_type_data_generation}}'
Expand All @@ -211,6 +223,7 @@ jobs:
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
Expand Down

0 comments on commit 5f06b01

Please sign in to comment.