From 24918324a8f0e3848889f985ec03506cfed2bb72 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 13 Dec 2024 21:18:02 +0000 Subject: [PATCH 1/2] fix init_quant --- examples/quantize.py | 137 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 124 insertions(+), 13 deletions(-) diff --git a/examples/quantize.py b/examples/quantize.py index 923e9766..29308498 100644 --- a/examples/quantize.py +++ b/examples/quantize.py @@ -1,19 +1,130 @@ +import torch.nn as nn + from awq import AutoAWQForCausalLM -from transformers import AutoTokenizer +from awq.utils.qwen_vl_utils import process_vision_info +from awq.quantize.quantizer import AwqQuantizer, clear_memory, get_best_device + +# Specify paths and hyperparameters for quantization +model_path = "mistral-community/pixtral-12b" +quant_path = "pixtral-12b-awq" +quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} + +model = AutoAWQForCausalLM.from_pretrained( + model_path +) +# FIXME: hack to make pixtral work +model.processor.tokenizer.pad_token_id = 11 + +def print_module_devices(model): + for name, module in model.named_modules(): + # Check parameters + param_devices = { + param_name: param.device + for param_name, param in module.named_parameters(recurse=False) + } + + # Check buffers + buffer_devices = { + buffer_name: buffer.device + for buffer_name, buffer in module.named_buffers(recurse=False) + } + + if param_devices or buffer_devices: + if param_devices: + for param_name, device in param_devices.items(): + print(f" {name} {param_name}: {device}") + if buffer_devices: + for buffer_name, device in buffer_devices.items(): + print(f" {name} {buffer_name}: {device}") + + +# We define our own quantizer by extending the AwqQuantizer. +# The main difference is in how the samples are processed when +# the quantization process initialized. +class PixtralAwqQuantizer(AwqQuantizer): + def init_quant(self, n_samples=None, max_seq_len=None): + modules = self.awq_model.get_model_layers(self.model) + samples = self.calib_data + + inps = [] + layer_kwargs = {} + + best_device = get_best_device() + modules[0] = modules[0].to(best_device) + self.awq_model.move_embed(self.model, best_device) + + # FIXME: Hacky way to move the vision part to the right device + self.model.vision_tower = self.model.vision_tower.to(best_device) + self.model.multi_modal_projector = self.model.multi_modal_projector.to(best_device) + + # get input and kwargs to layer 0 + # with_kwargs is only supported in PyTorch 2.0 + # use this Catcher hack for now + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + # assume first input to forward is hidden states + if len(args) > 0: + hidden_states = args[0] + del args + else: + first_key = list(kwargs.keys())[0] + hidden_states = kwargs.pop(first_key) + + inps.append(hidden_states) + layer_kwargs.update(kwargs) + raise ValueError # early exit to break later inference + + # patch layer 0 to catch input and kwargs + modules[0] = Catcher(modules[0]) + print_module_devices(self.model) + try: + self.model(**samples.to(best_device)) + except ValueError: # work with early exit + pass + modules[0] = modules[0].module # restore + + del samples + inps = inps[0] + + modules[0] = modules[0].cpu() + self.awq_model.move_embed(self.model, "cpu") + + clear_memory() + + return modules, layer_kwargs, inps + +def prepare_dataset(n_sample: int = 8) -> list[list[dict]]: + from datasets import load_dataset -model_path = 'Qwen/Qwen2.5-14B-Instruct' -quant_path = 'Qwen2.5-14B-Instruct-awq' -quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } + dataset = load_dataset("laion/220k-GPT4Vision-captions-from-LIVIS", split=f"train[:{n_sample}]") + return [ + [ + { + "role": "user", + "content": [ + {"type": "image", "image": sample["url"]}, + {"type": "text", "text": "generate a caption for this image"}, + ], + }, + {"role": "assistant", "content": sample["caption"]}, + ] + for sample in dataset + ] -# Load model -model = AutoAWQForCausalLM.from_pretrained(model_path) -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +dataset = prepare_dataset() -# Quantize -model.quantize(tokenizer, quant_config=quant_config) +# process the dataset into tensors +text = model.processor.apply_chat_template(dataset, tokenize=False, add_generation_prompt=True) +image_inputs, video_inputs = process_vision_info(dataset) +inputs = model.processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") -# Save quantized model -model.save_quantized(quant_path) -tokenizer.save_pretrained(quant_path) +# Then just run the calibration process by one line of code +model.quantize(calib_data=inputs, quant_config=quant_config, quantizer_cls=PixtralAwqQuantizer) -print(f'Model is quantized and saved at "{quant_path}"') \ No newline at end of file +# Save the model +model.model.config.use_cache = model.model.generation_config.use_cache = True +model.save_quantized(quant_path, safetensors=True, shard_size="4GB") \ No newline at end of file From 710695ce069f5bcc57fee818be433cc05acfbf6d Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 13 Dec 2024 21:50:34 +0000 Subject: [PATCH 2/2] pivot to text only (pixtral_multimodal.py for keepers sake) --- examples/pixtral_multimodal.py | 130 +++++++++++++++++++++++++++++++++ examples/quantize.py | 116 +++++++++++------------------ 2 files changed, 174 insertions(+), 72 deletions(-) create mode 100644 examples/pixtral_multimodal.py diff --git a/examples/pixtral_multimodal.py b/examples/pixtral_multimodal.py new file mode 100644 index 00000000..29308498 --- /dev/null +++ b/examples/pixtral_multimodal.py @@ -0,0 +1,130 @@ +import torch.nn as nn + +from awq import AutoAWQForCausalLM +from awq.utils.qwen_vl_utils import process_vision_info +from awq.quantize.quantizer import AwqQuantizer, clear_memory, get_best_device + +# Specify paths and hyperparameters for quantization +model_path = "mistral-community/pixtral-12b" +quant_path = "pixtral-12b-awq" +quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} + +model = AutoAWQForCausalLM.from_pretrained( + model_path +) +# FIXME: hack to make pixtral work +model.processor.tokenizer.pad_token_id = 11 + +def print_module_devices(model): + for name, module in model.named_modules(): + # Check parameters + param_devices = { + param_name: param.device + for param_name, param in module.named_parameters(recurse=False) + } + + # Check buffers + buffer_devices = { + buffer_name: buffer.device + for buffer_name, buffer in module.named_buffers(recurse=False) + } + + if param_devices or buffer_devices: + if param_devices: + for param_name, device in param_devices.items(): + print(f" {name} {param_name}: {device}") + if buffer_devices: + for buffer_name, device in buffer_devices.items(): + print(f" {name} {buffer_name}: {device}") + + +# We define our own quantizer by extending the AwqQuantizer. +# The main difference is in how the samples are processed when +# the quantization process initialized. +class PixtralAwqQuantizer(AwqQuantizer): + def init_quant(self, n_samples=None, max_seq_len=None): + modules = self.awq_model.get_model_layers(self.model) + samples = self.calib_data + + inps = [] + layer_kwargs = {} + + best_device = get_best_device() + modules[0] = modules[0].to(best_device) + self.awq_model.move_embed(self.model, best_device) + + # FIXME: Hacky way to move the vision part to the right device + self.model.vision_tower = self.model.vision_tower.to(best_device) + self.model.multi_modal_projector = self.model.multi_modal_projector.to(best_device) + + # get input and kwargs to layer 0 + # with_kwargs is only supported in PyTorch 2.0 + # use this Catcher hack for now + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + # assume first input to forward is hidden states + if len(args) > 0: + hidden_states = args[0] + del args + else: + first_key = list(kwargs.keys())[0] + hidden_states = kwargs.pop(first_key) + + inps.append(hidden_states) + layer_kwargs.update(kwargs) + raise ValueError # early exit to break later inference + + # patch layer 0 to catch input and kwargs + modules[0] = Catcher(modules[0]) + print_module_devices(self.model) + try: + self.model(**samples.to(best_device)) + except ValueError: # work with early exit + pass + modules[0] = modules[0].module # restore + + del samples + inps = inps[0] + + modules[0] = modules[0].cpu() + self.awq_model.move_embed(self.model, "cpu") + + clear_memory() + + return modules, layer_kwargs, inps + +def prepare_dataset(n_sample: int = 8) -> list[list[dict]]: + from datasets import load_dataset + + dataset = load_dataset("laion/220k-GPT4Vision-captions-from-LIVIS", split=f"train[:{n_sample}]") + return [ + [ + { + "role": "user", + "content": [ + {"type": "image", "image": sample["url"]}, + {"type": "text", "text": "generate a caption for this image"}, + ], + }, + {"role": "assistant", "content": sample["caption"]}, + ] + for sample in dataset + ] + +dataset = prepare_dataset() + +# process the dataset into tensors +text = model.processor.apply_chat_template(dataset, tokenize=False, add_generation_prompt=True) +image_inputs, video_inputs = process_vision_info(dataset) +inputs = model.processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") + +# Then just run the calibration process by one line of code +model.quantize(calib_data=inputs, quant_config=quant_config, quantizer_cls=PixtralAwqQuantizer) + +# Save the model +model.model.config.use_cache = model.model.generation_config.use_cache = True +model.save_quantized(quant_path, safetensors=True, shard_size="4GB") \ No newline at end of file diff --git a/examples/quantize.py b/examples/quantize.py index 29308498..d7065b73 100644 --- a/examples/quantize.py +++ b/examples/quantize.py @@ -1,50 +1,35 @@ +import torch import torch.nn as nn +from transformers import AutoTokenizer from awq import AutoAWQForCausalLM -from awq.utils.qwen_vl_utils import process_vision_info -from awq.quantize.quantizer import AwqQuantizer, clear_memory, get_best_device +from awq.quantize.quantizer import ( + AwqQuantizer, + clear_memory, + get_best_device, + get_calib_dataset, +) -# Specify paths and hyperparameters for quantization model_path = "mistral-community/pixtral-12b" quant_path = "pixtral-12b-awq" -quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} +quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } -model = AutoAWQForCausalLM.from_pretrained( - model_path -) -# FIXME: hack to make pixtral work -model.processor.tokenizer.pad_token_id = 11 - -def print_module_devices(model): - for name, module in model.named_modules(): - # Check parameters - param_devices = { - param_name: param.device - for param_name, param in module.named_parameters(recurse=False) - } - - # Check buffers - buffer_devices = { - buffer_name: buffer.device - for buffer_name, buffer in module.named_buffers(recurse=False) - } - - if param_devices or buffer_devices: - if param_devices: - for param_name, device in param_devices.items(): - print(f" {name} {param_name}: {device}") - if buffer_devices: - for buffer_name, device in buffer_devices.items(): - print(f" {name} {buffer_name}: {device}") - - -# We define our own quantizer by extending the AwqQuantizer. -# The main difference is in how the samples are processed when -# the quantization process initialized. -class PixtralAwqQuantizer(AwqQuantizer): - def init_quant(self, n_samples=None, max_seq_len=None): +# Load model +model = AutoAWQForCausalLM.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +class PixtralTextQuantizer(AwqQuantizer): + def init_quant(self, n_samples=128, max_seq_len=512): modules = self.awq_model.get_model_layers(self.model) - samples = self.calib_data + samples = get_calib_dataset( + data=self.calib_data, + tokenizer=self.tokenizer, + n_samples=n_samples, + max_seq_len=max_seq_len, + split=self.split, + text_column=self.text_column, + ) + samples = torch.cat(samples, dim=0) inps = [] layer_kwargs = {} @@ -80,13 +65,18 @@ def forward(self, *args, **kwargs): # patch layer 0 to catch input and kwargs modules[0] = Catcher(modules[0]) - print_module_devices(self.model) try: - self.model(**samples.to(best_device)) + self.model(samples.to(best_device)) except ValueError: # work with early exit pass modules[0] = modules[0].module # restore + # Update the layer kwargs with `prepare_inputs_for_generation` method + # that takes care of everything to avoid unexpected errors. + layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs) + # Pop the input_ids as they are not needed at all. + layer_kwargs.pop("input_ids") + del samples inps = inps[0] @@ -95,36 +85,18 @@ def forward(self, *args, **kwargs): clear_memory() + if layer_kwargs.get("attention_mask") is not None: + layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to( + best_device + ) + return modules, layer_kwargs, inps -def prepare_dataset(n_sample: int = 8) -> list[list[dict]]: - from datasets import load_dataset - - dataset = load_dataset("laion/220k-GPT4Vision-captions-from-LIVIS", split=f"train[:{n_sample}]") - return [ - [ - { - "role": "user", - "content": [ - {"type": "image", "image": sample["url"]}, - {"type": "text", "text": "generate a caption for this image"}, - ], - }, - {"role": "assistant", "content": sample["caption"]}, - ] - for sample in dataset - ] - -dataset = prepare_dataset() - -# process the dataset into tensors -text = model.processor.apply_chat_template(dataset, tokenize=False, add_generation_prompt=True) -image_inputs, video_inputs = process_vision_info(dataset) -inputs = model.processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") - -# Then just run the calibration process by one line of code -model.quantize(calib_data=inputs, quant_config=quant_config, quantizer_cls=PixtralAwqQuantizer) - -# Save the model -model.model.config.use_cache = model.model.generation_config.use_cache = True -model.save_quantized(quant_path, safetensors=True, shard_size="4GB") \ No newline at end of file +# Quantize +model.quantize(tokenizer, quant_config=quant_config, quantizer_cls=PixtralTextQuantizer) + +# Save quantized model +model.save_quantized(quant_path) +tokenizer.save_pretrained(quant_path) + +print(f'Model is quantized and saved at "{quant_path}"') \ No newline at end of file