Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PRM Performance on Different Data Type #2591

Open
TanZhendong opened this issue Jan 19, 2025 · 0 comments
Open

PRM Performance on Different Data Type #2591

TanZhendong opened this issue Jan 19, 2025 · 0 comments
Labels
🏋 PRM Related to PRM ❓ question Seeking clarification or more information

Comments

@TanZhendong
Copy link

Hello, when I use PRM Trainer, I found that different data types (torch.float32, torch.bfloat16) have a significant impact on the model's accuracy. Specifically, I fine-tuned the LLAMA3.2-1B-Instruct model using trl-lib/math shepherd. The train config is:

training_args = PRMConfig(
    output_dir='xxx' 
    logging_steps=5,
    num_train_epochs=1,
    learning_rate=2e-6,
    per_device_train_batch_size=8,
    lr_scheduler_type='cosine',
    warmup_ratio=0.1,
    save_strategy='epoch',
)

It is worth noting that, in order to ensure accuracy, hyperparameters such as batch_size and learning rate remain consistent. The initial points of the model are also the same.
Then, the final loss curve is as follows:
Image
When using torch.bfloat16, the model seems to converge to a local optimum. And I tested the performance of the model in trl-lib/math shepherd(test), and the results are as follows:

Label Acc Strict Label Acc
bf16 72.29 32.53
fp32 86.31 67.53

The meaning of Strict Label Acc is that an entire chain of thought is classified correctly.
It can be seen that different data types significantly affect the training performance of the model. I wonder, is this phenomenon normal? Does this suggest that we should try to use the fp32 data type as much as possible when training with PRMTrainer? 🤔

@github-actions github-actions bot added the ❓ question Seeking clarification or more information label Jan 19, 2025
@qgallouedec qgallouedec added the 🏋 PRM Related to PRM label Jan 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🏋 PRM Related to PRM ❓ question Seeking clarification or more information
Projects
None yet
Development

No branches or pull requests

2 participants