Skip to content

Commit

Permalink
bug correction
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Nov 5, 2024
1 parent 14c7f52 commit 8919757
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion examples/advanced/llm_hf/sft_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ def main():
if train_mode.lower() == "sft":
job = FedJob(name="llm_hf_sft", min_clients=num_clients)
output_path = "sft"
mode = 0
elif train_mode.lower() == "peft":
job = FedJob(name="llm_hf_peft", min_clients=num_clients)
output_path = "peft"
mode = 1
else:
raise ValueError(f"Invalid train_mode: {train_mode}, only SFT and PEFT are supported.")

Expand Down Expand Up @@ -80,7 +82,7 @@ def main():
data_path_valid = os.path.join(args.data_path, client_id, "validation.jsonl")
runner = ScriptRunner(
script=train_script,
script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --mode {train_mode} --clean_up {clean_up}",
script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --mode {mode} --clean_up {clean_up}",
)
job.to(runner, site_name, tasks=["train"])

Expand Down
2 changes: 1 addition & 1 deletion examples/advanced/llm_hf/src/hf_sft_peft_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def evaluate(input_weights, mode):
out_param["model." + key] = out_param.pop(key).cpu()

# cast out_param to float32 preparing for communication
out_param = {k: v.to(torch.float32) for k, v in out_param.items()}
out_param = {k: v.to(torch.float16) for k, v in out_param.items()}

# construct trained FL model
output_model = flare.FLModel(
Expand Down

0 comments on commit 8919757

Please sign in to comment.