Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-output models with bioimageio spec 0.5 #117

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
68 changes: 43 additions & 25 deletions src/main/java/qupath/ext/instanseg/core/InstanSeg.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.bioimageio.spec.tensor.OutputTensor;
import qupath.lib.experimental.pixels.OpenCVProcessor;
import qupath.lib.experimental.pixels.OutputHandler;
import qupath.lib.experimental.pixels.Parameters;
import qupath.lib.experimental.pixels.PixelProcessor;
import qupath.lib.experimental.pixels.Processor;
import qupath.lib.images.ImageData;
import qupath.lib.images.servers.ColorTransforms;
Expand Down Expand Up @@ -58,7 +60,7 @@ public class InstanSeg {
private final InstanSegModel model;
private final Device device;
private final TaskRunner taskRunner;
private final Class<? extends PathObject> preferredOutputClass;
private final Class<? extends PathObject> preferredOutputType;
private final Map<String, Object> optionalArgs = new LinkedHashMap<>();

// This was previously an adjustable parameter, but it's now fixed at 1 because we handle overlaps differently.
Expand All @@ -75,7 +77,7 @@ private InstanSeg(Builder builder) {
this.model = builder.model;
this.device = builder.device;
this.taskRunner = builder.taskRunner;
this.preferredOutputClass = builder.preferredOutputClass;
this.preferredOutputType = builder.preferredOutputClass;
this.randomColors = builder.randomColors;
this.makeMeasurements = builder.makeMeasurements;
this.optionalArgs.putAll(builder.optionalArgs);
Expand Down Expand Up @@ -160,13 +162,23 @@ private void makeMeasurements(ImageData<BufferedImage> imageData, Collection<? e

private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collection<? extends PathObject> pathObjects) {
long startTime = System.currentTimeMillis();

Optional<Path> oModelPath = model.getPath();
if (oModelPath.isEmpty()) {
return InstanSegResults.emptyInstance();
}
Path modelPath = oModelPath.get().resolve("instanseg.pt");

Optional<List<OutputTensor>> oOutputTensors = this.model.getOutputs();
if (oOutputTensors.isEmpty()) {
throw new IllegalArgumentException("No output tensors available even though model is available");
}
var outputTensors = oOutputTensors.get();

List<String> outputClasses = this.model.getClasses();
if (outputClasses.isEmpty() && outputTensors.size() > 1) {
logger.warn("No output classes available, classes will be set as 'Class 1' etc.");
}

// Provide some way to change the number of predictors, even if this can't be specified through the UI
// See https://forum.image.sc/t/instanseg-under-utilizing-cpu-only-2-3-cores/104496/7
int nPredictors = Integer.parseInt(System.getProperty("instanseg.numPredictors", "1"));
Expand All @@ -179,8 +191,6 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
logger.warn("Padding to input size is turned on - this is likely to be slower (but could help fix any issues)");
}
String layout = "CHW";

// TODO: Remove C if not needed (added for instanseg_v0_2_0.pt) - still relevant?
String layoutOutput = "CHW";

// Get the downsample - this may be specified by the user, or determined from the model spec
Expand All @@ -202,6 +212,7 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
// Create an int[] representing a boolean array of channels to use
boolean[] outputChannelArray = null;
if (outputChannels != null && outputChannels.length > 0) {
//noinspection OptionalGetWithoutIsPresent
outputChannelArray = new boolean[model.getOutputChannels().get()]; // safe to call get because of previous checks
for (int c : outputChannels) {
if (c < 0 || c >= outputChannelArray.length) {
Expand All @@ -215,7 +226,7 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
var inputChannels = getInputChannels(imageData);

try (var model = Criteria.builder()
.setTypes(Mat.class, Mat.class)
.setTypes(Mat.class, Mat[].class)
.optModelUrls(String.valueOf(modelPath.toUri()))
.optProgress(new ProgressBar())
.optDevice(device) // Remove this line if devices are problematic!
Expand All @@ -227,7 +238,7 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
printResourceCount("Resource count before prediction",
(BaseNDManager)baseManager.getParentManager());
baseManager.debugDump(2);
BlockingQueue<Predictor<Mat, Mat>> predictors = new ArrayBlockingQueue<>(nPredictors);
BlockingQueue<Predictor<Mat, Mat[]>> predictors = new ArrayBlockingQueue<>(nPredictors);

try {
for (int i = 0; i < nPredictors; i++) {
Expand All @@ -239,10 +250,11 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec

var tiler = createTiler(downsample, tileDims, padding);
var predictionProcessor = createProcessor(predictors, inputChannels, tileDims, padToInputSize);
var outputHandler = createOutputHandler(preferredOutputClass, randomColors, boundaryThreshold);
var outputHandler = createOutputHandler(preferredOutputType, randomColors, boundaryThreshold, outputTensors);
var postProcessor = createPostProcessor();

var processor = OpenCVProcessor.builder(predictionProcessor)
var processor = new PixelProcessor.Builder<Mat, Mat, Mat[]>()
.processor(predictionProcessor)
.maskSupplier(OpenCVProcessor.createMatMaskSupplier())
.imageSupplier((parameters) -> ImageOps.buildImageDataOp(inputChannels)
.apply(parameters.getImageData(), parameters.getRegionRequest()))
.tiler(tiler)
Expand Down Expand Up @@ -279,24 +291,25 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
}
}


/**
* Check if we are requesting tiles for debugging purposes.
* When this is true, we should create objects that represent the tiles - not the objects to be detected.
* @return
* @return Whether the system debugging property is set.
*/
private static boolean debugTiles() {
return System.getProperty("instanseg.debug.tiles", "false").strip().equalsIgnoreCase("true");
}

private static Processor<Mat, Mat, Mat> createProcessor(BlockingQueue<Predictor<Mat, Mat>> predictors,
private static Processor<Mat, Mat, Mat[]> createProcessor(BlockingQueue<Predictor<Mat, Mat[]>> predictors,
Collection<? extends ColorTransforms.ColorTransform> inputChannels,
int tileDims, boolean padToInputSize) {
if (debugTiles())
return InstanSeg::createOnes;
return new TilePredictionProcessor(predictors, inputChannels, tileDims, tileDims, padToInputSize);
}

private static Mat createOnes(Parameters<Mat, Mat> parameters) {
private static Mat[] createOnes(Parameters<Mat, Mat> parameters) {
var tileRequest = parameters.getTileRequest();
int width, height;
if (tileRequest == null) {
Expand All @@ -307,15 +320,20 @@ private static Mat createOnes(Parameters<Mat, Mat> parameters) {
width = tileRequest.getTileWidth();
height = tileRequest.getTileHeight();
}
return Mat.ones(height, width, opencv_core.CV_8UC1).asMat();
try (var ones = Mat.ones(height, width, opencv_core.CV_8UC1)) {
return new Mat[]{ones.asMat()};
}
}

private static OutputHandler<Mat, Mat, Mat> createOutputHandler(Class<? extends PathObject> preferredOutputClass,
boolean randomColors,
int boundaryThreshold) {
if (debugTiles())
return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter());
var converter = new InstanSegOutputToObjectConverter(preferredOutputClass, randomColors);

private static OutputHandler<Mat, Mat, Mat[]> createOutputHandler(Class<? extends PathObject> preferredOutputClass,
boolean randomColors,
int boundaryThreshold,
List<OutputTensor> outputTensors) {
// TODO: Reinstate this for Mat[] output (it was written for Mat output)
// if (debugTiles())
// return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter());
var converter = new InstanSegOutputToObjectConverter(outputTensors, preferredOutputClass, randomColors);
if (boundaryThreshold >= 0) {
return new PruneObjectOutputHandler<>(converter, boundaryThreshold);
} else {
Expand All @@ -334,8 +352,8 @@ private static Tiler createTiler(double downsample, int tileDims, int padding) {

/**
* Get the input channels to use; if we don't have any specified, use all of them
* @param imageData
* @return
* @param imageData The image data
* @return The possible input channels.
*/
private List<ColorTransforms.ColorTransform> getInputChannels(ImageData<BufferedImage> imageData) {
if (inputChannels == null || inputChannels.isEmpty()) {
Expand Down Expand Up @@ -364,8 +382,8 @@ private static ObjectProcessor createPostProcessor() {
/**
* Print resource count for debugging purposes.
* If we are not logging at debug level, do nothing.
* @param title
* @param manager
* @param title The name to be used in the log.
* @param manager The NDManager to print from.
*/
private static void printResourceCount(String title, BaseNDManager manager) {
if (logger.isDebugEnabled()) {
Expand Down Expand Up @@ -555,7 +573,7 @@ public Builder randomColors() {

/**
* Optionally request that random colors be used for the output objects.
* @param doRandomColors
* @param doRandomColors Whether to use random colors for output object.
* @return this builder
*/
public Builder randomColors(boolean doRandomColors) {
Expand Down
77 changes: 65 additions & 12 deletions src/main/java/qupath/ext/instanseg/core/InstanSegModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.bioimageio.spec.BioimageIoSpec;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;

import qupath.bioimageio.spec.Model;
import qupath.bioimageio.spec.Resource;
import qupath.bioimageio.spec.tensor.OutputTensor;
import qupath.lib.common.GeneralTools;
import qupath.lib.images.servers.PixelCalibration;

Expand All @@ -20,13 +23,17 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.Objects;

import static qupath.bioimageio.spec.tensor.axes.Axes.getAxesString;


public class InstanSegModel {

private static final Logger logger = LoggerFactory.getLogger(InstanSegModel.class);
Expand All @@ -39,10 +46,10 @@ public class InstanSegModel {
public static final int ANY_CHANNELS = -1;

private Path path = null;
private BioimageIoSpec.BioimageIoModel model = null;
private Model model = null;
private final String name;

private InstanSegModel(BioimageIoSpec.BioimageIoModel bioimageIoModel) {
private InstanSegModel(Model bioimageIoModel) {
this.model = bioimageIoModel;
this.path = Paths.get(model.getBaseURI());
this.version = model.getVersion();
Expand All @@ -62,7 +69,7 @@ private InstanSegModel(String name, String version, URL modelURL) {
* @throws IOException If the directory can't be found or isn't a valid model directory.
*/
public static InstanSegModel fromPath(Path path) throws IOException {
return new InstanSegModel(BioimageIoSpec.parseModel(path));
return new InstanSegModel(Model.parseModel(path));
}

/**
Expand Down Expand Up @@ -103,7 +110,7 @@ public void checkIfDownloaded(Path downloadedModelDir, boolean downloadIfNotVali
downloadIfNotValid);
this.path = unzipIfNeeded(zipFile);
if (this.path != null) {
this.model = BioimageIoSpec.parseModel(path.toFile());
this.model = Model.parseModel(path.toFile());
this.version = model.getVersion();
}
}
Expand Down Expand Up @@ -212,7 +219,7 @@ public Optional<Path> getPath() {
public String toString() {
String name = getName();
String parent = getPath().map(Path::getParent).map(Path::getFileName).map(Path::toString).orElse(null);
String version = getModel().map(BioimageIoSpec.BioimageIoModel::getVersion).orElse(this.version);
String version = getModel().map(Resource::getVersion).orElse(this.version);
if (parent != null && !parent.equals(name)) {
name = parent + "/" + name;
}
Expand Down Expand Up @@ -251,8 +258,54 @@ public Optional<Integer> getNumChannels() {
return getModel().flatMap(model -> Optional.of(extractChannelNum(model)));
}

private static int extractChannelNum(BioimageIoSpec.BioimageIoModel model) {
int ind = model.getInputs().getFirst().getAxes().toLowerCase().indexOf("c");
/**
* Try to check the output tensors from the model spec.
* @return The output tensors if the model is downloaded, otherwise empty.
*/
public Optional<List<OutputTensor>> getOutputs() {
return getModel().flatMap(model -> Optional.ofNullable(model.getOutputs()));
}

/**
* Try to check the output classes from the model spec.
* @return The output classes if the model is downloaded, and it's present, otherwise empty.
*/
public List<String> getClasses() {
var config = model.getConfig().getOrDefault("qupath", null);
if (config instanceof Map configMap) {
List<String> classes = new ArrayList<>();
var el = configMap.get("classes");
if (el != null && el instanceof List elList) {
for (var t: elList) {
classes.add(t.toString());
}
}
return classes;
}
return List.of();
}

public enum OutputType {
// "instance segmentation" "cell embeddings" "cell classes" "cell probabilities" "semantic segmentation"
INSTANCE_SEGMENTATION("instance_segmentation"),
DETECTION_EMBEDDINGS("detection_embeddings"),
DETECTION_LOGITS("detection_logits"),
DETECTION_CLASSES("detection_classes"),
SEMANTIC_SEGMENTATION("semantic_segmentation");

private final String type;
OutputType(String type) {
this.type = type;
}
@Override
public String toString() {
return type;
}
}

private static int extractChannelNum(Model model) {
String axes = getAxesString(model.getInputs().getFirst().getAxes());
int ind = axes.indexOf("c");
var shape = model.getInputs().getFirst().getShape();
if (shape.getShapeStep()[ind] == 1) {
return ANY_CHANNELS;
Expand All @@ -265,7 +318,7 @@ private static int extractChannelNum(BioimageIoSpec.BioimageIoModel model) {
* Retrieve the BioImage model spec.
* @return The BioImageIO model spec for this InstanSeg model.
*/
private Optional<BioimageIoSpec.BioimageIoModel> getModel() {
private Optional<Model> getModel() {
return Optional.ofNullable(model);
}

Expand All @@ -292,7 +345,7 @@ private static boolean isDownloadedAlready(Path zipFile) {
return false;
}
try {
BioimageIoSpec.parseModel(zipFile.toFile());
Model.parseModel(zipFile.toFile());
} catch (IOException e) {
logger.warn("Invalid zip file", e);
return false;
Expand All @@ -304,7 +357,7 @@ private Path unzipIfNeeded(Path zipFile) throws IOException {
if (zipFile == null) {
return null;
}
var zipSpec = BioimageIoSpec.parseModel(zipFile);
var zipSpec = Model.parseModel(zipFile);
String version = zipSpec.getVersion();
var outdir = zipFile.resolveSibling(getFolderName(zipSpec.getName(), version));
if (!isUnpackedAlready(outdir)) {
Expand Down Expand Up @@ -399,7 +452,7 @@ private Optional<Map<String, Double>> getPixelSize() {
public Optional<Integer> getOutputChannels() {
return getModel().map(model -> {
var output = model.getOutputs().getFirst();
String axes = output.getAxes().toLowerCase();
String axes = getAxesString(output.getAxes());
int ind = axes.indexOf("c");
var shape = output.getShape().getShape();
if (shape != null && shape.length > ind)
Expand Down
Loading
Loading