Skip to content

Commit

Permalink
fix typo in LmHeadLinearAllreduce initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
ftian1 committed Jan 16, 2025
1 parent 9b947a7 commit 46f7ef2
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 3 deletions.
2 changes: 1 addition & 1 deletion deepspeed/inference/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def forward(self, input: Tensor) -> Tensor:
class QuantizedLmHeadLinearAllreduce(nn.Linear):

def __init__(self, config: Dict, pre_quant_layer: nn.Linear) -> None:
super(QuantizedLinearLayer, self).__init__(in_features=pre_quant_layer.weight.shape[1],
super(QuantizedLmHeadLinearAllreduce, self).__init__(in_features=pre_quant_layer.weight.shape[1],
out_features=pre_quant_layer.weight.shape[0],
bias=pre_quant_layer.bias is not None,
device=pre_quant_layer.weight.device,
Expand Down
2 changes: 0 additions & 2 deletions deepspeed/inference/quantization/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ def _init_group_wise_weight_quantization(model: nn.Module, ds_config: Dict) -> n
if is_zero3_enabled:
module.weight.all_gather()

assert module.weight.dtype == torch.float16, 'Model weight is expected in half.'

new_module = QUANTIZATION_LAYER_MAPPINGS[type(module)](matched_quantization_config, module)

if is_zero3_enabled:
Expand Down

0 comments on commit 46f7ef2

Please sign in to comment.