Skip to content

Commit

Permalink
refactor: remove orchestrator abstraction from API (#289)
Browse files Browse the repository at this point in the history
* refactor: remove orchestrator abstraction from API

* Remove orchestrator in GPT-J config

* Add `reward_fn` arg to NeMo constructor to match base trainer API

* Initial support for `make_experience` in NeMo ILQL

* Run pre-commit

* Remove unused sampling util
  • Loading branch information
jon-tow authored Feb 10, 2023
1 parent eb62d08 commit 81e935a
Show file tree
Hide file tree
Showing 27 changed files with 466 additions and 607 deletions.
1 change: 0 additions & 1 deletion configs/ilql_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ train:
eval_interval: 100

pipeline: "PromptPipeline"
orchestrator: "OfflineOrchestrator"
trainer: "AccelerateILQLTrainer"
seed: 1000

Expand Down
1 change: 0 additions & 1 deletion configs/nemo_ilql_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ train:
eval_interval: 20

pipeline: "PromptPipeline"
orchestrator: "OfflineOrchestrator"
trainer: "NeMoILQLTrainer"
trainer_kwargs:
pretrained_model: "/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/"
Expand Down
1 change: 0 additions & 1 deletion configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ train:
eval_interval: 100

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AcceleratePPOTrainer"

model:
Expand Down
1 change: 0 additions & 1 deletion configs/ppo_gptj.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ train:
eval_interval: 16

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AcceleratePPOTrainer"

model:
Expand Down
1 change: 0 additions & 1 deletion configs/sft_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ train:
eval_interval: 100

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AccelerateSFTTrainer"

model:
Expand Down
3 changes: 1 addition & 2 deletions configs/test_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ train:
eval_interval: 128 # eval interval

pipeline: "PromptPipeline" # prompt pipeline to load
orchestrator: "PPOOrchestrator" # orchestrator to load
trainer: "AcceleratePPOTrainer" # Name of model trainer to load

model:
Expand Down Expand Up @@ -36,7 +35,7 @@ scheduler:
method:
name: "ppoconfig" # Name of RL method config
num_rollouts: 128 # Number of rollouts to collect per epoch
chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator
chunk_size: 128 # Number of rollouts to collect in one loop
ppo_epochs: 4 # Number of ppo epochs
init_kl_coef: 0.2 # init kl coefficient
target: 6 # target kl coefficient, set None for fixed kl coef
Expand Down
1 change: 0 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ currently supports training using PPO or ILQL for models up to 20B using Acceler

data
models
orchestrator
configs
pipeline
examples
Expand Down
23 changes: 0 additions & 23 deletions docs/source/orchestrator.rst

This file was deleted.

2 changes: 1 addition & 1 deletion docs/source/pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Pipelines
************************

Pipelines are how you read from a dataset with trlX. Rollout stores are how models store experiences created
for them by the orchestrator. It is these experiences in their rollout store that they are trained on.
for them. It is these experiences in their rollout store that they are trained on.

**General**

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ train:
eval_interval: 16

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AcceleratePPOTrainer"

model:
Expand Down
1 change: 0 additions & 1 deletion examples/randomwalks/configs/ilql_randomwalks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ train:
eval_interval: 16

pipeline: "PromptPipeline"
orchestrator: "OfflineOrchestrator"
trainer: "AccelerateILQLTrainer"

seed: 1000
Expand Down
1 change: 0 additions & 1 deletion examples/randomwalks/configs/ppo_randomwalks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ train:
eval_interval: 20

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AcceleratePPOTrainer"

model:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ train:
save_best: False

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AcceleratePPOTrainer"

model:
Expand Down
1 change: 0 additions & 1 deletion examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ train:
eval_interval: 200

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AcceleratePPOTrainer"

model:
Expand Down
17 changes: 2 additions & 15 deletions trlx/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,18 @@
from dataclasses import dataclass
from typing import Any, Iterable
from typing import Iterable

from torchtyping import TensorType

from . import configs


@dataclass
class GeneralElement:
"""
General element outputted by data pipeline being read by orchestrator.
General element outputted by a data pipeline
"""

pass


@dataclass
class SimElement:
"""
Batch element for Gyarados or Gyarados-like similarity scoring model
"""

content: Any = None
preference: Any = None
score: float = None


@dataclass
class RLElement:
"""
Expand Down
4 changes: 0 additions & 4 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,6 @@ class TrainConfig:
:param pipeline: Pipeline to use for training. One of the registered pipelines present in trlx.pipeline
:type pipeline: str
:param orchestrator: Orchestrator to use for training. One of the registered orchestrators present in trlx.orchestrator
:type orchestrator: str
:param trainer: Trainer to use for training. One of the registered trainers present in trlx.trainer
:type trainer: str
Expand Down Expand Up @@ -193,7 +190,6 @@ class TrainConfig:
eval_interval: int

pipeline: str # One of the pipelines in framework.pipeline
orchestrator: str # One of the orchestrators
trainer: str # One of the trainers
trainer_kwargs: Dict[str, Any] = field(default_factory=dict) # Extra keyword arguments for the trainer

Expand Down
46 changes: 0 additions & 46 deletions trlx/orchestrator/__init__.py

This file was deleted.

132 changes: 0 additions & 132 deletions trlx/orchestrator/offline_orchestrator.py

This file was deleted.

Loading

0 comments on commit 81e935a

Please sign in to comment.