Skip to content

Commit

Permalink
Merge pull request fastmachinelearning#976 from vloncar/channels_last…
Browse files Browse the repository at this point in the history
…_flatten

Remove unnecessary transposes related to conversion to channels_last format
  • Loading branch information
jmitrevs authored Apr 19, 2024
2 parents 1616caf + 295ba9f commit a357b7a
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 9 deletions.
4 changes: 3 additions & 1 deletion hls4ml/model/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
[
'infer_precision_types',
'channels_last_converter',
'remove_transpose_before_flatten',
'remove_nop_transpose',
'remove_single_channel_transpose',
'fuse_bias_add',
'remove_useless_transpose',
'expand_layer_group',
'output_rounding_saturation_mode',
'qkeras_factorize_alpha',
Expand Down
62 changes: 61 additions & 1 deletion hls4ml/model/optimizer/passes/convert_to_channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# Based on https://github.com/fastmachinelearning/qonnx/blob/
# 12c96a3ded06beacab08e0f554e4ed014476c0aa/src/qonnx/transformation/channels_last.py

from hls4ml.model.layers import Concatenate, Input, Reshape
from hls4ml.model.layers import Concatenate, Dense, Input, Reshape, Transpose
from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.types import WeightVariable


class ChannelsLastConverter(OptimizerPass):
Expand Down Expand Up @@ -133,3 +134,62 @@ def transform(self, model, node):

node.channels_last_converted = True
return True


class RemoveTransposeBeforeFlatten(OptimizerPass):
'''After the channels last conversion, model may have a sequence: Transpose -> Flatten -> Dense.
In this case we can remove the expensive transpose and instead transpose the weights of the Dense layer.'''

def match(self, node):
if node.model.config.get_config_value('IOType') != 'io_parallel':
return False

if hasattr(node, '_channels_last_keep_transpose') and node._channels_last_keep_transpose:
return False

if isinstance(node, Reshape):
input_node = node.get_input_node()
output_nodes = node.get_output_nodes()
if (
len(node.get_attr('target_shape')) == 1
and isinstance(input_node, Transpose)
and len(output_nodes) == 1
and isinstance(output_nodes[0], Dense)
):
return True

return False

def transform(self, model, node):
transpose_node = node.get_input_node()
dense_node = node.get_output_nodes()[0]
input_shape = transpose_node.get_output_variable().shape

if len(input_shape) == 2: # Usually after Conv1D
tran_axis = [1, 0, 2]
elif len(input_shape) == 3: # Usually after Conv2D
tran_axis = [1, 2, 0, 3]
else: # In this case we bail
node._channels_last_keep_transpose = True
return False

weight_var = dense_node.get_weights('weight')
# Transpose the weights to achieve the same computation with transposed input
weight_data_t = weight_var.data.reshape(*input_shape, -1).transpose(*tran_axis)
weight_data_t = weight_data_t.reshape(-1, weight_data_t.shape[-1])
new_weight_var = WeightVariable(
var_name=weight_var.name,
type_name=weight_var.type.name,
precision=weight_var.type.precision,
quantizer=weight_var.quantizer,
data=weight_data_t,
index=dense_node.index,
)

# Update the weight variable of the node
dense_node.set_attr('weight', new_weight_var)

# Get rid of the Transpose node
model.remove_node(transpose_node)

return True
40 changes: 33 additions & 7 deletions hls4ml/model/optimizer/passes/transpose_opt.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,47 @@
from hls4ml.model.layers import Transpose
from hls4ml.model.layers import Input, Transpose
from hls4ml.model.optimizer import OptimizerPass


class RemoveUselessTranspose(OptimizerPass):
class RemoveNopTranspose(OptimizerPass):
"""
Remove a transpose layer if it doesn't do anything to a 1D array. i.e, 1D input and perm = [0]
"""

def match(self, node):
is_match = isinstance(node, Transpose) and node.get_attr('perm') == [0] # Useless transpose
return is_match

def transform(self, model, node):
"""
Remove a transpose layer if it doesn't do anything. i.e 1D input and perm = [0]
"""
print(f"Unnessary {node.name} in the model, optimizing ...")
print(f'Unnecessary transpose node ({node.name}) detected, optimizing ...')
if not node.get_output_nodes():
print(f"WARNING: {node.name} is the output layer! No rewiring performed.")
print(f'WARNING: {node.name} is the output layer! No rewiring performed.')
model.remove_node(node, rewire=False) # Don't rewire if there is no output layer
else:
model.remove_node(node, rewire=True)

return True


class RemoveSingleChannelTranspose(OptimizerPass):
"""
Remove transpose of inputs if the number of channels is 1 as for io_parallel this doesn't affect the array
representation used
"""

def match(self, node):
if node.model.config.get_config_value('IOType') != 'io_parallel':
return False

return (
isinstance(node, Transpose)
and isinstance(node.get_input_node(), Input)
and node.get_input_variable().shape[0] == 1
)

def transform(self, model, node):
# Adjust the input shape and remove the Transpose node
input_var = node.get_input_variable()
input_var.shape.append(input_var.shape.pop(0))
model.remove_node(node)

return True
69 changes: 69 additions & 0 deletions test/pytest/test_pytorch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,3 +740,72 @@ def test_skipped_layers(backend, io_type):
hls_prediction = hls_model.predict(hls_input).flatten()

np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2)


@pytest.mark.parametrize('backend', ['Vivado', 'Quartus'])
@pytest.mark.parametrize('io_type', ['io_parallel']) # Only io_parallel for now
@pytest.mark.parametrize('tensor_rank', [2, 3])
def test_remove_transpose(backend, io_type, tensor_rank):
class TestModel(nn.Module):
def __init__(self, tensor_rank):
super().__init__()
if tensor_rank == 2:
self.conv1 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=3, bias=False)
self.relu1 = nn.ReLU()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(in_features=4 * 6, out_features=5, bias=False)
self.relu2 = nn.ReLU()
else:
self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, bias=False)
self.relu1 = nn.ReLU()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(in_features=4 * 6 * 6, out_features=5, bias=False)
self.relu2 = nn.ReLU()

def forward(self, x):
# In the hls4ml model, there should be a Transpose node on the input tensor before conv1
x = self.conv1(x)
x = self.relu1(x)
x = self.flatten(x) # This should result in a Transpose node that we aim to remove
x = self.fc1(x)
x = self.relu2(x)
return x

model = TestModel(tensor_rank=tensor_rank)
if tensor_rank == 2:
input_shape = (1, 8)
input_tensor = torch.randn(10, 1, 8)
hls_input = np.ascontiguousarray(torch.permute(input_tensor, (0, 2, 1)).detach().numpy())
else:
input_shape = (1, 8, 8)
input_tensor = torch.randn(10, 1, 8, 8)
hls_input = np.ascontiguousarray(torch.permute(input_tensor, (0, 2, 3, 1)).detach().numpy())

batch_input_shape = (None,) + input_shape
config = config_from_pytorch_model(
model,
default_precision='ap_fixed<32,16>',
inputs_channel_last=False, # Crucial for testing if the first Transpose was removed
transpose_outputs=False,
)
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_transpose_nop_{tensor_rank}d_{backend}_{io_type}')
hls_model = convert_from_pytorch_model(
model,
batch_input_shape,
hls_config=config,
output_dir=output_dir,
io_type=io_type,
backend=backend,
)

hls_model.compile()

# Test optimizers removed the two Transpose layers
transpose_layers = [layer for layer in list(hls_model.get_layers()) if layer.class_name == 'Transpose']
assert len(transpose_layers) == 0

# Test predictions match
pytorch_prediction = model(input_tensor).detach().numpy().flatten()
hls_prediction = hls_model.predict(hls_input).flatten()

np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2)

0 comments on commit a357b7a

Please sign in to comment.