From adc5238bdff782ef8892e069e67cfeda164d9996 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 2 Jan 2025 12:39:40 +0000 Subject: [PATCH] Fix multi-GPU setup --- .../common/accelerate_utils/accelerate.py | 1 - src/brevitas_examples/llm/main.py | 83 ++++++++++++------- 2 files changed, 51 insertions(+), 33 deletions(-) diff --git a/src/brevitas_examples/common/accelerate_utils/accelerate.py b/src/brevitas_examples/common/accelerate_utils/accelerate.py index ead616ed2..369c456c0 100644 --- a/src/brevitas_examples/common/accelerate_utils/accelerate.py +++ b/src/brevitas_examples/common/accelerate_utils/accelerate.py @@ -407,7 +407,6 @@ def offload_model( else: device_map = infer_auto_device_map( model, memory_map, no_split_module_classes=model._no_split_modules) - model = dispatch_model(model, device_map) # Fixes an asymetric behavior in Accelerate where hooks are not attached at all when a single device is used. diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 947314185..08eb92769 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -3,6 +3,7 @@ import argparse from copy import deepcopy +from functools import partial from functools import wraps import os import sys @@ -20,6 +21,7 @@ from brevitas.export import export_torch_qcdq from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.graph.base import ModuleInstanceFuseRotationWeights +from brevitas.graph.base import Transform from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize @@ -56,27 +58,31 @@ def set_seed(seed): torch.random.manual_seed(seed) -def on_process(process_index: int): +def is_main_process(): + return int(os.environ.get('LOCAL_RANK', -1)) in [-1, 0] - def decorator(func: Callable): - @wraps(func) - def _wrapper(model, *args, **kwargs): - curr_process_index = int(os.environ.get('LOCAL_RANK', -1)) +def on_process(func: Callable, process_index: int): - if curr_process_index == -1 or (process_index == curr_process_index): - print(f"Applying {func.__name__} on process index {curr_process_index}") - return func(model, *args, **kwargs) - else: - print(f"Skipping function {func.__name__} on process index {curr_process_index}") - return model + @wraps(func) + def _wrapper(model, *args, **kwargs): + curr_process_index = int(os.environ.get('LOCAL_RANK', -1)) + + if curr_process_index == -1 or (process_index == curr_process_index): + print(f"Applying {func.__name__} on process index {curr_process_index}") + return func(model, *args, **kwargs) + else: + print(f"Skipping function {func.__name__} on process index {curr_process_index}") + return model - return _wrapper + return _wrapper - return decorator +on_main_process = partial(on_process, process_index=0) -def apply_fused_rotations(model: torch.nn.Module, rewriters: List) -> torch.nn.Module: + +@on_main_process +def apply_fused_rotations(model: torch.nn.Module, rewriters: List[Transform]) -> torch.nn.Module: model = offload_model(model) for r in rewriters: if isinstance(r, ModuleInstanceFuseRotationWeights): @@ -85,6 +91,15 @@ def apply_fused_rotations(model: torch.nn.Module, rewriters: List) -> torch.nn.M return model +@on_main_process +def evaluate_model(model: torch.nn.Module, validation_loader, args, tokenizer): + model = offload_model(model) + quant_ppl = compute_perplexity( + model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) + print(f"Perplexity ({args.dataset}): {quant_ppl:.3f}") + remove_hooks(model) + + # TODO: Use no_grad? The result of fusing the rotations would yield tensor with requires_grad set to False, # which might no be a problem, as that flag is set in the appropiate QAT/PTQ algorithms. def fused_rotation_no_fx( @@ -282,12 +297,9 @@ def main(args, unknown_args=None): if args.eval: assert args.export_target != 'torch_qcdq', "TorchScript QCDQ export and Evaluation simultaneously" - print("Float model eval...") - model = offload_model(model) - float_ppl = compute_perplexity( - model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) - remove_hooks(model) - print(f"Float perplexity ({args.dataset}): {float_ppl:.3f}") + print("Evaluating float model...") + evaluate_model(model, validation_loader, args, tokenizer) + print("Float evaluation done.") if args.replace_rmsnorm: model = replace_rmsnorm_with_torch(model, model.config) @@ -437,12 +449,21 @@ def mock_save_pretrained_fn(*args, **kwargs): if args.bias_corr: model = add_zero_bias_to_linear(model) - model = offload_model(model) - - with torch.no_grad(): - model(**calibration_loader[0]) - - remove_hooks(model) + # We need to run a calibration forward pass to initialize quantization-related parameters, + # e.g. scales. In DDP, as parameters are synchronized across replicas before optimization, + # it is not needed to run this pass for every process, as the parameters of the main + # process will be broadcasted to each replica. + if is_main_process(): + model = offload_model(model) + with torch.no_grad(): + model(**calibration_loader[0]) + remove_hooks(model) + else: + # TODO: Generalize this logic. Currently, only ParameterFromStatsFromParameterZeroPoint + # and ParameterFromStatsFromParameterScaling have the attribute init_done + for module in model.modules(): + if hasattr(module, "init_done"): + module.init_done = True if args.rotation in ['fused_no_fx_optimize', 'fused_no_fx_optimize_self_attn_region']: apply_rotation_optimization( @@ -453,6 +474,7 @@ def mock_save_pretrained_fn(*args, **kwargs): ) remove_hooks(model) + torch.cuda.empty_cache() if args.act_calibration: print("Apply act calibration...") @@ -489,12 +511,9 @@ def mock_save_pretrained_fn(*args, **kwargs): print("Bias correction applied.") if args.eval and not args.no_quantize: - print("Model eval...") - model = offload_model(model) - quant_ppl = compute_perplexity( - model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) - print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") - remove_hooks(model) + print("Evaluating quantized model...") + evaluate_model(model, validation_loader, args, tokenizer) + print("Quantized evaluation done.") if args.checkpoint_name is not None: print(f"Saving checkpoint to {args.checkpoint_name}")