Skip to content

Commit

Permalink
Use ? as type parameter and properly handle more values
Browse files Browse the repository at this point in the history
  • Loading branch information
alanocallaghan committed Jan 8, 2025
1 parent 39f5f0e commit 1bda28c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/main/java/qupath/ext/instanseg/core/InstanSeg.java
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ public static final class Builder {
private Collection<? extends ColorTransforms.ColorTransform> channels;
private InstanSegModel model;
private Class<? extends PathObject> preferredOutputClass;
private Map<String, ?> optionalArgs;
private Map<String, Object> optionalArgs;

Builder() {}

Expand Down Expand Up @@ -664,7 +664,7 @@ public Builder outputAnnotations() {
* @return A modified builder.
*/
public Builder args(Map<String, ?> optionalArgs) {
this.optionalArgs = (optionalArgs);
this.optionalArgs.putAll(optionalArgs);
return this;
}

Expand Down
37 changes: 29 additions & 8 deletions src/main/java/qupath/ext/instanseg/core/MatTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,23 @@
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.math.BigDecimal;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.ext.djl.DjlTools;

import java.util.Arrays;


class MatTranslator implements Translator<Mat, Mat> {
private static final Logger logger = LoggerFactory.getLogger(MatTranslator.class);

private final String inputLayoutNd;
private final String outputLayoutNd;
private final int[] outputChannels;
private final Map<String, Object> optionalArgs;
private final Map<String, ?> optionalArgs;

/**
* Create a translator from InstanSeg input to output.
Expand All @@ -30,7 +31,7 @@ class MatTranslator implements Translator<Mat, Mat> {
* @param outputChannels Array of channels to output; if null or empty, output all channels.
* Values should be true for channels to output, false for channels to ignore.
*/
MatTranslator(String inputLayoutNd, String outputLayoutNd, boolean[] outputChannels, Map<String, Object> optionalArgs) {
MatTranslator(String inputLayoutNd, String outputLayoutNd, boolean[] outputChannels, Map<String, ?> optionalArgs) {
this.inputLayoutNd = inputLayoutNd;
this.outputLayoutNd = outputLayoutNd;
this.outputChannels = convertBooleanArray(outputChannels);
Expand Down Expand Up @@ -68,12 +69,32 @@ public NDList processInput(TranslatorContext ctx, Mat input) {
return out;
}

private static List<NDArray> sanitizeOptionalArgs(Map<String, Object> optionalArgs, NDManager manager) {
private static List<NDArray> sanitizeOptionalArgs(Map<String, ?> optionalArgs, NDManager manager) {
List<NDArray> arrays = new ArrayList<>();
for (var es : optionalArgs.entrySet()) {
var val = es.getValue();
if (val instanceof Double || val instanceof BigDecimal) {
NDArray array = manager.create(((Number) val).floatValue());
NDArray array = null;
switch (val) {
case NDArray ndarray -> array = ndarray;
case String s -> array = manager.create(s);
case Boolean bool -> array = manager.create(bool);
case Byte b -> array = manager.create(b);
case Integer integer -> array = manager.create(integer);
case Long l -> array = manager.create(l);
case Number num ->
// Default to float for all other numbers
// (not double, which would fail with MPS)
array = manager.create(num.floatValue());
case boolean[] arr -> array = manager.create(arr);
case byte[] arr -> array = manager.create(arr);
case int[] arr -> array = manager.create(arr);
case long[] arr -> array = manager.create(arr);
case float[] arr -> array = manager.create(arr);
case null, default ->
logger.warn("Unsupported optional argument: name={}, type={}",
es.getKey(), val == null ? "null" : val.getClass());
}
if (array != null) {
array.setName("args." + es.getKey());
arrays.add(array);
}
Expand Down

0 comments on commit 1bda28c

Please sign in to comment.