Skip to content

Commit

Permalink
Support multi-output models with bioimageio spec 0.5 (#117)
Browse files Browse the repository at this point in the history
Add support for multi-output models using the bioimageio 0.5.x model spec.

We currently handle logits and embeddings by adding them to the measurements of output objects. We don't have a good way to name measurements yet; hopefully we can get axis names at some point in the future of the bioimageio spec. Detection classes are set as PathClass on the output objects.

Detection classes are supported with a NominalOrOrdinalDataDescr. If logits are present too, we try to match the names of outputs and logits. Otherwise, we default to "Class 1", "Logit 1" etc.

Semantic segmentation is recognised as an output type, but support is not currently implemented (suggest this could be a future PR for now?).

Superset of #100

---------

Co-authored-by: Pete <[email protected]>
  • Loading branch information
alanocallaghan and petebankhead authored Feb 12, 2025
1 parent c890cdf commit 6160def
Show file tree
Hide file tree
Showing 16 changed files with 342 additions and 133 deletions.
71 changes: 42 additions & 29 deletions src/main/java/qupath/ext/instanseg/core/InstanSeg.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.bioimageio.spec.tensor.OutputTensor;
import qupath.lib.experimental.pixels.OpenCVProcessor;
import qupath.lib.experimental.pixels.OutputHandler;
import qupath.lib.experimental.pixels.Parameters;
import qupath.lib.experimental.pixels.PixelProcessor;
import qupath.lib.experimental.pixels.Processor;
import qupath.lib.images.ImageData;
import qupath.lib.images.servers.ColorTransforms;
Expand Down Expand Up @@ -58,7 +60,7 @@ public class InstanSeg {
private final InstanSegModel model;
private final Device device;
private final TaskRunner taskRunner;
private final Class<? extends PathObject> preferredOutputClass;
private final Class<? extends PathObject> preferredOutputType;
private final Map<String, Object> optionalArgs = new LinkedHashMap<>();

// This was previously an adjustable parameter, but it's now fixed at 1 because we handle overlaps differently.
Expand All @@ -75,7 +77,7 @@ private InstanSeg(Builder builder) {
this.model = builder.model;
this.device = builder.device;
this.taskRunner = builder.taskRunner;
this.preferredOutputClass = builder.preferredOutputClass;
this.preferredOutputType = builder.preferredOutputType;
this.randomColors = builder.randomColors;
this.makeMeasurements = builder.makeMeasurements;
this.optionalArgs.putAll(builder.optionalArgs);
Expand Down Expand Up @@ -160,13 +162,18 @@ private void makeMeasurements(ImageData<BufferedImage> imageData, Collection<? e

private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collection<? extends PathObject> pathObjects) {
long startTime = System.currentTimeMillis();

Optional<Path> oModelPath = model.getPath();
if (oModelPath.isEmpty()) {
return InstanSegResults.emptyInstance();
}
Path modelPath = oModelPath.get().resolve("instanseg.pt");

Optional<List<OutputTensor>> oOutputTensors = this.model.getOutputs();
if (oOutputTensors.isEmpty()) {
throw new IllegalArgumentException("No output tensors available even though model is available");
}
var outputTensors = oOutputTensors.get();

// Provide some way to change the number of predictors, even if this can't be specified through the UI
// See https://forum.image.sc/t/instanseg-under-utilizing-cpu-only-2-3-cores/104496/7
int nPredictors = Integer.parseInt(System.getProperty("instanseg.numPredictors", "1"));
Expand All @@ -179,8 +186,6 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
logger.warn("Padding to input size is turned on - this is likely to be slower (but could help fix any issues)");
}
String layout = "CHW";

// TODO: Remove C if not needed (added for instanseg_v0_2_0.pt) - still relevant?
String layoutOutput = "CHW";

// Get the downsample - this may be specified by the user, or determined from the model spec
Expand All @@ -202,6 +207,7 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
// Create an int[] representing a boolean array of channels to use
boolean[] outputChannelArray = null;
if (outputChannels != null && outputChannels.length > 0) {
//noinspection OptionalGetWithoutIsPresent
outputChannelArray = new boolean[model.getOutputChannels().get()]; // safe to call get because of previous checks
for (int c : outputChannels) {
if (c < 0 || c >= outputChannelArray.length) {
Expand All @@ -215,7 +221,7 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
var inputChannels = getInputChannels(imageData);

try (var model = Criteria.builder()
.setTypes(Mat.class, Mat.class)
.setTypes(Mat.class, Mat[].class)
.optModelUrls(String.valueOf(modelPath.toUri()))
.optProgress(new ProgressBar())
.optDevice(device) // Remove this line if devices are problematic!
Expand All @@ -227,7 +233,7 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
printResourceCount("Resource count before prediction",
(BaseNDManager)baseManager.getParentManager());
baseManager.debugDump(2);
BlockingQueue<Predictor<Mat, Mat>> predictors = new ArrayBlockingQueue<>(nPredictors);
BlockingQueue<Predictor<Mat, Mat[]>> predictors = new ArrayBlockingQueue<>(nPredictors);

try {
for (int i = 0; i < nPredictors; i++) {
Expand All @@ -239,10 +245,11 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec

var tiler = createTiler(downsample, tileDims, padding);
var predictionProcessor = createProcessor(predictors, inputChannels, tileDims, padToInputSize);
var outputHandler = createOutputHandler(preferredOutputClass, randomColors, boundaryThreshold);
var outputHandler = createOutputHandler(preferredOutputType, randomColors, boundaryThreshold, outputTensors);
var postProcessor = createPostProcessor();

var processor = OpenCVProcessor.builder(predictionProcessor)
var processor = new PixelProcessor.Builder<Mat, Mat, Mat[]>()
.processor(predictionProcessor)
.maskSupplier(OpenCVProcessor.createMatMaskSupplier())
.imageSupplier((parameters) -> ImageOps.buildImageDataOp(inputChannels)
.apply(parameters.getImageData(), parameters.getRegionRequest()))
.tiler(tiler)
Expand Down Expand Up @@ -279,24 +286,25 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
}
}


/**
* Check if we are requesting tiles for debugging purposes.
* When this is true, we should create objects that represent the tiles - not the objects to be detected.
* @return
* @return Whether the system debugging property is set.
*/
private static boolean debugTiles() {
return System.getProperty("instanseg.debug.tiles", "false").strip().equalsIgnoreCase("true");
}

private static Processor<Mat, Mat, Mat> createProcessor(BlockingQueue<Predictor<Mat, Mat>> predictors,
private static Processor<Mat, Mat, Mat[]> createProcessor(BlockingQueue<Predictor<Mat, Mat[]>> predictors,
Collection<? extends ColorTransforms.ColorTransform> inputChannels,
int tileDims, boolean padToInputSize) {
if (debugTiles())
return InstanSeg::createOnes;
return new TilePredictionProcessor(predictors, inputChannels, tileDims, tileDims, padToInputSize);
}

private static Mat createOnes(Parameters<Mat, Mat> parameters) {
private static Mat[] createOnes(Parameters<Mat, Mat> parameters) {
var tileRequest = parameters.getTileRequest();
int width, height;
if (tileRequest == null) {
Expand All @@ -307,15 +315,20 @@ private static Mat createOnes(Parameters<Mat, Mat> parameters) {
width = tileRequest.getTileWidth();
height = tileRequest.getTileHeight();
}
return Mat.ones(height, width, opencv_core.CV_8UC1).asMat();
try (var ones = Mat.ones(height, width, opencv_core.CV_8UC1)) {
return new Mat[]{ones.asMat()};
}
}

private static OutputHandler<Mat, Mat, Mat> createOutputHandler(Class<? extends PathObject> preferredOutputClass,
boolean randomColors,
int boundaryThreshold) {
if (debugTiles())
return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter());
var converter = new InstanSegOutputToObjectConverter(preferredOutputClass, randomColors);

private static OutputHandler<Mat, Mat, Mat[]> createOutputHandler(Class<? extends PathObject> preferredOutputType,
boolean randomColors,
int boundaryThreshold,
List<OutputTensor> outputTensors) {
// TODO: Reinstate this for Mat[] output (it was written for Mat output)
// if (debugTiles())
// return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter());
var converter = new InstanSegOutputToObjectConverter(outputTensors, preferredOutputType, randomColors);
if (boundaryThreshold >= 0) {
return new PruneObjectOutputHandler<>(converter, boundaryThreshold);
} else {
Expand All @@ -334,8 +347,8 @@ private static Tiler createTiler(double downsample, int tileDims, int padding) {

/**
* Get the input channels to use; if we don't have any specified, use all of them
* @param imageData
* @return
* @param imageData The image data
* @return The possible input channels.
*/
private List<ColorTransforms.ColorTransform> getInputChannels(ImageData<BufferedImage> imageData) {
if (inputChannels == null || inputChannels.isEmpty()) {
Expand Down Expand Up @@ -364,8 +377,8 @@ private static ObjectProcessor createPostProcessor() {
/**
* Print resource count for debugging purposes.
* If we are not logging at debug level, do nothing.
* @param title
* @param manager
* @param title The name to be used in the log.
* @param manager The NDManager to print from.
*/
private static void printResourceCount(String title, BaseNDManager manager) {
if (logger.isDebugEnabled()) {
Expand Down Expand Up @@ -395,7 +408,7 @@ public static final class Builder {
private TaskRunner taskRunner = TaskRunnerUtils.getDefaultInstance().createTaskRunner();
private Collection<? extends ColorTransforms.ColorTransform> channels;
private InstanSegModel model;
private Class<? extends PathObject> preferredOutputClass;
private Class<? extends PathObject> preferredOutputType;
private final Map<String, Object> optionalArgs = new LinkedHashMap<>();

Builder() {}
Expand Down Expand Up @@ -555,7 +568,7 @@ public Builder randomColors() {

/**
* Optionally request that random colors be used for the output objects.
* @param doRandomColors
* @param doRandomColors Whether to use random colors for output object.
* @return this builder
*/
public Builder randomColors(boolean doRandomColors) {
Expand Down Expand Up @@ -636,7 +649,7 @@ public Builder device(Device device) {
* @return this builder
*/
public Builder outputCells() {
this.preferredOutputClass = PathCellObject.class;
this.preferredOutputType = PathCellObject.class;
return this;
}

Expand All @@ -645,7 +658,7 @@ public Builder outputCells() {
* @return this builder
*/
public Builder outputDetections() {
this.preferredOutputClass = PathDetectionObject.class;
this.preferredOutputType = PathDetectionObject.class;
return this;
}

Expand All @@ -654,7 +667,7 @@ public Builder outputDetections() {
* @return this builder
*/
public Builder outputAnnotations() {
this.preferredOutputClass = PathAnnotationObject.class;
this.preferredOutputType = PathAnnotationObject.class;
return this;
}

Expand Down
77 changes: 65 additions & 12 deletions src/main/java/qupath/ext/instanseg/core/InstanSegModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.bioimageio.spec.BioimageIoSpec;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;

import qupath.bioimageio.spec.Model;
import qupath.bioimageio.spec.Resource;
import qupath.bioimageio.spec.tensor.OutputTensor;
import qupath.lib.common.GeneralTools;
import qupath.lib.images.servers.PixelCalibration;

Expand All @@ -27,6 +30,9 @@
import java.util.zip.ZipInputStream;
import java.util.Objects;

import static qupath.bioimageio.spec.tensor.axes.Axes.getAxesString;


public class InstanSegModel {

private static final Logger logger = LoggerFactory.getLogger(InstanSegModel.class);
Expand All @@ -39,10 +45,10 @@ public class InstanSegModel {
public static final int ANY_CHANNELS = -1;

private Path path = null;
private BioimageIoSpec.BioimageIoModel model = null;
private Model model = null;
private final String name;

private InstanSegModel(BioimageIoSpec.BioimageIoModel bioimageIoModel) {
private InstanSegModel(Model bioimageIoModel) {
this.model = bioimageIoModel;
this.path = Paths.get(model.getBaseURI());
this.version = model.getVersion();
Expand All @@ -62,7 +68,7 @@ private InstanSegModel(String name, String version, URL modelURL) {
* @throws IOException If the directory can't be found or isn't a valid model directory.
*/
public static InstanSegModel fromPath(Path path) throws IOException {
return new InstanSegModel(BioimageIoSpec.parseModel(path));
return new InstanSegModel(Model.parse(path));
}

/**
Expand Down Expand Up @@ -103,7 +109,7 @@ public void checkIfDownloaded(Path downloadedModelDir, boolean downloadIfNotVali
downloadIfNotValid);
this.path = unzipIfNeeded(zipFile);
if (this.path != null) {
this.model = BioimageIoSpec.parseModel(path.toFile());
this.model = Model.parse(path.toFile());
this.version = model.getVersion();
}
}
Expand Down Expand Up @@ -212,7 +218,7 @@ public Optional<Path> getPath() {
public String toString() {
String name = getName();
String parent = getPath().map(Path::getParent).map(Path::getFileName).map(Path::toString).orElse(null);
String version = getModel().map(BioimageIoSpec.BioimageIoModel::getVersion).orElse(this.version);
String version = getModel().map(Resource::getVersion).orElse(this.version);
if (parent != null && !parent.equals(name)) {
name = parent + "/" + name;
}
Expand Down Expand Up @@ -251,8 +257,55 @@ public Optional<Integer> getNumChannels() {
return getModel().flatMap(model -> Optional.of(extractChannelNum(model)));
}

private static int extractChannelNum(BioimageIoSpec.BioimageIoModel model) {
int ind = model.getInputs().getFirst().getAxes().toLowerCase().indexOf("c");
/**
* Try to check the output tensors from the model spec.
* @return The output tensors if the model is downloaded, otherwise empty.
*/
public Optional<List<OutputTensor>> getOutputs() {
return getModel().flatMap(model -> Optional.ofNullable(model.getOutputs()));
}

/**
* Types of output tensors that may be supported by InstanSeg models.
*/
public enum OutputTensorType {
// "instance segmentation" "cell embeddings" "cell classes" "cell probabilities" "semantic segmentation"
INSTANCE_SEGMENTATION("instance_segmentation"),
DETECTION_EMBEDDINGS("detection_embeddings"),
DETECTION_LOGITS("detection_logits"),
DETECTION_CLASSES("detection_classes"),
SEMANTIC_SEGMENTATION("semantic_segmentation");

private final String type;

OutputTensorType(String type) {
this.type = type;
}

@Override
public String toString() {
return type;
}

/**
* Get the output type from a string, ignoring case
* @param value the input String to be matched against the possible values
* @return the matching output type, or empty if no match
*/
public static Optional<OutputTensorType> fromString(String value) {
for (OutputTensorType t: values()) {
if (t.type.equalsIgnoreCase(value)) {
return Optional.of(t);
}
}
logger.error("Unknown output type {}", value);
return Optional.empty();
}
}

private static int extractChannelNum(Model model) {
String axes = getAxesString(model.getInputs().getFirst().getAxes());
int ind = axes.indexOf("c");
var shape = model.getInputs().getFirst().getShape();
if (shape.getShapeStep()[ind] == 1) {
return ANY_CHANNELS;
Expand All @@ -265,7 +318,7 @@ private static int extractChannelNum(BioimageIoSpec.BioimageIoModel model) {
* Retrieve the BioImage model spec.
* @return The BioImageIO model spec for this InstanSeg model.
*/
private Optional<BioimageIoSpec.BioimageIoModel> getModel() {
private Optional<Model> getModel() {
return Optional.ofNullable(model);
}

Expand All @@ -292,7 +345,7 @@ private static boolean isDownloadedAlready(Path zipFile) {
return false;
}
try {
BioimageIoSpec.parseModel(zipFile.toFile());
Model.parse(zipFile.toFile());
} catch (IOException e) {
logger.warn("Invalid zip file", e);
return false;
Expand All @@ -304,7 +357,7 @@ private Path unzipIfNeeded(Path zipFile) throws IOException {
if (zipFile == null) {
return null;
}
var zipSpec = BioimageIoSpec.parseModel(zipFile);
var zipSpec = Model.parse(zipFile);
String version = zipSpec.getVersion();
var outdir = zipFile.resolveSibling(getFolderName(zipSpec.getName(), version));
if (!isUnpackedAlready(outdir)) {
Expand Down Expand Up @@ -399,7 +452,7 @@ private Optional<Map<String, Double>> getPixelSize() {
public Optional<Integer> getOutputChannels() {
return getModel().map(model -> {
var output = model.getOutputs().getFirst();
String axes = output.getAxes().toLowerCase();
String axes = getAxesString(output.getAxes());
int ind = axes.indexOf("c");
var shape = output.getShape().getShape();
if (shape != null && shape.length > ind)
Expand Down
Loading

0 comments on commit 6160def

Please sign in to comment.