From 9ac30bcc99c174b8d9e40fd76faeefeaf5ef47cc Mon Sep 17 00:00:00 2001 From: Pete Date: Thu, 3 Oct 2024 20:55:34 +0100 Subject: [PATCH 01/14] Explore handling multiple outputs --- .../qupath/ext/instanseg/core/InstanSeg.java | 24 +++++---- .../InstanSegOutputToObjectConverter.java | 52 ++++++++++++++++--- .../ext/instanseg/core/MatTranslator.java | 14 +++-- .../core/TilePredictionProcessor.java | 17 +++--- 4 files changed, 78 insertions(+), 29 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index 1fbb7b0..eb77688 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -12,6 +12,7 @@ import qupath.lib.experimental.pixels.OpenCVProcessor; import qupath.lib.experimental.pixels.OutputHandler; import qupath.lib.experimental.pixels.Parameters; +import qupath.lib.experimental.pixels.PixelProcessor; import qupath.lib.experimental.pixels.Processor; import qupath.lib.images.ImageData; import qupath.lib.images.servers.ColorTransforms; @@ -208,7 +209,7 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec var inputChannels = getInputChannels(imageData); try (var model = Criteria.builder() - .setTypes(Mat.class, Mat.class) + .setTypes(Mat.class, Mat[].class) .optModelUrls(String.valueOf(modelPath.toUri())) .optProgress(new ProgressBar()) .optDevice(device) // Remove this line if devices are problematic! @@ -220,7 +221,7 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec printResourceCount("Resource count before prediction", (BaseNDManager)baseManager.getParentManager()); baseManager.debugDump(2); - BlockingQueue> predictors = new ArrayBlockingQueue<>(nPredictors); + BlockingQueue> predictors = new ArrayBlockingQueue<>(nPredictors); try { for (int i = 0; i < nPredictors; i++) { @@ -234,8 +235,9 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec var predictionProcessor = createProcessor(predictors, inputChannels, tileDims, padToInputSize); var outputHandler = createOutputHandler(preferredOutputClass, randomColors, boundaryThreshold); var postProcessor = createPostProcessor(); - - var processor = OpenCVProcessor.builder(predictionProcessor) + var processor = new PixelProcessor.Builder() + .processor(predictionProcessor) + .maskSupplier(OpenCVProcessor.createMatMaskSupplier()) .imageSupplier((parameters) -> ImageOps.buildImageDataOp(inputChannels) .apply(parameters.getImageData(), parameters.getRegionRequest())) .tiler(tiler) @@ -271,6 +273,7 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec } } + /** * Check if we are requesting tiles for debugging purposes. * When this is true, we should create objects that represent the tiles - not the objects to be detected. @@ -280,7 +283,7 @@ private static boolean debugTiles() { return System.getProperty("instanseg.debug.tiles", "false").strip().equalsIgnoreCase("true"); } - private static Processor createProcessor(BlockingQueue> predictors, + private static Processor createProcessor(BlockingQueue> predictors, Collection inputChannels, int tileDims, boolean padToInputSize) { if (debugTiles()) @@ -288,7 +291,7 @@ private static Processor createProcessor(BlockingQueue parameters) { + private static Mat[] createOnes(Parameters parameters) { var tileRequest = parameters.getTileRequest(); int width, height; if (tileRequest == null) { @@ -299,13 +302,14 @@ private static Mat createOnes(Parameters parameters) { width = tileRequest.getTileWidth(); height = tileRequest.getTileHeight(); } - return Mat.ones(height, width, opencv_core.CV_8UC1).asMat(); + return new Mat[]{Mat.ones(height, width, opencv_core.CV_8UC1).asMat()}; } - private static OutputHandler createOutputHandler(Class preferredOutputClass, boolean randomColors, + private static OutputHandler createOutputHandler(Class preferredOutputClass, boolean randomColors, int boundaryThreshold) { - if (debugTiles()) - return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter()); + // TODO: Reinstate this for Mat[] output (it was written for Mat output) +// if (debugTiles()) +// return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter()); return new PruneObjectOutputHandler<>( new InstanSegOutputToObjectConverter(preferredOutputClass, randomColors), boundaryThreshold); } diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java index a9956e0..e7f431e 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java @@ -14,6 +14,7 @@ import qupath.lib.objects.PathObject; import qupath.lib.objects.PathObjects; import qupath.lib.objects.PathTileObject; +import qupath.lib.objects.classes.PathClass; import qupath.lib.regions.ImagePlane; import qupath.lib.roi.GeometryTools; import qupath.lib.roi.interfaces.ROI; @@ -21,6 +22,7 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -28,7 +30,7 @@ import java.util.function.Function; import java.util.stream.Collectors; -class InstanSegOutputToObjectConverter implements OutputHandler.OutputToObjectConverter { +class InstanSegOutputToObjectConverter implements OutputHandler.OutputToObjectConverter { private static final Logger logger = LoggerFactory.getLogger(InstanSegOutputToObjectConverter.class); @@ -46,18 +48,19 @@ class InstanSegOutputToObjectConverter implements OutputHandler.OutputToObjectCo } @Override - public List convertToObjects(Parameters params, Mat output) { + public List convertToObjects(Parameters params, Mat[] output) { if (output == null) { return List.of(); } - int nChannels = output.channels(); + var matLabels = output[0]; + int nChannels = matLabels.channels(); if (nChannels < 1 || nChannels > 2) throw new IllegalArgumentException("Expected 1 or 2 channels, but found " + nChannels); List> roiMaps = new ArrayList<>(); ImagePlane plane = params.getRegionRequest().getImagePlane(); - for (var mat : OpenCVTools.splitChannels(output)) { + for (var mat : OpenCVTools.splitChannels(matLabels)) { var image = OpenCVTools.matToSimpleImage(mat, 0); var geoms = ContourTracing.createGeometries(image, params.getRegionRequest(), 1, -1); roiMaps.add(geoms.entrySet().stream() @@ -68,6 +71,17 @@ public List convertToObjects(Parameters params, Mat output ); } + // If we have two outputs, the second may give classifications - arrange by row + Map classifications = new HashMap<>(); + if (output.length > 1) { + var matClass = output[1]; + int nRows = matClass.rows(); + for (int r = 0; r < nRows; r++) { + double[] doubles = OpenCVTools.extractDoubles(matClass.row(r)); + classifications.put(r, doubles); + } + } + // We reverse the order because the smaller output (e.g. nucleus) comes before the larger out (e.g. cell) // and we want to iterate in the opposite order. If this changes (or becomes inconsistent) we may need to // sum pixels or areas. @@ -83,8 +97,11 @@ public List convertToObjects(Parameters params, Mat output Map childROIs = roiMaps.size() >= 2 ? roiMaps.get(1) : Collections.emptyMap(); pathObjects = parentROIs.entrySet().stream().map(entry -> { var parent = entry.getValue(); - var child = childROIs.getOrDefault(entry.getKey(), null); - return PathObjects.createCellObject(parent, child); + var label = entry.getKey(); + var child = childROIs.getOrDefault(label, null); + var cell = PathObjects.createCellObject(parent, child); + assignClassificationsIfAvailable(cell, classifications.getOrDefault(label, null)); + return cell; }).toList(); } else { Function createObjectFun = createObjectFun(preferredObjectClass); @@ -92,17 +109,19 @@ public List convertToObjects(Parameters params, Mat output Map parentMap = roiMaps.getFirst(); List> childMaps = roiMaps.size() == 1 ? Collections.emptyList() : roiMaps.subList(1, roiMaps.size()); for (var entry : parentMap.entrySet()) { + var label = entry.getKey(); var roi = entry.getValue(); var pathObject = createObjectFun.apply(roi); if (roiMaps.size() > 1) { for (var subMap : childMaps) { - var childROI = subMap.get(entry.getKey()); + var childROI = subMap.get(label); if (childROI != null) { var childObject = createObjectFun.apply(childROI); pathObject.addChildObject(childObject); } } } + assignClassificationsIfAvailable(pathObject, classifications.getOrDefault(label, null)); pathObjects.add(pathObject); } } @@ -115,6 +134,25 @@ public List convertToObjects(Parameters params, Mat output return pathObjects; } + private static void assignClassificationsIfAvailable(PathObject pathObject, double[] values) { + if (values == null) + return; + try (var ml = pathObject.getMeasurementList()) { + int maxInd = 0; + double maxVal = values[0]; + for (int i = 0; i < values.length; i++) { + double val = values[i]; + if (val > maxVal) { + maxVal = val; + maxInd = i; + } + pathObject.getMeasurementList().put("Prediction " + i, val); + } + pathObject.setPathClass(PathClass.fromString("Class " + maxInd)); + } + } + + /** * Assign a random color to a PathObject and all descendants, returning the object. * @param pathObject diff --git a/src/main/java/qupath/ext/instanseg/core/MatTranslator.java b/src/main/java/qupath/ext/instanseg/core/MatTranslator.java index 2788c92..3d29fdd 100644 --- a/src/main/java/qupath/ext/instanseg/core/MatTranslator.java +++ b/src/main/java/qupath/ext/instanseg/core/MatTranslator.java @@ -7,10 +7,8 @@ import org.bytedeco.opencv.opencv_core.Mat; import qupath.ext.djl.DjlTools; -import java.util.Arrays; - -class MatTranslator implements Translator { +class MatTranslator implements Translator { private final String inputLayoutNd; private final String outputLayoutNd; @@ -59,9 +57,15 @@ public NDList processInput(TranslatorContext ctx, Mat input) { } @Override - public Mat processOutput(TranslatorContext ctx, NDList list) { + public Mat[] processOutput(TranslatorContext ctx, NDList list) { var array = list.getFirst(); - return DjlTools.ndArrayToMat(array, outputLayoutNd); + var labels = DjlTools.ndArrayToMat(array, outputLayoutNd); + var output = new Mat[list.size()]; + output[0] = labels; + for (int i = 1; i < list.size(); i++) { + output[i] = DjlTools.ndArrayToMat(list.get(i), "HW"); + } + return output; } } diff --git a/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java b/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java index 53af730..5ce3ebf 100644 --- a/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java +++ b/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java @@ -32,11 +32,11 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -class TilePredictionProcessor implements Processor { +class TilePredictionProcessor implements Processor { private static final Logger logger = LoggerFactory.getLogger(TilePredictionProcessor.class); - private final BlockingQueue> predictors; + private final BlockingQueue> predictors; private final int inputWidth; private final int inputHeight; @@ -59,7 +59,7 @@ class TilePredictionProcessor implements Processor { */ private final Map normalization = Collections.synchronizedMap(new WeakHashMap<>()); - TilePredictionProcessor(BlockingQueue> predictors, + TilePredictionProcessor(BlockingQueue> predictors, Collection channels, int inputWidth, int inputHeight, boolean doPadding) { this.predictors = predictors; @@ -108,7 +108,7 @@ public boolean wasInterrupted() { } @Override - public Mat process(Parameters params) throws IOException { + public Mat[] process(Parameters params) throws IOException { var mat = params.getImage(); @@ -136,7 +136,7 @@ public Mat process(Parameters params) throws IOException { mat = mat2; } - Predictor predictor = null; + Predictor predictor = null; try { predictor = predictors.take(); logger.debug("Predicting tile {}", mat); @@ -148,9 +148,12 @@ public Mat process(Parameters params) throws IOException { OpenCVTools.matToImagePlus("Output " + params.getRegionRequest(), matOutput).show(); } - matOutput.convertTo(matOutput, opencv_core.CV_32S); + // Handle the first output (labels) + // There may be other outputs (classiications, features), but we don't handle those here + matOutput[0].convertTo(matOutput[0], opencv_core.CV_32S); if (padding != null) - matOutput = OpenCVTools.crop(matOutput, padding); + matOutput[0] = OpenCVTools.crop(matOutput[0], padding); + return matOutput; } catch (TranslateException e) { nTilesFailed.incrementAndGet(); From c965d4ceb5311503645316cadf2dcef8055e16ec Mon Sep 17 00:00:00 2001 From: Pete Date: Thu, 3 Oct 2024 21:02:09 +0100 Subject: [PATCH 02/14] Labels start at 1 --- .../ext/instanseg/core/InstanSegOutputToObjectConverter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java index e7f431e..0634cdf 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java @@ -78,7 +78,7 @@ public List convertToObjects(Parameters params, Mat[] outp int nRows = matClass.rows(); for (int r = 0; r < nRows; r++) { double[] doubles = OpenCVTools.extractDoubles(matClass.row(r)); - classifications.put(r, doubles); + classifications.put(r+1, doubles); } } From 648a57bc103c11793ec5b805a083c58247ce8d5c Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Tue, 10 Dec 2024 21:36:07 +0000 Subject: [PATCH 03/14] Draft RDF code --- .../qupath/ext/instanseg/core/InstanSeg.java | 31 +++++-- .../ext/instanseg/core/InstanSegModel.java | 46 ++++++++++ .../InstanSegOutputToObjectConverter.java | 92 +++++++++++++++---- .../ext/instanseg/ui/CheckModelCache.java | 14 ++- .../ext/instanseg/ui/InstanSegController.java | 1 - 5 files changed, 145 insertions(+), 39 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index ed5706c..a76b8ed 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -9,6 +9,7 @@ import org.bytedeco.opencv.opencv_core.Mat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import qupath.bioimageio.spec.BioimageIoSpec; import qupath.lib.experimental.pixels.OpenCVProcessor; import qupath.lib.experimental.pixels.OutputHandler; import qupath.lib.experimental.pixels.Parameters; @@ -57,10 +58,10 @@ public class InstanSeg { private final InstanSegModel model; private final Device device; private final TaskRunner taskRunner; - private final Class preferredOutputClass; + private final Class preferredOutputType; // This was previously an adjustable parameter, but it's now fixed at 1 because we handle overlaps differently. - // However we might want to reinstate it, possibly as a proportion of the padding amount. + // However, we might want to reinstate it, possibly as a proportion of the padding amount. private final int boundaryThreshold = 1; private InstanSeg(Builder builder) { @@ -72,7 +73,7 @@ private InstanSeg(Builder builder) { this.model = builder.model; this.device = builder.device; this.taskRunner = builder.taskRunner; - this.preferredOutputClass = builder.preferredOutputClass; + this.preferredOutputType = builder.preferredOutputClass; this.randomColors = builder.randomColors; this.makeMeasurements = builder.makeMeasurements; } @@ -157,13 +158,23 @@ private void makeMeasurements(ImageData imageData, Collection imageData, Collection pathObjects) { long startTime = System.currentTimeMillis(); - Optional oModelPath = model.getPath(); if (oModelPath.isEmpty()) { return InstanSegResults.emptyInstance(); } Path modelPath = oModelPath.get().resolve("instanseg.pt"); + Optional> oOutputTensors = this.model.getOutputs(); + if (oOutputTensors.isEmpty()) { + throw new IllegalArgumentException("No output tensors available even though model is available"); + } + var outputTensors = oOutputTensors.get(); + + List outputClasses = this.model.getClasses(); + if (outputClasses.isEmpty() && outputTensors.size() > 1) { + logger.warn("No output classes available, classes will be set as 'Class 1' etc."); + } + // Provide some way to change the number of predictors, even if this can't be specified through the UI // See https://forum.image.sc/t/instanseg-under-utilizing-cpu-only-2-3-cores/104496/7 int nPredictors = Integer.parseInt(System.getProperty("instanseg.numPredictors", "1")); @@ -176,8 +187,6 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec logger.warn("Padding to input size is turned on - this is likely to be slower (but could help fix any issues)"); } String layout = "CHW"; - - // TODO: Remove C if not needed (added for instanseg_v0_2_0.pt) - still relevant? String layoutOutput = "CHW"; // Get the downsample - this may be specified by the user, or determined from the model spec @@ -236,7 +245,7 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec var tiler = createTiler(downsample, tileDims, padding); var predictionProcessor = createProcessor(predictors, inputChannels, tileDims, padToInputSize); - var outputHandler = createOutputHandler(preferredOutputClass, randomColors, boundaryThreshold); + var outputHandler = createOutputHandler(preferredOutputType, randomColors, boundaryThreshold, outputTensors, outputClasses); var postProcessor = createPostProcessor(); var processor = new PixelProcessor.Builder() .processor(predictionProcessor) @@ -281,7 +290,7 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec /** * Check if we are requesting tiles for debugging purposes. * When this is true, we should create objects that represent the tiles - not the objects to be detected. - * @return + * @return Whether the system debugging property is set. */ private static boolean debugTiles() { return System.getProperty("instanseg.debug.tiles", "false").strip().equalsIgnoreCase("true"); @@ -312,11 +321,13 @@ private static Mat[] createOnes(Parameters parameters) { private static OutputHandler createOutputHandler(Class preferredOutputClass, boolean randomColors, - int boundaryThreshold) { + int boundaryThreshold, + List outputTensors, + List outputClasses) { // TODO: Reinstate this for Mat[] output (it was written for Mat output) // if (debugTiles()) // return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter()); - var converter = new InstanSegOutputToObjectConverter(preferredOutputClass, randomColors); + var converter = new InstanSegOutputToObjectConverter(outputTensors, outputClasses, preferredOutputClass, randomColors); if (boundaryThreshold >= 0) { return new PruneObjectOutputHandler<>(converter, boundaryThreshold); } else { diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index 3b35262..ef26d60 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -20,6 +20,8 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -248,6 +250,50 @@ public Optional getNumChannels() { return getModel().flatMap(model -> Optional.of(extractChannelNum(model))); } + /** + * Try to check the output tensors from the model spec. + * @return The output tensors if the model is downloaded, otherwise empty. + */ + public Optional> getOutputs() { + return getModel().flatMap(model -> Optional.ofNullable(model.getOutputs())); + } + + /** + * Try to check the output classes from the model spec. + * @return The output classes if the model is downloaded, and it's present, otherwise empty. + */ + public List getClasses() { + var config = model.getConfig().getOrDefault("qupath", null); + if (config instanceof Map configMap) { + List classes = new ArrayList<>(); + var tmp = (List) configMap.get("classes"); + System.out.println(tmp); + for (var t: tmp) { + classes.add(t.toString()); + } + return classes; + } + return List.of(); + } + + public enum OutputType { + // "instance segmentation" "cell embeddings" "cell classes" "cell probabilities" "semantic segmentation" + INSTANCE_SEGMENTATION("instance_segmentation"), + CELL_EMBEDDINGS("cell_embeddings"), + CELL_PROBABILITIES("cell_probabilities"), + CELL_CLASSES("cell_classes"), + SEMANTIC_SEGMENTATION("semantic_segmentation"); + + private final String type; + OutputType(String type) { + this.type = type; + } + @Override + public String toString() { + return type; + } + } + private static int extractChannelNum(BioimageIoSpec.BioimageIoModel model) { int ind = model.getInputs().getFirst().getAxes().toLowerCase().indexOf("c"); var shape = model.getInputs().getFirst().getShape(); diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java index 0634cdf..c01ea9d 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java @@ -4,6 +4,7 @@ import org.locationtech.jts.geom.Geometry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import qupath.bioimageio.spec.BioimageIoSpec; import qupath.lib.analysis.images.ContourTracing; import qupath.lib.common.ColorTools; import qupath.lib.experimental.pixels.OutputHandler; @@ -41,9 +42,16 @@ class InstanSegOutputToObjectConverter implements OutputHandler.OutputToObjectCo * This may be turned off or made optional in the future. */ private final boolean randomColors; + private final List outputTensors; + private final List outputClasses; - InstanSegOutputToObjectConverter(Class preferredObjectClass, boolean randomColors) { - this.preferredObjectClass = preferredObjectClass; + InstanSegOutputToObjectConverter(List outputTensors, + List outputClasses, + Class preferredOutputType, + boolean randomColors) { + this.outputTensors = outputTensors; + this.outputClasses = outputClasses; + this.preferredObjectClass = preferredOutputType; this.randomColors = randomColors; } @@ -71,14 +79,23 @@ public List convertToObjects(Parameters params, Mat[] outp ); } + // "instance segmentation" "cell embeddings" "cell classes" "cell probabilities" "semantic segmentation" // If we have two outputs, the second may give classifications - arrange by row - Map classifications = new HashMap<>(); + List> auxiliaryValues = new ArrayList<>(); + auxiliaryValues.add(new HashMap<>()); if (output.length > 1) { - var matClass = output[1]; - int nRows = matClass.rows(); - for (int r = 0; r < nRows; r++) { - double[] doubles = OpenCVTools.extractDoubles(matClass.row(r)); - classifications.put(r+1, doubles); + Map auxVals = new HashMap<>(); + System.out.println(output[1].dims()); + System.out.println(output[1].rows()); + System.out.println(output[1].cols()); + for (int i = 1; i < output.length; i++) { + var matClass = output[i]; + int nRows = matClass.rows(); + for (int r = 0; r < nRows; r++) { + double[] doubles = OpenCVTools.extractDoubles(matClass.row(r)); + auxVals.put(r+1, doubles); + } + auxiliaryValues.add(auxVals); } } @@ -100,7 +117,14 @@ public List convertToObjects(Parameters params, Mat[] outp var label = entry.getKey(); var child = childROIs.getOrDefault(label, null); var cell = PathObjects.createCellObject(parent, child); - assignClassificationsIfAvailable(cell, classifications.getOrDefault(label, null)); + for (int i = 1; i < output.length; i++) { + handleAuxOutput( + cell, + auxiliaryValues.get(i).getOrDefault(label, null), + InstanSegModel.OutputType.valueOf(outputTensors.get(i).getName().toUpperCase()), + outputClasses + ); + } return cell; }).toList(); } else { @@ -121,7 +145,14 @@ public List convertToObjects(Parameters params, Mat[] outp } } } - assignClassificationsIfAvailable(pathObject, classifications.getOrDefault(label, null)); + for (int i = 1; i < output.length; i++) { + handleAuxOutput( + pathObject, + auxiliaryValues.get(i).getOrDefault(label, null), + InstanSegModel.OutputType.valueOf(outputTensors.get(i).getName().toUpperCase()), + outputClasses + ); + } pathObjects.add(pathObject); } } @@ -134,21 +165,42 @@ public List convertToObjects(Parameters params, Mat[] outp return pathObjects; } - private static void assignClassificationsIfAvailable(PathObject pathObject, double[] values) { + private static void handleAuxOutput(PathObject pathObject, double[] values, InstanSegModel.OutputType outputType, List outputClasses) { if (values == null) return; - try (var ml = pathObject.getMeasurementList()) { - int maxInd = 0; - double maxVal = values[0]; + if (outputType == InstanSegModel.OutputType.CELL_PROBABILITIES) { + try (var ml = pathObject.getMeasurementList()) { + int maxInd = 0; + double maxVal = values[0]; + for (int i = 0; i < values.length; i++) { + double val = values[i]; + if (val > maxVal) { + maxVal = val; + maxInd = i; + } + ml.put("Probability " + i, val); + } + pathObject.setPathClass(PathClass.fromString(outputClasses.get(maxInd))); + // todo: get class names from RDF + } + } + if (outputType == InstanSegModel.OutputType.CELL_EMBEDDINGS) { + try (var ml = pathObject.getMeasurementList()) { + for (int i = 0; i < values.length; i++) { + double val = values[i]; + ml.put("Embedding " + i, val); + } + } + } + if (outputType == InstanSegModel.OutputType.CELL_CLASSES) { for (int i = 0; i < values.length; i++) { double val = values[i]; - if (val > maxVal) { - maxVal = val; - maxInd = i; - } - pathObject.getMeasurementList().put("Prediction " + i, val); + pathObject.setPathClass(PathClass.fromString("Class " + outputClasses.get((int)val))); + // todo: get class names from RDF } - pathObject.setPathClass(PathClass.fromString("Class " + maxInd)); + } + if (outputType == InstanSegModel.OutputType.SEMANTIC_SEGMENTATION) { + throw new UnsupportedOperationException("No idea what to do here!"); } } diff --git a/src/main/java/qupath/ext/instanseg/ui/CheckModelCache.java b/src/main/java/qupath/ext/instanseg/ui/CheckModelCache.java index fb0fdbf..82e9108 100644 --- a/src/main/java/qupath/ext/instanseg/ui/CheckModelCache.java +++ b/src/main/java/qupath/ext/instanseg/ui/CheckModelCache.java @@ -5,9 +5,7 @@ import javafx.beans.property.SimpleObjectProperty; import javafx.beans.value.ObservableValue; import org.controlsfx.control.CheckComboBox; -import org.controlsfx.control.CheckModel; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.WeakHashMap; @@ -22,7 +20,7 @@ public class CheckModelCache { private final ObjectProperty value = new SimpleObjectProperty<>(); private final CheckComboBox checkbox; - private final Map> lastChecks = new WeakHashMap<>(); + private final Map> lastChecks = new WeakHashMap<>(); private CheckModelCache(ObservableValue value, CheckComboBox checkbox) { this.value.bind(value); @@ -48,7 +46,7 @@ public static CheckModelCache create(ObservableValue value, Chec private void handleValueChange(ObservableValue observable, S oldValue, S newValue) { if (oldValue != null) { - lastChecks.put(oldValue, List.copyOf(checkbox.getCheckModel().getCheckedItems())); + lastChecks.put(oldValue, List.copyOf(checkbox.getCheckModel().getCheckedIndices())); } } @@ -72,11 +70,11 @@ public ReadOnlyObjectProperty valueProperty() { * @return true if the checks were restored, false otherwise */ public boolean restoreChecks() { - List checks = lastChecks.get(value.get()); - if (checks != null && new HashSet<>(checkbox.getItems()).containsAll(checks)) { + List checks = lastChecks.get(value.get()); + if (checks != null && checks.stream().allMatch(i -> i < checkbox.getItems().size())) { var checkModel = checkbox.getCheckModel(); checkModel.clearChecks(); - checks.forEach(checkModel::check); + checks.forEach(checkModel::checkIndices); return true; } else { return false; @@ -108,7 +106,7 @@ public boolean resetChecks() { public boolean snapshotChecks() { var val = value.get(); if (val != null) { - lastChecks.put(val, List.copyOf(checkbox.getCheckModel().getCheckedItems())); + lastChecks.put(val, List.copyOf(checkbox.getCheckModel().getCheckedIndices())); return true; } else { return false; diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index 09220ca..bdb3ad9 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -382,7 +382,6 @@ private void updateInputChannels(ImageData imageData) { comboInputChannels.getCheckModel().checkIndices(0, 1, 2); } } - } } From 1c8872f5a83209f84d0d73dc5b702376e20d0f46 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Wed, 18 Dec 2024 14:15:05 +0000 Subject: [PATCH 04/14] 0.5 spec changes --- .../qupath/ext/instanseg/core/InstanSeg.java | 1 - .../ext/instanseg/core/InstanSegModel.java | 17 ++++++++++------- .../core/InstanSegOutputToObjectConverter.java | 3 --- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index a76b8ed..9ae9621 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -156,7 +156,6 @@ private void makeMeasurements(ImageData imageData, Collection imageData, Collection pathObjects) { - long startTime = System.currentTimeMillis(); Optional oModelPath = model.getPath(); if (oModelPath.isEmpty()) { diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index ef26d60..d97f03c 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -21,7 +21,6 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -29,6 +28,8 @@ import java.util.zip.ZipInputStream; import java.util.Objects; +import static qupath.bioimageio.spec.BioimageIoSpec.getAxesString; + public class InstanSegModel { private static final Logger logger = LoggerFactory.getLogger(InstanSegModel.class); @@ -266,10 +267,11 @@ public List getClasses() { var config = model.getConfig().getOrDefault("qupath", null); if (config instanceof Map configMap) { List classes = new ArrayList<>(); - var tmp = (List) configMap.get("classes"); - System.out.println(tmp); - for (var t: tmp) { - classes.add(t.toString()); + var el = configMap.get("classes"); + if (el != null && el instanceof List elList) { + for (var t: elList) { + classes.add(t.toString()); + } } return classes; } @@ -295,7 +297,8 @@ public String toString() { } private static int extractChannelNum(BioimageIoSpec.BioimageIoModel model) { - int ind = model.getInputs().getFirst().getAxes().toLowerCase().indexOf("c"); + String axes = getAxesString(model.getInputs().getFirst().getAxes()); + int ind = axes.indexOf("c"); var shape = model.getInputs().getFirst().getShape(); if (shape.getShapeStep()[ind] == 1) { return ANY_CHANNELS; @@ -437,7 +440,7 @@ private Optional> getPixelSize() { public Optional getOutputChannels() { return getModel().map(model -> { var output = model.getOutputs().getFirst(); - String axes = output.getAxes().toLowerCase(); + String axes = getAxesString(output.getAxes()); int ind = axes.indexOf("c"); var shape = output.getShape().getShape(); if (shape != null && shape.length > ind) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java index c01ea9d..e43db07 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java @@ -85,9 +85,6 @@ public List convertToObjects(Parameters params, Mat[] outp auxiliaryValues.add(new HashMap<>()); if (output.length > 1) { Map auxVals = new HashMap<>(); - System.out.println(output[1].dims()); - System.out.println(output[1].rows()); - System.out.println(output[1].cols()); for (int i = 1; i < output.length; i++) { var matClass = output[i]; int nRows = matClass.rows(); From 1a7bd67efd9b7d80ea839915bf12d863f1aef9fd Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Fri, 3 Jan 2025 11:08:48 +0000 Subject: [PATCH 05/14] tmp --- .../qupath/ext/instanseg/core/InstanSeg.java | 8 +-- .../ext/instanseg/core/InstanSegModel.java | 6 +- .../InstanSegOutputToObjectConverter.java | 63 +++++++++---------- 3 files changed, 38 insertions(+), 39 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index 9ae9621..d9a596b 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -207,6 +207,7 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec // Create an int[] representing a boolean array of channels to use boolean[] outputChannelArray = null; if (outputChannels != null && outputChannels.length > 0) { + //noinspection OptionalGetWithoutIsPresent outputChannelArray = new boolean[model.getOutputChannels().get()]; // safe to call get because of previous checks for (int c : outputChannels) { if (c < 0 || c >= outputChannelArray.length) { @@ -244,7 +245,7 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec var tiler = createTiler(downsample, tileDims, padding); var predictionProcessor = createProcessor(predictors, inputChannels, tileDims, padToInputSize); - var outputHandler = createOutputHandler(preferredOutputType, randomColors, boundaryThreshold, outputTensors, outputClasses); + var outputHandler = createOutputHandler(preferredOutputType, randomColors, boundaryThreshold, outputTensors); var postProcessor = createPostProcessor(); var processor = new PixelProcessor.Builder() .processor(predictionProcessor) @@ -321,12 +322,11 @@ private static Mat[] createOnes(Parameters parameters) { private static OutputHandler createOutputHandler(Class preferredOutputClass, boolean randomColors, int boundaryThreshold, - List outputTensors, - List outputClasses) { + List outputTensors) { // TODO: Reinstate this for Mat[] output (it was written for Mat output) // if (debugTiles()) // return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter()); - var converter = new InstanSegOutputToObjectConverter(outputTensors, outputClasses, preferredOutputClass, randomColors); + var converter = new InstanSegOutputToObjectConverter(outputTensors, preferredOutputClass, randomColors); if (boundaryThreshold >= 0) { return new PruneObjectOutputHandler<>(converter, boundaryThreshold); } else { diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index d97f03c..5ab044f 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -281,9 +281,9 @@ public List getClasses() { public enum OutputType { // "instance segmentation" "cell embeddings" "cell classes" "cell probabilities" "semantic segmentation" INSTANCE_SEGMENTATION("instance_segmentation"), - CELL_EMBEDDINGS("cell_embeddings"), - CELL_PROBABILITIES("cell_probabilities"), - CELL_CLASSES("cell_classes"), + DETECTION_EMBEDDINGS("detection_embeddings"), + DETECTION_LOGITS("detection_logits"), + DETECTION_CLASSES("detection_classes"), SEMANTIC_SEGMENTATION("semantic_segmentation"); private final String type; diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java index e43db07..4a6cd99 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java @@ -43,14 +43,11 @@ class InstanSegOutputToObjectConverter implements OutputHandler.OutputToObjectCo */ private final boolean randomColors; private final List outputTensors; - private final List outputClasses; InstanSegOutputToObjectConverter(List outputTensors, - List outputClasses, Class preferredOutputType, boolean randomColors) { this.outputTensors = outputTensors; - this.outputClasses = outputClasses; this.preferredObjectClass = preferredOutputType; this.randomColors = randomColors; } @@ -118,8 +115,7 @@ public List convertToObjects(Parameters params, Mat[] outp handleAuxOutput( cell, auxiliaryValues.get(i).getOrDefault(label, null), - InstanSegModel.OutputType.valueOf(outputTensors.get(i).getName().toUpperCase()), - outputClasses + outputTensors.get(i) ); } return cell; @@ -146,8 +142,7 @@ public List convertToObjects(Parameters params, Mat[] outp handleAuxOutput( pathObject, auxiliaryValues.get(i).getOrDefault(label, null), - InstanSegModel.OutputType.valueOf(outputTensors.get(i).getName().toUpperCase()), - outputClasses + outputTensors.get(i) ); } pathObjects.add(pathObject); @@ -162,40 +157,44 @@ public List convertToObjects(Parameters params, Mat[] outp return pathObjects; } - private static void handleAuxOutput(PathObject pathObject, double[] values, InstanSegModel.OutputType outputType, List outputClasses) { + private static void handleAuxOutput(PathObject pathObject, double[] values, BioimageIoSpec.OutputTensor outputTensor) { + List outputClasses = List.of(); // todo if (values == null) return; - if (outputType == InstanSegModel.OutputType.CELL_PROBABILITIES) { - try (var ml = pathObject.getMeasurementList()) { - int maxInd = 0; - double maxVal = values[0]; - for (int i = 0; i < values.length; i++) { - double val = values[i]; - if (val > maxVal) { - maxVal = val; - maxInd = i; + var outputType = InstanSegModel.OutputType.valueOf(outputTensor.getName().toUpperCase()); + switch(outputType) { + case DETECTION_LOGITS -> { + try (var ml = pathObject.getMeasurementList()) { + int maxInd = 0; + double maxVal = values[0]; + for (int i = 0; i < values.length; i++) { + double val = values[i]; + if (val > maxVal) { + maxVal = val; + maxInd = i; + } + ml.put("Logit class " + i, val); } - ml.put("Probability " + i, val); + pathObject.setPathClass(PathClass.fromString(outputClasses.get(maxInd))); + // todo: get class names from RDF } - pathObject.setPathClass(PathClass.fromString(outputClasses.get(maxInd))); - // todo: get class names from RDF } - } - if (outputType == InstanSegModel.OutputType.CELL_EMBEDDINGS) { - try (var ml = pathObject.getMeasurementList()) { - for (int i = 0; i < values.length; i++) { - double val = values[i]; - ml.put("Embedding " + i, val); + case DETECTION_EMBEDDINGS -> { + try (var ml = pathObject.getMeasurementList()) { + for (int i = 0; i < values.length; i++) { + double val = values[i]; + ml.put("Embedding " + i, val); + } } } - } - if (outputType == InstanSegModel.OutputType.CELL_CLASSES) { - for (int i = 0; i < values.length; i++) { - double val = values[i]; - pathObject.setPathClass(PathClass.fromString("Class " + outputClasses.get((int)val))); - // todo: get class names from RDF + case DETECTION_CLASSES -> { + for (double val : values) { + pathObject.setPathClass(PathClass.fromString("Class " + outputClasses.get((int) val))); + // todo: get class names from RDF + } } } + if (outputType == InstanSegModel.OutputType.SEMANTIC_SEGMENTATION) { throw new UnsupportedOperationException("No idea what to do here!"); } From 6c07c625f6039034d8a54d141aa308e8d60dffac Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Wed, 8 Jan 2025 17:23:47 +0000 Subject: [PATCH 06/14] tmp --- .../instanseg/core/InstanSegOutputToObjectConverter.java | 6 +++++- src/main/java/qupath/ext/instanseg/ui/InstanSegTask.java | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java index 4a6cd99..abd7706 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java @@ -158,7 +158,11 @@ public List convertToObjects(Parameters params, Mat[] outp } private static void handleAuxOutput(PathObject pathObject, double[] values, BioimageIoSpec.OutputTensor outputTensor) { - List outputClasses = List.of(); // todo + List outputClasses = new ArrayList<>(); // todo: get from RDF + int nClasses = outputTensor.getShape().getShape()[2]; + for (int i = 0; i < nClasses; i++) { + outputClasses.add("Class" + i); + } if (values == null) return; var outputType = InstanSegModel.OutputType.valueOf(outputTensor.getName().toUpperCase()); diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegTask.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegTask.java index 8132bd2..c597d48 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegTask.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegTask.java @@ -62,6 +62,7 @@ protected Void call() { return null; } // TODO: HANDLE OUTPUT CHANNELS! + // todo: Unclear what this means int nOutputs = model.getOutputChannels().orElse(1); int[] outputChannels = new int[0]; if (nOutputs <= 0) { From c393f4aa420c1203044aaedb904a42a56572a1aa Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Thu, 9 Jan 2025 15:50:44 +0000 Subject: [PATCH 07/14] Fetch class names from RDF for class outputs at least --- .../InstanSegOutputToObjectConverter.java | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java index abd7706..815108b 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java @@ -158,11 +158,18 @@ public List convertToObjects(Parameters params, Mat[] outp } private static void handleAuxOutput(PathObject pathObject, double[] values, BioimageIoSpec.OutputTensor outputTensor) { - List outputClasses = new ArrayList<>(); // todo: get from RDF - int nClasses = outputTensor.getShape().getShape()[2]; - for (int i = 0; i < nClasses; i++) { - outputClasses.add("Class" + i); + List outputClasses; + var description = outputTensor.getDataDescription(); + if (description instanceof BioimageIoSpec.NominalOrOrdinalDataDescription dataDescription) { + outputClasses = dataDescription.getValues().stream().map(Object::toString).toList(); + } else { + outputClasses = new ArrayList<>(); + int nClasses = outputTensor.getShape().getShape()[2]; // batch, object, class + for (int i = 0; i < nClasses; i++) { + outputClasses.add("Class" + i); + } } + if (values == null) return; var outputType = InstanSegModel.OutputType.valueOf(outputTensor.getName().toUpperCase()); @@ -177,10 +184,9 @@ private static void handleAuxOutput(PathObject pathObject, double[] values, Bioi maxVal = val; maxInd = i; } - ml.put("Logit class " + i, val); + ml.put("Logit " + outputClasses.get(i), val); } pathObject.setPathClass(PathClass.fromString(outputClasses.get(maxInd))); - // todo: get class names from RDF } } case DETECTION_EMBEDDINGS -> { @@ -193,8 +199,7 @@ private static void handleAuxOutput(PathObject pathObject, double[] values, Bioi } case DETECTION_CLASSES -> { for (double val : values) { - pathObject.setPathClass(PathClass.fromString("Class " + outputClasses.get((int) val))); - // todo: get class names from RDF + pathObject.setPathClass(PathClass.fromString(outputClasses.get((int) val))); } } } From 605ca66168bf67ea6a5d47b76ab501b5bd458405 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Thu, 9 Jan 2025 16:42:21 +0000 Subject: [PATCH 08/14] Linting --- .../qupath/ext/instanseg/core/InstanSeg.java | 14 ++++--- .../InstanSegOutputToObjectConverter.java | 13 +++---- .../ext/instanseg/core/InstanSegResults.java | 2 +- .../core/PruneObjectOutputHandler.java | 2 +- .../ext/instanseg/core/PytorchManager.java | 18 ++++----- .../core/TilePredictionProcessor.java | 12 ++++-- .../ext/instanseg/ui/CheckModelCache.java | 14 +++---- .../ext/instanseg/ui/InputChannelItem.java | 4 +- .../ext/instanseg/ui/InstanSegController.java | 38 +++++++++---------- .../instanseg/ui/InstanSegPreferences.java | 3 +- .../ext/instanseg/ui/InstanSegTask.java | 1 - .../ext/instanseg/ui/ModelListCell.java | 6 +-- .../ext/instanseg/ui/OutputChannelItem.java | 2 +- .../java/qupath/ext/instanseg/ui/Watcher.java | 17 ++++----- 14 files changed, 75 insertions(+), 71 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index d9a596b..9cc1ac2 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -315,7 +315,9 @@ private static Mat[] createOnes(Parameters parameters) { width = tileRequest.getTileWidth(); height = tileRequest.getTileHeight(); } - return new Mat[]{Mat.ones(height, width, opencv_core.CV_8UC1).asMat()}; + try (var ones = Mat.ones(height, width, opencv_core.CV_8UC1)) { + return new Mat[]{ones.asMat()}; + } } @@ -345,8 +347,8 @@ private static Tiler createTiler(double downsample, int tileDims, int padding) { /** * Get the input channels to use; if we don't have any specified, use all of them - * @param imageData - * @return + * @param imageData The image data + * @return The possible input channels. */ private List getInputChannels(ImageData imageData) { if (inputChannels == null || inputChannels.isEmpty()) { @@ -375,8 +377,8 @@ private static ObjectProcessor createPostProcessor() { /** * Print resource count for debugging purposes. * If we are not logging at debug level, do nothing. - * @param title - * @param manager + * @param title The name to be used in the log. + * @param manager The NDManager to print from. */ private static void printResourceCount(String title, BaseNDManager manager) { if (logger.isDebugEnabled()) { @@ -565,7 +567,7 @@ public Builder randomColors() { /** * Optionally request that random colors be used for the output objects. - * @param doRandomColors + * @param doRandomColors Whether to use random colors for output object. * @return this builder */ public Builder randomColors(boolean doRandomColors) { diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java index 815108b..f5f05d2 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java @@ -212,16 +212,15 @@ private static void handleAuxOutput(PathObject pathObject, double[] values, Bioi /** * Assign a random color to a PathObject and all descendants, returning the object. - * @param pathObject - * @param rng - * @return + * + * @param pathObject The PathObject + * @param rng A random number generator. */ - private static PathObject assignRandomColor(PathObject pathObject, Random rng) { + private static void assignRandomColor(PathObject pathObject, Random rng) { pathObject.setColor(randomRGB(rng)); for (var child : pathObject.getChildObjects()) { assignRandomColor(child, rng); } - return pathObject; } private static ROI geometryToFilledROI(Geometry geom, ImagePlane plane) { @@ -245,8 +244,8 @@ else if (Objects.equals(PathTileObject.class, preferredObjectClass)) /** * Create annotations that are locked by default, to reduce the risk of editing them accidentally. - * @param roi - * @return + * @param roi The region of interest + * @return A locked annotation object. */ private static PathObject createLockedAnnotation(ROI roi) { var annotation = PathObjects.createAnnotationObject(roi); diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegResults.java b/src/main/java/qupath/ext/instanseg/core/InstanSegResults.java index 72fd4ca..9d1eb9a 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegResults.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegResults.java @@ -21,7 +21,7 @@ public record InstanSegResults( /** * Get an empty instance of InstanSegResults. - * @return + * @return An empty instance with default values. */ public static InstanSegResults emptyInstance() { return EMPTY; diff --git a/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java b/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java index fc6bdea..fa23712 100644 --- a/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java +++ b/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java @@ -41,7 +41,7 @@ public boolean handleOutput(Parameters params, U output) { List newObjects = converter.convertToObjects(params, output); if (newObjects == null) return false; - // If using a proxy object (eg tile), + // If using a proxy object (e.g. tile), // we want to remove things touching the tile boundary, // then add the objects to the proxy rather than the parent var parentOrProxy = params.getParentOrProxy(); diff --git a/src/main/java/qupath/ext/instanseg/core/PytorchManager.java b/src/main/java/qupath/ext/instanseg/core/PytorchManager.java index f0ba496..3ef9de3 100644 --- a/src/main/java/qupath/ext/instanseg/core/PytorchManager.java +++ b/src/main/java/qupath/ext/instanseg/core/PytorchManager.java @@ -65,7 +65,7 @@ public static Collection getAvailableDevices() { /** * Query if the PyTorch engine is already available, without a need to download. - * @return + * @return whether PyTorch is available. */ public static boolean hasPyTorchEngine() { return getEngineOffline() != null; @@ -86,10 +86,10 @@ static Engine getEngineOffline() { /** * Call a function with the "offline" property set to true (to block automatic downloads). - * @param callable - * @return - * @param - * @throws Exception + * @param callable The function to be called. + * @return The return value of the callable. + * @param The return type of the callable. + * @throws Exception If the callable does. */ private static T callOffline(Callable callable) throws Exception { return callWithTempProperty("ai.djl.offline", "true", callable); @@ -97,10 +97,10 @@ private static T callOffline(Callable callable) throws Exception { /** * Call a function with the "offline" property set to false (to allow automatic downloads). - * @param callable - * @return - * @param - * @throws Exception + * @param callable The function to be called. + * @return The return value of the callable. + * @param The return type of the callable. + * @throws Exception If the callable does. */ private static T callOnline(Callable callable) throws Exception { return callWithTempProperty("ai.djl.offline", "false", callable); diff --git a/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java b/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java index 5ce3ebf..305015f 100644 --- a/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java +++ b/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java @@ -91,7 +91,7 @@ public int getTilesFailedCount() { * The number of channels does not influence the result. *

* One use of this is to help assess the impact of padding on the processing time. - * @return the pixels that were processed + * @return the pixels that were processed. */ public long getPixelsProcessedCount() { return nPixelsProcessed.get(); @@ -101,7 +101,7 @@ public long getPixelsProcessedCount() { * Check if the processing was interrupted. * This can be used to determine if the processing was stopped prematurely, * and failed tiles were not necessarily errors. - * @return + * @return whether processing was interrupted. */ public boolean wasInterrupted() { return wasInterrupted.get(); @@ -192,7 +192,13 @@ public Mat[] process(Parameters params) throws IOException { * @return Percentile-based normalisation based on the bounding box, * or default tile-based percentile normalisation if that fails. */ - private static ImageOp getNormalization(ImageData imageData, ROI roi, Collection channels, double lowPerc, double highPerc) { + private static ImageOp getNormalization( + ImageData imageData, + ROI roi, + Collection channels, + double lowPerc, + double highPerc) { + var defaults = ImageOps.Normalize.percentile(lowPerc, highPerc, true, 1e-6); try { BufferedImage image; diff --git a/src/main/java/qupath/ext/instanseg/ui/CheckModelCache.java b/src/main/java/qupath/ext/instanseg/ui/CheckModelCache.java index 82e9108..e848d31 100644 --- a/src/main/java/qupath/ext/instanseg/ui/CheckModelCache.java +++ b/src/main/java/qupath/ext/instanseg/ui/CheckModelCache.java @@ -34,11 +34,11 @@ private CheckModelCache(ObservableValue value, CheckComboBox checkbox) { *

* This can then be used to restore the checks later, if needed. * - * @param value - * @param checkBox - * @return - * @param - * @param + * @param value the observable value that this cache corresponds to (e.g., a specific model's name). + * @param checkBox the CheckComboBox containing the possibly checked items. + * @return A {@link CheckModelCache} + * @param The observable value type. + * @param The type of checked item. */ public static CheckModelCache create(ObservableValue value, CheckComboBox checkBox) { return new CheckModelCache<>(value, checkBox); @@ -52,7 +52,7 @@ private void handleValueChange(ObservableValue observable, S oldVal /** * Get the value property. - * @return + * @return The value property. */ public ReadOnlyObjectProperty valueProperty() { return value; @@ -101,7 +101,7 @@ public boolean resetChecks() { * Create a snapshot of the checks currently associated with the observable value. * This is useful in case some other checkbox manipulation is required without switching the value, * and we want to restore the checks later (e.g. changing the items). - * @return + * @return whether the snapshot was possible. */ public boolean snapshotChecks() { var val = value.get(); diff --git a/src/main/java/qupath/ext/instanseg/ui/InputChannelItem.java b/src/main/java/qupath/ext/instanseg/ui/InputChannelItem.java index 4a477a6..3376024 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InputChannelItem.java +++ b/src/main/java/qupath/ext/instanseg/ui/InputChannelItem.java @@ -63,8 +63,8 @@ String getConstructor() { /** * Get a list of available channels for the given image data. - * @param imageData - * @return + * @param imageData The image data. + * @return A list of channels, including color deconvolution channels if present. */ static List getAvailableChannels(ImageData imageData) { List list = new ArrayList<>(); diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index 28c5c0c..67d5425 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -143,7 +143,7 @@ public class InstanSegController extends BorderPane { private List remoteModels; - private BooleanProperty requestingPyTorch = new SimpleBooleanProperty(false); + private final BooleanProperty requestingPyTorch = new SimpleBooleanProperty(false); // Listener for property changes in the current ImageData; these can be required to update the input channels private final PropertyChangeListener imageDataPropertyChangeListener = this::handleImageDataPropertyChange; @@ -249,9 +249,7 @@ private void configureInfoButton() { infoButton.disableProperty().bind(selectedModelIsAvailable.not()); WebView webView = WebViews.create(true); PopOver infoPopover = new PopOver(webView); - infoButton.setOnAction(e -> { - parseMarkdown(selectedModel.get(), webView, infoButton, infoPopover); - }); + infoButton.setOnAction(e -> parseMarkdown(selectedModel.get(), webView, infoButton, infoPopover)); } private void configureDownloadButton() { @@ -485,12 +483,15 @@ private void refreshModelChoice() { return; var modelDir = InstanSegUtils.getModelDirectory().orElse(null); - try { - model.checkIfDownloaded(modelDir.resolve("downloaded"), false); - } catch (IOException e) { - logger.debug("Error checking zip or RDF file(s); this shouldn't happen", e); - Dialogs.showErrorNotification(resources.getString("title"), resources.getString("error.checkingModel")); + if (modelDir != null) { + try { + model.checkIfDownloaded(modelDir.resolve("downloaded"), false); + } catch (IOException e) { + logger.debug("Error checking zip or RDF file(s); this shouldn't happen", e); + Dialogs.showErrorNotification(resources.getString("title"), resources.getString("error.checkingModel")); + } } + boolean isDownloaded = modelDir != null && model.isValid(); if (!isDownloaded || qupath.getImageData() == null) { return; @@ -511,7 +512,7 @@ private void refreshModelChoice() { /** * Try to download the currently-selected model in another thread. - * @return + * @return A Future that is completed when the download finishes. */ private CompletableFuture downloadSelectedModelAsync() { var model = selectedModel.get(); @@ -523,8 +524,8 @@ private CompletableFuture downloadSelectedModelAsync() { /** * Try to download the specified model in another thread. - * @param model - * @return + * @param model The model + * @return A future that is completed when the download finishes. */ private CompletableFuture downloadModelAsync(InstanSegModel model) { var modelDir = InstanSegUtils.getModelDirectory().orElse(null); @@ -538,9 +539,9 @@ private CompletableFuture downloadModelAsync(InstanSegModel mode /** * Try to download the specified model to the given directory in the current thread. - * @param model - * @param modelDir - * @return + * @param model The model. + * @param modelDir The model directory. + * @return The downloaded model. */ private InstanSegModel downloadModel(InstanSegModel model, Path modelDir) { Objects.requireNonNull(modelDir); @@ -574,10 +575,8 @@ private static void parseMarkdown(InstanSegModel model, WebView webView, Button // If the markdown doesn't start with a title, pre-pending the model title & description (if available) if (!body.startsWith("#")) { - var sb = new StringBuilder(); - sb.append("## ").append(model.getName()).append("\n\n"); - sb.append("----\n\n"); - doc.prependChild(parser.parse(sb.toString())); + String sb = "## " + model.getName() + "\n\n----\n\n"; + doc.prependChild(parser.parse(sb)); } webView.getEngine().loadContent(HtmlRenderer.builder().build().render(doc)); infoPopover.show(infoButton); @@ -623,6 +622,7 @@ private static List getRemoteModels() { } } InputStream in = InstanSegController.class.getResourceAsStream("model-index.json"); + assert in != null; String cont = new BufferedReader(new InputStreamReader(in)) .lines() .collect(Collectors.joining("\n")); diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegPreferences.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegPreferences.java index 493dd20..47760a0 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegPreferences.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegPreferences.java @@ -3,7 +3,6 @@ import javafx.beans.property.BooleanProperty; import javafx.beans.property.IntegerProperty; import javafx.beans.property.ObjectProperty; -import javafx.beans.property.Property; import javafx.beans.property.StringProperty; import qupath.lib.common.GeneralTools; import qupath.lib.gui.prefs.PathPrefs; @@ -56,7 +55,7 @@ enum OnlinePermission { /** * MPS should work reliably (and much faster) on Apple Silicon, so set as default. * Everywhere else, use CPU as we can't count on a GPU/CUDA being available. - * @return + * @return The default device string. */ private static String getDefaultDevice() { if (GeneralTools.isMac() && "aarch64".equals(System.getProperty("os.arch"))) { diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegTask.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegTask.java index c597d48..0d4df5b 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegTask.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegTask.java @@ -14,7 +14,6 @@ import qupath.lib.plugins.workflow.DefaultScriptableWorkflowStep; import java.awt.image.BufferedImage; -import java.io.File; import java.nio.file.Path; import java.util.Arrays; import java.util.List; diff --git a/src/main/java/qupath/ext/instanseg/ui/ModelListCell.java b/src/main/java/qupath/ext/instanseg/ui/ModelListCell.java index bb253f4..f6233a4 100644 --- a/src/main/java/qupath/ext/instanseg/ui/ModelListCell.java +++ b/src/main/java/qupath/ext/instanseg/ui/ModelListCell.java @@ -12,10 +12,10 @@ public class ModelListCell extends ListCell { - private ResourceBundle resources = InstanSegResources.getResources(); + private final ResourceBundle resources = InstanSegResources.getResources(); - private Glyph web = createOnlineIcon(); - private Tooltip tooltip; + private final Glyph web = createOnlineIcon(); + private final Tooltip tooltip; public ModelListCell() { super(); diff --git a/src/main/java/qupath/ext/instanseg/ui/OutputChannelItem.java b/src/main/java/qupath/ext/instanseg/ui/OutputChannelItem.java index 2fab682..c1bfaf3 100644 --- a/src/main/java/qupath/ext/instanseg/ui/OutputChannelItem.java +++ b/src/main/java/qupath/ext/instanseg/ui/OutputChannelItem.java @@ -40,7 +40,7 @@ static List getOutputsForChannelCount(int nChannels) { /** * Get the index of the output channel. - * @return + * @return The index of the output channel */ public int getIndex() { return index; diff --git a/src/main/java/qupath/ext/instanseg/ui/Watcher.java b/src/main/java/qupath/ext/instanseg/ui/Watcher.java index e37f321..153ff11 100644 --- a/src/main/java/qupath/ext/instanseg/ui/Watcher.java +++ b/src/main/java/qupath/ext/instanseg/ui/Watcher.java @@ -64,7 +64,7 @@ private Watcher() { /** * Get a singleton instance of the Watcher. - * @return + * @return The watcher. */ static Watcher getInstance() { return instance; @@ -176,7 +176,7 @@ private void processEvents() { /** * Add the model, after checking it won't be a duplicate. - * @param model + * @param model The model. */ private void ensureModelInList(InstanSegModel model) { for (var existingModel : models) { @@ -232,7 +232,7 @@ private void refreshAllModelPaths() { /** * Update all models in response to a change in the model paths. - * @param change + * @param change The change event. */ private void handleModelPathsChanged(ListChangeListener.Change change) { if (!Platform.isFxApplicationThread()) @@ -250,13 +250,12 @@ private void handleModelPathsChanged(ListChangeListener.Change c /** * Get all the local models in a directory. - * @param dir - * @return + * @param dir The model directory. + * @return A list of model paths. */ private static List getModelPathsInDir(Path dir) { - try { - return Files.list(dir) - .filter(InstanSegModel::isValidModel) + try (var paths = Files.list(dir)) { + return paths.filter(InstanSegModel::isValidModel) .toList(); } catch (IOException e) { logger.error("Unable to list files in directory", e); @@ -272,7 +271,7 @@ private static List getModelPathsInDir(Path dir) { * Any calling code needs to figure out of models are really the same or different, which requires some * decision-making (e.g. is a model that has been downloaded the same as the local representation...? * Or are two models the same if the directories are duplicated, so that one has a different path...?). - * @return + * @return A list of models. */ ObservableList getModels() { return modelsUnmodifiable; From 7eab7aa54f643e71608b3d58d9d2bdaaf930a299 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Wed, 29 Jan 2025 12:20:58 +0000 Subject: [PATCH 09/14] Update to spec 0.2.0-SNAPSHOT --- .../qupath/ext/instanseg/core/InstanSeg.java | 6 ++-- .../ext/instanseg/core/InstanSegModel.java | 28 +++++++++++-------- .../InstanSegOutputToObjectConverter.java | 11 ++++---- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index 9cc1ac2..2cbba4d 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -9,7 +9,7 @@ import org.bytedeco.opencv.opencv_core.Mat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import qupath.bioimageio.spec.BioimageIoSpec; +import qupath.bioimageio.spec.tensor.OutputTensor; import qupath.lib.experimental.pixels.OpenCVProcessor; import qupath.lib.experimental.pixels.OutputHandler; import qupath.lib.experimental.pixels.Parameters; @@ -163,7 +163,7 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec } Path modelPath = oModelPath.get().resolve("instanseg.pt"); - Optional> oOutputTensors = this.model.getOutputs(); + Optional> oOutputTensors = this.model.getOutputs(); if (oOutputTensors.isEmpty()) { throw new IllegalArgumentException("No output tensors available even though model is available"); } @@ -324,7 +324,7 @@ private static Mat[] createOnes(Parameters parameters) { private static OutputHandler createOutputHandler(Class preferredOutputClass, boolean randomColors, int boundaryThreshold, - List outputTensors) { + List outputTensors) { // TODO: Reinstate this for Mat[] output (it was written for Mat output) // if (debugTiles()) // return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter()); diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index 57e18b8..8f0fc9a 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -2,12 +2,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import qupath.bioimageio.spec.BioimageIoSpec; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.FileInputStream; import java.io.FileOutputStream; + +import qupath.bioimageio.spec.Model; +import qupath.bioimageio.spec.Resource; +import qupath.bioimageio.spec.tensor.OutputTensor; import qupath.lib.common.GeneralTools; import qupath.lib.images.servers.PixelCalibration; @@ -28,7 +31,8 @@ import java.util.zip.ZipInputStream; import java.util.Objects; -import static qupath.bioimageio.spec.BioimageIoSpec.getAxesString; +import static qupath.bioimageio.spec.tensor.axes.Axes.getAxesString; + public class InstanSegModel { @@ -42,10 +46,10 @@ public class InstanSegModel { public static final int ANY_CHANNELS = -1; private Path path = null; - private BioimageIoSpec.BioimageIoModel model = null; + private Model model = null; private final String name; - private InstanSegModel(BioimageIoSpec.BioimageIoModel bioimageIoModel) { + private InstanSegModel(Model bioimageIoModel) { this.model = bioimageIoModel; this.path = Paths.get(model.getBaseURI()); this.version = model.getVersion(); @@ -65,7 +69,7 @@ private InstanSegModel(String name, String version, URL modelURL) { * @throws IOException If the directory can't be found or isn't a valid model directory. */ public static InstanSegModel fromPath(Path path) throws IOException { - return new InstanSegModel(BioimageIoSpec.parseModel(path)); + return new InstanSegModel(Model.parseModel(path)); } /** @@ -106,7 +110,7 @@ public void checkIfDownloaded(Path downloadedModelDir, boolean downloadIfNotVali downloadIfNotValid); this.path = unzipIfNeeded(zipFile); if (this.path != null) { - this.model = BioimageIoSpec.parseModel(path.toFile()); + this.model = Model.parseModel(path.toFile()); this.version = model.getVersion(); } } @@ -215,7 +219,7 @@ public Optional getPath() { public String toString() { String name = getName(); String parent = getPath().map(Path::getParent).map(Path::getFileName).map(Path::toString).orElse(null); - String version = getModel().map(BioimageIoSpec.BioimageIoModel::getVersion).orElse(this.version); + String version = getModel().map(Resource::getVersion).orElse(this.version); if (parent != null && !parent.equals(name)) { name = parent + "/" + name; } @@ -258,7 +262,7 @@ public Optional getNumChannels() { * Try to check the output tensors from the model spec. * @return The output tensors if the model is downloaded, otherwise empty. */ - public Optional> getOutputs() { + public Optional> getOutputs() { return getModel().flatMap(model -> Optional.ofNullable(model.getOutputs())); } @@ -299,7 +303,7 @@ public String toString() { } } - private static int extractChannelNum(BioimageIoSpec.BioimageIoModel model) { + private static int extractChannelNum(Model model) { String axes = getAxesString(model.getInputs().getFirst().getAxes()); int ind = axes.indexOf("c"); var shape = model.getInputs().getFirst().getShape(); @@ -314,7 +318,7 @@ private static int extractChannelNum(BioimageIoSpec.BioimageIoModel model) { * Retrieve the BioImage model spec. * @return The BioImageIO model spec for this InstanSeg model. */ - private Optional getModel() { + private Optional getModel() { return Optional.ofNullable(model); } @@ -341,7 +345,7 @@ private static boolean isDownloadedAlready(Path zipFile) { return false; } try { - BioimageIoSpec.parseModel(zipFile.toFile()); + Model.parseModel(zipFile.toFile()); } catch (IOException e) { logger.warn("Invalid zip file", e); return false; @@ -353,7 +357,7 @@ private Path unzipIfNeeded(Path zipFile) throws IOException { if (zipFile == null) { return null; } - var zipSpec = BioimageIoSpec.parseModel(zipFile); + var zipSpec = Model.parseModel(zipFile); String version = zipSpec.getVersion(); var outdir = zipFile.resolveSibling(getFolderName(zipSpec.getName(), version)); if (!isUnpackedAlready(outdir)) { diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java index f5f05d2..e5fcde0 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java @@ -4,7 +4,8 @@ import org.locationtech.jts.geom.Geometry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import qupath.bioimageio.spec.BioimageIoSpec; +import qupath.bioimageio.spec.tensor.OutputTensor; +import qupath.bioimageio.spec.tensor.Tensors; import qupath.lib.analysis.images.ContourTracing; import qupath.lib.common.ColorTools; import qupath.lib.experimental.pixels.OutputHandler; @@ -42,9 +43,9 @@ class InstanSegOutputToObjectConverter implements OutputHandler.OutputToObjectCo * This may be turned off or made optional in the future. */ private final boolean randomColors; - private final List outputTensors; + private final List outputTensors; - InstanSegOutputToObjectConverter(List outputTensors, + InstanSegOutputToObjectConverter(List outputTensors, Class preferredOutputType, boolean randomColors) { this.outputTensors = outputTensors; @@ -157,10 +158,10 @@ public List convertToObjects(Parameters params, Mat[] outp return pathObjects; } - private static void handleAuxOutput(PathObject pathObject, double[] values, BioimageIoSpec.OutputTensor outputTensor) { + private static void handleAuxOutput(PathObject pathObject, double[] values, OutputTensor outputTensor) { List outputClasses; var description = outputTensor.getDataDescription(); - if (description instanceof BioimageIoSpec.NominalOrOrdinalDataDescription dataDescription) { + if (description instanceof Tensors.NominalOrOrdinalDataDescription dataDescription) { outputClasses = dataDescription.getValues().stream().map(Object::toString).toList(); } else { outputClasses = new ArrayList<>(); From 10a34d590c714fa7a631852e20ef82e9bfc3c92e Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Fri, 7 Feb 2025 10:15:39 +0000 Subject: [PATCH 10/14] Fix model name --- .../java/qupath/ext/instanseg/core/InstanSegModel.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index 8f0fc9a..efa8ce6 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -69,7 +69,7 @@ private InstanSegModel(String name, String version, URL modelURL) { * @throws IOException If the directory can't be found or isn't a valid model directory. */ public static InstanSegModel fromPath(Path path) throws IOException { - return new InstanSegModel(Model.parseModel(path)); + return new InstanSegModel(Model.parse(path)); } /** @@ -110,7 +110,7 @@ public void checkIfDownloaded(Path downloadedModelDir, boolean downloadIfNotVali downloadIfNotValid); this.path = unzipIfNeeded(zipFile); if (this.path != null) { - this.model = Model.parseModel(path.toFile()); + this.model = Model.parse(path.toFile()); this.version = model.getVersion(); } } @@ -345,7 +345,7 @@ private static boolean isDownloadedAlready(Path zipFile) { return false; } try { - Model.parseModel(zipFile.toFile()); + Model.parse(zipFile.toFile()); } catch (IOException e) { logger.warn("Invalid zip file", e); return false; @@ -357,7 +357,7 @@ private Path unzipIfNeeded(Path zipFile) throws IOException { if (zipFile == null) { return null; } - var zipSpec = Model.parseModel(zipFile); + var zipSpec = Model.parse(zipFile); String version = zipSpec.getVersion(); var outdir = zipFile.resolveSibling(getFolderName(zipSpec.getName(), version)); if (!isUnpackedAlready(outdir)) { From 281578480e09e83c2c61048dc487c6e370ee1840 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Fri, 7 Feb 2025 15:39:27 +0000 Subject: [PATCH 11/14] Try to handle matched logits and class names --- .../ext/instanseg/core/InstanSegModel.java | 10 ++ .../InstanSegOutputToObjectConverter.java | 96 ++++++++++++++----- 2 files changed, 80 insertions(+), 26 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index efa8ce6..ec2244d 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -301,6 +301,16 @@ public enum OutputType { public String toString() { return type; } + + public static OutputType fromString(String value) { + for (OutputType t: values()) { + if (t.type.equalsIgnoreCase(value)) { + return t; + } + } + logger.error("Unknown output type {}", value); + return null; + } } private static int extractChannelNum(Model model) { diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java index e5fcde0..1abaea4 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java @@ -22,13 +22,7 @@ import qupath.lib.roi.interfaces.ROI; import qupath.opencv.tools.OpenCVTools; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Random; +import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; @@ -104,6 +98,7 @@ public List convertToObjects(Parameters params, Mat[] outp (roiMaps.size() == 2 && preferredObjectClass == null); List pathObjects; + Map> outputNameToClasses = fetchOutputClasses(outputTensors); if (createCells) { Map parentROIs = roiMaps.get(0); Map childROIs = roiMaps.size() >= 2 ? roiMaps.get(1) : Collections.emptyMap(); @@ -113,10 +108,12 @@ public List convertToObjects(Parameters params, Mat[] outp var child = childROIs.getOrDefault(label, null); var cell = PathObjects.createCellObject(parent, child); for (int i = 1; i < output.length; i++) { + // todo: handle paired logits and class labels handleAuxOutput( cell, auxiliaryValues.get(i).getOrDefault(label, null), - outputTensors.get(i) + outputTensors.get(i), + outputNameToClasses.get(outputTensors.get(i).getName()) ); } return cell; @@ -140,10 +137,12 @@ public List convertToObjects(Parameters params, Mat[] outp } } for (int i = 1; i < output.length; i++) { + // todo: handle paired logits and class labels handleAuxOutput( pathObject, auxiliaryValues.get(i).getOrDefault(label, null), - outputTensors.get(i) + outputTensors.get(i), + outputNameToClasses.get(outputTensors.get(i).getName()) ); } pathObjects.add(pathObject); @@ -158,36 +157,86 @@ public List convertToObjects(Parameters params, Mat[] outp return pathObjects; } - private static void handleAuxOutput(PathObject pathObject, double[] values, OutputTensor outputTensor) { - List outputClasses; + private Map> fetchOutputClasses(List outputTensors) { + Map> out = new HashMap<>(); + // todo: loop through and check type + // if there's only one output, or if there's no pairing, then return nothing + if (outputTensors.size() == 1) { + return out; + } + + var classOutputs = outputTensors.stream().filter(ot -> ot.getName().startsWith("detection_classes")).toList(); + var logitOutputs = outputTensors.stream().filter(ot -> ot.getName().startsWith("detection_logits")).toList(); + + var classTypeToClassNames = classOutputs.stream() + .collect( + Collectors.toMap( + co -> co.getName(), + co -> getClassNames(co) + ) + ); + out.putAll(classTypeToClassNames); + // try to find logits that correspond to classes (eg, detection_classes_foo and detection_logits_foo) + var logitNameToClassName = logitOutputs.stream().collect(Collectors.toMap( + lo -> lo.getName(), + lo -> { + var matchingClassNames = classTypeToClassNames.entrySet().stream() + .filter(es -> { + return es.getKey().replace("detection_classes_", "").equals(lo.getName().replace("detection_logits_", "")); + }).toList(); + if (matchingClassNames.size() > 1) { + logger.warn("More than one matching class name for logits {}, choosing the first", lo.getName()); + } else if (matchingClassNames.size() == 0) { + // try to get default class names anyway... + return getClassNames(lo); + } + return matchingClassNames.getFirst().getValue(); + })); + out.putAll(logitNameToClassName); + return out; + } + + private List getClassNames(OutputTensor outputTensor) { var description = outputTensor.getDataDescription(); - if (description instanceof Tensors.NominalOrOrdinalDataDescription dataDescription) { + List outputClasses; + if (description != null && description instanceof Tensors.NominalOrOrdinalDataDescription dataDescription) { outputClasses = dataDescription.getValues().stream().map(Object::toString).toList(); } else { outputClasses = new ArrayList<>(); - int nClasses = outputTensor.getShape().getShape()[2]; // batch, object, class + int nClasses = outputTensor.getShape().getShape()[2]; // output axes are batch, index, class for (int i = 0; i < nClasses; i++) { outputClasses.add("Class" + i); } } + return outputClasses; + } + private static void handleAuxOutput(PathObject pathObject, double[] values, OutputTensor outputTensor, List outputClasses) { if (values == null) return; - var outputType = InstanSegModel.OutputType.valueOf(outputTensor.getName().toUpperCase()); + var outputType = InstanSegModel.OutputType.fromString(outputTensor.getName()); + if (outputType == null) { + return; + } switch(outputType) { case DETECTION_LOGITS -> { try (var ml = pathObject.getMeasurementList()) { - int maxInd = 0; - double maxVal = values[0]; +// int maxInd = 0; +// double maxVal = values[0]; for (int i = 0; i < values.length; i++) { double val = values[i]; - if (val > maxVal) { - maxVal = val; - maxInd = i; - } +// if (val > maxVal) { +// maxVal = val; +// maxInd = i; +// } ml.put("Logit " + outputClasses.get(i), val); } - pathObject.setPathClass(PathClass.fromString(outputClasses.get(maxInd))); +// pathObject.setPathClass(PathClass.fromString(outputClasses.get(maxInd))); + } + } + case DETECTION_CLASSES -> { + for (double val : values) { + pathObject.setPathClass(PathClass.fromString(outputClasses.get((int) val))); } } case DETECTION_EMBEDDINGS -> { @@ -198,11 +247,6 @@ private static void handleAuxOutput(PathObject pathObject, double[] values, Outp } } } - case DETECTION_CLASSES -> { - for (double val : values) { - pathObject.setPathClass(PathClass.fromString(outputClasses.get((int) val))); - } - } } if (outputType == InstanSegModel.OutputType.SEMANTIC_SEGMENTATION) { From 77d57a0e49194230ff805c01181fd8836188a8f8 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Fri, 7 Feb 2025 15:40:18 +0000 Subject: [PATCH 12/14] Type not class --- .../java/qupath/ext/instanseg/core/InstanSeg.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index 22d5c39..a56e66a 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -77,7 +77,7 @@ private InstanSeg(Builder builder) { this.model = builder.model; this.device = builder.device; this.taskRunner = builder.taskRunner; - this.preferredOutputType = builder.preferredOutputClass; + this.preferredOutputType = builder.preferredOutputType; this.randomColors = builder.randomColors; this.makeMeasurements = builder.makeMeasurements; this.optionalArgs.putAll(builder.optionalArgs); @@ -326,14 +326,14 @@ private static Mat[] createOnes(Parameters parameters) { } - private static OutputHandler createOutputHandler(Class preferredOutputClass, + private static OutputHandler createOutputHandler(Class preferredOutputType, boolean randomColors, int boundaryThreshold, List outputTensors) { // TODO: Reinstate this for Mat[] output (it was written for Mat output) // if (debugTiles()) // return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter()); - var converter = new InstanSegOutputToObjectConverter(outputTensors, preferredOutputClass, randomColors); + var converter = new InstanSegOutputToObjectConverter(outputTensors, preferredOutputType, randomColors); if (boundaryThreshold >= 0) { return new PruneObjectOutputHandler<>(converter, boundaryThreshold); } else { @@ -413,7 +413,7 @@ public static final class Builder { private TaskRunner taskRunner = TaskRunnerUtils.getDefaultInstance().createTaskRunner(); private Collection channels; private InstanSegModel model; - private Class preferredOutputClass; + private Class preferredOutputType; private final Map optionalArgs = new LinkedHashMap<>(); Builder() {} @@ -654,7 +654,7 @@ public Builder device(Device device) { * @return this builder */ public Builder outputCells() { - this.preferredOutputClass = PathCellObject.class; + this.preferredOutputType = PathCellObject.class; return this; } @@ -663,7 +663,7 @@ public Builder outputCells() { * @return this builder */ public Builder outputDetections() { - this.preferredOutputClass = PathDetectionObject.class; + this.preferredOutputType = PathDetectionObject.class; return this; } @@ -672,7 +672,7 @@ public Builder outputDetections() { * @return this builder */ public Builder outputAnnotations() { - this.preferredOutputClass = PathAnnotationObject.class; + this.preferredOutputType = PathAnnotationObject.class; return this; } From 6ea5831f106eb26e002fafaa5b19b2035d175285 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Fri, 7 Feb 2025 20:04:59 +0000 Subject: [PATCH 13/14] Address comments --- build.gradle.kts | 1 - .../qupath/ext/instanseg/core/InstanSeg.java | 5 --- .../ext/instanseg/core/InstanSegModel.java | 38 +++++++------------ .../InstanSegOutputToObjectConverter.java | 30 +++++++-------- 4 files changed, 29 insertions(+), 45 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 0bb322d..887a0f7 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -20,7 +20,6 @@ dependencies { implementation(libs.bioimageio.spec) implementation(libs.deepJavaLibrary) - implementation("io.github.qupath:qupath-extension-djl:0.4.0-SNAPSHOT") // For testing testImplementation(libs.junit) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index a56e66a..a52c80b 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -174,11 +174,6 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec } var outputTensors = oOutputTensors.get(); - List outputClasses = this.model.getClasses(); - if (outputClasses.isEmpty() && outputTensors.size() > 1) { - logger.warn("No output classes available, classes will be set as 'Class 1' etc."); - } - // Provide some way to change the number of predictors, even if this can't be specified through the UI // See https://forum.image.sc/t/instanseg-under-utilizing-cpu-only-2-3-cores/104496/7 int nPredictors = Integer.parseInt(System.getProperty("instanseg.numPredictors", "1")); diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index ec2244d..228d372 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -23,7 +23,6 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -267,25 +266,9 @@ public Optional> getOutputs() { } /** - * Try to check the output classes from the model spec. - * @return The output classes if the model is downloaded, and it's present, otherwise empty. + * Types of output tensors that may be supported by InstanSeg models. */ - public List getClasses() { - var config = model.getConfig().getOrDefault("qupath", null); - if (config instanceof Map configMap) { - List classes = new ArrayList<>(); - var el = configMap.get("classes"); - if (el != null && el instanceof List elList) { - for (var t: elList) { - classes.add(t.toString()); - } - } - return classes; - } - return List.of(); - } - - public enum OutputType { + public enum OutputTensorType { // "instance segmentation" "cell embeddings" "cell classes" "cell probabilities" "semantic segmentation" INSTANCE_SEGMENTATION("instance_segmentation"), DETECTION_EMBEDDINGS("detection_embeddings"), @@ -294,22 +277,29 @@ public enum OutputType { SEMANTIC_SEGMENTATION("semantic_segmentation"); private final String type; - OutputType(String type) { + + OutputTensorType(String type) { this.type = type; } + @Override public String toString() { return type; } - public static OutputType fromString(String value) { - for (OutputType t: values()) { + /** + * Get the output type from a string, ignoring case + * @param value the input String to be matched against the possible values + * @return the matching output type, or empty if no match + */ + public static Optional fromString(String value) { + for (OutputTensorType t: values()) { if (t.type.equalsIgnoreCase(value)) { - return t; + return Optional.of(t); } } logger.error("Unknown output type {}", value); - return null; + return Optional.empty(); } } diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java index 1abaea4..f50c5a3 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java @@ -22,7 +22,14 @@ import qupath.lib.roi.interfaces.ROI; import qupath.opencv.tools.OpenCVTools; -import java.util.*; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; import java.util.function.Function; import java.util.stream.Collectors; @@ -214,24 +221,18 @@ private List getClassNames(OutputTensor outputTensor) { private static void handleAuxOutput(PathObject pathObject, double[] values, OutputTensor outputTensor, List outputClasses) { if (values == null) return; - var outputType = InstanSegModel.OutputType.fromString(outputTensor.getName()); - if (outputType == null) { + var outputType = InstanSegModel.OutputTensorType.fromString(outputTensor.getName()); + if (outputType.isEmpty()) { return; } - switch(outputType) { + switch(outputType.get()) { case DETECTION_LOGITS -> { + // we could also assign classes here, but assume for now this is handled internally and supplied as binary output try (var ml = pathObject.getMeasurementList()) { -// int maxInd = 0; -// double maxVal = values[0]; for (int i = 0; i < values.length; i++) { double val = values[i]; -// if (val > maxVal) { -// maxVal = val; -// maxInd = i; -// } ml.put("Logit " + outputClasses.get(i), val); } -// pathObject.setPathClass(PathClass.fromString(outputClasses.get(maxInd))); } } case DETECTION_CLASSES -> { @@ -247,10 +248,9 @@ private static void handleAuxOutput(PathObject pathObject, double[] values, Outp } } } - } - - if (outputType == InstanSegModel.OutputType.SEMANTIC_SEGMENTATION) { - throw new UnsupportedOperationException("No idea what to do here!"); + case SEMANTIC_SEGMENTATION -> { + throw new UnsupportedOperationException("No idea what to do here!"); + } } } From 554185fbd105686e47373bd1c719f1921aa296d9 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Fri, 7 Feb 2025 20:13:29 +0000 Subject: [PATCH 14/14] Add DJL extension back in --- build.gradle.kts | 1 + 1 file changed, 1 insertion(+) diff --git a/build.gradle.kts b/build.gradle.kts index 887a0f7..0bb322d 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -20,6 +20,7 @@ dependencies { implementation(libs.bioimageio.spec) implementation(libs.deepJavaLibrary) + implementation("io.github.qupath:qupath-extension-djl:0.4.0-SNAPSHOT") // For testing testImplementation(libs.junit)