Skip to content

Commit

Permalink
Merge branch 'main' into measurements
Browse files Browse the repository at this point in the history
  • Loading branch information
alanocallaghan committed Aug 5, 2024
2 parents 239954d + 2178f9f commit 024c684
Show file tree
Hide file tree
Showing 8 changed files with 878 additions and 293 deletions.
409 changes: 409 additions & 0 deletions src/main/java/qupath/ext/instanseg/core/InstanSeg.java

Large diffs are not rendered by default.

52 changes: 33 additions & 19 deletions src/main/java/qupath/ext/instanseg/core/InstanSegModel.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package qupath.ext.instanseg.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.BaseNDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.training.util.ProgressBar;
import com.google.gson.internal.LinkedTreeMap;
import org.bytedeco.opencv.opencv_core.Mat;
Expand All @@ -31,6 +29,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
Expand All @@ -50,15 +49,32 @@ private InstanSegModel(BioimageIoSpec.BioimageIoModel bioimageIoModel) {
this.name = model.getName();
}

public InstanSegModel(URL modelURL, String name) {

private InstanSegModel(URL modelURL, String name) {
this.modelURL = modelURL;
this.name = name;
}

public static InstanSegModel createModel(Path path) throws IOException {
/**
* Create an InstanSeg model from an existing path.
* @param path The path to the folder that contains the model .pt file and the config YAML file.
* @return A handle on the model that can be used for inference.
* @throws IOException If the directory can't be found or isn't a valid model directory.
*/
public static InstanSegModel fromPath(Path path) throws IOException {
return new InstanSegModel(BioimageIoSpec.parseModel(path.toFile()));
}

/**
* Request an InstanSeg model from the set of available models
* @param name The model name
* @return The specified model.
*/
public static InstanSegModel fromName(String name) {
// todo: instantiate built-in models somehow
throw new UnsupportedOperationException("Fetching models by name is not yet implemented!");
}

public BioimageIoSpec.BioimageIoModel getModel() {
if (model == null) {
try {
Expand Down Expand Up @@ -128,34 +144,31 @@ public String toString() {
return getName();
}

public void runInstanSeg(
Collection<PathObject> pathObjects,
void runInstanSeg(
ImageData<BufferedImage> imageData,
Collection<PathObject> pathObjects,
Collection<ColorTransforms.ColorTransform> channels,
int tileSize,
int tileDims,
double downsample,
String deviceName,
int padding,
int boundary,
Device device,
boolean nucleiOnly,
TaskRunner taskRunner) throws ModelNotFoundException, MalformedModelException, IOException, InterruptedException {
List<Class<? extends PathObject>> outputClasses,
TaskRunner taskRunner) {

nFailed = 0;
Path modelPath = getPath().resolve("instanseg.pt");
int nPredictors = 1; // todo: change me?

int padding = 40; // todo: setting? or just based on tile size. Should discuss.
int boundary = 20;
if (tileSize == 128) {
padding = 25;
boundary = 15;
}

// Optionally pad images to the required size
boolean padToInputSize = true;
String layout = "CHW";

// TODO: Remove C if not needed (added for instanseg_v0_2_0.pt) - still relevant?
String layoutOutput = "CHW";

var device = Device.fromName(deviceName);

try (var model = Criteria.builder()
.setTypes(Mat.class, Mat.class)
Expand All @@ -166,6 +179,7 @@ public void runInstanSeg(
.build()
.loadModel()) {


BaseNDManager baseManager = (BaseNDManager)model.getNDManager();
printResourceCount("Resource count before prediction",
(BaseNDManager)baseManager.getParentManager());
Expand All @@ -180,17 +194,17 @@ public void runInstanSeg(
printResourceCount("Resource count after creating predictors",
(BaseNDManager)baseManager.getParentManager());

int sizeWithoutPadding = (int) Math.ceil(downsample * (tileSize - (double) padding));
int sizeWithoutPadding = (int) Math.ceil(downsample * (tileDims - (double) padding));
var predictionProcessor = new TilePredictionProcessor(predictors, baseManager,
layout, layoutOutput, channels, tileSize, tileSize, padToInputSize);
layout, layoutOutput, channels, tileDims, tileDims, padToInputSize);
var processor = OpenCVProcessor.builder(predictionProcessor)
.imageSupplier((parameters) -> ImageOps.buildImageDataOp(channels).apply(parameters.getImageData(), parameters.getRegionRequest()))
.tiler(Tiler.builder(sizeWithoutPadding)
.alignCenter()
.cropTiles(false)
.build()
)
.outputHandler(new OutputToObjectConverter.PruneObjectOutputHandler<>(new OutputToObjectConverter(), boundary))
.outputHandler(new PruneObjectOutputHandler<>(new InstansegOutputToObjectConverter(outputClasses), boundary))
.padding(padding)
.merger(ObjectMerger.createIoUMerger(0.2))
.downsample(downsample)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package qupath.ext.instanseg.core;

import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.analysis.images.ContourTracing;
import qupath.lib.common.ColorTools;
import qupath.lib.experimental.pixels.OutputHandler;
import qupath.lib.experimental.pixels.Parameters;
import qupath.lib.objects.PathAnnotationObject;
import qupath.lib.objects.PathCellObject;
import qupath.lib.objects.PathDetectionObject;
import qupath.lib.objects.PathObject;
import qupath.lib.objects.PathObjects;
import qupath.lib.roi.interfaces.ROI;
import qupath.opencv.tools.OpenCVTools;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

class InstansegOutputToObjectConverter implements OutputHandler.OutputToObjectConverter<Mat, Mat, Mat> {
private static final Logger logger = LoggerFactory.getLogger(InstansegOutputToObjectConverter.class);

private static final long seed = 1243;
private final List<Class<? extends PathObject>> classes;

InstansegOutputToObjectConverter(List<Class<? extends PathObject>> outputClasses) {
this.classes = outputClasses;
}

@Override
public List<PathObject> convertToObjects(Parameters<Mat, Mat> params, Mat output) {
if (output == null) {
return List.of();
}
int nChannels = output.channels();
if (nChannels < 1 || nChannels > 2)
throw new IllegalArgumentException("Expected 1 or 2 channels, but found " + nChannels);


List<Map<Number, ROI>> roiMaps = new ArrayList<>();
for (var mat : OpenCVTools.splitChannels(output)) {
var image = OpenCVTools.matToSimpleImage(mat, 0);
roiMaps.add(
ContourTracing.createROIs(image, params.getRegionRequest(), 1, -1)
);
}
var rng = new Random(seed);

BiFunction<ROI, ROI, PathObject> function;
if (classes.size() == 1) {
function = getOneClassBiFunction(classes, rng);
} else {
// if of length 2, then can be:
// detection <- annotation, annotation <- annotation, detection <- detection
assert classes.size() == 2;
function = getTwoClassBiFunction(classes, rng);
}

if (roiMaps.size() == 1) {
// One-channel detected, represent using detection objects
return roiMaps.get(0).values().stream()
.map(roi -> function.apply(roi, null))
.collect(Collectors.toList());
} else {
// Two channels detected, represent using cell objects
// We assume that the labels are matched - and we can't have a nucleus without a cell
Map<Number, ROI> childROIs = roiMaps.get(0);
Map<Number, ROI> parentROIs = roiMaps.get(1);
List<PathObject> cells = new ArrayList<>();
for (var entry : parentROIs.entrySet()) {
var parent = entry.getValue();
var child = childROIs.getOrDefault(entry.getKey(), null);
var outputObject = function.apply(parent, child);
cells.add(outputObject);
}
return cells;
}
}

private static BiFunction<ROI, ROI, PathObject> getOneClassBiFunction(List<Class<? extends PathObject>> classes, Random rng) {
// if of length 1, can be
// cellObject (with or without nucleus), annotations, detections
if (classes.get(0) == PathAnnotationObject.class) {
return createObjectsFun(PathObjects::createAnnotationObject, PathObjects::createAnnotationObject, rng);
} else if (classes.get(0) == PathDetectionObject.class) {
return createObjectsFun(PathObjects::createDetectionObject, PathObjects::createDetectionObject, rng);
} else if (classes.get(0) == PathCellObject.class) {
return createCellFun(rng);
} else {
logger.warn("Unknown output {}, defaulting to cells", classes.get(0));
return createCellFun(rng);
}
}

private static BiFunction<ROI, ROI, PathObject> getTwoClassBiFunction(List<Class<? extends PathObject>> classes, Random rng) {
Function<ROI, PathObject> fun0, fun1;
var knownClasses = List.of(PathDetectionObject.class, PathAnnotationObject.class);
if (!knownClasses.contains(classes.get(0)) || !knownClasses.contains(classes.get(1))) {
logger.warn("Unknown combination of outputs {} <- {}, defaulting to cells", classes.get(0), classes.get(1));
return createCellFun(rng);
}
if (classes.get(0) == PathDetectionObject.class) {
fun0 = PathObjects::createDetectionObject;
} else {
fun0 = PathObjects::createAnnotationObject;
}
if (classes.get(1) == PathDetectionObject.class) {
fun1 = PathObjects::createDetectionObject;
} else {
fun1 = PathObjects::createAnnotationObject;
}
return createObjectsFun(fun0, fun1, rng);
}

private static BiFunction<ROI, ROI, PathObject> createCellFun(Random rng) {
return (parent, child) -> {
var cell = PathObjects.createCellObject(parent, child);
var color = ColorTools.packRGB(rng.nextInt(255), rng.nextInt(255), rng.nextInt(255));
cell.setColor(color);
return cell;
};
}

private static BiFunction<ROI, ROI, PathObject> createObjectsFun(Function<ROI, PathObject> createParentFun, Function<ROI, PathObject> createChildFun, Random rng) {
return (parent, child) -> {
var parentObj = createParentFun.apply(parent);
var color = ColorTools.packRGB(rng.nextInt(255), rng.nextInt(255), rng.nextInt(255));
parentObj.setColor(color);
if (child != null) {
var childObj = createChildFun.apply(child);
childObj.setColor(color);
parentObj.addChildObject(childObj);
}
return parentObj;
};
}

}
Loading

0 comments on commit 024c684

Please sign in to comment.