Skip to content

Commit

Permalink
feat(quantize): full gptq pipeline now integrated with RKWV (quite sl…
Browse files Browse the repository at this point in the history
…ow for some layer + need tests)
  • Loading branch information
3outeille committed Apr 25, 2023
1 parent dba2670 commit 57079e7
Showing 1 changed file with 47 additions and 29 deletions.
76 changes: 47 additions & 29 deletions quantize/tmp_rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,14 @@ def __init__(self, weight, name):
self.columns = self.weight.shape[0]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0
self.deactivate_add_batch_call = False

def add_batch(self, inp):

# After calling fasterquant, we don't want to call add_batch anymore
if self.deactivate_add_batch_call:
return

if len(inp.shape) == 2:
inp = inp.unsqueeze(0)

Expand Down Expand Up @@ -253,30 +259,20 @@ def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr
return x + out, xx

def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry):
# x = (2048, 768)
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
# xx = (2048, 768)
sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
# sx = (2048, 768)
kx = xx * k_mix + sx * (1 - k_mix)
rx = xx * r_mix + sx * (1 - r_mix)
# kx = (2048, 768)
# rx = (2048, 768)

r = torch.sigmoid(rx @ rw.weight)
# r = (2048, 768)
rw.add_batch(rx)
vx = torch.square(torch.relu(kx @ kw.weight))
# vx = (2048, 3072)
# kx: (2048, 768)
# kw.weight: (768, 3072)
# vx: (2048, 3072)
kw.add_batch(kx)
out = r * (vx @ vw.weight)
vw.add_batch(vx)
return x + out, xx[-1,:]

def forward_block(self, x, state, i, seq_mode, full_output=False):
def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False):
with torch.no_grad():
args = self.args

Expand Down Expand Up @@ -344,6 +340,11 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
rmx=rmx, rrx=rrx, rmy=rmy, rry=rry,
omx=omx, orx=orx, omy=omy, ory=ory,
)

kw.deactivate_add_batch_call = True
vw.deactivate_add_batch_call = True
rw.deactivate_add_batch_call = True
ow.deactivate_add_batch_call = True

if dd.stream:
del kw, vw, rw, ow
Expand Down Expand Up @@ -378,28 +379,34 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
rmx=rmx, rrx=rrx, rmy=rmy, rry=rry,
)

# Deactivate add_batch() after quantization is applied
kw.deactivate_add_batch_call = True
vw.deactivate_add_batch_call = True
rw.deactivate_add_batch_call = True

if dd.stream:
del kw, vw, rw

if self.RESCALE_LAYER > 0:
if (i+1) % self.RESCALE_LAYER == 0:
x = x / 2

dd = self.strategy[args.n_layer]
x = x[-1,:] if (seq_mode and (not full_output)) else x
x = x.to(dtype=dd.atype, device=dd.device)

#TODO: Add GPTQ support for head & ln_out
x = F.layer_norm(x, (args.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias'])
if self.w['head.weight'].dtype != torch.uint8:
x = x @ self.w['head.weight']
else:
if seq_mode and full_output:
x = self.mm8_seq(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry'])
if is_last_layer:
dd = self.strategy[args.n_layer]
x = x[-1,:] if (seq_mode and (not full_output)) else x
x = x.to(dtype=dd.atype, device=dd.device)

#TODO: Add GPTQ support for head & ln_out
x = F.layer_norm(x, (args.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias'])
if self.w['head.weight'].dtype != torch.uint8:
x = x @ self.w['head.weight']
else:
x = self.mm8_one(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry'])
if seq_mode and full_output:
x = self.mm8_seq(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry'])
else:
x = self.mm8_one(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry'])

return x.float(), state
return x.float()

### end RWKV

Expand All @@ -420,11 +427,13 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
print("tokens.shape", tokens.shape)

model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32')
is_last_layer = [False] * (model.args.n_layer - 1) + [True]

#TODO: Do the same in GPU side
with torch.no_grad():
seq_mode = len(tokens) > 1
x = model.w['emb.weight'][tokens if seq_mode else tokens[0]]
inps = model.w['emb.weight'][tokens if seq_mode else tokens[0]]
outs = torch.zeros_like(inps)

quantizers = {}

Expand All @@ -433,14 +442,23 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
model.alloc_gptq(layer_id)

for j in range(NSAMPLES):
_ = model.forward_block(x[j], state=None, i=layer_id, seq_mode=seq_mode)

if not is_last_layer[layer_id]:
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id])
else:
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id])

model.fasterquant(layer_id, quantizers)

for j in range(NSAMPLES):
if not is_last_layer[layer_id]:
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id])
else:
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id])
model.free_gptq()

#TODO: Since we quantize per block, we should pass the outputs of block 0 to input of block 1 ?
# inps, outs = outs, inps
if not is_last_layer[layer_id]:
# We need to pass the outputs of block i as input of block i+1 (except for last block)
inps, outs = outs, inps

# TODO: create a function that check if all weights were properly quantized
print("Done")

0 comments on commit 57079e7

Please sign in to comment.