From 46f7ef255a2562a14cc80d536d6669f9d41d1a87 Mon Sep 17 00:00:00 2001 From: "Tian, Feng" Date: Wed, 6 Dec 2023 14:34:21 +0800 Subject: [PATCH] fix typo in LmHeadLinearAllreduce initialization --- deepspeed/inference/quantization/layers.py | 2 +- deepspeed/inference/quantization/quantization.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/deepspeed/inference/quantization/layers.py b/deepspeed/inference/quantization/layers.py index 94c095f2b6ae..5533ffca2ecf 100644 --- a/deepspeed/inference/quantization/layers.py +++ b/deepspeed/inference/quantization/layers.py @@ -135,7 +135,7 @@ def forward(self, input: Tensor) -> Tensor: class QuantizedLmHeadLinearAllreduce(nn.Linear): def __init__(self, config: Dict, pre_quant_layer: nn.Linear) -> None: - super(QuantizedLinearLayer, self).__init__(in_features=pre_quant_layer.weight.shape[1], + super(QuantizedLmHeadLinearAllreduce, self).__init__(in_features=pre_quant_layer.weight.shape[1], out_features=pre_quant_layer.weight.shape[0], bias=pre_quant_layer.bias is not None, device=pre_quant_layer.weight.device, diff --git a/deepspeed/inference/quantization/quantization.py b/deepspeed/inference/quantization/quantization.py index 43f5d74c6efc..c1e8ab9df612 100644 --- a/deepspeed/inference/quantization/quantization.py +++ b/deepspeed/inference/quantization/quantization.py @@ -85,8 +85,6 @@ def _init_group_wise_weight_quantization(model: nn.Module, ds_config: Dict) -> n if is_zero3_enabled: module.weight.all_gather() - assert module.weight.dtype == torch.float16, 'Model weight is expected in half.' - new_module = QUANTIZATION_LAYER_MAPPINGS[type(module)](matched_quantization_config, module) if is_zero3_enabled: