Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GRU to match pytorch (#2701). #2704

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

nwhitehead
Copy link

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

This addresses issue #2701

Changes

  • Update GRU implementation of "new" gate to match pytorch implementation. This can change numerical output in some cases.

  • Add GRU unit test with sequence length > 1.

  • Fix GRU input state dimensions and hidden state handling. This is an API change since the dimensions of the optional hidden state input are being corrected to the right sizes.

These changes do affect numerical results and change the API slightly. I think just updating to the correct API dimensions seems like the best thing since the previous implementation was incorrect, not just different than pytorch.

Testing

These changes were tested with a small unit test. For this test the correct values were computed manually using the equations for GRU.

I tested these changes against PyTorch. The weights and biases from PyTorch were saved then split into sections using a custom script (to split apart the weights for each gate). Input and output tensors were separately saved and then loaded into a test rust program. Everything was randomly initialized. With this PR the results from burn and torch were almost identical (within 6 decimal digits). I tried input sizes of 1, 2, and 8. I tried hidden sizes of 1, 2, and 8. I tried sequence lengths of 1, 2, and 3.

Update GRU implementation of new gate to match pytorch implementation.
This can change numerical output in some cases.

Add GRU unit test with sequence length > 1.

Fix GRU input state dimensions and hidden state handling. This is an API
change since the dimensions of the optional hidden state input
are being corrected to the right sizes. Just updating to the correct
dimensions seems like the best thing since the previous implementation
was incorrect, not just different than pytorch.
Copy link

codecov bot commented Jan 15, 2025

Codecov Report

Attention: Patch coverage is 80.48780% with 8 lines in your changes missing coverage. Please review.

Project coverage is 83.19%. Comparing base (f630b3b) to head (3134f82).

Files with missing lines Patch % Lines
crates/burn-core/src/nn/rnn/gru.rs 80.48% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2704      +/-   ##
==========================================
- Coverage   83.20%   83.19%   -0.01%     
==========================================
  Files         819      819              
  Lines      106814   106823       +9     
==========================================
+ Hits        88870    88873       +3     
- Misses      17944    17950       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bringing this up and tackling it yourself 🙏

Fun fact, the reset gate changes don't originate from pytorch 😄

The implementation we had is based on the latest v3 revisions (published at EMNLP) and has reset gate applied to hidden state before matrix multiplication.

The changes in your PR are based on the original v1 and applies the reset gate after.

Curiously, pytorch notes efficiency for their differing implementation (without much explanation). If you're interested, check out this awesome explanation behind the motivation to move the reset gate.

Your implementation LGTM! But what do you think about supporting both via a config? And we could provide the references I just linked in the doc.

Comment on lines 139 to 142
let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate

// n(ew)g(ate) tensor
let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate);
let biased_ng_input_sum =
self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since both implementations are valid, what do you think about having a field like reset_after to preserve both. And we could default to reset_after = true

Correct me if I am wrong, but with your changes this could be an easy condition:

let biased_ng_input_sum = if self.reset_after {
  self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate)
} else {
  let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate
  self.gate_product(&input_t, &reset_t, None, &self.new_gate);
}

And document the differences between both in the docstring.

Lmk what you think.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that seems like a good idea to me. I added a commit to do this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants