Skip to content

Commit

Permalink
fix tests on main
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Dec 6, 2023
1 parent 7f2401b commit 71a7337
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions tests/test_peft_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def test_save_pretrained_peft(self):
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)

# check that the files `adapter_model.bin` and `adapter_config.json` are in the directory
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
self.assertTrue(
os.path.isfile(f"{tmp_dir}/adapter_model.bin"),
msg=f"{tmp_dir}/adapter_model.bin does not exist",
os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"),
msg=f"{tmp_dir}/adapter_model.safetensors does not exist",
)
self.assertTrue(
os.path.exists(f"{tmp_dir}/adapter_config.json"),
Expand Down Expand Up @@ -177,10 +177,10 @@ def test_load_pretrained_peft(self):
pretrained_model.save_pretrained(tmp_dir)
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir)

# check that the files `adapter_model.bin` and `adapter_config.json` are in the directory
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
self.assertTrue(
os.path.isfile(f"{tmp_dir}/adapter_model.bin"),
msg=f"{tmp_dir}/adapter_model.bin does not exist",
os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"),
msg=f"{tmp_dir}/adapter_model.safetensors does not exist",
)
self.assertTrue(
os.path.exists(f"{tmp_dir}/adapter_config.json"),
Expand Down
6 changes: 3 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def test_peft_sft_trainer(self):
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])

self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))

Expand Down Expand Up @@ -693,7 +693,7 @@ def test_peft_sft_trainer_gc(self):
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])

self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))

Expand Down Expand Up @@ -751,7 +751,7 @@ def test_peft_sft_trainer_neftune(self):
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])

self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))

Expand Down

0 comments on commit 71a7337

Please sign in to comment.