Skip to content

Commit

Permalink
[core] Fix failing tests on main (#1065)
Browse files Browse the repository at this point in the history
* fix tests on main

* fix last test
  • Loading branch information
younesbelkada authored Dec 6, 2023
1 parent 7f2401b commit ee44946
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 13 deletions.
12 changes: 6 additions & 6 deletions tests/test_peft_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def test_save_pretrained_peft(self):
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)

# check that the files `adapter_model.bin` and `adapter_config.json` are in the directory
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
self.assertTrue(
os.path.isfile(f"{tmp_dir}/adapter_model.bin"),
msg=f"{tmp_dir}/adapter_model.bin does not exist",
os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"),
msg=f"{tmp_dir}/adapter_model.safetensors does not exist",
)
self.assertTrue(
os.path.exists(f"{tmp_dir}/adapter_config.json"),
Expand Down Expand Up @@ -177,10 +177,10 @@ def test_load_pretrained_peft(self):
pretrained_model.save_pretrained(tmp_dir)
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir)

# check that the files `adapter_model.bin` and `adapter_config.json` are in the directory
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
self.assertTrue(
os.path.isfile(f"{tmp_dir}/adapter_model.bin"),
msg=f"{tmp_dir}/adapter_model.bin does not exist",
os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"),
msg=f"{tmp_dir}/adapter_model.safetensors does not exist",
)
self.assertTrue(
os.path.exists(f"{tmp_dir}/adapter_config.json"),
Expand Down
6 changes: 3 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def test_peft_sft_trainer(self):
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])

self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))

Expand Down Expand Up @@ -693,7 +693,7 @@ def test_peft_sft_trainer_gc(self):
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])

self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))

Expand Down Expand Up @@ -751,7 +751,7 @@ def test_peft_sft_trainer_neftune(self):
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])

self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))

Expand Down
24 changes: 20 additions & 4 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def add_and_load_reward_modeling_adapter(
pretrained_model.train()

filename = os.path.join(adapter_model_id, "adapter_model.bin")
safe_loading = False
if not os.path.exists(filename):
try:
local_filename = hf_hub_download(
Expand All @@ -452,13 +453,28 @@ def add_and_load_reward_modeling_adapter(
token=token,
)
except: # noqa
raise ValueError(
"Could not find adapter model in the Hub, make sure you have the correct adapter model id."
)
filename = os.path.join(adapter_model_id, "adapter_model.safetensors")
safe_loading = True
if not os.path.exists(filename):
try:
local_filename = hf_hub_download(
adapter_model_id,
"adapter_model.safetensors",
token=token,
)
except: # noqa
raise ValueError(
"Could not find adapter model in the Hub, make sure you have the correct adapter model id."
)
else:
local_filename = filename
else:
local_filename = filename

adapter_state_dict = torch.load(local_filename, map_location="cpu")
loading_func = safe_load_file if safe_loading else torch.load
load_kwargs = {} if safe_loading else {"map_location": "cpu"}

adapter_state_dict = loading_func(local_filename, **load_kwargs)

for score_name_candidate in cls.supported_rm_modules:
if any([score_name_candidate in name for name in adapter_state_dict.keys()]):
Expand Down

0 comments on commit ee44946

Please sign in to comment.