-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'aiw' of https://github.com/rishabhranawat/reasoning-gym …
…into aiw
- Loading branch information
Showing
17 changed files
with
1,358 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
outputs/ | ||
wandb/ | ||
verl_output.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
### env setup | ||
|
||
``` | ||
conda create --name verl python=3.12 -y | ||
conda activate verl | ||
pip install flash-attn --no-build-isolation | ||
pip install vllm==0.7.0 ray wandb | ||
``` | ||
|
||
### clone and install veRL | ||
|
||
tested with verl HEAD a65c9157bc0b85b64cd753de19f94e80a11bd871 | ||
|
||
``` | ||
git clone https://github.com/volcengine/verl.git | ||
cd verl | ||
pip install -e . | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
data: | ||
tokenizer: null | ||
train_files: ~/data/rlhf/gsm8k/train.parquet | ||
val_files: ~/data/rlhf/gsm8k/test.parquet | ||
prompt_key: prompt | ||
max_prompt_length: 512 | ||
max_response_length: 512 | ||
train_batch_size: 1024 | ||
val_batch_size: 1312 | ||
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs | ||
return_raw_chat: False | ||
|
||
actor_rollout_ref: | ||
hybrid_engine: True | ||
model: | ||
path: ~/models/deepseek-llm-7b-chat | ||
external_lib: null | ||
override_config: { } | ||
enable_gradient_checkpointing: True | ||
use_remove_padding: False | ||
actor: | ||
strategy: fsdp # This is for backward-compatibility | ||
ppo_mini_batch_size: 256 | ||
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu | ||
ppo_micro_batch_size_per_gpu: null | ||
use_dynamic_bsz: False | ||
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} | ||
grad_clip: 1.0 | ||
clip_ratio: 0.2 | ||
entropy_coeff: 0.001 | ||
use_kl_loss: False # True for GRPO | ||
kl_loss_coef: 0.001 # for grpo | ||
kl_loss_type: low_var_kl # for grpo | ||
ppo_epochs: 1 | ||
shuffle: False | ||
ulysses_sequence_parallel_size: 1 # sp size | ||
optim: | ||
lr: 1e-6 | ||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime | ||
min_lr_ratio: null # only useful for warmup with cosine | ||
warmup_style: constant # select from constant/cosine | ||
total_training_steps: -1 # must be override by program | ||
fsdp_config: | ||
wrap_policy: | ||
# transformer_layer_cls_to_wrap: None | ||
min_num_params: 0 | ||
param_offload: False | ||
grad_offload: False | ||
optimizer_offload: False | ||
fsdp_size: -1 | ||
ref: | ||
fsdp_config: | ||
param_offload: False | ||
wrap_policy: | ||
# transformer_layer_cls_to_wrap: None | ||
min_num_params: 0 | ||
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu | ||
log_prob_micro_batch_size_per_gpu: null | ||
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} | ||
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} | ||
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size | ||
rollout: | ||
name: vllm | ||
temperature: 1.0 | ||
top_k: -1 # 0 for hf rollout, -1 for vllm rollout | ||
top_p: 1 | ||
prompt_length: ${data.max_prompt_length} # not use for opensource | ||
response_length: ${data.max_response_length} | ||
# for vllm rollout | ||
dtype: bfloat16 # should align with FSDP | ||
gpu_memory_utilization: 0.5 | ||
ignore_eos: False | ||
enforce_eager: True | ||
free_cache_engine: True | ||
load_format: dummy_dtensor | ||
tensor_model_parallel_size: 2 | ||
max_num_batched_tokens: 8192 | ||
max_num_seqs: 1024 | ||
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu | ||
log_prob_micro_batch_size_per_gpu: null | ||
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} | ||
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} | ||
disable_log_stats: True | ||
enable_chunked_prefill: True # could get higher throughput | ||
# for hf rollout | ||
do_sample: True | ||
# number of responses (i.e. num sample times) | ||
n: 1 # > 1 for grpo | ||
|
||
critic: | ||
strategy: fsdp | ||
optim: | ||
lr: 1e-5 | ||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime | ||
min_lr_ratio: null # only useful for warmup with cosine | ||
warmup_style: constant # select from constant/cosine | ||
total_training_steps: -1 # must be override by program | ||
model: | ||
path: ~/models/deepseek-llm-7b-chat | ||
tokenizer_path: ${actor_rollout_ref.model.path} | ||
override_config: { } | ||
external_lib: ${actor_rollout_ref.model.external_lib} | ||
enable_gradient_checkpointing: True | ||
use_remove_padding: False | ||
fsdp_config: | ||
param_offload: False | ||
grad_offload: False | ||
optimizer_offload: False | ||
wrap_policy: | ||
# transformer_layer_cls_to_wrap: None | ||
min_num_params: 0 | ||
fsdp_size: -1 | ||
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} | ||
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu | ||
ppo_micro_batch_size_per_gpu: null | ||
forward_micro_batch_size: ${critic.ppo_micro_batch_size} | ||
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} | ||
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} | ||
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 | ||
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} | ||
ulysses_sequence_parallel_size: 1 # sp size | ||
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} | ||
shuffle: ${actor_rollout_ref.actor.shuffle} | ||
grad_clip: 1.0 | ||
cliprange_value: 0.5 | ||
|
||
reward_model: | ||
enable: False | ||
strategy: fsdp | ||
model: | ||
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical | ||
path: ~/models/FsfairX-LLaMA3-RM-v0.1 | ||
external_lib: ${actor_rollout_ref.model.external_lib} | ||
use_remove_padding: False | ||
fsdp_config: | ||
min_num_params: 0 | ||
param_offload: False | ||
fsdp_size: -1 | ||
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu | ||
micro_batch_size_per_gpu: null # set a number | ||
max_length: null | ||
ulysses_sequence_parallel_size: 1 # sp size | ||
use_dynamic_bsz: ${critic.use_dynamic_bsz} | ||
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} | ||
|
||
algorithm: | ||
gamma: 1.0 | ||
lam: 1.0 | ||
adv_estimator: gae | ||
kl_penalty: kl # how to estimate kl divergence | ||
kl_ctrl: | ||
type: fixed | ||
kl_coef: 0.001 | ||
|
||
trainer: | ||
total_epochs: 30 | ||
total_training_steps: null | ||
project_name: verl_examples | ||
experiment_name: gsm8k | ||
logger: [ 'console', 'wandb' ] | ||
nnodes: 1 | ||
n_gpus_per_node: 8 | ||
save_freq: -1 | ||
test_freq: -1 | ||
critic_warmup: 0 | ||
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} | ||
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#!/bin/bash | ||
|
||
export N_GPUS=4 | ||
export BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct | ||
export ROLLOUT_TP_SIZE=2 | ||
export EXPERIMENT_NAME=chain_sum_llama | ||
export VLLM_ATTENTION_BACKEND=XFORMERS | ||
|
||
bash ./train.sh |
Oops, something went wrong.