diff --git a/model/main.py b/model/main.py index b016d47..aaf63ec 100644 --- a/model/main.py +++ b/model/main.py @@ -286,7 +286,7 @@ def skip(*args, **kwargs): param.requires_grad = False if args.multigpu: - if "llama" in args.model.lower(): + if ("llama" in args.model.lower()) or ("mixtral" in args.model.lower()): map_layers_to_multi_gpus(lm.model.model.layers) input_device = lm.model.model.layers[0].device output_device = lm.model.model.layers[-1].device