From 4300366edc7cfc99b42ddf1fdd3a6d5ab3a0e682 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Fri, 21 Jun 2024 15:51:39 +0100 Subject: [PATCH 01/28] Sketch API --- build.gradle | 1 + .../qupath/ext/instanseg/core/InstanSeg.java | 28 +++++++++++++++++++ .../ext/instanseg/core/InstanSegModel.java | 20 +++++++------ .../ext/instanseg/ui/InstanSegController.java | 15 ++++++++-- .../java/qupath/ext/instanseg/ui/Watcher.java | 2 +- 5 files changed, 54 insertions(+), 12 deletions(-) create mode 100644 src/main/java/qupath/ext/instanseg/core/InstanSeg.java diff --git a/build.gradle b/build.gradle index c017a79..fa2ec0b 100644 --- a/build.gradle +++ b/build.gradle @@ -35,6 +35,7 @@ ext.qupathJavaVersion = 21 * but shouldn't be bundled up for use in the extension. */ dependencies { + implementation project(path: ':qupath-core') // Main QuPath user interface jar. // Automatically includes other QuPath jars as subdependencies. shadow "io.github.qupath:qupath-gui-fx:${qupathVersion}" diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java new file mode 100644 index 0000000..b4eee4f --- /dev/null +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -0,0 +1,28 @@ +package qupath.ext.instanseg.core; + +import qupath.lib.images.ImageData; +import qupath.lib.images.servers.ColorTransforms; +import qupath.lib.plugins.TaskRunner; +import qupath.lib.scripting.QP; + +import java.util.Collection; + +public class InstanSeg { + // todo: chainable setters, and throw warnings for any missing params...? + int tileDims = 512; + double downsample = 1; + int padding = 40; + int boundary = 20; + boolean twoChannel; + ImageData imageData; + Collection channels; + TaskRunner taskRunner; + InstanSegModel model; + + // todo: API like eg... new InstanSeg().padding(40).channels(List.of(1, 2, 3)).runInference() ? + + public void runInference() { + + } + +} diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index 08d87fa..81edccb 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -31,7 +31,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; @@ -56,10 +55,14 @@ public InstanSegModel(URL modelURL, String name) { this.name = name; } - public static InstanSegModel createModel(Path path) throws IOException { + public static InstanSegModel fromPath(Path path) throws IOException { return new InstanSegModel(BioimageIoSpec.parseModel(path.toFile())); } + public static InstanSeg fromName(String name) { + // todo: instantiate built-in models somehow! + } + public BioimageIoSpec.BioimageIoModel getModel() { if (model == null) { try { @@ -132,10 +135,10 @@ public String toString() { public void runInstanSeg( Collection pathObjects, ImageData imageData, - List channels, - int tileSize, + Collection channels, + int tileDims, double downsample, - String deviceName, + Device device, boolean nucleiOnly, TaskRunner taskRunner) throws ModelNotFoundException, MalformedModelException, IOException, InterruptedException { @@ -145,7 +148,7 @@ public void runInstanSeg( int padding = 40; // todo: setting? or just based on tile size. Should discuss. int boundary = 20; - if (tileSize == 128) { + if (tileDims == 128) { padding = 25; boundary = 15; } @@ -156,7 +159,6 @@ public void runInstanSeg( // TODO: Remove C if not needed (added for instanseg_v0_2_0.pt) - still relevant? String layoutOutput = "CHW"; - var device = Device.fromName(deviceName); try (var model = Criteria.builder() .setTypes(Mat.class, Mat.class) @@ -181,9 +183,9 @@ public void runInstanSeg( printResourceCount("Resource count after creating predictors", (BaseNDManager)baseManager.getParentManager()); - int sizeWithoutPadding = (int) Math.ceil(downsample * (tileSize - (double) padding)); + int sizeWithoutPadding = (int) Math.ceil(downsample * (tileDims - (double) padding)); var predictionProcessor = new TilePredictionProcessor(predictors, baseManager, - layout, layoutOutput, channels, tileSize, tileSize, padToInputSize); + layout, layoutOutput, channels, tileDims, tileDims, padToInputSize); var processor = OpenCVProcessor.builder(predictionProcessor) .imageSupplier((parameters) -> ImageOps.buildImageDataOp(channels).apply(parameters.getImageData(), parameters.getRegionRequest())) .tiler(Tiler.builder(sizeWithoutPadding) diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index 3c2f17a..2dac02c 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -1,5 +1,6 @@ package qupath.ext.instanseg.ui; +import ai.djl.Device; import ai.djl.MalformedModelException; import ai.djl.repository.zoo.ModelNotFoundException; import javafx.application.Platform; @@ -144,6 +145,16 @@ private void updateChannelPicker(ImageData imageData) { comboChannels.getCheckModel().checkIndices(IntStream.range(0, imageData.getServer().nChannels()).toArray()); } + private static void addToHistoryWorkflow(ImageData imageData) { + // todo: need to instantiate the model, then run it... + // imageData.getHistoryWorkflow() + // .addStep( + // new DefaultScriptableWorkflowStep( + // resources.getString("workflow.title"), + // WSInfer.class.getName() + ".runInference(\"" + modelName + "\")" + // )); + } + private static String getCheckComboBoxText(CheckComboBox comboBox) { int n = comboBox.getCheckModel().getCheckedItems().stream() .filter(Objects::nonNull) @@ -332,7 +343,7 @@ static void addModelsFromPath(String dir, ComboBox box) { try (var ps = Files.list(path)) { for (var file: ps.toList()) { if (InstanSegModel.isValidModel(file)) { - box.getItems().add(InstanSegModel.createModel(file)); + box.getItems().add(InstanSegModel.fromPath(file)); } } } catch (IOException e) { @@ -377,7 +388,7 @@ protected Void call() { selectedChannels, InstanSegPreferences.tileSizeProperty().get(), model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize(), - deviceChoices.getSelectionModel().getSelectedItem(), + Device.fromName(deviceChoices.getSelectionModel().getSelectedItem()), nucleiOnlyCheckBox.isSelected(), QPEx.createTaskRunner(InstanSegPreferences.numThreadsProperty().getValue())); } catch (ModelNotFoundException | MalformedModelException | diff --git a/src/main/java/qupath/ext/instanseg/ui/Watcher.java b/src/main/java/qupath/ext/instanseg/ui/Watcher.java index 5519970..967975b 100644 --- a/src/main/java/qupath/ext/instanseg/ui/Watcher.java +++ b/src/main/java/qupath/ext/instanseg/ui/Watcher.java @@ -88,7 +88,7 @@ void processEvents() { if (kind == ENTRY_CREATE && InstanSegModel.isValidModel(name)) { try { - modelChoiceBox.getItems().add(InstanSegModel.createModel(child)); + modelChoiceBox.getItems().add(InstanSegModel.fromPath(child)); } catch (IOException e) { logger.error("Unable to add model", e); } From db2f81eba9acb62015b5b495850f0c8a2b3b1b1d Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Fri, 21 Jun 2024 15:52:08 +0100 Subject: [PATCH 02/28] Shutup error --- src/main/java/qupath/ext/instanseg/core/InstanSegModel.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index 81edccb..f3a5e4f 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -61,6 +61,7 @@ public static InstanSegModel fromPath(Path path) throws IOException { public static InstanSeg fromName(String name) { // todo: instantiate built-in models somehow! + return null; } public BioimageIoSpec.BioimageIoModel getModel() { From 93fec877805f6222728598a610d839afcddc67b1 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Mon, 24 Jun 2024 16:16:49 +0100 Subject: [PATCH 03/28] Draft API with builder pattern --- .../qupath/ext/instanseg/core/InstanSeg.java | 154 ++++++++++++++++-- .../ext/instanseg/core/InstanSegModel.java | 21 ++- .../ext/instanseg/ui/InstanSegController.java | 49 ++++-- 3 files changed, 187 insertions(+), 37 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index b4eee4f..010779a 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -1,28 +1,156 @@ package qupath.ext.instanseg.core; +import ai.djl.Device; +import ai.djl.MalformedModelException; +import ai.djl.repository.zoo.ModelNotFoundException; import qupath.lib.images.ImageData; import qupath.lib.images.servers.ColorTransforms; import qupath.lib.plugins.TaskRunner; +import qupath.lib.plugins.TaskRunnerUtils; import qupath.lib.scripting.QP; +import java.awt.image.BufferedImage; +import java.io.IOException; +import java.nio.file.Path; import java.util.Collection; +import java.util.stream.IntStream; public class InstanSeg { - // todo: chainable setters, and throw warnings for any missing params...? - int tileDims = 512; - double downsample = 1; - int padding = 40; - int boundary = 20; - boolean twoChannel; - ImageData imageData; - Collection channels; - TaskRunner taskRunner; - InstanSegModel model; + private int tileDims; + private double downsample; + private int padding; + private int boundary; + private int numOutputChannels; + private ImageData imageData; + private Collection channels; + private TaskRunner taskRunner; + private InstanSegModel model; + private Device device; - // todo: API like eg... new InstanSeg().padding(40).channels(List.of(1, 2, 3)).runInference() ? - - public void runInference() { + public static Builder builder() { + return new Builder(); + } + public void detectObjects() throws ModelNotFoundException, MalformedModelException, IOException, InterruptedException { + model.runInstanSeg( + QP.getSelectedObjects(), + imageData, + channels, + tileDims, + downsample, + padding, + boundary, + device, + numOutputChannels == 1, + taskRunner + ); } + public static final class Builder { + private int tileDims = 512; + private double downsample = 1; + private int padding = 40; + private int boundary = 20; + private int numOutputChannels = 2; + private ImageData imageData = QP.getCurrentImageData(); + // todo: set default? + private Collection channels; + private TaskRunner taskRunner = TaskRunnerUtils.getDefaultInstance().createTaskRunner(); + private Device device; + private InstanSegModel model; + + Builder() {} + + public Builder tileDims(int tileDims) { + this.tileDims = tileDims; + return this; + } + + public Builder downsample(double downsample) { + this.downsample = downsample; + return this; + } + + public Builder interTilePadding(int padding) { + this.padding = padding; + return this; + } + + public Builder tileBoundary(int boundary) { + this.boundary = boundary; + return this; + } + + public Builder numOutputChannels(int numOutputChannels) { + this.numOutputChannels = numOutputChannels; + return this; + } + + public Builder imageData(ImageData imageData) { + this.imageData = imageData; + return this; + } + + public Builder channels(Collection channels) { + this.channels = channels; + return this; + } + + public Builder allChannels() { + // todo: lazy eval this? + return channelIndices(IntStream.of(imageData.getServer().nChannels()).boxed().toList()); + } + + public Builder channelIndices(Collection channels) { + this.channels = channels.stream() + .map(ColorTransforms::createChannelExtractor) + .toList(); + return this; + } + + public Builder channelNames(Collection channels) { + this.channels = channels.stream() + .map(ColorTransforms::createChannelExtractor) + .toList(); + return this; + } + + public Builder taskRunner(TaskRunner taskRunner) { + this.taskRunner = taskRunner; + return this; + } + + public Builder model(InstanSegModel model) { + this.model = model; + return this; + } + + public Builder modelPath(Path path) throws IOException { + return model(InstanSegModel.fromPath(path)); + } + + public Builder modelName(String name) { + return model(InstanSegModel.fromName(name)); + } + + public Builder device(String device) { + this.device = Device.fromName(device); + return this; + } + + public InstanSeg build() { + InstanSeg instanSeg = new InstanSeg(); + instanSeg.channels = this.channels; + instanSeg.taskRunner = this.taskRunner; + instanSeg.numOutputChannels = this.numOutputChannels; + instanSeg.tileDims = this.tileDims; + instanSeg.boundary = this.boundary; + instanSeg.model = this.model; + instanSeg.padding = this.padding; + instanSeg.downsample = this.downsample; + instanSeg.imageData = this.imageData; + instanSeg.device = this.device; + return instanSeg; + } + } } diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index f3a5e4f..8ec2f96 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -59,8 +59,8 @@ public static InstanSegModel fromPath(Path path) throws IOException { return new InstanSegModel(BioimageIoSpec.parseModel(path.toFile())); } - public static InstanSeg fromName(String name) { - // todo: instantiate built-in models somehow! + public static InstanSegModel fromName(String name) { + // todo: instantiate built-in models somehow return null; } @@ -133,12 +133,14 @@ public String toString() { return getName(); } - public void runInstanSeg( + void runInstanSeg( Collection pathObjects, ImageData imageData, Collection channels, int tileDims, double downsample, + int padding, + int boundary, Device device, boolean nucleiOnly, TaskRunner taskRunner) throws ModelNotFoundException, MalformedModelException, IOException, InterruptedException { @@ -147,12 +149,13 @@ public void runInstanSeg( Path modelPath = getPath().resolve("instanseg.pt"); int nPredictors = 1; // todo: change me? - int padding = 40; // todo: setting? or just based on tile size. Should discuss. - int boundary = 20; - if (tileDims == 128) { - padding = 25; - boundary = 15; - } + // int padding = 40; // todo: setting? or just based on tile size. Should discuss. + // int boundary = 20; + // if (tileDims == 128) { + // padding = 25; + // boundary = 15; + // } + // Optionally pad images to the required size boolean padToInputSize = true; String layout = "CHW"; diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index 2dac02c..79903d6 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -1,6 +1,5 @@ package qupath.ext.instanseg.ui; -import ai.djl.Device; import ai.djl.MalformedModelException; import ai.djl.repository.zoo.ModelNotFoundException; import javafx.application.Platform; @@ -30,6 +29,7 @@ import org.controlsfx.control.SearchableComboBox; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import qupath.ext.instanseg.core.InstanSeg; import qupath.ext.instanseg.core.InstanSegModel; import qupath.fx.dialogs.Dialogs; import qupath.fx.dialogs.FileChoosers; @@ -42,6 +42,7 @@ import qupath.lib.images.ImageData; import qupath.lib.images.servers.ColorTransforms; import qupath.lib.images.servers.ImageServer; +import qupath.lib.plugins.workflow.DefaultScriptableWorkflowStep; import qupath.lib.scripting.QP; import java.awt.image.BufferedImage; @@ -147,12 +148,7 @@ private void updateChannelPicker(ImageData imageData) { private static void addToHistoryWorkflow(ImageData imageData) { // todo: need to instantiate the model, then run it... - // imageData.getHistoryWorkflow() - // .addStep( - // new DefaultScriptableWorkflowStep( - // resources.getString("workflow.title"), - // WSInfer.class.getName() + ".runInference(\"" + modelName + "\")" - // )); + } private static String getCheckComboBoxText(CheckComboBox comboBox) { @@ -367,6 +363,7 @@ private void runInstanSeg() { var model = modelChoiceBox.getSelectionModel().getSelectedItem(); ImageServer server = qupath.getImageData().getServer(); + // todo: how to record this in workflow? List selectedChannels = comboChannels .getCheckModel().getCheckedItems() .stream() @@ -382,15 +379,37 @@ protected Void call() { downloadPyTorch(); } try { - model.runInstanSeg( - QP.getSelectedObjects(), - QP.getCurrentImageData(), - selectedChannels, + String cmd = String.format(""" + def instanSeg = InstanSeg.builder() + .modelPath("%s") + .device("%s") + .numOutputChannels(%d) + .channels(selectedChannels) + .tileDims(%d) + .downsample(%f) + .build(); + """, + model.getPath(), + deviceChoices.getSelectionModel().getSelectedItem(), + nucleiOnlyCheckBox.isSelected() ? 1:2, + // todo: channels, InstanSegPreferences.tileSizeProperty().get(), - model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize(), - Device.fromName(deviceChoices.getSelectionModel().getSelectedItem()), - nucleiOnlyCheckBox.isSelected(), - QPEx.createTaskRunner(InstanSegPreferences.numThreadsProperty().getValue())); + model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize() + ); + QP.getCurrentImageData().getHistoryWorkflow() + .addStep( + new DefaultScriptableWorkflowStep(resources.getString("workflow.title"), cmd) + ); + var instanSeg = InstanSeg.builder() + .model(model) // todo: set this in workflow somehow + .device(deviceChoices.getSelectionModel().getSelectedItem()) + .numOutputChannels(nucleiOnlyCheckBox.isSelected() ? 1:2) + .channels(selectedChannels) + .tileDims(InstanSegPreferences.tileSizeProperty().get()) + .downsample(model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize()) + .taskRunner(QPEx.createTaskRunner(InstanSegPreferences.numThreadsProperty().getValue())) + .build(); + instanSeg.detectObjects(); } catch (ModelNotFoundException | MalformedModelException | IOException | InterruptedException e) { Dialogs.showErrorMessage("Unable to run InstanSeg", e); From 512c6123c4477946bc2521ac72cc7bdcb536748b Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Tue, 25 Jun 2024 14:06:21 +0100 Subject: [PATCH 04/28] Hacky way to instantiate channel selectors --- .../qupath/ext/instanseg/core/InstanSeg.java | 95 +++++++++++++------ .../ext/instanseg/ui/ChannelSelectItem.java | 27 +++++- .../ext/instanseg/ui/InstanSegController.java | 43 +++++---- 3 files changed, 115 insertions(+), 50 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index 010779a..d8a57fc 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -3,10 +3,9 @@ import ai.djl.Device; import ai.djl.MalformedModelException; import ai.djl.repository.zoo.ModelNotFoundException; +import qupath.lib.gui.scripting.QPEx; import qupath.lib.images.ImageData; import qupath.lib.images.servers.ColorTransforms; -import qupath.lib.plugins.TaskRunner; -import qupath.lib.plugins.TaskRunnerUtils; import qupath.lib.scripting.QP; import java.awt.image.BufferedImage; @@ -16,16 +15,17 @@ import java.util.stream.IntStream; public class InstanSeg { - private int tileDims; - private double downsample; - private int padding; - private int boundary; - private int numOutputChannels; - private ImageData imageData; - private Collection channels; - private TaskRunner taskRunner; - private InstanSegModel model; - private Device device; + private final int tileDims; + private final double downsample; + private final int padding; + private final int boundary; + private final int numOutputChannels; + private final ImageData imageData; + private final Collection channels; + private final int nThreads; + private final InstanSegModel model; + private final Device device; + public static Builder builder() { return new Builder(); @@ -42,7 +42,7 @@ public void detectObjects() throws ModelNotFoundException, MalformedModelExcepti boundary, device, numOutputChannels == 1, - taskRunner + QPEx.createTaskRunner(nThreads) ); } @@ -53,11 +53,10 @@ public static final class Builder { private int boundary = 20; private int numOutputChannels = 2; private ImageData imageData = QP.getCurrentImageData(); - // todo: set default? private Collection channels; - private TaskRunner taskRunner = TaskRunnerUtils.getDefaultInstance().createTaskRunner(); private Device device; private InstanSegModel model; + private int nThreads = 1; Builder() {} @@ -91,6 +90,11 @@ public Builder imageData(ImageData imageData) { return this; } + public Builder currentImageData() { + this.imageData = QP.getCurrentImageData(); + return this; + } + public Builder channels(Collection channels) { this.channels = channels; return this; @@ -115,8 +119,8 @@ public Builder channelNames(Collection channels) { return this; } - public Builder taskRunner(TaskRunner taskRunner) { - this.taskRunner = taskRunner; + public Builder nThreads(int nThreads) { + this.nThreads = nThreads; return this; } @@ -133,24 +137,53 @@ public Builder modelName(String name) { return model(InstanSegModel.fromName(name)); } - public Builder device(String device) { - this.device = Device.fromName(device); + public Builder device(String deviceName) { + this.device = Device.fromName(deviceName); + return this; + } + + public Builder device(Device device) { + this.device = device; return this; } public InstanSeg build() { - InstanSeg instanSeg = new InstanSeg(); - instanSeg.channels = this.channels; - instanSeg.taskRunner = this.taskRunner; - instanSeg.numOutputChannels = this.numOutputChannels; - instanSeg.tileDims = this.tileDims; - instanSeg.boundary = this.boundary; - instanSeg.model = this.model; - instanSeg.padding = this.padding; - instanSeg.downsample = this.downsample; - instanSeg.imageData = this.imageData; - instanSeg.device = this.device; - return instanSeg; + if (imageData == null) { + this.currentImageData(); + } + if (channels == null) { + allChannels(); + } + if (device == null) { + device("cpu"); + } + return new InstanSeg( + this.tileDims, + this.downsample, + this.padding, + this.boundary, + this.numOutputChannels, + this.imageData, + this.channels, + this.nThreads, + this.model, + this.device + ); } + + } + + private InstanSeg(int tileDims, double downsample, int padding, int boundary, int numOutputChannels, ImageData imageData, + Collection channels, int nThreads, InstanSegModel model, Device device) { + this.tileDims = tileDims; + this.downsample = downsample; + this.padding = padding; + this.boundary = boundary; + this.numOutputChannels = numOutputChannels; + this.imageData = imageData; + this.channels = channels; + this.nThreads = nThreads; + this.model = model; + this.device = device; } } diff --git a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java index b5f2dcf..e5e55ae 100644 --- a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java +++ b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java @@ -1,7 +1,11 @@ package qupath.ext.instanseg.ui; +import qupath.lib.color.ColorDeconvolutionStains; import qupath.lib.images.servers.ColorTransforms; +import java.util.Collection; +import java.util.stream.Collectors; + /** * Super simple class to deal with channel selection dropdown items that have different display and selection names. * e.g., the first channel in non-RGB images is shown as "Channel 1 (C1)" but the actual name is "Channel 1". @@ -9,14 +13,25 @@ class ChannelSelectItem { private final String name; private final ColorTransforms.ColorTransform transform; + private final String constructor; + + // todo: public method to get a constructor for the colortransform ChannelSelectItem(String name) { this.name = name; this.transform = ColorTransforms.createChannelExtractor(name); + this.constructor = String.format("ColorTransforms.createChannelExtractor(\"%s\")", name); } - ChannelSelectItem(String name, ColorTransforms.ColorTransform transform) { + ChannelSelectItem(String name, int i) { this.name = name; - this.transform = transform; + this.transform = ColorTransforms.createChannelExtractor(i); + this.constructor = String.format("ColorTransforms.createChannelExtractor(%d)", i); + } + + ChannelSelectItem(ColorDeconvolutionStains stains, int i) { + this.name = stains.getStain(i).getName(); + this.transform = ColorTransforms.createColorDeconvolvedChannel(stains, i); + this.constructor = String.format("ColorTransforms.createColorDeconvolvedChannel(stains, %d)", i); } @Override @@ -31,4 +46,12 @@ public String getName() { public ColorTransforms.ColorTransform getTransform() { return transform; } + + public String getConstructor() { + return this.constructor; + } + + public static String toConstructorString(Collection items) { + return "List.of(" + items.stream().map(ChannelSelectItem::getConstructor).collect(Collectors.joining(", ")) + ")"; + } } diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index 79903d6..6964986 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -37,10 +37,8 @@ import qupath.lib.common.ThreadTools; import qupath.lib.display.ChannelDisplayInfo; import qupath.lib.gui.QuPathGUI; -import qupath.lib.gui.scripting.QPEx; import qupath.lib.gui.tools.GuiTools; import qupath.lib.images.ImageData; -import qupath.lib.images.servers.ColorTransforms; import qupath.lib.images.servers.ImageServer; import qupath.lib.plugins.workflow.DefaultScriptableWorkflowStep; import qupath.lib.scripting.QP; @@ -172,6 +170,11 @@ private void addSetFromVisible(CheckComboBox comboChannels) { var channelNames = activeChannels.stream() .map(ChannelDisplayInfo::getName) .toList(); + if (qupath.getImageData() != null && !qupath.getImageData().getServer().isRGB()) { + channelNames = channelNames.stream() + .map(s -> s.replaceAll(" \\(C\\d+\\)$", "")) + .toList(); + } var comboItems = comboChannels.getItems(); for (int i = 0; i < comboItems.size(); i++) { if (channelNames.contains(comboItems.get(i).getName())) { @@ -197,29 +200,30 @@ private static Collection getAvailableChannels(ImageData i var server = imageData.getServer(); int i = 1; boolean hasDuplicates = false; + ChannelSelectItem item; for (var channel : server.getMetadata().getChannels()) { var name = channel.getName(); - var transform = ColorTransforms.createChannelExtractor(name); if (names.contains(name)) { logger.warn("Found duplicate channel name! Channel " + i + " (name '" + name + "')."); logger.warn("Using channel indices instead of names because of duplicated channel names."); hasDuplicates = true; } names.add(name); + // if (!server.isRGB()) { + // name += " (C" + i + ")"; + // } if (hasDuplicates) { - transform = ColorTransforms.createChannelExtractor(i - 1); - } - if (!server.isRGB()) { - name += " (C" + i + ")"; + item = new ChannelSelectItem(name, i - 1); + } else { + item = new ChannelSelectItem(name); } - list.add(new ChannelSelectItem(name, transform)); + list.add(item); i++; } var stains = imageData.getColorDeconvolutionStains(); if (stains != null) { for (i = 1; i < 4; i++) { - var transform = ColorTransforms.createColorDeconvolvedChannel(stains, i); - list.add(new ChannelSelectItem(transform.getName(), transform)); + list.add(new ChannelSelectItem(stains, i)); } } return list; @@ -364,11 +368,11 @@ private void runInstanSeg() { var model = modelChoiceBox.getSelectionModel().getSelectedItem(); ImageServer server = qupath.getImageData().getServer(); // todo: how to record this in workflow? - List selectedChannels = comboChannels + List selectedChannels = comboChannels .getCheckModel().getCheckedItems() .stream() .filter(Objects::nonNull) - .map(ChannelSelectItem::getTransform) + // .map(ChannelSelectItem::getTransform) .toList(); var task = new Task() { @@ -380,34 +384,39 @@ protected Void call() { } try { String cmd = String.format(""" + var channels = %s; def instanSeg = InstanSeg.builder() .modelPath("%s") .device("%s") .numOutputChannels(%d) - .channels(selectedChannels) + .channels(channels) .tileDims(%d) .downsample(%f) + .nthreads(%d) .build(); + instanSeg.detectObjects(); """, + ChannelSelectItem.toConstructorString(selectedChannels), model.getPath(), deviceChoices.getSelectionModel().getSelectedItem(), nucleiOnlyCheckBox.isSelected() ? 1:2, // todo: channels, InstanSegPreferences.tileSizeProperty().get(), - model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize() + model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize(), + InstanSegPreferences.numThreadsProperty().getValue() ); QP.getCurrentImageData().getHistoryWorkflow() .addStep( new DefaultScriptableWorkflowStep(resources.getString("workflow.title"), cmd) ); var instanSeg = InstanSeg.builder() - .model(model) // todo: set this in workflow somehow + .model(model) .device(deviceChoices.getSelectionModel().getSelectedItem()) .numOutputChannels(nucleiOnlyCheckBox.isSelected() ? 1:2) - .channels(selectedChannels) + .channels(selectedChannels.stream().map(ChannelSelectItem::getTransform).toList()) .tileDims(InstanSegPreferences.tileSizeProperty().get()) .downsample(model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize()) - .taskRunner(QPEx.createTaskRunner(InstanSegPreferences.numThreadsProperty().getValue())) + .nThreads(InstanSegPreferences.numThreadsProperty().getValue()) .build(); instanSeg.detectObjects(); } catch (ModelNotFoundException | MalformedModelException | From 616829da0e737e74d61ddf54c16032e000d84d9a Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Tue, 25 Jun 2024 14:07:39 +0100 Subject: [PATCH 05/28] Remove qupath-core from build --- build.gradle | 1 - 1 file changed, 1 deletion(-) diff --git a/build.gradle b/build.gradle index fa2ec0b..c017a79 100644 --- a/build.gradle +++ b/build.gradle @@ -35,7 +35,6 @@ ext.qupathJavaVersion = 21 * but shouldn't be bundled up for use in the extension. */ dependencies { - implementation project(path: ':qupath-core') // Main QuPath user interface jar. // Automatically includes other QuPath jars as subdependencies. shadow "io.github.qupath:qupath-gui-fx:${qupathVersion}" From b255fa768d0b16c58de024078233d6309d11eb44 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Tue, 25 Jun 2024 14:21:37 +0100 Subject: [PATCH 06/28] Add jdoc and make less public --- .../qupath/ext/instanseg/core/InstanSeg.java | 91 ++++++++++++++++++- .../ext/instanseg/core/InstanSegModel.java | 10 +- .../ext/instanseg/ui/ChannelSelectItem.java | 8 +- .../ext/instanseg/ui/InstanSegController.java | 86 +++++++----------- 4 files changed, 127 insertions(+), 68 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index d8a57fc..687bd24 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -1,8 +1,6 @@ package qupath.ext.instanseg.core; import ai.djl.Device; -import ai.djl.MalformedModelException; -import ai.djl.repository.zoo.ModelNotFoundException; import qupath.lib.gui.scripting.QPEx; import qupath.lib.images.ImageData; import qupath.lib.images.servers.ColorTransforms; @@ -31,7 +29,7 @@ public static Builder builder() { return new Builder(); } - public void detectObjects() throws ModelNotFoundException, MalformedModelException, IOException, InterruptedException { + public void detectObjects() { model.runInstanSeg( QP.getSelectedObjects(), imageData, @@ -60,51 +58,99 @@ public static final class Builder { Builder() {} + /** + * Set the width and height of tiles + * @param tileDims The tile width and height + * @return A modified builder + */ public Builder tileDims(int tileDims) { this.tileDims = tileDims; return this; } + /** + * Set the downsample to be used in region requests + * @param downsample The downsample to be used + * @return A modified builder + */ public Builder downsample(double downsample) { this.downsample = downsample; return this; } + /** + * Set the padding (overlap) between tiles + * @param padding The extra size added to tiles to allow overlap + * @return A modified builder + */ public Builder interTilePadding(int padding) { this.padding = padding; return this; } + /** + * Set the size of the overlap region between tiles + * @param boundary The width in pixels that overlaps between tiles + * @return A modified builder + */ public Builder tileBoundary(int boundary) { this.boundary = boundary; return this; } + /** + * Set the number of output channels + * @param numOutputChannels The number of output channels (1 or 2 currently) + * @return A modified builder + */ public Builder numOutputChannels(int numOutputChannels) { this.numOutputChannels = numOutputChannels; return this; } + /** + * Set the imageData to be used + * @param imageData An imageData instance + * @return A modified builder + */ public Builder imageData(ImageData imageData) { this.imageData = imageData; return this; } + /** + * Set the imageData to the currently visible one + * @return A modified builder + */ public Builder currentImageData() { this.imageData = QP.getCurrentImageData(); return this; } + /** + * Set the channels to be used in inference + * @param channels A collection of channels to be used in inference + * @return A modified builder + */ public Builder channels(Collection channels) { this.channels = channels; return this; } + /** + * Set the model to use all channels for inference + * @return A modified builder + */ public Builder allChannels() { // todo: lazy eval this? return channelIndices(IntStream.of(imageData.getServer().nChannels()).boxed().toList()); } + /** + * Set the channels using indices + * @param channels Integers used to specify the channels used + * @return A modified builder + */ public Builder channelIndices(Collection channels) { this.channels = channels.stream() .map(ColorTransforms::createChannelExtractor) @@ -112,6 +158,11 @@ public Builder channelIndices(Collection channels) { return this; } + /** + * Set the channel names to be used + * @param channels A set of channel names + * @return A modified builder + */ public Builder channelNames(Collection channels) { this.channels = channels.stream() .map(ColorTransforms::createChannelExtractor) @@ -119,34 +170,68 @@ public Builder channelNames(Collection channels) { return this; } + /** + * Set the number of threads used + * @param nThreads The number of threads to be used + * @return A modified builder + */ public Builder nThreads(int nThreads) { this.nThreads = nThreads; return this; } + /** + * Set the specific model to be used + * @param model An already instantiated InstanSeg model. + * @return A modified builder + */ public Builder model(InstanSegModel model) { this.model = model; return this; } + /** + * Set the specific model by path + * @param path A path on disk to create an InstanSeg model from. + * @return A modified builder + */ public Builder modelPath(Path path) throws IOException { return model(InstanSegModel.fromPath(path)); } + /** + * Set the specific model to be used + * @param name The name of a built-in model + * @return A modified builder + */ public Builder modelName(String name) { return model(InstanSegModel.fromName(name)); } + /** + * Set the device to be used + * @param deviceName The name of the device to be used (eg, "gpu", "mps"). + * @return A modified builder + */ public Builder device(String deviceName) { this.device = Device.fromName(deviceName); return this; } + /** + * Set the device to be used + * @param device The {@link Device} to be used + * @return A modified builder + */ public Builder device(Device device) { this.device = device; return this; } + /** + * Build the InstanSeg instance. + * @return An InstanSeg instance ready for object detection. + */ public InstanSeg build() { if (imageData == null) { this.currentImageData(); diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index 8ec2f96..b353bba 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -1,11 +1,9 @@ package qupath.ext.instanseg.core; import ai.djl.Device; -import ai.djl.MalformedModelException; import ai.djl.inference.Predictor; import ai.djl.ndarray.BaseNDManager; import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.training.util.ProgressBar; import com.google.gson.internal.LinkedTreeMap; import org.bytedeco.opencv.opencv_core.Mat; @@ -143,18 +141,12 @@ void runInstanSeg( int boundary, Device device, boolean nucleiOnly, - TaskRunner taskRunner) throws ModelNotFoundException, MalformedModelException, IOException, InterruptedException { + TaskRunner taskRunner) { nFailed = 0; Path modelPath = getPath().resolve("instanseg.pt"); int nPredictors = 1; // todo: change me? - // int padding = 40; // todo: setting? or just based on tile size. Should discuss. - // int boundary = 20; - // if (tileDims == 128) { - // padding = 25; - // boundary = 15; - // } // Optionally pad images to the required size boolean padToInputSize = true; diff --git a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java index e5e55ae..c0c311a 100644 --- a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java +++ b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java @@ -39,19 +39,19 @@ public String toString() { return this.name; } - public String getName() { + String getName() { return name; } - public ColorTransforms.ColorTransform getTransform() { + ColorTransforms.ColorTransform getTransform() { return transform; } - public String getConstructor() { + private String getConstructor() { return this.constructor; } - public static String toConstructorString(Collection items) { + static String toConstructorString(Collection items) { return "List.of(" + items.stream().map(ChannelSelectItem::getConstructor).collect(Collectors.joining(", ")) + ")"; } } diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index 6964986..6ca94d1 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -1,7 +1,5 @@ package qupath.ext.instanseg.ui; -import ai.djl.MalformedModelException; -import ai.djl.repository.zoo.ModelNotFoundException; import javafx.application.Platform; import javafx.beans.binding.Bindings; import javafx.beans.property.ObjectProperty; @@ -144,11 +142,6 @@ private void updateChannelPicker(ImageData imageData) { comboChannels.getCheckModel().checkIndices(IntStream.range(0, imageData.getServer().nChannels()).toArray()); } - private static void addToHistoryWorkflow(ImageData imageData) { - // todo: need to instantiate the model, then run it... - - } - private static String getCheckComboBoxText(CheckComboBox comboBox) { int n = comboBox.getCheckModel().getCheckedItems().stream() .filter(Objects::nonNull) @@ -209,9 +202,6 @@ private static Collection getAvailableChannels(ImageData i hasDuplicates = true; } names.add(name); - // if (!server.isRGB()) { - // name += " (C" + i + ")"; - // } if (hasDuplicates) { item = new ChannelSelectItem(name, i - 1); } else { @@ -372,7 +362,6 @@ private void runInstanSeg() { .getCheckModel().getCheckedItems() .stream() .filter(Objects::nonNull) - // .map(ChannelSelectItem::getTransform) .toList(); var task = new Task() { @@ -382,48 +371,41 @@ protected Void call() { if (!PytorchManager.hasPyTorchEngine()) { downloadPyTorch(); } - try { - String cmd = String.format(""" - var channels = %s; - def instanSeg = InstanSeg.builder() - .modelPath("%s") - .device("%s") - .numOutputChannels(%d) - .channels(channels) - .tileDims(%d) - .downsample(%f) - .nthreads(%d) - .build(); - instanSeg.detectObjects(); - """, - ChannelSelectItem.toConstructorString(selectedChannels), - model.getPath(), - deviceChoices.getSelectionModel().getSelectedItem(), - nucleiOnlyCheckBox.isSelected() ? 1:2, - // todo: channels, - InstanSegPreferences.tileSizeProperty().get(), - model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize(), - InstanSegPreferences.numThreadsProperty().getValue() - ); - QP.getCurrentImageData().getHistoryWorkflow() - .addStep( - new DefaultScriptableWorkflowStep(resources.getString("workflow.title"), cmd) - ); - var instanSeg = InstanSeg.builder() - .model(model) - .device(deviceChoices.getSelectionModel().getSelectedItem()) - .numOutputChannels(nucleiOnlyCheckBox.isSelected() ? 1:2) - .channels(selectedChannels.stream().map(ChannelSelectItem::getTransform).toList()) - .tileDims(InstanSegPreferences.tileSizeProperty().get()) - .downsample(model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize()) - .nThreads(InstanSegPreferences.numThreadsProperty().getValue()) + String cmd = String.format(""" + var channels = %s; + def instanSeg = InstanSeg.builder() + .modelPath("%s") + .device("%s") + .numOutputChannels(%d) + .channels(channels) + .tileDims(%d) + .downsample(%f) + .nthreads(%d) .build(); - instanSeg.detectObjects(); - } catch (ModelNotFoundException | MalformedModelException | - IOException | InterruptedException e) { - Dialogs.showErrorMessage("Unable to run InstanSeg", e); - logger.error("Unable to run InstanSeg", e); - } + instanSeg.detectObjects(); + """, + ChannelSelectItem.toConstructorString(selectedChannels), + model.getPath(), + deviceChoices.getSelectionModel().getSelectedItem(), + nucleiOnlyCheckBox.isSelected() ? 1:2, + InstanSegPreferences.tileSizeProperty().get(), + model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize(), + InstanSegPreferences.numThreadsProperty().getValue() + ); + QP.getCurrentImageData().getHistoryWorkflow() + .addStep( + new DefaultScriptableWorkflowStep(resources.getString("workflow.title"), cmd) + ); + var instanSeg = InstanSeg.builder() + .model(model) + .device(deviceChoices.getSelectionModel().getSelectedItem()) + .numOutputChannels(nucleiOnlyCheckBox.isSelected() ? 1:2) + .channels(selectedChannels.stream().map(ChannelSelectItem::getTransform).toList()) + .tileDims(InstanSegPreferences.tileSizeProperty().get()) + .downsample(model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize()) + .nThreads(InstanSegPreferences.numThreadsProperty().getValue()) + .build(); + instanSeg.detectObjects(); QP.fireHierarchyUpdate(); if (model.nFailed() > 0) { var errorMessage = String.format(resources.getString("error.tiles-failed"), model.nFailed()); From 91038c422c7513db3fe5692492220545cd2bdf92 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Tue, 25 Jun 2024 14:29:54 +0100 Subject: [PATCH 07/28] Add overloaded path fun, and fix typo in thread spec --- src/main/java/qupath/ext/instanseg/core/InstanSeg.java | 9 +++++++++ .../qupath/ext/instanseg/ui/InstanSegController.java | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index 687bd24..369cafc 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -199,6 +199,15 @@ public Builder modelPath(Path path) throws IOException { return model(InstanSegModel.fromPath(path)); } + /** + * Set the specific model by path + * @param path A path on disk to create an InstanSeg model from. + * @return A modified builder + */ + public Builder modelPath(String path) throws IOException { + return modelPath(Path.of(path)); + } + /** * Set the specific model to be used * @param name The name of a built-in model diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index 6ca94d1..6f41908 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -380,7 +380,7 @@ protected Void call() { .channels(channels) .tileDims(%d) .downsample(%f) - .nthreads(%d) + .nThreads(%d) .build(); instanSeg.detectObjects(); """, From dbbaec9553c187c8bfae929ca2a9c8d55def7c8c Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Tue, 25 Jun 2024 16:28:31 +0100 Subject: [PATCH 08/28] Simpler groovy --- src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java | 2 +- src/main/java/qupath/ext/instanseg/ui/InstanSegController.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java index c0c311a..853781d 100644 --- a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java +++ b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java @@ -52,6 +52,6 @@ private String getConstructor() { } static String toConstructorString(Collection items) { - return "List.of(" + items.stream().map(ChannelSelectItem::getConstructor).collect(Collectors.joining(", ")) + ")"; + return "[" + items.stream().map(ChannelSelectItem::getConstructor).collect(Collectors.joining(", ")) + "]"; } } diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index 6f41908..64ef672 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -372,7 +372,7 @@ protected Void call() { downloadPyTorch(); } String cmd = String.format(""" - var channels = %s; + def channels = %s; def instanSeg = InstanSeg.builder() .modelPath("%s") .device("%s") From 8a2181c27b67569e15fb3730f966d8e0adc3815a Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Tue, 25 Jun 2024 16:44:43 +0100 Subject: [PATCH 09/28] Add varargs versions --- .../qupath/ext/instanseg/core/InstanSeg.java | 42 ++++++++++++++++++- .../ext/instanseg/ui/ChannelSelectItem.java | 2 +- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index 369cafc..490ea5d 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -9,7 +9,9 @@ import java.awt.image.BufferedImage; import java.io.IOException; import java.nio.file.Path; +import java.util.Arrays; import java.util.Collection; +import java.util.List; import java.util.stream.IntStream; public class InstanSeg { @@ -137,13 +139,25 @@ public Builder channels(Collection channels) { return this; } + /** + * Set the channels to be used in inference + * @param channels Channels to be used in inference + * @return A modified builder + */ + public Builder channels(ColorTransforms.ColorTransform... channels) { + this.channels = List.of(channels); + return this; + } + /** * Set the model to use all channels for inference * @return A modified builder */ public Builder allChannels() { - // todo: lazy eval this? - return channelIndices(IntStream.of(imageData.getServer().nChannels()).boxed().toList()); + return channelIndices( + IntStream.of(imageData.getServer().nChannels()) + .boxed() + .toList()); } /** @@ -158,6 +172,18 @@ public Builder channelIndices(Collection channels) { return this; } + /** + * Set the channels using indices + * @param channels Integers used to specify the channels used + * @return A modified builder + */ + public Builder channelIndices(int... channels) { + this.channels = Arrays.stream(channels).boxed() + .map(ColorTransforms::createChannelExtractor) + .toList(); + return this; + } + /** * Set the channel names to be used * @param channels A set of channel names @@ -170,6 +196,18 @@ public Builder channelNames(Collection channels) { return this; } + /** + * Set the channel names to be used + * @param channels A set of channel names + * @return A modified builder + */ + public Builder channelNames(String... channels) { + this.channels = Arrays.stream(channels) + .map(ColorTransforms::createChannelExtractor) + .toList(); + return this; + } + /** * Set the number of threads used * @param nThreads The number of threads to be used diff --git a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java index 853781d..af67901 100644 --- a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java +++ b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java @@ -47,7 +47,7 @@ ColorTransforms.ColorTransform getTransform() { return transform; } - private String getConstructor() { + String getConstructor() { return this.constructor; } From fcc6d7548559a46e2b6b46f3af17f7b190e35472 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Wed, 26 Jun 2024 03:51:59 +0100 Subject: [PATCH 10/28] Add import --- src/main/java/qupath/ext/instanseg/ui/InstanSegController.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index 64ef672..342c190 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -372,6 +372,8 @@ protected Void call() { downloadPyTorch(); } String cmd = String.format(""" + import qupath.ext.instanseg.core.InstanSeg + def channels = %s; def instanSeg = InstanSeg.builder() .modelPath("%s") From e525f57a46eb6fd8b5054525fdce6edc8a6ebecc Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Wed, 26 Jun 2024 15:41:39 +0100 Subject: [PATCH 11/28] Remove QP uses --- .../qupath/ext/instanseg/core/InstanSeg.java | 16 +++------------- .../ext/instanseg/core/InstanSegModel.java | 4 ++-- .../ext/instanseg/ui/InstanSegController.java | 13 ++++++++----- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index 490ea5d..c568122 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -4,7 +4,6 @@ import qupath.lib.gui.scripting.QPEx; import qupath.lib.images.ImageData; import qupath.lib.images.servers.ColorTransforms; -import qupath.lib.scripting.QP; import java.awt.image.BufferedImage; import java.io.IOException; @@ -33,8 +32,8 @@ public static Builder builder() { public void detectObjects() { model.runInstanSeg( - QP.getSelectedObjects(), imageData, + imageData.getHierarchy().getSelectionModel().getSelectedObjects(), channels, tileDims, downsample, @@ -52,7 +51,7 @@ public static final class Builder { private int padding = 40; private int boundary = 20; private int numOutputChannels = 2; - private ImageData imageData = QP.getCurrentImageData(); + private ImageData imageData; private Collection channels; private Device device; private InstanSegModel model; @@ -120,15 +119,6 @@ public Builder imageData(ImageData imageData) { return this; } - /** - * Set the imageData to the currently visible one - * @return A modified builder - */ - public Builder currentImageData() { - this.imageData = QP.getCurrentImageData(); - return this; - } - /** * Set the channels to be used in inference * @param channels A collection of channels to be used in inference @@ -281,7 +271,7 @@ public Builder device(Device device) { */ public InstanSeg build() { if (imageData == null) { - this.currentImageData(); + throw new IllegalStateException("imageData cannot be null!"); } if (channels == null) { allChannels(); diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index b353bba..60aa782 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -59,7 +59,7 @@ public static InstanSegModel fromPath(Path path) throws IOException { public static InstanSegModel fromName(String name) { // todo: instantiate built-in models somehow - return null; + throw new UnsupportedOperationException("Fetching models by name is not yet implemented!"); } public BioimageIoSpec.BioimageIoModel getModel() { @@ -132,8 +132,8 @@ public String toString() { } void runInstanSeg( - Collection pathObjects, ImageData imageData, + Collection pathObjects, Collection channels, int tileDims, double downsample, diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index 342c190..69b7414 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -39,7 +39,6 @@ import qupath.lib.images.ImageData; import qupath.lib.images.servers.ImageServer; import qupath.lib.plugins.workflow.DefaultScriptableWorkflowStep; -import qupath.lib.scripting.QP; import java.awt.image.BufferedImage; import java.io.File; @@ -381,6 +380,7 @@ protected Void call() { .numOutputChannels(%d) .channels(channels) .tileDims(%d) + .imageData(QP.getCurrentImageData()) .downsample(%f) .nThreads(%d) .build(); @@ -394,12 +394,13 @@ protected Void call() { model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize(), InstanSegPreferences.numThreadsProperty().getValue() ); - QP.getCurrentImageData().getHistoryWorkflow() + qupath.getImageData().getHistoryWorkflow() .addStep( new DefaultScriptableWorkflowStep(resources.getString("workflow.title"), cmd) ); var instanSeg = InstanSeg.builder() .model(model) + .imageData(qupath.getImageData()) .device(deviceChoices.getSelectionModel().getSelectedItem()) .numOutputChannels(nucleiOnlyCheckBox.isSelected() ? 1:2) .channels(selectedChannels.stream().map(ChannelSelectItem::getTransform).toList()) @@ -408,7 +409,7 @@ protected Void call() { .nThreads(InstanSegPreferences.numThreadsProperty().getValue()) .build(); instanSeg.detectObjects(); - QP.fireHierarchyUpdate(); + qupath.getImageData().getHierarchy().fireHierarchyChangedEvent(this); if (model.nFailed() > 0) { var errorMessage = String.format(resources.getString("error.tiles-failed"), model.nFailed()); logger.error(errorMessage); @@ -436,12 +437,14 @@ private void downloadPyTorch() { @FXML private void selectAllAnnotations() { - QP.selectAnnotations(); + var hierarchy = qupath.getImageData().getHierarchy(); + hierarchy.getSelectionModel().setSelectedObjects(hierarchy.getAnnotationObjects(), null); } @FXML private void selectAllTMACores() { - QP.selectTMACores(); + var hierarchy = qupath.getImageData().getHierarchy(); + hierarchy.getSelectionModel().setSelectedObjects(hierarchy.getTMAGrid().getTMACoreList(), null); } /** From 2fb90023ce740f77eb51c496ab341f05a6c50267 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Thu, 27 Jun 2024 14:03:52 +0100 Subject: [PATCH 12/28] Remove QP and QPEx usage --- .../qupath/ext/instanseg/core/InstanSeg.java | 43 ++++-- .../ext/instanseg/ui/InstanSegController.java | 145 +++++++++++------- 2 files changed, 118 insertions(+), 70 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index c568122..d5414a0 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -1,9 +1,12 @@ package qupath.ext.instanseg.core; import ai.djl.Device; -import qupath.lib.gui.scripting.QPEx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import qupath.lib.images.ImageData; import qupath.lib.images.servers.ColorTransforms; +import qupath.lib.plugins.TaskRunner; +import qupath.lib.plugins.TaskRunnerUtils; import java.awt.image.BufferedImage; import java.io.IOException; @@ -14,6 +17,8 @@ import java.util.stream.IntStream; public class InstanSeg { + private static final Logger logger = LoggerFactory.getLogger(InstanSeg.class); + private final int tileDims; private final double downsample; private final int padding; @@ -21,9 +26,9 @@ public class InstanSeg { private final int numOutputChannels; private final ImageData imageData; private final Collection channels; - private final int nThreads; private final InstanSegModel model; private final Device device; + private final TaskRunner taskRunner; public static Builder builder() { @@ -31,6 +36,7 @@ public static Builder builder() { } public void detectObjects() { + // todo: replace createTaskRunner model.runInstanSeg( imageData, imageData.getHierarchy().getSelectionModel().getSelectedObjects(), @@ -41,21 +47,24 @@ public void detectObjects() { boundary, device, numOutputChannels == 1, - QPEx.createTaskRunner(nThreads) + taskRunner ); } + + + public static final class Builder { private int tileDims = 512; private double downsample = 1; private int padding = 40; private int boundary = 20; private int numOutputChannels = 2; + private Device device = Device.fromName("cpu"); private ImageData imageData; private Collection channels; - private Device device; private InstanSegModel model; - private int nThreads = 1; + private TaskRunner taskRunner; Builder() {} @@ -204,7 +213,17 @@ public Builder channelNames(String... channels) { * @return A modified builder */ public Builder nThreads(int nThreads) { - this.nThreads = nThreads; + this.taskRunner = TaskRunnerUtils.getDefaultInstance().createTaskRunner(nThreads); + return this; + } + + /** + * Set the TaskRunner + * @param taskRunner An object that will run tasks and show progress + * @return A modified builder + */ + public Builder taskRunner(TaskRunner taskRunner) { + this.taskRunner = taskRunner; return this; } @@ -276,9 +295,6 @@ public InstanSeg build() { if (channels == null) { allChannels(); } - if (device == null) { - device("cpu"); - } return new InstanSeg( this.tileDims, this.downsample, @@ -287,16 +303,15 @@ public InstanSeg build() { this.numOutputChannels, this.imageData, this.channels, - this.nThreads, this.model, - this.device - ); + this.device, + this.taskRunner); } } private InstanSeg(int tileDims, double downsample, int padding, int boundary, int numOutputChannels, ImageData imageData, - Collection channels, int nThreads, InstanSegModel model, Device device) { + Collection channels, InstanSegModel model, Device device, TaskRunner taskRunner) { this.tileDims = tileDims; this.downsample = downsample; this.padding = padding; @@ -304,8 +319,8 @@ private InstanSeg(int tileDims, double downsample, int padding, int boundary, in this.numOutputChannels = numOutputChannels; this.imageData = imageData; this.channels = channels; - this.nThreads = nThreads; this.model = model; this.device = device; + this.taskRunner = taskRunner; } } diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index 69b7414..cfc3e0d 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -5,6 +5,7 @@ import javafx.beans.property.ObjectProperty; import javafx.beans.property.SimpleObjectProperty; import javafx.beans.property.StringProperty; +import javafx.beans.value.ObservableValue; import javafx.collections.FXCollections; import javafx.collections.ListChangeListener; import javafx.concurrent.Task; @@ -35,6 +36,7 @@ import qupath.lib.common.ThreadTools; import qupath.lib.display.ChannelDisplayInfo; import qupath.lib.gui.QuPathGUI; +import qupath.lib.gui.TaskRunnerFX; import qupath.lib.gui.tools.GuiTools; import qupath.lib.images.ImageData; import qupath.lib.images.servers.ImageServer; @@ -363,62 +365,7 @@ private void runInstanSeg() { .filter(Objects::nonNull) .toList(); - var task = new Task() { - @Override - protected Void call() { - // Ensure PyTorch engine is available - if (!PytorchManager.hasPyTorchEngine()) { - downloadPyTorch(); - } - String cmd = String.format(""" - import qupath.ext.instanseg.core.InstanSeg - - def channels = %s; - def instanSeg = InstanSeg.builder() - .modelPath("%s") - .device("%s") - .numOutputChannels(%d) - .channels(channels) - .tileDims(%d) - .imageData(QP.getCurrentImageData()) - .downsample(%f) - .nThreads(%d) - .build(); - instanSeg.detectObjects(); - """, - ChannelSelectItem.toConstructorString(selectedChannels), - model.getPath(), - deviceChoices.getSelectionModel().getSelectedItem(), - nucleiOnlyCheckBox.isSelected() ? 1:2, - InstanSegPreferences.tileSizeProperty().get(), - model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize(), - InstanSegPreferences.numThreadsProperty().getValue() - ); - qupath.getImageData().getHistoryWorkflow() - .addStep( - new DefaultScriptableWorkflowStep(resources.getString("workflow.title"), cmd) - ); - var instanSeg = InstanSeg.builder() - .model(model) - .imageData(qupath.getImageData()) - .device(deviceChoices.getSelectionModel().getSelectedItem()) - .numOutputChannels(nucleiOnlyCheckBox.isSelected() ? 1:2) - .channels(selectedChannels.stream().map(ChannelSelectItem::getTransform).toList()) - .tileDims(InstanSegPreferences.tileSizeProperty().get()) - .downsample(model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize()) - .nThreads(InstanSegPreferences.numThreadsProperty().getValue()) - .build(); - instanSeg.detectObjects(); - qupath.getImageData().getHierarchy().fireHierarchyChangedEvent(this); - if (model.nFailed() > 0) { - var errorMessage = String.format(resources.getString("error.tiles-failed"), model.nFailed()); - logger.error(errorMessage); - Dialogs.showErrorMessage(resources.getString("title"), - errorMessage); - } - return null; - } - }; + var task = new InstanSegTask(server, model, selectedChannels); pendingTask.set(task); // Reset the pending task when it completes (either successfully or not) task.stateProperty().addListener((observable, oldValue, newValue) -> { @@ -429,6 +376,92 @@ protected Void call() { }); } + private class InstanSegTask extends Task { + + private final List channels; + private final ImageServer server; + private final InstanSegModel model; + + InstanSegTask(ImageServer server, InstanSegModel model, List channels) { + this.server = server; + this.model = model; + this.channels = channels; + // this.progressListener = new ProgressDialog(this. + // QuPathGUI.getInstance().getStage(), e -> { + // if (Dialogs.showYesNoDialog(getDialogTitle(), resources.getString("ui.stop-tasks"))) { + // cancel(true); + // e.consume(); + // } + // }); + this.stateProperty().addListener(this::handleStateChange); + } + + private void handleStateChange(ObservableValue value, Worker.State oldValue, Worker.State newValue) { + // if (progressListener != null && newValue == Worker.State.CANCELLED) + // progressListener.cancel(); + } + + + @Override + protected Void call() { + // Ensure PyTorch engine is available + if (!PytorchManager.hasPyTorchEngine()) { + downloadPyTorch(); + } + String cmd = String.format(""" + import qupath.ext.instanseg.core.InstanSeg + + def channels = %s; + def instanSeg = InstanSeg.builder() + .modelPath("%s") + .device("%s") + .numOutputChannels(%d) + .channels(channels) + .tileDims(%d) + .imageData(QP.getCurrentImageData()) + .downsample(%f) + .nThreads(QPEx.createTaskRunner(%d)) + .build(); + instanSeg.detectObjects(); + """, + ChannelSelectItem.toConstructorString(channels), + model.getPath(), + deviceChoices.getSelectionModel().getSelectedItem(), + nucleiOnlyCheckBox.isSelected() ? 1 : 2, + InstanSegPreferences.tileSizeProperty().get(), + model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize(), + InstanSegPreferences.numThreadsProperty().getValue() + ); + qupath.getImageData().getHistoryWorkflow() + .addStep( + new DefaultScriptableWorkflowStep(resources.getString("workflow.title"), cmd) + ); + var taskRunner = new TaskRunnerFX( + QuPathGUI.getInstance(), + InstanSegPreferences.numThreadsProperty().getValue()); + + var instanSeg = InstanSeg.builder() + .model(model) + .imageData(qupath.getImageData()) + .device(deviceChoices.getSelectionModel().getSelectedItem()) + .numOutputChannels(nucleiOnlyCheckBox.isSelected() ? 1 : 2) + .channels(channels.stream().map(ChannelSelectItem::getTransform).toList()) + .tileDims(InstanSegPreferences.tileSizeProperty().get()) + .downsample(model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize()) + .taskRunner(taskRunner) + .build(); + instanSeg.detectObjects(); + qupath.getImageData().getHierarchy().fireHierarchyChangedEvent(this); + if (model.nFailed() > 0) { + var errorMessage = String.format(resources.getString("error.tiles-failed"), model.nFailed()); + logger.error(errorMessage); + Dialogs.showErrorMessage(resources.getString("title"), + errorMessage); + } + return null; + } + } + private void downloadPyTorch() { Platform.runLater(() -> Dialogs.showInfoNotification(resources.getString("title"), resources.getString("ui.pytorch-downloading"))); PytorchManager.getEngineOnline(); From 679302578906fd1f381539315094038f890abf8d Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Thu, 27 Jun 2024 16:14:56 +0100 Subject: [PATCH 13/28] Default runner --- src/main/java/qupath/ext/instanseg/core/InstanSeg.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index d5414a0..65491c8 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -61,10 +61,10 @@ public static final class Builder { private int boundary = 20; private int numOutputChannels = 2; private Device device = Device.fromName("cpu"); + private TaskRunner taskRunner = TaskRunnerUtils.getDefaultInstance().createTaskRunner(); private ImageData imageData; private Collection channels; private InstanSegModel model; - private TaskRunner taskRunner; Builder() {} From d0b3a33086e3049ce893666d37f76facaf5545ef Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Thu, 27 Jun 2024 23:26:50 +0100 Subject: [PATCH 14/28] Javadocs and gradle build --- build.gradle | 1 + .../qupath/ext/instanseg/core/InstanSeg.java | 23 +++++++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/build.gradle b/build.gradle index c017a79..4f82db5 100644 --- a/build.gradle +++ b/build.gradle @@ -35,6 +35,7 @@ ext.qupathJavaVersion = 21 * but shouldn't be bundled up for use in the extension. */ dependencies { + // Main QuPath user interface jar. // Automatically includes other QuPath jars as subdependencies. shadow "io.github.qupath:qupath-gui-fx:${qupathVersion}" diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index 65491c8..21101a1 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -5,6 +5,7 @@ import org.slf4j.LoggerFactory; import qupath.lib.images.ImageData; import qupath.lib.images.servers.ColorTransforms; +import qupath.lib.objects.PathObject; import qupath.lib.plugins.TaskRunner; import qupath.lib.plugins.TaskRunnerUtils; @@ -31,15 +32,28 @@ public class InstanSeg { private final TaskRunner taskRunner; + /** + * Create a builder object for InstanSeg. + * @return A builder, which may not be valid. + */ public static Builder builder() { return new Builder(); } + /** + * Run inference for the currently selected PathObjects. + */ public void detectObjects() { - // todo: replace createTaskRunner + detectObjects(imageData.getHierarchy().getSelectionModel().getSelectedObjects()); + } + + /** + * Run inference for a collection of PathObjects. + */ + public void detectObjects(Collection pathObjects) { model.runInstanSeg( imageData, - imageData.getHierarchy().getSelectionModel().getSelectedObjects(), + pathObjects, channels, tileDims, downsample, @@ -52,8 +66,9 @@ public void detectObjects() { } - - + /** + * A builder class for InstanSeg. + */ public static final class Builder { private int tileDims = 512; private double downsample = 1; From 8e2b44463240ce962426eaea2804507eefa75831 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Mon, 1 Jul 2024 13:13:44 +0100 Subject: [PATCH 15/28] Update InstanSegController.java --- .../ext/instanseg/ui/InstanSegController.java | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index cfc3e0d..91370b4 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -386,19 +386,6 @@ private class InstanSegTask extends Task { this.server = server; this.model = model; this.channels = channels; - // this.progressListener = new ProgressDialog(this. - // QuPathGUI.getInstance().getStage(), e -> { - // if (Dialogs.showYesNoDialog(getDialogTitle(), resources.getString("ui.stop-tasks"))) { - // cancel(true); - // e.consume(); - // } - // }); - this.stateProperty().addListener(this::handleStateChange); - } - - private void handleStateChange(ObservableValue value, Worker.State oldValue, Worker.State newValue) { - // if (progressListener != null && newValue == Worker.State.CANCELLED) - // progressListener.cancel(); } From 0ab3ff1046ab129d3dfe55f322899ba0ef15b94c Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Thu, 11 Jul 2024 14:46:07 +0100 Subject: [PATCH 16/28] Make first arg concrete, later varargs --- .../qupath/ext/instanseg/core/InstanSeg.java | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index 21101a1..2241e25 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -12,6 +12,7 @@ import java.awt.image.BufferedImage; import java.io.IOException; import java.nio.file.Path; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; @@ -158,8 +159,10 @@ public Builder channels(Collection channels) { * @param channels Channels to be used in inference * @return A modified builder */ - public Builder channels(ColorTransforms.ColorTransform... channels) { - this.channels = List.of(channels); + public Builder channels(ColorTransforms.ColorTransform channel, ColorTransforms.ColorTransform... channels) { + var l = Arrays.asList(channels); + l.add(channel); + this.channels = l; return this; } @@ -191,10 +194,13 @@ public Builder channelIndices(Collection channels) { * @param channels Integers used to specify the channels used * @return A modified builder */ - public Builder channelIndices(int... channels) { - this.channels = Arrays.stream(channels).boxed() - .map(ColorTransforms::createChannelExtractor) - .toList(); + public Builder channelIndices(int channel, int... channels) { + List l = new ArrayList<>(); + l.add(ColorTransforms.createChannelExtractor(channel)); + for (int i: channels) { + l.add(ColorTransforms.createChannelExtractor(i)); + } + this.channels = l; return this; } @@ -215,10 +221,13 @@ public Builder channelNames(Collection channels) { * @param channels A set of channel names * @return A modified builder */ - public Builder channelNames(String... channels) { - this.channels = Arrays.stream(channels) - .map(ColorTransforms::createChannelExtractor) - .toList(); + public Builder channelNames(String channel, String... channels) { + List l = new ArrayList<>(); + l.add(ColorTransforms.createChannelExtractor(channel)); + for (String s: channels) { + l.add(ColorTransforms.createChannelExtractor(s)); + } + this.channels = l; return this; } From 7ab16e1c0abd69dadfa77ef9dbe9a0a08a5df968 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Thu, 11 Jul 2024 18:11:48 +0100 Subject: [PATCH 17/28] =?UTF-8?q?Fix=20channels=C2=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/qupath/ext/instanseg/core/InstanSeg.java | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index 2241e25..00767c2 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -171,10 +171,12 @@ public Builder channels(ColorTransforms.ColorTransform channel, ColorTransforms. * @return A modified builder */ public Builder allChannels() { - return channelIndices( - IntStream.of(imageData.getServer().nChannels()) + // assignment is just to suppress IDE suggestion for void return val + var tmp = channelIndices( + IntStream.range(0, imageData.getServer().nChannels()) .boxed() .toList()); + return this; } /** @@ -317,7 +319,8 @@ public InstanSeg build() { throw new IllegalStateException("imageData cannot be null!"); } if (channels == null) { - allChannels(); + // assignment is just to suppress IDE suggestion for void return + var tmp = allChannels(); } return new InstanSeg( this.tileDims, From 79f45aa067437e46ccb609596b055e09a963b2f5 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Fri, 12 Jul 2024 13:36:31 +0100 Subject: [PATCH 18/28] Enable setting output classes; resolve #36 --- .../qupath/ext/instanseg/core/InstanSeg.java | 35 ++++++- .../ext/instanseg/core/InstanSegModel.java | 4 +- .../core/OutputToObjectConverter.java | 99 +++++++++++++++++-- 3 files changed, 125 insertions(+), 13 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index 00767c2..a751720 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -5,6 +5,9 @@ import org.slf4j.LoggerFactory; import qupath.lib.images.ImageData; import qupath.lib.images.servers.ColorTransforms; +import qupath.lib.objects.PathAnnotationObject; +import qupath.lib.objects.PathCellObject; +import qupath.lib.objects.PathDetectionObject; import qupath.lib.objects.PathObject; import qupath.lib.plugins.TaskRunner; import qupath.lib.plugins.TaskRunnerUtils; @@ -31,6 +34,7 @@ public class InstanSeg { private final InstanSegModel model; private final Device device; private final TaskRunner taskRunner; + private final List> outputClasses; /** @@ -62,6 +66,7 @@ public void detectObjects(Collection pathObjects) { boundary, device, numOutputChannels == 1, + outputClasses, taskRunner ); } @@ -81,6 +86,7 @@ public static final class Builder { private ImageData imageData; private Collection channels; private InstanSegModel model; + private List> outputClasses; Builder() {} @@ -310,6 +316,26 @@ public Builder device(Device device) { return this; } + public Builder outputClasses(List> outputClasses) { + this.outputClasses = outputClasses; + return this; + } + + public Builder outputCells() { + this.outputClasses = List.of(PathCellObject.class); + return this; + } + + public Builder outputDetections() { + this.outputClasses = List.of(PathDetectionObject.class); + return this; + } + + public Builder outputAnnotations() { + this.outputClasses = List.of(PathAnnotationObject.class); + return this; + } + /** * Build the InstanSeg instance. * @return An InstanSeg instance ready for object detection. @@ -332,13 +358,17 @@ public InstanSeg build() { this.channels, this.model, this.device, - this.taskRunner); + this.taskRunner, + this.outputClasses); } } + + private InstanSeg(int tileDims, double downsample, int padding, int boundary, int numOutputChannels, ImageData imageData, - Collection channels, InstanSegModel model, Device device, TaskRunner taskRunner) { + Collection channels, InstanSegModel model, Device device, TaskRunner taskRunner, + List> outputClasses) { this.tileDims = tileDims; this.downsample = downsample; this.padding = padding; @@ -349,5 +379,6 @@ private InstanSeg(int tileDims, double downsample, int padding, int boundary, in this.model = model; this.device = device; this.taskRunner = taskRunner; + this.outputClasses = outputClasses; } } diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index 60aa782..ecc2b3f 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -29,6 +29,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; @@ -141,6 +142,7 @@ void runInstanSeg( int boundary, Device device, boolean nucleiOnly, + List> outputClasses, TaskRunner taskRunner) { nFailed = 0; @@ -189,7 +191,7 @@ void runInstanSeg( .cropTiles(false) .build() ) - .outputHandler(new OutputToObjectConverter.PruneObjectOutputHandler<>(new OutputToObjectConverter(), boundary)) + .outputHandler(new OutputToObjectConverter.PruneObjectOutputHandler<>(new OutputToObjectConverter(outputClasses), boundary)) .padding(padding) .merger(ObjectMerger.createIoUMerger(0.2)) .downsample(downsample) diff --git a/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java index 0127717..8cc8259 100644 --- a/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java @@ -2,10 +2,15 @@ import org.bytedeco.opencv.opencv_core.Mat; import org.locationtech.jts.geom.Envelope; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import qupath.lib.analysis.images.ContourTracing; import qupath.lib.experimental.pixels.OutputHandler; import qupath.lib.experimental.pixels.Parameters; import qupath.lib.experimental.pixels.PixelProcessorUtils; +import qupath.lib.objects.PathAnnotationObject; +import qupath.lib.objects.PathCellObject; +import qupath.lib.objects.PathDetectionObject; import qupath.lib.objects.PathObject; import qupath.lib.objects.PathObjects; import qupath.lib.roi.GeometryTools; @@ -16,11 +21,18 @@ import java.util.List; import java.util.Map; import java.util.Random; +import java.util.function.BiFunction; import java.util.stream.Collectors; class OutputToObjectConverter implements OutputHandler.OutputToObjectConverter { + private static final Logger logger = LoggerFactory.getLogger(OutputToObjectConverter.class); private static final long seed = 1243; + private List> classes; + + public OutputToObjectConverter(List> outputClasses) { + this.classes = outputClasses; + } @Override public List convertToObjects(Parameters params, Mat output) { @@ -39,11 +51,47 @@ public List convertToObjects(Parameters params, Mat output ); } var rng = new Random(seed); + + // if of length 1, can be + // cellObject (with or without nucleus) + // annotations + // detections + // if of length 2, then can be: + // detection <- annotation + // annotation <- annotation + // detection <- detection + BiFunction function; + if (classes.size() == 1) { + // todo + if (classes.get(0) == PathAnnotationObject.class) { + function = OutputToObjectConverter::annotationInsideAnnotation; + } else if (classes.get(0) == PathDetectionObject.class) { + function = OutputToObjectConverter::detectionInsideDetection; + } else if (classes.get(0) == PathCellObject.class) { + function = OutputToObjectConverter::createCell; + } else { + function = OutputToObjectConverter::createCell; + logger.warn("Unknown output {}", classes.get(0)); + } + } else { + assert classes.size() == 2; + if (classes.get(0) == PathDetectionObject.class && classes.get(1) == PathAnnotationObject.class) { + function = OutputToObjectConverter::detectionInsideAnnotation; + } else if (classes.get(0) == PathAnnotationObject.class && classes.get(1) == PathAnnotationObject.class) { + function = OutputToObjectConverter::annotationInsideAnnotation; + } else if (classes.get(0) == PathDetectionObject.class && classes.get(1) == PathDetectionObject.class) { + function = OutputToObjectConverter::detectionInsideDetection; + } else { + logger.warn("Unknown combination of outputs {} <- {}", classes.get(0), classes.get(1)); + function = OutputToObjectConverter::createCell; + } + } + if (roiMaps.size() == 1) { // One-channel detected, represent using detection objects return roiMaps.get(0).values().stream() - .map(p -> { - var obj = PathObjects.createDetectionObject(p); + .map(roi -> { + var obj = function.apply(roi, null); obj.setColor( rng.nextInt(255), rng.nextInt(255), @@ -55,24 +103,55 @@ public List convertToObjects(Parameters params, Mat output } else { // Two channels detected, represent using cell objects // We assume that the labels are matched - and we can't have a nucleus without a cell - Map nucleusROIs = roiMaps.get(0); - Map cellROIs = roiMaps.get(1); + Map childROIs = roiMaps.get(0); + Map parentROIs = roiMaps.get(1); List cells = new ArrayList<>(); - for (var entry : cellROIs.entrySet()) { - var cell = entry.getValue(); - var nucleus = nucleusROIs.getOrDefault(entry.getKey(), null); - var cellObject = PathObjects.createCellObject(cell, nucleus); - cellObject.setColor( + for (var entry : parentROIs.entrySet()) { + var parent = entry.getValue(); + var child = childROIs.getOrDefault(entry.getKey(), null); + var outputObject = function.apply(parent, child); + outputObject.setColor( rng.nextInt(255), rng.nextInt(255), rng.nextInt(255) ); - cells.add(cellObject); + cells.add(outputObject); } return cells; } } + private static PathObject detectionInsideDetection(ROI parent, ROI child) { + var parentDetection = PathObjects.createDetectionObject(parent); + if (child != null) { + var childDetection = PathObjects.createDetectionObject(child); + parentDetection.addChildObject(childDetection); + } + return parentDetection; + } + + private static PathObject detectionInsideAnnotation(ROI parent, ROI child) { + var parentAnnotation = PathObjects.createAnnotationObject(parent); + if (child != null) { + var childDetection = PathObjects.createDetectionObject(child); + parentAnnotation.addChildObject(childDetection); + } + return parentAnnotation; + } + + private static PathObject annotationInsideAnnotation(ROI parent, ROI child) { + var parentAnnotation = PathObjects.createAnnotationObject(parent); + if (child != null) { + var childAnnotation = PathObjects.createAnnotationObject(child); + parentAnnotation.addChildObject(childAnnotation); + } + return parentAnnotation; + } + + private static PathObject createCell(ROI parent, ROI child) { + return PathObjects.createCellObject(parent, child); + } + static class PruneObjectOutputHandler implements OutputHandler { private final OutputToObjectConverter converter; From a91a355b5e5f381eb4f17af696303e986252364c Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Fri, 12 Jul 2024 13:39:38 +0100 Subject: [PATCH 19/28] JDoc --- .../qupath/ext/instanseg/core/InstanSeg.java | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index a751720..c7e1010 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -316,21 +316,40 @@ public Builder device(Device device) { return this; } + /** + * Specify the output class(es) + * @param outputClasses A list specifying what type the output should be. + * eg, [PathDetectionObject.class, PathAnnotationObject.class] + * specifies to create detections nested inside annotations. + * @return A modified builder + */ public Builder outputClasses(List> outputClasses) { this.outputClasses = outputClasses; return this; } + /** + * Specify cells as the output class, possibly without nuclei + * @return A modified builder + */ public Builder outputCells() { this.outputClasses = List.of(PathCellObject.class); return this; } + /** + * Specify (possibly nested) detections as the output class + * @return A modified builder + */ public Builder outputDetections() { this.outputClasses = List.of(PathDetectionObject.class); return this; } + /** + * Specify (possibly nested) annotations as the output class + * @return A modified builder + */ public Builder outputAnnotations() { this.outputClasses = List.of(PathAnnotationObject.class); return this; @@ -348,6 +367,9 @@ public InstanSeg build() { // assignment is just to suppress IDE suggestion for void return var tmp = allChannels(); } + if (outputClasses == null) { + var tmp = outputCells(); + } return new InstanSeg( this.tileDims, this.downsample, From 181ecb64456b3755bd06cea39eeb42496e2a2d81 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Mon, 15 Jul 2024 12:03:24 +0100 Subject: [PATCH 20/28] Functional solution to colouring --- .../core/OutputToObjectConverter.java | 83 +++++++------------ 1 file changed, 32 insertions(+), 51 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java index 8cc8259..0c8fd6c 100644 --- a/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java @@ -5,6 +5,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import qupath.lib.analysis.images.ContourTracing; +import qupath.lib.common.ColorTools; import qupath.lib.experimental.pixels.OutputHandler; import qupath.lib.experimental.pixels.Parameters; import qupath.lib.experimental.pixels.PixelProcessorUtils; @@ -22,6 +23,7 @@ import java.util.Map; import java.util.Random; import java.util.function.BiFunction; +import java.util.function.Function; import java.util.stream.Collectors; class OutputToObjectConverter implements OutputHandler.OutputToObjectConverter { @@ -64,41 +66,33 @@ public List convertToObjects(Parameters params, Mat output if (classes.size() == 1) { // todo if (classes.get(0) == PathAnnotationObject.class) { - function = OutputToObjectConverter::annotationInsideAnnotation; + function = createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng); } else if (classes.get(0) == PathDetectionObject.class) { - function = OutputToObjectConverter::detectionInsideDetection; + function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createDetectionObject, rng); } else if (classes.get(0) == PathCellObject.class) { - function = OutputToObjectConverter::createCell; + function = createCellFun(rng); } else { - function = OutputToObjectConverter::createCell; - logger.warn("Unknown output {}", classes.get(0)); + function = createCellFun(rng); + logger.warn("Unknown output {}, defaulting to cells", classes.get(0)); } } else { assert classes.size() == 2; if (classes.get(0) == PathDetectionObject.class && classes.get(1) == PathAnnotationObject.class) { - function = OutputToObjectConverter::detectionInsideAnnotation; + function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createAnnotationObject, rng); } else if (classes.get(0) == PathAnnotationObject.class && classes.get(1) == PathAnnotationObject.class) { - function = OutputToObjectConverter::annotationInsideAnnotation; + function = createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng); } else if (classes.get(0) == PathDetectionObject.class && classes.get(1) == PathDetectionObject.class) { - function = OutputToObjectConverter::detectionInsideDetection; + function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createDetectionObject, rng); } else { - logger.warn("Unknown combination of outputs {} <- {}", classes.get(0), classes.get(1)); - function = OutputToObjectConverter::createCell; + logger.warn("Unknown combination of outputs {} <- {}, defaulting to cells", classes.get(0), classes.get(1)); + function = createCellFun(rng); } } if (roiMaps.size() == 1) { // One-channel detected, represent using detection objects return roiMaps.get(0).values().stream() - .map(roi -> { - var obj = function.apply(roi, null); - obj.setColor( - rng.nextInt(255), - rng.nextInt(255), - rng.nextInt(255) - ); - return obj; - }) + .map(roi -> function.apply(roi, null)) .collect(Collectors.toList()); } else { // Two channels detected, represent using cell objects @@ -110,46 +104,33 @@ public List convertToObjects(Parameters params, Mat output var parent = entry.getValue(); var child = childROIs.getOrDefault(entry.getKey(), null); var outputObject = function.apply(parent, child); - outputObject.setColor( - rng.nextInt(255), - rng.nextInt(255), - rng.nextInt(255) - ); cells.add(outputObject); } return cells; } } - private static PathObject detectionInsideDetection(ROI parent, ROI child) { - var parentDetection = PathObjects.createDetectionObject(parent); - if (child != null) { - var childDetection = PathObjects.createDetectionObject(child); - parentDetection.addChildObject(childDetection); - } - return parentDetection; - } - - private static PathObject detectionInsideAnnotation(ROI parent, ROI child) { - var parentAnnotation = PathObjects.createAnnotationObject(parent); - if (child != null) { - var childDetection = PathObjects.createDetectionObject(child); - parentAnnotation.addChildObject(childDetection); - } - return parentAnnotation; + private static BiFunction createCellFun(Random rng) { + return (parent, child) -> { + var cell = PathObjects.createCellObject(parent, child); + var color = ColorTools.packRGB(rng.nextInt(255), rng.nextInt(255), rng.nextInt(255)); + cell.setColor(color); + return cell; + }; } - private static PathObject annotationInsideAnnotation(ROI parent, ROI child) { - var parentAnnotation = PathObjects.createAnnotationObject(parent); - if (child != null) { - var childAnnotation = PathObjects.createAnnotationObject(child); - parentAnnotation.addChildObject(childAnnotation); - } - return parentAnnotation; - } - - private static PathObject createCell(ROI parent, ROI child) { - return PathObjects.createCellObject(parent, child); + private static BiFunction createObjectsFun(Function createParentFun, Function createChildFun, Random rng) { + return (parent, child) -> { + var parentObj = createParentFun.apply(parent); + var color = ColorTools.packRGB(rng.nextInt(255), rng.nextInt(255), rng.nextInt(255)); + parentObj.setColor(color); + if (child != null) { + var childObj = createChildFun.apply(child); + childObj.setColor(color); + parentObj.addChildObject(childObj); + } + return parentObj; + }; } static class PruneObjectOutputHandler implements OutputHandler { From 6f2e2e7a5c4a57505e20f1293a16111724766132 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Mon, 15 Jul 2024 12:11:50 +0100 Subject: [PATCH 21/28] Tidy comments --- .../instanseg/core/OutputToObjectConverter.java | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java index 0c8fd6c..759dbcb 100644 --- a/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java @@ -54,17 +54,12 @@ public List convertToObjects(Parameters params, Mat output } var rng = new Random(seed); - // if of length 1, can be - // cellObject (with or without nucleus) - // annotations - // detections - // if of length 2, then can be: - // detection <- annotation - // annotation <- annotation - // detection <- detection BiFunction function; if (classes.size() == 1) { - // todo + // if of length 1, can be + // cellObject (with or without nucleus) + // annotations + // detections if (classes.get(0) == PathAnnotationObject.class) { function = createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng); } else if (classes.get(0) == PathDetectionObject.class) { @@ -76,6 +71,10 @@ public List convertToObjects(Parameters params, Mat output logger.warn("Unknown output {}, defaulting to cells", classes.get(0)); } } else { + // if of length 2, then can be: + // detection <- annotation + // annotation <- annotation + // detection <- detection assert classes.size() == 2; if (classes.get(0) == PathDetectionObject.class && classes.get(1) == PathAnnotationObject.class) { function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createAnnotationObject, rng); From a001ecbb1b6856657581adbb2a550dea53270252 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Mon, 15 Jul 2024 12:14:09 +0100 Subject: [PATCH 22/28] Smaller comment --- .../ext/instanseg/core/OutputToObjectConverter.java | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java index 759dbcb..6d51dc2 100644 --- a/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java @@ -57,9 +57,7 @@ public List convertToObjects(Parameters params, Mat output BiFunction function; if (classes.size() == 1) { // if of length 1, can be - // cellObject (with or without nucleus) - // annotations - // detections + // cellObject (with or without nucleus), annotations, detections if (classes.get(0) == PathAnnotationObject.class) { function = createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng); } else if (classes.get(0) == PathDetectionObject.class) { @@ -72,9 +70,7 @@ public List convertToObjects(Parameters params, Mat output } } else { // if of length 2, then can be: - // detection <- annotation - // annotation <- annotation - // detection <- detection + // detection <- annotation, annotation <- annotation, detection <- detection assert classes.size() == 2; if (classes.get(0) == PathDetectionObject.class && classes.get(1) == PathAnnotationObject.class) { function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createAnnotationObject, rng); From 2e5d46f4ef1db338c975d929ccd28aec222bd2e9 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Tue, 30 Jul 2024 17:48:51 +0100 Subject: [PATCH 23/28] Update ChannelSelectItem.java --- src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java index af67901..eaf7bf5 100644 --- a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java +++ b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java @@ -15,7 +15,6 @@ class ChannelSelectItem { private final ColorTransforms.ColorTransform transform; private final String constructor; - // todo: public method to get a constructor for the colortransform ChannelSelectItem(String name) { this.name = name; this.transform = ColorTransforms.createChannelExtractor(name); From b5103fb4c3d27e10e0b001d66d348a4f043bfe91 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Tue, 30 Jul 2024 17:49:07 +0100 Subject: [PATCH 24/28] Update build.gradle --- build.gradle | 1 - 1 file changed, 1 deletion(-) diff --git a/build.gradle b/build.gradle index 4f82db5..c017a79 100644 --- a/build.gradle +++ b/build.gradle @@ -35,7 +35,6 @@ ext.qupathJavaVersion = 21 * but shouldn't be bundled up for use in the extension. */ dependencies { - // Main QuPath user interface jar. // Automatically includes other QuPath jars as subdependencies. shadow "io.github.qupath:qupath-gui-fx:${qupathVersion}" From 2f0704b4f72832d9ff49b7ff835b51e6e29796d4 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Mon, 5 Aug 2024 12:46:24 +0100 Subject: [PATCH 25/28] Restructure code and add some javadocs --- .../ext/instanseg/core/InstanSegModel.java | 2 +- .../InstansegOutputToObjectConverter.java | 128 +++++++++ .../core/OutputToObjectConverter.java | 269 ------------------ .../core/PruneObjectOutputHandler.java | 164 +++++++++++ 4 files changed, 293 insertions(+), 270 deletions(-) create mode 100644 src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java delete mode 100644 src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java create mode 100644 src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index ecc2b3f..e054136 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -191,7 +191,7 @@ void runInstanSeg( .cropTiles(false) .build() ) - .outputHandler(new OutputToObjectConverter.PruneObjectOutputHandler<>(new OutputToObjectConverter(outputClasses), boundary)) + .outputHandler(new PruneObjectOutputHandler<>(new InstansegOutputToObjectConverter(outputClasses), boundary)) .padding(padding) .merger(ObjectMerger.createIoUMerger(0.2)) .downsample(downsample) diff --git a/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java new file mode 100644 index 0000000..93aa94c --- /dev/null +++ b/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java @@ -0,0 +1,128 @@ +package qupath.ext.instanseg.core; + +import org.bytedeco.opencv.opencv_core.Mat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import qupath.lib.analysis.images.ContourTracing; +import qupath.lib.common.ColorTools; +import qupath.lib.experimental.pixels.OutputHandler; +import qupath.lib.experimental.pixels.Parameters; +import qupath.lib.objects.PathAnnotationObject; +import qupath.lib.objects.PathCellObject; +import qupath.lib.objects.PathDetectionObject; +import qupath.lib.objects.PathObject; +import qupath.lib.objects.PathObjects; +import qupath.lib.roi.interfaces.ROI; +import qupath.opencv.tools.OpenCVTools; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +class InstansegOutputToObjectConverter implements OutputHandler.OutputToObjectConverter { + private static final Logger logger = LoggerFactory.getLogger(InstansegOutputToObjectConverter.class); + + private static final long seed = 1243; + private final List> classes; + + InstansegOutputToObjectConverter(List> outputClasses) { + this.classes = outputClasses; + } + + @Override + public List convertToObjects(Parameters params, Mat output) { + if (output == null) { + return List.of(); + } + int nChannels = output.channels(); + if (nChannels < 1 || nChannels > 2) + throw new IllegalArgumentException("Expected 1 or 2 channels, but found " + nChannels); + + List> roiMaps = new ArrayList<>(); + for (var mat : OpenCVTools.splitChannels(output)) { + var image = OpenCVTools.matToSimpleImage(mat, 0); + roiMaps.add( + ContourTracing.createROIs(image, params.getRegionRequest(), 1, -1) + ); + } + var rng = new Random(seed); + + BiFunction function; + if (classes.size() == 1) { + // if of length 1, can be + // cellObject (with or without nucleus), annotations, detections + if (classes.get(0) == PathAnnotationObject.class) { + function = createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng); + } else if (classes.get(0) == PathDetectionObject.class) { + function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createDetectionObject, rng); + } else if (classes.get(0) == PathCellObject.class) { + function = createCellFun(rng); + } else { + function = createCellFun(rng); + logger.warn("Unknown output {}, defaulting to cells", classes.get(0)); + } + } else { + // if of length 2, then can be: + // detection <- annotation, annotation <- annotation, detection <- detection + assert classes.size() == 2; + if (classes.get(0) == PathDetectionObject.class && classes.get(1) == PathAnnotationObject.class) { + function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createAnnotationObject, rng); + } else if (classes.get(0) == PathAnnotationObject.class && classes.get(1) == PathAnnotationObject.class) { + function = createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng); + } else if (classes.get(0) == PathDetectionObject.class && classes.get(1) == PathDetectionObject.class) { + function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createDetectionObject, rng); + } else { + logger.warn("Unknown combination of outputs {} <- {}, defaulting to cells", classes.get(0), classes.get(1)); + function = createCellFun(rng); + } + } + + if (roiMaps.size() == 1) { + // One-channel detected, represent using detection objects + return roiMaps.get(0).values().stream() + .map(roi -> function.apply(roi, null)) + .collect(Collectors.toList()); + } else { + // Two channels detected, represent using cell objects + // We assume that the labels are matched - and we can't have a nucleus without a cell + Map childROIs = roiMaps.get(0); + Map parentROIs = roiMaps.get(1); + List cells = new ArrayList<>(); + for (var entry : parentROIs.entrySet()) { + var parent = entry.getValue(); + var child = childROIs.getOrDefault(entry.getKey(), null); + var outputObject = function.apply(parent, child); + cells.add(outputObject); + } + return cells; + } + } + + private static BiFunction createCellFun(Random rng) { + return (parent, child) -> { + var cell = PathObjects.createCellObject(parent, child); + var color = ColorTools.packRGB(rng.nextInt(255), rng.nextInt(255), rng.nextInt(255)); + cell.setColor(color); + return cell; + }; + } + + private static BiFunction createObjectsFun(Function createParentFun, Function createChildFun, Random rng) { + return (parent, child) -> { + var parentObj = createParentFun.apply(parent); + var color = ColorTools.packRGB(rng.nextInt(255), rng.nextInt(255), rng.nextInt(255)); + parentObj.setColor(color); + if (child != null) { + var childObj = createChildFun.apply(child); + childObj.setColor(color); + parentObj.addChildObject(childObj); + } + return parentObj; + }; + } + +} diff --git a/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java deleted file mode 100644 index 6d51dc2..0000000 --- a/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java +++ /dev/null @@ -1,269 +0,0 @@ -package qupath.ext.instanseg.core; - -import org.bytedeco.opencv.opencv_core.Mat; -import org.locationtech.jts.geom.Envelope; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import qupath.lib.analysis.images.ContourTracing; -import qupath.lib.common.ColorTools; -import qupath.lib.experimental.pixels.OutputHandler; -import qupath.lib.experimental.pixels.Parameters; -import qupath.lib.experimental.pixels.PixelProcessorUtils; -import qupath.lib.objects.PathAnnotationObject; -import qupath.lib.objects.PathCellObject; -import qupath.lib.objects.PathDetectionObject; -import qupath.lib.objects.PathObject; -import qupath.lib.objects.PathObjects; -import qupath.lib.roi.GeometryTools; -import qupath.lib.roi.interfaces.ROI; -import qupath.opencv.tools.OpenCVTools; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.stream.Collectors; - -class OutputToObjectConverter implements OutputHandler.OutputToObjectConverter { - private static final Logger logger = LoggerFactory.getLogger(OutputToObjectConverter.class); - - private static final long seed = 1243; - private List> classes; - - public OutputToObjectConverter(List> outputClasses) { - this.classes = outputClasses; - } - - @Override - public List convertToObjects(Parameters params, Mat output) { - if (output == null) { - return List.of(); - } - int nChannels = output.channels(); - if (nChannels < 1 || nChannels > 2) - throw new IllegalArgumentException("Expected 1 or 2 channels, but found " + nChannels); - - List> roiMaps = new ArrayList<>(); - for (var mat : OpenCVTools.splitChannels(output)) { - var image = OpenCVTools.matToSimpleImage(mat, 0); - roiMaps.add( - ContourTracing.createROIs(image, params.getRegionRequest(), 1, -1) - ); - } - var rng = new Random(seed); - - BiFunction function; - if (classes.size() == 1) { - // if of length 1, can be - // cellObject (with or without nucleus), annotations, detections - if (classes.get(0) == PathAnnotationObject.class) { - function = createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng); - } else if (classes.get(0) == PathDetectionObject.class) { - function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createDetectionObject, rng); - } else if (classes.get(0) == PathCellObject.class) { - function = createCellFun(rng); - } else { - function = createCellFun(rng); - logger.warn("Unknown output {}, defaulting to cells", classes.get(0)); - } - } else { - // if of length 2, then can be: - // detection <- annotation, annotation <- annotation, detection <- detection - assert classes.size() == 2; - if (classes.get(0) == PathDetectionObject.class && classes.get(1) == PathAnnotationObject.class) { - function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createAnnotationObject, rng); - } else if (classes.get(0) == PathAnnotationObject.class && classes.get(1) == PathAnnotationObject.class) { - function = createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng); - } else if (classes.get(0) == PathDetectionObject.class && classes.get(1) == PathDetectionObject.class) { - function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createDetectionObject, rng); - } else { - logger.warn("Unknown combination of outputs {} <- {}, defaulting to cells", classes.get(0), classes.get(1)); - function = createCellFun(rng); - } - } - - if (roiMaps.size() == 1) { - // One-channel detected, represent using detection objects - return roiMaps.get(0).values().stream() - .map(roi -> function.apply(roi, null)) - .collect(Collectors.toList()); - } else { - // Two channels detected, represent using cell objects - // We assume that the labels are matched - and we can't have a nucleus without a cell - Map childROIs = roiMaps.get(0); - Map parentROIs = roiMaps.get(1); - List cells = new ArrayList<>(); - for (var entry : parentROIs.entrySet()) { - var parent = entry.getValue(); - var child = childROIs.getOrDefault(entry.getKey(), null); - var outputObject = function.apply(parent, child); - cells.add(outputObject); - } - return cells; - } - } - - private static BiFunction createCellFun(Random rng) { - return (parent, child) -> { - var cell = PathObjects.createCellObject(parent, child); - var color = ColorTools.packRGB(rng.nextInt(255), rng.nextInt(255), rng.nextInt(255)); - cell.setColor(color); - return cell; - }; - } - - private static BiFunction createObjectsFun(Function createParentFun, Function createChildFun, Random rng) { - return (parent, child) -> { - var parentObj = createParentFun.apply(parent); - var color = ColorTools.packRGB(rng.nextInt(255), rng.nextInt(255), rng.nextInt(255)); - parentObj.setColor(color); - if (child != null) { - var childObj = createChildFun.apply(child); - childObj.setColor(color); - parentObj.addChildObject(childObj); - } - return parentObj; - }; - } - - static class PruneObjectOutputHandler implements OutputHandler { - - private final OutputToObjectConverter converter; - private final int boundaryThreshold; - - PruneObjectOutputHandler(OutputToObjectConverter converter, int boundaryThreshold) { - this.converter = converter; - this.boundaryThreshold = boundaryThreshold; - } - - @Override - public boolean handleOutput(Parameters params, U output) { - if (output == null) - return false; - else { - List newObjects = converter.convertToObjects(params, output); - if (newObjects == null) - return false; - // If using a proxy object (eg tile), - // we want to remove things touching the tile boundary, - // then add the objects to the proxy rather than the parent - var parentOrProxy = params.getParentOrProxy(); - parentOrProxy.clearChildObjects(); - - // remove features within N pixels of the region request boundaries - var bounds = GeometryTools.createRectangle( - params.getRegionRequest().getX(), params.getRegionRequest().getY(), - params.getRegionRequest().getWidth(), params.getRegionRequest().getHeight()); - - int width = params.getServer().getWidth(); - int height = params.getServer().getHeight(); - - newObjects = newObjects.parallelStream() - .filter(p -> doesntTouchBoundaries(p.getROI().getGeometry().getEnvelopeInternal(), bounds.getEnvelopeInternal(), boundaryThreshold, width, height)) - .toList(); - - if (!newObjects.isEmpty()) { - // since we're using IoU to merge objects, we want to keep anything that is within the overall object bounding box - var parent = params.getParent().getROI(); - newObjects = newObjects.parallelStream() - .flatMap(p -> PixelProcessorUtils.maskObject(parent, p).stream()) - .toList(); - } - parentOrProxy.addChildObjects(newObjects); - parentOrProxy.setLocked(true); - return true; - } - } - - - /** - * Tests if a detection is near the boundary of a parent region. - * It first checks if the detection is on the edge of the overall image, in which case it should be kept, - * unless it is at the edge of the image and the perpendicular edge of the parent region. - * For example, on the left side of the image, but on the top/bottom edge of the parent region. - * Then, it checks if the detection is on the boundary of the parent region. - * @param det The detection object. - * @param region The region containing all detection objects. - * @param boundaryPixels The size of the boundary, in pixels, to use for removing object. - * @param imageWidth The width of the image, in pixels. - * @param imageHeight The height of the image, in pixels. - * @return Whether the detection object should be removed, based on these criteria. - */ - private boolean doesntTouchBoundaries(Envelope det, Envelope region, int boundaryPixels, int imageWidth, int imageHeight) { - // keep any objects at the boundary of the annotation, except the stuff around region boundaries - if (touchesLeftOfImage(det, boundaryPixels)) { - if (touchesTopOfImage(det, boundaryPixels) || touchesBottomOfImage(det, imageHeight, boundaryPixels)) { - return true; - } - if (!(touchesBottomOfRegion(det, region, boundaryPixels) || touchesTopOfRegion(det, region, boundaryPixels))) { - return true; - } - } - if (touchesTopOfImage(det, boundaryPixels)) { - if (touchesLeftOfImage(det, boundaryPixels) || touchesRightOfImage(det, imageWidth, boundaryPixels)) { - return true; - } - if (!(touchesLeftOfRegion(det, region, boundaryPixels) || touchesRightOfRegion(det, region, boundaryPixels))) { - return true; - } - } - - if (touchesRightOfImage(det, imageWidth, boundaryPixels)) { - if (touchesTopOfImage(det, boundaryPixels) || touchesBottomOfImage(det, imageHeight, boundaryPixels)) { - return true; - } - if (!(touchesBottomOfRegion(det, region, boundaryPixels) || touchesTopOfRegion(det, region, boundaryPixels))) { - return true; - } - } - if (touchesBottomOfImage(det, imageHeight, boundaryPixels)) { - if (touchesLeftOfImage(det, boundaryPixels) || touchesRightOfImage(det, imageWidth, boundaryPixels)) { - return true; - } - if (!(touchesLeftOfRegion(det, region, boundaryPixels) || touchesRightOfRegion(det, region, boundaryPixels))) { - return true; - } - } - - // remove any objects at other region boundaries - return !(touchesLeftOfRegion(det, region, boundaryPixels) - || touchesRightOfRegion(det, region, boundaryPixels) - || touchesBottomOfRegion(det, region, boundaryPixels) - || touchesTopOfRegion(det, region, boundaryPixels)); - } - } - - private static boolean touchesLeftOfImage(Envelope det, int boundary) { - return det.getMinX() < boundary; - } - - private static boolean touchesRightOfImage(Envelope det, int width, int boundary) { - return width - det.getMaxX() < boundary; - } - - private static boolean touchesTopOfImage(Envelope det, int boundary) { - return det.getMinY() < boundary; - } - - private static boolean touchesBottomOfImage(Envelope det, int height, int boundary) { - return height - det.getMaxY() < boundary; - } - - private static boolean touchesLeftOfRegion(Envelope det, Envelope region, int boundary) { - return det.getMinX() - region.getMinX() < boundary; - } - - private static boolean touchesRightOfRegion(Envelope det, Envelope region, int boundary) { - return region.getMaxX() - det.getMaxX() < boundary; - } - - private static boolean touchesTopOfRegion(Envelope det, Envelope region, int boundary) { - return det.getMinY() - region.getMinY() < boundary; - } - - private static boolean touchesBottomOfRegion(Envelope det, Envelope region, int boundary) { - return region.getMaxY() - det.getMaxY() < boundary; - } -} diff --git a/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java b/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java new file mode 100644 index 0000000..493d54a --- /dev/null +++ b/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java @@ -0,0 +1,164 @@ +package qupath.ext.instanseg.core; + +import org.locationtech.jts.geom.Envelope; +import qupath.lib.experimental.pixels.OutputHandler; +import qupath.lib.experimental.pixels.Parameters; +import qupath.lib.experimental.pixels.PixelProcessorUtils; +import qupath.lib.objects.PathObject; +import qupath.lib.roi.GeometryTools; + +import java.util.List; + +class PruneObjectOutputHandler implements OutputHandler { + + private final OutputToObjectConverter converter; + private final int boundaryThreshold; + + /** + * An output handler that prunes the output, removing any objects that are + * within a certain distance (in pixels) to the tile boundaries, leaving + * all objects on the border of the image. + *

+ * Relies on having a relatively large overlap between tiles. + *

+ * Useful if you want to use for example IoU to merge objects between tiles, + * where the general QuPath approach of merging objects with shared + * boundaries won't work. + * @param converter An output to object converter. + * @param boundaryThreshold The size of the boundary, in pixels, to use for removing objects. + * See {@link #doesntTouchBoundaries} for more details. + */ + PruneObjectOutputHandler(OutputToObjectConverter converter, int boundaryThreshold) { + this.converter = converter; + this.boundaryThreshold = boundaryThreshold; + } + + @Override + public boolean handleOutput(Parameters params, U output) { + if (output == null) + return false; + else { + List newObjects = converter.convertToObjects(params, output); + if (newObjects == null) + return false; + // If using a proxy object (eg tile), + // we want to remove things touching the tile boundary, + // then add the objects to the proxy rather than the parent + var parentOrProxy = params.getParentOrProxy(); + parentOrProxy.clearChildObjects(); + + // remove features within N pixels of the region request boundaries + var bounds = GeometryTools.createRectangle( + params.getRegionRequest().getX(), params.getRegionRequest().getY(), + params.getRegionRequest().getWidth(), params.getRegionRequest().getHeight()); + + int width = params.getServer().getWidth(); + int height = params.getServer().getHeight(); + + newObjects = newObjects.parallelStream() + .filter(p -> doesntTouchBoundaries(p.getROI().getGeometry().getEnvelopeInternal(), bounds.getEnvelopeInternal(), boundaryThreshold, width, height)) + .toList(); + + if (!newObjects.isEmpty()) { + // since we're using IoU to merge objects, we want to keep anything that is within the overall object bounding box + var parent = params.getParent().getROI(); + newObjects = newObjects.parallelStream() + .flatMap(p -> PixelProcessorUtils.maskObject(parent, p).stream()) + .toList(); + } + parentOrProxy.addChildObjects(newObjects); + parentOrProxy.setLocked(true); + return true; + } + } + + + /** + * Tests if a detection is near the boundary of a parent region. + * It first checks if the detection is on the edge of the overall image, in which case it should be kept, + * unless it is at the edge of the image and the perpendicular edge of the parent region. + * For example, on the left side of the image, but on the top/bottom edge of the parent region. + * Then, it checks if the detection is on the boundary of the parent region. + * + * @param det The detection object. + * @param region The region containing all detection objects. + * @param boundaryPixels The size of the boundary, in pixels, to use for removing objects. + * @param imageWidth The width of the image, in pixels. + * @param imageHeight The height of the image, in pixels. + * @return Whether the detection object should be removed, based on these criteria. + */ + private boolean doesntTouchBoundaries(Envelope det, Envelope region, int boundaryPixels, int imageWidth, int imageHeight) { + // keep any objects at the boundary of the annotation, except the stuff around region boundaries + if (touchesLeftOfImage(det, boundaryPixels)) { + if (touchesTopOfImage(det, boundaryPixels) || touchesBottomOfImage(det, imageHeight, boundaryPixels)) { + return true; + } + if (!(touchesBottomOfRegion(det, region, boundaryPixels) || touchesTopOfRegion(det, region, boundaryPixels))) { + return true; + } + } + if (touchesTopOfImage(det, boundaryPixels)) { + if (touchesLeftOfImage(det, boundaryPixels) || touchesRightOfImage(det, imageWidth, boundaryPixels)) { + return true; + } + if (!(touchesLeftOfRegion(det, region, boundaryPixels) || touchesRightOfRegion(det, region, boundaryPixels))) { + return true; + } + } + + if (touchesRightOfImage(det, imageWidth, boundaryPixels)) { + if (touchesTopOfImage(det, boundaryPixels) || touchesBottomOfImage(det, imageHeight, boundaryPixels)) { + return true; + } + if (!(touchesBottomOfRegion(det, region, boundaryPixels) || touchesTopOfRegion(det, region, boundaryPixels))) { + return true; + } + } + if (touchesBottomOfImage(det, imageHeight, boundaryPixels)) { + if (touchesLeftOfImage(det, boundaryPixels) || touchesRightOfImage(det, imageWidth, boundaryPixels)) { + return true; + } + if (!(touchesLeftOfRegion(det, region, boundaryPixels) || touchesRightOfRegion(det, region, boundaryPixels))) { + return true; + } + } + + // remove any objects at other region boundaries + return !(touchesLeftOfRegion(det, region, boundaryPixels) + || touchesRightOfRegion(det, region, boundaryPixels) + || touchesBottomOfRegion(det, region, boundaryPixels) + || touchesTopOfRegion(det, region, boundaryPixels)); + } + + private static boolean touchesLeftOfImage(Envelope det, int boundary) { + return det.getMinX() < boundary; + } + + private static boolean touchesRightOfImage(Envelope det, int width, int boundary) { + return width - det.getMaxX() < boundary; + } + + private static boolean touchesTopOfImage(Envelope det, int boundary) { + return det.getMinY() < boundary; + } + + private static boolean touchesBottomOfImage(Envelope det, int height, int boundary) { + return height - det.getMaxY() < boundary; + } + + private static boolean touchesLeftOfRegion(Envelope det, Envelope region, int boundary) { + return det.getMinX() - region.getMinX() < boundary; + } + + private static boolean touchesRightOfRegion(Envelope det, Envelope region, int boundary) { + return region.getMaxX() - det.getMaxX() < boundary; + } + + private static boolean touchesTopOfRegion(Envelope det, Envelope region, int boundary) { + return det.getMinY() - region.getMinY() < boundary; + } + + private static boolean touchesBottomOfRegion(Envelope det, Envelope region, int boundary) { + return region.getMaxY() - det.getMaxY() < boundary; + } +} From b0cda49f1d1c2e68eaf4d2a81e125b8cb9c838a6 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Mon, 5 Aug 2024 12:50:36 +0100 Subject: [PATCH 26/28] Javadocs --- .../qupath/ext/instanseg/core/InstanSegModel.java | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index e054136..613f67a 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -49,15 +49,27 @@ private InstanSegModel(BioimageIoSpec.BioimageIoModel bioimageIoModel) { this.name = model.getName(); } - public InstanSegModel(URL modelURL, String name) { + + private InstanSegModel(URL modelURL, String name) { this.modelURL = modelURL; this.name = name; } + /** + * Create an InstanSeg model from an existing path. + * @param path The path to the folder that contains the model .pt file and the config YAML file. + * @return A handle on the model that can be used for inference. + * @throws IOException If the directory can't be found or isn't a valid model directory. + */ public static InstanSegModel fromPath(Path path) throws IOException { return new InstanSegModel(BioimageIoSpec.parseModel(path.toFile())); } + /** + * Request an InstanSeg model from the set of available models + * @param name The model name + * @return The specified model. + */ public static InstanSegModel fromName(String name) { // todo: instantiate built-in models somehow throw new UnsupportedOperationException("Fetching models by name is not yet implemented!"); @@ -167,6 +179,7 @@ void runInstanSeg( .build() .loadModel()) { + BaseNDManager baseManager = (BaseNDManager)model.getNDManager(); printResourceCount("Resource count before prediction", (BaseNDManager)baseManager.getParentManager()); From 15b10f3aecf3b3a1ce6ce6ce6dc2982ea943910a Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Mon, 5 Aug 2024 13:09:10 +0100 Subject: [PATCH 27/28] Simplify code flow --- .../InstansegOutputToObjectConverter.java | 59 ++++++++++++------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java index 93aa94c..3952696 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java @@ -53,32 +53,12 @@ public List convertToObjects(Parameters params, Mat output BiFunction function; if (classes.size() == 1) { - // if of length 1, can be - // cellObject (with or without nucleus), annotations, detections - if (classes.get(0) == PathAnnotationObject.class) { - function = createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng); - } else if (classes.get(0) == PathDetectionObject.class) { - function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createDetectionObject, rng); - } else if (classes.get(0) == PathCellObject.class) { - function = createCellFun(rng); - } else { - function = createCellFun(rng); - logger.warn("Unknown output {}, defaulting to cells", classes.get(0)); - } + function = getOneClassBiFunction(classes, rng); } else { // if of length 2, then can be: // detection <- annotation, annotation <- annotation, detection <- detection assert classes.size() == 2; - if (classes.get(0) == PathDetectionObject.class && classes.get(1) == PathAnnotationObject.class) { - function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createAnnotationObject, rng); - } else if (classes.get(0) == PathAnnotationObject.class && classes.get(1) == PathAnnotationObject.class) { - function = createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng); - } else if (classes.get(0) == PathDetectionObject.class && classes.get(1) == PathDetectionObject.class) { - function = createObjectsFun(PathObjects::createDetectionObject, PathObjects::createDetectionObject, rng); - } else { - logger.warn("Unknown combination of outputs {} <- {}, defaulting to cells", classes.get(0), classes.get(1)); - function = createCellFun(rng); - } + function = getTwoClassBiFunction(classes, rng); } if (roiMaps.size() == 1) { @@ -102,6 +82,41 @@ public List convertToObjects(Parameters params, Mat output } } + private static BiFunction getOneClassBiFunction(List> classes, Random rng) { + // if of length 1, can be + // cellObject (with or without nucleus), annotations, detections + if (classes.get(0) == PathAnnotationObject.class) { + return createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng); + } else if (classes.get(0) == PathDetectionObject.class) { + return createObjectsFun(PathObjects::createDetectionObject, PathObjects::createDetectionObject, rng); + } else if (classes.get(0) == PathCellObject.class) { + return createCellFun(rng); + } else { + logger.warn("Unknown output {}, defaulting to cells", classes.get(0)); + return createCellFun(rng); + } + } + + private static BiFunction getTwoClassBiFunction(List> classes, Random rng) { + Function fun0, fun1; + var knownClasses = List.of(PathDetectionObject.class, PathAnnotationObject.class); + if (!knownClasses.contains(classes.get(0)) || !knownClasses.contains(classes.get(1))) { + logger.warn("Unknown combination of outputs {} <- {}, defaulting to cells", classes.get(0), classes.get(1)); + return createCellFun(rng); + } + if (classes.get(0) == PathDetectionObject.class) { + fun0 = PathObjects::createDetectionObject; + } else { + fun0 = PathObjects::createAnnotationObject; + } + if (classes.get(1) == PathDetectionObject.class) { + fun1 = PathObjects::createDetectionObject; + } else { + fun1 = PathObjects::createAnnotationObject; + } + return createObjectsFun(fun0, fun1, rng); + } + private static BiFunction createCellFun(Random rng) { return (parent, child) -> { var cell = PathObjects.createCellObject(parent, child); From f612894586feecbdafeac1f4ee83f38747d464da Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Mon, 5 Aug 2024 13:12:19 +0100 Subject: [PATCH 28/28] Throw error in builder when output channels and output classes mismatch --- src/main/java/qupath/ext/instanseg/core/InstanSeg.java | 3 +++ .../ext/instanseg/core/InstansegOutputToObjectConverter.java | 1 + 2 files changed, 4 insertions(+) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index c7e1010..9a8ce2d 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -370,6 +370,9 @@ public InstanSeg build() { if (outputClasses == null) { var tmp = outputCells(); } + if (outputClasses.size() > 1 && numOutputChannels == 1) { + throw new IllegalArgumentException("Cannot have multiple output types when using only one output channel."); + } return new InstanSeg( this.tileDims, this.downsample, diff --git a/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java index 3952696..85b4271 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java @@ -42,6 +42,7 @@ public List convertToObjects(Parameters params, Mat output if (nChannels < 1 || nChannels > 2) throw new IllegalArgumentException("Expected 1 or 2 channels, but found " + nChannels); + List> roiMaps = new ArrayList<>(); for (var mat : OpenCVTools.splitChannels(output)) { var image = OpenCVTools.matToSimpleImage(mat, 0);