Skip to content

Commit

Permalink
Add run API
Browse files Browse the repository at this point in the history
Add run API using a builder pattern.
Part of this enables specification of the output classes,
enabling users to output nested annotations, detections, or
just cells (with or without nuclei).
Enables the flexible selection of input channels with indices or names.

Also mildly restructure code around the output converter/output handler.
  • Loading branch information
alanocallaghan authored Aug 5, 2024
2 parents f73b755 + f612894 commit 2178f9f
Show file tree
Hide file tree
Showing 8 changed files with 873 additions and 289 deletions.
409 changes: 409 additions & 0 deletions src/main/java/qupath/ext/instanseg/core/InstanSeg.java

Large diffs are not rendered by default.

53 changes: 33 additions & 20 deletions src/main/java/qupath/ext/instanseg/core/InstanSegModel.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package qupath.ext.instanseg.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.BaseNDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.training.util.ProgressBar;
import com.google.gson.internal.LinkedTreeMap;
import org.bytedeco.opencv.opencv_core.Mat;
Expand Down Expand Up @@ -51,15 +49,32 @@ private InstanSegModel(BioimageIoSpec.BioimageIoModel bioimageIoModel) {
this.name = model.getName();
}

public InstanSegModel(URL modelURL, String name) {

private InstanSegModel(URL modelURL, String name) {
this.modelURL = modelURL;
this.name = name;
}

public static InstanSegModel createModel(Path path) throws IOException {
/**
* Create an InstanSeg model from an existing path.
* @param path The path to the folder that contains the model .pt file and the config YAML file.
* @return A handle on the model that can be used for inference.
* @throws IOException If the directory can't be found or isn't a valid model directory.
*/
public static InstanSegModel fromPath(Path path) throws IOException {
return new InstanSegModel(BioimageIoSpec.parseModel(path.toFile()));
}

/**
* Request an InstanSeg model from the set of available models
* @param name The model name
* @return The specified model.
*/
public static InstanSegModel fromName(String name) {
// todo: instantiate built-in models somehow
throw new UnsupportedOperationException("Fetching models by name is not yet implemented!");
}

public BioimageIoSpec.BioimageIoModel getModel() {
if (model == null) {
try {
Expand Down Expand Up @@ -129,34 +144,31 @@ public String toString() {
return getName();
}

public void runInstanSeg(
Collection<PathObject> pathObjects,
void runInstanSeg(
ImageData<BufferedImage> imageData,
List<ColorTransforms.ColorTransform> channels,
int tileSize,
Collection<PathObject> pathObjects,
Collection<ColorTransforms.ColorTransform> channels,
int tileDims,
double downsample,
String deviceName,
int padding,
int boundary,
Device device,
boolean nucleiOnly,
TaskRunner taskRunner) throws ModelNotFoundException, MalformedModelException, IOException, InterruptedException {
List<Class<? extends PathObject>> outputClasses,
TaskRunner taskRunner) {

nFailed = 0;
Path modelPath = getPath().resolve("instanseg.pt");
int nPredictors = 1; // todo: change me?

int padding = 40; // todo: setting? or just based on tile size. Should discuss.
int boundary = 20;
if (tileSize == 128) {
padding = 25;
boundary = 15;
}

// Optionally pad images to the required size
boolean padToInputSize = true;
String layout = "CHW";

// TODO: Remove C if not needed (added for instanseg_v0_2_0.pt) - still relevant?
String layoutOutput = "CHW";

var device = Device.fromName(deviceName);

try (var model = Criteria.builder()
.setTypes(Mat.class, Mat.class)
Expand All @@ -167,6 +179,7 @@ public void runInstanSeg(
.build()
.loadModel()) {


BaseNDManager baseManager = (BaseNDManager)model.getNDManager();
printResourceCount("Resource count before prediction",
(BaseNDManager)baseManager.getParentManager());
Expand All @@ -181,17 +194,17 @@ public void runInstanSeg(
printResourceCount("Resource count after creating predictors",
(BaseNDManager)baseManager.getParentManager());

int sizeWithoutPadding = (int) Math.ceil(downsample * (tileSize - (double) padding));
int sizeWithoutPadding = (int) Math.ceil(downsample * (tileDims - (double) padding));
var predictionProcessor = new TilePredictionProcessor(predictors, baseManager,
layout, layoutOutput, channels, tileSize, tileSize, padToInputSize);
layout, layoutOutput, channels, tileDims, tileDims, padToInputSize);
var processor = OpenCVProcessor.builder(predictionProcessor)
.imageSupplier((parameters) -> ImageOps.buildImageDataOp(channels).apply(parameters.getImageData(), parameters.getRegionRequest()))
.tiler(Tiler.builder(sizeWithoutPadding)
.alignCenter()
.cropTiles(false)
.build()
)
.outputHandler(new OutputToObjectConverter.PruneObjectOutputHandler<>(new OutputToObjectConverter(), boundary))
.outputHandler(new PruneObjectOutputHandler<>(new InstansegOutputToObjectConverter(outputClasses), boundary))
.padding(padding)
.merger(ObjectMerger.createIoUMerger(0.2))
.downsample(downsample)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package qupath.ext.instanseg.core;

import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.analysis.images.ContourTracing;
import qupath.lib.common.ColorTools;
import qupath.lib.experimental.pixels.OutputHandler;
import qupath.lib.experimental.pixels.Parameters;
import qupath.lib.objects.PathAnnotationObject;
import qupath.lib.objects.PathCellObject;
import qupath.lib.objects.PathDetectionObject;
import qupath.lib.objects.PathObject;
import qupath.lib.objects.PathObjects;
import qupath.lib.roi.interfaces.ROI;
import qupath.opencv.tools.OpenCVTools;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

class InstansegOutputToObjectConverter implements OutputHandler.OutputToObjectConverter<Mat, Mat, Mat> {
private static final Logger logger = LoggerFactory.getLogger(InstansegOutputToObjectConverter.class);

private static final long seed = 1243;
private final List<Class<? extends PathObject>> classes;

InstansegOutputToObjectConverter(List<Class<? extends PathObject>> outputClasses) {
this.classes = outputClasses;
}

@Override
public List<PathObject> convertToObjects(Parameters<Mat, Mat> params, Mat output) {
if (output == null) {
return List.of();
}
int nChannels = output.channels();
if (nChannels < 1 || nChannels > 2)
throw new IllegalArgumentException("Expected 1 or 2 channels, but found " + nChannels);


List<Map<Number, ROI>> roiMaps = new ArrayList<>();
for (var mat : OpenCVTools.splitChannels(output)) {
var image = OpenCVTools.matToSimpleImage(mat, 0);
roiMaps.add(
ContourTracing.createROIs(image, params.getRegionRequest(), 1, -1)
);
}
var rng = new Random(seed);

BiFunction<ROI, ROI, PathObject> function;
if (classes.size() == 1) {
function = getOneClassBiFunction(classes, rng);
} else {
// if of length 2, then can be:
// detection <- annotation, annotation <- annotation, detection <- detection
assert classes.size() == 2;
function = getTwoClassBiFunction(classes, rng);
}

if (roiMaps.size() == 1) {
// One-channel detected, represent using detection objects
return roiMaps.get(0).values().stream()
.map(roi -> function.apply(roi, null))
.collect(Collectors.toList());
} else {
// Two channels detected, represent using cell objects
// We assume that the labels are matched - and we can't have a nucleus without a cell
Map<Number, ROI> childROIs = roiMaps.get(0);
Map<Number, ROI> parentROIs = roiMaps.get(1);
List<PathObject> cells = new ArrayList<>();
for (var entry : parentROIs.entrySet()) {
var parent = entry.getValue();
var child = childROIs.getOrDefault(entry.getKey(), null);
var outputObject = function.apply(parent, child);
cells.add(outputObject);
}
return cells;
}
}

private static BiFunction<ROI, ROI, PathObject> getOneClassBiFunction(List<Class<? extends PathObject>> classes, Random rng) {
// if of length 1, can be
// cellObject (with or without nucleus), annotations, detections
if (classes.get(0) == PathAnnotationObject.class) {
return createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng);
} else if (classes.get(0) == PathDetectionObject.class) {
return createObjectsFun(PathObjects::createDetectionObject, PathObjects::createDetectionObject, rng);
} else if (classes.get(0) == PathCellObject.class) {
return createCellFun(rng);
} else {
logger.warn("Unknown output {}, defaulting to cells", classes.get(0));
return createCellFun(rng);
}
}

private static BiFunction<ROI, ROI, PathObject> getTwoClassBiFunction(List<Class<? extends PathObject>> classes, Random rng) {
Function<ROI, PathObject> fun0, fun1;
var knownClasses = List.of(PathDetectionObject.class, PathAnnotationObject.class);
if (!knownClasses.contains(classes.get(0)) || !knownClasses.contains(classes.get(1))) {
logger.warn("Unknown combination of outputs {} <- {}, defaulting to cells", classes.get(0), classes.get(1));
return createCellFun(rng);
}
if (classes.get(0) == PathDetectionObject.class) {
fun0 = PathObjects::createDetectionObject;
} else {
fun0 = PathObjects::createAnnotationObject;
}
if (classes.get(1) == PathDetectionObject.class) {
fun1 = PathObjects::createDetectionObject;
} else {
fun1 = PathObjects::createAnnotationObject;
}
return createObjectsFun(fun0, fun1, rng);
}

private static BiFunction<ROI, ROI, PathObject> createCellFun(Random rng) {
return (parent, child) -> {
var cell = PathObjects.createCellObject(parent, child);
var color = ColorTools.packRGB(rng.nextInt(255), rng.nextInt(255), rng.nextInt(255));
cell.setColor(color);
return cell;
};
}

private static BiFunction<ROI, ROI, PathObject> createObjectsFun(Function<ROI, PathObject> createParentFun, Function<ROI, PathObject> createChildFun, Random rng) {
return (parent, child) -> {
var parentObj = createParentFun.apply(parent);
var color = ColorTools.packRGB(rng.nextInt(255), rng.nextInt(255), rng.nextInt(255));
parentObj.setColor(color);
if (child != null) {
var childObj = createChildFun.apply(child);
childObj.setColor(color);
parentObj.addChildObject(childObj);
}
return parentObj;
};
}

}
Loading

0 comments on commit 2178f9f

Please sign in to comment.