You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For context, I am working on a project where I need to use functions outside autograd which do not support autograd's ArrayBox arguments, but have easy enough derivatives that I can implement them directly and extend autograd.
When doing so, I was running into intermittent issues (the worst kind of issue) where these functions were sometimes still being called with ArrayBox arguments despite being defined as primitives, and when that happened an error was raised.
After some tinkering, I finally traced the core issue: when a primitive is called with keyword arguments, the forward pass of the function is done with the ArrayBox version of the argument regardless.
x=1.0
dg=2.0
x=Autograd ArrayBox with value 1.0
dh=2.0
In words, f(x) is declared a primitive. All it does is print(x) and return x**2.
Both g and h are direct wrappers around f, but while g calls f with a positional x argument, h does so by keyword.
When evaluating the grad(g) at a point, the forward pass prints the argument when it evaluates the primitive f, correctly outputing x=1.0, "unboxed" with a simple float type.
When evaluating grad(h) instead, the forward pass prints that x is an ArrayBox with value 1.0. While this is not be problematic here, it would be if the reason it was made a primitive was incompatibility with Boxed arguments.
I haven't checked the source code to see if there would be a simple way to fix this behavior, but even if not at the very least a well-described warning would be very useful for others not to spend as much time as I did puzzled by what could be causing errors like the ones I was having.
The text was updated successfully, but these errors were encountered:
For context, I am working on a project where I need to use functions outside autograd which do not support autograd's ArrayBox arguments, but have easy enough derivatives that I can implement them directly and extend autograd.
When doing so, I was running into intermittent issues (the worst kind of issue) where these functions were sometimes still being called with ArrayBox arguments despite being defined as primitives, and when that happened an error was raised.
After some tinkering, I finally traced the core issue: when a primitive is called with keyword arguments, the forward pass of the function is done with the ArrayBox version of the argument regardless.
I include a minimal working example below
This code outputs
In words, f(x) is declared a primitive. All it does is print(x) and return x**2.
Both g and h are direct wrappers around f, but while g calls f with a positional x argument, h does so by keyword.
When evaluating the grad(g) at a point, the forward pass prints the argument when it evaluates the primitive f, correctly outputing x=1.0, "unboxed" with a simple float type.
When evaluating grad(h) instead, the forward pass prints that x is an ArrayBox with value 1.0. While this is not be problematic here, it would be if the reason it was made a primitive was incompatibility with Boxed arguments.
I haven't checked the source code to see if there would be a simple way to fix this behavior, but even if not at the very least a well-described warning would be very useful for others not to spend as much time as I did puzzled by what could be causing errors like the ones I was having.
The text was updated successfully, but these errors were encountered: