From 29b9e35e8e3ddcf8f61356ac25e2853a7ec14b46 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 7 Nov 2023 11:06:14 +0000 Subject: [PATCH] Fix (llm): add checks for group dimensions --- src/brevitas_examples/llm/llm_quant/export.py | 1 + src/brevitas_examples/llm/llm_quant/quantizers.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/export.py b/src/brevitas_examples/llm/llm_quant/export.py index 2d99f17e9..e1da58822 100644 --- a/src/brevitas_examples/llm/llm_quant/export.py +++ b/src/brevitas_examples/llm/llm_quant/export.py @@ -118,6 +118,7 @@ def pack_int_weights(self, bit_width, int_weights): if bit_width == 8: return int_weights elif bit_width == 4 or bit_width == 2: + assert int_weights.shape[1] * bit_width % 8 == 0, "Number of columns multiplied by the bit-width must be a multiple of 8" packed_int_weights = torch.zeros( (int_weights.shape[0], int_weights.shape[1] * bit_width // 8), device=int_weights.device, diff --git a/src/brevitas_examples/llm/llm_quant/quantizers.py b/src/brevitas_examples/llm/llm_quant/quantizers.py index 28590a0e8..8c71f1447 100644 --- a/src/brevitas_examples/llm/llm_quant/quantizers.py +++ b/src/brevitas_examples/llm/llm_quant/quantizers.py @@ -32,22 +32,22 @@ class WeightSymmetricGroupQuantMixin(ExtendedInjector): @value def expanded_scaling_shape(module, block_size): if isinstance(module, nn.Conv2d): - return module.weight.size(0), module.weight.size(1) // block_size, block_size, module.weight.size(2), module.weight.size(3) + return module.weight.size(0), (module.weight.size(1) + block_size - 1) // block_size, block_size, module.weight.size(2), module.weight.size(3) elif isinstance(module, nn.Linear): - return module.weight.size(0), module.weight.size(1) // block_size, block_size + return module.weight.size(0), (module.weight.size(1) + block_size - 1) // block_size, block_size elif isinstance(module, nn.Embedding): - return module.weight.size(0), module.weight.size(1) // block_size, block_size + return module.weight.size(0), (module.weight.size(1) + block_size - 1) // block_size, block_size else: raise RuntimeError("Module not supported.") @value def scaling_shape(module, block_size): if isinstance(module, nn.Conv2d): - return module.weight.size(0), module.weight.size(1) // block_size, 1, module.weight.size(2), module.weight.size(3) + return module.weight.size(0), (module.weight.size(1) + block_size - 1) // block_size, 1, module.weight.size(2), module.weight.size(3) elif isinstance(module, nn.Linear): - return module.weight.size(0), module.weight.size(1) // block_size, 1 + return module.weight.size(0), (module.weight.size(1) + block_size - 1) // block_size, 1 elif isinstance(module, nn.Embedding): - return module.weight.size(0), module.weight.size(1) // block_size, 1 + return module.weight.size(0), (module.weight.size(1) + block_size - 1) // block_size, 1 else: raise RuntimeError("Module not supported.")