Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Pixtral #681

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions examples/pixtral_multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import torch.nn as nn

from awq import AutoAWQForCausalLM
from awq.utils.qwen_vl_utils import process_vision_info
from awq.quantize.quantizer import AwqQuantizer, clear_memory, get_best_device

# Specify paths and hyperparameters for quantization
model_path = "mistral-community/pixtral-12b"
quant_path = "pixtral-12b-awq"
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"}

model = AutoAWQForCausalLM.from_pretrained(
model_path
)
# FIXME: hack to make pixtral work
model.processor.tokenizer.pad_token_id = 11

def print_module_devices(model):
for name, module in model.named_modules():
# Check parameters
param_devices = {
param_name: param.device
for param_name, param in module.named_parameters(recurse=False)
}

# Check buffers
buffer_devices = {
buffer_name: buffer.device
for buffer_name, buffer in module.named_buffers(recurse=False)
}

if param_devices or buffer_devices:
if param_devices:
for param_name, device in param_devices.items():
print(f" {name} {param_name}: {device}")
if buffer_devices:
for buffer_name, device in buffer_devices.items():
print(f" {name} {buffer_name}: {device}")


# We define our own quantizer by extending the AwqQuantizer.
# The main difference is in how the samples are processed when
# the quantization process initialized.
class PixtralAwqQuantizer(AwqQuantizer):
def init_quant(self, n_samples=None, max_seq_len=None):
modules = self.awq_model.get_model_layers(self.model)
samples = self.calib_data

inps = []
layer_kwargs = {}

best_device = get_best_device()
modules[0] = modules[0].to(best_device)
self.awq_model.move_embed(self.model, best_device)

# FIXME: Hacky way to move the vision part to the right device
self.model.vision_tower = self.model.vision_tower.to(best_device)
self.model.multi_modal_projector = self.model.multi_modal_projector.to(best_device)

# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module

def forward(self, *args, **kwargs):
# assume first input to forward is hidden states
if len(args) > 0:
hidden_states = args[0]
del args
else:
first_key = list(kwargs.keys())[0]
hidden_states = kwargs.pop(first_key)

inps.append(hidden_states)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference

# patch layer 0 to catch input and kwargs
modules[0] = Catcher(modules[0])
print_module_devices(self.model)
try:
self.model(**samples.to(best_device))
except ValueError: # work with early exit
pass
modules[0] = modules[0].module # restore

del samples
inps = inps[0]

modules[0] = modules[0].cpu()
self.awq_model.move_embed(self.model, "cpu")

clear_memory()

return modules, layer_kwargs, inps

def prepare_dataset(n_sample: int = 8) -> list[list[dict]]:
from datasets import load_dataset

dataset = load_dataset("laion/220k-GPT4Vision-captions-from-LIVIS", split=f"train[:{n_sample}]")
return [
[
{
"role": "user",
"content": [
{"type": "image", "image": sample["url"]},
{"type": "text", "text": "generate a caption for this image"},
],
},
{"role": "assistant", "content": sample["caption"]},
]
for sample in dataset
]

dataset = prepare_dataset()

# process the dataset into tensors
text = model.processor.apply_chat_template(dataset, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(dataset)
inputs = model.processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")

# Then just run the calibration process by one line of code
model.quantize(calib_data=inputs, quant_config=quant_config, quantizer_cls=PixtralAwqQuantizer)

# Save the model
model.model.config.use_cache = model.model.generation_config.use_cache = True
model.save_quantized(quant_path, safetensors=True, shard_size="4GB")
91 changes: 87 additions & 4 deletions examples/quantize.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,99 @@
from awq import AutoAWQForCausalLM
import torch
import torch.nn as nn
from transformers import AutoTokenizer

model_path = 'Qwen/Qwen2.5-14B-Instruct'
quant_path = 'Qwen2.5-14B-Instruct-awq'
from awq import AutoAWQForCausalLM
from awq.quantize.quantizer import (
AwqQuantizer,
clear_memory,
get_best_device,
get_calib_dataset,
)

model_path = "mistral-community/pixtral-12b"
quant_path = "pixtral-12b-awq"
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

class PixtralTextQuantizer(AwqQuantizer):
def init_quant(self, n_samples=128, max_seq_len=512):
modules = self.awq_model.get_model_layers(self.model)
samples = get_calib_dataset(
data=self.calib_data,
tokenizer=self.tokenizer,
n_samples=n_samples,
max_seq_len=max_seq_len,
split=self.split,
text_column=self.text_column,
)
samples = torch.cat(samples, dim=0)

inps = []
layer_kwargs = {}

best_device = get_best_device()
modules[0] = modules[0].to(best_device)
self.awq_model.move_embed(self.model, best_device)

# FIXME: Hacky way to move the vision part to the right device
self.model.vision_tower = self.model.vision_tower.to(best_device)
self.model.multi_modal_projector = self.model.multi_modal_projector.to(best_device)

# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module

def forward(self, *args, **kwargs):
# assume first input to forward is hidden states
if len(args) > 0:
hidden_states = args[0]
del args
else:
first_key = list(kwargs.keys())[0]
hidden_states = kwargs.pop(first_key)

inps.append(hidden_states)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference

# patch layer 0 to catch input and kwargs
modules[0] = Catcher(modules[0])
try:
self.model(samples.to(best_device))
except ValueError: # work with early exit
pass
modules[0] = modules[0].module # restore

# Update the layer kwargs with `prepare_inputs_for_generation` method
# that takes care of everything to avoid unexpected errors.
layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
# Pop the input_ids as they are not needed at all.
layer_kwargs.pop("input_ids")

del samples
inps = inps[0]

modules[0] = modules[0].cpu()
self.awq_model.move_embed(self.model, "cpu")

clear_memory()

if layer_kwargs.get("attention_mask") is not None:
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(
best_device
)

return modules, layer_kwargs, inps

# Quantize
model.quantize(tokenizer, quant_config=quant_config)
model.quantize(tokenizer, quant_config=quant_config, quantizer_cls=PixtralTextQuantizer)

# Save quantized model
model.save_quantized(quant_path)
Expand Down