Skip to content

Commit

Permalink
Update for conversational distillation
Browse files Browse the repository at this point in the history
  • Loading branch information
babu-namburi committed Aug 2, 2024
1 parent 5c0050a commit ad7ac8e
Showing 1 changed file with 149 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def generate_synthetic_data(
train_file_path (Path): Train JSONL file path
validation_file_path (Path, optional): Validation JSONL file path. Defaults to None.
"""

def process_request(idx: str, enable_cot: bool, data: dict, url: str, endpoint_key: str) -> dict:
"""Process a single request.
Expand Down Expand Up @@ -294,12 +293,157 @@ def process_request(idx: str, enable_cot: bool, data: dict, url: str, endpoint_k
"text": None,
"exception": e,
}
def process_conversational_request(idx: str, data: dict, url: str, endpoint_key: str):
"""Process a single conversational request.
Args:
idx (str): Row index in Input data.
data (dict): Payload dict
url (str): Endpoint URL
endpoint_key (str): key to authenticate endpoint request
Returns:
dict: result dictionary
"""
try:
logger.info(f"request_data: {repr(data)}")
# Basic validation for the input data
messages = data.pop("messages", [])
if not messages: # empty messages
return {
"idx": idx,
"status_code": None,
"messages": [],
"exception": "Empty messages"
}
first_message = messages[0]
if first_message['role'] != 'system':
logger.warning(f"First message should be system, but got {first_message['role']}")
#TODO: handle this case
for message in messages[1:]:
role = message['role']
if role not in ('system', 'user'):
logger.warning(f"role should be system or user, but got {role}")
#TODO: handle this case
synthetic_responses = []
for message in messages:
role = message['role']
if role in ('system', 'user'):
synthetic_responses.append(message)
else:
# replace the assistant content from the model
response: Response = _invoke_endpoint(url=url, key=endpoint_key, data={"messages": synthetic_responses} | data)
if response.status_code != 200:
break
logger.info(f"response_text: {response.text}")
response_data = response.json()

logger.info(f"JSON response: {response_data}")
prediction_result = (
None if response.status_code != 200
# response content should be structured as below for a successful vllm response
else response_data['choices'][0]["message"]["content"].strip()
)
synthetic_responses.append({'role': 'assistant', 'content': prediction_result })
return {
"idx": idx,
"status_code": response.status_code,
"messages": synthetic_responses,
"exception": None,
}
except Exception as e:
logger.error(f"idx: {idx}. exception: {e}")
return {
"idx": idx,
"status_code": None,
"messages": [],
"exception": e,
}

def replace_cot_system_message(messages: List[dict]) -> List[dict]:
# Replace the system message without changing the original messages list
cot_system_message = {'role': 'system', 'content': COT_SYSTEM_PROMPT}
return [(cot_system_message if message['role'] == 'system' else message) for message in messages]

def batch_process_conversation_data(input_file_path: Path, output_file_path: Path, batch_size: int) -> None:
"""Batch process data and do a bulk request to teacher model endpoint.
Args:
input_file_path (Path): Input data file path
output_file_path (Path): Path to output directory
batch_size (int): Input batch size for processing rows in train and validation dataset
Raises:
Exception: if success ratio is less than min_endpoint_success_ratio
"""
train_df = pd.read_json(input_file_path, lines=True, chunksize=batch_size)
total_rows = 0
error_count = 0
output_data = []
error_map = {}
ERROR = "error"

for batch in train_df:
total_rows += len(batch)
futures = []

with ThreadPoolExecutor() as executor:
for idx, row in batch.iterrows():
messages = row.iloc[0]
request_data = {
"messages": messages,
**inference_params,
}
futures.append(
executor.submit(
process_conversational_request,
idx,
request_data,
teacher_model_endpoint_url,
teacher_model_endpoint_key
)
)

# wait for results to complete
future_results = {
result["idx"]: result
for result in [future.result() for future in as_completed(futures)]
}

idx = 0
for idx, row in batch.iterrows():
future_result = future_results.get(idx)
logger.info(future_result)
if future_result is None:
logger.error(f"row {idx} not found in future_results")
error_map[ERROR] = error_map.get(ERROR, 0) + 1
if future_result['exception']:
logger.error(f"row {idx} failed with exception: {future_result['exception']}")
error_map[ERROR] = error_map.get(ERROR, 0) + 1
elif future_result['status_code'] != 200:
logger.warning(f"row {idx} request status_code: {future_result['status_code']} != 200")
error_map[future_result['status_code']] = error_map.get(future_result['status_code'], 0) + 1
else:
answer = future_result['messages']
output_data.append(answer)

Path(output_file_path.parent).mkdir(exist_ok=True, parents=True)
with open(output_file_path, 'w') as f:
for entry in output_data:
f.write(json.dumps(entry) + '\n')

if error_map:
logger.info("Error summary. With key denoting non-200 status code or some other error.")
for k, v in error_map.items():
error_count += v
logger.warning(f"{k} => {v}")

success_ratio = float(total_rows - error_count) / total_rows
logger.info(f"Success rate was {success_ratio} for {input_file_path}")
if success_ratio < min_endpoint_success_ratio:
msg = f"Success ratio for dataset {input_file_path}: {success_ratio} < {min_endpoint_success_ratio}."
raise Exception(msg)

def batch_process_data(input_file_path: Path, output_file_path: Path, batch_size: int) -> None:
"""Batch process data and do a bulk request to teacher model endpoint.
Expand Down Expand Up @@ -388,12 +532,16 @@ def batch_process_data(input_file_path: Path, output_file_path: Path, batch_size
raise Exception(msg)

logger.info("Processing train file")
#TODO: conditionally the batch_process_conversation_data based on the data_generation_task_type
batch_process_data(train_file_path, generated_train_file_path, request_batch_size)
# batch_process_conversation_data(train_file_path, generated_train_file_path, request_batch_size)
logger.info("Data generated and saved for train file")

if validation_file_path:
logger.info("Processing validation file")
batch_process_data(validation_file_path, generated_validation_file_path, request_batch_size)
# TODO: conditionally the batch_process_conversation_data based on the data_generation_task_type
# batch_process_conversation_data(validation_file_path, generated_validation_file_path, request_batch_size)
logger.info("Data generated and saved for validation file")


Expand Down

0 comments on commit ad7ac8e

Please sign in to comment.