diff --git a/src/qonnx/transformation/qcdq_to_qonnx.py b/src/qonnx/transformation/qcdq_to_qonnx.py index b7e35c0d..04f8342a 100644 --- a/src/qonnx/transformation/qcdq_to_qonnx.py +++ b/src/qonnx/transformation/qcdq_to_qonnx.py @@ -109,6 +109,7 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: dq_init = model.get_initializer(dq_inp) dq_scale_v = model.get_initializer(dq_scale) dq_zeropt_v = model.get_initializer(dq_zeropt) + axis = get_by_name(dq_node.attribute, "axis") if quant_candidates is None and dq_init is None: continue if any([x is None for x in [dq_scale_v, dq_zeropt_v]]): @@ -123,13 +124,24 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: # read quantized weight dtype for standalone deqnt q_vi = model.get_tensor_valueinfo(dq_inp) (bitwidth, signed, narrow) = extract_elem_type(q_vi.type.tensor_type.elem_type) + scale_factor, zeropt = dq_scale, dq_zeropt + # fix scale factor for Quant (different shape expectations wrt broadcasting) + if not (axis is None): + axis_i = axis.i + ishape = model.get_tensor_shape(dq_inp) + desired_shp = [1] * len(ishape) + desired_shp[axis_i] = dq_scale_v.shape[0] + dq_scale_v = dq_scale_v.reshape(desired_shp) + dq_zeropt_v = dq_zeropt_v.reshape(desired_shp) + model.set_initializer(scale_factor, dq_scale_v) + model.set_initializer(zeropt, dq_zeropt_v) # overwrite DQ initializer with scaled version scaled_qnt_t = (dq_init - dq_zeropt_v) * dq_scale_v scaled_qnt_t = scaled_qnt_t.astype(np.float32) model.set_initializer(dq_inp, scaled_qnt_t) q_inp = dq_inp final_out = dq_node.output[0] - scale_factor, zeropt = dq_scale, dq_zeropt + nodes_to_remove.append(dq_node) elif quant_candidates[0].op_type in ["QuantizeLinear", "Clip"]: clip_range = None @@ -167,20 +179,19 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: value_info = model.get_tensor_valueinfo(quant_node.output[0]) (bitwidth, signed, narrow) = extract_elem_type(value_info.type.tensor_type.elem_type, clip_range) scale_factor, zeropt = q_scale, q_zeropt + if not (axis is None): + axis_i = axis.i + ishape = model.get_tensor_shape(dq_inp) + desired_shp = [1] * len(ishape) + desired_shp[axis_i] = dq_scale_v.shape[0] + dq_scale_v = dq_scale_v.reshape(desired_shp) + dq_zeropt_v = dq_zeropt_v.reshape(desired_shp) + model.set_initializer(scale_factor, dq_scale_v) + model.set_initializer(zeropt, dq_zeropt_v) else: # handle all other cases, skip continue - axis = get_by_name(dq_node.attribute, "axis") - # fix scale factor for Quant (different shape expectations wrt broadcasting) - if not (axis is None): - axis_i = axis.i - ishape = model.get_tensor_shape(dq_inp) - desired_shp = [1] * len(ishape) - desired_shp[axis_i] = dq_scale_v.shape[0] - dq_scale_v = dq_scale_v.reshape(desired_shp) - dq_zeropt_v = dq_zeropt_v.reshape(desired_shp) - model.set_initializer(scale_factor, dq_scale_v) - model.set_initializer(zeropt, dq_zeropt_v) + # create new Quant node for suitable cases new_q_node_name = "Quant_" + q_inp bw_tensor_name = f"{new_q_node_name}_bitwidth"