diff --git a/TrainingExtensions/torch/src/python/aimet_torch/arch_checker/arch_checker_rules.py b/TrainingExtensions/torch/src/python/aimet_torch/arch_checker/arch_checker_rules.py index 8377612dbc8..ac3f4f5607a 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/arch_checker/arch_checker_rules.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/arch_checker/arch_checker_rules.py @@ -110,14 +110,29 @@ def _examine_intermediate_padding_in_op_subset(op_subset): If both 1st and 2nd conv have paddings, add conv2 to inter_pad_node_list. """ if len(op_subset) == 4: - conv1, _, _, conv2 = op_subset + conv1, _, _, next_node = op_subset else: - conv1, _, conv2 = op_subset + conv1, _, next_node = op_subset - conv1_padding = sum(conv1.get_module().padding) - conv2_padding = sum(conv2.get_module().padding) - if conv1_padding and conv2_padding: - inter_pad_op_list.append(conv2) + previous_padding = sum(conv1.get_module().padding) + + # Examine all following nodes, while ignore activations and break if meets none conv node. + while next_node: + if next_node.type in _support_conv_op_type: + current_padding = sum(next_node.get_module().padding) + + if previous_padding and current_padding: + inter_pad_op_list.append(next_node) + + previous_padding = previous_padding or current_padding + + next_outputs = next_node.output_ops + + # Break if next_outputs has more than 1 output or not a activation/conv. + if len(next_outputs) != 1 or (next_outputs[0].type not in _support_activation_op_type and next_outputs[0].type not in _support_conv_op_type): + break + else: + next_node = next_outputs[0] _support_activation_op_type = ("Relu", "Tanh", "HardSwish") _support_conv_op_type = ("Conv", "Conv2D") diff --git a/TrainingExtensions/torch/test/python/test_arch_checker.py b/TrainingExtensions/torch/test/python/test_arch_checker.py index 8e78285c5dc..203bb4591ed 100644 --- a/TrainingExtensions/torch/test/python/test_arch_checker.py +++ b/TrainingExtensions/torch/test/python/test_arch_checker.py @@ -74,6 +74,7 @@ def __init__(self): self.bn2 = torch.nn.BatchNorm2d(32) self.conv4 = torch.nn.Conv2d(32, 32, kernel_size=2, stride=2, padding=2, bias=False) + # conv5 has intermediate paddings when consider (conv3, conv4, conv5) # conv6 has no intermediate paddings self.conv5 = torch.nn.Conv2d(32, 32, kernel_size=2, padding=2, bias=False) self.relu3 = torch.nn.ReLU() @@ -129,17 +130,19 @@ def __init__(self): self.relu1 = torch.nn.ReLU() self.conv2 = torch.nn.Conv2d(32, 32, kernel_size=2, stride=2, padding=2, bias=False) - # conv4 has no intermediate paddings since conv3 has no paddings + # conv4 has intermediate paddings consider (conv1, conv2, conv3, conv4) self.conv3 = torch.nn.Conv2d(32, 32, kernel_size=2, stride=2, padding=0, bias=False) self.relu2 = torch.nn.ReLU() self.conv4 = torch.nn.Conv2d(32, 32, kernel_size=2, stride=2, padding=2, bias=False) - + + # PReLU is not a supported activation, stop examining the nodes after for (conv1, relu1, conv2) pattern. + self.prelu = torch.nn.PReLU() # conv6 has no intermediate paddings self.conv5 = torch.nn.Conv2d(32, 32, kernel_size=2, padding=2, bias=False) self.relu3 = torch.nn.ReLU() self.conv6 = torch.nn.Conv2d(32, 32, kernel_size=2, padding=0, bias=False) - # conv7 has no intermediate paddings + # conv7, conv8 has no intermediate paddings self.conv7 = torch.nn.Conv2d(32, 32, kernel_size=2, padding=0, bias=False) self.relu4 = torch.nn.ReLU() self.conv8 = torch.nn.Conv2d(32, 32, kernel_size=2, padding=0, bias=False) @@ -153,6 +156,7 @@ def forward(self, x): x = self.relu2(x) x = self.conv4(x) + x = self.prelu(x) x = self.conv5(x) x = self.relu3(x) x = self.conv6(x) @@ -198,14 +202,17 @@ def forward(self, x): x = self.conv1(x) x = self.relu(x) x = self.conv2(x) + x = self.prelu(x) # Break point for variable length x = self.conv3(x) x = self.tanh(x) x = self.conv4(x) + x = self.prelu(x) # Break point for variable length x = self.conv5(x) x = self.hardswich(x) x = self.conv6(x) + x = self.prelu(x) # Break point for variable length x = self.conv7(x) x = self.prelu(x) @@ -324,6 +331,7 @@ def test_intermediate_padding(self): ArchChecker.check_model_arch(model, self.dummy_input) arch_checker_report = ArchChecker._arch_checker_report assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_with_BN.conv2"].failed_checks + assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_with_BN.conv5"].failed_checks assert "Model_inter_pad_with_BN.conv4" not in arch_checker_report.raw_report assert "Model_inter_pad_with_BN.conv6" not in arch_checker_report.raw_report assert "Model_inter_pad_with_BN.conv8" not in arch_checker_report.raw_report @@ -334,7 +342,7 @@ def test_intermediate_padding(self): ArchChecker.check_model_arch(model, self.dummy_input) arch_checker_report = ArchChecker._arch_checker_report assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_without_BN.conv2"].failed_checks - assert "Model_inter_pad_without_BN.conv4" not in arch_checker_report.raw_report + assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_without_BN.conv4"].failed_checks assert "Model_inter_pad_without_BN.conv6" not in arch_checker_report.raw_report assert "Model_inter_pad_without_BN.conv8" not in arch_checker_report.raw_report arch_checker_report.reset_raw_report() @@ -342,10 +350,16 @@ def test_intermediate_padding(self): model = Model_inter_pad_act_type() ArchChecker.check_model_arch(model, self.dummy_input) arch_checker_report = ArchChecker._arch_checker_report + + assert "_check_intermediate_padding" not in arch_checker_report.raw_report["Model_inter_pad_act_type.conv1"].failed_checks assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_act_type.conv2"].failed_checks + assert "Model_inter_pad_act_type.conv3" not in arch_checker_report.raw_report assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_act_type.conv4"].failed_checks + assert "Model_inter_pad_act_type.conv5" not in arch_checker_report.raw_report assert "_check_intermediate_padding" in arch_checker_report.raw_report["Model_inter_pad_act_type.conv6"].failed_checks + assert "Model_inter_pad_act_type.conv7" not in arch_checker_report.raw_report assert "Model_inter_pad_act_type.conv8" not in arch_checker_report.raw_report + arch_checker_report.reset_raw_report() filepath = ArchChecker._arch_checker_report._get_write_path(".html")