diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index 7b0945d..a52c80b 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -9,9 +9,11 @@ import org.bytedeco.opencv.opencv_core.Mat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import qupath.bioimageio.spec.tensor.OutputTensor; import qupath.lib.experimental.pixels.OpenCVProcessor; import qupath.lib.experimental.pixels.OutputHandler; import qupath.lib.experimental.pixels.Parameters; +import qupath.lib.experimental.pixels.PixelProcessor; import qupath.lib.experimental.pixels.Processor; import qupath.lib.images.ImageData; import qupath.lib.images.servers.ColorTransforms; @@ -58,7 +60,7 @@ public class InstanSeg { private final InstanSegModel model; private final Device device; private final TaskRunner taskRunner; - private final Class preferredOutputClass; + private final Class preferredOutputType; private final Map optionalArgs = new LinkedHashMap<>(); // This was previously an adjustable parameter, but it's now fixed at 1 because we handle overlaps differently. @@ -75,7 +77,7 @@ private InstanSeg(Builder builder) { this.model = builder.model; this.device = builder.device; this.taskRunner = builder.taskRunner; - this.preferredOutputClass = builder.preferredOutputClass; + this.preferredOutputType = builder.preferredOutputType; this.randomColors = builder.randomColors; this.makeMeasurements = builder.makeMeasurements; this.optionalArgs.putAll(builder.optionalArgs); @@ -160,13 +162,18 @@ private void makeMeasurements(ImageData imageData, Collection imageData, Collection pathObjects) { long startTime = System.currentTimeMillis(); - Optional oModelPath = model.getPath(); if (oModelPath.isEmpty()) { return InstanSegResults.emptyInstance(); } Path modelPath = oModelPath.get().resolve("instanseg.pt"); + Optional> oOutputTensors = this.model.getOutputs(); + if (oOutputTensors.isEmpty()) { + throw new IllegalArgumentException("No output tensors available even though model is available"); + } + var outputTensors = oOutputTensors.get(); + // Provide some way to change the number of predictors, even if this can't be specified through the UI // See https://forum.image.sc/t/instanseg-under-utilizing-cpu-only-2-3-cores/104496/7 int nPredictors = Integer.parseInt(System.getProperty("instanseg.numPredictors", "1")); @@ -179,8 +186,6 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec logger.warn("Padding to input size is turned on - this is likely to be slower (but could help fix any issues)"); } String layout = "CHW"; - - // TODO: Remove C if not needed (added for instanseg_v0_2_0.pt) - still relevant? String layoutOutput = "CHW"; // Get the downsample - this may be specified by the user, or determined from the model spec @@ -202,6 +207,7 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec // Create an int[] representing a boolean array of channels to use boolean[] outputChannelArray = null; if (outputChannels != null && outputChannels.length > 0) { + //noinspection OptionalGetWithoutIsPresent outputChannelArray = new boolean[model.getOutputChannels().get()]; // safe to call get because of previous checks for (int c : outputChannels) { if (c < 0 || c >= outputChannelArray.length) { @@ -215,7 +221,7 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec var inputChannels = getInputChannels(imageData); try (var model = Criteria.builder() - .setTypes(Mat.class, Mat.class) + .setTypes(Mat.class, Mat[].class) .optModelUrls(String.valueOf(modelPath.toUri())) .optProgress(new ProgressBar()) .optDevice(device) // Remove this line if devices are problematic! @@ -227,7 +233,7 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec printResourceCount("Resource count before prediction", (BaseNDManager)baseManager.getParentManager()); baseManager.debugDump(2); - BlockingQueue> predictors = new ArrayBlockingQueue<>(nPredictors); + BlockingQueue> predictors = new ArrayBlockingQueue<>(nPredictors); try { for (int i = 0; i < nPredictors; i++) { @@ -239,10 +245,11 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec var tiler = createTiler(downsample, tileDims, padding); var predictionProcessor = createProcessor(predictors, inputChannels, tileDims, padToInputSize); - var outputHandler = createOutputHandler(preferredOutputClass, randomColors, boundaryThreshold); + var outputHandler = createOutputHandler(preferredOutputType, randomColors, boundaryThreshold, outputTensors); var postProcessor = createPostProcessor(); - - var processor = OpenCVProcessor.builder(predictionProcessor) + var processor = new PixelProcessor.Builder() + .processor(predictionProcessor) + .maskSupplier(OpenCVProcessor.createMatMaskSupplier()) .imageSupplier((parameters) -> ImageOps.buildImageDataOp(inputChannels) .apply(parameters.getImageData(), parameters.getRegionRequest())) .tiler(tiler) @@ -279,16 +286,17 @@ private InstanSegResults runInstanSeg(ImageData imageData, Collec } } + /** * Check if we are requesting tiles for debugging purposes. * When this is true, we should create objects that represent the tiles - not the objects to be detected. - * @return + * @return Whether the system debugging property is set. */ private static boolean debugTiles() { return System.getProperty("instanseg.debug.tiles", "false").strip().equalsIgnoreCase("true"); } - private static Processor createProcessor(BlockingQueue> predictors, + private static Processor createProcessor(BlockingQueue> predictors, Collection inputChannels, int tileDims, boolean padToInputSize) { if (debugTiles()) @@ -296,7 +304,7 @@ private static Processor createProcessor(BlockingQueue parameters) { + private static Mat[] createOnes(Parameters parameters) { var tileRequest = parameters.getTileRequest(); int width, height; if (tileRequest == null) { @@ -307,15 +315,20 @@ private static Mat createOnes(Parameters parameters) { width = tileRequest.getTileWidth(); height = tileRequest.getTileHeight(); } - return Mat.ones(height, width, opencv_core.CV_8UC1).asMat(); + try (var ones = Mat.ones(height, width, opencv_core.CV_8UC1)) { + return new Mat[]{ones.asMat()}; + } } - private static OutputHandler createOutputHandler(Class preferredOutputClass, - boolean randomColors, - int boundaryThreshold) { - if (debugTiles()) - return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter()); - var converter = new InstanSegOutputToObjectConverter(preferredOutputClass, randomColors); + + private static OutputHandler createOutputHandler(Class preferredOutputType, + boolean randomColors, + int boundaryThreshold, + List outputTensors) { + // TODO: Reinstate this for Mat[] output (it was written for Mat output) +// if (debugTiles()) +// return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter()); + var converter = new InstanSegOutputToObjectConverter(outputTensors, preferredOutputType, randomColors); if (boundaryThreshold >= 0) { return new PruneObjectOutputHandler<>(converter, boundaryThreshold); } else { @@ -334,8 +347,8 @@ private static Tiler createTiler(double downsample, int tileDims, int padding) { /** * Get the input channels to use; if we don't have any specified, use all of them - * @param imageData - * @return + * @param imageData The image data + * @return The possible input channels. */ private List getInputChannels(ImageData imageData) { if (inputChannels == null || inputChannels.isEmpty()) { @@ -364,8 +377,8 @@ private static ObjectProcessor createPostProcessor() { /** * Print resource count for debugging purposes. * If we are not logging at debug level, do nothing. - * @param title - * @param manager + * @param title The name to be used in the log. + * @param manager The NDManager to print from. */ private static void printResourceCount(String title, BaseNDManager manager) { if (logger.isDebugEnabled()) { @@ -395,7 +408,7 @@ public static final class Builder { private TaskRunner taskRunner = TaskRunnerUtils.getDefaultInstance().createTaskRunner(); private Collection channels; private InstanSegModel model; - private Class preferredOutputClass; + private Class preferredOutputType; private final Map optionalArgs = new LinkedHashMap<>(); Builder() {} @@ -555,7 +568,7 @@ public Builder randomColors() { /** * Optionally request that random colors be used for the output objects. - * @param doRandomColors + * @param doRandomColors Whether to use random colors for output object. * @return this builder */ public Builder randomColors(boolean doRandomColors) { @@ -636,7 +649,7 @@ public Builder device(Device device) { * @return this builder */ public Builder outputCells() { - this.preferredOutputClass = PathCellObject.class; + this.preferredOutputType = PathCellObject.class; return this; } @@ -645,7 +658,7 @@ public Builder outputCells() { * @return this builder */ public Builder outputDetections() { - this.preferredOutputClass = PathDetectionObject.class; + this.preferredOutputType = PathDetectionObject.class; return this; } @@ -654,7 +667,7 @@ public Builder outputDetections() { * @return this builder */ public Builder outputAnnotations() { - this.preferredOutputClass = PathAnnotationObject.class; + this.preferredOutputType = PathAnnotationObject.class; return this; } diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index d887661..228d372 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -2,12 +2,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import qupath.bioimageio.spec.BioimageIoSpec; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.FileInputStream; import java.io.FileOutputStream; + +import qupath.bioimageio.spec.Model; +import qupath.bioimageio.spec.Resource; +import qupath.bioimageio.spec.tensor.OutputTensor; import qupath.lib.common.GeneralTools; import qupath.lib.images.servers.PixelCalibration; @@ -27,6 +30,9 @@ import java.util.zip.ZipInputStream; import java.util.Objects; +import static qupath.bioimageio.spec.tensor.axes.Axes.getAxesString; + + public class InstanSegModel { private static final Logger logger = LoggerFactory.getLogger(InstanSegModel.class); @@ -39,10 +45,10 @@ public class InstanSegModel { public static final int ANY_CHANNELS = -1; private Path path = null; - private BioimageIoSpec.BioimageIoModel model = null; + private Model model = null; private final String name; - private InstanSegModel(BioimageIoSpec.BioimageIoModel bioimageIoModel) { + private InstanSegModel(Model bioimageIoModel) { this.model = bioimageIoModel; this.path = Paths.get(model.getBaseURI()); this.version = model.getVersion(); @@ -62,7 +68,7 @@ private InstanSegModel(String name, String version, URL modelURL) { * @throws IOException If the directory can't be found or isn't a valid model directory. */ public static InstanSegModel fromPath(Path path) throws IOException { - return new InstanSegModel(BioimageIoSpec.parseModel(path)); + return new InstanSegModel(Model.parse(path)); } /** @@ -103,7 +109,7 @@ public void checkIfDownloaded(Path downloadedModelDir, boolean downloadIfNotVali downloadIfNotValid); this.path = unzipIfNeeded(zipFile); if (this.path != null) { - this.model = BioimageIoSpec.parseModel(path.toFile()); + this.model = Model.parse(path.toFile()); this.version = model.getVersion(); } } @@ -212,7 +218,7 @@ public Optional getPath() { public String toString() { String name = getName(); String parent = getPath().map(Path::getParent).map(Path::getFileName).map(Path::toString).orElse(null); - String version = getModel().map(BioimageIoSpec.BioimageIoModel::getVersion).orElse(this.version); + String version = getModel().map(Resource::getVersion).orElse(this.version); if (parent != null && !parent.equals(name)) { name = parent + "/" + name; } @@ -251,8 +257,55 @@ public Optional getNumChannels() { return getModel().flatMap(model -> Optional.of(extractChannelNum(model))); } - private static int extractChannelNum(BioimageIoSpec.BioimageIoModel model) { - int ind = model.getInputs().getFirst().getAxes().toLowerCase().indexOf("c"); + /** + * Try to check the output tensors from the model spec. + * @return The output tensors if the model is downloaded, otherwise empty. + */ + public Optional> getOutputs() { + return getModel().flatMap(model -> Optional.ofNullable(model.getOutputs())); + } + + /** + * Types of output tensors that may be supported by InstanSeg models. + */ + public enum OutputTensorType { + // "instance segmentation" "cell embeddings" "cell classes" "cell probabilities" "semantic segmentation" + INSTANCE_SEGMENTATION("instance_segmentation"), + DETECTION_EMBEDDINGS("detection_embeddings"), + DETECTION_LOGITS("detection_logits"), + DETECTION_CLASSES("detection_classes"), + SEMANTIC_SEGMENTATION("semantic_segmentation"); + + private final String type; + + OutputTensorType(String type) { + this.type = type; + } + + @Override + public String toString() { + return type; + } + + /** + * Get the output type from a string, ignoring case + * @param value the input String to be matched against the possible values + * @return the matching output type, or empty if no match + */ + public static Optional fromString(String value) { + for (OutputTensorType t: values()) { + if (t.type.equalsIgnoreCase(value)) { + return Optional.of(t); + } + } + logger.error("Unknown output type {}", value); + return Optional.empty(); + } + } + + private static int extractChannelNum(Model model) { + String axes = getAxesString(model.getInputs().getFirst().getAxes()); + int ind = axes.indexOf("c"); var shape = model.getInputs().getFirst().getShape(); if (shape.getShapeStep()[ind] == 1) { return ANY_CHANNELS; @@ -265,7 +318,7 @@ private static int extractChannelNum(BioimageIoSpec.BioimageIoModel model) { * Retrieve the BioImage model spec. * @return The BioImageIO model spec for this InstanSeg model. */ - private Optional getModel() { + private Optional getModel() { return Optional.ofNullable(model); } @@ -292,7 +345,7 @@ private static boolean isDownloadedAlready(Path zipFile) { return false; } try { - BioimageIoSpec.parseModel(zipFile.toFile()); + Model.parse(zipFile.toFile()); } catch (IOException e) { logger.warn("Invalid zip file", e); return false; @@ -304,7 +357,7 @@ private Path unzipIfNeeded(Path zipFile) throws IOException { if (zipFile == null) { return null; } - var zipSpec = BioimageIoSpec.parseModel(zipFile); + var zipSpec = Model.parse(zipFile); String version = zipSpec.getVersion(); var outdir = zipFile.resolveSibling(getFolderName(zipSpec.getName(), version)); if (!isUnpackedAlready(outdir)) { @@ -399,7 +452,7 @@ private Optional> getPixelSize() { public Optional getOutputChannels() { return getModel().map(model -> { var output = model.getOutputs().getFirst(); - String axes = output.getAxes().toLowerCase(); + String axes = getAxesString(output.getAxes()); int ind = axes.indexOf("c"); var shape = output.getShape().getShape(); if (shape != null && shape.length > ind) diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java index a9956e0..f50c5a3 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegOutputToObjectConverter.java @@ -4,6 +4,8 @@ import org.locationtech.jts.geom.Geometry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import qupath.bioimageio.spec.tensor.OutputTensor; +import qupath.bioimageio.spec.tensor.Tensors; import qupath.lib.analysis.images.ContourTracing; import qupath.lib.common.ColorTools; import qupath.lib.experimental.pixels.OutputHandler; @@ -14,13 +16,16 @@ import qupath.lib.objects.PathObject; import qupath.lib.objects.PathObjects; import qupath.lib.objects.PathTileObject; +import qupath.lib.objects.classes.PathClass; import qupath.lib.regions.ImagePlane; import qupath.lib.roi.GeometryTools; import qupath.lib.roi.interfaces.ROI; import qupath.opencv.tools.OpenCVTools; + import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -28,7 +33,7 @@ import java.util.function.Function; import java.util.stream.Collectors; -class InstanSegOutputToObjectConverter implements OutputHandler.OutputToObjectConverter { +class InstanSegOutputToObjectConverter implements OutputHandler.OutputToObjectConverter { private static final Logger logger = LoggerFactory.getLogger(InstanSegOutputToObjectConverter.class); @@ -39,25 +44,30 @@ class InstanSegOutputToObjectConverter implements OutputHandler.OutputToObjectCo * This may be turned off or made optional in the future. */ private final boolean randomColors; + private final List outputTensors; - InstanSegOutputToObjectConverter(Class preferredObjectClass, boolean randomColors) { - this.preferredObjectClass = preferredObjectClass; + InstanSegOutputToObjectConverter(List outputTensors, + Class preferredOutputType, + boolean randomColors) { + this.outputTensors = outputTensors; + this.preferredObjectClass = preferredOutputType; this.randomColors = randomColors; } @Override - public List convertToObjects(Parameters params, Mat output) { + public List convertToObjects(Parameters params, Mat[] output) { if (output == null) { return List.of(); } - int nChannels = output.channels(); + var matLabels = output[0]; + int nChannels = matLabels.channels(); if (nChannels < 1 || nChannels > 2) throw new IllegalArgumentException("Expected 1 or 2 channels, but found " + nChannels); List> roiMaps = new ArrayList<>(); ImagePlane plane = params.getRegionRequest().getImagePlane(); - for (var mat : OpenCVTools.splitChannels(output)) { + for (var mat : OpenCVTools.splitChannels(matLabels)) { var image = OpenCVTools.matToSimpleImage(mat, 0); var geoms = ContourTracing.createGeometries(image, params.getRegionRequest(), 1, -1); roiMaps.add(geoms.entrySet().stream() @@ -68,6 +78,23 @@ public List convertToObjects(Parameters params, Mat output ); } + // "instance segmentation" "cell embeddings" "cell classes" "cell probabilities" "semantic segmentation" + // If we have two outputs, the second may give classifications - arrange by row + List> auxiliaryValues = new ArrayList<>(); + auxiliaryValues.add(new HashMap<>()); + if (output.length > 1) { + Map auxVals = new HashMap<>(); + for (int i = 1; i < output.length; i++) { + var matClass = output[i]; + int nRows = matClass.rows(); + for (int r = 0; r < nRows; r++) { + double[] doubles = OpenCVTools.extractDoubles(matClass.row(r)); + auxVals.put(r+1, doubles); + } + auxiliaryValues.add(auxVals); + } + } + // We reverse the order because the smaller output (e.g. nucleus) comes before the larger out (e.g. cell) // and we want to iterate in the opposite order. If this changes (or becomes inconsistent) we may need to // sum pixels or areas. @@ -78,13 +105,25 @@ public List convertToObjects(Parameters params, Mat output (roiMaps.size() == 2 && preferredObjectClass == null); List pathObjects; + Map> outputNameToClasses = fetchOutputClasses(outputTensors); if (createCells) { Map parentROIs = roiMaps.get(0); Map childROIs = roiMaps.size() >= 2 ? roiMaps.get(1) : Collections.emptyMap(); pathObjects = parentROIs.entrySet().stream().map(entry -> { var parent = entry.getValue(); - var child = childROIs.getOrDefault(entry.getKey(), null); - return PathObjects.createCellObject(parent, child); + var label = entry.getKey(); + var child = childROIs.getOrDefault(label, null); + var cell = PathObjects.createCellObject(parent, child); + for (int i = 1; i < output.length; i++) { + // todo: handle paired logits and class labels + handleAuxOutput( + cell, + auxiliaryValues.get(i).getOrDefault(label, null), + outputTensors.get(i), + outputNameToClasses.get(outputTensors.get(i).getName()) + ); + } + return cell; }).toList(); } else { Function createObjectFun = createObjectFun(preferredObjectClass); @@ -92,17 +131,27 @@ public List convertToObjects(Parameters params, Mat output Map parentMap = roiMaps.getFirst(); List> childMaps = roiMaps.size() == 1 ? Collections.emptyList() : roiMaps.subList(1, roiMaps.size()); for (var entry : parentMap.entrySet()) { + var label = entry.getKey(); var roi = entry.getValue(); var pathObject = createObjectFun.apply(roi); if (roiMaps.size() > 1) { for (var subMap : childMaps) { - var childROI = subMap.get(entry.getKey()); + var childROI = subMap.get(label); if (childROI != null) { var childObject = createObjectFun.apply(childROI); pathObject.addChildObject(childObject); } } } + for (int i = 1; i < output.length; i++) { + // todo: handle paired logits and class labels + handleAuxOutput( + pathObject, + auxiliaryValues.get(i).getOrDefault(label, null), + outputTensors.get(i), + outputNameToClasses.get(outputTensors.get(i).getName()) + ); + } pathObjects.add(pathObject); } } @@ -115,18 +164,108 @@ public List convertToObjects(Parameters params, Mat output return pathObjects; } + private Map> fetchOutputClasses(List outputTensors) { + Map> out = new HashMap<>(); + // todo: loop through and check type + // if there's only one output, or if there's no pairing, then return nothing + if (outputTensors.size() == 1) { + return out; + } + + var classOutputs = outputTensors.stream().filter(ot -> ot.getName().startsWith("detection_classes")).toList(); + var logitOutputs = outputTensors.stream().filter(ot -> ot.getName().startsWith("detection_logits")).toList(); + + var classTypeToClassNames = classOutputs.stream() + .collect( + Collectors.toMap( + co -> co.getName(), + co -> getClassNames(co) + ) + ); + out.putAll(classTypeToClassNames); + // try to find logits that correspond to classes (eg, detection_classes_foo and detection_logits_foo) + var logitNameToClassName = logitOutputs.stream().collect(Collectors.toMap( + lo -> lo.getName(), + lo -> { + var matchingClassNames = classTypeToClassNames.entrySet().stream() + .filter(es -> { + return es.getKey().replace("detection_classes_", "").equals(lo.getName().replace("detection_logits_", "")); + }).toList(); + if (matchingClassNames.size() > 1) { + logger.warn("More than one matching class name for logits {}, choosing the first", lo.getName()); + } else if (matchingClassNames.size() == 0) { + // try to get default class names anyway... + return getClassNames(lo); + } + return matchingClassNames.getFirst().getValue(); + })); + out.putAll(logitNameToClassName); + return out; + } + + private List getClassNames(OutputTensor outputTensor) { + var description = outputTensor.getDataDescription(); + List outputClasses; + if (description != null && description instanceof Tensors.NominalOrOrdinalDataDescription dataDescription) { + outputClasses = dataDescription.getValues().stream().map(Object::toString).toList(); + } else { + outputClasses = new ArrayList<>(); + int nClasses = outputTensor.getShape().getShape()[2]; // output axes are batch, index, class + for (int i = 0; i < nClasses; i++) { + outputClasses.add("Class" + i); + } + } + return outputClasses; + } + + private static void handleAuxOutput(PathObject pathObject, double[] values, OutputTensor outputTensor, List outputClasses) { + if (values == null) + return; + var outputType = InstanSegModel.OutputTensorType.fromString(outputTensor.getName()); + if (outputType.isEmpty()) { + return; + } + switch(outputType.get()) { + case DETECTION_LOGITS -> { + // we could also assign classes here, but assume for now this is handled internally and supplied as binary output + try (var ml = pathObject.getMeasurementList()) { + for (int i = 0; i < values.length; i++) { + double val = values[i]; + ml.put("Logit " + outputClasses.get(i), val); + } + } + } + case DETECTION_CLASSES -> { + for (double val : values) { + pathObject.setPathClass(PathClass.fromString(outputClasses.get((int) val))); + } + } + case DETECTION_EMBEDDINGS -> { + try (var ml = pathObject.getMeasurementList()) { + for (int i = 0; i < values.length; i++) { + double val = values[i]; + ml.put("Embedding " + i, val); + } + } + } + case SEMANTIC_SEGMENTATION -> { + throw new UnsupportedOperationException("No idea what to do here!"); + } + } + } + + /** * Assign a random color to a PathObject and all descendants, returning the object. - * @param pathObject - * @param rng - * @return + * + * @param pathObject The PathObject + * @param rng A random number generator. */ - private static PathObject assignRandomColor(PathObject pathObject, Random rng) { + private static void assignRandomColor(PathObject pathObject, Random rng) { pathObject.setColor(randomRGB(rng)); for (var child : pathObject.getChildObjects()) { assignRandomColor(child, rng); } - return pathObject; } private static ROI geometryToFilledROI(Geometry geom, ImagePlane plane) { @@ -150,8 +289,8 @@ else if (Objects.equals(PathTileObject.class, preferredObjectClass)) /** * Create annotations that are locked by default, to reduce the risk of editing them accidentally. - * @param roi - * @return + * @param roi The region of interest + * @return A locked annotation object. */ private static PathObject createLockedAnnotation(ROI roi) { var annotation = PathObjects.createAnnotationObject(roi); diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegResults.java b/src/main/java/qupath/ext/instanseg/core/InstanSegResults.java index 72fd4ca..9d1eb9a 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegResults.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegResults.java @@ -21,7 +21,7 @@ public record InstanSegResults( /** * Get an empty instance of InstanSegResults. - * @return + * @return An empty instance with default values. */ public static InstanSegResults emptyInstance() { return EMPTY; diff --git a/src/main/java/qupath/ext/instanseg/core/MatTranslator.java b/src/main/java/qupath/ext/instanseg/core/MatTranslator.java index 6b8f1cf..50880c8 100644 --- a/src/main/java/qupath/ext/instanseg/core/MatTranslator.java +++ b/src/main/java/qupath/ext/instanseg/core/MatTranslator.java @@ -4,20 +4,20 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; +import org.bytedeco.opencv.opencv_core.Mat; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import qupath.ext.djl.DjlTools; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import org.bytedeco.opencv.opencv_core.Mat; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import qupath.ext.djl.DjlTools; -class MatTranslator implements Translator { +class MatTranslator implements Translator { private static final Logger logger = LoggerFactory.getLogger(MatTranslator.class); private final String inputLayoutNd; @@ -104,9 +104,15 @@ private static List sanitizeOptionalArgs(Map optionalArgs, N } @Override - public Mat processOutput(TranslatorContext ctx, NDList list) { + public Mat[] processOutput(TranslatorContext ctx, NDList list) { var array = list.getFirst(); - return DjlTools.ndArrayToMat(array, outputLayoutNd); + var labels = DjlTools.ndArrayToMat(array, outputLayoutNd); + var output = new Mat[list.size()]; + output[0] = labels; + for (int i = 1; i < list.size(); i++) { + output[i] = DjlTools.ndArrayToMat(list.get(i), "HW"); + } + return output; } } diff --git a/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java b/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java index fc6bdea..fa23712 100644 --- a/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java +++ b/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java @@ -41,7 +41,7 @@ public boolean handleOutput(Parameters params, U output) { List newObjects = converter.convertToObjects(params, output); if (newObjects == null) return false; - // If using a proxy object (eg tile), + // If using a proxy object (e.g. tile), // we want to remove things touching the tile boundary, // then add the objects to the proxy rather than the parent var parentOrProxy = params.getParentOrProxy(); diff --git a/src/main/java/qupath/ext/instanseg/core/PytorchManager.java b/src/main/java/qupath/ext/instanseg/core/PytorchManager.java index 53b5eb4..f6b4380 100644 --- a/src/main/java/qupath/ext/instanseg/core/PytorchManager.java +++ b/src/main/java/qupath/ext/instanseg/core/PytorchManager.java @@ -65,7 +65,7 @@ public static Collection getAvailableDevices() { /** * Query if the PyTorch engine is already available, without a need to download. - * @return + * @return whether PyTorch is available. */ public static boolean hasPyTorchEngine() { return getEngineOffline() != null; @@ -86,10 +86,10 @@ static Engine getEngineOffline() { /** * Call a function with the "offline" property set to true (to block automatic downloads). - * @param callable - * @return - * @param - * @throws Exception + * @param callable The function to be called. + * @return The return value of the callable. + * @param The return type of the callable. + * @throws Exception If the callable does. */ private static T callOffline(Callable callable) throws Exception { return callWithTempProperty("ai.djl.offline", "true", callable); @@ -97,10 +97,10 @@ private static T callOffline(Callable callable) throws Exception { /** * Call a function with the "offline" property set to false (to allow automatic downloads). - * @param callable - * @return - * @param - * @throws Exception + * @param callable The function to be called. + * @return The return value of the callable. + * @param The return type of the callable. + * @throws Exception If the callable does. */ private static T callOnline(Callable callable) throws Exception { return callWithTempProperty("ai.djl.offline", "false", callable); diff --git a/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java b/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java index 53af730..305015f 100644 --- a/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java +++ b/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java @@ -32,11 +32,11 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -class TilePredictionProcessor implements Processor { +class TilePredictionProcessor implements Processor { private static final Logger logger = LoggerFactory.getLogger(TilePredictionProcessor.class); - private final BlockingQueue> predictors; + private final BlockingQueue> predictors; private final int inputWidth; private final int inputHeight; @@ -59,7 +59,7 @@ class TilePredictionProcessor implements Processor { */ private final Map normalization = Collections.synchronizedMap(new WeakHashMap<>()); - TilePredictionProcessor(BlockingQueue> predictors, + TilePredictionProcessor(BlockingQueue> predictors, Collection channels, int inputWidth, int inputHeight, boolean doPadding) { this.predictors = predictors; @@ -91,7 +91,7 @@ public int getTilesFailedCount() { * The number of channels does not influence the result. *

* One use of this is to help assess the impact of padding on the processing time. - * @return the pixels that were processed + * @return the pixels that were processed. */ public long getPixelsProcessedCount() { return nPixelsProcessed.get(); @@ -101,14 +101,14 @@ public long getPixelsProcessedCount() { * Check if the processing was interrupted. * This can be used to determine if the processing was stopped prematurely, * and failed tiles were not necessarily errors. - * @return + * @return whether processing was interrupted. */ public boolean wasInterrupted() { return wasInterrupted.get(); } @Override - public Mat process(Parameters params) throws IOException { + public Mat[] process(Parameters params) throws IOException { var mat = params.getImage(); @@ -136,7 +136,7 @@ public Mat process(Parameters params) throws IOException { mat = mat2; } - Predictor predictor = null; + Predictor predictor = null; try { predictor = predictors.take(); logger.debug("Predicting tile {}", mat); @@ -148,9 +148,12 @@ public Mat process(Parameters params) throws IOException { OpenCVTools.matToImagePlus("Output " + params.getRegionRequest(), matOutput).show(); } - matOutput.convertTo(matOutput, opencv_core.CV_32S); + // Handle the first output (labels) + // There may be other outputs (classiications, features), but we don't handle those here + matOutput[0].convertTo(matOutput[0], opencv_core.CV_32S); if (padding != null) - matOutput = OpenCVTools.crop(matOutput, padding); + matOutput[0] = OpenCVTools.crop(matOutput[0], padding); + return matOutput; } catch (TranslateException e) { nTilesFailed.incrementAndGet(); @@ -189,7 +192,13 @@ public Mat process(Parameters params) throws IOException { * @return Percentile-based normalisation based on the bounding box, * or default tile-based percentile normalisation if that fails. */ - private static ImageOp getNormalization(ImageData imageData, ROI roi, Collection channels, double lowPerc, double highPerc) { + private static ImageOp getNormalization( + ImageData imageData, + ROI roi, + Collection channels, + double lowPerc, + double highPerc) { + var defaults = ImageOps.Normalize.percentile(lowPerc, highPerc, true, 1e-6); try { BufferedImage image; diff --git a/src/main/java/qupath/ext/instanseg/ui/CheckModelCache.java b/src/main/java/qupath/ext/instanseg/ui/CheckModelCache.java index fb0fdbf..e848d31 100644 --- a/src/main/java/qupath/ext/instanseg/ui/CheckModelCache.java +++ b/src/main/java/qupath/ext/instanseg/ui/CheckModelCache.java @@ -5,9 +5,7 @@ import javafx.beans.property.SimpleObjectProperty; import javafx.beans.value.ObservableValue; import org.controlsfx.control.CheckComboBox; -import org.controlsfx.control.CheckModel; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.WeakHashMap; @@ -22,7 +20,7 @@ public class CheckModelCache { private final ObjectProperty value = new SimpleObjectProperty<>(); private final CheckComboBox checkbox; - private final Map> lastChecks = new WeakHashMap<>(); + private final Map> lastChecks = new WeakHashMap<>(); private CheckModelCache(ObservableValue value, CheckComboBox checkbox) { this.value.bind(value); @@ -36,11 +34,11 @@ private CheckModelCache(ObservableValue value, CheckComboBox checkbox) { *

* This can then be used to restore the checks later, if needed. * - * @param value - * @param checkBox - * @return - * @param - * @param + * @param value the observable value that this cache corresponds to (e.g., a specific model's name). + * @param checkBox the CheckComboBox containing the possibly checked items. + * @return A {@link CheckModelCache} + * @param The observable value type. + * @param The type of checked item. */ public static CheckModelCache create(ObservableValue value, CheckComboBox checkBox) { return new CheckModelCache<>(value, checkBox); @@ -48,13 +46,13 @@ public static CheckModelCache create(ObservableValue value, Chec private void handleValueChange(ObservableValue observable, S oldValue, S newValue) { if (oldValue != null) { - lastChecks.put(oldValue, List.copyOf(checkbox.getCheckModel().getCheckedItems())); + lastChecks.put(oldValue, List.copyOf(checkbox.getCheckModel().getCheckedIndices())); } } /** * Get the value property. - * @return + * @return The value property. */ public ReadOnlyObjectProperty valueProperty() { return value; @@ -72,11 +70,11 @@ public ReadOnlyObjectProperty valueProperty() { * @return true if the checks were restored, false otherwise */ public boolean restoreChecks() { - List checks = lastChecks.get(value.get()); - if (checks != null && new HashSet<>(checkbox.getItems()).containsAll(checks)) { + List checks = lastChecks.get(value.get()); + if (checks != null && checks.stream().allMatch(i -> i < checkbox.getItems().size())) { var checkModel = checkbox.getCheckModel(); checkModel.clearChecks(); - checks.forEach(checkModel::check); + checks.forEach(checkModel::checkIndices); return true; } else { return false; @@ -103,12 +101,12 @@ public boolean resetChecks() { * Create a snapshot of the checks currently associated with the observable value. * This is useful in case some other checkbox manipulation is required without switching the value, * and we want to restore the checks later (e.g. changing the items). - * @return + * @return whether the snapshot was possible. */ public boolean snapshotChecks() { var val = value.get(); if (val != null) { - lastChecks.put(val, List.copyOf(checkbox.getCheckModel().getCheckedItems())); + lastChecks.put(val, List.copyOf(checkbox.getCheckModel().getCheckedIndices())); return true; } else { return false; diff --git a/src/main/java/qupath/ext/instanseg/ui/InputChannelItem.java b/src/main/java/qupath/ext/instanseg/ui/InputChannelItem.java index 4a477a6..3376024 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InputChannelItem.java +++ b/src/main/java/qupath/ext/instanseg/ui/InputChannelItem.java @@ -63,8 +63,8 @@ String getConstructor() { /** * Get a list of available channels for the given image data. - * @param imageData - * @return + * @param imageData The image data. + * @return A list of channels, including color deconvolution channels if present. */ static List getAvailableChannels(ImageData imageData) { List list = new ArrayList<>(); diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index 4db4e1d..b769a70 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -143,7 +143,7 @@ public class InstanSegController extends BorderPane { private List remoteModels; - private BooleanProperty requestingPyTorch = new SimpleBooleanProperty(false); + private final BooleanProperty requestingPyTorch = new SimpleBooleanProperty(false); // Listener for property changes in the current ImageData; these can be required to update the input channels private final PropertyChangeListener imageDataPropertyChangeListener = this::handleImageDataPropertyChange; @@ -249,9 +249,7 @@ private void configureInfoButton() { infoButton.disableProperty().bind(selectedModelIsAvailable.not()); WebView webView = WebViews.create(true); PopOver infoPopover = new PopOver(webView); - infoButton.setOnAction(e -> { - parseMarkdown(selectedModel.get(), webView, infoButton, infoPopover); - }); + infoButton.setOnAction(e -> parseMarkdown(selectedModel.get(), webView, infoButton, infoPopover)); } private void configureDownloadButton() { @@ -382,7 +380,6 @@ private void updateInputChannels(ImageData imageData) { comboInputChannels.getCheckModel().checkIndices(0, 1, 2); } } - } } @@ -514,7 +511,7 @@ private void refreshModelChoice() { /** * Try to download the currently-selected model in another thread. - * @return + * @return A Future that is completed when the download finishes. */ private CompletableFuture downloadSelectedModelAsync() { var model = selectedModel.get(); @@ -526,8 +523,8 @@ private CompletableFuture downloadSelectedModelAsync() { /** * Try to download the specified model in another thread. - * @param model - * @return + * @param model The model + * @return A future that is completed when the download finishes. */ private CompletableFuture downloadModelAsync(InstanSegModel model) { var modelDir = InstanSegUtils.getModelDirectory().orElse(null); @@ -541,9 +538,9 @@ private CompletableFuture downloadModelAsync(InstanSegModel mode /** * Try to download the specified model to the given directory in the current thread. - * @param model - * @param modelDir - * @return + * @param model The model. + * @param modelDir The model directory. + * @return The downloaded model. */ private InstanSegModel downloadModel(InstanSegModel model, Path modelDir) { Objects.requireNonNull(modelDir); @@ -577,10 +574,8 @@ private static void parseMarkdown(InstanSegModel model, WebView webView, Button // If the markdown doesn't start with a title, pre-pending the model title & description (if available) if (!body.startsWith("#")) { - var sb = new StringBuilder(); - sb.append("## ").append(model.getName()).append("\n\n"); - sb.append("----\n\n"); - doc.prependChild(parser.parse(sb.toString())); + String sb = "## " + model.getName() + "\n\n----\n\n"; + doc.prependChild(parser.parse(sb)); } webView.getEngine().loadContent(HtmlRenderer.builder().build().render(doc)); infoPopover.show(infoButton); @@ -626,6 +621,7 @@ private static List getRemoteModels() { } } InputStream in = InstanSegController.class.getResourceAsStream("model-index.json"); + assert in != null; String cont = new BufferedReader(new InputStreamReader(in)) .lines() .collect(Collectors.joining("\n")); diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegPreferences.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegPreferences.java index 493dd20..47760a0 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegPreferences.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegPreferences.java @@ -3,7 +3,6 @@ import javafx.beans.property.BooleanProperty; import javafx.beans.property.IntegerProperty; import javafx.beans.property.ObjectProperty; -import javafx.beans.property.Property; import javafx.beans.property.StringProperty; import qupath.lib.common.GeneralTools; import qupath.lib.gui.prefs.PathPrefs; @@ -56,7 +55,7 @@ enum OnlinePermission { /** * MPS should work reliably (and much faster) on Apple Silicon, so set as default. * Everywhere else, use CPU as we can't count on a GPU/CUDA being available. - * @return + * @return The default device string. */ private static String getDefaultDevice() { if (GeneralTools.isMac() && "aarch64".equals(System.getProperty("os.arch"))) { diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegTask.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegTask.java index 8132bd2..0d4df5b 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegTask.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegTask.java @@ -14,7 +14,6 @@ import qupath.lib.plugins.workflow.DefaultScriptableWorkflowStep; import java.awt.image.BufferedImage; -import java.io.File; import java.nio.file.Path; import java.util.Arrays; import java.util.List; @@ -62,6 +61,7 @@ protected Void call() { return null; } // TODO: HANDLE OUTPUT CHANNELS! + // todo: Unclear what this means int nOutputs = model.getOutputChannels().orElse(1); int[] outputChannels = new int[0]; if (nOutputs <= 0) { diff --git a/src/main/java/qupath/ext/instanseg/ui/ModelListCell.java b/src/main/java/qupath/ext/instanseg/ui/ModelListCell.java index bb253f4..f6233a4 100644 --- a/src/main/java/qupath/ext/instanseg/ui/ModelListCell.java +++ b/src/main/java/qupath/ext/instanseg/ui/ModelListCell.java @@ -12,10 +12,10 @@ public class ModelListCell extends ListCell { - private ResourceBundle resources = InstanSegResources.getResources(); + private final ResourceBundle resources = InstanSegResources.getResources(); - private Glyph web = createOnlineIcon(); - private Tooltip tooltip; + private final Glyph web = createOnlineIcon(); + private final Tooltip tooltip; public ModelListCell() { super(); diff --git a/src/main/java/qupath/ext/instanseg/ui/OutputChannelItem.java b/src/main/java/qupath/ext/instanseg/ui/OutputChannelItem.java index 2fab682..c1bfaf3 100644 --- a/src/main/java/qupath/ext/instanseg/ui/OutputChannelItem.java +++ b/src/main/java/qupath/ext/instanseg/ui/OutputChannelItem.java @@ -40,7 +40,7 @@ static List getOutputsForChannelCount(int nChannels) { /** * Get the index of the output channel. - * @return + * @return The index of the output channel */ public int getIndex() { return index; diff --git a/src/main/java/qupath/ext/instanseg/ui/Watcher.java b/src/main/java/qupath/ext/instanseg/ui/Watcher.java index e2502da..b7d9448 100644 --- a/src/main/java/qupath/ext/instanseg/ui/Watcher.java +++ b/src/main/java/qupath/ext/instanseg/ui/Watcher.java @@ -64,7 +64,7 @@ private Watcher() { /** * Get a singleton instance of the Watcher. - * @return + * @return The watcher. */ static Watcher getInstance() { return instance; @@ -176,7 +176,7 @@ private void processEvents() { /** * Add the model, after checking it won't be a duplicate. - * @param model + * @param model The model. */ private void ensureModelInList(InstanSegModel model) { for (var existingModel : models) { @@ -232,7 +232,7 @@ private void refreshAllModelPaths() { /** * Update all models in response to a change in the model paths. - * @param change + * @param change The change event. */ private void handleModelPathsChanged(ListChangeListener.Change change) { if (!Platform.isFxApplicationThread()) @@ -250,13 +250,12 @@ private void handleModelPathsChanged(ListChangeListener.Change c /** * Get all the local models in a directory. - * @param dir - * @return + * @param dir The model directory. + * @return A list of model paths. */ private static List getModelPathsInDir(Path dir) { - try { - return Files.list(dir) - .filter(InstanSegModel::isValidModel) + try (var paths = Files.list(dir)) { + return paths.filter(InstanSegModel::isValidModel) .toList(); } catch (IOException e) { logger.error("Unable to list files in directory", e); @@ -272,7 +271,7 @@ private static List getModelPathsInDir(Path dir) { * Any calling code needs to figure out of models are really the same or different, which requires some * decision-making (e.g. is a model that has been downloaded the same as the local representation...? * Or are two models the same if the directories are duplicated, so that one has a different path...?). - * @return + * @return A list of models. */ ObservableList getModels() { return modelsUnmodifiable;