From 6e30991e97ba5403c21442296052023e3edf7315 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 23 Jan 2025 21:00:11 +0800 Subject: [PATCH] FEAT Add gptqmodel support (#2247) Add support for gptqmodel quantization. This is a replacement for auto-gptq. For now, both packages are supported, but since auto-gptq is no longer being developed, it will be deprecated and removed at some point in the future. --------- Signed-off-by: jiqing-feng Co-authored-by: LRL-ModelCloud <165116337+LRL-ModelCloud@users.noreply.github.com> Co-authored-by: Qubitium-ModelCloud Co-authored-by: ZX-ModelCloud <165115237+ZX-ModelCloud@users.noreply.github.com> Co-authored-by: LRL Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- Makefile | 1 + docs/source/developer_guides/quantization.md | 26 ++ src/peft/import_utils.py | 27 ++ src/peft/tuners/adalora/model.py | 16 +- src/peft/tuners/lora/gptq.py | 14 +- src/peft/tuners/lora/model.py | 3 +- src/peft/utils/__init__.py | 2 + src/peft/utils/other.py | 89 +++-- tests/test_common_gpu.py | 6 +- tests/test_gptqmodel.py | 349 +++++++++++++++++++ tests/testing_utils.py | 12 +- 11 files changed, 508 insertions(+), 37 deletions(-) create mode 100644 tests/test_gptqmodel.py diff --git a/Makefile b/Makefile index 2052bc5465..0d93db3229 100644 --- a/Makefile +++ b/Makefile @@ -34,6 +34,7 @@ tests_core_single_gpu: tests_common_gpu: python -m pytest tests/test_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_decoder.log",) python -m pytest tests/test_encoder_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_encoder_decoder.log",) + python -m pytest tests/test_gptqmodel.py $(if $(IS_GITHUB_CI),--report-log "gptqmodel_gpu.log",) tests_examples_multi_gpu_bnb: python -m pytest -m "multi_gpu_tests and bitsandbytes" tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "multi_gpu_examples.log",) diff --git a/docs/source/developer_guides/quantization.md b/docs/source/developer_guides/quantization.md index 1d0271ba90..5067156d8b 100644 --- a/docs/source/developer_guides/quantization.md +++ b/docs/source/developer_guides/quantization.md @@ -107,6 +107,32 @@ QLoRA adds trainable weights to all the linear layers in the transformer archite config = LoraConfig(target_modules="all-linear", ...) ``` +## GPTQ quantization + +You can learn more about gptq based `[2, 3, 4, 8]` bits quantization at [GPTQModel](https://github.com/ModelCloud/GPTQModel) and the Transformers [GPTQ](https://huggingface.co/docs/transformers/quantization/gptq) doc. Post-quant training, PEFT can use both [GPTQModel](https://github.com/ModelCloud/GPTQModel) or [AutoGPTQ](https://github.com/autogptq/autogptq) libraries, but we recommend GPTQModel because AutoGPTQ will be deprecated in a future release. + +```bash +# gptqmodel install +pip install gptqmodel --no-build-isolation +``` + +```py +from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig + +model_id = "facebook/opt-125m" +tokenizer = AutoTokenizer.from_pretrained(model_id) + +gptq_config = GPTQConfig(bits=4, group_size=128, dataset="wikitext2", tokenizer=tokenizer) + +quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=gptq_config) + +# save quantized model +quantized_model.save_pretrained("./opt-125m-gptq") +tokenizer.save_pretrained("./opt-125m-gptq") +``` + +Once quantized, you can post-train GPTQ models with PEFT APIs. + ## AQLM quantization Additive Quantization of Language Models ([AQLM](https://arxiv.org/abs/2401.06118)) is a Large Language Models compression method. It quantizes multiple weights together and takes advantage of interdependencies between them. AQLM represents groups of 8-16 weights as a sum of multiple vector codes. This allows it to compress models down to as low as 2-bit with considerably low accuracy losses. diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index 7599aed35f..97404aeb4b 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -49,6 +49,33 @@ def is_auto_gptq_available(): ) +@lru_cache +def is_gptqmodel_available(): + if importlib.util.find_spec("gptqmodel") is not None: + GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.7.0") + OPTIMUM_MINIMUM_VERSION = packaging.version.parse("1.23.99") + version_gptqmodel = packaging.version.parse(importlib_metadata.version("gptqmodel")) + if GPTQMODEL_MINIMUM_VERSION <= version_gptqmodel: + if is_optimum_available(): + version_optimum = packaging.version.parse(importlib_metadata.version("optimum")) + if OPTIMUM_MINIMUM_VERSION <= version_optimum: + return True + else: + raise ImportError( + f"gptqmodel requires optimum version {OPTIMUM_MINIMUM_VERSION} or higher. Found version {version_optimum}, " + f"but only versions above {OPTIMUM_MINIMUM_VERSION} are supported" + ) + else: + raise ImportError( + f"gptqmodel requires optimum version {OPTIMUM_MINIMUM_VERSION} or higher to be installed." + ) + else: + raise ImportError( + f"Found an incompatible version of gptqmodel. Found version {version_gptqmodel}, " + f"but only versions above {GPTQMODEL_MINIMUM_VERSION} are supported" + ) + + @lru_cache def is_optimum_available() -> bool: return importlib.util.find_spec("optimum") is not None diff --git a/src/peft/tuners/adalora/model.py b/src/peft/tuners/adalora/model.py index db5759a5ac..3c52ecdf2f 100644 --- a/src/peft/tuners/adalora/model.py +++ b/src/peft/tuners/adalora/model.py @@ -17,7 +17,7 @@ import torch from transformers.pytorch_utils import Conv1D -from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_gptqmodel_available from peft.tuners.lora import LoraConfig, LoraModel from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils import ( @@ -25,6 +25,7 @@ _freeze_adapter, _get_submodules, get_auto_gptq_quant_linear, + get_gptqmodel_quant_linear, get_quantization_config, ) from peft.utils.integrations import gather_params_ctx @@ -135,7 +136,8 @@ def _create_and_replace( # If it is not an AdaLoraLayer, create a new module, else update it with new adapters if not isinstance(target, AdaLoraLayer): - new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) + device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None + new_module = self._create_new_module(lora_config, adapter_name, target, device_map=device_map, **kwargs) if adapter_name not in self.active_adapters: # adding an additional adapter: it is not automatically trainable new_module.requires_grad_(False) @@ -150,7 +152,7 @@ def _create_and_replace( ) @staticmethod - def _create_new_module(lora_config, adapter_name, target, **kwargs): + def _create_new_module(lora_config, adapter_name, target, device_map=None, **kwargs): # avoid eager bnb import if is_bnb_available(): import bitsandbytes as bnb @@ -160,7 +162,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): from .bnb import SVDLinear4bit gptq_quantization_config = kwargs.get("gptq_quantization_config", None) - AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) + + if is_gptqmodel_available(): + QuantLinear = get_gptqmodel_quant_linear(gptq_quantization_config, device_map=device_map) + else: + QuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) @@ -189,7 +195,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): } ) new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs) - elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear): + elif QuantLinear is not None and isinstance(target, QuantLinear): new_module = SVDQuantLinear(target, adapter_name, **kwargs) else: if isinstance(target_base_layer, torch.nn.Linear): diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index d4826016dc..0fb8cd49a3 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -16,9 +16,10 @@ import torch +from peft.import_utils import is_gptqmodel_available from peft.tuners.lora.layer import LoraLayer from peft.tuners.tuners_utils import BaseTunerLayer -from peft.utils import get_auto_gptq_quant_linear +from peft.utils import get_auto_gptq_quant_linear, get_gptqmodel_quant_linear class QuantLinear(torch.nn.Module, LoraLayer): @@ -106,10 +107,15 @@ def dispatch_gptq( else: target_base_layer = target - gptq_quantization_config = kwargs.get("gptq_quantization_config", None) - AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) + cfg = kwargs.get("gptq_quantization_config", None) - if AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear): + if is_gptqmodel_available(): + device_map = kwargs.get("device_map", None) + quant_linear = get_gptqmodel_quant_linear(cfg, device_map=device_map) + else: + quant_linear = get_auto_gptq_quant_linear(cfg) + + if quant_linear is not None and isinstance(target_base_layer, quant_linear): new_module = QuantLinear(target, adapter_name, **kwargs) target.qweight = target_base_layer.qweight diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 2967b8da9c..92d9b6257f 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -232,7 +232,8 @@ def _create_and_replace( lora_bias=lora_config.lora_bias, ) else: - new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) + device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None + new_module = self._create_new_module(lora_config, adapter_name, target, device_map=device_map, **kwargs) if adapter_name not in self.active_adapters: # adding an additional adapter: it is not automatically trainable new_module.requires_grad_(False) diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 99a86e5b23..00dbb44d93 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -39,6 +39,7 @@ bloom_model_postprocess_past_key_value, cast_mixed_precision_params, get_auto_gptq_quant_linear, + get_gptqmodel_quant_linear, get_quantization_config, id_tensor_storage, infer_device, @@ -77,6 +78,7 @@ "bloom_model_postprocess_past_key_value", "cast_mixed_precision_params", "get_auto_gptq_quant_linear", + "get_gptqmodel_quant_linear", "get_peft_model_state_dict", "get_quantization_config", "id_tensor_storage", diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 825612009d..77a8b07630 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -30,7 +30,7 @@ from packaging import version from safetensors.torch import storage_ptr, storage_size -from ..import_utils import is_auto_gptq_available, is_torch_tpu_available +from ..import_utils import is_auto_gptq_available, is_gptqmodel_available, is_torch_tpu_available from .constants import ( CONFIG_NAME, EMBEDDING_LAYER_NAMES, @@ -610,30 +610,73 @@ def get_auto_gptq_quant_linear(gptq_quantization_config): """ Get the right AutoGPTQQuantLinear class based on the quantization config file """ - if gptq_quantization_config is not None and is_auto_gptq_available(): + if gptq_quantization_config is None: + return None + + if is_auto_gptq_available(): from auto_gptq.utils.import_utils import dynamically_import_QuantLinear + else: + return None - desc_act = gptq_quantization_config.desc_act - group_size = gptq_quantization_config.group_size - bits = gptq_quantization_config.bits - if hasattr(gptq_quantization_config, "use_exllama"): - use_exllama = gptq_quantization_config.use_exllama - else: - use_exllama = not gptq_quantization_config.disable_exllama - if hasattr(gptq_quantization_config, "exllama_config"): - exllama_version = gptq_quantization_config.exllama_config["version"] - else: - exllama_version = 1 - AutoGPTQQuantLinear = dynamically_import_QuantLinear( - use_triton=False, - desc_act=desc_act, - group_size=group_size, - bits=bits, - disable_exllama=not (use_exllama and exllama_version == 1), - disable_exllamav2=not (use_exllama and exllama_version == 2), - ) - return AutoGPTQQuantLinear - return None + desc_act = gptq_quantization_config.desc_act + group_size = gptq_quantization_config.group_size + bits = gptq_quantization_config.bits + if hasattr(gptq_quantization_config, "use_exllama"): + use_exllama = gptq_quantization_config.use_exllama + else: + use_exllama = not gptq_quantization_config.disable_exllama + if hasattr(gptq_quantization_config, "exllama_config"): + exllama_version = gptq_quantization_config.exllama_config["version"] + else: + exllama_version = 1 + + QuantLinear = dynamically_import_QuantLinear( + use_triton=False, + desc_act=desc_act, + group_size=group_size, + bits=bits, + disable_exllama=not (use_exllama and exllama_version == 1), + disable_exllamav2=not (use_exllama and exllama_version == 2), + ) + + return QuantLinear + + +def get_gptqmodel_quant_linear(gptq_quantization_config, device_map=None): + """ + Get the right GPTQQuantLinear class based on the quantization config file + """ + if gptq_quantization_config is None: + return None + + if not is_gptqmodel_available(): + return None + + from gptqmodel.utils.importer import hf_select_quant_linear + + desc_act = gptq_quantization_config.desc_act + group_size = gptq_quantization_config.group_size + bits = gptq_quantization_config.bits + checkpoint_format = ( + gptq_quantization_config.checkpoint_format + if hasattr(gptq_quantization_config, "checkpoint_format") + else "gptq" + ) + sym = gptq_quantization_config.sym + meta = gptq_quantization_config.meta if hasattr(gptq_quantization_config, "meta") else None + + QuantLinear = hf_select_quant_linear( + bits=bits, + group_size=group_size, + desc_act=desc_act, + sym=sym, + device_map=device_map, + checkpoint_format=checkpoint_format, + meta=meta, + backend="auto_trainable", + ) + + return QuantLinear def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]: diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index a92e2c8171..2d1f8cba1e 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -406,19 +406,19 @@ def test_lora_gptq_quantization_from_pretrained_safetensors(self): config = LoraConfig(task_type="CAUSAL_LM") peft_model = get_peft_model(model, config) - peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0)) + peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device)) with tempfile.TemporaryDirectory() as tmp_dir: peft_model.save_pretrained(tmp_dir) model = AutoModelForCausalLM.from_pretrained(**kwargs) model = PeftModel.from_pretrained(model, tmp_dir) model = prepare_model_for_kbit_training(model) - model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0)) + model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device)) # loading a 2nd adapter works, #1239 model.load_adapter(tmp_dir, "adapter2") model.set_adapter("adapter2") - model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0)) + model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device)) # check that both adapters are in the same layer assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A diff --git a/tests/test_gptqmodel.py b/tests/test_gptqmodel.py new file mode 100644 index 0000000000..1eaf7f5096 --- /dev/null +++ b/tests/test_gptqmodel.py @@ -0,0 +1,349 @@ +# Note: These tests were copied from test_common_gpu.py and test_gpu_examples.py as they can run on CPU too. +# +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import os +import tempfile +import unittest + +import pytest +import torch +from datasets import load_dataset +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + Trainer, + TrainingArguments, +) + +from peft import ( + AdaLoraConfig, + LoraConfig, + PeftModel, + get_peft_model, + prepare_model_for_kbit_training, +) +from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device + +from .testing_utils import ( + require_gptqmodel, + require_optimum, + require_torch_multi_gpu, +) + + +@require_gptqmodel +class PeftGPTQModelCommonTests(unittest.TestCase): + r""" + A common tester to run common operations that are performed on GPU/CPU such as generation, loading in 8bit, etc. + """ + + def setUp(self): + self.causal_lm_model_id = "facebook/opt-350m" + self.device = infer_device() + + def tearDown(self): + r""" + Efficient mechanism to free GPU memory after each test. Based on + https://github.com/huggingface/transformers/issues/21094 + """ + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + def test_lora_gptq_quantization_from_pretrained_safetensors(self): + r""" + Tests that the gptqmodel quantization using LoRA works as expected with safetensors weights. + """ + from transformers import GPTQConfig + + model_id = "marcsun13/opt-350m-gptq-4bit" + quantization_config = GPTQConfig(bits=4, use_exllama=False) + kwargs = { + "pretrained_model_name_or_path": model_id, + "torch_dtype": torch.float16, + "device_map": "auto", + "quantization_config": quantization_config, + } + model = AutoModelForCausalLM.from_pretrained(**kwargs) + model = prepare_model_for_kbit_training(model) + + config = LoraConfig(task_type="CAUSAL_LM") + peft_model = get_peft_model(model, config) + peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device)) + + with tempfile.TemporaryDirectory() as tmp_dir: + peft_model.save_pretrained(tmp_dir) + model = AutoModelForCausalLM.from_pretrained(**kwargs) + model = PeftModel.from_pretrained(model, tmp_dir) + model = prepare_model_for_kbit_training(model) + model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device)) + + # loading a 2nd adapter works, #1239 + model.load_adapter(tmp_dir, "adapter2") + model.set_adapter("adapter2") + model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device)) + + # check that both adapters are in the same layer + assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A + assert "adapter2" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A + + +@require_gptqmodel +@require_optimum +class PeftGPTQModelTests(unittest.TestCase): + r""" + GPTQ + peft tests + """ + + def setUp(self): + from transformers import GPTQConfig + + self.causal_lm_model_id = "marcsun13/opt-350m-gptq-4bit" + self.quantization_config = GPTQConfig(bits=4, backend="auto_trainable") + self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + + def tearDown(self): + r""" + Efficient mechanism to free GPU memory after each test. Based on + https://github.com/huggingface/transformers/issues/21094 + """ + gc.collect() + torch.cuda.empty_cache() + + def _check_inference_finite(self, model, batch): + # try inference without Trainer class + training = model.training + model.eval() + output = model(**batch.to(model.device)) + assert torch.isfinite(output.logits).all() + model.train(training) + + def test_causal_lm_training(self): + r""" + Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set + correctly. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + torch_dtype=torch.float16, + device_map="auto", + quantization_config=self.quantization_config, + ) + + model = prepare_model_for_kbit_training(model) + config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, config) + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + def test_adalora_causalLM(self): + r""" + Tests the gptq training with adalora + """ + + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + torch_dtype=torch.float16, + device_map="auto", + quantization_config=self.quantization_config, + ) + + tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + model = prepare_model_for_kbit_training(model) + + peft_config = AdaLoraConfig( + init_r=6, + target_r=4, + tinit=50, + tfinal=100, + deltaT=5, + beta1=0.3, + beta2=0.3, + orth_reg_weight=0.2, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, peft_config) + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True) + self._check_inference_finite(model, batch) + + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.multi_gpu_tests + @require_torch_multi_gpu + def test_causal_lm_training_multi_gpu(self): + r""" + Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set + correctly. + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + torch_dtype=torch.float16, + device_map="auto", + quantization_config=self.quantization_config, + ) + + assert set(model.hf_device_map.values()) == set(range(torch.cuda.device_count())) + + model = prepare_model_for_kbit_training(model) + + setattr(model, "model_parallel", True) + setattr(model, "is_parallelizable", True) + + config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("Abirate/english_quotes") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + def test_non_default_adapter_name(self): + # See issue 1346 + config = LoraConfig( + r=16, + target_modules=["q_proj", "v_proj"], + task_type="CAUSAL_LM", + ) + + # default adapter name + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + torch_dtype=torch.float16, + device_map="auto", + quantization_config=self.quantization_config, + ) + model = prepare_model_for_kbit_training(model) + model = get_peft_model(model, config) + n_trainable_default, n_total_default = model.get_nb_trainable_parameters() + + # other adapter name + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + torch_dtype=torch.float16, + device_map="auto", + quantization_config=self.quantization_config, + ) + model = prepare_model_for_kbit_training(model) + model = get_peft_model(model, config, adapter_name="other") + n_trainable_other, n_total_other = model.get_nb_trainable_parameters() + + assert n_trainable_other > 0 + # sanity check + assert n_trainable_default == n_trainable_other + assert n_total_default == n_total_other diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 32bd6515ad..d5071923b8 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -24,6 +24,7 @@ is_auto_awq_available, is_auto_gptq_available, is_eetq_available, + is_gptqmodel_available, is_hqq_available, is_optimum_available, is_torchao_available, @@ -95,7 +96,16 @@ def require_auto_gptq(test_case): """ Decorator marking a test that requires auto-gptq. These tests are skipped when auto-gptq isn't installed. """ - return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case) + return unittest.skipUnless(is_gptqmodel_available() or is_auto_gptq_available(), "test requires auto-gptq")( + test_case + ) + + +def require_gptqmodel(test_case): + """ + Decorator marking a test that requires gptqmodel. These tests are skipped when gptqmodel isn't installed. + """ + return unittest.skipUnless(is_gptqmodel_available(), "test requires gptqmodel")(test_case) def require_aqlm(test_case):