-
Notifications
You must be signed in to change notification settings - Fork 13.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add apple silicon support #469
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment on memory management for CUDA, MPS, and CPU
In fp8_cast_bf16.py
at line 97 :
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
The current implementation calls torch.cuda.empty_cache()
for memory management, but this only works on CUDA devices. For MPS (Metal Performance Shaders) and CPU, we should handle memory differently. Here's an updated version that works across all devices:
import gc
...
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
# Check if CUDA is available, then free memory on CUDA device
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Check if MPS is available, then free memory on MPS device
elif torch.mps.is_available():
torch.mps.empty_cache()
# Otherwise, clean up memory for CPU
else:
gc.collect()
Why this change is important:
- CUDA: If CUDA is available,
torch.cuda.empty_cache()
will be called to free unused GPU memory. - MPS: If MPS is available on macOS,
torch.mps.empty_cache()
will be used for memory management. - CPU: For CPU or other non-GPU devices, we use
gc.collect()
to trigger garbage collection and free memory.
This ensures memory is managed correctly across all devices—whether you're using CUDA, MPS, or CPU. Would you consider updating this logic for better compatibility across environments?
Thanks! 🚀
@IFaTaK p.s. Sorry for the silly error😅 |
inference/requirements.txt
Outdated
@@ -1,4 +1,4 @@ | |||
torch==2.4.1 | |||
torch==2.6.0 | |||
triton==3.0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggested Change to requirements.txt
It seems that torch 2.6.0
specifically requires triton 3.2.0
compatibility. Currently, the requirements.txt lists triton==3.0.0
, which causes a conflict with torch==2.6.0
.
I recommend updating the triton
version in the requirements.txt
file to 3.2.0
to match the requirement from torch 2.6.0
. The updated section would look like this:
This should resolve the version compatibility issue.
However, please note that triton==3.2.0
does not currently work on macOs or Windows through pip
. Users on those platforms will need to manually install Triton by cloning the repository and building it from source. I suggest adding documentation for that in case someone needs to follow those steps. You can follow the installation instructions from the official Triton repository to do this.
This leads to update requirement.txt
as follow :
torch==2.6.0
triton==3.2.0; platform_system == "Linux"
transformers==4.46.3
safetensors==0.4.5
🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually further looking into this, Triton does not have a pip package at all for Mac or windows for any version. It currently only supports linux(https://github.com/triton-lang/triton?tab=readme-ov-file#compatibility).
Maybe we can avoid using triton all together and use pure PyTorch but that would come with it's own issues. Not sure how we can proceed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One potential solution is to have requirements.txt
look something like this:
torch==2.6.0
git+https://github.com/triton-lang/triton@main#subdirectory=python
transformers==4.46.3
safetensors==0.4.5
This way we install triton from source and should work on every OS, but yea triton would not support GPU acceleration on Macs as of now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While downloading Triton from source using git+https://github.com/triton-lang/triton@main#subdirectory=python
ensures compatibility across all OS, I am concerned that Triton may still require additional building.
That said, if we want to fully support Metal on Apple Silicon, we would need to use MLX.
However, MLX is not as mature as Triton and lacks many of the optimizations and utilities that Triton provides for CUDA. Using MLX would require significant custom implementation, particularly for matrix operations and kernel optimizations. We'd need to manually write custom Metal shaders and utilize the Python API to replicate functionality that Triton handles out of the box.
This would involve a substantial amount of work to make MLX work like Triton. It could be a great project, but it'll definitely take some time to get it right. 🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While downloading Triton from source using
git+https://github.com/triton-lang/triton@main#subdirectory=python
ensures compatibility across all OS, I am concerned that Triton may still require additional building.That said, if we want to fully support Metal on Apple Silicon, we would need to use MLX. However, MLX is not as mature as Triton and lacks many of the optimizations and utilities that Triton provides for CUDA. Using MLX would require significant custom implementation, particularly for matrix operations and kernel optimizations. We'd need to manually write custom Metal shaders and utilize the Python API to replicate functionality that Triton handles out of the box.
This would involve a substantial amount of work to make MLX work like Triton. It could be a great project, but it'll definitely take some time to get it right. 🚀
Triton is currently built around NVIDIA’s CUDA (NVPTX) and AMD’s ROCm backends. It would not work on other backends unless added(via metal for apple maybe). MLX is an array framework similar to PyTorch. We could implement this entire kernel using PyTorch’s native tensor operations, JIT, or even TorchScript which would enable cross platform, but this could come with some trade-offs like performance etc.
thanks, this change is good to be an apple-silicon branch, because we have no apple device to test it... |
No description provided.