Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Primitive called with ArrayBox object on forward pass when argument passed by keyword #681

Open
andre-al opened this issue Feb 27, 2025 · 0 comments

Comments

@andre-al
Copy link

andre-al commented Feb 27, 2025

For context, I am working on a project where I need to use functions outside autograd which do not support autograd's ArrayBox arguments, but have easy enough derivatives that I can implement them directly and extend autograd.

When doing so, I was running into intermittent issues (the worst kind of issue) where these functions were sometimes still being called with ArrayBox arguments despite being defined as primitives, and when that happened an error was raised.

After some tinkering, I finally traced the core issue: when a primitive is called with keyword arguments, the forward pass of the function is done with the ArrayBox version of the argument regardless.

I include a minimal working example below

from autograd import grad
from autograd.extend import primitive, defvjp

@primitive
def f(x):
  print(f'x={x}')
  return x**2
def f_vjp(ans, x):
  return lambda g: 2*g*x
defvjp(f, f_vjp)

def g(y):
  return f(y)

def h(y):
  return f(x=y)

dg = grad(g)
print(f'dg={dg(1.)}')

dh = grad(h)
print(f'dg={dh(1.)}')

This code outputs

x=1.0
dg=2.0
x=Autograd ArrayBox with value 1.0
dh=2.0

In words, f(x) is declared a primitive. All it does is print(x) and return x**2.
Both g and h are direct wrappers around f, but while g calls f with a positional x argument, h does so by keyword.

When evaluating the grad(g) at a point, the forward pass prints the argument when it evaluates the primitive f, correctly outputing x=1.0, "unboxed" with a simple float type.

When evaluating grad(h) instead, the forward pass prints that x is an ArrayBox with value 1.0. While this is not be problematic here, it would be if the reason it was made a primitive was incompatibility with Boxed arguments.

I haven't checked the source code to see if there would be a simple way to fix this behavior, but even if not at the very least a well-described warning would be very useful for others not to spend as much time as I did puzzled by what could be causing errors like the ones I was having.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant