You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, when I use PRM Trainer, I found that different data types (torch.float32, torch.bfloat16) have a significant impact on the model's accuracy. Specifically, I fine-tuned the LLAMA3.2-1B-Instruct model using trl-lib/math shepherd. The train config is:
It is worth noting that, in order to ensure accuracy, hyperparameters such as batch_size and learning rate remain consistent. The initial points of the model are also the same.
Then, the final loss curve is as follows:
When using torch.bfloat16, the model seems to converge to a local optimum. And I tested the performance of the model in trl-lib/math shepherd(test), and the results are as follows:
Label Acc
Strict Label Acc
bf16
72.29
32.53
fp32
86.31
67.53
The meaning of Strict Label Acc is that an entire chain of thought is classified correctly.
It can be seen that different data types significantly affect the training performance of the model. I wonder, is this phenomenon normal? Does this suggest that we should try to use the fp32 data type as much as possible when training with PRMTrainer? 🤔
The text was updated successfully, but these errors were encountered:
Hello, when I use PRM Trainer, I found that different data types (torch.float32, torch.bfloat16) have a significant impact on the model's accuracy. Specifically, I fine-tuned the
LLAMA3.2-1B-Instruct
model usingtrl-lib/math shepherd
. The train config is:It is worth noting that, in order to ensure accuracy, hyperparameters such as batch_size and learning rate remain consistent. The initial points of the model are also the same.
Then, the final loss curve is as follows:
When using torch.bfloat16, the model seems to converge to a local optimum. And I tested the performance of the model in
trl-lib/math shepherd(test)
, and the results are as follows:The meaning of
Strict Label Acc
is that an entire chain of thought is classified correctly.It can be seen that different data types significantly affect the training performance of the model. I wonder, is this phenomenon normal? Does this suggest that we should try to use the fp32 data type as much as possible when training with PRMTrainer? 🤔
The text was updated successfully, but these errors were encountered: