diff --git a/src/qonnx/.DS_Store b/src/qonnx/.DS_Store new file mode 100644 index 00000000..605cb96a Binary files /dev/null and b/src/qonnx/.DS_Store differ diff --git a/src/qonnx/core/data_layout.py b/src/qonnx/core/data_layout.py index 4a5d87a4..3e90a144 100644 --- a/src/qonnx/core/data_layout.py +++ b/src/qonnx/core/data_layout.py @@ -34,16 +34,20 @@ NCW = ["N", "C", "W"] NWC = ["N", "W", "C"] NC = ["N", "C"] +# 5-dimension video input, D for sequence depth +NCDHW = ["N", "C", "D", "H", "W"] +NDHWC = ["N", "D", "H", "W", "C"] + UNKNOWN = [] def is_channels_last(layout): return layout[-1] == "C" - def get_channels_last_layout_for_ndims(ndims): - return {4: NHWC, 3: NWC, 2: NC}[ndims] + return {5: NDHWC, 4: NHWC, 3: NWC, 2: NC}[ndims] def get_channels_first_layout_for_ndims(ndims): - return {4: NCHW, 3: NCW, 2: NC}[ndims] + return {5: NCDHW, 4: NCHW, 3: NCW, 2: NC}[ndims] + diff --git a/src/qonnx/transformation/.DS_Store b/src/qonnx/transformation/.DS_Store new file mode 100644 index 00000000..5008ddfc Binary files /dev/null and b/src/qonnx/transformation/.DS_Store differ diff --git a/src/qonnx/transformation/infer_data_layouts.py b/src/qonnx/transformation/infer_data_layouts.py index 81143e45..8b9c2f77 100644 --- a/src/qonnx/transformation/infer_data_layouts.py +++ b/src/qonnx/transformation/infer_data_layouts.py @@ -50,6 +50,11 @@ def _dims_to_layout(model, node, ndims): return DataLayout.NWC elif layout == "NC" and ndims == 2: return DataLayout.NC + # 5D + elif layout == "NCDHW" and ndims == 5: + return DataLayout.NCDHW + elif layout == "NDHWC" and ndims == 5: + return DataLayout.NDHWC else: return DataLayout.UNKNOWN else: @@ -59,6 +64,9 @@ def _dims_to_layout(model, node, ndims): return DataLayout.NWC elif ndims == 2: return DataLayout.NC + # 5D + elif ndims == 5: + return DataLayout.NCDHW else: return DataLayout.UNKNOWN else: @@ -135,6 +143,11 @@ def apply(self, model): graph_modified = True warnings.warn("Assuming 2D input is NC") model.set_tensor_layout(inp_name, DataLayout.NC) + # 5D + elif len(inp_shape) == 5: + graph_modified = True + warnings.warn("Assuming 5D input is NCDHW") + model.set_tensor_layout(inp_name, DataLayout.NCDHW) else: raise Exception( """Unknown number of dims for input, don't know