Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add apple silicon support #469

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

NripeshN
Copy link

No description provided.

Copy link

@IFaTaK IFaTaK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on memory management for CUDA, MPS, and CPU

In fp8_cast_bf16.py at line 97 :

if len(loaded_files) > 2:
    oldest_file = next(iter(loaded_files))
    del loaded_files[oldest_file]
    torch.cuda.empty_cache()

The current implementation calls torch.cuda.empty_cache() for memory management, but this only works on CUDA devices. For MPS (Metal Performance Shaders) and CPU, we should handle memory differently. Here's an updated version that works across all devices:

import gc

...

# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
    oldest_file = next(iter(loaded_files))
    del loaded_files[oldest_file]

    # Check if CUDA is available, then free memory on CUDA device
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    # Check if MPS is available, then free memory on MPS device
    elif torch.mps.is_available():
        torch.mps.empty_cache()
    # Otherwise, clean up memory for CPU
    else:
        gc.collect()

Why this change is important:

  • CUDA: If CUDA is available, torch.cuda.empty_cache() will be called to free unused GPU memory.
  • MPS: If MPS is available on macOS, torch.mps.empty_cache() will be used for memory management.
  • CPU: For CPU or other non-GPU devices, we use gc.collect() to trigger garbage collection and free memory.

This ensures memory is managed correctly across all devices—whether you're using CUDA, MPS, or CPU. Would you consider updating this logic for better compatibility across environments?

Thanks! 🚀

inference/model.py Outdated Show resolved Hide resolved
@NripeshN
Copy link
Author

NripeshN commented Feb 2, 2025

@IFaTaK
Thanks for the suggestion, just made all the changes.

p.s. Sorry for the silly error😅

@NripeshN NripeshN requested a review from IFaTaK February 2, 2025 10:43
@@ -1,4 +1,4 @@
torch==2.4.1
torch==2.6.0
triton==3.0.0
Copy link

@IFaTaK IFaTaK Feb 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested Change to requirements.txt
It seems that torch 2.6.0 specifically requires triton 3.2.0 compatibility. Currently, the requirements.txt lists triton==3.0.0, which causes a conflict with torch==2.6.0.

I recommend updating the triton version in the requirements.txt file to 3.2.0 to match the requirement from torch 2.6.0. The updated section would look like this:

This should resolve the version compatibility issue.

However, please note that triton==3.2.0 does not currently work on macOs or Windows through pip. Users on those platforms will need to manually install Triton by cloning the repository and building it from source. I suggest adding documentation for that in case someone needs to follow those steps. You can follow the installation instructions from the official Triton repository to do this.

This leads to update requirement.txt as follow :

torch==2.6.0
triton==3.2.0; platform_system == "Linux"
transformers==4.46.3
safetensors==0.4.5

🚀

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually further looking into this, Triton does not have a pip package at all for Mac or windows for any version. It currently only supports linux(https://github.com/triton-lang/triton?tab=readme-ov-file#compatibility).

Maybe we can avoid using triton all together and use pure PyTorch but that would come with it's own issues. Not sure how we can proceed.

Copy link
Author

@NripeshN NripeshN Feb 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One potential solution is to have requirements.txt look something like this:

torch==2.6.0
git+https://github.com/triton-lang/triton@main#subdirectory=python
transformers==4.46.3
safetensors==0.4.5

This way we install triton from source and should work on every OS, but yea triton would not support GPU acceleration on Macs as of now.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While downloading Triton from source using git+https://github.com/triton-lang/triton@main#subdirectory=python ensures compatibility across all OS, I am concerned that Triton may still require additional building.

That said, if we want to fully support Metal on Apple Silicon, we would need to use MLX.
However, MLX is not as mature as Triton and lacks many of the optimizations and utilities that Triton provides for CUDA. Using MLX would require significant custom implementation, particularly for matrix operations and kernel optimizations. We'd need to manually write custom Metal shaders and utilize the Python API to replicate functionality that Triton handles out of the box.

This would involve a substantial amount of work to make MLX work like Triton. It could be a great project, but it'll definitely take some time to get it right. 🚀

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While downloading Triton from source using git+https://github.com/triton-lang/triton@main#subdirectory=python ensures compatibility across all OS, I am concerned that Triton may still require additional building.

That said, if we want to fully support Metal on Apple Silicon, we would need to use MLX. However, MLX is not as mature as Triton and lacks many of the optimizations and utilities that Triton provides for CUDA. Using MLX would require significant custom implementation, particularly for matrix operations and kernel optimizations. We'd need to manually write custom Metal shaders and utilize the Python API to replicate functionality that Triton handles out of the box.

This would involve a substantial amount of work to make MLX work like Triton. It could be a great project, but it'll definitely take some time to get it right. 🚀

Triton is currently built around NVIDIA’s CUDA (NVPTX) and AMD’s ROCm backends. It would not work on other backends unless added(via metal for apple maybe). MLX is an array framework similar to PyTorch. We could implement this entire kernel using PyTorch’s native tensor operations, JIT, or even TorchScript which would enable cross platform, but this could come with some trade-offs like performance etc.

@mowentian
Copy link
Contributor

thanks, this change is good to be an apple-silicon branch, because we have no apple device to test it...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants