Skip to content

Commit

Permalink
[Major] add mixtral8x7b support
Browse files Browse the repository at this point in the history
  • Loading branch information
cylinbao committed Apr 11, 2024
1 parent 97647be commit 7dafad0
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 202 deletions.
8 changes: 4 additions & 4 deletions model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pprint import pprint
from modelutils_llama import quantize_model_llama, reorder_model_llama, quantize_model_gptq_llama, add_act_quant_wrapper_llama
from modelutils_opt import quantize_model_opt, reorder_model_opt, quantize_model_gptq_opt, add_act_quant_wrapper_opt
from modelutils_mixtral import quantize_model_mixtral, add_act_quant_wrapper_mixtral
from modelutils_mixtral import quantize_model_mixtral, add_act_quant_wrapper_mixtral, reorder_model_mixtral
from parallel_utils import map_layers_to_multi_gpus
from LMClass import LMClass
from eval import pattern_match
Expand Down Expand Up @@ -211,10 +211,10 @@ def skip(*args, **kwargs):
eval_func = opt_eval
elif "mixtral" in args.model.lower():
model = get_mixtral(args.model)
# get_act_stats_func = get_act_stats_mixtral
# reorder_model_func = reorder_model_mixtral
get_act_stats_func = get_act_stats_llama
reorder_model_func = reorder_model_mixtral
add_act_quant_wrapper_func = add_act_quant_wrapper_mixtral
# quantize_model_gptq_func = quantize_model_gptq_mixtral
quantize_model_gptq_func = quantize_model_gptq_llama
quantize_model_func = quantize_model_mixtral
eval_func = llama_eval
model.eval()
Expand Down
16 changes: 9 additions & 7 deletions model/modelutils_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import torch.nn as nn
from tqdm import tqdm
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
from qLinearLayer import find_qlinear_layers
from qLlamaLayer import QLlamaDecoderLayer
from qMixtralLayer import QMixtralDecoderLayer
from gptq import GPTQ, Quantizer_GPTQ
from functools import partial

Expand Down Expand Up @@ -44,15 +46,12 @@ def reorder_model_llama(model, device, args, reorder_index):
# Not reorder due to the RoPE embedding.
m.self_attn.q_proj.reorder(
in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'q_proj', 'input')],
# out_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'output')]
out_reorder_index=None
)
m.self_attn.k_proj.reorder(
in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'input')],
# out_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'output')]
out_reorder_index=None
)

m.self_attn.v_proj.reorder(
in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'v_proj', 'input')],
out_reorder_index=None
Expand Down Expand Up @@ -201,16 +200,19 @@ def forward(self, inp, **kwargs):

quantizers = {}
for i in tqdm(range(len(layers))):
m = None
if isinstance(layers[i], LlamaDecoderLayer):
m = QLlamaDecoderLayer(
originalLayer=layers[i],
args=args,
)
elif isinstance(layers[i], QLlamaDecoderLayer):
elif isinstance(layers[i], MixtralDecoderLayer):
m = QMixtralDecoderLayer(
originalLayer=layers[i],
args=args,
)
elif isinstance(layers[i], QLlamaDecoderLayer) or isinstance(layers[i], QMixtralDecoderLayer):
m = layers[i]

if m is None:
else:
continue

layer = m.to(device)
Expand Down
267 changes: 83 additions & 184 deletions model/modelutils_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,72 +10,86 @@

from quant import quantize_activation_wrapper, quantize_attn_v_wrapper, quantize_attn_k_wrapper

# def reorder_model_llama(model, device, args, reorder_index):
# model.config.use_cache = False
# layers = model.model.layers
# assert reorder_index is not None, "Reorder index is None"
#
#
# for i in tqdm(range(len(layers))):
# layers[i] = layers[i].to(device)
# layers[i] = layers[i].to(device)
# if isinstance(layers[i], LlamaDecoderLayer):
# m = QLlamaDecoderLayer(
# originalLayer=layers[i],
# args=args,
# )
# elif isinstance(layers[i], QLlamaDecoderLayer):
# m = layers[i]
#
# nameTemplate = 'layers.{}.{}.{}.{}' # Something like layers.10.self_attn.q_proj
#
# m.mlp.gate_proj.reorder(
# in_reorder_index=reorder_index[nameTemplate.format(i, 'mlp', 'gate_proj', 'input')],
# out_reorder_index=reorder_index[nameTemplate.format(i, 'mlp', 'down_proj', 'input')]
# )
# m.mlp.up_proj.reorder(
# in_reorder_index=reorder_index[nameTemplate.format(i, 'mlp', 'up_proj', 'input')],
# out_reorder_index=reorder_index[nameTemplate.format(i, 'mlp', 'down_proj', 'input')]
# )
# m.mlp.down_proj.reorder(
# in_reorder_index=reorder_index[nameTemplate.format(i, 'mlp', 'down_proj', 'input')],
# out_reorder_index=None
# )
# # K has outlier should be kept.
# # Not reorder due to the RoPE embedding.
# m.self_attn.q_proj.reorder(
# in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'q_proj', 'input')],
# # out_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'output')]
# out_reorder_index=None
# )
# m.self_attn.k_proj.reorder(
# in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'input')],
# # out_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'output')]
# out_reorder_index=None
# )
#
# m.self_attn.v_proj.reorder(
# in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'v_proj', 'input')],
# out_reorder_index=None
# )
# m.self_attn.o_proj.reorder(
# in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'o_proj', 'input')],
# out_reorder_index=None
# )
# m.input_layernorm.register_buffer('reorder_index',
# reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'input')] # Random choose one from k,q,v proj.
# )
# m.post_attention_layernorm.register_buffer('reorder_index',
# reorder_index[nameTemplate.format(i, 'mlp', 'gate_proj', 'input')]
# )
# m.self_attn.register_buffer('reorder_index', reorder_index[nameTemplate.format(i, 'self_attn', 'o_proj', 'input')])
#
# layers[i] = layers[i].cpu()
# layers[i] = m.cpu()
# del m
# torch.cuda.empty_cache()
# return model
#
def reorder_model_mixtral(model, device, args, reorder_index):
model.config.use_cache = False
layers = model.model.layers
assert reorder_index is not None, "Reorder index is None"


for i in tqdm(range(len(layers))):
layers[i] = layers[i].to(device)
layers[i] = layers[i].to(device)
if isinstance(layers[i], MixtralDecoderLayer):
m = QMixtralDecoderLayer(
originalLayer=layers[i],
args=args,
)
elif isinstance(layers[i], QMixtralDecoderLayer):
m = layers[i]

# reordering for the attention
nameTemplate = 'layers.{}.{}.{}.{}' # Something like layers.10.self_attn.q_proj

m.input_layernorm.register_buffer('reorder_index',
reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'input')] # Random choose one from k,q,v proj.
)

# K has outlier should be kept.
# Not reorder due to the RoPE embedding.
m.self_attn.q_proj.reorder(
in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'input')],
out_reorder_index=None
)
m.self_attn.k_proj.reorder(
in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'input')],
out_reorder_index=None
)

m.self_attn.v_proj.reorder(
in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'input')],
out_reorder_index=None
)
m.self_attn.o_proj.reorder(
in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'o_proj', 'input')],
out_reorder_index=None
)

m.self_attn.register_buffer('reorder_index', reorder_index[nameTemplate.format(i, 'self_attn', 'o_proj', 'input')])

# reordering for the MoE
nameTemplate_moe = 'layers.{}.{}.{}.{}.{}.{}' # Something like layers.10.block_sparse_moe.experts.1.w1

# pick expert.0.w1's order and reorder all related modules
m.block_sparse_moe.gate.reorder(
in_reorder_index=reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w1', 'input')],
out_reorder_index=None
)

num_experts = m.block_sparse_moe.num_experts
for j in range(num_experts):
m.block_sparse_moe.experts[j].w1.reorder(
in_reorder_index=reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w1', 'input')],
out_reorder_index=reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w2', 'input')]
)
m.block_sparse_moe.experts[j].w3.reorder(
in_reorder_index=reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w1', 'input')],
out_reorder_index=reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w2', 'input')]
)
m.block_sparse_moe.experts[j].w2.reorder(
in_reorder_index=reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w2', 'input')],
out_reorder_index=None
)

m.post_attention_layernorm.register_buffer('reorder_index',
reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w1', 'input')],
)

layers[i] = layers[i].cpu()
layers[i] = m.cpu()
del m
torch.cuda.empty_cache()
return model


def add_act_quant_wrapper_mixtral(model, device, args, scales):
model.config.use_cache = False
Expand All @@ -100,8 +114,8 @@ def add_act_quant_wrapper_mixtral(model, device, args, scales):
for expert in m.block_sparse_moe.experts:
expert.act_quant = partial(quantize_activation_wrapper, args=args)

m.input_layernorm.act_quant = partial(quantize_activation_wrapper, args=args)
m.post_attention_layernorm.act_quant = partial(quantize_activation_wrapper, args=args)
m.act_quant = partial(quantize_activation_wrapper, args=args)
m.block_sparse_moe.act_quant = partial(quantize_activation_wrapper, args=args)

layers[i] = m.cpu()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -132,119 +146,4 @@ def quantize_model_mixtral(model, device, args):

layers[i] = m.cpu()
torch.cuda.empty_cache()
return model

# def quantize_model_gptq_llama(model, device, args, dataloader):
# print('Starting GPTQ quantization ...')
#
# use_cache = model.config.use_cache
# model.config.use_cache = False
# layers = model.model.layers
#
# model.model.embed_tokens = model.model.embed_tokens.to(device)
# model.model.norm = model.model.norm.to(device)
# layers[0] = layers[0].to(device)
#
# dtype = next(iter(model.parameters())).dtype
# inps = torch.zeros(
# (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=device
# )
#
# cache = {'i': 0, 'attention_mask': None}
#
# class Catcher(nn.Module):
# def __init__(self, module):
# super().__init__()
# self.module = module
# def forward(self, inp, **kwargs):
# inps[cache['i']] = inp
# cache['i'] += 1
# cache['attention_mask'] = kwargs['attention_mask']
# cache['position_ids'] = kwargs['position_ids']
# raise ValueError
#
# layers[0] = Catcher(layers[0])
# for batch in dataloader:
# try:
# model(batch[0].to(device))
# except ValueError:
# pass
# layers[0] = layers[0].module
# layers[0] = layers[0].cpu()
# model.model.embed_tokens = model.model.embed_tokens.cpu()
# model.model.norm = model.model.norm.cpu()
# torch.cuda.empty_cache()
#
# outs = torch.zeros_like(inps)
# attention_mask = cache['attention_mask']
# position_ids = cache['position_ids']
#
# quantizers = {}
# for i in tqdm(range(len(layers))):
# m = None
# if isinstance(layers[i], LlamaDecoderLayer):
# m = QLlamaDecoderLayer(
# originalLayer=layers[i],
# args=args,
# )
# elif isinstance(layers[i], QLlamaDecoderLayer):
# m = layers[i]
#
# if m is None:
# continue
#
# layer = m.to(device)
#
# block_layers = find_qlinear_layers(layer)
#
# sequential = [list(block_layers.keys())]
#
# for names in sequential:
# subset = {n: block_layers[n] for n in names}
#
# gptq = {}
# for name in subset:
# gptq[name] = GPTQ(
# subset[name], n_out=args.keeper, keeper_precision=args.keeper_precision
# )
# gptq[name].quantizer = Quantizer_GPTQ()
# gptq[name].quantizer.configure(
# args.wbits, perchannel=True, sym=args.w_sym, mse=False,
# channel_group=args.weight_channel_group,
# clip_ratio=args.w_clip_ratio
# )
#
# def add_batch(name):
# def tmp(_, inp, out):
# gptq[name].add_batch(inp[0].data, out.data)
# return tmp
#
# handles = []
# for name in subset:
# handles.append(subset[name].register_forward_hook(add_batch(name)))
# for j in range(args.nsamples):
# layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
# for h in handles:
# h.remove()
#
# for name in subset:
# gptq[name].fasterquant(
# percdamp=args.percdamp, groupsize=args.weight_group_size
# )
# quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer.cpu()
# gptq[name].free()
#
# del gptq
#
# for j in range(args.nsamples):
# outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
#
# layers[i] = layer.cpu()
# del layer, m
# torch.cuda.empty_cache()
# gc.collect()
#
# inps, outs = outs, inps
#
# model.config.use_cache = use_cache
# return model
return model
7 changes: 5 additions & 2 deletions model/qLinearLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

def find_qlinear_layers(module, name=''):
if type(module) == QLinearLayer:
return {name: module}
if module.enable_quant:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_qlinear_layers(
Expand All @@ -16,11 +17,13 @@ class QLinearLayer(nn.Module):
def __init__(
self,
originalLayer: nn.Linear,
args
args,
enable_quant: bool = True
):
super().__init__()
self.args = args
self.register_buffer('weight', originalLayer.weight)
self.enable_quant = enable_quant # whether to allow quant on weights, default True
if originalLayer.bias is not None:
self.register_buffer('bias', originalLayer.bias)
else:
Expand Down
Loading

0 comments on commit 7dafad0

Please sign in to comment.