From 423352276f0925e80c24ff235b949c33a32b71b6 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom Date: Fri, 28 Apr 2023 20:25:55 +0000 Subject: [PATCH] feat(sanity-check): implem version of gptq now added --- quantize/gptq/quant.py | 182 +++++++++++++++++++++++++++++ quantize/gptq/sanity_check_main.py | 102 +++++++++++++--- 2 files changed, 271 insertions(+), 13 deletions(-) diff --git a/quantize/gptq/quant.py b/quantize/gptq/quant.py index e301c521..00cb2819 100644 --- a/quantize/gptq/quant.py +++ b/quantize/gptq/quant.py @@ -147,6 +147,188 @@ def make_quant(module, names, bits, groupsize, name=''): for name1, child in module.named_children(): make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) +def make_quant_custom(module, names, bits, groupsize, name=''): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + + bias_name = attr.replace('w', 'b') + layer_name = attr.replace('w', 'quant') + setattr(module, layer_name, QuantLinear_custom(bits, groupsize, tmp.shape[0], tmp.shape[1], module.w[bias_name] is not None)) + + +class QuantLinear_custom(nn.Module): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda): + super().__init__() + if bits not in [2,3,4,8]: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.groupsize = groupsize if groupsize != -1 else infeatures + self.maxq = 2 ** self.bits - 1 + + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32)) + if bias: + self.register_buffer('bias', torch.zeros((outfeatures),dtype=torch.float16)) + else: + self.bias = None + + # is performed by unpacking the weights and using torch.matmul + if self.bits in [2,4,8]: + self.register_buffer('wf',torch.tensor(list(range(0,32,self.bits)), dtype=torch.int32).unsqueeze(0),persistent=False) + elif self.bits == 3: + self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0], + [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31], + [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],], dtype=torch.int32).reshape(1,3,12), persistent=False) + + self.kernel_switch_threshold = kernel_switch_threshold + self.is_cuda = is_cuda + + def pack(self, weight, bias, scales, zeros, g_idx = None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if bias is not None: + self.bias = bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append(torch.round((weight[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None]) + intweight = torch.cat(intweight,dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32//self.bits + row += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i)) + i += 10 + qweight[row] |= intweight[i] << 30 + row += 1 + qweight[row] |= (intweight[i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 1) + i += 10 + qweight[row] |= intweight[i] << 31 + row += 1 + qweight[row] |= (intweight[i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 2) + i += 10 + row += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32//self.bits + col += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) + i += 10 + qzeros[:, col] |= zeros[:, i] << 30 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) + i += 10 + qzeros[:, col] |= zeros[:, i] << 31 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) + i += 10 + col += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures, ) + x = x.reshape(-1,x.shape[-1]) + if self.is_cuda is True and (self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold): + out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32) + if self.bits == 2: + quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 3: + quant_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 4: + quant_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 8: + quant_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + out = out.half() + else: + if self.bits in [2,4,8]: + zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(weight,(2 ** self.bits) - 1, out=weight) + elif self.bits == 3: + zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12) + zeros = (zeros >> self.wf.unsqueeze(0)) + zeros[:,:,0,10] = (zeros[:,:,0,10]&0x3) | ((zeros[:,:,1,0] << 2)&0x4) + zeros[:,:,1,11] = (zeros[:,:,1,11]&0x1) | ((zeros[:,:,2,0] << 1)&0x6) + zeros = zeros & 0x7 + zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = self.qweight.reshape(self.qweight.shape[0]//3, 3, 1, self.qweight.shape[1]).expand(-1, -1, 12, -1) + weight = (weight >> self.wf.unsqueeze(-1))&0x7 + weight[:,0,10] = (weight[:,0,10]&0x3) | ((weight[:,1,0] << 2)&0x4) + weight[:,1,11] = (weight[:,1,11]&0x1) | ((weight[:,2,0] << 1)&0x6) + weight = weight & 0x7 + weight = torch.cat([weight[:,0,:11], weight[:,1,1:12], weight[:,2,1:11]], dim=1) + + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + weights = (self.scales[self.g_idx] * (weight - zeros[self.g_idx])) + out = torch.matmul(x.half(), weights) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out + class QuantLinear(nn.Module): def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda): super().__init__() diff --git a/quantize/gptq/sanity_check_main.py b/quantize/gptq/sanity_check_main.py index eec4114a..31803adf 100644 --- a/quantize/gptq/sanity_check_main.py +++ b/quantize/gptq/sanity_check_main.py @@ -4,6 +4,8 @@ import torch import torch.nn as nn import torch.optim as optim +from collections import OrderedDict +import torch.nn.functional as F from sanity_check_utils import seed_everything, MNISTloader, SimpleNet, train, evaluate, SimpleNet_V2 from gptq import * @@ -34,9 +36,8 @@ def load_quant(model, checkpoint, wbits, groupsize): # Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way) # (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification) - for name in ["linear4"]: - if name in layers: - del layers[name] + if "linear4" in layers: + del layers["linear4"] make_quant(model, layers, wbits, groupsize) model.load_state_dict(torch.load(checkpoint)) @@ -258,8 +259,8 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False) ### begin GPTQ_CUSTOM def __init__(self, checkpoint_path): super().__init__() - self.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) - + self.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) + def _fill_subset(self, layer_id): is_last_layer = (layer_id == self.nb_layers - 1) if is_last_layer: @@ -292,7 +293,7 @@ def fasterquant(self, layer_id, quantizers): print(layer_id, name) print('Quantizing ...') scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False) - quantizers[f"linear{layer_id + 1}"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) + quantizers[f"linear{layer_id}_w"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) ## end GPTQ_CUSTOM @@ -301,6 +302,19 @@ def my_linear(self, x, weight, bias): out = x @ weight.weight + bias weight.add_batch(x) return out + + def forward(self, x): + if len(x.shape) == 4: + x = x.view(x.size(0), -1) + + residual = x + x = F.relu(self.linear0_quant(x)) + x = self.linear1_quant(x) + x = F.relu(x) + residual + x = self.linear2_quant(x) + x = F.relu(x) + residual + x = super().my_linear(x, self.linear3_w, self.linear3_b) + return x ## End SimpleNet_V2 @@ -321,9 +335,11 @@ def quantize_gptq_custom(model, train_loader): quantizers = {} for layer_id in range(nb_layers): - + if not is_last_layer(layer_id): - + + print(f"Quantizing layer {layer_id} ...") + model.alloc_gptq(layer_id) for i in range(nsamples): @@ -342,12 +358,56 @@ def quantize_gptq_custom(model, train_loader): return quantizers - def model_pack_custom(model, quantizers, wbits, groupsize): - pass + # Extract weights and bias from model + is_weight = re.compile(r'^linear\d+_w$') + weights, bias = OrderedDict(), OrderedDict() + for name, param in model.w.items(): + if is_weight.match(name): + weights[name] = param + else: + bias[name] = param + + make_quant_custom(model, quantizers, wbits, groupsize) + qlayers = find_layers(model, [QuantLinear_custom]) + + print('Packing ...') + for i in range(len(qlayers)): + name_w, name_b, layer_quant_name = f'linear{i}_w', f'linear{i}_b', f'linear{i}_quant' + quantizers[name_w],scale,zero,g_idx = quantizers[name_w] + qlayers[layer_quant_name].pack(weights[name_w], bias[name_b], scale, zero, g_idx) + print('Done.') + return model + +def load_quant_custom(model, checkpoint, wbits, groupsize): + print('Loading model ...') + model = model.eval() + # Extract weights and bias from model + is_weight = re.compile(r'^linear\d+_w$') + weights, bias = OrderedDict(), OrderedDict() + for name, param in model.w.items(): + if is_weight.match(name): + weights[name] = param + else: + bias[name] = param + + # Create linear layer out of weights and bias + layers = {} + for (w_name, w_param), (_, b_param) in zip(weights.items(), bias.items()): + layers[w_name] = nn.Linear(w_param.shape[1], w_param.shape[0], bias=True) + layers[w_name].weight.data = w_param + layers[w_name].bias.data = b_param + + # Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way) + # (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification) + if "linear3_w" in layers: + del layers["linear3_w"] + + make_quant_custom(model, layers, wbits, groupsize) + model.load_state_dict(torch.load(checkpoint)) + print('Done.') + return model -def load_quant_custom(model, quantizers, wbits, groupsize): - pass def assert_parameters(model, model_custom): is_weight = re.compile(r'^linear\d+.weight$') @@ -371,6 +431,7 @@ def assert_parameters(model, model_custom): parser.add_argument("--eval_gptq", action="store_true") parser.add_argument("--train_custom", action="store_true") parser.add_argument("--gptq_custom", action="store_true") + parser.add_argument("--eval_gptq_custom", action="store_true") parser.add_argument("--pyquant", action="store_true") args = parser.parse_args() @@ -381,7 +442,9 @@ def assert_parameters(model, model_custom): criterion = nn.CrossEntropyLoss() train_loader, _, _ = MNISTloader(train_val_split=0.95).load() - #TODO: Do Custom packing + #TODO: Do custom eval gptq + #TODO: Is reference GPTQ quantizing bias as well ? + #TODO: Add seed everywhere in GPT for reproducibility ## ================== REFERENCE ================== if args.train: @@ -430,6 +493,19 @@ def assert_parameters(model, model_custom): model_pack_custom(model, quantizers, WBITS, GROUPSIZE) torch.save(model.state_dict(), "model_quantized_custom.pt") print("Done Custom GPTQ") + elif args.eval_gptq_custom: + model = GPTQ_CUSTOM("./model_custom.pt") + device = torch.device("cuda:0") + model = load_quant_custom(model, "model_quantized_custom.pt", WBITS, GROUPSIZE) + model = model.to(device) + + start = time.time() + val_loss, val_acc = evaluate(device, model, criterion, train_loader) + end = time.time() + + print(f"wbits = {WBITS} using {device}") + print(f"val_loss: {val_loss:.3f} \t val_acc: {val_acc:.3f}") + print(f"Latency: {end - start}") ## ================== MISC ================== elif args.pyquant: # Baseline post-training quantization from Pytorch