Skip to content

Commit

Permalink
FEAT Add gptqmodel support (#2247)
Browse files Browse the repository at this point in the history
Add support for gptqmodel quantization. This is a replacement for
auto-gptq.

For now, both packages are supported, but since auto-gptq is no longer
being developed, it will be deprecated and removed at some point in the
future.

---------

Signed-off-by: jiqing-feng <[email protected]>
Co-authored-by: LRL-ModelCloud <[email protected]>
Co-authored-by: Qubitium-ModelCloud <[email protected]>
Co-authored-by: ZX-ModelCloud <[email protected]>
Co-authored-by: LRL <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
  • Loading branch information
6 people authored Jan 23, 2025
1 parent 1b9bcb2 commit 6e30991
Show file tree
Hide file tree
Showing 11 changed files with 508 additions and 37 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ tests_core_single_gpu:
tests_common_gpu:
python -m pytest tests/test_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_decoder.log",)
python -m pytest tests/test_encoder_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_encoder_decoder.log",)
python -m pytest tests/test_gptqmodel.py $(if $(IS_GITHUB_CI),--report-log "gptqmodel_gpu.log",)

tests_examples_multi_gpu_bnb:
python -m pytest -m "multi_gpu_tests and bitsandbytes" tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "multi_gpu_examples.log",)
Expand Down
26 changes: 26 additions & 0 deletions docs/source/developer_guides/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,32 @@ QLoRA adds trainable weights to all the linear layers in the transformer archite
config = LoraConfig(target_modules="all-linear", ...)
```

## GPTQ quantization

You can learn more about gptq based `[2, 3, 4, 8]` bits quantization at [GPTQModel](https://github.com/ModelCloud/GPTQModel) and the Transformers [GPTQ](https://huggingface.co/docs/transformers/quantization/gptq) doc. Post-quant training, PEFT can use both [GPTQModel](https://github.com/ModelCloud/GPTQModel) or [AutoGPTQ](https://github.com/autogptq/autogptq) libraries, but we recommend GPTQModel because AutoGPTQ will be deprecated in a future release.

```bash
# gptqmodel install
pip install gptqmodel --no-build-isolation
```

```py
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_id)

gptq_config = GPTQConfig(bits=4, group_size=128, dataset="wikitext2", tokenizer=tokenizer)

quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=gptq_config)

# save quantized model
quantized_model.save_pretrained("./opt-125m-gptq")
tokenizer.save_pretrained("./opt-125m-gptq")
```

Once quantized, you can post-train GPTQ models with PEFT APIs.

## AQLM quantization

Additive Quantization of Language Models ([AQLM](https://arxiv.org/abs/2401.06118)) is a Large Language Models compression method. It quantizes multiple weights together and takes advantage of interdependencies between them. AQLM represents groups of 8-16 weights as a sum of multiple vector codes. This allows it to compress models down to as low as 2-bit with considerably low accuracy losses.
Expand Down
27 changes: 27 additions & 0 deletions src/peft/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,33 @@ def is_auto_gptq_available():
)


@lru_cache
def is_gptqmodel_available():
if importlib.util.find_spec("gptqmodel") is not None:
GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.7.0")
OPTIMUM_MINIMUM_VERSION = packaging.version.parse("1.23.99")
version_gptqmodel = packaging.version.parse(importlib_metadata.version("gptqmodel"))
if GPTQMODEL_MINIMUM_VERSION <= version_gptqmodel:
if is_optimum_available():
version_optimum = packaging.version.parse(importlib_metadata.version("optimum"))
if OPTIMUM_MINIMUM_VERSION <= version_optimum:
return True
else:
raise ImportError(
f"gptqmodel requires optimum version {OPTIMUM_MINIMUM_VERSION} or higher. Found version {version_optimum}, "
f"but only versions above {OPTIMUM_MINIMUM_VERSION} are supported"
)
else:
raise ImportError(
f"gptqmodel requires optimum version {OPTIMUM_MINIMUM_VERSION} or higher to be installed."
)
else:
raise ImportError(
f"Found an incompatible version of gptqmodel. Found version {version_gptqmodel}, "
f"but only versions above {GPTQMODEL_MINIMUM_VERSION} are supported"
)


@lru_cache
def is_optimum_available() -> bool:
return importlib.util.find_spec("optimum") is not None
Expand Down
16 changes: 11 additions & 5 deletions src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
import torch
from transformers.pytorch_utils import Conv1D

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_gptqmodel_available
from peft.tuners.lora import LoraConfig, LoraModel
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import (
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
_freeze_adapter,
_get_submodules,
get_auto_gptq_quant_linear,
get_gptqmodel_quant_linear,
get_quantization_config,
)
from peft.utils.integrations import gather_params_ctx
Expand Down Expand Up @@ -135,7 +136,8 @@ def _create_and_replace(

# If it is not an AdaLoraLayer, create a new module, else update it with new adapters
if not isinstance(target, AdaLoraLayer):
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None
new_module = self._create_new_module(lora_config, adapter_name, target, device_map=device_map, **kwargs)
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
Expand All @@ -150,7 +152,7 @@ def _create_and_replace(
)

@staticmethod
def _create_new_module(lora_config, adapter_name, target, **kwargs):
def _create_new_module(lora_config, adapter_name, target, device_map=None, **kwargs):
# avoid eager bnb import
if is_bnb_available():
import bitsandbytes as bnb
Expand All @@ -160,7 +162,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
from .bnb import SVDLinear4bit

gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)

if is_gptqmodel_available():
QuantLinear = get_gptqmodel_quant_linear(gptq_quantization_config, device_map=device_map)
else:
QuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)

loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
Expand Down Expand Up @@ -189,7 +195,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
}
)
new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs)
elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear):
elif QuantLinear is not None and isinstance(target, QuantLinear):
new_module = SVDQuantLinear(target, adapter_name, **kwargs)
else:
if isinstance(target_base_layer, torch.nn.Linear):
Expand Down
14 changes: 10 additions & 4 deletions src/peft/tuners/lora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

import torch

from peft.import_utils import is_gptqmodel_available
from peft.tuners.lora.layer import LoraLayer
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import get_auto_gptq_quant_linear
from peft.utils import get_auto_gptq_quant_linear, get_gptqmodel_quant_linear


class QuantLinear(torch.nn.Module, LoraLayer):
Expand Down Expand Up @@ -106,10 +107,15 @@ def dispatch_gptq(
else:
target_base_layer = target

gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
cfg = kwargs.get("gptq_quantization_config", None)

if AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear):
if is_gptqmodel_available():
device_map = kwargs.get("device_map", None)
quant_linear = get_gptqmodel_quant_linear(cfg, device_map=device_map)
else:
quant_linear = get_auto_gptq_quant_linear(cfg)

if quant_linear is not None and isinstance(target_base_layer, quant_linear):
new_module = QuantLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight

Expand Down
3 changes: 2 additions & 1 deletion src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def _create_and_replace(
lora_bias=lora_config.lora_bias,
)
else:
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None
new_module = self._create_new_module(lora_config, adapter_name, target, device_map=device_map, **kwargs)
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
Expand Down
2 changes: 2 additions & 0 deletions src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
bloom_model_postprocess_past_key_value,
cast_mixed_precision_params,
get_auto_gptq_quant_linear,
get_gptqmodel_quant_linear,
get_quantization_config,
id_tensor_storage,
infer_device,
Expand Down Expand Up @@ -77,6 +78,7 @@
"bloom_model_postprocess_past_key_value",
"cast_mixed_precision_params",
"get_auto_gptq_quant_linear",
"get_gptqmodel_quant_linear",
"get_peft_model_state_dict",
"get_quantization_config",
"id_tensor_storage",
Expand Down
89 changes: 66 additions & 23 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from packaging import version
from safetensors.torch import storage_ptr, storage_size

from ..import_utils import is_auto_gptq_available, is_torch_tpu_available
from ..import_utils import is_auto_gptq_available, is_gptqmodel_available, is_torch_tpu_available
from .constants import (
CONFIG_NAME,
EMBEDDING_LAYER_NAMES,
Expand Down Expand Up @@ -610,30 +610,73 @@ def get_auto_gptq_quant_linear(gptq_quantization_config):
"""
Get the right AutoGPTQQuantLinear class based on the quantization config file
"""
if gptq_quantization_config is not None and is_auto_gptq_available():
if gptq_quantization_config is None:
return None

if is_auto_gptq_available():
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
else:
return None

desc_act = gptq_quantization_config.desc_act
group_size = gptq_quantization_config.group_size
bits = gptq_quantization_config.bits
if hasattr(gptq_quantization_config, "use_exllama"):
use_exllama = gptq_quantization_config.use_exllama
else:
use_exllama = not gptq_quantization_config.disable_exllama
if hasattr(gptq_quantization_config, "exllama_config"):
exllama_version = gptq_quantization_config.exllama_config["version"]
else:
exllama_version = 1
AutoGPTQQuantLinear = dynamically_import_QuantLinear(
use_triton=False,
desc_act=desc_act,
group_size=group_size,
bits=bits,
disable_exllama=not (use_exllama and exllama_version == 1),
disable_exllamav2=not (use_exllama and exllama_version == 2),
)
return AutoGPTQQuantLinear
return None
desc_act = gptq_quantization_config.desc_act
group_size = gptq_quantization_config.group_size
bits = gptq_quantization_config.bits
if hasattr(gptq_quantization_config, "use_exllama"):
use_exllama = gptq_quantization_config.use_exllama
else:
use_exllama = not gptq_quantization_config.disable_exllama
if hasattr(gptq_quantization_config, "exllama_config"):
exllama_version = gptq_quantization_config.exllama_config["version"]
else:
exllama_version = 1

QuantLinear = dynamically_import_QuantLinear(
use_triton=False,
desc_act=desc_act,
group_size=group_size,
bits=bits,
disable_exllama=not (use_exllama and exllama_version == 1),
disable_exllamav2=not (use_exllama and exllama_version == 2),
)

return QuantLinear


def get_gptqmodel_quant_linear(gptq_quantization_config, device_map=None):
"""
Get the right GPTQQuantLinear class based on the quantization config file
"""
if gptq_quantization_config is None:
return None

if not is_gptqmodel_available():
return None

from gptqmodel.utils.importer import hf_select_quant_linear

desc_act = gptq_quantization_config.desc_act
group_size = gptq_quantization_config.group_size
bits = gptq_quantization_config.bits
checkpoint_format = (
gptq_quantization_config.checkpoint_format
if hasattr(gptq_quantization_config, "checkpoint_format")
else "gptq"
)
sym = gptq_quantization_config.sym
meta = gptq_quantization_config.meta if hasattr(gptq_quantization_config, "meta") else None

QuantLinear = hf_select_quant_linear(
bits=bits,
group_size=group_size,
desc_act=desc_act,
sym=sym,
device_map=device_map,
checkpoint_format=checkpoint_format,
meta=meta,
backend="auto_trainable",
)

return QuantLinear


def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_common_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,19 +406,19 @@ def test_lora_gptq_quantization_from_pretrained_safetensors(self):

config = LoraConfig(task_type="CAUSAL_LM")
peft_model = get_peft_model(model, config)
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))

with tempfile.TemporaryDirectory() as tmp_dir:
peft_model.save_pretrained(tmp_dir)
model = AutoModelForCausalLM.from_pretrained(**kwargs)
model = PeftModel.from_pretrained(model, tmp_dir)
model = prepare_model_for_kbit_training(model)
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))

# loading a 2nd adapter works, #1239
model.load_adapter(tmp_dir, "adapter2")
model.set_adapter("adapter2")
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))

# check that both adapters are in the same layer
assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A
Expand Down
Loading

0 comments on commit 6e30991

Please sign in to comment.