Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
floatingbigcat committed Jan 14, 2025
1 parent a40c125 commit e3ebc30
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion scripts/eval_prompt_based.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

# Task Selection
TASK="ai2_arc" # Available options: mbpp, math, ai2_arc
TASK="mbpp2" # Available options: mbpp2, math, ai2_arc

# First Stage Inference: Classification Expert
# Set to 'None' if not using cls expert
Expand Down
2 changes: 1 addition & 1 deletion scripts/train_task_expert.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# !/bin/bash

# Task Selection
TASK="ai2_arc" # Available options: mbpp, math, ai2_arc
TASK="mbpp2" # Available options: mbpp2, gsm8k, ai2_arc, cls

# Training Setting
NUM_ITERS=200
Expand Down
2 changes: 1 addition & 1 deletion svd_reinforce_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def main(cfg):
gpu=gpu,
)

if resuming_from_ckpt:
if resuming_from_ckpt and os.path.exists(load_ckpt):
print(f"Starting from checkpoint at: {load_ckpt}")
# load the lora weight
if use_lora:
Expand Down
3 changes: 2 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ def eval_model_experts_prompt_based(

# Load and apply expert model parameters if available
if expert_model_path:
expert_params = torch.load(expert_model_path)
policy.load_state_dict(torch.load(expert_model_path))
expert_params = policy.get_learnable_params()
updated_params = forward(
policy=policy,
model=model,
Expand Down

0 comments on commit e3ebc30

Please sign in to comment.