Skip to content

Commit

Permalink
Fix MPS backend 'index out of range' error (lm-sys#2737)
Browse files Browse the repository at this point in the history
  • Loading branch information
suquark authored and zhanghao.smooth committed Jan 26, 2024
1 parent 5077879 commit a336364
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,16 @@ def load_model(
kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
elif device == "mps":
kwargs = {"torch_dtype": torch.float16}
# Avoid bugs in mps backend by not using in-place operations.
replace_llama_attn_with_non_inplace_operations()
import transformers

version = tuple(int(v) for v in transformers.__version__.split("."))
if version < (4, 35, 0):
# NOTE: Recent transformers library seems to fix the mps issue, also
# it has made some changes causing compatibility issues with our
# original patch. So we only apply the patch for older versions.

# Avoid bugs in mps backend by not using in-place operations.
replace_llama_attn_with_non_inplace_operations()
elif device == "xpu":
kwargs = {"torch_dtype": torch.bfloat16}
# Try to load ipex, while it looks unused, it links into torch for xpu support
Expand Down

0 comments on commit a336364

Please sign in to comment.