diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index ee4fe573d..30cc1129c 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -212,8 +212,16 @@ def load_model( kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} elif device == "mps": kwargs = {"torch_dtype": torch.float16} - # Avoid bugs in mps backend by not using in-place operations. - replace_llama_attn_with_non_inplace_operations() + import transformers + + version = tuple(int(v) for v in transformers.__version__.split(".")) + if version < (4, 35, 0): + # NOTE: Recent transformers library seems to fix the mps issue, also + # it has made some changes causing compatibility issues with our + # original patch. So we only apply the patch for older versions. + + # Avoid bugs in mps backend by not using in-place operations. + replace_llama_attn_with_non_inplace_operations() elif device == "xpu": kwargs = {"torch_dtype": torch.bfloat16} # Try to load ipex, while it looks unused, it links into torch for xpu support