Skip to content

Commit

Permalink
fix last test
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Dec 6, 2023
1 parent 71a7337 commit 89576e7
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def add_and_load_reward_modeling_adapter(
pretrained_model.train()

filename = os.path.join(adapter_model_id, "adapter_model.bin")
safe_loading = False
if not os.path.exists(filename):
try:
local_filename = hf_hub_download(
Expand All @@ -452,13 +453,28 @@ def add_and_load_reward_modeling_adapter(
token=token,
)
except: # noqa
raise ValueError(
"Could not find adapter model in the Hub, make sure you have the correct adapter model id."
)
filename = os.path.join(adapter_model_id, "adapter_model.safetensors")
safe_loading = True
if not os.path.exists(filename):
try:
local_filename = hf_hub_download(
adapter_model_id,
"adapter_model.safetensors",
token=token,
)
except: # noqa
raise ValueError(
"Could not find adapter model in the Hub, make sure you have the correct adapter model id."
)
else:
local_filename = filename
else:
local_filename = filename

adapter_state_dict = torch.load(local_filename, map_location="cpu")
loading_func = safe_load_file if safe_loading else torch.load
load_kwargs = {} if safe_loading else {"map_location": "cpu"}

adapter_state_dict = loading_func(local_filename, **load_kwargs)

for score_name_candidate in cls.supported_rm_modules:
if any([score_name_candidate in name for name in adapter_state_dict.keys()]):
Expand Down

0 comments on commit 89576e7

Please sign in to comment.