Skip to content

Commit

Permalink
Core: Multiple Inputs and Keyword Arguments
Browse files Browse the repository at this point in the history
- use additions to forward hooks in torch 2.0.0 to pass kwargs to
  pass keyword arguments
- handle multiple inputs and outputs in core.Hook and core.BasicHook, by
  passing all required grad_outputs and inputs to the backward
  implementation

TODO:

- finish draft and test implementation
- add tests
- add documentation

- This stands in conflict with #168, but promises a better
  implementation by handling inputs and outpus as common to a single
  function, rather than individually as proposed in #168
  • Loading branch information
chr5tphr committed Aug 10, 2023
1 parent c30e3cc commit d127af8
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 47 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def replace(mobj):
'click',
'numpy',
'Pillow',
'torch>=1.7.0',
'torch>=2.0.0',
'torchvision',
],
setup_requires=[
Expand Down
167 changes: 121 additions & 46 deletions src/zennit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import functools
import weakref
from contextlib import contextmanager
from itertools import compress

import torch

Expand Down Expand Up @@ -234,6 +235,20 @@ def modifier_wrapper(input, name):
return zero_params_wrapper


def uncompress(data, selector, compressed):
its = iter(selector)
itc = iter(compressed)
itd = iter(data)
try:
if next(its):
next(itd)
yield next(itc)
else:
yield next(itd)
except StopIteration:
return


class ParamMod:
'''Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
Expand Down Expand Up @@ -361,6 +376,7 @@ class Identity(torch.autograd.Function):
@staticmethod
def forward(ctx, *inputs):
'''Forward identity.'''
ctx.mark_non_differentiable(*[elem for elem in inputs if not elem.requires_grad])
return inputs

@staticmethod
Expand All @@ -376,56 +392,88 @@ def __init__(self):
self.active = True
self.tensor_handles = RemovableHandleList()

def pre_forward(self, module, input):
@staticmethod
def _inject_grad_fn(args):
tensor_mask = [isinstance(elem, torch.Tensor) for elem in args]
tensors = tuple(compress(args, tensor_mask))
# tensors = [(n, elem) for elem in enumerate(args) if isinstance(elem, torch.Tensor)]

# only if gradient required
if not any(tensor.requires_grad for _, tensor in tensors):
return None, args

# add identity to ensure .grad_fn exists
post_tensors = Identity.apply(*tensors)
grad_fn = next((tensor.grad_fn for tensor in post_tensors if tensor.grad_fn is not None), None)
if grad_fn is None:
raise RuntimeError('Backward hook could not be registered!')

post_args = tuple(uncompress(output, tensor_mask, post_tensors))
# work around to support in-place operations
# post_args = tuple(elem.clone() for elem in post_args)
return grad_fn, post_args, tensor_mask

def pre_forward(self, module, args, kwargs):
'''Apply an Identity to the input before the module to register a backward hook.'''
hook_ref = weakref.ref(self)

grad_fn, post_args, input_tensor_mask = self._inject_grad_fn(args)
if grad_fn is None:
return

@functools.wraps(self.backward)
def wrapper(grad_input, grad_output):
hook = hook_ref()
if hook is not None and hook.active:
return hook.backward(module, grad_input, hook.stored_tensors['grad_output'])
return hook.backward(
module,
list(uncompress(
repeat(None),
input_tensor_mask,
grad_input,
)),
hook.stored_tensors['grad_output'],
)
return None

if not isinstance(input, tuple):
input = (input,)
# register the input tensor gradient hook
self.tensor_handles.append(grad_fn.register_hook(wrapper))

# only if gradient required
if input[0].requires_grad:
# add identity to ensure .grad_fn exists
post_input = Identity.apply(*input)
# register the input tensor gradient hook
self.tensor_handles.append(
post_input[0].grad_fn.register_hook(wrapper)
)
# work around to support in-place operations
post_input = tuple(elem.clone() for elem in post_input)
else:
# no gradient required
post_input = input
return post_input[0] if len(post_input) == 1 else post_input
return post_args, kwargs

def post_forward(self, module, input, output):
def post_forward(self, module, args, kwargs, output):
'''Register a backward-hook to the resulting tensor right after the forward.'''
hook_ref = weakref.ref(self)

single = not isinstance(output, tuple)
if single:
output = (output,)

grad_fn, post_output, output_tensor_mask = self._inject_grad_fn(output)
if grad_fn is None:
return

@functools.wraps(self.pre_backward)
def wrapper(grad_input, grad_output):
hook = hook_ref()
if hook is not None and hook.active:
return hook.pre_backward(module, grad_input, grad_output)
return hook.pre_backward(
module,
grad_input,
list(uncompress(
repeat(None),
output_tensor_mask,
grad_output
))
)
return None

if not isinstance(output, tuple):
output = (output,)
# register the output tensor gradient hook
self.tensor_handles.append(grad_fn.register_hook(wrapper))

# only if gradient required
if output[0].grad_fn is not None:
# register the output tensor gradient hook
self.tensor_handles.append(
output[0].grad_fn.register_hook(wrapper)
)
return output[0] if len(output) == 1 else output
if single:
return post_output[0]
return post_output

def pre_backward(self, module, grad_input, grad_output):
'''Store the grad_output for the backward hook'''
Expand All @@ -452,9 +500,9 @@ def register(self, module):
'''Register this instance by registering all hooks to the supplied module.'''
return RemovableHandleList([
RemovableHandle(self),
module.register_forward_pre_hook(self.pre_forward),
module.register_forward_hook(self.post_forward),
module.register_forward_hook(self.forward),
module.register_forward_pre_hook(self.pre_forward, with_kwargs=True),
module.register_forward_hook(self.post_forward, with_kwargs=True),
module.register_forward_hook(self.forward, with_kwargs=True),
])


Expand Down Expand Up @@ -518,31 +566,58 @@ def __init__(
self.gradient_mapper = gradient_mapper
self.reducer = reducer

def forward(self, module, input, output):
def forward(self, module, args, kwargs, output):
'''Forward hook to save module in-/outputs.'''
self.stored_tensors['input'] = input
self.stored_tensors['input'] = args
self.stored_tensors['kwargs'] = kwargs

def backward(self, module, grad_input, grad_output):
'''Backward hook to compute LRP based on the class attributes.'''
original_input = self.stored_tensors['input'][0].clone()
input_mask = [elem is not None for elem in self.stored_tensors['input']]
output_mask = [elem is not None for elem in grad_output]
cgrad_output = tuple(compress(grad_output, output_mask))

original_inputs = [tensor.clone() for tensor in self.stored_tensors['input']]
inputs = []
outputs = []
for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers):
input = in_mod(original_input).requires_grad_()
mod_args = (in_mod(tensor).requires_grad_() for tensor in compress(original_inputs, input_mask))
args = tuple(uncompress(original_inputs, input_mask, mod_args))
with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad():
output = modified.forward(input)
output = out_mod(output)
inputs.append(input)
output = modified.forward(*args, **kwargs)
if not isinstance(output, tuple):
output = (output,)
output = tuple(out_mod(tensor) for tensor in compress(output, output_mask))
inputs.append(compress(args, input_mask))
outputs.append(output)
grad_outputs = self.gradient_mapper(grad_output[0], outputs)

inputs = list(zip(*inputs))
outputs = list(zip(*outputs))
input_struct = [len(elem) for elem in inputs]
output_struct = [len(elem) for elem in outputs]

grad_outputs = tuple(
self.gradient_mapper(gradout, outs)
for gradout, outs in zip(cgrad_output, outputs)
)
inputs_flat = sum(inputs, tuple())
outputs_flat = sum(outputs, tuple())
grad_outputs_flat = sum(grad_outputs, tuple())

gradients = torch.autograd.grad(
outputs,
inputs,
grad_outputs=grad_outputs,
create_graph=grad_output[0].requires_grad
inputs_flat,
outputs_flat,
grad_outputs=grad_outputs_flat,
create_graph=any(tensor.requires_grad for tensor in cgrad_output)
)
relevance = self.reducer(inputs, gradients)
return tuple(relevance if original.shape == relevance.shape else None for original in grad_input)

input_it = iter(input)
inputs_re = [tuple(islice(input_it, size)) for size in input_struct]
gradient_it = iter(gradients)
gradients_re = [tuple(islice(gradient_it, size)) for size in input_struct]

relevances = (self.reducer(inp, grad) for inp, grad in zip(inputs_re, gradients_re))
return tuple(uncompress(repeat(None), input_mask, relevances))

def copy(self):
'''Return a copy of this hook.
Expand Down

0 comments on commit d127af8

Please sign in to comment.