Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass optional arguments #121

Merged
merged 8 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/main/java/qupath/ext/instanseg/core/InstanSeg.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ArrayBlockingQueue;
Expand All @@ -57,11 +59,13 @@ public class InstanSeg {
private final Device device;
private final TaskRunner taskRunner;
private final Class<? extends PathObject> preferredOutputClass;
private final Map<String, Object> optionalArgs = new LinkedHashMap<>();

// This was previously an adjustable parameter, but it's now fixed at 1 because we handle overlaps differently.
// However we might want to reinstate it, possibly as a proportion of the padding amount.
// However, we might want to reinstate it, possibly as a proportion of the padding amount.
private final int boundaryThreshold = 1;


private InstanSeg(Builder builder) {
this.tileDims = builder.tileDims;
this.downsample = builder.downsample; // Optional... and not advised (use the model spec instead); set <= 0 to ignore
Expand All @@ -74,6 +78,7 @@ private InstanSeg(Builder builder) {
this.preferredOutputClass = builder.preferredOutputClass;
this.randomColors = builder.randomColors;
this.makeMeasurements = builder.makeMeasurements;
this.optionalArgs.putAll(builder.optionalArgs);
}

/**
Expand Down Expand Up @@ -215,7 +220,7 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
.optModelUrls(String.valueOf(modelPath.toUri()))
.optProgress(new ProgressBar())
.optDevice(device) // Remove this line if devices are problematic!
.optTranslator(new MatTranslator(layout, layoutOutput, outputChannelArray))
.optTranslator(new MatTranslator(layout, layoutOutput, outputChannelArray, optionalArgs))
.build()
.loadModel()) {

Expand Down Expand Up @@ -392,6 +397,7 @@ public static final class Builder {
private Collection<? extends ColorTransforms.ColorTransform> channels;
private InstanSegModel model;
private Class<? extends PathObject> preferredOutputClass;
private Map<String, Object> optionalArgs;

Builder() {}

Expand Down Expand Up @@ -653,6 +659,17 @@ public Builder outputAnnotations() {
return this;
}

/**
* Set a number of optional arguments
* @param optionalArgs The argument names and values.
* @return A modified builder.
*/
public Builder args(Map<String, ?> optionalArgs) {
this.optionalArgs.putAll(optionalArgs);
return this;
}


/**
* Request to make measurements from the objects created by InstanSeg.
* @return this builder
Expand Down
51 changes: 48 additions & 3 deletions src/main/java/qupath/ext/instanseg/core/MatTranslator.java
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
package qupath.ext.instanseg.core;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.ext.djl.DjlTools;

import java.util.Arrays;


class MatTranslator implements Translator<Mat, Mat> {
private static final Logger logger = LoggerFactory.getLogger(MatTranslator.class);

private final String inputLayoutNd;
private final String outputLayoutNd;
private final int[] outputChannels;
private final Map<String, Object> optionalArgs = new HashMap<>();

/**
* Create a translator from InstanSeg input to output.
Expand All @@ -23,10 +32,11 @@ class MatTranslator implements Translator<Mat, Mat> {
* @param outputChannels Array of channels to output; if null or empty, output all channels.
* Values should be true for channels to output, false for channels to ignore.
*/
MatTranslator(String inputLayoutNd, String outputLayoutNd, boolean[] outputChannels) {
MatTranslator(String inputLayoutNd, String outputLayoutNd, boolean[] outputChannels, Map<String, ?> optionalArgs) {
this.inputLayoutNd = inputLayoutNd;
this.outputLayoutNd = outputLayoutNd;
this.outputChannels = convertBooleanArray(outputChannels);
this.optionalArgs.putAll(optionalArgs);
}

private static int[] convertBooleanArray(boolean[] array) {
Expand Down Expand Up @@ -55,9 +65,44 @@ public NDList processInput(TranslatorContext ctx, Mat input) {
var arrayCPU = array.toDevice(Device.cpu(), false);
out.add(arrayCPU);
}
List<NDArray> args = sanitizeOptionalArgs(optionalArgs, manager);
out.addAll(args);
return out;
}

private static List<NDArray> sanitizeOptionalArgs(Map<String, ?> optionalArgs, NDManager manager) {
List<NDArray> arrays = new ArrayList<>();
for (var es : optionalArgs.entrySet()) {
var val = es.getValue();
NDArray array = null;
switch (val) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so that's how you do it... wanted to use a switch originally, but didn't know the syntax :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can thank IntelliJ's linting, I just alt-enter stuff and keep it if I think it's better

case NDArray ndarray -> array = ndarray;
case String s -> array = manager.create(s);
case Boolean bool -> array = manager.create(bool);
case Byte b -> array = manager.create(b);
case Integer integer -> array = manager.create(integer);
case Long l -> array = manager.create(l);
case Number num ->
// Default to float for all other numbers
// (not double, which would fail with MPS)
array = manager.create(num.floatValue());
case boolean[] arr -> array = manager.create(arr);
case byte[] arr -> array = manager.create(arr);
case int[] arr -> array = manager.create(arr);
case long[] arr -> array = manager.create(arr);
case float[] arr -> array = manager.create(arr);
case null, default ->
logger.warn("Unsupported optional argument: name={}, type={}",
es.getKey(), val == null ? "null" : val.getClass());
}
if (array != null) {
array.setName("args." + es.getKey());
arrays.add(array);
}
}
return arrays;
}

@Override
public Mat processOutput(TranslatorContext ctx, NDList list) {
var array = list.getFirst();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,19 +481,21 @@ private void configureRunning() {
* This may be called when the selected model is changed, or an existing model is downloaded.
*/
private void refreshModelChoice() {
var modelDir = InstanSegUtils.getModelDirectory().orElse(null);
if (modelDir == null)
return;

var model = selectedModel.get();
if (model == null)
return;

var modelDir = InstanSegUtils.getModelDirectory().orElse(null);
try {
model.checkIfDownloaded(modelDir.resolve("downloaded"), false);
} catch (IOException e) {
logger.debug("Error checking zip or RDF file(s); this shouldn't happen", e);
Dialogs.showErrorNotification(resources.getString("title"), resources.getString("error.checkingModel"));
}
boolean isDownloaded = modelDir != null && model.isValid();
if (!isDownloaded || qupath.getImageData() == null) {
if (!model.isValid() || qupath.getImageData() == null) {
return;
}
var numChannels = model.getNumChannels();
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/qupath/ext/instanseg/ui/Watcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ private void handleModelPathsChanged(ListChangeListener.Change<? extends Path> c
try {
set.add(InstanSegModel.fromPath(modelPath));
} catch (IOException e) {
logger.error("Unable to load model from path", e);
logger.error("Unable to load model from {}", modelPath, e);
}
}
models.setAll(set);
Expand Down
Loading