Skip to content

Commit

Permalink
add deps for veRL experiment in README
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Feb 1, 2025
1 parent e671b97 commit 3f24df3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 18 deletions.
19 changes: 19 additions & 0 deletions examples/veRL/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
### env setup

```
conda create --name verl python=3.12 -y
conda activate verl
pip install flash-attn --no-build-isolation
pip install vllm==0.7.0 ray wandb
```

### clone and install veRL

tested with verl HEAD a65c9157bc0b85b64cd753de19f94e80a11bd871

```
git clone https://github.com/volcengine/verl.git
cd verl
pip install -e .
```
34 changes: 16 additions & 18 deletions examples/veRL/main_ppo_custom_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,22 @@
"""

from typing import Optional
from omegaconf import OmegaConf, open_dict
import reasoning_gym
from reasoning_gym.utils import extract_answer

import reasoning_gym.utils
from verl import DataProto
import hydra
import ray
import torch
from torch.utils.data import Dataset, DataLoader
import verl.utils.torch_functional as verl_F
from omegaconf import OmegaConf, open_dict
from torch.utils.data import DataLoader, Dataset
from transformers import PreTrainedTokenizer

import ray
import hydra


from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.utils.model import compute_position_id_with_mask
from verl.utils.dataset.rl_dataset import collate_fn
import verl.utils.torch_functional as verl_F
from verl.utils.model import compute_position_id_with_mask

import reasoning_gym
import reasoning_gym.utils
from reasoning_gym.utils import extract_answer


class RewardManager:
Expand Down Expand Up @@ -262,12 +260,12 @@ def _create_dataloader(self):

@ray.remote
def main_task(config, compute_score=None):
from verl.utils.fs import copy_local_path_from_hdfs
from transformers import AutoTokenizer

# print initial config
from pprint import pprint

from omegaconf import OmegaConf
from transformers import AutoTokenizer
from verl.utils.fs import copy_local_path_from_hdfs

pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
Expand All @@ -283,15 +281,15 @@ def main_task(config, compute_score=None):
# define worker classes
if config.actor_rollout_ref.actor.strategy == "fsdp":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker

ray_worker_group_cls = RayWorkerGroup

elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker

ray_worker_group_cls = NVMegatronRayWorkerGroup

Expand Down

0 comments on commit 3f24df3

Please sign in to comment.