Skip to content

Commit

Permalink
feat: introduce validation component for distillation pipeline (#3284)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShobithNandakumar authored Aug 26, 2024
1 parent 326b05d commit 7c566f9
Show file tree
Hide file tree
Showing 12 changed files with 986 additions and 32 deletions.
45 changes: 45 additions & 0 deletions assets/training/distillation/components/data_generation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
## Data Generation Component

### Name

oss_distillation_generate_data

### Version

0.0.5

### Type

command

### Description

Component to generate data from teacher model enpoint

## Inputs

| Name | Description | Type | Optional |
|--------------------| ----------------------------------------------------------------------------------- | ------- | ------- |
| train_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True |
| validation_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True
| teacher_model_endpoint_name | Teacher model endpoint name. | string | True
| teacher_model_endpoint_url | Teacher model endpoint URL. | string | True
| teacher_model_endpoint_key | Teacher model endpoint key. | string | True
| teacher_model_max_new_tokens | Teacher model max_new_tokens inference parameter. | integer | True
| teacher_model_temperature | Teacher model temperature inference parameter. | number | True
| teacher_model_top_p | Teacher model top_p inference parameter. | number | True | |
| teacher_model_frequency_penalty | Teacher model frequency penalty inference parameter. | number | True |
| teacher_model_presence_penalty | Teacher model presence penalty inference parameter. | number | True
| teacher_model_stop | Teacher model stop inference parameter. | string | True
| request_batch_size | No of data records to hit teacher model endpoint in one go. | integer | True
| min_endpoint_success_ratio | The minimum value of (successful_requests / total_requests) required for classifying inference as successful. | number | True
| enable_chain_of_thought | Enable Chain of thought for data generation. | string | True
| data_generation_task_type | Data generation task types, supported values - NLI, CONVERSATION, NLU_QA. | string | False
| validation_output | Validation status from validation component. | uri_file | True

## Outputs

| Name | Description | Type |
| -------------------- | -------------------------------------------------------- | ------------ |
| generated_train_file_path | Generated training data. | uri_file |
| generated_validation_file_path | Generated validation data. | uri_file |
15 changes: 11 additions & 4 deletions assets/training/distillation/components/data_generation/spec.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_generate_data
version: 0.0.4
version: 0.0.5
type: command

is_deterministic: True

display_name: OSS Distillation Generate Data
description: Component to generate data from teacher model enpoint

environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/63
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/66

inputs:
# Inputs
Expand Down Expand Up @@ -97,7 +97,14 @@ inputs:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
# Output of validation component.
validation_output:
type: uri_file
optional: true
description: Validation status.
mode: rw_mount

outputs:
generated_train_file_path:
type: uri_file
Expand All @@ -108,7 +115,7 @@ outputs:
description: Generated validation data
mode: rw_mount

code: src/
code: ../../src
command: >-
python generate_data.py
--train_file_path ${{inputs.train_file_path}}
Expand Down
49 changes: 46 additions & 3 deletions assets/training/distillation/components/pipeline/spec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json
name: oss_distillation_pipeline
version: 0.0.5
version: 0.0.6
type: pipeline


Expand All @@ -9,6 +9,10 @@ description: Component to generate data from teacher model enpoint and finetune

inputs:
# Compute parameters
instance_type_pipeline_validation:
type: string
optional: True
description: Instance type to be used for validation component. The parameter compute_pipeline_validation must be set to 'serverless' for instance_type to be used.
instance_type_data_generation:
type: string
optional: true
Expand All @@ -25,6 +29,12 @@ inputs:
default: Singularity.ND96amrs_A100_v4
description: Instance type to be used for finetune component in case of virtual cluster compute, eg. Singularity.ND40_v2. The parameter compute_finetune must be set to 'serverless' for instance_type to be used

compute_pipeline_validation:
type: string
optional: True
default: 'serverless'
description: compute to be used for validation component

compute_data_generation:
type: string
optional: true
Expand All @@ -50,8 +60,8 @@ inputs:
compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value.
If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used
# ########################### Data Generator Component ########################### #

## OSS Data generator Input Parameters
train_file_path:
type: uri_file
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
Expand Down Expand Up @@ -138,7 +148,8 @@ inputs:
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
## OSS Finetune Input Parameters
# ########################### Finetuning Component ########################### #

number_of_gpu_to_use_finetuning:
type: integer
default: 1
Expand Down Expand Up @@ -203,6 +214,37 @@ outputs:
mode: rw_mount

jobs:
oss_distillation_validate_pipeline:
type: command
component: azureml:oss_distillation_validate_pipeline:0.0.1
compute: '${{parent.inputs.compute_pipeline_validation}}'
resources:
instance_type: '${{parent.inputs.instance_type_pipeline_validation}}'
identity:
type: user_identity
inputs:
train_file_path: '${{parent.inputs.train_file_path}}'
validation_file_path: '${{parent.inputs.validation_file_path}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}'
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
request_batch_size: '${{parent.inputs.request_batch_size}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
num_train_epochs: '${{parent.inputs.num_train_epochs}}'
per_device_train_batch_size: '${{parent.inputs.per_device_train_batch_size}}'
learning_rate: '${{parent.inputs.learning_rate}}'
outputs:
validation_info:
type: uri_file
path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.json

oss_distillation_generate_data:
type: command
component: azureml:oss_distillation_generate_data:0.0.4
Expand All @@ -226,6 +268,7 @@ jobs:
teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}'
request_batch_size: '${{parent.inputs.request_batch_size}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
validation_output: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}'
outputs:
generated_train_file_path:
type: uri_file
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
## Pipeline Validation Component

### Name

oss_distillation_validate_pipeline

### Version

0.0.1

### Type

command

### Description

Component to validate all inputs to the distillation pipeline.

## Inputs

| Name | Description | Type | Optional |
|--------------------| ----------------------------------------------------------------------------------- | ------- | ------- |
| train_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True |
| validation_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True
| teacher_model_endpoint_name | Teacher model endpoint name. | string | True
| teacher_model_endpoint_url | Teacher model endpoint URL. | string | True
| teacher_model_endpoint_key | Teacher model endpoint key. | string | True
| teacher_model_max_new_tokens | Teacher model max_new_tokens inference parameter. | integer | True
| teacher_model_temperature | Teacher model temperature inference parameter. | number | True
| teacher_model_top_p | Teacher model top_p inference parameter. | number | True | |
| teacher_model_frequency_penalty | Teacher model frequency penalty inference parameter. | number | True |
| teacher_model_presence_penalty | Teacher model presence penalty inference parameter. | number | True
| teacher_model_stop | Teacher model stop inference parameter. | string | True
| request_batch_size | No of data records to hit teacher model endpoint in one go. | integer | True
| min_endpoint_success_ratio | The minimum value of (successful_requests / total_requests) required for classifying inference as successful. | number | True
| enable_chain_of_thought | Enable Chain of thought for data generation. | string | True
| num_train_epochs | Number of training epochs. | string | True
| data_generation_task_type | Data generation task types, supported values - NLI, CONVERSATION, NLU_QA. | string | False
| per_device_train_batch_size | Train batch size. | integer | True
| learning_rate | Start learning rate. | number | True

## Outputs

| Name | Description | Type |
| -------------------- | -------------------------------------------------------- | ------------ |
| validation_info | Validation status file. | uri_file |
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
type: component
spec: spec.yaml
categories: ["Foundational Models", "Finetune", "Distillation"]
147 changes: 147 additions & 0 deletions assets/training/distillation/components/pipeline_validation/spec.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_validate_pipeline
version: 0.0.1
type: command

is_deterministic: true

display_name: OSS Distillation Validate Pipeline
description: Component to validate inputs to the distillation pipeline

environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/66

code: ../../src

inputs:
# Inputs
train_file_path:
type: uri_file
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount

validation_file_path:
type: uri_file
optional: true
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount

teacher_model_endpoint_name:
type: string
optional: true
description: Teacher model endpoint name

teacher_model_endpoint_url:
type: string
optional: true
description: Teacher model endpoint URL

teacher_model_endpoint_key:
type: string
optional: true
description: Teacher model endpoint key

teacher_model_max_new_tokens:
type: integer
default: 128
description: Teacher model max_new_tokens inference parameter

teacher_model_temperature:
type: number
default: 0.2
description: Teacher model temperature inference parameter

teacher_model_top_p:
type: number
default: 0.1
description: Teacher model top_p inference parameter

teacher_model_frequency_penalty:
type: number
default: 0.0
description: Teacher model frequency penalty inference parameter

teacher_model_presence_penalty:
type: number
default: 0.0
description: Teacher model presence penalty inference parameter

teacher_model_stop:
type: string
optional: true
description: Teacher model stop inference parameter

request_batch_size:
type: integer
default: 10
description: No of data records to hit teacher model endpoint in one go

min_endpoint_success_ratio:
type: number
default: 0.7
description: >
The minimum value of (successful_requests / total_requests) required for classifying inference as successful.
If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed.
By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.)
enable_chain_of_thought:
type: string
default: "true"
description: Enable Chain of thought for data generation

data_generation_task_type:
type: string
enum:
- NLI
- CONVERSATION
- NLU_QA
description: >
Data generation task type. Supported values are:
1. NLI: Generate Natural Language Inference data
2. CONVERSATION: Generate conversational data (multi/single turn)
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data
num_train_epochs:
type: integer
default: 1
optional: true
description: training epochs

per_device_train_batch_size:
type: integer
default: 1
optional: true
description: Train batch size

learning_rate:
type: number
default: 3e-04
optional: true
description: Start learning rate.

outputs:
validation_info:
type: uri_file
description: Validation status.
mode: rw_mount

command: >-
python validate_pipeline.py
--train_file_path ${{inputs.train_file_path}}
$[[--validation_file_path ${{inputs.validation_file_path}}]]
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]]
$[[--teacher_model_endpoint_url ${{inputs.teacher_model_endpoint_url}}]]
$[[--teacher_model_endpoint_key ${{inputs.teacher_model_endpoint_key}}]]
--teacher_model_max_new_tokens ${{inputs.teacher_model_max_new_tokens}}
--teacher_model_temperature ${{inputs.teacher_model_temperature}}
--teacher_model_top_p ${{inputs.teacher_model_top_p}}
--teacher_model_frequency_penalty ${{inputs.teacher_model_frequency_penalty}}
--teacher_model_presence_penalty ${{inputs.teacher_model_presence_penalty}}
$[[--teacher_model_stop ${{inputs.teacher_model_stop}}]]
--request_batch_size ${{inputs.request_batch_size}}
--min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}}
--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}
--data_generation_task_type ${{inputs.data_generation_task_type}}
$[[--num_train_epochs ${{inputs.num_train_epochs}}]]
$[[--per_device_train_batch_size ${{inputs.per_device_train_batch_size}}]]
$[[--learning_rate ${{inputs.learning_rate}}]]
--validation_info ${{outputs.validation_info}}
Loading

0 comments on commit 7c566f9

Please sign in to comment.