-
Notifications
You must be signed in to change notification settings - Fork 136
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: introduce validation component for distillation pipeline (#3284)
- Loading branch information
1 parent
326b05d
commit 7c566f9
Showing
12 changed files
with
986 additions
and
32 deletions.
There are no files selected for viewing
45 changes: 45 additions & 0 deletions
45
assets/training/distillation/components/data_generation/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
## Data Generation Component | ||
|
||
### Name | ||
|
||
oss_distillation_generate_data | ||
|
||
### Version | ||
|
||
0.0.5 | ||
|
||
### Type | ||
|
||
command | ||
|
||
### Description | ||
|
||
Component to generate data from teacher model enpoint | ||
|
||
## Inputs | ||
|
||
| Name | Description | Type | Optional | | ||
|--------------------| ----------------------------------------------------------------------------------- | ------- | ------- | | ||
| train_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True | | ||
| validation_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True | ||
| teacher_model_endpoint_name | Teacher model endpoint name. | string | True | ||
| teacher_model_endpoint_url | Teacher model endpoint URL. | string | True | ||
| teacher_model_endpoint_key | Teacher model endpoint key. | string | True | ||
| teacher_model_max_new_tokens | Teacher model max_new_tokens inference parameter. | integer | True | ||
| teacher_model_temperature | Teacher model temperature inference parameter. | number | True | ||
| teacher_model_top_p | Teacher model top_p inference parameter. | number | True | | | ||
| teacher_model_frequency_penalty | Teacher model frequency penalty inference parameter. | number | True | | ||
| teacher_model_presence_penalty | Teacher model presence penalty inference parameter. | number | True | ||
| teacher_model_stop | Teacher model stop inference parameter. | string | True | ||
| request_batch_size | No of data records to hit teacher model endpoint in one go. | integer | True | ||
| min_endpoint_success_ratio | The minimum value of (successful_requests / total_requests) required for classifying inference as successful. | number | True | ||
| enable_chain_of_thought | Enable Chain of thought for data generation. | string | True | ||
| data_generation_task_type | Data generation task types, supported values - NLI, CONVERSATION, NLU_QA. | string | False | ||
| validation_output | Validation status from validation component. | uri_file | True | ||
|
||
## Outputs | ||
|
||
| Name | Description | Type | | ||
| -------------------- | -------------------------------------------------------- | ------------ | | ||
| generated_train_file_path | Generated training data. | uri_file | | ||
| generated_validation_file_path | Generated validation data. | uri_file | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
assets/training/distillation/components/pipeline_validation/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
## Pipeline Validation Component | ||
|
||
### Name | ||
|
||
oss_distillation_validate_pipeline | ||
|
||
### Version | ||
|
||
0.0.1 | ||
|
||
### Type | ||
|
||
command | ||
|
||
### Description | ||
|
||
Component to validate all inputs to the distillation pipeline. | ||
|
||
## Inputs | ||
|
||
| Name | Description | Type | Optional | | ||
|--------------------| ----------------------------------------------------------------------------------- | ------- | ------- | | ||
| train_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True | | ||
| validation_file_path | Path to the registered training data set in `jsonl, json, csv, tsv and parquet` format. | uri_file | True | ||
| teacher_model_endpoint_name | Teacher model endpoint name. | string | True | ||
| teacher_model_endpoint_url | Teacher model endpoint URL. | string | True | ||
| teacher_model_endpoint_key | Teacher model endpoint key. | string | True | ||
| teacher_model_max_new_tokens | Teacher model max_new_tokens inference parameter. | integer | True | ||
| teacher_model_temperature | Teacher model temperature inference parameter. | number | True | ||
| teacher_model_top_p | Teacher model top_p inference parameter. | number | True | | | ||
| teacher_model_frequency_penalty | Teacher model frequency penalty inference parameter. | number | True | | ||
| teacher_model_presence_penalty | Teacher model presence penalty inference parameter. | number | True | ||
| teacher_model_stop | Teacher model stop inference parameter. | string | True | ||
| request_batch_size | No of data records to hit teacher model endpoint in one go. | integer | True | ||
| min_endpoint_success_ratio | The minimum value of (successful_requests / total_requests) required for classifying inference as successful. | number | True | ||
| enable_chain_of_thought | Enable Chain of thought for data generation. | string | True | ||
| num_train_epochs | Number of training epochs. | string | True | ||
| data_generation_task_type | Data generation task types, supported values - NLI, CONVERSATION, NLU_QA. | string | False | ||
| per_device_train_batch_size | Train batch size. | integer | True | ||
| learning_rate | Start learning rate. | number | True | ||
|
||
## Outputs | ||
|
||
| Name | Description | Type | | ||
| -------------------- | -------------------------------------------------------- | ------------ | | ||
| validation_info | Validation status file. | uri_file | |
3 changes: 3 additions & 0 deletions
3
assets/training/distillation/components/pipeline_validation/asset.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
type: component | ||
spec: spec.yaml | ||
categories: ["Foundational Models", "Finetune", "Distillation"] |
147 changes: 147 additions & 0 deletions
147
assets/training/distillation/components/pipeline_validation/spec.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json | ||
name: oss_distillation_validate_pipeline | ||
version: 0.0.1 | ||
type: command | ||
|
||
is_deterministic: true | ||
|
||
display_name: OSS Distillation Validate Pipeline | ||
description: Component to validate inputs to the distillation pipeline | ||
|
||
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/66 | ||
|
||
code: ../../src | ||
|
||
inputs: | ||
# Inputs | ||
train_file_path: | ||
type: uri_file | ||
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`. | ||
mode: rw_mount | ||
|
||
validation_file_path: | ||
type: uri_file | ||
optional: true | ||
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`. | ||
mode: rw_mount | ||
|
||
teacher_model_endpoint_name: | ||
type: string | ||
optional: true | ||
description: Teacher model endpoint name | ||
|
||
teacher_model_endpoint_url: | ||
type: string | ||
optional: true | ||
description: Teacher model endpoint URL | ||
|
||
teacher_model_endpoint_key: | ||
type: string | ||
optional: true | ||
description: Teacher model endpoint key | ||
|
||
teacher_model_max_new_tokens: | ||
type: integer | ||
default: 128 | ||
description: Teacher model max_new_tokens inference parameter | ||
|
||
teacher_model_temperature: | ||
type: number | ||
default: 0.2 | ||
description: Teacher model temperature inference parameter | ||
|
||
teacher_model_top_p: | ||
type: number | ||
default: 0.1 | ||
description: Teacher model top_p inference parameter | ||
|
||
teacher_model_frequency_penalty: | ||
type: number | ||
default: 0.0 | ||
description: Teacher model frequency penalty inference parameter | ||
|
||
teacher_model_presence_penalty: | ||
type: number | ||
default: 0.0 | ||
description: Teacher model presence penalty inference parameter | ||
|
||
teacher_model_stop: | ||
type: string | ||
optional: true | ||
description: Teacher model stop inference parameter | ||
|
||
request_batch_size: | ||
type: integer | ||
default: 10 | ||
description: No of data records to hit teacher model endpoint in one go | ||
|
||
min_endpoint_success_ratio: | ||
type: number | ||
default: 0.7 | ||
description: > | ||
The minimum value of (successful_requests / total_requests) required for classifying inference as successful. | ||
If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed. | ||
By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.) | ||
enable_chain_of_thought: | ||
type: string | ||
default: "true" | ||
description: Enable Chain of thought for data generation | ||
|
||
data_generation_task_type: | ||
type: string | ||
enum: | ||
- NLI | ||
- CONVERSATION | ||
- NLU_QA | ||
description: > | ||
Data generation task type. Supported values are: | ||
1. NLI: Generate Natural Language Inference data | ||
2. CONVERSATION: Generate conversational data (multi/single turn) | ||
3. NLU_QA: Generate Natural Language Understanding data for Question Answering data | ||
num_train_epochs: | ||
type: integer | ||
default: 1 | ||
optional: true | ||
description: training epochs | ||
|
||
per_device_train_batch_size: | ||
type: integer | ||
default: 1 | ||
optional: true | ||
description: Train batch size | ||
|
||
learning_rate: | ||
type: number | ||
default: 3e-04 | ||
optional: true | ||
description: Start learning rate. | ||
|
||
outputs: | ||
validation_info: | ||
type: uri_file | ||
description: Validation status. | ||
mode: rw_mount | ||
|
||
command: >- | ||
python validate_pipeline.py | ||
--train_file_path ${{inputs.train_file_path}} | ||
$[[--validation_file_path ${{inputs.validation_file_path}}]] | ||
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]] | ||
$[[--teacher_model_endpoint_url ${{inputs.teacher_model_endpoint_url}}]] | ||
$[[--teacher_model_endpoint_key ${{inputs.teacher_model_endpoint_key}}]] | ||
--teacher_model_max_new_tokens ${{inputs.teacher_model_max_new_tokens}} | ||
--teacher_model_temperature ${{inputs.teacher_model_temperature}} | ||
--teacher_model_top_p ${{inputs.teacher_model_top_p}} | ||
--teacher_model_frequency_penalty ${{inputs.teacher_model_frequency_penalty}} | ||
--teacher_model_presence_penalty ${{inputs.teacher_model_presence_penalty}} | ||
$[[--teacher_model_stop ${{inputs.teacher_model_stop}}]] | ||
--request_batch_size ${{inputs.request_batch_size}} | ||
--min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}} | ||
--enable_chain_of_thought ${{inputs.enable_chain_of_thought}} | ||
--data_generation_task_type ${{inputs.data_generation_task_type}} | ||
$[[--num_train_epochs ${{inputs.num_train_epochs}}]] | ||
$[[--per_device_train_batch_size ${{inputs.per_device_train_batch_size}}]] | ||
$[[--learning_rate ${{inputs.learning_rate}}]] | ||
--validation_info ${{outputs.validation_info}} |
File renamed without changes.
Oops, something went wrong.