From dd2a3f796c666021ee42e710f14859bf86bc66a7 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Thu, 10 Oct 2024 16:45:31 +0200 Subject: [PATCH] [InferDataLayouts] Add fallback for layout propagation If there is no tensor layout annotation for the input of a non-FINN operator, annotate with the same fallback defaults as for the FINN-ops instead of propagating empty/invalid layouts. --- .../transformation/infer_data_layouts.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/qonnx/transformation/infer_data_layouts.py b/src/qonnx/transformation/infer_data_layouts.py index 81143e45..241311a4 100644 --- a/src/qonnx/transformation/infer_data_layouts.py +++ b/src/qonnx/transformation/infer_data_layouts.py @@ -62,9 +62,23 @@ def _dims_to_layout(model, node, ndims): else: return DataLayout.UNKNOWN else: - # propagate input layout to output - # TODO this won't work for concat, squeeze/unsqueeze/reshape... - return model.get_tensor_layout(node.input[0]) + # Check whether there is a layout annotation for the first input + # TODO: There are multi-input operations, why should the first + # determine the output layout? + if layout := model.get_tensor_layout(node.input[0]): + # If annotation present: propagate input layout to output + # TODO: this won't work for concat, squeeze/unsqueeze/reshape... + return layout + # Fallback to the same defaults as for the FINN-Ops above + else: + if ndims == 4: + return DataLayout.NHWC + elif ndims == 3: + return DataLayout.NWC + elif ndims == 2: + return DataLayout.NC + else: + return DataLayout.UNKNOWN def _infer_node_data_layout(model, node):