Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use cuda for packing #407

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion auto_round/export/export_to_autogptq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ def pack_layer(name, model, layer_config, backend, pbar):
# so far can only pack layer on CPU
qlayer.to("cpu")
##force to float32 to be compatible with torch 2.0
layer, scale, zero = layer.to("cpu"), scale.to("cpu"), zero.to("cpu").to(torch.float32)
if sym:
zero = 2**(bits-1)
layer, scale = layer.to("cpu"), scale.to("cpu")
else:
layer, scale, zero = layer.to("cpu"), scale.to("cpu"), zero.to("cpu").to(torch.float32)
sig = inspect.signature(qlayer.pack)
param_count = len(sig.parameters)
if param_count == 2:
Expand Down
86 changes: 72 additions & 14 deletions auto_round/export/export_to_autogptq/qlinear_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=Fa
def post_init(self):
pass

def pack(self, linear, scales, zeros, g_idx=None):
def pack_cpu(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
Expand Down Expand Up @@ -107,13 +107,10 @@ def pack(self, linear, scales, zeros, g_idx=None):
row = 0
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1

qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
Expand All @@ -124,17 +121,78 @@ def pack(self, linear, scales, zeros, g_idx=None):
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1

qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)

def pack(self, linear, scales, zeros, g_idx):
if torch.cuda.is_available():
return self.pack_cuda(linear, scales, zeros, g_idx)
else:
return self.pack_cpu(linear, scales, zeros, g_idx)

def pack_cuda(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()
scales_t = scales.t().contiguous()
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
if linear.bias is not None:
self.bias = linear.bias.clone().half()
self.scales = scales_t.clone().half()

W = W.to("cuda")
repeat_scales = scales.to("cuda").repeat_interleave(self.group_size, 1)
if isinstance(zeros, torch.Tensor):
repeat_zeros = zeros.to("cuda").repeat_interleave(self.group_size, 1)
else:
repeat_zeros = zeros

intweight_cuda = torch.round(W / repeat_scales + repeat_zeros).to(torch.int).t().contiguous().to("cpu")
intweight = intweight_cuda.numpy().astype(np.uint32)

i = 0
row = 0
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
while row < qweight.shape[0]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1

qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

if isinstance(zeros, torch.Tensor):
zeros = zeros.t().contiguous()
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")

qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)

qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
else:
zeros -= 1
shape = scales_t.shape
value = 0
for j in range(0, (32 // self.bits)):
value |= zeros << (self.bits * j)
qzeros = np.ones((shape[0], shape[1] // 32 * self.bits), dtype=np.uint32) * value
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)


__all__ = ["QuantLinear"]
Loading