Skip to content

Commit

Permalink
Normalisation and tiling (#14)
Browse files Browse the repository at this point in the history
* Use global thumbnail normalisation (Use thumbnail if nresolutions > 1)
* Use more complex tiling logic:
  - Split ROI into overlapping tiles
  - Run model on tiles
  - Remove objects within N pixels of tile margin (except those on the edge of the image)
  - Merge objects if they have an IoU >= 0.2
  • Loading branch information
alanocallaghan authored May 2, 2024
1 parent 5fe05d1 commit aa881f9
Show file tree
Hide file tree
Showing 9 changed files with 333 additions and 99 deletions.
15 changes: 10 additions & 5 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ base {
ext.qupathVersion = gradle.ext.qupathVersion

// Should be Java 17 for QuPath v0.5.0
ext.qupathJavaVersion = 17
ext.qupathJavaVersion = 21

/**
* Define dependencies.
Expand All @@ -35,10 +35,12 @@ ext.qupathJavaVersion = 17
* but shouldn't be bundled up for use in the extension.
*/
dependencies {

// Main QuPath user interface jar.
// Automatically includes other QuPath jars as subdependencies.
shadow "io.github.qupath:qupath-gui-fx:${qupathVersion}"

// used for co-development, should generally be left out
// implementation project(":qupath-gui-fx")

// For logging - the version comes from QuPath's version catalog at
// https://github.com/qupath/qupath/blob/main/gradle/libs.versions.toml
Expand All @@ -47,9 +49,9 @@ dependencies {
shadow libs.slf4j

// If you aren't using Groovy, this can be removed
shadow libs.deepJavaLibrary
shadow libs.qupath.fxtras
shadow libs.bioimageio.spec
shadow libs.deepJavaLibrary

implementation 'io.github.qupath:qupath-extension-djl:0.3.0'

Expand Down Expand Up @@ -150,8 +152,11 @@ tasks.named('test') {
// but helps overcome some gradle trouble when including this as a subproject
// within QuPath itself (which is useful during development).
repositories {
// Add this if you need access to dependencies only installed locally
// mavenLocal()

if (findProperty("use-maven-local")) {
logger.warn("Using Maven local")
mavenLocal()
}

mavenCentral()

Expand Down
3 changes: 1 addition & 2 deletions settings.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pluginManagement {
}

rootProject.name = 'qupath-extension-instanseg'
gradle.ext.qupathVersion = "0.5.0"
gradle.ext.qupathVersion = "0.6.0"

dependencyResolutionManagement {

Expand All @@ -21,7 +21,6 @@ dependencyResolutionManagement {
}

repositories {

mavenCentral()

// Add scijava - which is where QuPath's jars are hosted
Expand Down
80 changes: 48 additions & 32 deletions src/main/java/qupath/ext/instanseg/core/InstanSegTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
import org.slf4j.LoggerFactory;
import qupath.fx.dialogs.Dialogs;
import qupath.lib.experimental.pixels.OpenCVProcessor;
import qupath.lib.experimental.pixels.OutputHandler;
import qupath.lib.images.servers.ColorTransforms;
import qupath.lib.images.servers.PixelType;
import qupath.lib.objects.utils.ObjectMerger;
import qupath.lib.objects.utils.Tiler;
import qupath.lib.scripting.QP;
import qupath.opencv.ops.ImageOps;

import java.io.IOException;
import java.nio.file.Path;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
Expand Down Expand Up @@ -57,16 +58,20 @@ private static void printResourceCount(String title, BaseNDManager manager) {

@Override
protected Void call() throws Exception {
logger.info("Using $nThreads threads");
int nPredictors = 1;

// TODO: Set path! (unsure what path this comment refers to, so not removing...)
var imageData = QP.getCurrentImageData();

int inputWidth = tileSize;
// int inputWidth = 256;
int inputHeight = inputWidth;
int padding = 16;

int padding = 80; // todo: setting? or just based on tile size. Should discuss.
int boundary = 25;
if (tileSize == 128) {
padding = 50;
boundary = 20;
}
// Optionally pad images to the required size
boolean padToInputSize = true;
String layout = "CHW";
Expand All @@ -76,6 +81,7 @@ protected Void call() throws Exception {

var device = Device.fromName(deviceName);


try (var model = Criteria.builder()
.setTypes(Mat.class, Mat.class)
.optModelUrls(String.valueOf(modelPath))
Expand All @@ -86,46 +92,56 @@ protected Void call() throws Exception {
.loadModel()) {

BaseNDManager baseManager = (BaseNDManager)model.getNDManager();

printResourceCount("Resource count before prediction", (BaseNDManager)baseManager.getParentManager());
printResourceCount("Resource count before prediction",
(BaseNDManager)baseManager.getParentManager());
baseManager.debugDump(2);

BlockingQueue<Predictor<Mat, Mat>> predictors = new ArrayBlockingQueue<>(nPredictors);

try {
for (int i = 0; i < nPredictors; i++)
for (int i = 0; i < nPredictors; i++) {
predictors.put(model.newPredictor());
}

printResourceCount("Resource count after creating predictors", (BaseNDManager)baseManager.getParentManager());

var preprocessing = ImageOps.Core.sequential(
ImageOps.Core.ensureType(PixelType.FLOAT32),
ImageOps.Normalize.percentile(1, 99, true, 1e-6)
);
var predictionProcessor = new TilePredictionProcessor(predictors, baseManager,
layout, layoutOutput, preprocessing, inputWidth, inputHeight, padToInputSize);
var processor = OpenCVProcessor.builder(predictionProcessor)
.imageSupplier((parameters) -> ImageOps.buildImageDataOp(channels).apply(parameters.getImageData(), parameters.getRegionRequest()))
.tiler(Tiler.builder((int)(downsample * inputWidth-padding*2), (int)(downsample * inputHeight-padding*2))
.alignTopLeft()
.cropTiles(false)
.build()
)
.outputHandler(OutputHandler.createObjectOutputHandler(new OutputToObjectConvert()))
.padding(padding)
.mergeSharedBoundaries(0.25)
.downsample(downsample)
.build();
var runner = createTaskRunner(nThreads);
processor.processObjects(runner, imageData, QP.getSelectedObjects());
printResourceCount("Resource count after creating predictors",
(BaseNDManager)baseManager.getParentManager());


for (var object: QP.getSelectedObjects()) {
var norm = ImageOps.Normalize.percentile(1, 99);

if (imageData.isFluorescence()) {
norm = InstanSegUtils.getNormalization(imageData, object, channels);
}
var preprocessing = ImageOps.Core.sequential(
ImageOps.Core.ensureType(PixelType.FLOAT32),
norm,
ImageOps.Core.clip(-0.5, 1.5)
);

var predictionProcessor = new TilePredictionProcessor(predictors, baseManager,
layout, layoutOutput, preprocessing, inputWidth, inputHeight, padToInputSize);
var processor = OpenCVProcessor.builder(predictionProcessor)
.imageSupplier((parameters) -> ImageOps.buildImageDataOp(channels).apply(parameters.getImageData(), parameters.getRegionRequest()))
.tiler(Tiler.builder((int)(downsample * inputWidth), (int)(downsample * (inputHeight)))
.alignCenter()
.cropTiles(false)
.build()
)
.outputHandler(new OutputToObjectConverter.PruneObjectOutputHandler<>(new OutputToObjectConverter(), boundary))
.padding(padding)
.merger(ObjectMerger.createIoUMerger(0.2))
.downsample(downsample)
.build();
var runner = createTaskRunner(nThreads);
processor.processObjects(runner, imageData, Collections.singleton(object));
}
} finally {
for (var predictor: predictors) {
predictor.close();
}
}
printResourceCount("Resource count after prediction", (BaseNDManager)baseManager.getParentManager());
} catch (ModelNotFoundException | MalformedModelException |
IOException | InterruptedException ex) {
} catch (ModelNotFoundException | MalformedModelException | IOException ex) {
Dialogs.showErrorMessage("Unable to run InstanSeg", ex);
logger.error("Unable to run InstanSeg", ex);
}
Expand Down
73 changes: 71 additions & 2 deletions src/main/java/qupath/ext/instanseg/core/InstanSegUtils.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package qupath.ext.instanseg.core;


import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.experimental.pixels.MeasurementProcessor;
import qupath.lib.images.ImageData;
import qupath.lib.images.servers.ColorTransforms;
import qupath.lib.objects.PathObject;
import qupath.lib.regions.RegionRequest;
import qupath.opencv.ops.ImageOp;
import qupath.opencv.ops.ImageOps;


import java.awt.image.BufferedImage;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;

public class InstanSegUtils {
private static final Logger logger = LoggerFactory.getLogger(InstanSegUtils.class);
Expand All @@ -15,6 +22,68 @@ private InstanSegUtils() {
throw new AssertionError("Do not instantiate this class");
}

/**
* Try to fetch percentile normalisation factors from the image, using a
* large downsample if the input pathObject is large. Uses the
* bounding box of the pathObject so hopefully allows comparable output
* to the same image through InstanSeg in Python as a full image.
*
* @param imageData ImageData for the current image.
* @param pathObject The object that we'll be doing segmentation in.
* @param channels The channels/color transforms that the segmentation
* will be restricted to.
* @return Percentile-based normalisation based on the bounding box,
* or default tile-based percentile normalisation if that fails.
*/
static ImageOp getNormalization(ImageData<BufferedImage> imageData, PathObject pathObject, List<ColorTransforms.ColorTransform> channels) {
var defaults = ImageOps.Normalize.percentile(1, 99, true, 1e-6);
try {
// read the bounding box of the current object
var roi = pathObject.getROI();
double nPix = roi.getBoundsWidth() * roi.getBoundsHeight() * channels.size();

BufferedImage image;
if (imageData.getServer().nResolutions() > 1) {
// if there's more than one resolution, pray that the thumbnail is reasonable size
image = imageData.getServer().getDefaultThumbnail(0, 0);
} else {
double downsample = Math.max(nPix / 5e7, 1);
var request = RegionRequest.createInstance(imageData.getServerPath(), downsample, roi);
image = imageData.getServer().readRegion(request);
}

double eps = 1e-6;
var params = channels.stream().map(colorTransform -> {
float[] fpix = colorTransform.extractChannel(imageData.getServer(), image, null);
double[] pixels = new double[fpix.length];
double offset;
double scale;
for (int j = 0; j < pixels.length; j++) {
pixels[j] = (double) fpix[j];
}
var lo = MeasurementProcessor.Functions.percentile(1).apply(pixels);
var hi = MeasurementProcessor.Functions.percentile(99).apply(pixels);
if (hi == lo && eps == 0.0) {
logger.warn("Normalization percentiles give the same value ({}), scale will be Infinity", lo);
scale = Double.POSITIVE_INFINITY;
} else {
scale = 1.0 / (hi - lo + eps);
}
offset = -lo * scale;
return new double[]{offset, scale};
}).toList();

return ImageOps.Core.sequential(
ImageOps.Core.multiply(params.stream().mapToDouble(e -> e[1]).toArray()),
ImageOps.Core.add(params.stream().mapToDouble(e -> e[0]).toArray())
);

} catch (Exception e) {
logger.error("Error reading thumbnail", e);
}
return defaults;
}

public static boolean isValidModel(Path path) {
// return path.toString().endsWith(".pt"); // if just looking at pt files
if (Files.isDirectory(path)) {
Expand Down
41 changes: 0 additions & 41 deletions src/main/java/qupath/ext/instanseg/core/OutputToObjectConvert.java

This file was deleted.

Loading

0 comments on commit aa881f9

Please sign in to comment.