-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Copy of the inputs for the evaluation during compilation #1150
base: main
Are you sure you want to change the base?
Conversation
|
|
363f22e
to
3413ebf
Compare
|
|
94ab912
to
4cf2184
Compare
|
|
4cf2184
to
9565677
Compare
|
1 similar comment
|
|
1 similar comment
|
return x + y | ||
|
||
input = torch.zeros(shape, requires_grad=False) | ||
framework_input = input.detach().clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we're doing detach and clone here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because if run inference on the torch
module and then want to run it on compiled module
, inputs for the compiled module would be different due to the changes made by torch (as default torch is not functional, meaning it will change input)
compiled_model = forge.compile(framework_model, sample_inputs=tt_inputs, module_name="inplace") | ||
tty = compiled_model(*tt_inputs)[0] | ||
|
||
compare_with_golden(golden=y, calculated=tty) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use the standard verify function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the nature of the in-place problem we are facing, I have separate input for compiled module
and torch module
so I can't pass it to the verify
function.
tty = compiled_model(*tt_inputs)[0] | ||
|
||
compare_with_golden(golden=y, calculated=tty) | ||
print(framework_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need these prints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really.
# convert tensor from tf to torch | ||
y = torch.tensor(y.numpy()) | ||
|
||
compare_with_golden(golden=y, calculated=tty) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comments as for PT example
I would still wait with this PR until we resolve which policy we want to follow (functional or per framework policy). |
Summary
This PR ensures that a copy of the input tensors is created when running forward during compilation (e.g. when trying to verify outputs using forward pass of the framework model). This prevents unintended modifications to the original input tensors if in-place operations are performed during the forward pass.
In order for it to work, first this PR in the
tt-tvm
needs to be merged.Why is this needed?
Some models perform in-place operations on input tensors during the forward pass, which can lead to unintended changes in the original inputs. By making a copy of the inputs, we ensure correctness and avoid potential issues when running forward during compilation.
Example test:
❗❗❗ IMPORTANT NOTES ❗❗❗
In the training mode if the inputs require grad, pytorch will try to do the forward and will throw runtime error:
a leaf Variable that requires grad is being used in an in-place operation.
, while the compiled model silently computes the gradients and this need to be addressed.Introducing this change our compiler will not change it's input, so it will act as tensorflow (it doesn't change it's input tensor), but it is won't be aligned with pytorch which allows in-place changes of the input tensor.