You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently the cg method in extras/gradient.py has a with_numpy decoration, meaning that the data will do roundtrip to the CPU. It should be able to run directly on GPU.
Plan of action:
Use with_numpy_cupy instead
adapt the computation inside to use the correct array module.
The text was updated successfully, but these errors were encountered:
Hi.The package of mirtorch have similar things and this link descibe a good way to calculate the gradient.These clues might help(https://github.com/guanhuaw/Bjork).
Currently the
cg
method inextras/gradient.py
has a with_numpy decoration, meaning that the data will do roundtrip to the CPU. It should be able to run directly on GPU.Plan of action:
with_numpy_cupy
insteadThe text was updated successfully, but these errors were encountered: