diff --git a/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py b/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py index 1aadbf25c..76b3fc8d1 100644 --- a/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +++ b/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py @@ -84,7 +84,8 @@ def set_bit_widths(mixed_precision_enable: bool, def _get_node_qc_by_bit_widths(node: BaseNode, bit_width_cfg: List[int], - node_index_in_graph: int) -> Any: + node_index_in_graph: int, + fw_info) -> Any: """ Get the node's quantization configuration that matches to the bit width index as in the MP configuration bit_width_cfg. diff --git a/model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py b/model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py index cd9fde61e..3a96191ef 100644 --- a/model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +++ b/model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py @@ -95,13 +95,14 @@ def get_trainable_quantizer_quantization_candidates(n: BaseNode, attr: str = Non """ if attr is not None: - # all candidates must have the same weights quantization method - weights_quantization_methods = set([cfg.weights_quantization_cfg.weights_quantization_method for cfg in n.candidates_quantization_cfg]) - if len(weights_quantization_methods) > 1: - Logger.critical(f"Invalid 'candidates_quantization_cfg': Inconsistent weights " - f"quantization methods detected: {weights_quantization_methods}. " - f"Trainable quantizer requires all candidates to have the same weights " - f"quantization method.") # pragma: no cover + # all candidates must have the same weights quantization method + weights_quantization_methods = set([cfg.weights_quantization_cfg.get_attr_config(attr).weights_quantization_method + for cfg in n.candidates_quantization_cfg]) + if len(weights_quantization_methods) > 1: + Logger.critical(f"Invalid 'candidates_quantization_cfg': Inconsistent weights " + f"quantization methods detected: {weights_quantization_methods}. " + f"Trainable quantizer requires all candidates to have the same weights " + f"quantization method.") # pragma: no cover # all candidates must have the same activation quantization method activation_quantization_methods = set([cfg.activation_quantization_cfg.activation_quantization_method