Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

macOS x86_64 MPS Support #118

Closed
jaredbrewer opened this issue Dec 20, 2024 · 19 comments
Closed

macOS x86_64 MPS Support #118

jaredbrewer opened this issue Dec 20, 2024 · 19 comments

Comments

@jaredbrewer
Copy link

This is largely an add-on to a previous issue raised for QuPath more broadly (qupath/qupath#1733), which has since been resolved.

Given that we can now expect to have access to PyTorch on macOS x64 (for the time being) in QuPath v0.6, the performance of InstanSeg can be materially improved by adding non-CPU device options to PyTorch. I note that the logic in PyTorchManager.java, at least in one location is as follows:

            if (GeneralTools.isMac() && "aarch64".equals(System.getProperty("os.arch"))) {
                availableDevices.add("mps");
            }

This is logical, but the Metal 3 framework supports many recent AMD GPUs on macOS (up-to and including the RX 6800/6900, which is only a single generation old). MPS is compatible with Intel Macs in general and should be an available device on most or all Intel Macs, which would speed up computations significantly. The internal fallback option for Metal is CPU inference, this just allows the system to determine the preferred device.

@alanocallaghan
Copy link
Collaborator

Can you test whether enabling MPS in this way enables gpu acceleration in this context? I'm reluctant to change it otherwise without a testing environment, which we cant really get without substantial expense and effort

@jaredbrewer
Copy link
Author

Yes, I am happy to test this out. Are there other locations that point toward the appropriate device based on operating system/architecture? I can make a branch and evaluate whether this works.

@alanocallaghan
Copy link
Collaborator

alanocallaghan commented Dec 20, 2024

No, if you make a branch and remove the System.getProperty check for aarch64 then any mac system will have MPS enabled

@jaredbrewer
Copy link
Author

jaredbrewer commented Dec 20, 2024

The conclusion of this is a bit strange but perhaps you might have additional insight:

  • I downloaded and compiled QuPath from the GitHub repo, PyTorch 2.2.2 downloaded fine and seems to work (great!). Note: I removed InstanSeg from the "include-extras" option
  • Downloaded and packaged this repository, removing the architecture checks.
  • The "mps" option correctly appears in the InstanSeg dialog and runs significantly faster than the "cpu" option.

However, the MPS option reports that most of the tiles failed, an error I haven't seen before. I tried adjusting the threads and tile size, to no avail. The dialog text suggests a memory issue but that doesn't appear to be the case based on my system monitoring and the GPU is indeed appropriately engaged by selecting the MPS option. I can try to troubleshoot further as this seems like a useful addition for me and likely others on "legacy" hardware.

@jaredbrewer
Copy link
Author

If I run this script via the interpreter:

qupath.ext.instanseg.core.InstanSeg.builder()
    .modelPath("models/fluorescence_nuclei_and_cells")
    .device("mps")
    .inputChannels({input channels go here})
    .outputChannels()
    .tileDims(256)
    .interTilePadding(32)
    .nThreads(4)
    .makeMeasurements(false)
    .randomColors(false)
    .build()
    .detectObjects()

I get this error, which is much more helpful:

ERROR: Error in prediction
ai.djl.translate.TranslateException: ai.djl.engine.EngineException: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/InstanSeg/utils/loss/instanseg_loss/___torch_mangle_3796.py", line 280, in forward
          lab = lab0
        else:
          lab1 = torch.unsqueeze(torch.stack(_23), 0)
                                 ~~~~~~~~~~~ <--- HERE
          lab = lab1
        _142 = torch.eq((torch.size(lab))[1], 2)

Traceback of TorchScript, original code (most recent call last):
  File "/home/thibaut_goldsborough/Documents/Projects/Project_InstanSeg/InstanSeg_Public/instanseg_thibaut/InstanSeg/utils/loss/instanseg_loss.py", line 1417, in forward
                    lab = labels_list[0][None, None]  # 1,1,H,W
                else:
                    lab = torch.stack(labels_list)[None] 
                          ~~~~~~~~~~~ <--- HERE
    
                if lab.shape[1] == 2 and resolve_cell_and_nucleus: #nuclei and cells
RuntimeError: torch.cat(): all input tensors must be on the same device. Received mps:0 and cpu

    at ai.djl.inference.Predictor.batchPredict(Predictor.java:197)
    at ai.djl.inference.Predictor.predict(Predictor.java:133)
    at qupath.ext.instanseg.core.TilePredictionProcessor.process(TilePredictionProcessor.java:143)
    at qupath.ext.instanseg.core.TilePredictionProcessor.process(TilePredictionProcessor.java:35)
    at qupath.lib.experimental.pixels.OpenCVProcessor$PointerScopeProcessor.process(OpenCVProcessor.java:261)
    at qupath.lib.experimental.pixels.PixelProcessor$ProcessorTask.run(PixelProcessor.java:273)
    at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Unknown Source)
    at java.base/java.util.concurrent.FutureTask.run(Unknown Source)
    at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Unknown Source)
    at java.base/java.util.concurrent.FutureTask.run(Unknown Source)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
    at java.base/java.lang.Thread.run(Unknown Source)
  Caused by The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/InstanSeg/utils/loss/instanseg_loss/___torch_mangle_3796.py", line 280, in forward
          lab = lab0
        else:
          lab1 = torch.unsqueeze(torch.stack(_23), 0)
                                 ~~~~~~~~~~~ <--- HERE
          lab = lab1
        _142 = torch.eq((torch.size(lab))[1], 2)

Traceback of TorchScript, original code (most recent call last):
  File "/home/thibaut_goldsborough/Documents/Projects/Project_InstanSeg/InstanSeg_Public/instanseg_thibaut/InstanSeg/utils/loss/instanseg_loss.py", line 1417, in forward
                    lab = labels_list[0][None, None]  # 1,1,H,W
                else:
                    lab = torch.stack(labels_list)[None] 
                          ~~~~~~~~~~~ <--- HERE
    
                if lab.shape[1] == 2 and resolve_cell_and_nucleus: #nuclei and cells
RuntimeError: torch.cat(): all input tensors must be on the same device. Received mps:0 and cpu
        at ai.djl.pytorch.jni.PyTorchLibrary.moduleRunMethod(Native Method)
        at ai.djl.pytorch.jni.IValueUtils.forward(IValueUtils.java:57)
        at ai.djl.pytorch.engine.PtSymbolBlock.forwardInternal(PtSymbolBlock.java:155)
        at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:79)
        at ai.djl.nn.Block.forward(Block.java:127)
        at ai.djl.inference.Predictor.predictInternal(Predictor.java:147)
        at ai.djl.inference.Predictor.batchPredict(Predictor.java:188)
        at ai.djl.inference.Predictor.predict(Predictor.java:133)
        at qupath.ext.instanseg.core.TilePredictionProcessor.process(TilePredictionProcessor.java:143)
        at qupath.ext.instanseg.core.TilePredictionProcessor.process(TilePredictionProcessor.java:35)
        at qupath.lib.experimental.pixels.OpenCVProcessor$PointerScopeProcessor.process(OpenCVProcessor.java:261)
        at qupath.lib.experimental.pixels.PixelProcessor$ProcessorTask.run(PixelProcessor.java:273)
        at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Unknown Source)
        at java.base/java.util.concurrent.FutureTask.run(Unknown Source)
        at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Unknown Source)
        at java.base/java.util.concurrent.FutureTask.run(Unknown Source)
        at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
        at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
        at java.base/java.lang.Thread.run(Unknown Source)

@alanocallaghan
Copy link
Collaborator

Hmm, looks like this will require some changes in the model code

@jaredbrewer
Copy link
Author

jaredbrewer commented Dec 21, 2024

Looking through the InstanSeg repo, it seems that there are some locations where MPS is referenced as a device option; I suppose the question is how does it work on Apple Silicon Macs in that light. If there are ways I can help, I would be happy to be of assistance; the model itself seems very structured.

Here, in the referenced instanseg_loss.py seems to be a potential source of the issue:

                    original_device = x.device

                    if x.is_mps:
                        device = 'cpu'
                        mesh_grid_flat = mesh_grid_flat.to(device)
                        x = x.to(device)
                        mask_map = mask_map.to(device)

@petebankhead
Copy link
Member

petebankhead commented Dec 21, 2024

I remember @ThibautGoldsborough had to do a lot of work to get MPS support working on Apple Silicon. It requires much more moving data between devices because some operations were either unsupported on MPS or seemingly-supported-but-actually-buggy. Sometimes the bugs caused crashes, but sometimes they were more subtle (e.g. torch.clamp gave wrong values until 2.4.1).

I think it's only really worth trying to get it working on QuPath if it already works in Python. Therefore if you want to investigate this, I think the first task would be to see what you need to do to get the InstanSeg Python code working fully on MPS and limited to PyTorch 2.2.2 at the very latest.

@petebankhead
Copy link
Member

petebankhead commented Dec 21, 2024

Actually, possibly ignore my last comment....

I don't know what the issue is, but I can't run the fluorescence model either from the repo in its current state. The fluorescence model works with CPU only, and the brightfield model works for both CPU and MPS.

The work to introduce model versioning for QuPath v0.6.0-rc4 is still ongoing. In the current state, at least two things appear to be broken:

  1. Downloaded models aren't recognised by the extension, because mode.isValid() always returns false because there is no path set. I think this is a regression introduced in Implement model versioning #113 after it was previously handled at Move model dir choice to top of dialog #69
  2. The TorchScript for the fluorescence model itself seems to be broken for MPS, somewhere in handling nuclei and cells together.

I think 2. because if I run QuPath v0.6.0-rc3 and copy the model file to where the extension now expects it to be then the inference works, i.e. inside downloaded/fluorescence_nuclei_and_cells-0.1.0. So the extension is downloading a slightly different model file from the one used with v0.6.0-rc3.

Because of 1, I expected I'd have to copy the model every time I launch QuPath. But actually I only have to do it once; the extension seems to unnecessarily redownload it on every relaunch, but doesn't actually overwrite the contents.

Mentioning @alanocallaghan and @ThibautGoldsborough as I suspect solving both of these problems will require both of them - but it's holiday time now so it shouldn't happen any time soon.

@jaredbrewer if you want to investigate before then, I'd suggest using the brightfield model only when testing with MPS. Or...

The good news

Alternatively, it should be easier to check if you just use a script instead of the UI - then you shoudn't need to customise the extension code at all. With that, I can run

qupath.ext.instanseg.core.InstanSeg.builder()
    .modelPath("/path/to/InstanSeg models/fluorescence_nuclei_and_cells")
    .device("mps")
    .outputChannels()
    .randomColors(true)
    .build()
    .detectObjects()

Although I can only try it on Apple Silicon, I'm using an x86_64 JDK and so it should be via Rosetta2. Then I find that if I switch "mps" to "cpu" it slows down massively.

tbh I'd no idea that MPS worked with PyTorch and non-Apple-Silicon. @jaredbrewer if you can confirm it works on an actual Intel computer, then we should definitely enable the option. With a bit of luck, there are no changes required on the Python side at all (other than to fix the regression that also affects Apple Silicon).

@jaredbrewer
Copy link
Author

jaredbrewer commented Dec 21, 2024

Hi @petebankhead,

Thank you for the very detailed investigation on this. And yes, of course none of this is an emergency and should wait for a more convenient time for everybody.

I can evaluate whether the brightfield model will run properly on Intel and edit this comment, I anticipate that it should given that even in the "failed" previous attempt we did have successful engagement of the GPU during runtime.

Edit: Yes, I can confirm that MPS processing works completely fine with the brightfield model using an example image from GTEx (https://gtexportal.org/home/histologyPage), so it does seem as if it is an issue with the fluorescence model. I'm glad that my issue was not so particular as to cause troubles but highlighted a potential bug that may impact more users at this point (with MPS broadly, including on AS).

@jaredbrewer
Copy link
Author

This is partially tangential to the question at hand, but it is worth noting that versions up-to and including the latest commits on pytorch/main are functionally compatible with x86 macOS if you compile it yourself. They deprecated support with the caveat that they would not intentionally introduce breaking changes and would accept PRs that resolve any that emerge so long as it does not interfere with other functionality. It seems that no such breaking changes have as-yet been introduced.

It raises the question of whether similar compatibility exists with DJL and might allow for a more complete resolution of some of these related issues.

@petebankhead
Copy link
Member

I'm not sure I understand entirely what you're thinking. DJL itself will download PyTorch using its own links, as well as the JNI 'glue' needed to make them accessible from Java - and this too would need to be built for x86.

If DJL would support x86 macOS, then that's great. But if not (and I guess they had their reasons for dropping x86 support), I don't think we can realistically take on the task of maintaining a custom fork of DJL and creating our own builds of both PyTorch and DJL just for QuPath - and I don't see a simpler alternative to get it working.

@jaredbrewer
Copy link
Author

I'm not sure I understand entirely what you're thinking. DJL itself will download PyTorch using its own links, as well as the JNI 'glue' needed to make them accessible from Java - and this too would need to be built for x86.

If DJL would support x86 macOS, then that's great. But if not (and I guess they had their reasons for dropping x86 support), I don't think we can realistically take on the task of maintaining a custom fork of DJL and creating our own builds of both PyTorch and DJL just for QuPath - and I don't see a simpler alternative to get it working.

Upon further investigation, it does appear that this is probably extraneous but an enterprising individual can indeed compile the existing PyTorch JNI, DJL (with tests turned off to avoid automatic download of pytorch-engine), and PyTorch 2.5.1 (or presumably another version supported by pytorch-engine 0.32) and get an macOS x86-compatible version of DJL with PyTorch support. These types of "workarounds" could be incorporated into QuPath directly or indirectly in the future by users needing access to models that are not backwards compatible with the last-supported-version on x86.

My assumption is that DJL dropped macOS x86 support because PyTorch itself did and there would no longer be guarantees of compatibility or stability going forward. However, to the exact degree that Java is OS-agnostic, there does not appear to be any other technical reason that they could not continue to support macOS x86 since ONNX, JAX, and TensorFlow are still compatible with no announced deprecation plans; the recent-ish release of Keras 3.0 allowing for interchangeable execution of models and code written for different libraries.

@alanocallaghan
Copy link
Collaborator

Downloaded models aren't recognised by the extension, because mode.isValid() always returns false because there is no path set. I think this is a regression introduced in #113 after it was previously handled at #69

I've tried to reproduce this umpteen times and simply cannot. For me, downloading models works exactly as expected - only downloaded if needed, and work fine afterwards, including distinguishing between local and downloaded models in the UI. Only minor complaints for me are that for remote models, you need to click download for the extension to find the files on disk, and the message reads "Model downloading" even if we're not downloading anything, just finding the existing files on disk.

@alanocallaghan
Copy link
Collaborator

alanocallaghan commented Jan 3, 2025

The TorchScript for the fluorescence model itself seems to be broken for MPS, somewhere in handling nuclei and cells together.

I've replace the fluorescence model used by the current instanseg extension main branch to the RC1 version (albeit with updated semver string), so it should work on MPS again.

@petebankhead
Copy link
Member

Downloaded models aren't recognised by the extension, because mode.isValid() always returns false because there is no path set. I think this is a regression introduced in #113 after it was previously handled at #69

I've tried to reproduce this umpteen times and simply cannot. For me, downloading models works exactly as expected - only downloaded if needed, and work fine afterwards, including distinguishing between local and downloaded models in the UI. Only minor complaints for me are that for remote models, you need to click download for the extension to find the files on disk, and the message reads "Model downloading" even if we're not downloading anything, just finding the existing files on disk.

It may be the same as the minor complaint, I just didn't find it so minor :)
I click 'Run', not download, and then it tells me that it needs to download the model. I believed it when it said it was downloading, although realise it might have been lying to me.

Then, after that, for the brightfield model it resets the channels and then complains that they aren't selected... so I have to manually set them.

Both these feel like regressions that will definitely earn complaints.

alanocallaghan added a commit to alanocallaghan/qupath-extension-instanseg that referenced this issue Jan 3, 2025
@alanocallaghan
Copy link
Collaborator

Okay that makes a lot more sense. Resolved in #119

The channel cache I'm not sure about, I didn't implement that and have not been able to figure out how or why it does and does not work

@petebankhead
Copy link
Member

Checked PR quickly, seems to be behaving well so I've merged it.

petebankhead added a commit to petebankhead/qupath-extension-instanseg that referenced this issue Jan 7, 2025
Seems it isn't limited to usefulness on Apple Silicon.
See qupath#118
We can't verify this, but it does improve things substantially when running the Intel build on Apple Silicon (still required for CZI files, for example) - so is beneficial anyway.
@petebankhead
Copy link
Member

I've merged #120 to exclude the system property check, so mps should be available on any Mac.

We can't check that it always works, but at least it shouldn't be made unnecessarily unavailable when it would work.

I think that should resolve the main issue here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants