diff --git a/assets/training/distillation/components/data_generation/src/generate_data.py b/assets/training/distillation/components/data_generation/src/generate_data.py index 2092978ebd..f40bd7f10f 100644 --- a/assets/training/distillation/components/data_generation/src/generate_data.py +++ b/assets/training/distillation/components/data_generation/src/generate_data.py @@ -246,7 +246,6 @@ def generate_synthetic_data( train_file_path (Path): Train JSONL file path validation_file_path (Path, optional): Validation JSONL file path. Defaults to None. """ - def process_request(idx: str, enable_cot: bool, data: dict, url: str, endpoint_key: str) -> dict: """Process a single request. @@ -294,12 +293,157 @@ def process_request(idx: str, enable_cot: bool, data: dict, url: str, endpoint_k "text": None, "exception": e, } + def process_conversational_request(idx: str, data: dict, url: str, endpoint_key: str): + """Process a single conversational request. + + Args: + idx (str): Row index in Input data. + data (dict): Payload dict + url (str): Endpoint URL + endpoint_key (str): key to authenticate endpoint request + + Returns: + dict: result dictionary + """ + try: + logger.info(f"request_data: {repr(data)}") + # Basic validation for the input data + messages = data.pop("messages", []) + if not messages: # empty messages + return { + "idx": idx, + "status_code": None, + "messages": [], + "exception": "Empty messages" + } + first_message = messages[0] + if first_message['role'] != 'system': + logger.warning(f"First message should be system, but got {first_message['role']}") + #TODO: handle this case + for message in messages[1:]: + role = message['role'] + if role not in ('system', 'user'): + logger.warning(f"role should be system or user, but got {role}") + #TODO: handle this case + synthetic_responses = [] + for message in messages: + role = message['role'] + if role in ('system', 'user'): + synthetic_responses.append(message) + else: + # replace the assistant content from the model + response: Response = _invoke_endpoint(url=url, key=endpoint_key, data={"messages": synthetic_responses} | data) + if response.status_code != 200: + break + logger.info(f"response_text: {response.text}") + response_data = response.json() + + logger.info(f"JSON response: {response_data}") + prediction_result = ( + None if response.status_code != 200 + # response content should be structured as below for a successful vllm response + else response_data['choices'][0]["message"]["content"].strip() + ) + synthetic_responses.append({'role': 'assistant', 'content': prediction_result }) + return { + "idx": idx, + "status_code": response.status_code, + "messages": synthetic_responses, + "exception": None, + } + except Exception as e: + logger.error(f"idx: {idx}. exception: {e}") + return { + "idx": idx, + "status_code": None, + "messages": [], + "exception": e, + } def replace_cot_system_message(messages: List[dict]) -> List[dict]: # Replace the system message without changing the original messages list cot_system_message = {'role': 'system', 'content': COT_SYSTEM_PROMPT} return [(cot_system_message if message['role'] == 'system' else message) for message in messages] + def batch_process_conversation_data(input_file_path: Path, output_file_path: Path, batch_size: int) -> None: + """Batch process data and do a bulk request to teacher model endpoint. + + Args: + input_file_path (Path): Input data file path + output_file_path (Path): Path to output directory + batch_size (int): Input batch size for processing rows in train and validation dataset + + Raises: + Exception: if success ratio is less than min_endpoint_success_ratio + """ + train_df = pd.read_json(input_file_path, lines=True, chunksize=batch_size) + total_rows = 0 + error_count = 0 + output_data = [] + error_map = {} + ERROR = "error" + + for batch in train_df: + total_rows += len(batch) + futures = [] + + with ThreadPoolExecutor() as executor: + for idx, row in batch.iterrows(): + messages = row.iloc[0] + request_data = { + "messages": messages, + **inference_params, + } + futures.append( + executor.submit( + process_conversational_request, + idx, + request_data, + teacher_model_endpoint_url, + teacher_model_endpoint_key + ) + ) + + # wait for results to complete + future_results = { + result["idx"]: result + for result in [future.result() for future in as_completed(futures)] + } + + idx = 0 + for idx, row in batch.iterrows(): + future_result = future_results.get(idx) + logger.info(future_result) + if future_result is None: + logger.error(f"row {idx} not found in future_results") + error_map[ERROR] = error_map.get(ERROR, 0) + 1 + if future_result['exception']: + logger.error(f"row {idx} failed with exception: {future_result['exception']}") + error_map[ERROR] = error_map.get(ERROR, 0) + 1 + elif future_result['status_code'] != 200: + logger.warning(f"row {idx} request status_code: {future_result['status_code']} != 200") + error_map[future_result['status_code']] = error_map.get(future_result['status_code'], 0) + 1 + else: + answer = future_result['messages'] + output_data.append(answer) + + Path(output_file_path.parent).mkdir(exist_ok=True, parents=True) + with open(output_file_path, 'w') as f: + for entry in output_data: + f.write(json.dumps(entry) + '\n') + + if error_map: + logger.info("Error summary. With key denoting non-200 status code or some other error.") + for k, v in error_map.items(): + error_count += v + logger.warning(f"{k} => {v}") + + success_ratio = float(total_rows - error_count) / total_rows + logger.info(f"Success rate was {success_ratio} for {input_file_path}") + if success_ratio < min_endpoint_success_ratio: + msg = f"Success ratio for dataset {input_file_path}: {success_ratio} < {min_endpoint_success_ratio}." + raise Exception(msg) + def batch_process_data(input_file_path: Path, output_file_path: Path, batch_size: int) -> None: """Batch process data and do a bulk request to teacher model endpoint. @@ -388,12 +532,16 @@ def batch_process_data(input_file_path: Path, output_file_path: Path, batch_size raise Exception(msg) logger.info("Processing train file") + #TODO: conditionally the batch_process_conversation_data based on the data_generation_task_type batch_process_data(train_file_path, generated_train_file_path, request_batch_size) + # batch_process_conversation_data(train_file_path, generated_train_file_path, request_batch_size) logger.info("Data generated and saved for train file") if validation_file_path: logger.info("Processing validation file") batch_process_data(validation_file_path, generated_validation_file_path, request_batch_size) + # TODO: conditionally the batch_process_conversation_data based on the data_generation_task_type + # batch_process_conversation_data(validation_file_path, generated_validation_file_path, request_batch_size) logger.info("Data generated and saved for validation file")