diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java new file mode 100644 index 0000000..9a8ce2d --- /dev/null +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -0,0 +1,409 @@ +package qupath.ext.instanseg.core; + +import ai.djl.Device; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import qupath.lib.images.ImageData; +import qupath.lib.images.servers.ColorTransforms; +import qupath.lib.objects.PathAnnotationObject; +import qupath.lib.objects.PathCellObject; +import qupath.lib.objects.PathDetectionObject; +import qupath.lib.objects.PathObject; +import qupath.lib.plugins.TaskRunner; +import qupath.lib.plugins.TaskRunnerUtils; + +import java.awt.image.BufferedImage; +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.stream.IntStream; + +public class InstanSeg { + private static final Logger logger = LoggerFactory.getLogger(InstanSeg.class); + + private final int tileDims; + private final double downsample; + private final int padding; + private final int boundary; + private final int numOutputChannels; + private final ImageData imageData; + private final Collection channels; + private final InstanSegModel model; + private final Device device; + private final TaskRunner taskRunner; + private final List> outputClasses; + + + /** + * Create a builder object for InstanSeg. + * @return A builder, which may not be valid. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Run inference for the currently selected PathObjects. + */ + public void detectObjects() { + detectObjects(imageData.getHierarchy().getSelectionModel().getSelectedObjects()); + } + + /** + * Run inference for a collection of PathObjects. + */ + public void detectObjects(Collection pathObjects) { + model.runInstanSeg( + imageData, + pathObjects, + channels, + tileDims, + downsample, + padding, + boundary, + device, + numOutputChannels == 1, + outputClasses, + taskRunner + ); + } + + + /** + * A builder class for InstanSeg. + */ + public static final class Builder { + private int tileDims = 512; + private double downsample = 1; + private int padding = 40; + private int boundary = 20; + private int numOutputChannels = 2; + private Device device = Device.fromName("cpu"); + private TaskRunner taskRunner = TaskRunnerUtils.getDefaultInstance().createTaskRunner(); + private ImageData imageData; + private Collection channels; + private InstanSegModel model; + private List> outputClasses; + + Builder() {} + + /** + * Set the width and height of tiles + * @param tileDims The tile width and height + * @return A modified builder + */ + public Builder tileDims(int tileDims) { + this.tileDims = tileDims; + return this; + } + + /** + * Set the downsample to be used in region requests + * @param downsample The downsample to be used + * @return A modified builder + */ + public Builder downsample(double downsample) { + this.downsample = downsample; + return this; + } + + /** + * Set the padding (overlap) between tiles + * @param padding The extra size added to tiles to allow overlap + * @return A modified builder + */ + public Builder interTilePadding(int padding) { + this.padding = padding; + return this; + } + + /** + * Set the size of the overlap region between tiles + * @param boundary The width in pixels that overlaps between tiles + * @return A modified builder + */ + public Builder tileBoundary(int boundary) { + this.boundary = boundary; + return this; + } + + /** + * Set the number of output channels + * @param numOutputChannels The number of output channels (1 or 2 currently) + * @return A modified builder + */ + public Builder numOutputChannels(int numOutputChannels) { + this.numOutputChannels = numOutputChannels; + return this; + } + + /** + * Set the imageData to be used + * @param imageData An imageData instance + * @return A modified builder + */ + public Builder imageData(ImageData imageData) { + this.imageData = imageData; + return this; + } + + /** + * Set the channels to be used in inference + * @param channels A collection of channels to be used in inference + * @return A modified builder + */ + public Builder channels(Collection channels) { + this.channels = channels; + return this; + } + + /** + * Set the channels to be used in inference + * @param channels Channels to be used in inference + * @return A modified builder + */ + public Builder channels(ColorTransforms.ColorTransform channel, ColorTransforms.ColorTransform... channels) { + var l = Arrays.asList(channels); + l.add(channel); + this.channels = l; + return this; + } + + /** + * Set the model to use all channels for inference + * @return A modified builder + */ + public Builder allChannels() { + // assignment is just to suppress IDE suggestion for void return val + var tmp = channelIndices( + IntStream.range(0, imageData.getServer().nChannels()) + .boxed() + .toList()); + return this; + } + + /** + * Set the channels using indices + * @param channels Integers used to specify the channels used + * @return A modified builder + */ + public Builder channelIndices(Collection channels) { + this.channels = channels.stream() + .map(ColorTransforms::createChannelExtractor) + .toList(); + return this; + } + + /** + * Set the channels using indices + * @param channels Integers used to specify the channels used + * @return A modified builder + */ + public Builder channelIndices(int channel, int... channels) { + List l = new ArrayList<>(); + l.add(ColorTransforms.createChannelExtractor(channel)); + for (int i: channels) { + l.add(ColorTransforms.createChannelExtractor(i)); + } + this.channels = l; + return this; + } + + /** + * Set the channel names to be used + * @param channels A set of channel names + * @return A modified builder + */ + public Builder channelNames(Collection channels) { + this.channels = channels.stream() + .map(ColorTransforms::createChannelExtractor) + .toList(); + return this; + } + + /** + * Set the channel names to be used + * @param channels A set of channel names + * @return A modified builder + */ + public Builder channelNames(String channel, String... channels) { + List l = new ArrayList<>(); + l.add(ColorTransforms.createChannelExtractor(channel)); + for (String s: channels) { + l.add(ColorTransforms.createChannelExtractor(s)); + } + this.channels = l; + return this; + } + + /** + * Set the number of threads used + * @param nThreads The number of threads to be used + * @return A modified builder + */ + public Builder nThreads(int nThreads) { + this.taskRunner = TaskRunnerUtils.getDefaultInstance().createTaskRunner(nThreads); + return this; + } + + /** + * Set the TaskRunner + * @param taskRunner An object that will run tasks and show progress + * @return A modified builder + */ + public Builder taskRunner(TaskRunner taskRunner) { + this.taskRunner = taskRunner; + return this; + } + + /** + * Set the specific model to be used + * @param model An already instantiated InstanSeg model. + * @return A modified builder + */ + public Builder model(InstanSegModel model) { + this.model = model; + return this; + } + + /** + * Set the specific model by path + * @param path A path on disk to create an InstanSeg model from. + * @return A modified builder + */ + public Builder modelPath(Path path) throws IOException { + return model(InstanSegModel.fromPath(path)); + } + + /** + * Set the specific model by path + * @param path A path on disk to create an InstanSeg model from. + * @return A modified builder + */ + public Builder modelPath(String path) throws IOException { + return modelPath(Path.of(path)); + } + + /** + * Set the specific model to be used + * @param name The name of a built-in model + * @return A modified builder + */ + public Builder modelName(String name) { + return model(InstanSegModel.fromName(name)); + } + + /** + * Set the device to be used + * @param deviceName The name of the device to be used (eg, "gpu", "mps"). + * @return A modified builder + */ + public Builder device(String deviceName) { + this.device = Device.fromName(deviceName); + return this; + } + + /** + * Set the device to be used + * @param device The {@link Device} to be used + * @return A modified builder + */ + public Builder device(Device device) { + this.device = device; + return this; + } + + /** + * Specify the output class(es) + * @param outputClasses A list specifying what type the output should be. + * eg, [PathDetectionObject.class, PathAnnotationObject.class] + * specifies to create detections nested inside annotations. + * @return A modified builder + */ + public Builder outputClasses(List> outputClasses) { + this.outputClasses = outputClasses; + return this; + } + + /** + * Specify cells as the output class, possibly without nuclei + * @return A modified builder + */ + public Builder outputCells() { + this.outputClasses = List.of(PathCellObject.class); + return this; + } + + /** + * Specify (possibly nested) detections as the output class + * @return A modified builder + */ + public Builder outputDetections() { + this.outputClasses = List.of(PathDetectionObject.class); + return this; + } + + /** + * Specify (possibly nested) annotations as the output class + * @return A modified builder + */ + public Builder outputAnnotations() { + this.outputClasses = List.of(PathAnnotationObject.class); + return this; + } + + /** + * Build the InstanSeg instance. + * @return An InstanSeg instance ready for object detection. + */ + public InstanSeg build() { + if (imageData == null) { + throw new IllegalStateException("imageData cannot be null!"); + } + if (channels == null) { + // assignment is just to suppress IDE suggestion for void return + var tmp = allChannels(); + } + if (outputClasses == null) { + var tmp = outputCells(); + } + if (outputClasses.size() > 1 && numOutputChannels == 1) { + throw new IllegalArgumentException("Cannot have multiple output types when using only one output channel."); + } + return new InstanSeg( + this.tileDims, + this.downsample, + this.padding, + this.boundary, + this.numOutputChannels, + this.imageData, + this.channels, + this.model, + this.device, + this.taskRunner, + this.outputClasses); + } + + } + + + + private InstanSeg(int tileDims, double downsample, int padding, int boundary, int numOutputChannels, ImageData imageData, + Collection channels, InstanSegModel model, Device device, TaskRunner taskRunner, + List> outputClasses) { + this.tileDims = tileDims; + this.downsample = downsample; + this.padding = padding; + this.boundary = boundary; + this.numOutputChannels = numOutputChannels; + this.imageData = imageData; + this.channels = channels; + this.model = model; + this.device = device; + this.taskRunner = taskRunner; + this.outputClasses = outputClasses; + } +} diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java index e96adc2..613f67a 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSegModel.java @@ -1,11 +1,9 @@ package qupath.ext.instanseg.core; import ai.djl.Device; -import ai.djl.MalformedModelException; import ai.djl.inference.Predictor; import ai.djl.ndarray.BaseNDManager; import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.training.util.ProgressBar; import com.google.gson.internal.LinkedTreeMap; import org.bytedeco.opencv.opencv_core.Mat; @@ -31,6 +29,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; @@ -50,15 +49,32 @@ private InstanSegModel(BioimageIoSpec.BioimageIoModel bioimageIoModel) { this.name = model.getName(); } - public InstanSegModel(URL modelURL, String name) { + + private InstanSegModel(URL modelURL, String name) { this.modelURL = modelURL; this.name = name; } - public static InstanSegModel createModel(Path path) throws IOException { + /** + * Create an InstanSeg model from an existing path. + * @param path The path to the folder that contains the model .pt file and the config YAML file. + * @return A handle on the model that can be used for inference. + * @throws IOException If the directory can't be found or isn't a valid model directory. + */ + public static InstanSegModel fromPath(Path path) throws IOException { return new InstanSegModel(BioimageIoSpec.parseModel(path.toFile())); } + /** + * Request an InstanSeg model from the set of available models + * @param name The model name + * @return The specified model. + */ + public static InstanSegModel fromName(String name) { + // todo: instantiate built-in models somehow + throw new UnsupportedOperationException("Fetching models by name is not yet implemented!"); + } + public BioimageIoSpec.BioimageIoModel getModel() { if (model == null) { try { @@ -128,26 +144,24 @@ public String toString() { return getName(); } - public void runInstanSeg( - Collection pathObjects, + void runInstanSeg( ImageData imageData, + Collection pathObjects, Collection channels, - int tileSize, + int tileDims, double downsample, - String deviceName, + int padding, + int boundary, + Device device, boolean nucleiOnly, - TaskRunner taskRunner) throws ModelNotFoundException, MalformedModelException, IOException, InterruptedException { + List> outputClasses, + TaskRunner taskRunner) { nFailed = 0; Path modelPath = getPath().resolve("instanseg.pt"); int nPredictors = 1; // todo: change me? - int padding = 40; // todo: setting? or just based on tile size. Should discuss. - int boundary = 20; - if (tileSize == 128) { - padding = 25; - boundary = 15; - } + // Optionally pad images to the required size boolean padToInputSize = true; String layout = "CHW"; @@ -155,7 +169,6 @@ public void runInstanSeg( // TODO: Remove C if not needed (added for instanseg_v0_2_0.pt) - still relevant? String layoutOutput = "CHW"; - var device = Device.fromName(deviceName); try (var model = Criteria.builder() .setTypes(Mat.class, Mat.class) @@ -166,6 +179,7 @@ public void runInstanSeg( .build() .loadModel()) { + BaseNDManager baseManager = (BaseNDManager)model.getNDManager(); printResourceCount("Resource count before prediction", (BaseNDManager)baseManager.getParentManager()); @@ -180,9 +194,9 @@ public void runInstanSeg( printResourceCount("Resource count after creating predictors", (BaseNDManager)baseManager.getParentManager()); - int sizeWithoutPadding = (int) Math.ceil(downsample * (tileSize - (double) padding)); + int sizeWithoutPadding = (int) Math.ceil(downsample * (tileDims - (double) padding)); var predictionProcessor = new TilePredictionProcessor(predictors, baseManager, - layout, layoutOutput, channels, tileSize, tileSize, padToInputSize); + layout, layoutOutput, channels, tileDims, tileDims, padToInputSize); var processor = OpenCVProcessor.builder(predictionProcessor) .imageSupplier((parameters) -> ImageOps.buildImageDataOp(channels).apply(parameters.getImageData(), parameters.getRegionRequest())) .tiler(Tiler.builder(sizeWithoutPadding) @@ -190,7 +204,7 @@ public void runInstanSeg( .cropTiles(false) .build() ) - .outputHandler(new OutputToObjectConverter.PruneObjectOutputHandler<>(new OutputToObjectConverter(), boundary)) + .outputHandler(new PruneObjectOutputHandler<>(new InstansegOutputToObjectConverter(outputClasses), boundary)) .padding(padding) .merger(ObjectMerger.createIoUMerger(0.2)) .downsample(downsample) diff --git a/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java new file mode 100644 index 0000000..85b4271 --- /dev/null +++ b/src/main/java/qupath/ext/instanseg/core/InstansegOutputToObjectConverter.java @@ -0,0 +1,144 @@ +package qupath.ext.instanseg.core; + +import org.bytedeco.opencv.opencv_core.Mat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import qupath.lib.analysis.images.ContourTracing; +import qupath.lib.common.ColorTools; +import qupath.lib.experimental.pixels.OutputHandler; +import qupath.lib.experimental.pixels.Parameters; +import qupath.lib.objects.PathAnnotationObject; +import qupath.lib.objects.PathCellObject; +import qupath.lib.objects.PathDetectionObject; +import qupath.lib.objects.PathObject; +import qupath.lib.objects.PathObjects; +import qupath.lib.roi.interfaces.ROI; +import qupath.opencv.tools.OpenCVTools; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +class InstansegOutputToObjectConverter implements OutputHandler.OutputToObjectConverter { + private static final Logger logger = LoggerFactory.getLogger(InstansegOutputToObjectConverter.class); + + private static final long seed = 1243; + private final List> classes; + + InstansegOutputToObjectConverter(List> outputClasses) { + this.classes = outputClasses; + } + + @Override + public List convertToObjects(Parameters params, Mat output) { + if (output == null) { + return List.of(); + } + int nChannels = output.channels(); + if (nChannels < 1 || nChannels > 2) + throw new IllegalArgumentException("Expected 1 or 2 channels, but found " + nChannels); + + + List> roiMaps = new ArrayList<>(); + for (var mat : OpenCVTools.splitChannels(output)) { + var image = OpenCVTools.matToSimpleImage(mat, 0); + roiMaps.add( + ContourTracing.createROIs(image, params.getRegionRequest(), 1, -1) + ); + } + var rng = new Random(seed); + + BiFunction function; + if (classes.size() == 1) { + function = getOneClassBiFunction(classes, rng); + } else { + // if of length 2, then can be: + // detection <- annotation, annotation <- annotation, detection <- detection + assert classes.size() == 2; + function = getTwoClassBiFunction(classes, rng); + } + + if (roiMaps.size() == 1) { + // One-channel detected, represent using detection objects + return roiMaps.get(0).values().stream() + .map(roi -> function.apply(roi, null)) + .collect(Collectors.toList()); + } else { + // Two channels detected, represent using cell objects + // We assume that the labels are matched - and we can't have a nucleus without a cell + Map childROIs = roiMaps.get(0); + Map parentROIs = roiMaps.get(1); + List cells = new ArrayList<>(); + for (var entry : parentROIs.entrySet()) { + var parent = entry.getValue(); + var child = childROIs.getOrDefault(entry.getKey(), null); + var outputObject = function.apply(parent, child); + cells.add(outputObject); + } + return cells; + } + } + + private static BiFunction getOneClassBiFunction(List> classes, Random rng) { + // if of length 1, can be + // cellObject (with or without nucleus), annotations, detections + if (classes.get(0) == PathAnnotationObject.class) { + return createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng); + } else if (classes.get(0) == PathDetectionObject.class) { + return createObjectsFun(PathObjects::createDetectionObject, PathObjects::createDetectionObject, rng); + } else if (classes.get(0) == PathCellObject.class) { + return createCellFun(rng); + } else { + logger.warn("Unknown output {}, defaulting to cells", classes.get(0)); + return createCellFun(rng); + } + } + + private static BiFunction getTwoClassBiFunction(List> classes, Random rng) { + Function fun0, fun1; + var knownClasses = List.of(PathDetectionObject.class, PathAnnotationObject.class); + if (!knownClasses.contains(classes.get(0)) || !knownClasses.contains(classes.get(1))) { + logger.warn("Unknown combination of outputs {} <- {}, defaulting to cells", classes.get(0), classes.get(1)); + return createCellFun(rng); + } + if (classes.get(0) == PathDetectionObject.class) { + fun0 = PathObjects::createDetectionObject; + } else { + fun0 = PathObjects::createAnnotationObject; + } + if (classes.get(1) == PathDetectionObject.class) { + fun1 = PathObjects::createDetectionObject; + } else { + fun1 = PathObjects::createAnnotationObject; + } + return createObjectsFun(fun0, fun1, rng); + } + + private static BiFunction createCellFun(Random rng) { + return (parent, child) -> { + var cell = PathObjects.createCellObject(parent, child); + var color = ColorTools.packRGB(rng.nextInt(255), rng.nextInt(255), rng.nextInt(255)); + cell.setColor(color); + return cell; + }; + } + + private static BiFunction createObjectsFun(Function createParentFun, Function createChildFun, Random rng) { + return (parent, child) -> { + var parentObj = createParentFun.apply(parent); + var color = ColorTools.packRGB(rng.nextInt(255), rng.nextInt(255), rng.nextInt(255)); + parentObj.setColor(color); + if (child != null) { + var childObj = createChildFun.apply(child); + childObj.setColor(color); + parentObj.addChildObject(childObj); + } + return parentObj; + }; + } + +} diff --git a/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java b/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java deleted file mode 100644 index 0127717..0000000 --- a/src/main/java/qupath/ext/instanseg/core/OutputToObjectConverter.java +++ /dev/null @@ -1,214 +0,0 @@ -package qupath.ext.instanseg.core; - -import org.bytedeco.opencv.opencv_core.Mat; -import org.locationtech.jts.geom.Envelope; -import qupath.lib.analysis.images.ContourTracing; -import qupath.lib.experimental.pixels.OutputHandler; -import qupath.lib.experimental.pixels.Parameters; -import qupath.lib.experimental.pixels.PixelProcessorUtils; -import qupath.lib.objects.PathObject; -import qupath.lib.objects.PathObjects; -import qupath.lib.roi.GeometryTools; -import qupath.lib.roi.interfaces.ROI; -import qupath.opencv.tools.OpenCVTools; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.stream.Collectors; - -class OutputToObjectConverter implements OutputHandler.OutputToObjectConverter { - - private static final long seed = 1243; - - @Override - public List convertToObjects(Parameters params, Mat output) { - if (output == null) { - return List.of(); - } - int nChannels = output.channels(); - if (nChannels < 1 || nChannels > 2) - throw new IllegalArgumentException("Expected 1 or 2 channels, but found " + nChannels); - - List> roiMaps = new ArrayList<>(); - for (var mat : OpenCVTools.splitChannels(output)) { - var image = OpenCVTools.matToSimpleImage(mat, 0); - roiMaps.add( - ContourTracing.createROIs(image, params.getRegionRequest(), 1, -1) - ); - } - var rng = new Random(seed); - if (roiMaps.size() == 1) { - // One-channel detected, represent using detection objects - return roiMaps.get(0).values().stream() - .map(p -> { - var obj = PathObjects.createDetectionObject(p); - obj.setColor( - rng.nextInt(255), - rng.nextInt(255), - rng.nextInt(255) - ); - return obj; - }) - .collect(Collectors.toList()); - } else { - // Two channels detected, represent using cell objects - // We assume that the labels are matched - and we can't have a nucleus without a cell - Map nucleusROIs = roiMaps.get(0); - Map cellROIs = roiMaps.get(1); - List cells = new ArrayList<>(); - for (var entry : cellROIs.entrySet()) { - var cell = entry.getValue(); - var nucleus = nucleusROIs.getOrDefault(entry.getKey(), null); - var cellObject = PathObjects.createCellObject(cell, nucleus); - cellObject.setColor( - rng.nextInt(255), - rng.nextInt(255), - rng.nextInt(255) - ); - cells.add(cellObject); - } - return cells; - } - } - - static class PruneObjectOutputHandler implements OutputHandler { - - private final OutputToObjectConverter converter; - private final int boundaryThreshold; - - PruneObjectOutputHandler(OutputToObjectConverter converter, int boundaryThreshold) { - this.converter = converter; - this.boundaryThreshold = boundaryThreshold; - } - - @Override - public boolean handleOutput(Parameters params, U output) { - if (output == null) - return false; - else { - List newObjects = converter.convertToObjects(params, output); - if (newObjects == null) - return false; - // If using a proxy object (eg tile), - // we want to remove things touching the tile boundary, - // then add the objects to the proxy rather than the parent - var parentOrProxy = params.getParentOrProxy(); - parentOrProxy.clearChildObjects(); - - // remove features within N pixels of the region request boundaries - var bounds = GeometryTools.createRectangle( - params.getRegionRequest().getX(), params.getRegionRequest().getY(), - params.getRegionRequest().getWidth(), params.getRegionRequest().getHeight()); - - int width = params.getServer().getWidth(); - int height = params.getServer().getHeight(); - - newObjects = newObjects.parallelStream() - .filter(p -> doesntTouchBoundaries(p.getROI().getGeometry().getEnvelopeInternal(), bounds.getEnvelopeInternal(), boundaryThreshold, width, height)) - .toList(); - - if (!newObjects.isEmpty()) { - // since we're using IoU to merge objects, we want to keep anything that is within the overall object bounding box - var parent = params.getParent().getROI(); - newObjects = newObjects.parallelStream() - .flatMap(p -> PixelProcessorUtils.maskObject(parent, p).stream()) - .toList(); - } - parentOrProxy.addChildObjects(newObjects); - parentOrProxy.setLocked(true); - return true; - } - } - - - /** - * Tests if a detection is near the boundary of a parent region. - * It first checks if the detection is on the edge of the overall image, in which case it should be kept, - * unless it is at the edge of the image and the perpendicular edge of the parent region. - * For example, on the left side of the image, but on the top/bottom edge of the parent region. - * Then, it checks if the detection is on the boundary of the parent region. - * @param det The detection object. - * @param region The region containing all detection objects. - * @param boundaryPixels The size of the boundary, in pixels, to use for removing object. - * @param imageWidth The width of the image, in pixels. - * @param imageHeight The height of the image, in pixels. - * @return Whether the detection object should be removed, based on these criteria. - */ - private boolean doesntTouchBoundaries(Envelope det, Envelope region, int boundaryPixels, int imageWidth, int imageHeight) { - // keep any objects at the boundary of the annotation, except the stuff around region boundaries - if (touchesLeftOfImage(det, boundaryPixels)) { - if (touchesTopOfImage(det, boundaryPixels) || touchesBottomOfImage(det, imageHeight, boundaryPixels)) { - return true; - } - if (!(touchesBottomOfRegion(det, region, boundaryPixels) || touchesTopOfRegion(det, region, boundaryPixels))) { - return true; - } - } - if (touchesTopOfImage(det, boundaryPixels)) { - if (touchesLeftOfImage(det, boundaryPixels) || touchesRightOfImage(det, imageWidth, boundaryPixels)) { - return true; - } - if (!(touchesLeftOfRegion(det, region, boundaryPixels) || touchesRightOfRegion(det, region, boundaryPixels))) { - return true; - } - } - - if (touchesRightOfImage(det, imageWidth, boundaryPixels)) { - if (touchesTopOfImage(det, boundaryPixels) || touchesBottomOfImage(det, imageHeight, boundaryPixels)) { - return true; - } - if (!(touchesBottomOfRegion(det, region, boundaryPixels) || touchesTopOfRegion(det, region, boundaryPixels))) { - return true; - } - } - if (touchesBottomOfImage(det, imageHeight, boundaryPixels)) { - if (touchesLeftOfImage(det, boundaryPixels) || touchesRightOfImage(det, imageWidth, boundaryPixels)) { - return true; - } - if (!(touchesLeftOfRegion(det, region, boundaryPixels) || touchesRightOfRegion(det, region, boundaryPixels))) { - return true; - } - } - - // remove any objects at other region boundaries - return !(touchesLeftOfRegion(det, region, boundaryPixels) - || touchesRightOfRegion(det, region, boundaryPixels) - || touchesBottomOfRegion(det, region, boundaryPixels) - || touchesTopOfRegion(det, region, boundaryPixels)); - } - } - - private static boolean touchesLeftOfImage(Envelope det, int boundary) { - return det.getMinX() < boundary; - } - - private static boolean touchesRightOfImage(Envelope det, int width, int boundary) { - return width - det.getMaxX() < boundary; - } - - private static boolean touchesTopOfImage(Envelope det, int boundary) { - return det.getMinY() < boundary; - } - - private static boolean touchesBottomOfImage(Envelope det, int height, int boundary) { - return height - det.getMaxY() < boundary; - } - - private static boolean touchesLeftOfRegion(Envelope det, Envelope region, int boundary) { - return det.getMinX() - region.getMinX() < boundary; - } - - private static boolean touchesRightOfRegion(Envelope det, Envelope region, int boundary) { - return region.getMaxX() - det.getMaxX() < boundary; - } - - private static boolean touchesTopOfRegion(Envelope det, Envelope region, int boundary) { - return det.getMinY() - region.getMinY() < boundary; - } - - private static boolean touchesBottomOfRegion(Envelope det, Envelope region, int boundary) { - return region.getMaxY() - det.getMaxY() < boundary; - } -} diff --git a/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java b/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java new file mode 100644 index 0000000..493d54a --- /dev/null +++ b/src/main/java/qupath/ext/instanseg/core/PruneObjectOutputHandler.java @@ -0,0 +1,164 @@ +package qupath.ext.instanseg.core; + +import org.locationtech.jts.geom.Envelope; +import qupath.lib.experimental.pixels.OutputHandler; +import qupath.lib.experimental.pixels.Parameters; +import qupath.lib.experimental.pixels.PixelProcessorUtils; +import qupath.lib.objects.PathObject; +import qupath.lib.roi.GeometryTools; + +import java.util.List; + +class PruneObjectOutputHandler implements OutputHandler { + + private final OutputToObjectConverter converter; + private final int boundaryThreshold; + + /** + * An output handler that prunes the output, removing any objects that are + * within a certain distance (in pixels) to the tile boundaries, leaving + * all objects on the border of the image. + *

+ * Relies on having a relatively large overlap between tiles. + *

+ * Useful if you want to use for example IoU to merge objects between tiles, + * where the general QuPath approach of merging objects with shared + * boundaries won't work. + * @param converter An output to object converter. + * @param boundaryThreshold The size of the boundary, in pixels, to use for removing objects. + * See {@link #doesntTouchBoundaries} for more details. + */ + PruneObjectOutputHandler(OutputToObjectConverter converter, int boundaryThreshold) { + this.converter = converter; + this.boundaryThreshold = boundaryThreshold; + } + + @Override + public boolean handleOutput(Parameters params, U output) { + if (output == null) + return false; + else { + List newObjects = converter.convertToObjects(params, output); + if (newObjects == null) + return false; + // If using a proxy object (eg tile), + // we want to remove things touching the tile boundary, + // then add the objects to the proxy rather than the parent + var parentOrProxy = params.getParentOrProxy(); + parentOrProxy.clearChildObjects(); + + // remove features within N pixels of the region request boundaries + var bounds = GeometryTools.createRectangle( + params.getRegionRequest().getX(), params.getRegionRequest().getY(), + params.getRegionRequest().getWidth(), params.getRegionRequest().getHeight()); + + int width = params.getServer().getWidth(); + int height = params.getServer().getHeight(); + + newObjects = newObjects.parallelStream() + .filter(p -> doesntTouchBoundaries(p.getROI().getGeometry().getEnvelopeInternal(), bounds.getEnvelopeInternal(), boundaryThreshold, width, height)) + .toList(); + + if (!newObjects.isEmpty()) { + // since we're using IoU to merge objects, we want to keep anything that is within the overall object bounding box + var parent = params.getParent().getROI(); + newObjects = newObjects.parallelStream() + .flatMap(p -> PixelProcessorUtils.maskObject(parent, p).stream()) + .toList(); + } + parentOrProxy.addChildObjects(newObjects); + parentOrProxy.setLocked(true); + return true; + } + } + + + /** + * Tests if a detection is near the boundary of a parent region. + * It first checks if the detection is on the edge of the overall image, in which case it should be kept, + * unless it is at the edge of the image and the perpendicular edge of the parent region. + * For example, on the left side of the image, but on the top/bottom edge of the parent region. + * Then, it checks if the detection is on the boundary of the parent region. + * + * @param det The detection object. + * @param region The region containing all detection objects. + * @param boundaryPixels The size of the boundary, in pixels, to use for removing objects. + * @param imageWidth The width of the image, in pixels. + * @param imageHeight The height of the image, in pixels. + * @return Whether the detection object should be removed, based on these criteria. + */ + private boolean doesntTouchBoundaries(Envelope det, Envelope region, int boundaryPixels, int imageWidth, int imageHeight) { + // keep any objects at the boundary of the annotation, except the stuff around region boundaries + if (touchesLeftOfImage(det, boundaryPixels)) { + if (touchesTopOfImage(det, boundaryPixels) || touchesBottomOfImage(det, imageHeight, boundaryPixels)) { + return true; + } + if (!(touchesBottomOfRegion(det, region, boundaryPixels) || touchesTopOfRegion(det, region, boundaryPixels))) { + return true; + } + } + if (touchesTopOfImage(det, boundaryPixels)) { + if (touchesLeftOfImage(det, boundaryPixels) || touchesRightOfImage(det, imageWidth, boundaryPixels)) { + return true; + } + if (!(touchesLeftOfRegion(det, region, boundaryPixels) || touchesRightOfRegion(det, region, boundaryPixels))) { + return true; + } + } + + if (touchesRightOfImage(det, imageWidth, boundaryPixels)) { + if (touchesTopOfImage(det, boundaryPixels) || touchesBottomOfImage(det, imageHeight, boundaryPixels)) { + return true; + } + if (!(touchesBottomOfRegion(det, region, boundaryPixels) || touchesTopOfRegion(det, region, boundaryPixels))) { + return true; + } + } + if (touchesBottomOfImage(det, imageHeight, boundaryPixels)) { + if (touchesLeftOfImage(det, boundaryPixels) || touchesRightOfImage(det, imageWidth, boundaryPixels)) { + return true; + } + if (!(touchesLeftOfRegion(det, region, boundaryPixels) || touchesRightOfRegion(det, region, boundaryPixels))) { + return true; + } + } + + // remove any objects at other region boundaries + return !(touchesLeftOfRegion(det, region, boundaryPixels) + || touchesRightOfRegion(det, region, boundaryPixels) + || touchesBottomOfRegion(det, region, boundaryPixels) + || touchesTopOfRegion(det, region, boundaryPixels)); + } + + private static boolean touchesLeftOfImage(Envelope det, int boundary) { + return det.getMinX() < boundary; + } + + private static boolean touchesRightOfImage(Envelope det, int width, int boundary) { + return width - det.getMaxX() < boundary; + } + + private static boolean touchesTopOfImage(Envelope det, int boundary) { + return det.getMinY() < boundary; + } + + private static boolean touchesBottomOfImage(Envelope det, int height, int boundary) { + return height - det.getMaxY() < boundary; + } + + private static boolean touchesLeftOfRegion(Envelope det, Envelope region, int boundary) { + return det.getMinX() - region.getMinX() < boundary; + } + + private static boolean touchesRightOfRegion(Envelope det, Envelope region, int boundary) { + return region.getMaxX() - det.getMaxX() < boundary; + } + + private static boolean touchesTopOfRegion(Envelope det, Envelope region, int boundary) { + return det.getMinY() - region.getMinY() < boundary; + } + + private static boolean touchesBottomOfRegion(Envelope det, Envelope region, int boundary) { + return region.getMaxY() - det.getMaxY() < boundary; + } +} diff --git a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java index b5f2dcf..eaf7bf5 100644 --- a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java +++ b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java @@ -1,7 +1,11 @@ package qupath.ext.instanseg.ui; +import qupath.lib.color.ColorDeconvolutionStains; import qupath.lib.images.servers.ColorTransforms; +import java.util.Collection; +import java.util.stream.Collectors; + /** * Super simple class to deal with channel selection dropdown items that have different display and selection names. * e.g., the first channel in non-RGB images is shown as "Channel 1 (C1)" but the actual name is "Channel 1". @@ -9,14 +13,24 @@ class ChannelSelectItem { private final String name; private final ColorTransforms.ColorTransform transform; + private final String constructor; + ChannelSelectItem(String name) { this.name = name; this.transform = ColorTransforms.createChannelExtractor(name); + this.constructor = String.format("ColorTransforms.createChannelExtractor(\"%s\")", name); } - ChannelSelectItem(String name, ColorTransforms.ColorTransform transform) { + ChannelSelectItem(String name, int i) { this.name = name; - this.transform = transform; + this.transform = ColorTransforms.createChannelExtractor(i); + this.constructor = String.format("ColorTransforms.createChannelExtractor(%d)", i); + } + + ChannelSelectItem(ColorDeconvolutionStains stains, int i) { + this.name = stains.getStain(i).getName(); + this.transform = ColorTransforms.createColorDeconvolvedChannel(stains, i); + this.constructor = String.format("ColorTransforms.createColorDeconvolvedChannel(stains, %d)", i); } @Override @@ -24,11 +38,19 @@ public String toString() { return this.name; } - public String getName() { + String getName() { return name; } - public ColorTransforms.ColorTransform getTransform() { + ColorTransforms.ColorTransform getTransform() { return transform; } + + String getConstructor() { + return this.constructor; + } + + static String toConstructorString(Collection items) { + return "[" + items.stream().map(ChannelSelectItem::getConstructor).collect(Collectors.joining(", ")) + "]"; + } } diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index 11e4a35..dde7626 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -1,12 +1,11 @@ package qupath.ext.instanseg.ui; -import ai.djl.MalformedModelException; -import ai.djl.repository.zoo.ModelNotFoundException; import javafx.application.Platform; import javafx.beans.binding.Bindings; import javafx.beans.property.ObjectProperty; import javafx.beans.property.SimpleObjectProperty; import javafx.beans.property.StringProperty; +import javafx.beans.value.ObservableValue; import javafx.collections.FXCollections; import javafx.collections.ListChangeListener; import javafx.concurrent.Task; @@ -30,6 +29,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import qupath.ext.instanseg.core.DetectionMeasurer; +import qupath.ext.instanseg.core.InstanSeg; import qupath.ext.instanseg.core.InstanSegModel; import qupath.fx.dialogs.Dialogs; import qupath.fx.dialogs.FileChoosers; @@ -37,13 +37,12 @@ import qupath.lib.common.ThreadTools; import qupath.lib.display.ChannelDisplayInfo; import qupath.lib.gui.QuPathGUI; -import qupath.lib.gui.scripting.QPEx; +import qupath.lib.gui.TaskRunnerFX; import qupath.lib.gui.tools.GuiTools; import qupath.lib.images.ImageData; -import qupath.lib.images.servers.ColorTransforms; import qupath.lib.images.servers.ImageServer; import qupath.lib.objects.PathObject; -import qupath.lib.scripting.QP; +import qupath.lib.plugins.workflow.DefaultScriptableWorkflowStep; import java.awt.image.BufferedImage; import java.io.File; @@ -167,6 +166,11 @@ private void addSetFromVisible(CheckComboBox comboChannels) { var channelNames = activeChannels.stream() .map(ChannelDisplayInfo::getName) .toList(); + if (qupath.getImageData() != null && !qupath.getImageData().getServer().isRGB()) { + channelNames = channelNames.stream() + .map(s -> s.replaceAll(" \\(C\\d+\\)$", "")) + .toList(); + } var comboItems = comboChannels.getItems(); for (int i = 0; i < comboItems.size(); i++) { if (channelNames.contains(comboItems.get(i).getName())) { @@ -192,9 +196,9 @@ private static Collection getAvailableChannels(ImageData i var server = imageData.getServer(); int i = 1; boolean hasDuplicates = false; + ChannelSelectItem item; for (var channel : server.getMetadata().getChannels()) { var name = channel.getName(); - var transform = ColorTransforms.createChannelExtractor(name); if (names.contains(name)) { logger.warn("Found duplicate channel name! Channel " + i + " (name '" + name + "')."); logger.warn("Using channel indices instead of names because of duplicated channel names."); @@ -202,19 +206,17 @@ private static Collection getAvailableChannels(ImageData i } names.add(name); if (hasDuplicates) { - transform = ColorTransforms.createChannelExtractor(i - 1); - } - if (!server.isRGB()) { - name += " (C" + i + ")"; + item = new ChannelSelectItem(name, i - 1); + } else { + item = new ChannelSelectItem(name); } - list.add(new ChannelSelectItem(name, transform)); + list.add(item); i++; } var stains = imageData.getColorDeconvolutionStains(); if (stains != null) { for (i = 1; i < 4; i++) { - var transform = ColorTransforms.createColorDeconvolvedChannel(stains, i); - list.add(new ChannelSelectItem(transform.getName(), transform)); + list.add(new ChannelSelectItem(stains, i)); } } return list; @@ -334,7 +336,7 @@ static void addModelsFromPath(String dir, ComboBox box) { try (var ps = Files.list(path)) { for (var file: ps.toList()) { if (InstanSegModel.isValidModel(file)) { - box.getItems().add(InstanSegModel.createModel(file)); + box.getItems().add(InstanSegModel.fromPath(file)); } } } catch (IOException e) { @@ -358,50 +360,14 @@ private void runInstanSeg() { var model = modelChoiceBox.getSelectionModel().getSelectedItem(); ImageServer server = qupath.getImageData().getServer(); - List selectedChannels = comboChannels + // todo: how to record this in workflow? + List selectedChannels = comboChannels .getCheckModel().getCheckedItems() .stream() .filter(Objects::nonNull) - .map(ChannelSelectItem::getTransform) .toList(); - var task = new Task() { - @Override - protected Void call() { - // Ensure PyTorch engine is available - if (!PytorchManager.hasPyTorchEngine()) { - downloadPyTorch(); - } - var objects = QP.getSelectedObjects(); - var imageData = QP.getCurrentImageData(); - try { - model.runInstanSeg( - objects, - imageData, - selectedChannels, - InstanSegPreferences.tileSizeProperty().get(), - model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize(), - deviceChoices.getSelectionModel().getSelectedItem(), - nucleiOnlyCheckBox.isSelected(), - QPEx.createTaskRunner(InstanSegPreferences.numThreadsProperty().getValue())); - } catch (ModelNotFoundException | MalformedModelException | - IOException | InterruptedException e) { - Dialogs.showErrorMessage("Unable to run InstanSeg", e); - logger.error("Unable to run InstanSeg", e); - } - for (PathObject po: objects) { - makeMeasurements(imageData, po.getChildObjects(), model); - } - QP.fireHierarchyUpdate(); - if (model.nFailed() > 0) { - var errorMessage = String.format(resources.getString("error.tiles-failed"), model.nFailed()); - logger.error(errorMessage); - Dialogs.showErrorMessage(resources.getString("title"), - errorMessage); - } - return null; - } - }; + var task = new InstanSegTask(server, model, selectedChannels); pendingTask.set(task); // Reset the pending task when it completes (either successfully or not) task.stateProperty().addListener((observable, oldValue, newValue) -> { @@ -412,6 +378,84 @@ protected Void call() { }); } + private class InstanSegTask extends Task { + + private final List channels; + private final ImageServer server; + private final InstanSegModel model; + + InstanSegTask(ImageServer server, InstanSegModel model, List channels) { + this.server = server; + this.model = model; + this.channels = channels; + } + + + @Override + protected Void call() { + // Ensure PyTorch engine is available + if (!PytorchManager.hasPyTorchEngine()) { + downloadPyTorch(); + } + String cmd = String.format(""" + import qupath.ext.instanseg.core.InstanSeg + + def channels = %s; + def instanSeg = InstanSeg.builder() + .modelPath("%s") + .device("%s") + .numOutputChannels(%d) + .channels(channels) + .tileDims(%d) + .imageData(QP.getCurrentImageData()) + .downsample(%f) + .nThreads(QPEx.createTaskRunner(%d)) + .build(); + instanSeg.detectObjects(); + """, + ChannelSelectItem.toConstructorString(channels), + model.getPath(), + deviceChoices.getSelectionModel().getSelectedItem(), + nucleiOnlyCheckBox.isSelected() ? 1 : 2, + InstanSegPreferences.tileSizeProperty().get(), + model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize(), + InstanSegPreferences.numThreadsProperty().getValue() + ); + qupath.getImageData().getHistoryWorkflow() + .addStep( + new DefaultScriptableWorkflowStep(resources.getString("workflow.title"), cmd) + ); + var taskRunner = new TaskRunnerFX( + QuPathGUI.getInstance(), + InstanSegPreferences.numThreadsProperty().getValue()); + + var imageData = qupath.getImageData(); + var selectedObjects = qupath.getImageData().getHierarchy().getSelectionModel().getSelectedObjects(); + var instanSeg = InstanSeg.builder() + .model(model) + .imageData(imageData) + .device(deviceChoices.getSelectionModel().getSelectedItem()) + .numOutputChannels(nucleiOnlyCheckBox.isSelected() ? 1 : 2) + .channels(channels.stream().map(ChannelSelectItem::getTransform).toList()) + .tileDims(InstanSegPreferences.tileSizeProperty().get()) + .downsample(model.getPixelSizeX() / (double) server.getPixelCalibration().getAveragedPixelSize()) + .taskRunner(taskRunner) + .build(); + instanSeg.detectObjects(selectedObjects); + qupath.getImageData().getHierarchy().fireHierarchyChangedEvent(this); + for (PathObject po: selectedObjects) { + makeMeasurements(imageData, po.getChildObjects(), model); + } + if (model.nFailed() > 0) { + var errorMessage = String.format(resources.getString("error.tiles-failed"), model.nFailed()); + logger.error(errorMessage); + Dialogs.showErrorMessage(resources.getString("title"), + errorMessage); + } + return null; + } + } + private void downloadPyTorch() { Platform.runLater(() -> Dialogs.showInfoNotification(resources.getString("title"), resources.getString("ui.pytorch-downloading"))); PytorchManager.getEngineOnline(); @@ -429,12 +473,14 @@ public void makeMeasurements(ImageData imageData, Collection