From 8004832ae5dd47019e07cdaa32d1fd91a83cd58a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Jan 2025 17:12:53 +0000 Subject: [PATCH] Fix (runtime_act): fix negative group_dim handling --- src/brevitas/export/inference/handler.py | 4 ++-- src/brevitas/proxy/groupwise_float_parameter_quant.py | 2 +- src/brevitas/proxy/groupwise_float_runtime_quant.py | 2 +- src/brevitas/proxy/groupwise_int_parameter_quant.py | 2 +- src/brevitas/proxy/groupwise_int_runtime_quant.py | 2 +- src/brevitas/quant/solver/common.py | 2 +- src/brevitas/quant_tensor/groupwise_float_quant_tensor.py | 2 +- src/brevitas/quant_tensor/groupwise_int_quant_tensor.py | 2 +- src/brevitas/utils/quant_utils.py | 5 ++--- src/brevitas_examples/common/generative/quantize.py | 4 ++-- 10 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 59944c2b0..fa17fcf70 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -131,7 +131,7 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: # If we skip quant tensor, we return the flattened version of the groupwise tensor if self.skip_create_quant_tensor: - start_dim = self.group_dim if self.group_dim != -1 else -2 + start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1 x = x.flatten(start_dim, start_dim + 1) output_args = tuple([x] + list(other)) return output_args @@ -278,7 +278,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor]: # If we skip quant tensor, we return the flattened version of the groupwise tensor if self.skip_create_quant_tensor: - start_dim = self.group_dim if self.group_dim != -1 else -2 + start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1 x = x.flatten(start_dim, start_dim + 1) output_args = tuple([x] + list(other)) return output_args diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index 206e983b5..df586811a 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -30,7 +30,7 @@ def group_size(self): def apply_input_view(self, x): x = super().apply_input_view(x) - start_dim = self.group_dim if self.group_dim != -1 else -2 + start_dim = start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1 return x.flatten(start_dim, start_dim + 1) def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor: diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index 5d76e4635..f62662c71 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -23,7 +23,7 @@ def group_size(self): def apply_input_view(self, x): x = super().apply_input_view(x) - start_dim = self.group_dim if self.group_dim != -1 else -2 + start_dim = start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1 return x.flatten(start_dim, start_dim + 1) def create_quant_tensor( diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index 51ff97c28..6bffdba23 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -30,7 +30,7 @@ def group_size(self): def apply_input_view(self, x): x = super().apply_input_view(x) - start_dim = self.group_dim if self.group_dim != -1 else -2 + start_dim = start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1 return x.flatten(start_dim, start_dim + 1) def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor: diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index 96d047808..48432cdc6 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -23,7 +23,7 @@ def group_size(self): def apply_input_view(self, x): x = super().apply_input_view(x) - start_dim = self.group_dim if self.group_dim != -1 else -2 + start_dim = start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1 return x.flatten(start_dim, start_dim + 1) def create_quant_tensor( diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 69b4c9438..90ebdd815 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -181,7 +181,7 @@ def stats_reduce_dim(scaling_stats_op, scaling_per_output, group_dim=None): elif scaling_per_output == ScalingPerOutputType.TENSOR: return None elif scaling_per_output == ScalingPerOutputType.GROUP: - reduce_dim = group_dim + 1 if group_dim != -1 else -1 + reduce_dim = group_dim + 1 if group_dim > 0 else group_dim return reduce_dim @value diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index 60c5ba84f..ba4f500e3 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -97,7 +97,7 @@ def expand(self): @staticmethod def from_expanded(value, group_size, group_dim, compress=False): - group_dim = group_dim if group_dim != -1 else -2 + group_dim = group_dim if group_dim > 0 else group_dim - 1 size = list(value.shape) assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size' if compress: diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index fa7e8438e..b0d8f1f48 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -83,7 +83,7 @@ def expand(self): @staticmethod def from_expanded(value, group_size, group_dim, compress=False): - group_dim = group_dim if group_dim != -1 else -2 + group_dim = group_dim if group_dim > 0 else group_dim - 1 size = list(value.shape) assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size' if compress: diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index d0d245089..39e4a8fd8 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -220,9 +220,8 @@ def float_to_int_impl_to_enum(module): def groupwise_dequant_expand(value_, scale_, zero_point_, group_dim, dequant_shape): - final_shape = dequant_shape curr_shape = value_.shape - start_dim = group_dim if group_dim != -1 else -2 + start_dim = group_dim if group_dim > 0 else group_dim - 1 new_value = value_.flatten(start_dim, start_dim + 1) if scale_.shape != (): new_scale = scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) @@ -237,7 +236,7 @@ def groupwise_dequant_expand(value_, scale_, zero_point_, group_dim, dequant_sha # First, we compute how much we padded along the group_dim shape # Then, we unbind the tensor along the group_dim shape, and drop the padded columns # Finally, we stack the remaining tensors - unpadding_shape = final_shape[group_dim] + unpadding_shape = dequant_shape[group_dim] residual = new_value.shape[group_dim] - unpadding_shape if residual > 0: diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 778955285..83fdb8e5e 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -388,10 +388,10 @@ def generate_quantizers( elif input_quant_granularity == 'per_group': q_scaled_quant = sym_input_quant.let( **{ - 'group_dim': 2, 'group_size': input_group_size}) + 'group_dim': -1, 'group_size': input_group_size}) k_transposed_quant = sym_input_quant.let( **{ - 'group_dim': 1, 'group_size': input_group_size}) + 'group_dim': -2, 'group_size': input_group_size}) v_quant = q_scaled_quant attn_output_weights_quant = q_scaled_quant else: