Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RLHF torchtune #262

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update code to latest
pierre.delaunay committed Aug 30, 2024
commit c6d7624f4d2aca4442bd56a30225d9861b1acbd8
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -56,4 +56,5 @@ benchmarks/voir
benchmarks/*/base/
benchmarks/lightning/lightning_logs/

benchmarks/*/src/
benchmarks/*/src/
benchmarks/llm/tune
12 changes: 12 additions & 0 deletions benchmarks/llm/benchfile.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,10 @@
from milabench.commands import SimpleCommand


URL = "https://github.com/pytorch/torchtune.git"
BRANCH = "a83eeff0079a73ee04a11e8fc2573ed8f671b231"


class Torchtune(TorchrunAllGPU):
@property
def executable(self):
@@ -40,6 +44,14 @@ class Llm(Package):
async def install(self):
await super().install() # super() call installs the requirements

# Clone the right version of torchtune
tune = self.dirs.code / "tune"
if not tune.exists():
tune.clone_subtree(URL, BRANCH)

# make an editable install
await self.pip_install("-e", str(tune))

def build_run_plan(self):
exec = SimpleCommand(self)
return TorchtuneAllNodes(exec).use_stdout()
2 changes: 1 addition & 1 deletion benchmarks/llm/configs/llama3_70B_full.yaml
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ model:
_component_: torchtune.models.llama3_1.llama3_1_70b

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.checkpointing.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
checkpoint_files: [
model-00001-of-00030.safetensors,
2 changes: 1 addition & 1 deletion benchmarks/llm/configs/llama3_70B_lora.yaml
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ tokenizer:

safetensors: true
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.checkpointing.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
checkpoint_files: [
model-00001-of-00030.safetensors,
2 changes: 1 addition & 1 deletion benchmarks/llm/configs/llama3_8B_lora.yaml
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ model:
lora_alpha: 16

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.checkpointing.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
checkpoint_files: [
consolidated.00.pth
2 changes: 1 addition & 1 deletion benchmarks/llm/configs/llama3_8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ tokenizer:
path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.checkpointing.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
checkpoint_files: [
consolidated.00.pth
2 changes: 1 addition & 1 deletion benchmarks/llm/configs/llama3_8B_qat_full.yaml
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ model:
_component_: torchtune.models.llama3_1.llama3_1_8b

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.checkpointing.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
checkpoint_files: [
consolidated.00.pth
2 changes: 1 addition & 1 deletion benchmarks/llm/configs/llama3_8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@ tokenizer:
path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.checkpointing.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
checkpoint_files: [
consolidated.00.pth
4 changes: 2 additions & 2 deletions benchmarks/llm/prepare.py
Original file line number Diff line number Diff line change
@@ -5,10 +5,10 @@

from omegaconf import OmegaConf
from argklass import ArgumentParser
from torchtune._cli.tune import TuneCLIParser

from benchmate.ux import long_action

from torchtune._cli.tune import TuneCLIParser


@dataclass
class Arguments:
Loading