Skip to content

Commit

Permalink
Merge pull request #15 from calvinmccarter-at-lightmatter/reducesum
Browse files Browse the repository at this point in the history
LSTM conversion, training & multi-device support, and more
  • Loading branch information
Talmaj authored Oct 7, 2021
2 parents 8eb6ae8 + 4418de8 commit dc7684c
Show file tree
Hide file tree
Showing 69 changed files with 2,963 additions and 281 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ The Uncompromising Code Formatter: [Black](https://github.com/psf/black)
```black {source_file_or_directory}```

Install it into pre-commit hook to always commit nicely formatted code:
```pre-commmit install```
```pre-commit install```

### Testing
[Pytest](https://docs.pytest.org/en/latest/) and [tox](https://tox.readthedocs.io/en/latest/).
Expand All @@ -66,4 +66,4 @@ Add any custom models to `./fixtures` folder to test their conversion.
### Debugging
Set `ConvertModel(..., debug=True)` to compare each converted
activation from pytorch with the activation from onnxruntime.
This helps identify where in the graph the activations start to differ.
This helps identify where in the graph the activations start to differ.
22 changes: 22 additions & 0 deletions onnx2pytorch/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch import nn
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.pooling import _MaxPoolNd
from onnx2pytorch.operations import (
BatchNormWrapper,
InstanceNormWrapper,
Loop,
LSTMWrapper,
Split,
TopK,
)


COMPOSITE_LAYERS = (nn.Sequential,)
MULTIOUTPUT_LAYERS = (_MaxPoolNd, Loop, LSTMWrapper, Split, TopK)
STANDARD_LAYERS = (
_ConvNd,
BatchNormWrapper,
InstanceNormWrapper,
LSTMWrapper,
nn.Linear,
)
155 changes: 102 additions & 53 deletions onnx2pytorch/convert/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def extract_attr_values(attr):
value = numpy_helper.to_array(attr.t)
elif attr.type == AttributeType["STRING"]:
value = attr.s.decode()
elif attr.type == AttributeType["GRAPH"]:
value = attr.g
else:
raise NotImplementedError(
"Extraction of attribute type {} not implemented.".format(attr.type)
Expand All @@ -52,21 +54,27 @@ def extract_attributes(node):
"""Extract onnx attributes. Map onnx feature naming to pytorch."""
kwargs = {}
for attr in node.attribute:
if attr.name == "dilations":
kwargs["dilation"] = extract_attr_values(attr)
elif attr.name == "group":
kwargs["groups"] = extract_attr_values(attr)
elif attr.name == "kernel_shape":
kwargs["kernel_size"] = extract_attr_values(attr)
elif attr.name == "pads":
params = extract_attr_values(attr)
if node.op_type == "Pad":
kwargs["padding"] = extract_padding_params(params)
if attr.name == "activation_alpha":
kwargs["activation_alpha"] = extract_attr_values(attr)
elif attr.name == "activation_beta":
kwargs["activation_beta"] = extract_attr_values(attr)
elif attr.name == "activations":
kwargs["activations"] = extract_attr_values(attr)
elif attr.name == "alpha":
if node.op_type == "LeakyRelu":
kwargs["negative_slope"] = extract_attr_values(attr)
elif node.op_type in ("Elu", "ThresholdedRelu"):
kwargs["alpha"] = extract_attr_values(attr)
else:
# Works for Conv, MaxPooling and other layers from convert_layer func
kwargs["padding"] = extract_padding_params_for_conv_layer(params)
elif attr.name == "strides":
kwargs["stride"] = extract_attr_values(attr)
kwargs["weight_multiplier"] = extract_attr_values(attr)
elif attr.name == "auto_pad":
value = extract_attr_values(attr)
if value == "NOTSET":
pass
else:
raise NotImplementedError(
"auto_pad={} functionality not implemented.".format(value)
)
elif attr.name == "axis" and node.op_type == "Flatten":
kwargs["start_dim"] = extract_attr_values(attr)
elif attr.name == "axis" or attr.name == "axes":
Expand All @@ -75,62 +83,103 @@ def extract_attributes(node):
kwargs["dim"] = v[0]
else:
kwargs["dim"] = v
elif attr.name == "keepdims":
kwargs["keepdim"] = bool(extract_attr_values(attr))
elif attr.name == "beta":
kwargs["bias_multiplier"] = extract_attr_values(attr)
elif attr.name == "body":
kwargs["body"] = extract_attr_values(attr)
elif attr.name == "ceil_mode":
kwargs["ceil_mode"] = bool(extract_attr_values(attr))
elif attr.name == "center_point_box":
kwargs["center_point_box"] = extract_attr_values(attr)
elif attr.name == "clip":
kwargs["clip"] = extract_attr_values(attr)
elif attr.name == "coordinate_transformation_mode":
arg = extract_attr_values(attr)
if arg == "align_corners":
kwargs["align_corners"] = True
else:
warnings.warn(
"Pytorch's interpolate uses no coordinate_transformation_mode={}. "
"Result might differ.".format(arg)
)
elif attr.name == "dilations":
kwargs["dilation"] = extract_attr_values(attr)
elif attr.name == "direction":
kwargs["direction"] = extract_attr_values(attr)
elif attr.name == "ends":
kwargs["ends"] = extract_attr_values(attr)
elif attr.name == "epsilon":
kwargs["eps"] = extract_attr_values(attr)
elif attr.name == "group":
kwargs["groups"] = extract_attr_values(attr)
elif attr.name == "hidden_size":
kwargs["hidden_size"] = extract_attr_values(attr)
elif attr.name == "input_forget":
kwargs["input_forget"] = extract_attr_values(attr)
elif attr.name == "keepdims":
kwargs["keepdim"] = bool(extract_attr_values(attr))
elif attr.name == "kernel_shape":
kwargs["kernel_size"] = extract_attr_values(attr)
elif attr.name == "largest":
kwargs["largest"] = extract_attr_values(attr)
elif attr.name == "layout":
kwargs["layout"] = extract_attr_values(attr)
elif attr.name == "mode":
kwargs["mode"] = extract_attr_values(attr)
elif attr.name == "momentum":
kwargs["momentum"] = extract_attr_values(attr)
elif attr.name == "ceil_mode":
kwargs["ceil_mode"] = bool(extract_attr_values(attr))
elif attr.name == "value":
kwargs["constant"] = extract_attr_values(attr)
elif attr.name == "noop_with_empty_axes":
kwargs["noop_with_empty_axes"] = extract_attr_values(attr)
elif attr.name == "output_shape" and node.op_type == "ConvTranspose":
raise NotImplementedError(
"ConvTranspose with dynamic padding not implemented."
)
elif attr.name == "pads":
params = extract_attr_values(attr)
if node.op_type == "Pad":
kwargs["padding"] = extract_padding_params(params)
else:
# Works for Conv, MaxPooling and other layers from convert_layer func
kwargs["padding"] = extract_padding_params_for_conv_layer(params)
elif attr.name == "perm":
kwargs["dims"] = extract_attr_values(attr)
elif attr.name == "split":
kwargs["split_size_or_sections"] = extract_attr_values(attr)
elif attr.name == "repeats":
kwargs["repeats"] = extract_attr_values(attr)
elif attr.name == "sorted":
kwargs["sorted"] = extract_attr_values(attr)
elif attr.name == "sparse_value":
kwargs["constant"] = extract_attr_values(attr)
elif attr.name == "spatial":
kwargs["spatial"] = extract_attr_values(attr) # Batch norm parameter
elif attr.name == "split":
kwargs["split_size_or_sections"] = extract_attr_values(attr)
elif attr.name == "strides":
kwargs["stride"] = extract_attr_values(attr)
elif attr.name == "starts":
kwargs["starts"] = extract_attr_values(attr)
elif attr.name == "to":
kwargs["dtype"] = TENSOR_PROTO_MAPPING[extract_attr_values(attr)].lower()
elif attr.name == "mode":
kwargs["mode"] = extract_attr_values(attr)
elif attr.name == "transB":
kwargs["transpose_weight"] = not extract_attr_values(attr)
elif attr.name == "transA":
kwargs["transpose_activation"] = bool(extract_attr_values(attr))
elif attr.name == "alpha" and node.op_type == "LeakyRelu":
kwargs["negative_slope"] = extract_attr_values(attr)
elif attr.name == "alpha" and node.op_type == "Elu":
kwargs["alpha"] = extract_attr_values(attr)
elif attr.name == "alpha":
kwargs["weight_multiplier"] = extract_attr_values(attr)
elif attr.name == "beta":
kwargs["bias_multiplier"] = extract_attr_values(attr)
elif attr.name == "starts":
kwargs["starts"] = extract_attr_values(attr)
elif attr.name == "ends":
kwargs["ends"] = extract_attr_values(attr)
elif attr.name == "coordinate_transformation_mode":
arg = extract_attr_values(attr)
if arg == "align_corners":
kwargs["align_corners"] = True
else:
warnings.warn(
"Pytorch's interpolate uses no coordinate_transformation_mode={}. "
"Result might differ.".format(arg)
)
elif attr.name == "value":
kwargs["constant"] = extract_attr_values(attr)
elif attr.name == "value_float":
kwargs["constant"] = extract_attr_values(attr)
elif attr.name == "value_floats":
kwargs["constant"] = extract_attr_values(attr)
elif attr.name == "value_int":
kwargs["constant"] = extract_attr_values(attr)
elif attr.name == "value_ints":
kwargs["constant"] = extract_attr_values(attr)
elif attr.name == "value_string":
kwargs["constant"] = extract_attr_values(attr)
elif attr.name == "value_strings":
kwargs["constant"] = extract_attr_values(attr)
elif node.op_type == "Resize":
# These parameters are not used, warn in Resize operator
kwargs[attr.name] = extract_attr_values(attr)
elif attr.name == "auto_pad":
value = extract_attr_values(attr)
if value == "NOTSET":
pass
else:
raise NotImplementedError(
"auto_pad={} functionality not implemented.".format(value)
)
else:
raise NotImplementedError(
"Extraction of attribute {} not implemented.".format(attr.name)
Expand Down
15 changes: 10 additions & 5 deletions onnx2pytorch/convert/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@ def debug_model_conversion(onnx_model, inputs, pred_act, node, rtol=1e-3, atol=1
raise TypeError("inputs should be in a list.")

if not all(isinstance(x, np.ndarray) for x in inputs):
inputs = [x.detach().numpy() for x in inputs]
inputs = [x.detach().cpu().numpy() for x in inputs]

exp_act = get_activation_value(onnx_model, inputs, list(node.output))
if isinstance(pred_act, list):
assert len(exp_act) == len(pred_act)
for a, b in zip(exp_act, pred_act):
assert torch.allclose(torch.from_numpy(a), b, rtol=rtol, atol=atol)
exp = torch.from_numpy(a).cpu()
pred = b.cpu()
assert torch.equal(torch.tensor(exp.shape), torch.tensor(pred.shape))
assert torch.allclose(exp, pred, rtol=rtol, atol=atol)
else:
a = torch.from_numpy(exp_act[0])
b = pred_act
assert torch.allclose(a, b, rtol=rtol, atol=atol)
exp = torch.from_numpy(exp_act[0]).cpu()
pred = pred_act.cpu()
assert torch.equal(torch.tensor(exp.shape), torch.tensor(pred.shape))
assert torch.allclose(exp, pred, rtol=rtol, atol=atol)
Loading

0 comments on commit dc7684c

Please sign in to comment.