Skip to content

Commit

Permalink
[CI] Fix CI with new transformers release (#946)
Browse files Browse the repository at this point in the history
* fix CI with transformers release

* final fix
  • Loading branch information
younesbelkada authored Nov 3, 2023
1 parent cc1de98 commit 951ca18
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 51 deletions.
2 changes: 2 additions & 0 deletions tests/test_modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
"trl-internal-testing/tiny-random-BloomForCausalLM",
"trl-internal-testing/tiny-random-GPT2LMHeadModel",
"trl-internal-testing/tiny-random-CodeGenForCausalLM-sharded",
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM-safetensors-sharded",
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM-safetensors"
# "trl-internal-testing/tiny-random-LlamaForCausalLM", uncomment on the next transformers release
]

Expand Down
26 changes: 13 additions & 13 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_sft_trainer(self):
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])

self.assertTrue("pytorch_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))

def test_sft_trainer_uncorrect_data(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -226,7 +226,7 @@ def test_sft_trainer_with_model_num_train_epochs(self):
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])

self.assertTrue("pytorch_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
Expand All @@ -253,7 +253,7 @@ def test_sft_trainer_with_model_num_train_epochs(self):

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

self.assertTrue("pytorch_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
Expand All @@ -278,7 +278,7 @@ def test_sft_trainer_with_model_num_train_epochs(self):

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

self.assertTrue("pytorch_model.bin" in os.listdir(tmp_dir + "/checkpoint-1"))
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1"))

def test_sft_trainer_with_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -305,7 +305,7 @@ def test_sft_trainer_with_model(self):
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])

self.assertTrue("pytorch_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
Expand All @@ -331,7 +331,7 @@ def test_sft_trainer_with_model(self):

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

self.assertTrue("pytorch_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))

# with formatting_func + packed
with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -358,7 +358,7 @@ def test_sft_trainer_with_model(self):

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

self.assertTrue("pytorch_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))

# with formatting_func + packed
with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -383,7 +383,7 @@ def test_sft_trainer_with_model(self):

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

self.assertTrue("pytorch_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
Expand All @@ -407,7 +407,7 @@ def test_sft_trainer_with_model(self):

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

self.assertTrue("pytorch_model.bin" in os.listdir(tmp_dir + "/checkpoint-1"))
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1"))

def test_data_collator_completion_lm(self):
response_template = "### Response:\n"
Expand Down Expand Up @@ -529,7 +529,7 @@ def test_sft_trainer_infinite_with_model(self):
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])

# make sure the trainer did 5 steps
self.assertTrue("pytorch_model.bin" in os.listdir(tmp_dir + "/checkpoint-5"))
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-5"))

def test_sft_trainer_infinite_with_model_epochs(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -557,7 +557,7 @@ def test_sft_trainer_infinite_with_model_epochs(self):
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# make sure the trainer did 5 steps
self.assertTrue("pytorch_model.bin" in os.listdir(tmp_dir + "/checkpoint-4"))
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-4"))

def test_sft_trainer_with_model_neftune(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -641,7 +641,7 @@ def test_peft_sft_trainer(self):

self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("pytorch_model.bin" not in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))

@require_peft
def test_peft_sft_trainer_neftune(self):
Expand Down Expand Up @@ -699,7 +699,7 @@ def test_peft_sft_trainer_neftune(self):

self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("pytorch_model.bin" not in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))

# Make sure forward pass works fine to check if embeddings forward is not broken.
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
Expand Down
122 changes: 84 additions & 38 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from accelerate import Accelerator
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, HFValidationError, LocalEntryNotFoundError
from safetensors.torch import load_file as safe_load_file
from transformers import PreTrainedModel

from ..import_utils import is_peft_available, is_transformers_greater_than, is_xpu_available
Expand Down Expand Up @@ -248,65 +249,58 @@ class and the arguments that are specific to trl models. The kwargs
# state_dict is removed from the model after loading it.
is_resuming_training = True
if isinstance(pretrained_model_name_or_path, str):
safe_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors")
filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json")
is_shared = False

if not os.path.exists(filename):
try:
filename = hf_hub_download(
sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json")
safe_sharded_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")
is_sharded = False
use_safe = os.path.exists(safe_filename)

if not (os.path.exists(filename) or os.path.exists(safe_filename)):
# Try with `pytorch_model.bin`
filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub(
pretrained_model,
pretrained_model_name_or_path,
sharded_index_filename,
token=token,
)
# Try with safetensors
if filename is None and files_to_download is None:
safe_filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub(
pretrained_model,
pretrained_model_name_or_path,
"pytorch_model.bin",
safe_sharded_index_filename,
token=token,
model_name="model.safetensors",
model_index_name="model.safetensors.index.json",
)
# sharded
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError):
if os.path.exists(sharded_index_filename):
index_file_name = sharded_index_filename
else:
try:
index_file_name = hf_hub_download(
pretrained_model_name_or_path,
"pytorch_model.bin.index.json",
token=token,
)
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError):
# not continue training, do not have v_head weight
is_resuming_training = False
logging.warning(
f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', "
f"and no v_head weight is found. This IS expected if you are not resuming PPO training."
)
# load json
if is_resuming_training:
with open(index_file_name, "r") as f:
index = json.load(f)
# check filename with `v_head` or any known extra module:
files_to_download = set()
for k, v in index["weight_map"].items():
if any([module in k for module in cls.supported_modules]):
files_to_download.add(v)
is_shared = True
use_safe = True
else:
use_safe = False

loading_func = safe_load_file if use_safe else torch.load
load_kwargs = {} if use_safe else {"map_location": "cpu"}

if is_resuming_training:
if is_shared:
if is_sharded:
# download each file and add it to the state_dict
state_dict = {}

for shard_file in files_to_download:
filename = hf_hub_download(
pretrained_model_name_or_path,
shard_file,
token=token,
)
state_dict.update(torch.load(filename, map_location="cpu"))
state_dict.update(loading_func(filename, **load_kwargs))
else:
state_dict = torch.load(filename, map_location="cpu")
state_dict = loading_func(filename if not use_safe else safe_filename, **load_kwargs)

else:
state_dict = pretrained_model_name_or_path.state_dict()

model.is_peft_model = is_peft_model

model.current_device = current_device

if is_resuming_training:
Expand All @@ -322,6 +316,58 @@ class and the arguments that are specific to trl models. The kwargs

return model

@classmethod
def _get_checkpoint_from_hub(
cls,
pretrained_model,
pretrained_model_name_or_path,
index_filename,
token=None,
model_name="pytorch_model.bin",
model_index_name="pytorch_model.bin.index.json",
):
files_to_download = None
filename = None
is_resuming_training = True
is_sharded = False

try:
filename = hf_hub_download(
pretrained_model_name_or_path,
model_name,
token=token,
)
# sharded
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError):
if os.path.exists(index_filename):
index_file_name = index_filename
else:
try:
index_file_name = hf_hub_download(
pretrained_model_name_or_path,
model_index_name,
token=token,
)
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError):
# not continue training, do not have v_head weight
is_resuming_training = False
logging.warning(
f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', "
f"and no v_head weight is found. This IS expected if you are not resuming PPO training."
)
# load json
if is_resuming_training:
with open(index_file_name, "r") as f:
index = json.load(f)
# check filename with `v_head` or any known extra module:
files_to_download = set()
for k, v in index["weight_map"].items():
if any([module in k for module in cls.supported_modules]):
files_to_download.add(v)
is_sharded = True

return filename, files_to_download, is_sharded, is_resuming_training

@classmethod
def _get_current_device(cls):
r"""
Expand Down

0 comments on commit 951ca18

Please sign in to comment.