diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py index d6c06128ef..3b004659d2 100644 --- a/tests/test_peft_models.py +++ b/tests/test_peft_models.py @@ -137,10 +137,10 @@ def test_save_pretrained_peft(self): with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir) - # check that the files `adapter_model.bin` and `adapter_config.json` are in the directory + # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory self.assertTrue( - os.path.isfile(f"{tmp_dir}/adapter_model.bin"), - msg=f"{tmp_dir}/adapter_model.bin does not exist", + os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"), + msg=f"{tmp_dir}/adapter_model.safetensors does not exist", ) self.assertTrue( os.path.exists(f"{tmp_dir}/adapter_config.json"), @@ -177,10 +177,10 @@ def test_load_pretrained_peft(self): pretrained_model.save_pretrained(tmp_dir) model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir) - # check that the files `adapter_model.bin` and `adapter_config.json` are in the directory + # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory self.assertTrue( - os.path.isfile(f"{tmp_dir}/adapter_model.bin"), - msg=f"{tmp_dir}/adapter_model.bin does not exist", + os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"), + msg=f"{tmp_dir}/adapter_model.safetensors does not exist", ) self.assertTrue( os.path.exists(f"{tmp_dir}/adapter_config.json"), diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 53d53a32e1..dec0e4036c 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -651,7 +651,7 @@ def test_peft_sft_trainer(self): self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) - self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")) self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")) @@ -693,7 +693,7 @@ def test_peft_sft_trainer_gc(self): self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) - self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")) self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")) @@ -751,7 +751,7 @@ def test_peft_sft_trainer_neftune(self): self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) - self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir + "/checkpoint-2")) + self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")) self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))