Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass optional arguments #121

Merged
merged 8 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions src/main/java/qupath/ext/instanseg/core/InstanSeg.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ai.djl.ndarray.BaseNDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.training.util.ProgressBar;
import java.util.HashMap;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
Expand Down Expand Up @@ -37,6 +38,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ArrayBlockingQueue;
Expand All @@ -57,11 +59,19 @@ public class InstanSeg {
private final Device device;
private final TaskRunner taskRunner;
private final Class<? extends PathObject> preferredOutputClass;
private final Map<String, Object> optionalArgs;

// This was previously an adjustable parameter, but it's now fixed at 1 because we handle overlaps differently.
// However we might want to reinstate it, possibly as a proportion of the padding amount.
private final int boundaryThreshold = 1;

/**
alanocallaghan marked this conversation as resolved.
Show resolved Hide resolved
* Run inference for the currently selected PathObjects in the current image.
*/
public InstanSegResults detectObjects() {
return detectObjects(QP.getCurrentImageData());
}

private InstanSeg(Builder builder) {
this.tileDims = builder.tileDims;
this.downsample = builder.downsample; // Optional... and not advised (use the model spec instead); set <= 0 to ignore
Expand All @@ -74,13 +84,7 @@ private InstanSeg(Builder builder) {
this.preferredOutputClass = builder.preferredOutputClass;
this.randomColors = builder.randomColors;
this.makeMeasurements = builder.makeMeasurements;
}

/**
* Run inference for the currently selected PathObjects in the current image.
*/
public InstanSegResults detectObjects() {
return detectObjects(QP.getCurrentImageData());
this.optionalArgs = builder.optionalArgs;
alanocallaghan marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand Down Expand Up @@ -215,7 +219,7 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
.optModelUrls(String.valueOf(modelPath.toUri()))
.optProgress(new ProgressBar())
.optDevice(device) // Remove this line if devices are problematic!
.optTranslator(new MatTranslator(layout, layoutOutput, outputChannelArray))
.optTranslator(new MatTranslator(layout, layoutOutput, outputChannelArray, optionalArgs))
.build()
.loadModel()) {

Expand Down Expand Up @@ -392,6 +396,7 @@ public static final class Builder {
private Collection<? extends ColorTransforms.ColorTransform> channels;
private InstanSegModel model;
private Class<? extends PathObject> preferredOutputClass;
private final Map<String, Object> optionalArgs = new HashMap<>();

Builder() {}

Expand Down Expand Up @@ -653,6 +658,27 @@ public Builder outputAnnotations() {
return this;
}

/**
* Set a number of optional arguments
* @param optionalArgs The argument names and values.
* @return A modified builder.
*/
public Builder args(Map<String, Object> optionalArgs) {
alanocallaghan marked this conversation as resolved.
Show resolved Hide resolved
this.optionalArgs.putAll(optionalArgs);
return this;
}

/**
* Set a number of optional arguments
* @param name The argument name.
* @param value The argument value.
* @return A modified builder.
*/
public Builder arg(String name, Object value) {
optionalArgs.put(name, value);
return this;
}

/**
* Request to make measurements from the objects created by InstanSeg.
* @return this builder
Expand Down
25 changes: 24 additions & 1 deletion src/main/java/qupath/ext/instanseg/core/MatTranslator.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package qupath.ext.instanseg.core;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.bytedeco.opencv.opencv_core.Mat;
import qupath.ext.djl.DjlTools;

Expand All @@ -15,6 +21,7 @@ class MatTranslator implements Translator<Mat, Mat> {
private final String inputLayoutNd;
private final String outputLayoutNd;
private final int[] outputChannels;
private final Map<String, Object> optionalArgs;

/**
* Create a translator from InstanSeg input to output.
Expand All @@ -23,10 +30,11 @@ class MatTranslator implements Translator<Mat, Mat> {
* @param outputChannels Array of channels to output; if null or empty, output all channels.
* Values should be true for channels to output, false for channels to ignore.
*/
MatTranslator(String inputLayoutNd, String outputLayoutNd, boolean[] outputChannels) {
MatTranslator(String inputLayoutNd, String outputLayoutNd, boolean[] outputChannels, Map<String, Object> optionalArgs) {
this.inputLayoutNd = inputLayoutNd;
this.outputLayoutNd = outputLayoutNd;
this.outputChannels = convertBooleanArray(outputChannels);
this.optionalArgs = optionalArgs;
alanocallaghan marked this conversation as resolved.
Show resolved Hide resolved
}

private static int[] convertBooleanArray(boolean[] array) {
Expand Down Expand Up @@ -55,9 +63,24 @@ public NDList processInput(TranslatorContext ctx, Mat input) {
var arrayCPU = array.toDevice(Device.cpu(), false);
out.add(arrayCPU);
}
List<NDArray> args = sanitizeOptionalArgs(optionalArgs, manager);
out.addAll(args);
return out;
}

private static List<NDArray> sanitizeOptionalArgs(Map<String, Object> optionalArgs, NDManager manager) {
List<NDArray> arrays = new ArrayList<>();
for (var es : optionalArgs.entrySet()) {
var val = es.getValue();
if (val instanceof Double || val instanceof BigDecimal) {
alanocallaghan marked this conversation as resolved.
Show resolved Hide resolved
NDArray array = manager.create(((Number) val).floatValue());
array.setName("args." + es.getKey());
arrays.add(array);
}
}
return arrays;
}

@Override
public Mat processOutput(TranslatorContext ctx, NDList list) {
var array = list.getFirst();
Expand Down
Loading