From 5ef5e906da993852a71d4991c84e87e33cf772c9 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 5 Nov 2024 16:20:40 -0500 Subject: [PATCH] bug correction --- examples/advanced/llm_hf/src/hf_sft_peft_fl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced/llm_hf/src/hf_sft_peft_fl.py b/examples/advanced/llm_hf/src/hf_sft_peft_fl.py index c79d36253c..1ae1dab54d 100755 --- a/examples/advanced/llm_hf/src/hf_sft_peft_fl.py +++ b/examples/advanced/llm_hf/src/hf_sft_peft_fl.py @@ -218,7 +218,7 @@ def evaluate(input_weights, mode): out_param["model." + key] = out_param.pop(key).cpu() # cast out_param to float32 preparing for communication - out_param = {k: v.to(torch.float16) for k, v in out_param.items()} + out_param = {k: v.to(torch.float32) for k, v in out_param.items()} # construct trained FL model output_model = flare.FLModel(