Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explore handling multiple outputs #100

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/main/java/qupath/ext/instanseg/core/InstanSeg.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import qupath.lib.experimental.pixels.OpenCVProcessor;
import qupath.lib.experimental.pixels.OutputHandler;
import qupath.lib.experimental.pixels.Parameters;
import qupath.lib.experimental.pixels.PixelProcessor;
import qupath.lib.experimental.pixels.Processor;
import qupath.lib.images.ImageData;
import qupath.lib.images.servers.ColorTransforms;
Expand Down Expand Up @@ -208,7 +209,7 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
var inputChannels = getInputChannels(imageData);

try (var model = Criteria.builder()
.setTypes(Mat.class, Mat.class)
.setTypes(Mat.class, Mat[].class)
.optModelUrls(String.valueOf(modelPath.toUri()))
.optProgress(new ProgressBar())
.optDevice(device) // Remove this line if devices are problematic!
Expand All @@ -220,7 +221,7 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
printResourceCount("Resource count before prediction",
(BaseNDManager)baseManager.getParentManager());
baseManager.debugDump(2);
BlockingQueue<Predictor<Mat, Mat>> predictors = new ArrayBlockingQueue<>(nPredictors);
BlockingQueue<Predictor<Mat, Mat[]>> predictors = new ArrayBlockingQueue<>(nPredictors);

try {
for (int i = 0; i < nPredictors; i++) {
Expand All @@ -234,8 +235,9 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
var predictionProcessor = createProcessor(predictors, inputChannels, tileDims, padToInputSize);
var outputHandler = createOutputHandler(preferredOutputClass, randomColors, boundaryThreshold);
var postProcessor = createPostProcessor();

var processor = OpenCVProcessor.builder(predictionProcessor)
var processor = new PixelProcessor.Builder<Mat, Mat, Mat[]>()
.processor(predictionProcessor)
.maskSupplier(OpenCVProcessor.createMatMaskSupplier())
.imageSupplier((parameters) -> ImageOps.buildImageDataOp(inputChannels)
.apply(parameters.getImageData(), parameters.getRegionRequest()))
.tiler(tiler)
Expand Down Expand Up @@ -271,6 +273,7 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
}
}


/**
* Check if we are requesting tiles for debugging purposes.
* When this is true, we should create objects that represent the tiles - not the objects to be detected.
Expand All @@ -280,15 +283,15 @@ private static boolean debugTiles() {
return System.getProperty("instanseg.debug.tiles", "false").strip().equalsIgnoreCase("true");
}

private static Processor<Mat, Mat, Mat> createProcessor(BlockingQueue<Predictor<Mat, Mat>> predictors,
private static Processor<Mat, Mat, Mat[]> createProcessor(BlockingQueue<Predictor<Mat, Mat[]>> predictors,
Collection<? extends ColorTransforms.ColorTransform> inputChannels,
int tileDims, boolean padToInputSize) {
if (debugTiles())
return InstanSeg::createOnes;
return new TilePredictionProcessor(predictors, inputChannels, tileDims, tileDims, padToInputSize);
}

private static Mat createOnes(Parameters<Mat, Mat> parameters) {
private static Mat[] createOnes(Parameters<Mat, Mat> parameters) {
var tileRequest = parameters.getTileRequest();
int width, height;
if (tileRequest == null) {
Expand All @@ -299,13 +302,14 @@ private static Mat createOnes(Parameters<Mat, Mat> parameters) {
width = tileRequest.getTileWidth();
height = tileRequest.getTileHeight();
}
return Mat.ones(height, width, opencv_core.CV_8UC1).asMat();
return new Mat[]{Mat.ones(height, width, opencv_core.CV_8UC1).asMat()};
}

private static OutputHandler<Mat, Mat, Mat> createOutputHandler(Class<? extends PathObject> preferredOutputClass, boolean randomColors,
private static OutputHandler<Mat, Mat, Mat[]> createOutputHandler(Class<? extends PathObject> preferredOutputClass, boolean randomColors,
int boundaryThreshold) {
if (debugTiles())
return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter());
// TODO: Reinstate this for Mat[] output (it was written for Mat output)
// if (debugTiles())
// return OutputHandler.createUnmaskedObjectOutputHandler(OpenCVProcessor.createAnnotationConverter());
return new PruneObjectOutputHandler<>(
new InstanSegOutputToObjectConverter(preferredOutputClass, randomColors), boundaryThreshold);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,23 @@
import qupath.lib.objects.PathObject;
import qupath.lib.objects.PathObjects;
import qupath.lib.objects.PathTileObject;
import qupath.lib.objects.classes.PathClass;
import qupath.lib.regions.ImagePlane;
import qupath.lib.roi.GeometryTools;
import qupath.lib.roi.interfaces.ROI;
import qupath.opencv.tools.OpenCVTools;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;

class InstanSegOutputToObjectConverter implements OutputHandler.OutputToObjectConverter<Mat, Mat, Mat> {
class InstanSegOutputToObjectConverter implements OutputHandler.OutputToObjectConverter<Mat, Mat, Mat[]> {

private static final Logger logger = LoggerFactory.getLogger(InstanSegOutputToObjectConverter.class);

Expand All @@ -46,18 +48,19 @@ class InstanSegOutputToObjectConverter implements OutputHandler.OutputToObjectCo
}

@Override
public List<PathObject> convertToObjects(Parameters<Mat, Mat> params, Mat output) {
public List<PathObject> convertToObjects(Parameters<Mat, Mat> params, Mat[] output) {
if (output == null) {
return List.of();
}
int nChannels = output.channels();
var matLabels = output[0];
int nChannels = matLabels.channels();
if (nChannels < 1 || nChannels > 2)
throw new IllegalArgumentException("Expected 1 or 2 channels, but found " + nChannels);


List<Map<Number, ROI>> roiMaps = new ArrayList<>();
ImagePlane plane = params.getRegionRequest().getImagePlane();
for (var mat : OpenCVTools.splitChannels(output)) {
for (var mat : OpenCVTools.splitChannels(matLabels)) {
var image = OpenCVTools.matToSimpleImage(mat, 0);
var geoms = ContourTracing.createGeometries(image, params.getRegionRequest(), 1, -1);
roiMaps.add(geoms.entrySet().stream()
Expand All @@ -68,6 +71,17 @@ public List<PathObject> convertToObjects(Parameters<Mat, Mat> params, Mat output
);
}

// If we have two outputs, the second may give classifications - arrange by row
Map<Number, double[]> classifications = new HashMap<>();
if (output.length > 1) {
var matClass = output[1];
int nRows = matClass.rows();
for (int r = 0; r < nRows; r++) {
double[] doubles = OpenCVTools.extractDoubles(matClass.row(r));
classifications.put(r+1, doubles);
}
}

// We reverse the order because the smaller output (e.g. nucleus) comes before the larger out (e.g. cell)
// and we want to iterate in the opposite order. If this changes (or becomes inconsistent) we may need to
// sum pixels or areas.
Expand All @@ -83,26 +97,31 @@ public List<PathObject> convertToObjects(Parameters<Mat, Mat> params, Mat output
Map<Number, ROI> childROIs = roiMaps.size() >= 2 ? roiMaps.get(1) : Collections.emptyMap();
pathObjects = parentROIs.entrySet().stream().map(entry -> {
var parent = entry.getValue();
var child = childROIs.getOrDefault(entry.getKey(), null);
return PathObjects.createCellObject(parent, child);
var label = entry.getKey();
var child = childROIs.getOrDefault(label, null);
var cell = PathObjects.createCellObject(parent, child);
assignClassificationsIfAvailable(cell, classifications.getOrDefault(label, null));
return cell;
}).toList();
} else {
Function<ROI, PathObject> createObjectFun = createObjectFun(preferredObjectClass);
pathObjects = new ArrayList<>();
Map<Number, ROI> parentMap = roiMaps.getFirst();
List<Map<Number, ROI>> childMaps = roiMaps.size() == 1 ? Collections.emptyList() : roiMaps.subList(1, roiMaps.size());
for (var entry : parentMap.entrySet()) {
var label = entry.getKey();
var roi = entry.getValue();
var pathObject = createObjectFun.apply(roi);
if (roiMaps.size() > 1) {
for (var subMap : childMaps) {
var childROI = subMap.get(entry.getKey());
var childROI = subMap.get(label);
if (childROI != null) {
var childObject = createObjectFun.apply(childROI);
pathObject.addChildObject(childObject);
}
}
}
assignClassificationsIfAvailable(pathObject, classifications.getOrDefault(label, null));
pathObjects.add(pathObject);
}
}
Expand All @@ -115,6 +134,25 @@ public List<PathObject> convertToObjects(Parameters<Mat, Mat> params, Mat output
return pathObjects;
}

private static void assignClassificationsIfAvailable(PathObject pathObject, double[] values) {
if (values == null)
return;
try (var ml = pathObject.getMeasurementList()) {
int maxInd = 0;
double maxVal = values[0];
for (int i = 0; i < values.length; i++) {
double val = values[i];
if (val > maxVal) {
maxVal = val;
maxInd = i;
}
pathObject.getMeasurementList().put("Prediction " + i, val);
}
pathObject.setPathClass(PathClass.fromString("Class " + maxInd));
}
}


/**
* Assign a random color to a PathObject and all descendants, returning the object.
* @param pathObject
Expand Down
14 changes: 9 additions & 5 deletions src/main/java/qupath/ext/instanseg/core/MatTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
import org.bytedeco.opencv.opencv_core.Mat;
import qupath.ext.djl.DjlTools;

import java.util.Arrays;


class MatTranslator implements Translator<Mat, Mat> {
class MatTranslator implements Translator<Mat, Mat[]> {

private final String inputLayoutNd;
private final String outputLayoutNd;
Expand Down Expand Up @@ -59,9 +57,15 @@ public NDList processInput(TranslatorContext ctx, Mat input) {
}

@Override
public Mat processOutput(TranslatorContext ctx, NDList list) {
public Mat[] processOutput(TranslatorContext ctx, NDList list) {
var array = list.getFirst();
return DjlTools.ndArrayToMat(array, outputLayoutNd);
var labels = DjlTools.ndArrayToMat(array, outputLayoutNd);
var output = new Mat[list.size()];
output[0] = labels;
for (int i = 1; i < list.size(); i++) {
output[i] = DjlTools.ndArrayToMat(list.get(i), "HW");
}
return output;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

class TilePredictionProcessor implements Processor<Mat, Mat, Mat> {
class TilePredictionProcessor implements Processor<Mat, Mat, Mat[]> {

private static final Logger logger = LoggerFactory.getLogger(TilePredictionProcessor.class);

private final BlockingQueue<Predictor<Mat, Mat>> predictors;
private final BlockingQueue<Predictor<Mat, Mat[]>> predictors;

private final int inputWidth;
private final int inputHeight;
Expand All @@ -59,7 +59,7 @@ class TilePredictionProcessor implements Processor<Mat, Mat, Mat> {
*/
private final Map<ROI, ImageOp> normalization = Collections.synchronizedMap(new WeakHashMap<>());

TilePredictionProcessor(BlockingQueue<Predictor<Mat, Mat>> predictors,
TilePredictionProcessor(BlockingQueue<Predictor<Mat, Mat[]>> predictors,
Collection<? extends ColorTransforms.ColorTransform> channels,
int inputWidth, int inputHeight, boolean doPadding) {
this.predictors = predictors;
Expand Down Expand Up @@ -108,7 +108,7 @@ public boolean wasInterrupted() {
}

@Override
public Mat process(Parameters<Mat, Mat> params) throws IOException {
public Mat[] process(Parameters<Mat, Mat> params) throws IOException {

var mat = params.getImage();

Expand Down Expand Up @@ -136,7 +136,7 @@ public Mat process(Parameters<Mat, Mat> params) throws IOException {
mat = mat2;
}

Predictor<Mat, Mat> predictor = null;
Predictor<Mat, Mat[]> predictor = null;
try {
predictor = predictors.take();
logger.debug("Predicting tile {}", mat);
Expand All @@ -148,9 +148,12 @@ public Mat process(Parameters<Mat, Mat> params) throws IOException {
OpenCVTools.matToImagePlus("Output " + params.getRegionRequest(), matOutput).show();
}

matOutput.convertTo(matOutput, opencv_core.CV_32S);
// Handle the first output (labels)
// There may be other outputs (classiications, features), but we don't handle those here
matOutput[0].convertTo(matOutput[0], opencv_core.CV_32S);
if (padding != null)
matOutput = OpenCVTools.crop(matOutput, padding);
matOutput[0] = OpenCVTools.crop(matOutput[0], padding);

return matOutput;
} catch (TranslateException e) {
nTilesFailed.incrementAndGet();
Expand Down