diff --git a/.github/workflows/build-main.yml b/.github/workflows/build-main.yml index 5ef5692..f09ae4e 100644 --- a/.github/workflows/build-main.yml +++ b/.github/workflows/build-main.yml @@ -13,11 +13,15 @@ jobs: steps: - uses: actions/checkout@v2 + + - name: Install Blosc + run: sudo apt install -y libblosc1 + - name: Set up Java uses: actions/setup-java@v3 with: - java-version: '8' - distribution: 'zulu' + java-version: '21' + distribution: 'temurin' cache: 'maven' - name: Set up CI environment run: .github/setup.sh diff --git a/.github/workflows/build-pr.yml b/.github/workflows/build-pr.yml index 925b576..ac4686d 100644 --- a/.github/workflows/build-pr.yml +++ b/.github/workflows/build-pr.yml @@ -11,11 +11,15 @@ jobs: steps: - uses: actions/checkout@v2 + + - name: Install Blosc + run: sudo apt install -y libblosc1 + - name: Set up Java uses: actions/setup-java@v3 with: - java-version: '8' - distribution: 'zulu' + java-version: '21' + distribution: 'temurin' cache: 'maven' - name: Set up CI environment run: .github/setup.sh diff --git a/pom.xml b/pom.xml index acb3fe2..a12e8ad 100644 --- a/pom.xml +++ b/pom.xml @@ -4,12 +4,12 @@ org.scijava pom-scijava - 31.1.0 + 37.0.0 org.janelia.saalfeldlab label-utilities-spark - 0.9.5-SNAPSHOT + 1.0.0-SNAPSHOT N5-Label-Multisets-Spark Spark based tools for label data. @@ -18,7 +18,7 @@ Saalfeld Lab - http://saalfeldlab.janelia.org/ + https://saalfeldlab.janelia.org/ @@ -30,7 +30,7 @@ ImageJ Forum - http://image.sc/ + https://image.sc/ @@ -38,7 +38,7 @@ hanslovsky Philipp Hanslovsky - http://imagej.net/User:Hanslovsky + https://imagej.net/User:Hanslovsky founder lead @@ -46,8 +46,19 @@ debugger reviewer support + + + + hulbertc + Caleb Hulbert + hulbertc@janelia.hhmi.org + HHMI Janelia + https://janelia.org/ + + developer maintainer + -5 @@ -83,7 +94,30 @@ sign,deploy-to-scijava org.janelia.saalfeldlab.label.spark - 1.0.0-beta-13 + 21 + 3.13.0 + true + ${javadoc.skip} + + 3.2.0 + 2.2.0 + 4.1.0 + 4.1.2 + 1.3.3 + 7.0.0 + 1.5.0 + 0.13.2 + 1.7.36 + + 0.5.1 + 0.3.2 + 0.1.2 + + 3.5.1 + ${spark.version} + 2.15.4 + + true @@ -97,27 +131,26 @@ org.janelia.saalfeldlab label-utilities - 0.5.0 + ${label-utilities.version} org.janelia.saalfeldlab label-utilities-n5 - 0.3.1 + ${label-utilities-n5.version} org.janelia.saalfeldlab imglib2-mutex-watershed - 0.1.2 + ${imglib2-mutex-watershed.version} org.apache.spark spark-core_2.12 - 3.2.1 + ${spark-core_2.12.version} net.imglib2 imglib2 - 5.13.0 net.imglib2 @@ -129,18 +162,7 @@ org.janelia.saalfeldlab - n5 - 2.2.0 - - - org.janelia.saalfeldlab - n5-imglib2 - 4.3.0 - - - org.janelia.saalfeldlab - n5-hdf5 - 1.0.4 + n5-universe net.imglib2 @@ -162,13 +184,18 @@ throwing-function 1.5.1 + + + + ch.qos.logback + logback-core + - com.fasterxml.jackson.core - jackson-databind - 2.12.7.1 + ch.qos.logback + logback-classic - + junit junit @@ -177,6 +204,29 @@ + + + + org.apache.maven.plugins + maven-compiler-plugin + + + --add-opens=java.base/sun.nio.ch=ALL-UNNAMED + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + --add-opens=java.base/sun.nio.ch=ALL-UNNAMED + + + + + + fatWithSpark @@ -185,7 +235,6 @@ org.apache.maven.plugins maven-shade-plugin - 3.2.4 @@ -244,7 +293,6 @@ org.apache.maven.plugins maven-shade-plugin - 3.2.4 diff --git a/src/.editorconfig b/src/.editorconfig index 046838e..b374b26 100644 --- a/src/.editorconfig +++ b/src/.editorconfig @@ -1,4 +1,4 @@ [*] -indent_style = tab -tab_width = unset +indent_style = tab +tab_width = unset max_line_length = 140 diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/LabelTools.java b/src/main/java/org/janelia/saalfeldlab/label/spark/LabelTools.java index 85a5b5d..ebc9847 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/LabelTools.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/LabelTools.java @@ -1,9 +1,5 @@ package org.janelia.saalfeldlab.label.spark; -import java.util.Arrays; -import java.util.Optional; -import java.util.concurrent.Callable; - import org.janelia.saalfeldlab.label.spark.LabelTools.Tool.FromString; import org.janelia.saalfeldlab.label.spark.affinities.AverageAffinities; import org.janelia.saalfeldlab.label.spark.affinities.MakePredictionMask; @@ -12,7 +8,6 @@ import org.janelia.saalfeldlab.label.spark.uniquelabels.ExtractUniqueLabelsPerBlock; import org.janelia.saalfeldlab.label.spark.uniquelabels.LabelToBlockMapping; import org.janelia.saalfeldlab.label.spark.uniquelabels.downsample.LabelListDownsampler; - import org.janelia.saalfeldlab.label.spark.watersheds.SparkWatersheds; import picocli.CommandLine; import picocli.CommandLine.Command; @@ -20,84 +15,82 @@ import picocli.CommandLine.Option; import picocli.CommandLine.Parameters; -public class LabelTools -{ - - public enum Tool - { - CONVERT( ConvertToLabelMultisetType::run ), - DOWNSAMPLE( SparkDownsampler::run ), - EXTRACT_UNIQUE_LABELS( ExtractUniqueLabelsPerBlock::run ), - DOWNSAMPLE_UNIQUE_LABELS( LabelListDownsampler::run ), - LABEL_TO_BLOCK_MAPPING( LabelToBlockMapping::run ), - WATERSHEDS( SparkWatersheds::run ), - MAKE_PREDICTION_MASK( MakePredictionMask::run ), - AVERAGE_AFFINITIES( AverageAffinities::run ); - - private interface ExceptionConsumer< T > - { - public void accept( T t ) throws Exception; +import java.util.Arrays; +import java.util.Optional; +import java.util.concurrent.Callable; + +public class LabelTools { + + public enum Tool { + CONVERT(ConvertToLabelMultisetType::run), + DOWNSAMPLE(SparkDownsampler::run), + EXTRACT_UNIQUE_LABELS(ExtractUniqueLabelsPerBlock::run), + DOWNSAMPLE_UNIQUE_LABELS(LabelListDownsampler::run), + LABEL_TO_BLOCK_MAPPING(LabelToBlockMapping::run), + WATERSHEDS(SparkWatersheds::run), + MAKE_PREDICTION_MASK(MakePredictionMask::run), + AVERAGE_AFFINITIES(AverageAffinities::run); + + private interface ExceptionConsumer { + public void accept(T t) throws Exception; } - private final ExceptionConsumer< String[] > run; + private final ExceptionConsumer run; + + private Tool(final ExceptionConsumer run) { - private Tool( final ExceptionConsumer< String[] > run ) - { this.run = run; } - public String getCmdLineRepresentation() - { + public String getCmdLineRepresentation() { + return this.name().toLowerCase(); } - public static Tool fromCmdLineRepresentation( final String representation ) - { - return Tool.valueOf( representation.replace( "-", "_" ).toUpperCase() ); + public static Tool fromCmdLineRepresentation(final String representation) { + + return Tool.valueOf(representation.replace("-", "_").toUpperCase()); } - public static class FromString implements ITypeConverter< Tool > - { + public static class FromString implements ITypeConverter { @Override - public Tool convert( final String str ) throws Exception - { - return Tool.fromCmdLineRepresentation( str ); + public Tool convert(final String str) throws Exception { + + return Tool.fromCmdLineRepresentation(str); } } } - @Command( name = "label-tools" ) - public static class CommandLineParameters implements Callable< Boolean > - { + @Command(name = "label-tools") + public static class CommandLineParameters implements Callable { @Parameters( index = "0", paramLabel = "TOOL", converter = FromString.class, - description = "Tool to run. Run multiset-tools --help/-h for specific help message. Current options are convert, downsample, extract-unique-labels, downsample-unique-labels, label-to-block-mapping" ) + description = "Tool to run. Run multiset-tools --help/-h for specific help message. Current options are convert, downsample, extract-unique-labels, downsample-unique-labels, label-to-block-mapping") private Tool tool; - @Option( names = { "-h", "--help" }, usageHelp = true, description = "display a help message" ) + @Option(names = {"-h", "--help"}, usageHelp = true, description = "display a help message") private boolean helpRequested; @Override - public Boolean call() throws Exception - { + public Boolean call() throws Exception { + return true; } } - public static void main( final String[] args ) throws Exception - { + public static void main(final String[] args) throws Exception { + final CommandLineParameters params = new CommandLineParameters(); - final Boolean paramsParsedSuccessfully = Optional.ofNullable( CommandLine.call( params, System.err, args.length > 0 ? args[ 0 ] : "--help" ) ).orElse( false ); - if ( paramsParsedSuccessfully ) - { - params.tool.run.accept( Arrays.copyOfRange( args, 1, args.length ) ); + final Boolean paramsParsedSuccessfully = Optional.ofNullable(CommandLine.call(params, System.err, args.length > 0 ? args[0] : "--help")).orElse(false); + if (paramsParsedSuccessfully) { + params.tool.run.accept(Arrays.copyOfRange(args, 1, args.length)); } } diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/N5Helpers.java b/src/main/java/org/janelia/saalfeldlab/label/spark/N5Helpers.java index 7d1a82d..cfb6f21 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/N5Helpers.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/N5Helpers.java @@ -1,22 +1,19 @@ package org.janelia.saalfeldlab.label.spark; +import org.janelia.saalfeldlab.n5.N5Exception; +import org.janelia.saalfeldlab.n5.N5Reader; +import org.janelia.saalfeldlab.n5.N5Writer; +import org.janelia.saalfeldlab.n5.universe.N5Factory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.io.IOException; import java.lang.invoke.MethodHandles; import java.util.Arrays; import java.util.Optional; import java.util.regex.Pattern; -import org.janelia.saalfeldlab.n5.N5FSReader; -import org.janelia.saalfeldlab.n5.N5FSWriter; -import org.janelia.saalfeldlab.n5.N5Reader; -import org.janelia.saalfeldlab.n5.N5Writer; -import org.janelia.saalfeldlab.n5.hdf5.N5HDF5Reader; -import org.janelia.saalfeldlab.n5.hdf5.N5HDF5Writer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class N5Helpers -{ +public class N5Helpers { public static final String LABEL_MULTISETTYPE_KEY = "isLabelMultiset"; @@ -26,166 +23,153 @@ public class N5Helpers public static final String MAX_ID_KEY = "maxId"; - private static final Logger LOG = LoggerFactory.getLogger( MethodHandles.lookup().lookupClass() ); + private static final Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + public static N5Reader n5Reader(final String base, final int... defaultCellDimensions) throws IOException { - public static N5Reader n5Reader( final String base, final int... defaultCellDimensions ) throws IOException - { - return isHDF( base ) ? new N5HDF5Reader( base, defaultCellDimensions ) : new N5FSReader( base ); + final var factory = new N5Factory(); + factory.hdf5DefaultBlockSize(defaultCellDimensions); + return factory.openReader(base); } - public static N5Writer n5Writer( final String base, final int... defaultCellDimensions ) throws IOException - { - return isHDF( base ) ? new N5HDF5Writer( base, defaultCellDimensions ) : new N5FSWriter( base ); + public static N5Writer n5Writer(final String base, final int... defaultCellDimensions) throws IOException { + + final var factory = new N5Factory(); + factory.hdf5DefaultBlockSize(defaultCellDimensions); + return factory.openWriter(base); } - public static boolean isHDF( final String base ) - { - LOG.debug( "Checking {} for HDF", base ); - final boolean isHDF = Pattern.matches( "^h5://", base ) || Pattern.matches( "^.*\\.(hdf|h5)$", base ); - LOG.debug( "{} is hdf5? {}", base, isHDF ); + public static boolean isHDF(final String base) { + + LOG.debug("Checking {} for HDF", base); + final boolean isHDF = Pattern.matches("^h5://", base) || Pattern.matches("^.*\\.(hdf|h5)$", base); + LOG.debug("{} is hdf5? {}", base, isHDF); return isHDF; } - public static long[] blockPos( final long[] position, final int[] blockSize ) - { - final long[] blockPos = new long[ position.length ]; - Arrays.setAll( blockPos, d -> position[ d ] / blockSize[ d ] ); + public static long[] blockPos(final long[] position, final int[] blockSize) { + + final long[] blockPos = new long[position.length]; + Arrays.setAll(blockPos, d -> position[d] / blockSize[d]); return blockPos; } - public static boolean isMultiScale( final N5Reader reader, final String dataset ) throws IOException - { - return Optional.ofNullable( reader.getAttribute( dataset, MULTISCALE_KEY, Boolean.class ) ).orElse( false ); + public static boolean isMultiScale(final N5Reader reader, final String dataset) throws IOException { + + return Optional.ofNullable(reader.getAttribute(dataset, MULTISCALE_KEY, Boolean.class)).orElse(false); } - public static String[] listScaleDatasets( final N5Reader n5, final String group ) throws IOException - { + public static String[] listScaleDatasets(final N5Reader n5, final String group) throws IOException { + final String[] scaleDirs = Arrays - .stream( n5.list( group ) ) - .filter( s -> s.matches( "^s\\d+$" ) ) - .filter( s -> { - try - { - return n5.datasetExists( group + "/" + s ); - } - catch ( final IOException e ) - { + .stream(n5.list(group)) + .filter(s -> s.matches("^s\\d+$")) + .filter(s -> { + try { + return n5.datasetExists(group + "/" + s); + } catch (final N5Exception e) { return false; } - } ) - .toArray( String[]::new ); + }) + .toArray(String[]::new); - LOG.debug( "Found these scale dirs: {}", Arrays.toString( scaleDirs ) ); + LOG.debug("Found these scale dirs: {}", Arrays.toString(scaleDirs)); return scaleDirs; } - public static String[] listAndSortScaleDatasets( final N5Reader n5, final String group ) throws IOException - { - final String[] scaleDirs = listScaleDatasets( n5, group ); - sortScaleDatasets( scaleDirs ); + public static String[] listAndSortScaleDatasets(final N5Reader n5, final String group) throws IOException { - LOG.debug( "Sorted scale dirs: {}", Arrays.toString( scaleDirs ) ); + final String[] scaleDirs = listScaleDatasets(n5, group); + sortScaleDatasets(scaleDirs); + + LOG.debug("Sorted scale dirs: {}", Arrays.toString(scaleDirs)); return scaleDirs; } - public static void sortScaleDatasets( final String[] scaleDatasets ) - { - Arrays.sort( scaleDatasets, ( f1, f2 ) -> { + public static void sortScaleDatasets(final String[] scaleDatasets) { + + Arrays.sort(scaleDatasets, (f1, f2) -> { return Integer.compare( - Integer.parseInt( f1.replaceAll( "[^\\d]", "" ) ), - Integer.parseInt( f2.replaceAll( "[^\\d]", "" ) ) ); - } ); + Integer.parseInt(f1.replaceAll("[^\\d]", "")), + Integer.parseInt(f2.replaceAll("[^\\d]", ""))); + }); } - public static < T > T revertInplaceAndReturn( final T t, final boolean revert ) - { - if ( !revert ) { return t; } - - if ( t instanceof boolean[] ) - { - final boolean[] arr = ( boolean[] ) t; - for ( int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k ) - { - final boolean v = arr[ 0 ]; - arr[ 0 ] = arr[ k ]; - arr[ k ] = v; + public static T reverseInplaceAndReturn(final T t, final boolean reverse) { + + if (!reverse) { + return t; + } + + if (t instanceof boolean[]) { + final boolean[] arr = (boolean[])t; + for (int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k) { + final boolean v = arr[0]; + arr[0] = arr[k]; + arr[k] = v; } } - if ( t instanceof byte[] ) - { - final byte[] arr = ( byte[] ) t; - for ( int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k ) - { - final byte v = arr[ 0 ]; - arr[ 0 ] = arr[ k ]; - arr[ k ] = v; + if (t instanceof byte[]) { + final byte[] arr = (byte[])t; + for (int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k) { + final byte v = arr[0]; + arr[0] = arr[k]; + arr[k] = v; } } - if ( t instanceof char[] ) - { - final char[] arr = ( char[] ) t; - for ( int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k ) - { - final char v = arr[ 0 ]; - arr[ 0 ] = arr[ k ]; - arr[ k ] = v; + if (t instanceof char[]) { + final char[] arr = (char[])t; + for (int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k) { + final char v = arr[0]; + arr[0] = arr[k]; + arr[k] = v; } } - if ( t instanceof short[] ) - { - final short[] arr = ( short[] ) t; - for ( int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k ) - { - final short v = arr[ 0 ]; - arr[ 0 ] = arr[ k ]; - arr[ k ] = v; + if (t instanceof short[]) { + final short[] arr = (short[])t; + for (int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k) { + final short v = arr[0]; + arr[0] = arr[k]; + arr[k] = v; } } - if ( t instanceof int[] ) - { - final int[] arr = ( int[] ) t; - for ( int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k ) - { - final int v = arr[ 0 ]; - arr[ 0 ] = arr[ k ]; - arr[ k ] = v; + if (t instanceof int[]) { + final int[] arr = (int[])t; + for (int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k) { + final int v = arr[0]; + arr[0] = arr[k]; + arr[k] = v; } } - if ( t instanceof long[] ) - { - final long[] arr = ( long[] ) t; - for ( int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k ) - { - final long v = arr[ 0 ]; - arr[ 0 ] = arr[ k ]; - arr[ k ] = v; + if (t instanceof long[]) { + final long[] arr = (long[])t; + for (int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k) { + final long v = arr[0]; + arr[0] = arr[k]; + arr[k] = v; } } - if ( t instanceof float[] ) - { - final float[] arr = ( float[] ) t; - for ( int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k ) - { - final float v = arr[ 0 ]; - arr[ 0 ] = arr[ k ]; - arr[ k ] = v; + if (t instanceof float[]) { + final float[] arr = (float[])t; + for (int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k) { + final float v = arr[0]; + arr[0] = arr[k]; + arr[k] = v; } } - if ( t instanceof double[] ) - { - final double[] arr = ( double[] ) t; - for ( int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k ) - { - final double v = arr[ 0 ]; - arr[ 0 ] = arr[ k ]; - arr[ k ] = v; + if (t instanceof double[]) { + final double[] arr = (double[])t; + for (int i = 0, k = arr.length - 1; i < arr.length / 2; ++i, --k) { + final double v = arr[0]; + arr[0] = arr[k]; + arr[k] = v; } } diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/AverageAffinities.java b/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/AverageAffinities.java index 4ab32e5..a2d5699 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/AverageAffinities.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/AverageAffinities.java @@ -122,7 +122,7 @@ private static class Args implements Callable { int[] blockSize = null; @Expose - @CommandLine.Option(names = "--blocks-per-task", paramLabel = "BLOCKS_PER_TASK", description = "How many blocks to combine for watersheds/connected components (one value per dimension)", split=",") + @CommandLine.Option(names = "--blocks-per-task", paramLabel = "BLOCKS_PER_TASK", description = "How many blocks to combine for watersheds/connected components (one value per dimension)", split = ",") int[] blocksPerTask = {1, 1, 1}; @CommandLine.Option(names = "--json-pretty-print", defaultValue = "true") @@ -175,14 +175,15 @@ public Void call() throws Exception { gliaMaskThreshold = gliaMaskThreshold >= gliaMaskMax ? Double.NEGATIVE_INFINITY : gliaMaskThreshold <= gliaMaskMin - ? Double.POSITIVE_INFINITY - : gliaMaskMax - gliaMaskThreshold; + ? Double.POSITIVE_INFINITY + : gliaMaskMax - gliaMaskThreshold; System.out.println("Inverse glia mask threshold is " + gliaMaskThreshold); return null; } public Offset[] enumeratedOffsets() { + final Offset[] enumeratedOffsets = new Offset[this.offsets.length]; for (int i = 0; i < offsets.length; ++i) { final Offset o = this.offsets[i]; @@ -195,6 +196,7 @@ public Offset[] enumeratedOffsets() { } public static void main(String[] argv) throws IOException { + run(argv); } @@ -203,8 +205,6 @@ public static void run(String[] argv) throws IOException { final Args args = new Args(); CommandLine.call(args, argv); - - final N5WriterSupplier n5InSupplier = new N5WriterSupplier(args.inputContainer, false, false); final DatasetAttributes inputAttributes = n5InSupplier.get().getDatasetAttributes(args.affinities); final N5WriterSupplier n5OutSupplier = new N5WriterSupplier(args.outputContainer, args.prettyPrint, args.disbaleHtmlEscape); @@ -261,160 +261,161 @@ private static void run( LOG.info("Parallelizing over blocks {}", blocks); - Map, Boolean> returnCodes = sc .parallelize(blocks) - .mapToPair(block -> new Tuple2<>(block, (RandomAccessibleInterval) N5Utils.open(n5InSupplier.get(), affinities))) + .mapToPair(block -> new Tuple2<>(block, (RandomAccessibleInterval)N5Utils.open(n5InSupplier.get(), affinities))) .mapToPair(p -> { boolean wasSuccessful = false; - final long[] min = p._1()._1(); - final long[] max = p._1()._2(); - final RandomAccessible maskRA = maskSupplier.get(); - final RandomAccessibleInterval averagedAffinities = ArrayImgs.doubles(Intervals.dimensionsAsLongArray(new FinalInterval(min, max))); - final RandomAccessibleInterval slice1 = Views.translate(averagedAffinities, min); - final UnsignedByteType zero = new UnsignedByteType(0); - final Consumer invalidAction = nanOoobAffinities - ? t -> t.add(NAN) - : t -> {}; - for (final Offset offset : enumeratedOffsets) { - final RandomAccessible affs = Views.extendZero(Views.hyperSlice(p._2(), min.length, (long) offset.channelIndex())); - final IntervalView expanded1 = Views.interval(Views.extendZero(slice1), expandAsNeeded(slice1, offset.offset())); - final IntervalView expanded2 = Views.interval(Views.offset(Views.extendZero(slice1), offset.offset()), expanded1); - - LOG.debug( - "Averaging {} voxels for offset {} : [{}:{}] ({})", - Intervals.numElements(expanded1), - offset, - Intervals.minAsLongArray(expanded1), - Intervals.minAsLongArray(expanded1), - Intervals.dimensionsAsLongArray(expanded1)); - - final Cursor source = Views.flatIterable(Views.interval(Converters.convert(affs, new RealDoubleConverter<>(), new DoubleType()), expanded1)).cursor(); - final Cursor mask = Views.flatIterable(Views.interval(maskRA, expanded1)).cursor(); - final Cursor target1 = Views.flatIterable(expanded1).cursor(); - final Cursor target2 = Views.flatIterable(expanded2).cursor(); - - final StopWatch sw = StopWatch.createAndStart(); - final double[] minMax = {Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY}; - while (source.hasNext()) { - final DoubleType s = source.next(); - minMax[0] = Math.min(s.getRealDouble(), minMax[0]); - minMax[1] = Math.max(s.getRealDouble(), minMax[1]); - final boolean isInvalid = mask.next().valueEquals(zero); - target1.fwd(); - target2.fwd(); - - if (isInvalid) { - invalidAction.accept(target1.get()); - invalidAction.accept(target2.get()); - } else if (Double.isFinite(s.getRealDouble())) { - target1.get().add(s); - target2.get().add(s); - } + final long[] min = p._1()._1(); + final long[] max = p._1()._2(); + final RandomAccessible maskRA = maskSupplier.get(); + final RandomAccessibleInterval averagedAffinities = ArrayImgs.doubles(Intervals.dimensionsAsLongArray(new FinalInterval(min, max))); + final RandomAccessibleInterval slice1 = Views.translate(averagedAffinities, min); + final UnsignedByteType zero = new UnsignedByteType(0); + final Consumer invalidAction = nanOoobAffinities + ? t -> t.add(NAN) + : t -> { + }; + for (final Offset offset : enumeratedOffsets) { + final RandomAccessible affs = Views.extendZero(Views.hyperSlice(p._2(), min.length, (long)offset.channelIndex())); + final IntervalView expanded1 = Views.interval(Views.extendZero(slice1), expandAsNeeded(slice1, offset.offset())); + final IntervalView expanded2 = Views.interval(Views.offset(Views.extendZero(slice1), offset.offset()), expanded1); + + LOG.debug( + "Averaging {} voxels for offset {} : [{}:{}] ({})", + Intervals.numElements(expanded1), + offset, + Intervals.minAsLongArray(expanded1), + Intervals.minAsLongArray(expanded1), + Intervals.dimensionsAsLongArray(expanded1)); + + final Cursor source = Views.flatIterable(Views.interval(Converters.convert(affs, new RealDoubleConverter<>(), new DoubleType()), expanded1)).cursor(); + final Cursor mask = Views.flatIterable(Views.interval(maskRA, expanded1)).cursor(); + final Cursor target1 = Views.flatIterable(expanded1).cursor(); + final Cursor target2 = Views.flatIterable(expanded2).cursor(); + + final StopWatch sw = StopWatch.createAndStart(); + final double[] minMax = {Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY}; + while (source.hasNext()) { + final DoubleType s = source.next(); + minMax[0] = Math.min(s.getRealDouble(), minMax[0]); + minMax[1] = Math.max(s.getRealDouble(), minMax[1]); + final boolean isInvalid = mask.next().valueEquals(zero); + target1.fwd(); + target2.fwd(); + + if (isInvalid) { + invalidAction.accept(target1.get()); + invalidAction.accept(target2.get()); + } else if (Double.isFinite(s.getRealDouble())) { + target1.get().add(s); + target2.get().add(s); } - sw.stop(); - - LOG.debug( - "Averaged {} voxels for offset {} : [{}:{}] ({}) in {}s", - Intervals.numElements(expanded1), - offset, - Intervals.minAsLongArray(expanded1), - Intervals.minAsLongArray(expanded1), - Intervals.dimensionsAsLongArray(expanded1), - sw.nanoTime() * 1e-9); - LOG.debug("Min max: {}", minMax); - - // TODO LoopBuilder does not work in Spark -// LoopBuilder -// .setImages(Views.interval(Converters.convert(affs, new RealDoubleConverter<>(), new DoubleType()), expanded1), expanded1, expanded2) -// .forEachPixel((a, s1, s2) -> { -// if (Double.isFinite(a.getRealDouble())) { -// s1.add(a); -// s2.add(a); -// } -// }); } + sw.stop(); + + LOG.debug( + "Averaged {} voxels for offset {} : [{}:{}] ({}) in {}s", + Intervals.numElements(expanded1), + offset, + Intervals.minAsLongArray(expanded1), + Intervals.minAsLongArray(expanded1), + Intervals.dimensionsAsLongArray(expanded1), + sw.nanoTime() * 1e-9); + LOG.debug("Min max: {}", minMax); + + // TODO LoopBuilder does not work in Spark + // LoopBuilder + // .setImages(Views.interval(Converters.convert(affs, new RealDoubleConverter<>(), new DoubleType()), expanded1), expanded1, expanded2) + // .forEachPixel((a, s1, s2) -> { + // if (Double.isFinite(a.getRealDouble())) { + // s1.add(a); + // s2.add(a); + // } + // }); + } - // TODO combine the three loops into a single loop. Probably not that much overhead, though - // TODO only write out ROI of mask. Outside doesn't exist anyway! - final double factor = 0.5 / enumeratedOffsets.length; - Views.iterable(slice1).forEach(px -> px.mul(factor)); - - final RandomAccessible invertedGliaMask = invertedGliaMaskSupplier.get(); -// final IntervalView translatedSlice = Views.translate(slice1, min); - Views.interval(Views.pair(invertedGliaMask, slice1), slice1).forEach(pair -> pair.getB().mul(pair.getA().getRealDouble())); - LOG.debug("Glia mask threshold is {}", gliaMaskThreshold); - if (!Double.isNaN(gliaMaskThreshold)) { - LOG.debug("Setting values with inverted glia mask values < {} to NaN", gliaMaskThreshold); - Views.interval(Views.pair(invertedGliaMask, slice1), slice1) - .forEach(pair -> pair.getB().set(pair.getA().getRealDouble() <= gliaMaskThreshold ? Double.NaN : pair.getB().getRealDouble())); - } + // TODO combine the three loops into a single loop. Probably not that much overhead, though + // TODO only write out ROI of mask. Outside doesn't exist anyway! + final double factor = 0.5 / enumeratedOffsets.length; + Views.iterable(slice1).forEach(px -> px.mul(factor)); + + final RandomAccessible invertedGliaMask = invertedGliaMaskSupplier.get(); + // final IntervalView translatedSlice = Views.translate(slice1, min); + Views.interval(Views.pair(invertedGliaMask, slice1), slice1).forEach(pair -> pair.getB().mul(pair.getA().getRealDouble())); + LOG.debug("Glia mask threshold is {}", gliaMaskThreshold); + if (!Double.isNaN(gliaMaskThreshold)) { + LOG.debug("Setting values with inverted glia mask values < {} to NaN", gliaMaskThreshold); + Views.interval(Views.pair(invertedGliaMask, slice1), slice1) + .forEach(pair -> pair.getB().set(pair.getA().getRealDouble() <= gliaMaskThreshold ? Double.NaN : pair.getB().getRealDouble())); + } - final N5Writer n5 = n5OutSupplier.get(); - final DatasetAttributes attributes = n5.getDatasetAttributes(averaged); - final long[] gridOffset = min.clone(); - Arrays.setAll(gridOffset, d -> gridOffset[d] / attributes.getBlockSize()[d]); - - final List saveTheseBlocks = Grids.collectAllContainedIntervals(min, max, attributes.getBlockSize()); - final CellGrid grid = new CellGrid(attributes.getDimensions(), attributes.getBlockSize()); - final boolean[] success = new boolean[saveTheseBlocks.size()]; - final N5Writer n5out = n5OutSupplier.get(); - for (int attempt = 0; attempt < 4; ++attempt) { - for (int i = 0; i < saveTheseBlocks.size(); ++i) { - - if (success[i]) continue; - - final Interval saveThisBlock = saveTheseBlocks.get(i); - final long[] saveThisBlockAt = Intervals.minAsLongArray(saveThisBlock); - grid.getCellPosition(saveThisBlockAt, saveThisBlockAt); - final int[] size = Intervals.dimensionsAsIntArray(saveThisBlock); - final DataBlock block = (DataBlock) DataType.FLOAT32.createDataBlock(size, saveThisBlockAt); - final Cursor c = Views.flatIterable(Views.interval(slice1, saveThisBlock)).cursor(); - for (int k = 0; c.hasNext(); ++k) { - block.getData()[k] = c.next().getRealFloat(); - } + final N5Writer n5 = n5OutSupplier.get(); + final DatasetAttributes attributes = n5.getDatasetAttributes(averaged); + final long[] gridOffset = min.clone(); + Arrays.setAll(gridOffset, d -> gridOffset[d] / attributes.getBlockSize()[d]); + + final List saveTheseBlocks = Grids.collectAllContainedIntervals(min, max, attributes.getBlockSize()); + final CellGrid grid = new CellGrid(attributes.getDimensions(), attributes.getBlockSize()); + final boolean[] success = new boolean[saveTheseBlocks.size()]; + final N5Writer n5out = n5OutSupplier.get(); + for (int attempt = 0; attempt < 4; ++attempt) { + for (int i = 0; i < saveTheseBlocks.size(); ++i) { + + if (success[i]) + continue; + + final Interval saveThisBlock = saveTheseBlocks.get(i); + final long[] saveThisBlockAt = Intervals.minAsLongArray(saveThisBlock); + grid.getCellPosition(saveThisBlockAt, saveThisBlockAt); + final int[] size = Intervals.dimensionsAsIntArray(saveThisBlock); + final DataBlock block = (DataBlock)DataType.FLOAT32.createDataBlock(size, saveThisBlockAt); + final Cursor c = Views.flatIterable(Views.interval(slice1, saveThisBlock)).cursor(); + for (int k = 0; c.hasNext(); ++k) { + block.getData()[k] = c.next().getRealFloat(); + } + success[i] = false; + try { + n5out.writeBlock(averaged, attributes, block); + final DataBlock reloaded = (DataBlock)n5out.readBlock(averaged, attributes, saveThisBlockAt); + success[i] = Arrays.equals(block.getData(), reloaded.getData()); + } catch (Exception e) { success[i] = false; - try { - n5out.writeBlock(averaged, attributes, block); - final DataBlock reloaded = (DataBlock) n5out.readBlock(averaged, attributes, saveThisBlockAt); - success[i] = Arrays.equals(block.getData(), reloaded.getData()); - } catch (Exception e) { - success[i] = false; - } } } - final List failedBlocks = IntStream - .range(0, saveTheseBlocks.size()) - .filter(idx -> !success[idx]) - .mapToObj(saveTheseBlocks::get) - .collect(Collectors.toList()); + } + final List failedBlocks = IntStream + .range(0, saveTheseBlocks.size()) + .filter(idx -> !success[idx]) + .mapToObj(saveTheseBlocks::get) + .collect(Collectors.toList()); try { if (failedBlocks.size() > 0) throw new RuntimeException("Unable to save these blocks in 4 attempts: " + failedBlocks); else wasSuccessful = true; -// N5Utils.saveBlock( -// Converters.convert(slice1, new RealFloatConverter<>(), new FloatType()), -// n5OutSupplier.get(), -// averaged, -// attributes, -// gridOffset); -// wasSuccessful = true; -// final RandomAccessibleInterval reloaded = Views.interval(N5Utils.open(n5OutSupplier.get(), averaged), new FinalInterval(min, max)); -// final Cursor r = Views.flatIterable(reloaded).cursor(); -// final Cursor s = Views.flatIterable(Converters.convert(slice1, new RealFloatConverter<>(), new FloatType())).cursor(); -// while (r.hasNext() && wasSuccessful) { -// wasSuccessful = r.next().valueEquals(s.next()); -// } -// if (!wasSuccessful) -// throw new RuntimeException("Not successful for block " + Arrays.toString(min) + " " + Arrays.toString(max)); -// else -// LOG.info("Successfully saved block {}", gridOffset); + // N5Utils.saveBlock( + // Converters.convert(slice1, new RealFloatConverter<>(), new FloatType()), + // n5OutSupplier.get(), + // averaged, + // attributes, + // gridOffset); + // wasSuccessful = true; + // final RandomAccessibleInterval reloaded = Views.interval(N5Utils.open(n5OutSupplier.get(), averaged), new FinalInterval(min, max)); + // final Cursor r = Views.flatIterable(reloaded).cursor(); + // final Cursor s = Views.flatIterable(Converters.convert(slice1, new RealFloatConverter<>(), new FloatType())).cursor(); + // while (r.hasNext() && wasSuccessful) { + // wasSuccessful = r.next().valueEquals(s.next()); + // } + // if (!wasSuccessful) + // throw new RuntimeException("Not successful for block " + Arrays.toString(min) + " " + Arrays.toString(max)); + // else + // LOG.info("Successfully saved block {}", gridOffset); } catch (final Exception e) { wasSuccessful = false; -// throw e instanceof RuntimeException ? (RuntimeException) e : new RuntimeException(e); + // throw e instanceof RuntimeException ? (RuntimeException) e : new RuntimeException(e); } return new Tuple2<>(p._1(), wasSuccessful); }) @@ -439,18 +440,21 @@ private static void run( } private static long[] ignoreLast(final long[] dims) { + final long[] newDims = new long[dims.length - 1]; Arrays.setAll(newDims, d -> dims[d]); return newDims; } private static Interval translate(Interval interval, final long[] translation) { + for (int d = 0; d < translation.length; ++d) interval = Intervals.translate(interval, translation[d], d); return interval; } private static long[] abs(long... array) { + final long[] abs = new long[array.length]; Arrays.setAll(abs, d -> Math.abs(array[d])); return abs; @@ -460,6 +464,7 @@ private static Interval expandAsNeeded( final Interval source, final long[] offsets ) { + final long[] min = Intervals.minAsLongArray(source); final long[] max = Intervals.maxAsLongArray(source); @@ -476,34 +481,37 @@ else if (offset > 0) } private static int[] subArray(final int[] array, int start, int stop) { + final int[] result = new int[stop - start]; Arrays.setAll(result, d -> array[d] + start); return result; } - private static class MaskSupplier implements Serializable{ + private static class MaskSupplier implements Serializable { private final N5WriterSupplier container; private final String dataset; private MaskSupplier(N5WriterSupplier container, String dataset, long[] fovDiff) { + this.container = container; this.dataset = dataset; } public RandomAccessible get() throws IOException { + RandomAccessibleInterval rai = container.get().getDatasetAttributes(dataset).getDataType() == DataType.UINT8 ? N5Utils.open(container.get(), dataset) : getAsUnsignedByteType(); // TODO this assumes zero offset in the affinities, consider affinity offset instead! - final double[] resolution = Optional.ofNullable(container.get().getAttribute(dataset, "resolution", double[].class)).orElse(new double[] {1.0, 1.0, 1.0}); - final double[] offset = Optional.ofNullable(container.get().getAttribute(dataset, "offset", double[].class)).orElse(new double[] {0.0, 0.0, 0.0}); + final double[] resolution = Optional.ofNullable(container.get().getAttribute(dataset, "resolution", double[].class)).orElse(new double[]{1.0, 1.0, 1.0}); + final double[] offset = Optional.ofNullable(container.get().getAttribute(dataset, "offset", double[].class)).orElse(new double[]{0.0, 0.0, 0.0}); final long[] longOffset = new long[3]; for (int d = 0; d < longOffset.length; ++d) { final double r = offset[d] / resolution[d]; - final long l = (long) r; + final long l = (long)r; longOffset[d] = l; assert r == l; } @@ -511,6 +519,7 @@ public RandomAccessible get() throws IOException { } private & NativeType> RandomAccessibleInterval getAsUnsignedByteType() throws IOException { + final RandomAccessibleInterval rai = N5Utils.open(container.get(), dataset); return Converters.convert(rai, (s, t) -> t.setInteger(s.getIntegerLong()), new UnsignedByteType()); } @@ -523,15 +532,18 @@ private static class ConstantValueRandomAccessibleSupplier implements Serializab private final int nDim; private ConstantValueRandomAccessibleSupplier(double value) { + this.value = value; this.nDim = 3; } public RandomAccessible get() { + return getConstantMask(); } public RandomAccessible getConstantMask() { + final FloatType ft = new FloatType(); ft.setReal(value); return ConstantUtils.constantRandomAccessible(ft, nDim); @@ -549,6 +561,7 @@ private static class GliaMaskSupplier implements Serializable, Supplier getChecked() throws IOException { + final RandomAccessibleInterval data = readAndConvert(container.get()); final FloatType extension = new FloatType(); extension.setReal(minBound); @@ -565,16 +579,18 @@ public RandomAccessible getChecked() throws IOException { @Override public RandomAccessible get() { + return ThrowingSupplier.unchecked(this::getChecked).get(); } private RandomAccessibleInterval getAndConvertIfNecessary(final N5Reader reader) throws IOException { + final DataType dtype = reader.getDatasetAttributes(dataset).getDataType(); switch (dtype) { - case FLOAT32: - return N5Utils.open(reader, dataset); - default: - return readAndConvert(reader); + case FLOAT32: + return N5Utils.open(reader, dataset); + default: + return readAndConvert(reader); } } diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/MakePredictionMask.java b/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/MakePredictionMask.java index 2f74371..32965a5 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/MakePredictionMask.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/MakePredictionMask.java @@ -1,7 +1,11 @@ package org.janelia.saalfeldlab.label.spark.affinities; import com.google.gson.annotations.Expose; -import net.imglib2.*; +import net.imglib2.FinalInterval; +import net.imglib2.FinalRealInterval; +import net.imglib2.Interval; +import net.imglib2.RandomAccessible; +import net.imglib2.RandomAccessibleInterval; import net.imglib2.algorithm.util.Grids; import net.imglib2.converter.Converters; import net.imglib2.img.array.ArrayImg; @@ -16,7 +20,11 @@ import net.imglib2.view.Views; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; -import org.janelia.saalfeldlab.n5.*; +import org.janelia.saalfeldlab.n5.DataType; +import org.janelia.saalfeldlab.n5.DatasetAttributes; +import org.janelia.saalfeldlab.n5.GzipCompression; +import org.janelia.saalfeldlab.n5.N5FSReader; +import org.janelia.saalfeldlab.n5.N5Writer; import org.janelia.saalfeldlab.n5.imglib2.N5Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,7 +64,7 @@ private static class Args implements Callable { String inputDataset = null; @Expose - @CommandLine.Option(names = "--input-dataset-size", paramLabel = "INPUT_DATASET_SIZE", description = "In voxels. One of INPUT_DATASET_SIZE and INPUT_DATASET must be specified", split=",") + @CommandLine.Option(names = "--input-dataset-size", paramLabel = "INPUT_DATASET_SIZE", description = "In voxels. One of INPUT_DATASET_SIZE and INPUT_DATASET must be specified", split = ",") long[] inputDatasetSize; @Expose @@ -64,11 +72,11 @@ private static class Args implements Callable { Boolean inputIsMask; @Expose - @CommandLine.Option(names = "--mask-dataset", paramLabel = "MASK_DATASET", description = "Path to mask dataset in mask container. Will be written into", required=true) + @CommandLine.Option(names = "--mask-dataset", paramLabel = "MASK_DATASET", description = "Path to mask dataset in mask container. Will be written into", required = true) String maskDataset = null; @Expose - @CommandLine.Option(names = "--output-dataset-size", paramLabel = "OUTPUT_DATASET_SIZE", description = "In voxels. ", required = true, split=",") + @CommandLine.Option(names = "--output-dataset-size", paramLabel = "OUTPUT_DATASET_SIZE", description = "In voxels. ", required = true, split = ",") long[] outputDatasetSIze; @Expose @@ -99,29 +107,28 @@ private static class Args implements Callable { @CommandLine.Option(names = "--blocks-per-task", defaultValue = "1,1,1", split = ",") int[] blocksPerTask; -// @Expose -// @CommandLine.Option(names = "--raw-container", paramLabel = "RAW_CONTAINER", description = "Path to raw container. Defaults to MASK_CONTAINER.") -// String rawContainer = null; -// -// @Expose -// @CommandLine.Option(names = "prediction-container", paramLabel = "PREDICTION_CONTAINER", description = "Path to prediction container. Defaults to MASK_CONTAINER") -// String predictionContainer = null; -// -// @Expose -// @CommandLine.Option(names = "--raw-dataset", paramLabel = "RAW_DATASET", description = "Path of raw dataset dataset in RAW_CONTAINER.", defaultValue = "volumes/raw") -// String rawDataset; -// -// @Expose -// @CommandLine.Option(names = "--prediction-dataset", paramLabel = "PREDICTION_DATASET", description = "Path to prediction dataset in PREDICTION_CONTAINER", defaultValue = "volumes/affinities/predictions") -// String predictionDataset; - + // @Expose + // @CommandLine.Option(names = "--raw-container", paramLabel = "RAW_CONTAINER", description = "Path to raw container. Defaults to MASK_CONTAINER.") + // String rawContainer = null; + // + // @Expose + // @CommandLine.Option(names = "prediction-container", paramLabel = "PREDICTION_CONTAINER", description = "Path to prediction container. Defaults to MASK_CONTAINER") + // String predictionContainer = null; + // + // @Expose + // @CommandLine.Option(names = "--raw-dataset", paramLabel = "RAW_DATASET", description = "Path of raw dataset dataset in RAW_CONTAINER.", defaultValue = "volumes/raw") + // String rawDataset; + // + // @Expose + // @CommandLine.Option(names = "--prediction-dataset", paramLabel = "PREDICTION_DATASET", description = "Path to prediction dataset in PREDICTION_CONTAINER", defaultValue = "volumes/affinities/predictions") + // String predictionDataset; @Override public Void call() throws Exception { -// rawContainer = rawContainer == null ? maskContainer : rawContainer; -// -// predictionContainer = predictionContainer == null ? maskContainer : predictionContainer; + // rawContainer = rawContainer == null ? maskContainer : rawContainer; + // + // predictionContainer = predictionContainer == null ? maskContainer : predictionContainer; if (inputDatasetSize == null && inputDataset == null) throw new Exception("One of input dataset size or input dataset must be specified!"); @@ -129,15 +136,13 @@ public Void call() throws Exception { inputDatasetSize = new N5FSReader(inputContainer).getDatasetAttributes(inputDataset).getDimensions(); for (int d = 0; d < inputOffset.length; ++d) - if (inputOffset[d] / inputResolution[d] != (int) (inputOffset[d] / inputResolution[d])) + if (inputOffset[d] / inputResolution[d] != (int)(inputOffset[d] / inputResolution[d])) throw new Exception("Offset not integer multiple of resolution!"); for (int d = 0; d < outputOffset.length; ++d) - if (outputOffset[d] / outputResolution[d] != (int) (outputOffset[d] / outputResolution[d])) + if (outputOffset[d] / outputResolution[d] != (int)(outputOffset[d] / outputResolution[d])) throw new Exception("Offset not integer multiple of resolution!"); - - return null; } @@ -148,6 +153,7 @@ public int[] blockSize() { } public Supplier> inputMaskSupplier() { + if (inputDataset == null || !inputIsMask) { return new MaskProviderFromDims(inputDatasetSize); } else { @@ -158,39 +164,46 @@ public Supplier> inputMaskSupplier() { } public double[] snapDimensionsToBlockSize(final double[] dimensionsWorld, final double[] blockSizeWorld) { + final double[] snapped = new double[dimensionsWorld.length]; Arrays.setAll(snapped, d -> Math.ceil(dimensionsWorld[d] / blockSizeWorld[d]) * blockSizeWorld[d]); return snapped; } public double[] inputSizeWorld() { + final double[] inputSizeWorld = new double[inputDatasetSize.length]; Arrays.setAll(inputSizeWorld, d -> inputDatasetSize[d] * inputResolution[d]); return inputSizeWorld; } public double[] inputSizeWorldSnapped() { + return snapDimensionsToBlockSize(inputSizeWorld(), networkOutputSizeWorld); } public double[] networkSizeDiffWorld() { + final double[] diff = new double[this.networkInputSizeWorld.length]; Arrays.setAll(diff, d -> networkInputSizeWorld[d] - networkOutputSizeWorld[d]); return diff; } public double[] networkSizeDiffHalfWorld() { + final double[] diff = networkSizeDiffWorld(); Arrays.setAll(diff, d -> diff[d] / 2); return diff; } public double[] networkOutputSize() { + return divide(this.networkInputSizeWorld, outputResolution); } } public static void main(String[] argv) throws IOException { + run(argv); } @@ -205,8 +218,8 @@ public static void run(String[] argv) throws IOException { final double[] networkSizeDiffHalfWorld = args.networkSizeDiffHalfWorld(); final long[] networkSizeDiff = asLong(divide(networkSizeDiffWorld, args.outputResolution)); -// final double[] validMin = networkSizeDiffHalfWorld.clone(); -// final double[] validMax = subtract(inputSizeWorld, networkSizeDiffHalfWorld); + // final double[] validMin = networkSizeDiffHalfWorld.clone(); + // final double[] validMax = subtract(inputSizeWorld, networkSizeDiffHalfWorld); final double[] outputDatasetSizeDouble = divide(inputSizeWorldSnappedToOutput, args.outputResolution); final long[] outputDatasetSize = args.outputDatasetSIze; @@ -220,7 +233,7 @@ public static void run(String[] argv) throws IOException { n5out.get().setAttribute(args.maskDataset, "offset", args.outputOffset); n5out.get().setAttribute(args.maskDataset, "min", 0); n5out.get().setAttribute(args.maskDataset, "max", 1); - n5out.get().setAttribute(args.maskDataset, "value_range", new double[] {0, 1}); + n5out.get().setAttribute(args.maskDataset, "value_range", new double[]{0, 1}); run( n5out, @@ -247,6 +260,7 @@ private static void run( final Supplier> inputMask, final long[] outputDatasetSize, final int[] blockSize) { + final SparkConf conf = new SparkConf().setAppName(MethodHandles.lookup().lookupClass().getName()); try (final JavaSparkContext sc = new JavaSparkContext(conf)) { final List, long[]>> blocks = Grids @@ -264,8 +278,8 @@ private static void run( final DatasetAttributes attributes = new DatasetAttributes(outputDatasetSize, blockSize, DataType.UINT8, new GzipCompression()); final double[] minReal = LongStream.of(min).asDoubleStream().toArray(); final double[] maxReal = LongStream.of(max).asDoubleStream().toArray(); -// final Scale outputScale = new Scale(outputVoxelSize); -// final Scale inputScale = new Scale(inputVoxelSize); + // final Scale outputScale = new Scale(outputVoxelSize); + // final Scale inputScale = new Scale(inputVoxelSize); final AffineTransform3D outputTransform = new AffineTransform3D(); outputTransform.set(outputVoxelSize[0], 0, 0); outputTransform.set(outputVoxelSize[1], 1, 1); @@ -294,10 +308,10 @@ private static void run( } final ArrayImg outputMask = ArrayImgs.unsignedBytes(Intervals.dimensionsAsLongArray(interval)); - Arrays.fill(outputMask.update(null).getCurrentStorageArray(), isForeground ? (byte) 1 : 0); + Arrays.fill(outputMask.update(null).getCurrentStorageArray(), isForeground ? (byte)1 : 0); // this is bad alignment area in cremi sample_A+ -// for (long z = Math.max(min[2], 344); z < Math.min(max[2], 357); ++z) -// Views.hyperSlice(Views.translate(outputMask, min), 2, z).forEach(UnsignedByteType::setZero); + // for (long z = Math.max(min[2], 344); z < Math.min(max[2], 357); ++z) + // Views.hyperSlice(Views.translate(outputMask, min), 2, z).forEach(UnsignedByteType::setZero); N5Utils.saveBlock( outputMask, @@ -308,63 +322,70 @@ private static void run( }); } - } private static long[] asLong(double[] array) { + return convertAsLong(array, d -> d); } private static int[] asInt(double[] array) { + return convertAsInt(array, d -> d); } private static long[] convertAsLong(double[] array, DoubleUnaryOperator converter) { + final long[] ceil = new long[array.length]; - Arrays.setAll(ceil, d -> (long) converter.applyAsDouble(array[d])); + Arrays.setAll(ceil, d -> (long)converter.applyAsDouble(array[d])); return ceil; } private static int[] convertAsInt(double[] array, DoubleUnaryOperator converter) { + final int[] ceil = new int[array.length]; - Arrays.setAll(ceil, d -> (int) converter.applyAsDouble(array[d])); + Arrays.setAll(ceil, d -> (int)converter.applyAsDouble(array[d])); return ceil; } - private static double[] convert(double[] array, DoubleUnaryOperator converter) { + final double[] ceil = new double[array.length]; Arrays.setAll(ceil, d -> converter.applyAsDouble(array[d])); return ceil; } private static double[] subtract(final double[] minuend, final double[] subtrahend) { + final double[] difference = new double[minuend.length]; Arrays.setAll(difference, d -> minuend[d] - subtrahend[d]); return difference; } private static double[] divide(final double[] a, final double[] b) { + final double[] quotient = new double[a.length]; Arrays.setAll(quotient, d -> a[d] / b[d]); return quotient; } private static T toMinMaxTuple(final Interval interval, BiFunction toTuple) { + return toTuple.apply(Intervals.minAsLongArray(interval), Intervals.maxAsLongArray(interval)); } private static class MaskProviderFromDims implements Supplier>, Serializable { - private final long[] dims; private MaskProviderFromDims(long[] dims) { + this.dims = dims; } @Override public RandomAccessible get() { + return Views.extendZero(ConstantUtils.constantRandomAccessibleInterval(new UnsignedByteType(1), dims.length, new FinalInterval(dims))); } } @@ -376,23 +397,19 @@ private static class MaskProviderFromN5 implements Supplier get() { - try { - final RandomAccessibleInterval> img = (RandomAccessibleInterval) N5Utils.open(n5.get(), dataset); - final RandomAccessibleInterval convertedImg = Converters.convert(img, (s, t) -> t.setInteger(s.getIntegerLong()), new UnsignedByteType()); - return Views.extendValue(convertedImg, new UnsignedByteType(0)); - } catch (IOException e) { - throw new RuntimeException(e); - } + + final RandomAccessibleInterval> img = (RandomAccessibleInterval)N5Utils.open(n5.get(), dataset); + final RandomAccessibleInterval convertedImg = Converters.convert(img, (s, t) -> t.setInteger(s.getIntegerLong()), new UnsignedByteType()); + return Views.extendValue(convertedImg, new UnsignedByteType(0)); } } - - } diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/N5WriterSupplier.java b/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/N5WriterSupplier.java index df25eb3..4545119 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/N5WriterSupplier.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/N5WriterSupplier.java @@ -5,7 +5,6 @@ import org.janelia.saalfeldlab.n5.N5Writer; import org.janelia.saalfeldlab.n5.hdf5.N5HDF5Writer; -import java.io.IOException; import java.io.Serializable; import java.nio.file.Files; import java.nio.file.Paths; @@ -23,6 +22,7 @@ class N5WriterSupplier implements Supplier, Serializable { private final boolean serializeSpecialFloatingPointValues = true; N5WriterSupplier(final String container, final boolean withPrettyPrinting, final boolean disableHtmlEscaping) { + this.container = container; this.withPrettyPrinting = withPrettyPrinting; this.disableHtmlEscaping = disableHtmlEscaping; @@ -31,32 +31,33 @@ class N5WriterSupplier implements Supplier, Serializable { @Override public N5Writer get() { - try { - return Files.isDirectory(Paths.get(container)) - ? new N5FSWriter(container, createaBuilder()) - : new N5HDF5Writer(container); - } catch (final IOException e) { - throw new RuntimeException(e); - } + return Files.isDirectory(Paths.get(container)) + ? new N5FSWriter(container, createaBuilder()) + : new N5HDF5Writer(container); } private GsonBuilder createaBuilder() { + return serializeSpecialFloatingPointValues(withPrettyPrinting(disableHtmlEscaping(new GsonBuilder()))); } private GsonBuilder serializeSpecialFloatingPointValues(final GsonBuilder builder) { + return with(builder, this.serializeSpecialFloatingPointValues, GsonBuilder::serializeSpecialFloatingPointValues); } private GsonBuilder withPrettyPrinting(final GsonBuilder builder) { + return with(builder, this.withPrettyPrinting, GsonBuilder::setPrettyPrinting); } private GsonBuilder disableHtmlEscaping(final GsonBuilder builder) { + return with(builder, this.disableHtmlEscaping, GsonBuilder::disableHtmlEscaping); } private static GsonBuilder with(final GsonBuilder builder, boolean applyAction, Function action) { + return applyAction ? action.apply(builder) : builder; } } diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/Offset.java b/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/Offset.java index f2a6301..3b5ec7d 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/Offset.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/Offset.java @@ -1,7 +1,7 @@ package org.janelia.saalfeldlab.label.spark.affinities; import com.google.gson.annotations.Expose; -import org.apache.commons.lang.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringBuilder; import picocli.CommandLine; import java.io.Serializable; @@ -17,20 +17,24 @@ public class Offset implements Serializable { private final long[] offset; public Offset(final int channelIndex, final long... offset) { + this.channelIndex = channelIndex; this.offset = offset; } public long[] offset() { + return offset.clone(); } public int channelIndex() { + return channelIndex; } @Override public String toString() { + return new ToStringBuilder(this) .append("channelIndex", channelIndex) .append("offset", Arrays.toString(offset)) @@ -38,6 +42,7 @@ public String toString() { } public static Offset parseOffset(final String representation) { + final String[] split = representation.split(":"); return new Offset( split.length > 1 ? Integer.parseInt(split[1]) : -1, @@ -48,6 +53,7 @@ public static class Converter implements CommandLine.ITypeConverter { @Override public Offset convert(String s) { + return Offset.parseOffset(s); } } diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/SparkRain.java b/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/SparkRain.java index 6d1621d..d910f05 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/SparkRain.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/affinities/SparkRain.java @@ -10,7 +10,6 @@ import net.imglib2.Cursor; import net.imglib2.FinalInterval; import net.imglib2.Interval; -import net.imglib2.IterableInterval; import net.imglib2.Point; import net.imglib2.RandomAccessibleInterval; import net.imglib2.algorithm.gauss3.Gauss3; @@ -38,7 +37,7 @@ import net.imglib2.util.Util; import net.imglib2.view.IntervalView; import net.imglib2.view.Views; -import org.apache.commons.lang.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringBuilder; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.PairFunction; @@ -103,20 +102,24 @@ private static class Offset implements Serializable { private final long[] offset; public Offset(final int channelIndex, final long... offset) { + this.channelIndex = channelIndex; this.offset = offset; } public long[] offset() { + return offset.clone(); } public int channelIndex() { + return channelIndex; } @Override public String toString() { + return new ToStringBuilder(this) .append("channelIndex", channelIndex) .append("offset", Arrays.toString(offset)) @@ -124,6 +127,7 @@ public String toString() { } public static Offset parseOffset(final String representation) { + final String[] split = representation.split(":"); return new Offset( split.length > 1 ? Integer.parseInt(split[1]) : -1, @@ -134,6 +138,7 @@ public static class Converter implements CommandLine.ITypeConverter { @Override public Offset convert(String s) { + return Offset.parseOffset(s); } } @@ -190,11 +195,11 @@ private static class Args implements Serializable, Callable { int[] blockSize = {64, 64, 64}; @Expose - @CommandLine.Option(names = "--blocks-per-task", paramLabel = "BLOCKS_PER_TASK", description = "How many blocks to combine for watersheds/connected components (one value per dimension)", split=",") + @CommandLine.Option(names = "--blocks-per-task", paramLabel = "BLOCKS_PER_TASK", description = "How many blocks to combine for watersheds/connected components (one value per dimension)", split = ",") int[] blocksPerTask = {1, 1, 1}; @Expose - @CommandLine.Option(names = "--halo", paramLabel = "HALO", description = "Include halo region to run connected components/watersheds", split=",") + @CommandLine.Option(names = "--halo", paramLabel = "HALO", description = "Include halo region to run connected components/watersheds", split = ",") int[] halo = {0, 0, 0}; @Expose @@ -214,8 +219,8 @@ private static class Args implements Serializable, Callable { Boolean relabel; @Expose - @CommandLine.Option(names = "--revert-array-attributes", paramLabel = "RELABEL", description = "Revert all array attributes (that are not dataset attributes)", defaultValue = "false") - Boolean revertArrayAttributes; + @CommandLine.Option(names = "--reverse-array-attributes", paramLabel = "RELABEL", description = "Reverse all array attributes (that are not dataset attributes)", defaultValue = "false") + Boolean reverseArrayAttributes; @Expose @CommandLine.Option(names = "--smooth-affinities", paramLabel = "SIGMA", description = "Smooth affinities before watersheds (if SIGMA > 0)", defaultValue = "0.0") @@ -231,7 +236,7 @@ private static class Args implements Serializable, Callable { @CommandLine.Option(names = "--json-disable-html-escape", defaultValue = "true") transient Boolean disbaleHtmlEscape; - @CommandLine.Option(names = { "-h", "--help"}, usageHelp = true, description = "Display this help and exit") + @CommandLine.Option(names = {"-h", "--help"}, usageHelp = true, description = "Display this help and exit") private Boolean help; @Override @@ -253,6 +258,7 @@ public Void call() throws Exception { } public Offset[] enumeratedOffsets() { + final Offset[] enumeratedOffsets = new Offset[this.offsets.length]; for (int i = 0; i < offsets.length; ++i) { final Offset o = this.offsets[i]; @@ -292,21 +298,20 @@ public static void run(final String... argv) throws IOException { labelUtilitiesSparkAttributes.put(VERSION_KEY, Version.VERSION_STRING); final Map attributes = with(new HashMap<>(), LABEL_UTILITIES_SPARK_KEY, labelUtilitiesSparkAttributes); - final int[] taskBlockSize = IntStream.range(0, args.blockSize.length).map(d -> args.blockSize[d] * args.blocksPerTask[d]).toArray(); final boolean hasHalo = Arrays.stream(args.halo).filter(h -> h != 0).count() > 0; if (hasHalo) throw new UnsupportedOperationException("Halo currently not supported, please omit halo option!"); String[] uint64Datasets = args.minSize > 0 - ? new String[] {args.watersheds, args.merged, args.seededWatersheds, args.sizeFiltered} - : new String[] {args.watersheds, args.merged, args.seededWatersheds}; + ? new String[]{args.watersheds, args.merged, args.seededWatersheds, args.sizeFiltered} + : new String[]{args.watersheds, args.merged, args.seededWatersheds}; String[] uint8Datasets = {args.watershedSeeds}; - String[] float32Datasets = args.smoothAffinitiesSigma > 0 ? new String[] {args.smoothedAffinities} : new String[] {}; + String[] float32Datasets = args.smoothAffinitiesSigma > 0 ? new String[]{args.smoothedAffinities} : new String[]{}; - final double[] resolution = reverted(Optional.ofNullable(n5in.get().getAttribute(args.affinities, RESOLUTION_KEY, double[].class)).orElse(ones(outputDims.length)), args.revertArrayAttributes); - final double[] offset = reverted(Optional.ofNullable(n5in.get().getAttribute(args.affinities, OFFSET_KEY, double[].class)).orElse(new double[outputDims.length]), args.revertArrayAttributes); + final double[] resolution = reversed(Optional.ofNullable(n5in.get().getAttribute(args.affinities, RESOLUTION_KEY, double[].class)).orElse(ones(outputDims.length)), args.reverseArrayAttributes); + final double[] offset = reversed(Optional.ofNullable(n5in.get().getAttribute(args.affinities, OFFSET_KEY, double[].class)).orElse(new double[outputDims.length]), args.reverseArrayAttributes); attributes.put(RESOLUTION_KEY, resolution); attributes.put(OFFSET_KEY, offset); @@ -331,7 +336,6 @@ public static void run(final String... argv) throws IOException { final Offset[] offsets = args.enumeratedOffsets(); - final SparkConf conf = new SparkConf().setAppName(MethodHandles.lookup().lookupClass().getName()); try (final JavaSparkContext sc = new JavaSparkContext(conf)) { run( @@ -393,13 +397,13 @@ public static void run( .stream() .map(i -> new Tuple2<>(Intervals.minAsLongArray(i), Intervals.maxAsLongArray(i))) .collect(Collectors.toList()); - ; + ; final long[] negativeHalo = new long[halo.length]; Arrays.setAll(negativeHalo, d -> -halo[d]); final List, Integer>> idCounts = sc .parallelize(watershedBlocks) - .map(t -> (Interval) new FinalInterval(t._1(), t._2())) + .map(t -> (Interval)new FinalInterval(t._1(), t._2())) .mapToPair(new CropAffinities(n5in, affinities, invertAffinitiesAxis, halo, smoothAffinitiesSigma)) .mapValues(affs -> { // TODO how to avoid looking outside interval? @@ -427,12 +431,12 @@ public static void run( final int[] symmetricOrder = new int[offsets.length]; Arrays.setAll(symmetricOrder, d -> offsets.length - 1 - d); // LoopBuilder issues in this call! -// final RandomAccessibleInterval symmetricAffinities = Watersheds.constructAffinities( -// uncollapsedAffinities, -// offsets, -// new ArrayImgFactory<>(new FloatType()), -// symmetricOrder -// ); + // final RandomAccessibleInterval symmetricAffinities = Watersheds.constructAffinities( + // uncollapsedAffinities, + // offsets, + // new ArrayImgFactory<>(new FloatType()), + // symmetricOrder + // ); final RandomAccessibleInterval symmetricAffinities = constructAffinitiesWithCopy( uncollapsedAffinities, new ArrayImgFactory<>(new FloatType()), @@ -442,11 +446,11 @@ public static void run( final RandomAccessibleInterval symmetricSmoothedAffinities = uncollapsedSmoothedAffinities == null ? symmetricAffinities : Watersheds.constructAffinities( - uncollapsedSmoothedAffinities, - offsets, - new ArrayImgFactory<>(new FloatType()), - symmetricOrder); -// } + uncollapsedSmoothedAffinities, + offsets, + new ArrayImgFactory<>(new FloatType()), + symmetricOrder); + // } final long[][] symmetricOffsets = Watersheds.symmetricOffsets(Watersheds.SymmetricOffsetOrder.ABCCBA, offsets); final Pair parentsAndRoots = Watersheds.letItRain( @@ -478,8 +482,8 @@ public static void run( final ArrayImg um = ArrayImgs.bits(dims); final IntArrayUnionFind uf = new IntArrayUnionFind(roots.length); - final RandomAccessibleInterval mask = Converters.convert((RandomAccessibleInterval) labels, (s, tgt) -> tgt.set(s.getIntegerLong() > 0), new BitType()); - final ConnectedComponents.ToIndex toIndex = (it, index) -> parents[(int) index]; + final RandomAccessibleInterval mask = Converters.convert((RandomAccessibleInterval)labels, (s, tgt) -> tgt.set(s.getIntegerLong() > 0), new BitType()); + final ConnectedComponents.ToIndex toIndex = (it, index) -> parents[(int)index]; ConnectedComponents.unionFindFromSymmetricAffinities( Views.extendValue(mask, new BitType(false)), Views.collapseReal(uncollapsedAffinities), @@ -496,7 +500,6 @@ public static void run( n5out.get().writeBlock(merged, watershedAttributes, dataBlock); } - final TIntIntHashMap counts = new TIntIntHashMap(); for (final UnsignedLongType vx : Views.iterable(labels)) { final int v = vx.getInteger(); @@ -530,16 +533,16 @@ public static void run( Watersheds.seedsFromMask(Views.extendValue(labels, new UnsignedLongType(Label.OUTSIDE)), watershedSeedsMaskImg, symmetricOffsets); final List seeds = Watersheds.collectSeeds(watershedSeedsMaskImg); LOG.debug("Found watershed seeds {}", seeds); - final RandomAccessibleInterval watershedSeedsMaskImgUint8 = Converters.convert(watershedSeedsMaskImg, (src,tgt) -> tgt.set(src.get() ? 1 : 0), new UnsignedByteType()); + final RandomAccessibleInterval watershedSeedsMaskImgUint8 = Converters.convert(watershedSeedsMaskImg, (src, tgt) -> tgt.set(src.get() ? 1 : 0), new UnsignedByteType()); final DatasetAttributes croppedWatershedSeedsAtributes = new DatasetAttributes(outputDims, blockSize, DataType.UINT8, new GzipCompression()); N5Utils.saveBlock(Views.interval(watershedSeedsMaskImgUint8, relevantInterval), n5out.get(), hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, croppedWatershedSeedsAtributes, blockOffset); if (hasHalo) { throw new UnsupportedOperationException("Need to implement halo support!"); -// final DataBlock dataBlock = new LongArrayDataBlock(Intervals.dimensionsAsIntArray(watershedSeedsMaskImg), watershedsBlockOffset, labels.update(null).getCurrentStorageArray()); -// n5out.get().writeBlock(watershedSeeds, watershedAttributes, dataBlock); + // final DataBlock dataBlock = new LongArrayDataBlock(Intervals.dimensionsAsIntArray(watershedSeedsMaskImg), watershedsBlockOffset, labels.update(null).getCurrentStorageArray()); + // n5out.get().writeBlock(watershedSeeds, watershedAttributes, dataBlock); } - LOG.debug("Starting seeded watersheds with offsets {}", (Object) symmetricOffsets); + LOG.debug("Starting seeded watersheds with offsets {}", (Object)symmetricOffsets); Watersheds.seededFromAffinities( Views.collapseReal(symmetricSmoothedAffinities), labels, @@ -556,8 +559,7 @@ public static void run( return new Tuple2<>(new Tuple2<>(Intervals.minAsLongArray(t._1()), Intervals.maxAsLongArray(t._1())), roots.length - 1); }) - .collect() - ; + .collect(); long startIndex = 1; final List, Long>> idOffsets = new ArrayList<>(); @@ -606,6 +608,7 @@ private static void relabel( final String dataset, final Interval interval, final long addIfNotZero) throws IOException { + SparkRain.relabel(n5, dataset, interval, (src, tgt) -> { final long val = src.getIntegerLong(); tgt.set(val == 0 ? 0 : val + addIfNotZero); @@ -617,6 +620,7 @@ private static & NativeType> void relabel( final String dataset, final Interval interval, final BiConsumer idMapping) throws IOException { + final DatasetAttributes attributes = n5.getDatasetAttributes(dataset); final CellGrid grid = new CellGrid(attributes.getDimensions(), attributes.getBlockSize()); final RandomAccessibleInterval data = Views.interval(N5Utils.open(n5, dataset), interval); @@ -624,7 +628,7 @@ private static & NativeType> void relabel( for (net.imglib2.util.Pair p : Views.interval(Views.pair(Views.zeroMin(data), copy), Views.zeroMin(data))) idMapping.accept(p.getA(), p.getB()); // LoopBuilder class loader issues in spark -// LoopBuilder.setImages(data, copy).forEachPixel(idMapping); + // LoopBuilder.setImages(data, copy).forEachPixel(idMapping); final long[] blockPos = Intervals.minAsLongArray(interval); grid.getCellPosition(blockPos, blockPos); N5Utils.saveBlock(copy, n5, dataset, attributes, blockPos); @@ -635,6 +639,7 @@ private static void relabel( final String dataset, final long[] blockPos, final long addIfNonZero) throws IOException { + relabel(n5, dataset, blockPos, id -> id == 0 ? 0 : id + addIfNonZero); } @@ -643,8 +648,9 @@ private static void relabel( final String dataset, final long[] blockPos, final LongUnaryOperator idMapping) throws IOException { + final DatasetAttributes attributes = n5.getDatasetAttributes(dataset); - final LongArrayDataBlock block = ((LongArrayDataBlock) n5.readBlock(dataset, attributes, blockPos)); + final LongArrayDataBlock block = ((LongArrayDataBlock)n5.readBlock(dataset, attributes, blockPos)); final long[] data = block.getData(); for (int i = 0; i < data.length; ++i) { data[i] = idMapping.applyAsLong(data[i]); @@ -666,12 +672,14 @@ private static void prepareOutputDataset( final String dataset, final DatasetAttributes attributes, final Map additionalAttributes) throws IOException { + n5.createDataset(dataset, attributes); for (Map.Entry entry : additionalAttributes.entrySet()) n5.setAttribute(dataset, entry.getKey(), entry.getValue()); } private static Map with(Map map, K key, V value) { + map.put(key, value); return map; } @@ -687,6 +695,7 @@ private static class N5WriterSupplier implements Supplier, Serializabl private final boolean serializeSpecialFloatingPointValues = true; private N5WriterSupplier(final String container, final boolean withPrettyPrinting, final boolean disableHtmlEscaping) { + this.container = container; this.withPrettyPrinting = withPrettyPrinting; this.disableHtmlEscaping = disableHtmlEscaping; @@ -695,43 +704,46 @@ private N5WriterSupplier(final String container, final boolean withPrettyPrintin @Override public N5Writer get() { - try { - return Files.isDirectory(Paths.get(container)) - ? new N5FSWriter(container, createaBuilder()) - : new N5HDF5Writer(container); - } catch (final IOException e) { - throw new RuntimeException(e); - } + return Files.isDirectory(Paths.get(container)) + ? new N5FSWriter(container, createaBuilder()) + : new N5HDF5Writer(container); } private GsonBuilder createaBuilder() { + return serializeSpecialFloatingPointValues(withPrettyPrinting(disableHtmlEscaping(new GsonBuilder()))); } private GsonBuilder serializeSpecialFloatingPointValues(final GsonBuilder builder) { + return with(builder, this.serializeSpecialFloatingPointValues, GsonBuilder::serializeSpecialFloatingPointValues); } private GsonBuilder withPrettyPrinting(final GsonBuilder builder) { + return with(builder, this.withPrettyPrinting, GsonBuilder::setPrettyPrinting); } private GsonBuilder disableHtmlEscaping(final GsonBuilder builder) { + return with(builder, this.disableHtmlEscaping, GsonBuilder::disableHtmlEscaping); } private static GsonBuilder with(final GsonBuilder builder, boolean applyAction, Function action) { + return applyAction ? action.apply(builder) : builder; } } private static double[] ones(final int length) { + double[] ones = new double[length]; Arrays.fill(ones, 1.0); return ones; } private static Interval addDimension(final Interval interval, final long m, final long M) { + long[] min = new long[interval.numDimensions() + 1]; long[] max = new long[interval.numDimensions() + 1]; for (int d = 0; d < interval.numDimensions(); ++d) { @@ -744,14 +756,17 @@ private static Interval addDimension(final Interval interval, final long m, fina } private static String toString(final Interval interval) { + return String.format("(%s %s)", Arrays.toString(Intervals.minAsLongArray(interval)), Arrays.toString(Intervals.maxAsLongArray(interval))); } - private static double[] reverted(final double[] array, final boolean revert) { - return revert ? reverted(array) : array; + private static double[] reversed(final double[] array, final boolean reverse) { + + return reverse ? reversed(array) : array; } - private static double[] reverted(final double[] array) { + private static double[] reversed(final double[] array) { + final double[] copy = new double[array.length]; for (int i = 0, k = copy.length - 1; i < copy.length; ++i, --k) { copy[i] = array[k]; @@ -764,6 +779,7 @@ private static > ArrayImg smooth( final Interval interval, final int channelDim, double sigma) { + final ArrayImg img = ArrayImgs.floats(Intervals.dimensionsAsLongArray(interval)); for (long channel = interval.min(channelDim); channel <= interval.max(channelDim); ++channel) { @@ -780,6 +796,7 @@ private static > void invalidateOutOfBlockAffinities( final T invalid, final long[]... offsets ) { + for (int index = 0; index < offsets.length; ++index) { final IntervalView slice = Views.hyperSlice(affs, affs.numDimensions() - 1, index); for (int d = 0; d < offsets[index].length; ++d) { @@ -810,6 +827,7 @@ private CropAffinities( final boolean invertAffinitiesAxis, final long[] halo, final double smoothAffinitiesSigma) { + this.n5in = n5in; this.affinities = affinities; this.invertAffinitiesAxis = invertAffinitiesAxis; @@ -819,6 +837,7 @@ private CropAffinities( @Override public Tuple2, RandomAccessibleInterval>> call(final Interval interval) throws Exception { + RandomAccessibleInterval affs = N5Utils.open(n5in.get(), affinities); affs = invertAffinitiesAxis ? Views.zeroMin(Views.invertAxis(affs, affs.numDimensions() - 1)) : affs; @@ -831,10 +850,10 @@ public Tuple2, RandomAccess while (target.hasNext()) target.next().set(source.next()); // Class loader issues with loop builder on spark (non-local) -// LoopBuilder.setImages(affinityCrop, Views.interval(Views.extendValue(affs, new FloatType(Float.NaN)), withHaloAndChannels)).forEachPixel(FloatType::set); + // LoopBuilder.setImages(affinityCrop, Views.interval(Views.extendValue(affs, new FloatType(Float.NaN)), withHaloAndChannels)).forEachPixel(FloatType::set); return smoothAffinitiesSigma > 0.0 ? new Tuple2<>(interval, new Tuple2<>(affinityCrop, smooth(affs, withHaloAndChannels, withHaloAndChannels.numDimensions() - 1, smoothAffinitiesSigma))) - : new Tuple2<>(interval, new Tuple2<>(affinityCrop, (ArrayImg) null)); + : new Tuple2<>(interval, new Tuple2<>(affinityCrop, (ArrayImg)null)); } } @@ -861,12 +880,12 @@ private static > RandomAccessibleInterval constructAffi for (int offsetIndex = 0; offsetIndex < offsets.length; ++offsetIndex) { final int targetIndex = offsets.length + order[offsetIndex]; - final IntervalView targetSlice = Views.hyperSlice(symmetricAffinities, dims.length - 1, (long) targetIndex); + final IntervalView targetSlice = Views.hyperSlice(symmetricAffinities, dims.length - 1, (long)targetIndex); final IntervalView sourceSlice = Views.interval(Views.translate( Views.extendValue(Views.hyperSlice( zeroMinAffinities, dims.length - 1, - (long) offsetIndex), nanExtension), + (long)offsetIndex), nanExtension), offsets[offsetIndex]), targetSlice); Cursor source = Views.flatIterable(sourceSlice).cursor(); diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/convert/ConvertToLabelMultisetType.java b/src/main/java/org/janelia/saalfeldlab/label/spark/convert/ConvertToLabelMultisetType.java index 5095f29..bc4b3ce 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/convert/ConvertToLabelMultisetType.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/convert/ConvertToLabelMultisetType.java @@ -1,34 +1,7 @@ package org.janelia.saalfeldlab.label.spark.convert; -import java.io.IOException; -import java.lang.invoke.MethodHandles; -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.concurrent.Callable; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -import org.apache.commons.lang.time.DurationFormatUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.janelia.saalfeldlab.label.spark.N5Helpers; -import org.janelia.saalfeldlab.n5.Compression; -import org.janelia.saalfeldlab.n5.CompressionAdapter; -import org.janelia.saalfeldlab.n5.DataType; -import org.janelia.saalfeldlab.n5.GzipCompression; -import org.janelia.saalfeldlab.n5.N5Reader; -import org.janelia.saalfeldlab.n5.N5Writer; -import org.janelia.saalfeldlab.n5.imglib2.N5LabelMultisets; -import org.janelia.saalfeldlab.n5.imglib2.N5Utils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.google.gson.Gson; import com.google.gson.GsonBuilder; - import net.imglib2.FinalInterval; import net.imglib2.Interval; import net.imglib2.RandomAccessibleInterval; @@ -40,15 +13,44 @@ import net.imglib2.type.numeric.IntegerType; import net.imglib2.util.Intervals; import net.imglib2.view.Views; +import org.apache.commons.lang3.time.DurationFormatUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.janelia.saalfeldlab.label.spark.N5Helpers; +import org.janelia.saalfeldlab.n5.Compression; +import org.janelia.saalfeldlab.n5.CompressionAdapter; +import org.janelia.saalfeldlab.n5.DataType; +import org.janelia.saalfeldlab.n5.DatasetAttributes; +import org.janelia.saalfeldlab.n5.GsonUtils; +import org.janelia.saalfeldlab.n5.GzipCompression; +import org.janelia.saalfeldlab.n5.N5Reader; +import org.janelia.saalfeldlab.n5.N5Writer; +import org.janelia.saalfeldlab.n5.imglib2.N5LabelMultisets; +import org.janelia.saalfeldlab.n5.imglib2.N5Utils; +import org.janelia.saalfeldlab.n5.zarr.ZarrKeyValueReader; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import picocli.CommandLine; import picocli.CommandLine.Option; import picocli.CommandLine.Parameters; import scala.Tuple2; -public class ConvertToLabelMultisetType -{ +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.Callable; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class ConvertToLabelMultisetType { - private static final Logger LOG = LoggerFactory.getLogger( MethodHandles.lookup().lookupClass() ); + private static final Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); public static final String LABEL_MULTISETTYPE_KEY = "isLabelMultiset"; @@ -68,54 +70,52 @@ public class ConvertToLabelMultisetType // TODO make this parallizable/spark and not hdf-to-n5 but convert between // various instances of n5 instead - static public class CommandLineParameters implements Callable< Void > - { - @Option( names = { "--input-n5", "-i" }, paramLabel = "INPUT_N5", required = true, description = "Input N5 container. Currently supports N5 and HDF5." ) + static public class CommandLineParameters implements Callable { + @Option(names = {"--input-n5", "-i"}, paramLabel = "INPUT_N5", required = true, description = "Input N5 container. Currently supports N5 and HDF5.") private String inputN5; - @Option( names = { "--output-n5", "-o" }, paramLabel = "OUTPUT_N5", description = "Output N5 container. Defaults to INPUT_N5" ) + @Option(names = {"--output-n5", "-o"}, paramLabel = "OUTPUT_N5", description = "Output N5 container. Defaults to INPUT_N5") private String outputN5; - @Option( names = { "--dataset", "-d" }, paramLabel = "INPUT_DATASET", required = true, description = "Input dataset name (relative to INPUT_N5" ) + @Option(names = {"--dataset", "-d"}, paramLabel = "INPUT_DATASET", required = true, description = "Input dataset name (relative to INPUT_N5") private String inputDataset; - @Parameters( arity = "1", paramLabel = "OUTPUT_DATASET", description = "Output dataset name (relative to OUTPUT_N5)" ) + @Parameters(arity = "1", paramLabel = "OUTPUT_DATASET", description = "Output dataset name (relative to OUTPUT_N5)") private String outputDatasetName; - @Option( names = { "--block-size", "-b" }, paramLabel = "BLOCK_SIZE", description = "Size of cells to use in the output N5 dataset. Defaults to 64. Either single integer value for isotropic block size or comma-seperated list of block size per dimension", split = "," ) + @Option(names = {"--block-size", "-b"}, paramLabel = "BLOCK_SIZE", description = "Size of cells to use in the output N5 dataset. Defaults to 64. Either single integer value for isotropic block size or comma-seperated list of block size per dimension", split = ",") private int[] blockSize; - @Option( names = { "--compression", "-c" }, paramLabel = "COMPRESSION", description = "Compression type to use in output N5 dataset" ) + @Option(names = {"--compression", "-c"}, paramLabel = "COMPRESSION", description = "Compression type to use in output N5 dataset") public String compressionType = "{\"type\":\"gzip\",\"level\":-1}"; @Option( - names = { "--revert-array-attributes" }, + names = {"--reverse-array-attributes"}, required = false, - description = "When copying, revert all additional array attributes that are not dataset attributes. E.g. [3,2,1] -> [1,2,3]" ) - private boolean revertArrayAttributes; + description = "When copying, reverse all additional array attributes that are not dataset attributes. E.g. [3,2,1] -> [1,2,3]") + private boolean reverseArrayAttributes; @Override - public Void call() throws IOException - { - this.blockSize = this.blockSize == null || this.blockSize.length == 0 ? new int[] { DEFAULT_BLOCK_SIZE } : this.blockSize; + public Void call() throws IOException { + + this.blockSize = this.blockSize == null || this.blockSize.length == 0 ? new int[]{DEFAULT_BLOCK_SIZE} : this.blockSize; this.outputN5 = this.outputN5 == null ? this.inputN5 : this.outputN5; final Gson gson = new GsonBuilder() - .registerTypeHierarchyAdapter( Compression.class, CompressionAdapter.getJsonAdapter() ) + .registerTypeHierarchyAdapter(Compression.class, CompressionAdapter.getJsonAdapter()) .create(); final Compression compression = new GzipCompression();// .fromJson( - // compressionType, - // Compression.class - // ); - final int nDim = N5Helpers.n5Reader( this.inputN5 ).getDatasetAttributes( this.inputDataset ).getNumDimensions(); - final int[] blockSize = this.blockSize.length < nDim ? IntStream.generate( () -> this.blockSize[ 0 ] ).limit( nDim ).toArray() : this.blockSize; + // compressionType, + // Compression.class + // ); + final int nDim = N5Helpers.n5Reader(this.inputN5).getDatasetAttributes(this.inputDataset).getNumDimensions(); + final int[] blockSize = this.blockSize.length < nDim ? IntStream.generate(() -> this.blockSize[0]).limit(nDim).toArray() : this.blockSize; final long startTime = System.currentTimeMillis(); - final SparkConf conf = new SparkConf().setAppName( MethodHandles.lookup().lookupClass().getName() ); + final SparkConf conf = new SparkConf().setAppName(MethodHandles.lookup().lookupClass().getName()); - try (final JavaSparkContext sc = new JavaSparkContext( conf )) - { + try (final JavaSparkContext sc = new JavaSparkContext(conf)) { convertToLabelMultisetType( sc, inputN5, @@ -124,34 +124,34 @@ public Void call() throws IOException outputN5, outputDatasetName, compression, - revertArrayAttributes ); + reverseArrayAttributes); } final long endTime = System.currentTimeMillis(); - final String formattedTime = DurationFormatUtils.formatDuration( endTime - startTime, "HH:mm:ss.SSS" ); - System.out.println( "Converted " + inputN5 + " to N5 dataset at " + outputN5 + " with name " + outputDatasetName + - " in " + formattedTime ); + final String formattedTime = DurationFormatUtils.formatDuration(endTime - startTime, "HH:mm:ss.SSS"); + System.out.println("Converted " + inputN5 + " to N5 dataset at " + outputN5 + " with name " + outputDatasetName + + " in " + formattedTime); return null; } } - public static void main( final String... args ) throws IOException - { - run( args ); + public static void main(final String... args) throws IOException { + + run(args); } - public static void run( final String... args ) throws IOException - { - System.out.println( "Command line arguments: " + Arrays.toString( args ) ); - LOG.debug( "Command line arguments: ", Arrays.toString( args ) ); - CommandLine.call( new CommandLineParameters(), System.err, args ); + public static void run(final String... args) throws IOException { + + System.out.println("Command line arguments: " + Arrays.toString(args)); + LOG.debug("Command line arguments: ", Arrays.toString(args)); + CommandLine.call(new CommandLineParameters(), System.err, args); } - public static < I extends IntegerType< I > & NativeType< I > > void convertToLabelMultisetType( + public static & NativeType> void convertToLabelMultisetType( final JavaSparkContext sc, final String inputGroup, final String inputDataset, @@ -159,91 +159,93 @@ public static < I extends IntegerType< I > & NativeType< I > > void convertToLab final String outputGroupName, final String outputDatasetName, final Compression compression, - final boolean revert ) throws IOException - { - final N5Reader reader = N5Helpers.n5Reader( inputGroup, blockSize ); - final int[] inputBlockSize = reader.getDatasetAttributes( inputDataset ).getBlockSize(); - final RandomAccessibleInterval< I > img = N5Utils.open( reader, inputDataset ); - final Map< String, Class< ? > > attributeNames = reader.listAttributes( inputDataset ); - Arrays.asList( - LABEL_MULTISETTYPE_KEY, - DATA_TYPE_KEY, - COMPRESSION_KEY, - BLOCK_SIZE_KEY, - DIMENSIONS_KEY ) - .forEach( attributeNames::remove ); + final boolean reverse) throws IOException { + + final N5Reader reader = N5Helpers.n5Reader(inputGroup, blockSize); + final DatasetAttributes inputDataAttrs = reader.getDatasetAttributes(inputDataset); + final int[] inputBlockSize = inputDataAttrs.getBlockSize(); + final RandomAccessibleInterval img = N5Utils.open(reader, inputDataset); + final Map> attributeNames; + if (reader instanceof ZarrKeyValueReader) { + attributeNames = Optional.of(reader) + .map(ZarrKeyValueReader.class::cast) + .map(it -> it.getZAttributes(inputDataset)) + .map(GsonUtils::listAttributes) + .orElseGet(HashMap::new); + } else { + attributeNames = reader.listAttributes(inputDataset); + List.of( + LABEL_MULTISETTYPE_KEY, + DATA_TYPE_KEY, + COMPRESSION_KEY, + BLOCK_SIZE_KEY, + DIMENSIONS_KEY + ).forEach(attributeNames::remove); + } final int nDim = img.numDimensions(); - final long[] dimensions = new long[ nDim ]; - img.dimensions( dimensions ); + final long[] dimensions = new long[nDim]; + img.dimensions(dimensions); - final N5Writer writer = N5Helpers.n5Writer( outputGroupName, blockSize ); - final boolean outputDatasetExisted = writer.datasetExists( outputDatasetName ); - if ( outputDatasetExisted ) - { - final int[] existingBlockSize = writer.getDatasetAttributes( outputDatasetName ).getBlockSize(); - if ( !Arrays.equals( blockSize, existingBlockSize ) ) - throw new RuntimeException( "Cannot overwrite existing dataset when the block sizes are not the same." ); - } - writer.createDataset( outputDatasetName, dimensions, blockSize, DataType.UINT8, compression ); - writer.setAttribute( outputDatasetName, LABEL_MULTISETTYPE_KEY, true ); - for ( final Entry< String, Class< ? > > entry : attributeNames.entrySet() ) - writer.setAttribute( outputDatasetName, entry.getKey(), N5Helpers.revertInplaceAndReturn( reader.getAttribute( inputDataset, entry.getKey(), entry.getValue() ), revert ) ); - - final int[] parallelizeBlockSize = new int[ blockSize.length ]; - if ( Intervals.numElements( blockSize ) >= Intervals.numElements( inputBlockSize ) ) - { - Arrays.setAll( parallelizeBlockSize, d -> blockSize[ d ] ); - LOG.debug( "Output block size {} is the same or bigger than the input block size {}, parallelizing over output blocks of size {}", blockSize, inputBlockSize, parallelizeBlockSize ); + final N5Writer writer = N5Helpers.n5Writer(outputGroupName, blockSize); + final boolean outputDatasetExisted = writer.datasetExists(outputDatasetName); + if (outputDatasetExisted) { + final int[] existingBlockSize = writer.getDatasetAttributes(outputDatasetName).getBlockSize(); + if (!Arrays.equals(blockSize, existingBlockSize)) + throw new RuntimeException("Cannot overwrite existing dataset when the block sizes are not the same."); } - else - { - Arrays.setAll( parallelizeBlockSize, d -> ( int ) Math.max( Math.round( ( double ) inputBlockSize[ d ] / blockSize[ d ] ), 1 ) * blockSize[ d ] ); - LOG.debug( "Output block size {} is smaller than the input block size {}, parallelizing over adjusted input blocks of size {}", blockSize, inputBlockSize, parallelizeBlockSize ); + writer.createDataset(outputDatasetName, dimensions, blockSize, DataType.UINT8, compression); + writer.setAttribute(outputDatasetName, LABEL_MULTISETTYPE_KEY, true); + for (final Entry> entry : attributeNames.entrySet()) + writer.setAttribute(outputDatasetName, entry.getKey(), N5Helpers.reverseInplaceAndReturn(reader.getAttribute(inputDataset, entry.getKey(), entry.getValue()), reverse)); + + final int[] parallelizeBlockSize = new int[blockSize.length]; + if (Intervals.numElements(blockSize) >= Intervals.numElements(inputBlockSize)) { + Arrays.setAll(parallelizeBlockSize, d -> blockSize[d]); + LOG.debug("Output block size {} is the same or bigger than the input block size {}, parallelizing over output blocks of size {}", blockSize, inputBlockSize, parallelizeBlockSize); + } else { + Arrays.setAll(parallelizeBlockSize, d -> (int)Math.max(Math.round((double)inputBlockSize[d] / blockSize[d]), 1) * blockSize[d]); + LOG.debug("Output block size {} is smaller than the input block size {}, parallelizing over adjusted input blocks of size {}", blockSize, inputBlockSize, parallelizeBlockSize); } - final List< Tuple2< long[], long[] > > intervals = Grids.collectAllContainedIntervals( dimensions, parallelizeBlockSize ) + final List> intervals = Grids.collectAllContainedIntervals(dimensions, parallelizeBlockSize) .stream() - .map( interval -> new Tuple2<>( Intervals.minAsLongArray( interval ), Intervals.maxAsLongArray( interval ) ) ) - .collect( Collectors.toList() ); + .map(interval -> new Tuple2<>(Intervals.minAsLongArray(interval), Intervals.maxAsLongArray(interval))) + .collect(Collectors.toList()); final long maxId = sc - .parallelize( intervals, Math.min( intervals.size(), MAX_PARTITIONS ) ) - .map( intervalMinMax -> { - final Interval interval = new FinalInterval( intervalMinMax._1(), intervalMinMax._2() ); + .parallelize(intervals, Math.min(intervals.size(), MAX_PARTITIONS)) + .map(intervalMinMax -> { + final Interval interval = new FinalInterval(intervalMinMax._1(), intervalMinMax._2()); - @SuppressWarnings("unchecked") - final RandomAccessibleInterval< I > blockImg = Views.interval( - ( RandomAccessibleInterval< I > ) N5Utils.open( N5Helpers.n5Reader( inputGroup, blockSize ), inputDataset ), + @SuppressWarnings("unchecked") final RandomAccessibleInterval blockImg = Views.interval( + (RandomAccessibleInterval)N5Utils.open(N5Helpers.n5Reader(inputGroup, blockSize), inputDataset), interval - ); + ); - final FromIntegerTypeConverter< I > converter = new FromIntegerTypeConverter<>(); + final FromIntegerTypeConverter converter = new FromIntegerTypeConverter<>(); final LabelMultisetType type = FromIntegerTypeConverter.getAppropriateType(); long blockMaxId = Long.MIN_VALUE; - for ( final I i : Views.iterable( blockImg ) ) - { + for (final I i : Views.iterable(blockImg)) { final long il = i.getIntegerLong(); - blockMaxId = Math.max( il, blockMaxId ); + blockMaxId = Math.max(il, blockMaxId); } - final RandomAccessibleInterval< LabelMultisetType > converted = Converters.convert( blockImg, converter, type ); - + final RandomAccessibleInterval converted = Converters.convert(blockImg, converter, type); - final N5Writer localWriter = N5Helpers.n5Writer( outputGroupName, blockSize ); + final N5Writer localWriter = N5Helpers.n5Writer(outputGroupName, blockSize); - if ( outputDatasetExisted ) - { + if (outputDatasetExisted) { // Empty blocks will not be written out. Delete blocks to avoid remnant blocks if overwriting. - N5Utils.deleteBlock( converted, localWriter, outputDatasetName ); + N5Utils.deleteBlock(converted, localWriter, outputDatasetName); } - N5LabelMultisets.saveLabelMultisetNonEmptyBlock( converted, localWriter, outputDatasetName ); + N5LabelMultisets.saveLabelMultisetNonEmptyBlock(converted, localWriter, outputDatasetName); return blockMaxId; - } ) - .max( Comparator.naturalOrder() ); + }) + .max(Comparator.naturalOrder()); - writer.setAttribute( outputDatasetName, MAX_ID_KEY, maxId ); + writer.setAttribute(outputDatasetName, MAX_ID_KEY, maxId); } } diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/downsample/MinToInterval.java b/src/main/java/org/janelia/saalfeldlab/label/spark/downsample/MinToInterval.java index ec42381..bd74b1e 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/downsample/MinToInterval.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/downsample/MinToInterval.java @@ -1,32 +1,30 @@ package org.janelia.saalfeldlab.label.spark.downsample; -import java.util.Arrays; - -import org.apache.spark.api.java.function.Function; - import net.imglib2.FinalInterval; import net.imglib2.Interval; +import org.apache.spark.api.java.function.Function; -public class MinToInterval implements Function< long[], Interval > -{ +import java.util.Arrays; + +public class MinToInterval implements Function { private final long[] max; private final int[] blockSize; - public MinToInterval( final long[] max, final int[] blockSize ) - { + public MinToInterval(final long[] max, final int[] blockSize) { + super(); this.max = max; this.blockSize = blockSize; } @Override - public Interval call( final long[] min ) - { - final long[] max = new long[ min.length ]; - Arrays.setAll( max, d -> Math.min( min[ d ] + blockSize[ d ] - 1, this.max[ d ] ) ); - return new FinalInterval( min, max ); + public Interval call(final long[] min) { + + final long[] max = new long[min.length]; + Arrays.setAll(max, d -> Math.min(min[d] + blockSize[d] - 1, this.max[d])); + return new FinalInterval(min, max); } } diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/downsample/SparkDownsampleFunction.java b/src/main/java/org/janelia/saalfeldlab/label/spark/downsample/SparkDownsampleFunction.java index 02c052a..3686f87 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/downsample/SparkDownsampleFunction.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/downsample/SparkDownsampleFunction.java @@ -1,18 +1,5 @@ package org.janelia.saalfeldlab.label.spark.downsample; -import java.util.Arrays; -import java.util.List; - -import net.imglib2.img.array.ArrayImg; -import org.apache.spark.api.java.function.VoidFunction; -import org.janelia.saalfeldlab.n5.ByteArrayDataBlock; -import org.janelia.saalfeldlab.n5.DatasetAttributes; -import org.janelia.saalfeldlab.n5.N5FSReader; -import org.janelia.saalfeldlab.n5.N5FSWriter; -import org.janelia.saalfeldlab.n5.N5Reader; -import org.janelia.saalfeldlab.n5.N5Writer; -import org.janelia.saalfeldlab.n5.imglib2.N5LabelMultisetCacheLoader; - import net.imglib2.Interval; import net.imglib2.RandomAccess; import net.imglib2.algorithm.util.Grids; @@ -31,9 +18,19 @@ import net.imglib2.util.Intervals; import net.imglib2.util.Util; import net.imglib2.view.Views; +import org.apache.spark.api.java.function.VoidFunction; +import org.janelia.saalfeldlab.n5.ByteArrayDataBlock; +import org.janelia.saalfeldlab.n5.DatasetAttributes; +import org.janelia.saalfeldlab.n5.N5FSReader; +import org.janelia.saalfeldlab.n5.N5FSWriter; +import org.janelia.saalfeldlab.n5.N5Reader; +import org.janelia.saalfeldlab.n5.N5Writer; +import org.janelia.saalfeldlab.n5.imglib2.N5LabelMultisetCacheLoader; -public class SparkDownsampleFunction implements VoidFunction< Interval > -{ +import java.util.Arrays; +import java.util.List; + +public class SparkDownsampleFunction implements VoidFunction { private static final long serialVersionUID = 1384028449836651390L; @@ -55,8 +52,8 @@ public SparkDownsampleFunction( final int[] factor, final String outputGroupName, final String outputDatasetName, - final int maxNumEntries ) - { + final int maxNumEntries) { + this.inputGroupName = inputGroupName; this.inputDatasetName = inputDatasetName; this.factor = factor; @@ -66,109 +63,105 @@ public SparkDownsampleFunction( } @Override - public void call( final Interval interval ) throws Exception - { + public void call(final Interval interval) throws Exception { - final N5Reader reader = new N5FSReader( inputGroupName ); - final DatasetAttributes attr = reader.getDatasetAttributes( inputDatasetName ); + final N5Reader reader = new N5FSReader(inputGroupName); + final DatasetAttributes attr = reader.getDatasetAttributes(inputDatasetName); final long[] sourceDimensions = attr.getDimensions(); final int[] sourceBlockSize = attr.getBlockSize(); final int nDim = attr.getNumDimensions(); - final long[] blockMinInTarget = Intervals.minAsLongArray( interval ); - final int[] blockSizeInTarget = Intervals.dimensionsAsIntArray( interval ); + final long[] blockMinInTarget = Intervals.minAsLongArray(interval); + final int[] blockSizeInTarget = Intervals.dimensionsAsIntArray(interval); - final long[] blockMinInSource = new long[ nDim ]; - final long[] blockMaxInSource = new long[ nDim ]; - final long[] blockSizeInSource = new long[ nDim ]; - Arrays.setAll( blockMinInSource, i -> factor[ i ] * blockMinInTarget[ i ] ); - Arrays.setAll( blockMaxInSource, i -> Math.min( blockMinInSource[ i ] + factor[ i ] * blockSizeInTarget[ i ] - 1, sourceDimensions[ i ] - 1 ) ); - Arrays.setAll( blockSizeInSource, i -> blockMaxInSource[ i ] - blockMinInSource[ i ] + 1 ); + final long[] blockMinInSource = new long[nDim]; + final long[] blockMaxInSource = new long[nDim]; + final long[] blockSizeInSource = new long[nDim]; + Arrays.setAll(blockMinInSource, i -> factor[i] * blockMinInTarget[i]); + Arrays.setAll(blockMaxInSource, i -> Math.min(blockMinInSource[i] + factor[i] * blockSizeInTarget[i] - 1, sourceDimensions[i] - 1)); + Arrays.setAll(blockSizeInSource, i -> blockMaxInSource[i] - blockMinInSource[i] + 1); - final CachedCellImg< LabelMultisetType, VolatileLabelMultisetArray > source = - getSource( new N5LabelMultisetCacheLoader( reader, inputDatasetName ), sourceDimensions, sourceBlockSize ); + final CachedCellImg source = + getSource(new N5LabelMultisetCacheLoader(reader, inputDatasetName), sourceDimensions, sourceBlockSize); final CellGrid sourceGrid = source.getCellGrid(); - final int[] sourceCellDimensions = new int[ sourceGrid.numDimensions() ]; - Arrays.setAll( sourceCellDimensions, sourceGrid::cellDimension ); - final List< long[] > cellPositions = Grids.collectAllOffsets( + final int[] sourceCellDimensions = new int[sourceGrid.numDimensions()]; + Arrays.setAll(sourceCellDimensions, sourceGrid::cellDimension); + final List cellPositions = Grids.collectAllOffsets( blockMinInSource, blockMaxInSource, - sourceCellDimensions ); - - final LazyCells< Cell< VolatileLabelMultisetArray > > cells = source.getCells(); - final RandomAccess< Cell< VolatileLabelMultisetArray > > cellsAccess = cells.randomAccess(); - for ( final long[] pos : cellPositions ) - { - Arrays.setAll( pos, d -> pos[ d ] / sourceCellDimensions[ d ] ); - cellsAccess.setPosition( pos ); + sourceCellDimensions); + + final LazyCells> cells = source.getCells(); + final RandomAccess> cellsAccess = cells.randomAccess(); + for (final long[] pos : cellPositions) { + Arrays.setAll(pos, d -> pos[d] / sourceCellDimensions[d]); + cellsAccess.setPosition(pos); } source.getCellGrid(); // TODO Should this be passed from outside? Needs to load one additional // block for (almost) all tasks int eachCount = 0; - for ( final Entry< Label > e : Util.getTypeFromInterval( source ).entrySet() ) - { + for (final Entry affinitiesSlice = Views.hyperSlice(affinitiesInterval, numDims, (long) channel); - final RandomAccessibleInterval affinitiesCopySlice = Views.hyperSlice(affinitiesCopy, numDims, (long) channel); + final RandomAccessibleInterval affinitiesSlice = Views.hyperSlice(affinitiesInterval, numDims, (long)channel); + final RandomAccessibleInterval affinitiesCopySlice = Views.hyperSlice(affinitiesCopy, numDims, (long)channel); final RandomAccessibleInterval maskTo = Views.interval(mask, Intervals.translate(block, channelOffset)); final RandomAccessibleInterval gliaTo = Views.interval(glia, Intervals.translate(block, channelOffset)); @@ -430,15 +441,15 @@ G extends RealType> void runMutexWatersheds( } sw1.stop(); System.out.println("Prepared affinities in " + StopWatch.secondsToString(sw1.seconds())); -// return new Tuple2<>(bwo, affinitiesCopy); -// }) -// .mapToPair(t -> { + // return new Tuple2<>(bwo, affinitiesCopy); + // }) + // .mapToPair(t -> { final RandomAccessibleInterval affinities = affinitiesCopy; final RandomAccessibleInterval target = ArrayImgs.unsignedLongs(Intervals.dimensionsAsLongArray(block)); final double[] edgeProbabilities = Stream .of(offsets) - .mapToDouble(o -> squaredSum(o) <= 1 ? 1.0:0.01) + .mapToDouble(o -> squaredSum(o) <= 1 ? 1.0 : 0.01) .toArray(); final long seed = IntervalIndexer.positionToIndex(min, outputSize); @@ -463,7 +474,7 @@ G extends RealType> void runMutexWatersheds( rng::nextDouble); } sw.stop(); -// System.out.println("Ran mutex watersheds in " + StopWatch.secondsToString(sw.seconds())); + // System.out.println("Ran mutex watersheds in " + StopWatch.secondsToString(sw.seconds())); final TLongLongHashMap counts = new TLongLongHashMap(); Views.flatIterable(target).forEach(px -> { @@ -471,7 +482,6 @@ G extends RealType> void runMutexWatersheds( counts.put(v, counts.get(v) + 1); }); - long index = 1L; final TLongLongMap mapping = new TLongLongHashMap(); mapping.put(0, 0); @@ -620,7 +630,6 @@ G extends RealType> void runMutexWatersheds( t.setInteger(mapping.containsKey(k) ? mapping.get(k) : k); }); - final long[] blockOffset = new long[bwo.min.length]; Arrays.setAll(blockOffset, d -> bwo.min[d] / blockSize[d]); final DatasetAttributes attributes = new DatasetAttributes(outputSize, blockSize, DataType.UINT64, new GzipCompression()); @@ -631,10 +640,10 @@ G extends RealType> void runMutexWatersheds( // Success!! outputContainer.get().setAttribute(mutexWatershedMergedDataset, "completedSuccessfully", true); - } private static TLongLongMap argMaxes(final TLongObjectMap counts) { + final TLongLongMap argMaxes = new TLongLongHashMap(); counts.forEachEntry((key, value) -> { long argMax = -1; @@ -656,6 +665,7 @@ private static TLongLongMap argMaxes(final TLongObjectMap counts) } private static void incrementFor(final TLongObjectMap counts, final long from, final long to) { + if (!counts.containsKey(from)) counts.put(from, new TLongLongHashMap()); final TLongLongMap toCounts = counts.get(from); @@ -663,6 +673,7 @@ private static void incrementFor(final TLongObjectMap counts, fina } private static long squaredSum(long... values) { + long sum = 0; for (long v : values) sum += v * v; @@ -670,6 +681,7 @@ private static long squaredSum(long... values) { } private static boolean[] attractiveEdges(long[]... offsets) { + final boolean[] attractiveEdges = new boolean[offsets.length]; for (int d = 0; d < offsets.length; ++d) attractiveEdges[d] = squaredSum(offsets[d]) <= 1L; @@ -685,16 +697,14 @@ private static class N5WriterSupplier implements SerializableSupplier & NumericType< private ZeroExtendedSupplier( final SerializableSupplier container, final String dataset) { + this.container = container; this.dataset = dataset; } @Override public RandomAccessible get() { - try { - final N5Reader container = this.container.get(); - final double[] offset = container.getAttribute(dataset, "offset", double[].class); - if (offset == null || Arrays.stream(offset).allMatch(d -> d ==0.0)) - return Views.extendZero(N5Utils.open(container, dataset)); - final double[] resolution = Optional.ofNullable(container.getAttribute(dataset, "resolution", double[].class)).orElse(new double[] {1.0, 1.0, 1.0}); - final long[] offsetInVoxels = new long[offset.length]; - Arrays.setAll(offsetInVoxels, d -> (long) (offset[d] / resolution[d])); - return Views.extendZero(Views.translate(N5Utils.open(container, dataset), offsetInVoxels)); - } catch (final IOException e) { - throw new RuntimeException(e); - } + + final N5Reader container = this.container.get(); + final double[] offset = container.getAttribute(dataset, "offset", double[].class); + if (offset == null || Arrays.stream(offset).allMatch(d -> d == 0.0)) + return Views.extendZero(N5Utils.open(container, dataset)); + final double[] resolution = Optional.ofNullable(container.getAttribute(dataset, "resolution", double[].class)).orElse(new double[]{1.0, 1.0, 1.0}); + final long[] offsetInVoxels = new long[offset.length]; + Arrays.setAll(offsetInVoxels, d -> (long)(offset[d] / resolution[d])); + return Views.extendZero(Views.translate(N5Utils.open(container, dataset), offsetInVoxels)); } } @@ -752,32 +758,38 @@ private static class IntervalWithOffset implements Serializable { public final long[] max; public IntervalWithOffset(long[] min, long[] max) { + this.min = min; this.max = max; } public IntervalWithOffset(final Interval interval) { + this(Intervals.minAsLongArray(interval), Intervals.maxAsLongArray(interval)); } @Override public int hashCode() { + return Arrays.hashCode(this.min); } @Override public boolean equals(final Object other) { + return other instanceof IntervalWithOffset && Arrays.equals(min, ((IntervalWithOffset)other).min); } } private static > boolean anyZero(final RandomAccessibleInterval rai) { + final T zero = Util.getTypeFromInterval(rai); zero.setZero(); return anyMatches(Views.iterable(rai), zero); } private static > boolean anyMatches(final Iterable iterable, final T comp) { + for (final T t : iterable) if (t.valueEquals(comp)) return true; diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatersheds.java b/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatersheds.java index 463ab58..e7a055d 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatersheds.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatersheds.java @@ -138,7 +138,7 @@ private static class Args implements Serializable, Callable { int[] blockSize = {64, 64, 64}; @Expose - @CommandLine.Option(names = "--blocks-per-task", paramLabel = "BLOCKS_PER_TASK", description = "How many blocks to combine for watersheds/connected components (one value per dimension)", split=",") + @CommandLine.Option(names = "--blocks-per-task", paramLabel = "BLOCKS_PER_TASK", description = "How many blocks to combine for watersheds/connected components (one value per dimension)", split = ",") int[] blocksPerTask = {1, 1, 1}; @Expose @@ -150,7 +150,7 @@ private static class Args implements Serializable, Callable { Double minimumAffinity = Double.NEGATIVE_INFINITY; @Expose - @CommandLine.Option(names = "--halo", paramLabel = "HALO", description = "Include halo region to run connected components/watersheds", split=",") + @CommandLine.Option(names = "--halo", paramLabel = "HALO", description = "Include halo region to run connected components/watersheds", split = ",") int[] halo = {0, 0, 0}; @Expose @@ -162,8 +162,8 @@ private static class Args implements Serializable, Callable { Boolean relabel; @Expose - @CommandLine.Option(names = "--revert-array-attributes", paramLabel = "RELABEL", description = "Revert all array attributes (that are not dataset attributes)", defaultValue = "false") - Boolean revertArrayAttributes; + @CommandLine.Option(names = "--reverse-array-attributes", paramLabel = "RELABEL", description = "Reverse all array attributes (that are not dataset attributes)", defaultValue = "false") + Boolean reverseArrayAttributes; @CommandLine.Option(names = "--json-pretty-print", defaultValue = "true") transient Boolean prettyPrint; @@ -171,7 +171,7 @@ private static class Args implements Serializable, Callable { @CommandLine.Option(names = "--json-disable-html-escape", defaultValue = "true") transient Boolean disbaleHtmlEscape; - @CommandLine.Option(names = { "-h", "--help"}, usageHelp = true, description = "Display this help and exit") + @CommandLine.Option(names = {"-h", "--help"}, usageHelp = true, description = "Display this help and exit") private Boolean help; @Override @@ -212,7 +212,6 @@ public static void run(final String... argv) throws IOException { labelUtilitiesSparkAttributes.put(VERSION_KEY, Version.VERSION_STRING); final Map attributes = with(new HashMap<>(), LABEL_UTILITIES_SPARK_KEY, labelUtilitiesSparkAttributes); - final int[] taskBlockSize = IntStream.range(0, args.blockSize.length).map(d -> args.blockSize[d] * args.blocksPerTask[d]).toArray(); final boolean hasHalo = Arrays.stream(args.halo).filter(h -> h != 0).count() > 0; if (hasHalo) @@ -221,8 +220,8 @@ public static void run(final String... argv) throws IOException { String[] uint64Datasets = {args.merged, args.seededWatersheds, args.watershedSeeds, args.blockMerged}; String[] uint8Datasets = {}; - final double[] resolution = reverted(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, RESOLUTION_KEY, double[].class)).orElse(ones(outputDims.length)), args.revertArrayAttributes); - final double[] offset = reverted(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, OFFSET_KEY, double[].class)).orElse(new double[outputDims.length]), args.revertArrayAttributes); + final double[] resolution = reversed(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, RESOLUTION_KEY, double[].class)).orElse(ones(outputDims.length)), args.reverseArrayAttributes); + final double[] offset = reversed(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, OFFSET_KEY, double[].class)).orElse(new double[outputDims.length]), args.reverseArrayAttributes); attributes.put(RESOLUTION_KEY, resolution); attributes.put(OFFSET_KEY, offset); @@ -254,7 +253,7 @@ public static void run(final String... argv) throws IOException { outputDims, args.minimumAffinity, IntStream.of(args.halo).mapToLong(i -> i).toArray(), -// new SerializableMergeWatershedsMinThresholdSupplier(args.threshold), + // new SerializableMergeWatershedsMinThresholdSupplier(args.threshold), new SerializableMergeWatershedsMedianThresholdSupplier(args.threshold), args.threshold, args.averagedAffinities, @@ -296,13 +295,13 @@ public static void run( .stream() .map(i -> new Tuple2<>(Intervals.minAsLongArray(i), Intervals.maxAsLongArray(i))) .collect(Collectors.toList()); - ; + ; final long[] negativeHalo = new long[halo.length]; Arrays.setAll(negativeHalo, d -> -halo[d]); final List, Integer>> idCounts = sc .parallelize(watershedBlocks) - .map(t -> (Interval) new FinalInterval(t._1(), t._2())) + .map(t -> (Interval)new FinalInterval(t._1(), t._2())) .mapToPair(new CropAffinities(n5in, averagedAffinities, minimumWatershedAffinity, halo)) .mapToPair(t -> { final Interval block = t._1(); @@ -337,8 +336,6 @@ public static void run( } } - - final Interval relevantInterval = Intervals.expand(labels, negativeHalo); final DatasetAttributes croppedAttributes = new DatasetAttributes(outputDims, blockSize, DataType.UINT64, new GzipCompression()); @@ -349,8 +346,8 @@ public static void run( N5Utils.saveBlock(Views.interval(labels, relevantInterval), n5out.get(), hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, croppedAttributes, blockOffset); if (hasHalo) { throw new UnsupportedOperationException("Need to implement halo support!"); -// final DataBlock dataBlock = new LongArrayDataBlock(Intervals.dimensionsAsIntArray(watershedSeedsMaskImg), watershedsBlockOffset, labels.update(null).getCurrentStorageArray()); -// n5out.get().writeBlock(watershedSeeds, watershedAttributes, dataBlock); + // final DataBlock dataBlock = new LongArrayDataBlock(Intervals.dimensionsAsIntArray(watershedSeedsMaskImg), watershedsBlockOffset, labels.update(null).getCurrentStorageArray()); + // n5out.get().writeBlock(watershedSeeds, watershedAttributes, dataBlock); } N5Utils.saveBlock(Views.interval(labels, relevantInterval), n5out.get(), hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, croppedAttributes, blockOffset); @@ -407,8 +404,7 @@ public static void run( return new Tuple2<>(new Tuple2<>(Intervals.minAsLongArray(t._1()), Intervals.maxAsLongArray(t._1())), ids.size()); }) - .collect() - ; + .collect(); long startIndex = 1; final List, Long>> idOffsets = new ArrayList<>(); @@ -424,9 +420,9 @@ public static void run( LOG.debug("Relabeling block ({} {}) starting at id {}", t._1()._1(), t._1()._2(), t._2()); final N5Writer n5 = n5out.get(); final Interval interval = new FinalInterval(t._1()._1(), t._1()._2()); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, t._2()); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, t._2()); relabel(n5, hasHalo ? String.format(croppedDatasetPattern, merged) : merged, merged, interval, t._2()); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, t._2()); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, t._2()); if (hasHalo) throw new UnsupportedOperationException("Halo relabeling not implemented yet!"); @@ -448,7 +444,7 @@ public static void run( if (maxId + 2 > Integer.MAX_VALUE) throw new RuntimeException("Currently only Integer.MAX_VALUE labels supported"); -// final IntArrayUnionFind uf = findOverlappingLabelsArgMaxNoHalo(sc, n5out, merged, new IntArrayUnionFind((int) (maxId + 2)), outputDims, blockSize, blocksPerTask, 0); + // final IntArrayUnionFind uf = findOverlappingLabelsArgMaxNoHalo(sc, n5out, merged, new IntArrayUnionFind((int) (maxId + 2)), outputDims, blockSize, blocksPerTask, 0); final LongHashMapUnionFind uf = findOverlappingLabelsThresholdMedianEdgeAffinities( sc, n5out, @@ -473,10 +469,9 @@ public static void run( : (maxId); LOG.debug("Max label = {} min label = {} for block ({} {})", maxLabel, minLabel, idOffset._1()._1(), idOffset._1()._2()); - final long[] keys = new long[(int) (maxLabel - minLabel)]; + final long[] keys = new long[(int)(maxLabel - minLabel)]; final long[] vals = new long[keys.length]; - for (int i = 0; i < keys.length; ++i) { final long k = i + minLabel; final long root = uf.findRoot(k); @@ -496,9 +491,9 @@ public static void run( final Interval interval = new FinalInterval(t._1()._1(), t._1()._2()); final TLongLongMap mapping = new TLongLongHashMap(t._2()._1(), t._2()._2()); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, mapping); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, mapping); relabel(n5out.get(), hasHalo ? String.format(croppedDatasetPattern, merged) : merged, blockMerged, interval, mapping); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, mapping); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, mapping); if (hasHalo) throw new UnsupportedOperationException("Halo relabeling not implemented yet!"); }); @@ -513,6 +508,7 @@ private static void relabel( final String target, final Interval interval, final TLongLongMap mapping) throws IOException { + SparkWatersheds.relabel(n5, source, target, interval, (src, tgt) -> { final long val = mapping.get(src.getIntegerLong()); if (val != 0) @@ -526,6 +522,7 @@ private static void relabel( final String target, final Interval interval, final long addIfNotZero) throws IOException { + final CachedMapper mapper = new CachedMapper(addIfNotZero); SparkWatersheds.relabel(n5, source, target, interval, (src, tgt) -> tgt.set(mapper.applyAsLong(src.getIntegerLong()))); } @@ -536,6 +533,7 @@ private static & NativeType> void relabel( final String target, final Interval interval, final BiConsumer idMapping) throws IOException { + final DatasetAttributes attributes = n5.getDatasetAttributes(source); final CellGrid grid = new CellGrid(attributes.getDimensions(), attributes.getBlockSize()); final RandomAccessibleInterval data = Views.interval(N5Utils.open(n5, source), interval); @@ -561,6 +559,7 @@ private static void relabel( final String dataset, final long[] blockPos, final long addIfNonZero) throws IOException { + relabel(n5, dataset, blockPos, id -> id == 0 ? 0 : id + addIfNonZero); } @@ -569,14 +568,16 @@ private static void relabel( final String dataset, final long[] blockPos, final LongUnaryOperator idMapping) throws IOException { + final DatasetAttributes attributes = n5.getDatasetAttributes(dataset); - final LongArrayDataBlock block = ((LongArrayDataBlock) n5.readBlock(dataset, attributes, blockPos)); + final LongArrayDataBlock block = ((LongArrayDataBlock)n5.readBlock(dataset, attributes, blockPos)); final long[] data = block.getData(); for (int i = 0; i < data.length; ++i) { data[i] = idMapping.applyAsLong(data[i]); } n5.writeBlock(dataset, attributes, new LongArrayDataBlock(block.getSize(), data, block.getGridPosition())); } + private static void prepareOutputDatasets( final N5Writer n5, final Map datasets, @@ -591,12 +592,14 @@ private static void prepareOutputDataset( final String dataset, final DatasetAttributes attributes, final Map additionalAttributes) throws IOException { + n5.createDataset(dataset, attributes); for (Map.Entry entry : additionalAttributes.entrySet()) n5.setAttribute(dataset, entry.getKey(), entry.getValue()); } private static Map with(Map map, K key, V value) { + map.put(key, value); return map; } @@ -612,6 +615,7 @@ private static class N5WriterSupplier implements Supplier, Serializabl private final boolean serializeSpecialFloatingPointValues = true; private N5WriterSupplier(final String container, final boolean withPrettyPrinting, final boolean disableHtmlEscaping) { + this.container = container; this.withPrettyPrinting = withPrettyPrinting; this.disableHtmlEscaping = disableHtmlEscaping; @@ -620,43 +624,46 @@ private N5WriterSupplier(final String container, final boolean withPrettyPrintin @Override public N5Writer get() { - try { - return Files.isDirectory(Paths.get(container)) - ? new N5FSWriter(container, createaBuilder()) - : new N5HDF5Writer(container); - } catch (final IOException e) { - throw new RuntimeException(e); - } + return Files.isDirectory(Paths.get(container)) + ? new N5FSWriter(container, createaBuilder()) + : new N5HDF5Writer(container); } private GsonBuilder createaBuilder() { + return serializeSpecialFloatingPointValues(withPrettyPrinting(disableHtmlEscaping(new GsonBuilder()))); } private GsonBuilder serializeSpecialFloatingPointValues(final GsonBuilder builder) { + return with(builder, this.serializeSpecialFloatingPointValues, GsonBuilder::serializeSpecialFloatingPointValues); } private GsonBuilder withPrettyPrinting(final GsonBuilder builder) { + return with(builder, this.withPrettyPrinting, GsonBuilder::setPrettyPrinting); } private GsonBuilder disableHtmlEscaping(final GsonBuilder builder) { + return with(builder, this.disableHtmlEscaping, GsonBuilder::disableHtmlEscaping); } private static GsonBuilder with(final GsonBuilder builder, boolean applyAction, Function action) { + return applyAction ? action.apply(builder) : builder; } } private static double[] ones(final int length) { + double[] ones = new double[length]; Arrays.fill(ones, 1.0); return ones; } private static Interval addDimension(final Interval interval, final long m, final long M) { + long[] min = new long[interval.numDimensions() + 1]; long[] max = new long[interval.numDimensions() + 1]; for (int d = 0; d < interval.numDimensions(); ++d) { @@ -669,14 +676,17 @@ private static Interval addDimension(final Interval interval, final long m, fina } private static String toString(final Interval interval) { + return String.format("(%s %s)", Arrays.toString(Intervals.minAsLongArray(interval)), Arrays.toString(Intervals.maxAsLongArray(interval))); } - private static double[] reverted(final double[] array, final boolean revert) { - return revert ? reverted(array) : array; + private static double[] reversed(final double[] array, final boolean reverse) { + + return reverse ? reversed(array) : array; } - private static double[] reverted(final double[] array) { + private static double[] reversed(final double[] array) { + final double[] copy = new double[array.length]; for (int i = 0, k = copy.length - 1; i < copy.length; ++i, --k) { copy[i] = array[k]; @@ -689,6 +699,7 @@ private static > ArrayImg smooth( final Interval interval, final int channelDim, double sigma) { + final ArrayImg img = ArrayImgs.floats(Intervals.dimensionsAsLongArray(interval)); for (long channel = interval.min(channelDim); channel <= interval.max(channelDim); ++channel) { @@ -705,6 +716,7 @@ private static > void invalidateOutOfBlockAffinities( final T invalid, final long[]... offsets ) { + for (int index = 0; index < offsets.length; ++index) { final IntervalView slice = Views.hyperSlice(affs, affs.numDimensions() - 1, index); for (int d = 0; d < offsets[index].length; ++d) { @@ -755,7 +767,6 @@ private static UF findOverlappingLabelsThresholdMedianEdg final RandomAccessibleInterval thisBlockLabels = Views.interval(labels, minMax._1(), thisBlockMax); final RandomAccessibleInterval affinities = N5Utils.open(affinitiesContainer.get(), affinitiesDataset); - final TLongSet ignoreTheseSet = new TLongHashSet(ignoreThese); final TLongLongHashMap mapping = new TLongLongHashMap(); @@ -813,8 +824,7 @@ private static UF findOverlappingLabelsThresholdMedianEdg if (thisLabel < thatLabel) { e1 = thisLabel; e2 = thatLabel; - } - else { + } else { e1 = thatLabel; e2 = thisLabel; } @@ -835,7 +845,6 @@ private static UF findOverlappingLabelsThresholdMedianEdg LOG.info("Edge affinities: {}", affinitiesByEdge); - affinitiesByEdge.forEachEntry((k, v) -> { TLongObjectIterator edgeIt = v.iterator(); while (edgeIt.hasNext()) { @@ -878,7 +887,6 @@ private static UF findOverlappingLabelsThresholdMedianEdg return uf; - } private static UF findOverlappingLabelsArgMaxNoHalo( @@ -964,13 +972,12 @@ private static UF findOverlappingLabelsArgMaxNoHalo( addOne(thisMap.get(thisLabel), thatLabel); addOne(thatMap.get(thatLabel), thisLabel); -// thatArgMax.forEachEntry((k, v) -> { -//// if (thatArgMax.get(v) == k) -// localUF.join(localUF.findRoot(v), localUF.findRoot(k)); -//// mapping.put(k, v); -// return true; -// }); - + // thatArgMax.forEachEntry((k, v) -> { + //// if (thatArgMax.get(v) == k) + // localUF.join(localUF.findRoot(v), localUF.findRoot(k)); + //// mapping.put(k, v); + // return true; + // }); } @@ -1013,7 +1020,6 @@ private static UF findOverlappingLabelsArgMaxNoHalo( return uf; - } private static class CropAffinities implements PairFunction> { @@ -1031,6 +1037,7 @@ private CropAffinities( final String affinities, final double minimumAffinity, final long[] halo) { + this.n5in = n5in; this.affinities = affinities; this.minmimumAffinity = minimumAffinity; @@ -1039,6 +1046,7 @@ private CropAffinities( @Override public Tuple2> call(final Interval interval) throws Exception { + final RandomAccessibleInterval affs = N5Utils.open(n5in.get(), affinities); final Interval withHalo = Intervals.expand(interval, halo); @@ -1059,14 +1067,17 @@ public Tuple2> call(final Interval } private static T toMinMaxTuple(final Interval interval, BiFunction toTuple) { + return toTuple.apply(Intervals.minAsLongArray(interval), Intervals.maxAsLongArray(interval)); } private static void addOne(final TLongIntMap countMap, final long label) { + countMap.put(label, countMap.get(label) + 1); } private static TLongLongMap argMaxCounts(final TLongObjectMap counts) { + final TLongLongMap mapping = new TLongLongHashMap(); counts.forEachEntry((k, v) -> { mapping.put(k, argMaxCount(v)); @@ -1076,6 +1087,7 @@ private static TLongLongMap argMaxCounts(final TLongObjectMap count } private static long argMaxCount(final TLongIntMap counts) { + long maxCount = Long.MIN_VALUE; long argMaxCount = 0; for (final TLongIntIterator it = counts.iterator(); it.hasNext(); ) { @@ -1085,18 +1097,19 @@ private static long argMaxCount(final TLongIntMap counts) { maxCount = v; argMaxCount = it.key(); } - }; + } + ; return argMaxCount; } private static class CachedMapper implements LongUnaryOperator { - private long nextId; private final TLongLongMap cache = new TLongLongHashMap(); private CachedMapper(final long firstId) { + this.nextId = firstId; } @@ -1114,6 +1127,7 @@ public long applyAsLong(long l) { } private static T computeIfAbsent(final TLongObjectMap map, final long key, final LongFunction mappingFactory) { + final T value = map.get(key); if (value != null) return value; diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatershedsOnDistanceTransform.java b/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatershedsOnDistanceTransform.java index 1f79751..361afd5 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatershedsOnDistanceTransform.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatershedsOnDistanceTransform.java @@ -148,7 +148,7 @@ private static class Args implements Serializable, Callable { int[] blockSize = {64, 64, 64}; @Expose - @CommandLine.Option(names = "--blocks-per-task", paramLabel = "BLOCKS_PER_TASK", description = "How many blocks to combine for watersheds/connected components (one value per dimension)", split=",") + @CommandLine.Option(names = "--blocks-per-task", paramLabel = "BLOCKS_PER_TASK", description = "How many blocks to combine for watersheds/connected components (one value per dimension)", split = ",") int[] blocksPerTask = {1, 1, 1}; @Expose @@ -168,8 +168,8 @@ private static class Args implements Serializable, Callable { Boolean relabel; @Expose - @CommandLine.Option(names = "--revert-array-attributes", paramLabel = "RELABEL", description = "Revert all array attributes (that are not dataset attributes)", defaultValue = "false") - Boolean revertArrayAttributes; + @CommandLine.Option(names = "--reverse-array-attributes", paramLabel = "RELABEL", description = "Reverse all array attributes (that are not dataset attributes)", defaultValue = "false") + Boolean reverseArrayAttributes; @CommandLine.Option(names = "--json-pretty-print", defaultValue = "true") transient Boolean prettyPrint; @@ -177,7 +177,7 @@ private static class Args implements Serializable, Callable { @CommandLine.Option(names = "--json-disable-html-escape", defaultValue = "true") transient Boolean disbaleHtmlEscape; - @CommandLine.Option(names = { "-h", "--help"}, usageHelp = true, description = "Display this help and exit") + @CommandLine.Option(names = {"-h", "--help"}, usageHelp = true, description = "Display this help and exit") private Boolean help; @Override @@ -219,12 +219,11 @@ public static void run(final String... argv) throws IOException { labelUtilitiesSparkAttributes.put(VERSION_KEY, Version.VERSION_STRING); final Map attributes = with(new HashMap<>(), LABEL_UTILITIES_SPARK_KEY, labelUtilitiesSparkAttributes); - String[] uint64Datasets = {args.merged, args.seededWatersheds, args.watershedSeeds, args.blockMerged}; String[] float64Datasets = {args.distanceTransform}; - final double[] resolution = reverted(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, RESOLUTION_KEY, double[].class)).orElse(ones(outputDims.length)), args.revertArrayAttributes); - final double[] offset = reverted(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, OFFSET_KEY, double[].class)).orElse(new double[outputDims.length]), args.revertArrayAttributes); + final double[] resolution = reversed(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, RESOLUTION_KEY, double[].class)).orElse(ones(outputDims.length)), args.reverseArrayAttributes); + final double[] offset = reversed(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, OFFSET_KEY, double[].class)).orElse(new double[outputDims.length]), args.reverseArrayAttributes); attributes.put(RESOLUTION_KEY, resolution); attributes.put(OFFSET_KEY, offset); @@ -250,7 +249,7 @@ public static void run(final String... argv) throws IOException { DoubleStream.of(resolution).map(d -> d * d).toArray(), // TODO maybe pass these as parameters through CLI instead? args.affinityThresholdDistanceTransform, args.seedDistance, -// new SerializableMergeWatershedsMinThresholdSupplier(args.threshold), + // new SerializableMergeWatershedsMinThresholdSupplier(args.threshold), new SerializableMergeWatershedsMedianThresholdSupplier(args.mergeFragmentThreshold), args.mergeFragmentThreshold, args.averagedAffinities, @@ -292,11 +291,11 @@ public static void run( .stream() .map(i -> new Tuple2<>(Intervals.minAsLongArray(i), Intervals.maxAsLongArray(i))) .collect(Collectors.toList()); - ; + ; final List, Integer>> idCounts = sc .parallelize(watershedBlocks) - .map(t -> (Interval) new FinalInterval(t._1(), t._2())) + .map(t -> (Interval)new FinalInterval(t._1(), t._2())) .mapToPair(new CropAffinitiesToDistanceTransform(n5in, averagedAffinities, threshold, distanceTransformWeights)) .mapToPair(t -> { final Interval block = t._1(); @@ -376,8 +375,7 @@ public static void run( return new Tuple2<>(new Tuple2<>(Intervals.minAsLongArray(t._1()), Intervals.maxAsLongArray(t._1())), ids.size()); }) - .collect() - ; + .collect(); long startIndex = 1; final List, Long>> idOffsets = new ArrayList<>(); @@ -393,9 +391,9 @@ public static void run( LOG.debug("Relabeling block ({} {}) starting at id {}", t._1()._1(), t._1()._2(), t._2()); final N5Writer n5 = n5out.get(); final Interval interval = new FinalInterval(t._1()._1(), t._1()._2()); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, t._2()); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, t._2()); relabel(n5, merged, merged, interval, t._2()); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, t._2()); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, t._2()); // TODO do halo relabeling @@ -409,7 +407,7 @@ public static void run( n5out.get().setAttribute(merged, "maxId", maxId); n5out.get().setAttribute(seededWatersheds, "maxId", maxId); -// final IntArrayUnionFind uf = findOverlappingLabelsArgMaxNoHalo(sc, n5out, merged, new IntArrayUnionFind((int) (maxId + 2)), outputDims, blockSize, blocksPerTask, 0); + // final IntArrayUnionFind uf = findOverlappingLabelsArgMaxNoHalo(sc, n5out, merged, new IntArrayUnionFind((int) (maxId + 2)), outputDims, blockSize, blocksPerTask, 0); final LongHashMapUnionFind uf = findOverlappingLabelsThresholdMedianEdgeAffinities( sc, n5out, @@ -434,10 +432,9 @@ public static void run( : (maxId); LOG.debug("Max label = {} min label = {} for block ({} {})", maxLabel, minLabel, idOffset._1()._1(), idOffset._1()._2()); - final long[] keys = new long[(int) (maxLabel - minLabel)]; + final long[] keys = new long[(int)(maxLabel - minLabel)]; final long[] vals = new long[keys.length]; - for (int i = 0; i < keys.length; ++i) { final long k = i + minLabel; final long root = uf.findRoot(k); @@ -457,9 +454,9 @@ public static void run( final Interval interval = new FinalInterval(t._1()._1(), t._1()._2()); final TLongLongMap mapping = new TLongLongHashMap(t._2()._1(), t._2()._2()); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, mapping); - relabel(n5out.get(), merged, blockMerged, interval, mapping); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, mapping); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, mapping); + relabel(n5out.get(), merged, blockMerged, interval, mapping); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, mapping); }); } @@ -472,6 +469,7 @@ private static void relabel( final String target, final Interval interval, final TLongLongMap mapping) throws IOException { + SparkWatershedsOnDistanceTransform.relabel(n5, source, target, interval, (src, tgt) -> { final long val = mapping.get(src.getIntegerLong()); if (val != 0) @@ -485,6 +483,7 @@ private static void relabel( final String target, final Interval interval, final long addIfNotZero) throws IOException { + final CachedMapper mapper = new CachedMapper(addIfNotZero); SparkWatershedsOnDistanceTransform.relabel(n5, source, target, interval, (src, tgt) -> tgt.set(mapper.applyAsLong(src.getIntegerLong()))); } @@ -495,6 +494,7 @@ private static & NativeType> void relabel( final String target, final Interval interval, final BiConsumer idMapping) throws IOException { + final DatasetAttributes attributes = n5.getDatasetAttributes(source); final CellGrid grid = new CellGrid(attributes.getDimensions(), attributes.getBlockSize()); final RandomAccessibleInterval data = Views.interval(N5Utils.open(n5, source), interval); @@ -520,6 +520,7 @@ private static void relabel( final String dataset, final long[] blockPos, final long addIfNonZero) throws IOException { + relabel(n5, dataset, blockPos, id -> id == 0 ? 0 : id + addIfNonZero); } @@ -528,14 +529,16 @@ private static void relabel( final String dataset, final long[] blockPos, final LongUnaryOperator idMapping) throws IOException { + final DatasetAttributes attributes = n5.getDatasetAttributes(dataset); - final LongArrayDataBlock block = ((LongArrayDataBlock) n5.readBlock(dataset, attributes, blockPos)); + final LongArrayDataBlock block = ((LongArrayDataBlock)n5.readBlock(dataset, attributes, blockPos)); final long[] data = block.getData(); for (int i = 0; i < data.length; ++i) { data[i] = idMapping.applyAsLong(data[i]); } n5.writeBlock(dataset, attributes, new LongArrayDataBlock(block.getSize(), data, block.getGridPosition())); } + private static void prepareOutputDatasets( final N5Writer n5, final Map datasets, @@ -550,12 +553,14 @@ private static void prepareOutputDataset( final String dataset, final DatasetAttributes attributes, final Map additionalAttributes) throws IOException { + n5.createDataset(dataset, attributes); for (Map.Entry entry : additionalAttributes.entrySet()) n5.setAttribute(dataset, entry.getKey(), entry.getValue()); } private static Map with(Map map, K key, V value) { + map.put(key, value); return map; } @@ -571,6 +576,7 @@ private static class N5WriterSupplier implements Supplier, Serializabl private final boolean serializeSpecialFloatingPointValues = true; private N5WriterSupplier(final String container, final boolean withPrettyPrinting, final boolean disableHtmlEscaping) { + this.container = container; this.withPrettyPrinting = withPrettyPrinting; this.disableHtmlEscaping = disableHtmlEscaping; @@ -579,43 +585,46 @@ private N5WriterSupplier(final String container, final boolean withPrettyPrintin @Override public N5Writer get() { - try { - return Files.isDirectory(Paths.get(container)) - ? new N5FSWriter(container, createaBuilder()) - : new N5HDF5Writer(container); - } catch (final IOException e) { - throw new RuntimeException(e); - } + return Files.isDirectory(Paths.get(container)) + ? new N5FSWriter(container, createaBuilder()) + : new N5HDF5Writer(container); } private GsonBuilder createaBuilder() { + return serializeSpecialFloatingPointValues(withPrettyPrinting(disableHtmlEscaping(new GsonBuilder()))); } private GsonBuilder serializeSpecialFloatingPointValues(final GsonBuilder builder) { + return with(builder, this.serializeSpecialFloatingPointValues, GsonBuilder::serializeSpecialFloatingPointValues); } private GsonBuilder withPrettyPrinting(final GsonBuilder builder) { + return with(builder, this.withPrettyPrinting, GsonBuilder::setPrettyPrinting); } private GsonBuilder disableHtmlEscaping(final GsonBuilder builder) { + return with(builder, this.disableHtmlEscaping, GsonBuilder::disableHtmlEscaping); } private static GsonBuilder with(final GsonBuilder builder, boolean applyAction, Function action) { + return applyAction ? action.apply(builder) : builder; } } private static double[] ones(final int length) { + double[] ones = new double[length]; Arrays.fill(ones, 1.0); return ones; } private static Interval addDimension(final Interval interval, final long m, final long M) { + long[] min = new long[interval.numDimensions() + 1]; long[] max = new long[interval.numDimensions() + 1]; for (int d = 0; d < interval.numDimensions(); ++d) { @@ -628,14 +637,17 @@ private static Interval addDimension(final Interval interval, final long m, fina } private static String toString(final Interval interval) { + return String.format("(%s %s)", Arrays.toString(Intervals.minAsLongArray(interval)), Arrays.toString(Intervals.maxAsLongArray(interval))); } - private static double[] reverted(final double[] array, final boolean revert) { - return revert ? reverted(array) : array; + private static double[] reversed(final double[] array, final boolean reverse) { + + return reverse ? reversed(array) : array; } - private static double[] reverted(final double[] array) { + private static double[] reversed(final double[] array) { + final double[] copy = new double[array.length]; for (int i = 0, k = copy.length - 1; i < copy.length; ++i, --k) { copy[i] = array[k]; @@ -648,6 +660,7 @@ private static > ArrayImg smooth( final Interval interval, final int channelDim, double sigma) { + final ArrayImg img = ArrayImgs.floats(Intervals.dimensionsAsLongArray(interval)); for (long channel = interval.min(channelDim); channel <= interval.max(channelDim); ++channel) { @@ -664,6 +677,7 @@ private static > void invalidateOutOfBlockAffinities( final T invalid, final long[]... offsets ) { + for (int index = 0; index < offsets.length; ++index) { final IntervalView slice = Views.hyperSlice(affs, affs.numDimensions() - 1, index); for (int d = 0; d < offsets[index].length; ++d) { @@ -714,7 +728,6 @@ private static UF findOverlappingLabelsThresholdMedianEdg final RandomAccessibleInterval thisBlockLabels = Views.interval(labels, minMax._1(), thisBlockMax); final RandomAccessibleInterval affinities = N5Utils.open(affinitiesContainer.get(), affinitiesDataset); - final TLongSet ignoreTheseSet = new TLongHashSet(ignoreThese); final TLongLongHashMap mapping = new TLongLongHashMap(); @@ -773,8 +786,7 @@ private static UF findOverlappingLabelsThresholdMedianEdg if (thisLabel < thatLabel) { e1 = thisLabel; e2 = thatLabel; - } - else { + } else { e1 = thatLabel; e2 = thisLabel; } @@ -795,7 +807,6 @@ private static UF findOverlappingLabelsThresholdMedianEdg LOG.info("Edge affinities: {}", affinitiesByEdge); - affinitiesByEdge.forEachEntry((k, v) -> { TLongObjectIterator edgeIt = v.iterator(); while (edgeIt.hasNext()) { @@ -838,7 +849,6 @@ private static UF findOverlappingLabelsThresholdMedianEdg return uf; - } private static UF findOverlappingLabelsArgMaxNoHalo( @@ -925,13 +935,12 @@ private static UF findOverlappingLabelsArgMaxNoHalo( addOne(thisMap.get(thisLabel), thatLabel); addOne(thatMap.get(thatLabel), thisLabel); -// thatArgMax.forEachEntry((k, v) -> { -//// if (thatArgMax.get(v) == k) -// localUF.join(localUF.findRoot(v), localUF.findRoot(k)); -//// mapping.put(k, v); -// return true; -// }); - + // thatArgMax.forEachEntry((k, v) -> { + //// if (thatArgMax.get(v) == k) + // localUF.join(localUF.findRoot(v), localUF.findRoot(k)); + //// mapping.put(k, v); + // return true; + // }); } @@ -974,7 +983,6 @@ private static UF findOverlappingLabelsArgMaxNoHalo( return uf; - } private static class CropAffinitiesToDistanceTransform implements PairFunction, RandomAccessibleInterval>> { @@ -992,6 +1000,7 @@ private CropAffinitiesToDistanceTransform( final String affinities, final double threshold, final double[] weights) { + this.n5in = n5in; this.affinities = affinities; this.threshold = threshold; @@ -1000,6 +1009,7 @@ private CropAffinitiesToDistanceTransform( @Override public Tuple2, RandomAccessibleInterval>> call(final Interval interval) throws Exception { + final RandomAccessibleInterval affsImg = N5Utils.open(n5in.get(), affinities); final RandomAccessible affs = Views.extendValue(affsImg, new FloatType(Float.NaN)); final long[] min = Intervals.minAsLongArray(interval); @@ -1017,7 +1027,7 @@ public Tuple2, RandomAccess DistanceTransform.DISTANCE_TYPE.EUCLIDIAN, weights); // TODO should we actually rewrite those? - double[] minMax = new double[] {Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY}; + double[] minMax = new double[]{Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY}; LoopBuilder .setImages(Views.interval(affs, withContext), distanceTransform) .forEachPixel((a, d) -> { @@ -1038,14 +1048,17 @@ public Tuple2, RandomAccess } private static T toMinMaxTuple(final Interval interval, BiFunction toTuple) { + return toTuple.apply(Intervals.minAsLongArray(interval), Intervals.maxAsLongArray(interval)); } private static void addOne(final TLongIntMap countMap, final long label) { + countMap.put(label, countMap.get(label) + 1); } private static TLongLongMap argMaxCounts(final TLongObjectMap counts) { + final TLongLongMap mapping = new TLongLongHashMap(); counts.forEachEntry((k, v) -> { mapping.put(k, argMaxCount(v)); @@ -1055,6 +1068,7 @@ private static TLongLongMap argMaxCounts(final TLongObjectMap count } private static long argMaxCount(final TLongIntMap counts) { + long maxCount = Long.MIN_VALUE; long argMaxCount = 0; for (final TLongIntIterator it = counts.iterator(); it.hasNext(); ) { @@ -1064,18 +1078,19 @@ private static long argMaxCount(final TLongIntMap counts) { maxCount = v; argMaxCount = it.key(); } - }; + } + ; return argMaxCount; } private static class CachedMapper implements LongUnaryOperator { - private long nextId; private final TLongLongMap cache = new TLongLongHashMap(); private CachedMapper(final long firstId) { + this.nextId = firstId; } @@ -1093,6 +1108,7 @@ public long applyAsLong(long l) { } private static T computeIfAbsent(final TLongObjectMap map, final long key, final LongFunction mappingFactory) { + final T value = map.get(key); if (value != null) return value; diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatershedsOnDistanceTransformOfSampledFunction.java b/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatershedsOnDistanceTransformOfSampledFunction.java index cc29e65..1aa0302 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatershedsOnDistanceTransformOfSampledFunction.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatershedsOnDistanceTransformOfSampledFunction.java @@ -150,7 +150,7 @@ private static class Args implements Serializable, Callable { int[] blockSize = {64, 64, 64}; @Expose - @CommandLine.Option(names = "--blocks-per-task", paramLabel = "BLOCKS_PER_TASK", description = "How many blocks to combine for watersheds/connected components (one value per dimension)", split=",") + @CommandLine.Option(names = "--blocks-per-task", paramLabel = "BLOCKS_PER_TASK", description = "How many blocks to combine for watersheds/connected components (one value per dimension)", split = ",") int[] blocksPerTask = {1, 1, 1}; @Expose @@ -170,8 +170,8 @@ private static class Args implements Serializable, Callable { Boolean relabel; @Expose - @CommandLine.Option(names = "--revert-array-attributes", paramLabel = "RELABEL", description = "Revert all array attributes (that are not dataset attributes)", defaultValue = "false") - Boolean revertArrayAttributes; + @CommandLine.Option(names = "--reverse-array-attributes", paramLabel = "RELABEL", description = "Reverse all array attributes (that are not dataset attributes)", defaultValue = "false") + Boolean reverseArrayAttributes; @CommandLine.Option(names = "--json-pretty-print", defaultValue = "true") transient Boolean prettyPrint; @@ -179,7 +179,7 @@ private static class Args implements Serializable, Callable { @CommandLine.Option(names = "--json-disable-html-escape", defaultValue = "true") transient Boolean disbaleHtmlEscape; - @CommandLine.Option(names = { "-h", "--help"}, usageHelp = true, description = "Display this help and exit") + @CommandLine.Option(names = {"-h", "--help"}, usageHelp = true, description = "Display this help and exit") private Boolean help; @Override @@ -221,12 +221,11 @@ public static void run(final String... argv) throws IOException { labelUtilitiesSparkAttributes.put(VERSION_KEY, Version.VERSION_STRING); final Map attributes = with(new HashMap<>(), LABEL_UTILITIES_SPARK_KEY, labelUtilitiesSparkAttributes); - String[] uint64Datasets = {args.merged, args.seededWatersheds, args.watershedSeeds, args.blockMerged}; String[] float64Datasets = {args.distanceTransform}; - final double[] resolution = reverted(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, RESOLUTION_KEY, double[].class)).orElse(ones(outputDims.length)), args.revertArrayAttributes); - final double[] offset = reverted(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, OFFSET_KEY, double[].class)).orElse(new double[outputDims.length]), args.revertArrayAttributes); + final double[] resolution = reversed(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, RESOLUTION_KEY, double[].class)).orElse(ones(outputDims.length)), args.reverseArrayAttributes); + final double[] offset = reversed(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, OFFSET_KEY, double[].class)).orElse(new double[outputDims.length]), args.reverseArrayAttributes); attributes.put(RESOLUTION_KEY, resolution); attributes.put(OFFSET_KEY, offset); @@ -259,7 +258,7 @@ public static void run(final String... argv) throws IOException { outputDims, DoubleStream.of(resolution).map(d -> d / DoubleStream.of(resolution).min().getAsDouble()).map(d -> args.weightDistanceTransform * d * d).toArray(), // TODO maybe pass these as parameters through CLI instead? args.seedDistance, -// new SerializableMergeWatershedsMinThresholdSupplier(args.threshold), + // new SerializableMergeWatershedsMinThresholdSupplier(args.threshold), new SerializableMergeWatershedsMedianThresholdSupplier(args.mergeFragmentThreshold), args.mergeFragmentThreshold, args.averagedAffinities, @@ -301,11 +300,11 @@ public static void run( .stream() .map(i -> new Tuple2<>(Intervals.minAsLongArray(i), Intervals.maxAsLongArray(i))) .collect(Collectors.toList()); - ; + ; final List, Integer>> idCounts = sc .parallelize(watershedBlocks) - .map(t -> (Interval) new FinalInterval(t._1(), t._2())) + .map(t -> (Interval)new FinalInterval(t._1(), t._2())) .mapToPair(new CropAffinitiesToDistanceTransform(n5in, averagedAffinities, distanceTransformWeights)) .mapToPair(t -> { final Interval block = t._1(); @@ -383,18 +382,17 @@ public static void run( N5Utils.saveBlock(labels, n5out.get(), merged, attributes.apply(DataType.UINT64), blockOffset); final boolean wasSuccessful = true; -// final IntervalView reloaded = Views.interval(N5Utils.open(n5out.get(), merged), block); -// final Cursor r = Views.flatIterable(reloaded).cursor(); -// final Cursor m = Views.flatIterable(labels).cursor(); -// boolean wasSuccessful = true; -// while(r.hasNext() && wasSuccessful) { -// wasSuccessful = r.next().valueEquals(m.next()); -// } + // final IntervalView reloaded = Views.interval(N5Utils.open(n5out.get(), merged), block); + // final Cursor r = Views.flatIterable(reloaded).cursor(); + // final Cursor m = Views.flatIterable(labels).cursor(); + // boolean wasSuccessful = true; + // while(r.hasNext() && wasSuccessful) { + // wasSuccessful = r.next().valueEquals(m.next()); + // } return new Tuple2<>(new Tuple2<>(Intervals.minAsLongArray(t._1()), Intervals.maxAsLongArray(t._1())), wasSuccessful ? ids.size() : -1); }) - .collect() - ; + .collect(); if (idCounts.stream().mapToInt(Tuple2::_2).anyMatch(c -> c < 0)) { // TODO log failed blocks. Right now, just throw exception @@ -415,9 +413,9 @@ public static void run( LOG.debug("Relabeling block ({} {}) starting at id {}", t._1()._1(), t._1()._2(), t._2()); final N5Writer n5 = n5out.get(); final Interval interval = new FinalInterval(t._1()._1(), t._1()._2()); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, t._2()); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, t._2()); relabel(n5, merged, merged, interval, t._2()); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, t._2()); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, t._2()); // TODO do halo relabeling @@ -431,7 +429,7 @@ public static void run( n5out.get().setAttribute(merged, "maxId", maxId); n5out.get().setAttribute(seededWatersheds, "maxId", maxId); -// final IntArrayUnionFind uf = findOverlappingLabelsArgMaxNoHalo(sc, n5out, merged, new IntArrayUnionFind((int) (maxId + 2)), outputDims, blockSize, blocksPerTask, 0); + // final IntArrayUnionFind uf = findOverlappingLabelsArgMaxNoHalo(sc, n5out, merged, new IntArrayUnionFind((int) (maxId + 2)), outputDims, blockSize, blocksPerTask, 0); final LongHashMapUnionFind uf = findOverlappingLabelsThresholdMedianEdgeAffinities( sc, n5out, @@ -456,10 +454,9 @@ public static void run( : (maxId); LOG.debug("Max label = {} min label = {} for block ({} {})", maxLabel, minLabel, idOffset._1()._1(), idOffset._1()._2()); - final long[] keys = new long[(int) (maxLabel - minLabel)]; + final long[] keys = new long[(int)(maxLabel - minLabel)]; final long[] vals = new long[keys.length]; - for (int i = 0; i < keys.length; ++i) { final long k = i + minLabel; final long root = uf.findRoot(k); @@ -479,9 +476,9 @@ public static void run( final Interval interval = new FinalInterval(t._1()._1(), t._1()._2()); final TLongLongMap mapping = new TLongLongHashMap(t._2()._1(), t._2()._2()); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, mapping); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, mapping); return relabel(n5out.get(), merged, blockMerged, interval, mapping); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, mapping); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, mapping); }) .collect(); if (returnCodes.stream().anyMatch(r -> !r)) @@ -497,6 +494,7 @@ private static boolean relabel( final String target, final Interval interval, final TLongLongMap mapping) throws IOException { + return SparkWatershedsOnDistanceTransformOfSampledFunction.relabel(n5, source, target, interval, (src, tgt) -> { final long val = mapping.get(src.getIntegerLong()); if (val != 0) @@ -510,6 +508,7 @@ private static void relabel( final String target, final Interval interval, final long addIfNotZero) throws IOException { + final CachedMapper mapper = new CachedMapper(addIfNotZero); SparkWatershedsOnDistanceTransformOfSampledFunction.relabel(n5, source, target, interval, (src, tgt) -> tgt.set(mapper.applyAsLong(src.getIntegerLong()))); } @@ -520,6 +519,7 @@ private static & NativeType> boolean relabel( final String target, final Interval interval, final BiConsumer idMapping) throws IOException { + final DatasetAttributes attributes = n5.getDatasetAttributes(source); final CellGrid grid = new CellGrid(attributes.getDimensions(), attributes.getBlockSize()); final RandomAccessibleInterval data = Views.interval(N5Utils.open(n5, source), interval); @@ -538,11 +538,11 @@ private static & NativeType> boolean relabel( final long[] blockPos = Intervals.minAsLongArray(interval); grid.getCellPosition(blockPos, blockPos); N5Utils.saveBlock(copy, n5, target, attributes, blockPos); -// final Cursor reloaded = Views.flatIterable(Views.interval(N5Utils.open(n5, target), interval)).cursor(); -// final Cursor c = Views.flatIterable(copy).cursor(); -// while (c.hasNext()) -// if (c.next().getIntegerLong() != reloaded.next().getIntegerLong()) -// return false; + // final Cursor reloaded = Views.flatIterable(Views.interval(N5Utils.open(n5, target), interval)).cursor(); + // final Cursor c = Views.flatIterable(copy).cursor(); + // while (c.hasNext()) + // if (c.next().getIntegerLong() != reloaded.next().getIntegerLong()) + // return false; return true; } @@ -551,6 +551,7 @@ private static void relabel( final String dataset, final long[] blockPos, final long addIfNonZero) throws IOException { + relabel(n5, dataset, blockPos, id -> id == 0 ? 0 : id + addIfNonZero); } @@ -559,14 +560,16 @@ private static void relabel( final String dataset, final long[] blockPos, final LongUnaryOperator idMapping) throws IOException { + final DatasetAttributes attributes = n5.getDatasetAttributes(dataset); - final LongArrayDataBlock block = ((LongArrayDataBlock) n5.readBlock(dataset, attributes, blockPos)); + final LongArrayDataBlock block = ((LongArrayDataBlock)n5.readBlock(dataset, attributes, blockPos)); final long[] data = block.getData(); for (int i = 0; i < data.length; ++i) { data[i] = idMapping.applyAsLong(data[i]); } n5.writeBlock(dataset, attributes, new LongArrayDataBlock(block.getSize(), data, block.getGridPosition())); } + private static void prepareOutputDatasets( final N5Writer n5, final Map datasets, @@ -581,12 +584,14 @@ private static void prepareOutputDataset( final String dataset, final DatasetAttributes attributes, final Map additionalAttributes) throws IOException { + n5.createDataset(dataset, attributes); for (Map.Entry entry : additionalAttributes.entrySet()) n5.setAttribute(dataset, entry.getKey(), entry.getValue()); } private static Map with(Map map, K key, V value) { + map.put(key, value); return map; } @@ -602,6 +607,7 @@ private static class N5WriterSupplier implements Supplier, Serializabl private final boolean serializeSpecialFloatingPointValues = true; private N5WriterSupplier(final String container, final boolean withPrettyPrinting, final boolean disableHtmlEscaping) { + this.container = container; this.withPrettyPrinting = withPrettyPrinting; this.disableHtmlEscaping = disableHtmlEscaping; @@ -610,43 +616,46 @@ private N5WriterSupplier(final String container, final boolean withPrettyPrintin @Override public N5Writer get() { - try { - return Files.isDirectory(Paths.get(container)) - ? new N5FSWriter(container, createaBuilder()) - : new N5HDF5Writer(container); - } catch (final IOException e) { - throw new RuntimeException(e); - } + return Files.isDirectory(Paths.get(container)) + ? new N5FSWriter(container, createaBuilder()) + : new N5HDF5Writer(container); } private GsonBuilder createaBuilder() { + return serializeSpecialFloatingPointValues(withPrettyPrinting(disableHtmlEscaping(new GsonBuilder()))); } private GsonBuilder serializeSpecialFloatingPointValues(final GsonBuilder builder) { + return with(builder, this.serializeSpecialFloatingPointValues, GsonBuilder::serializeSpecialFloatingPointValues); } private GsonBuilder withPrettyPrinting(final GsonBuilder builder) { + return with(builder, this.withPrettyPrinting, GsonBuilder::setPrettyPrinting); } private GsonBuilder disableHtmlEscaping(final GsonBuilder builder) { + return with(builder, this.disableHtmlEscaping, GsonBuilder::disableHtmlEscaping); } private static GsonBuilder with(final GsonBuilder builder, boolean applyAction, Function action) { + return applyAction ? action.apply(builder) : builder; } } private static double[] ones(final int length) { + double[] ones = new double[length]; Arrays.fill(ones, 1.0); return ones; } private static Interval addDimension(final Interval interval, final long m, final long M) { + long[] min = new long[interval.numDimensions() + 1]; long[] max = new long[interval.numDimensions() + 1]; for (int d = 0; d < interval.numDimensions(); ++d) { @@ -659,14 +668,17 @@ private static Interval addDimension(final Interval interval, final long m, fina } private static String toString(final Interval interval) { + return String.format("(%s %s)", Arrays.toString(Intervals.minAsLongArray(interval)), Arrays.toString(Intervals.maxAsLongArray(interval))); } - private static double[] reverted(final double[] array, final boolean revert) { - return revert ? reverted(array) : array; + private static double[] reversed(final double[] array, final boolean reverse) { + + return reverse ? reversed(array) : array; } - private static double[] reverted(final double[] array) { + private static double[] reversed(final double[] array) { + final double[] copy = new double[array.length]; for (int i = 0, k = copy.length - 1; i < copy.length; ++i, --k) { copy[i] = array[k]; @@ -679,6 +691,7 @@ private static > ArrayImg smooth( final Interval interval, final int channelDim, double sigma) { + final ArrayImg img = ArrayImgs.floats(Intervals.dimensionsAsLongArray(interval)); for (long channel = interval.min(channelDim); channel <= interval.max(channelDim); ++channel) { @@ -695,6 +708,7 @@ private static > void invalidateOutOfBlockAffinities( final T invalid, final long[]... offsets ) { + for (int index = 0; index < offsets.length; ++index) { final IntervalView slice = Views.hyperSlice(affs, affs.numDimensions() - 1, index); for (int d = 0; d < offsets[index].length; ++d) { @@ -745,7 +759,6 @@ private static UF findOverlappingLabelsThresholdMedianEdg final RandomAccessibleInterval thisBlockLabels = Views.interval(labels, minMax._1(), thisBlockMax); final RandomAccessibleInterval affinities = N5Utils.open(affinitiesContainer.get(), affinitiesDataset); - final TLongSet ignoreTheseSet = new TLongHashSet(ignoreThese); final TLongLongHashMap mapping = new TLongLongHashMap(); @@ -804,8 +817,7 @@ private static UF findOverlappingLabelsThresholdMedianEdg if (thisLabel < thatLabel) { e1 = thisLabel; e2 = thatLabel; - } - else { + } else { e1 = thatLabel; e2 = thisLabel; } @@ -826,7 +838,6 @@ private static UF findOverlappingLabelsThresholdMedianEdg LOG.info("Edge affinities: {}", affinitiesByEdge); - affinitiesByEdge.forEachEntry((k, v) -> { TLongObjectIterator edgeIt = v.iterator(); while (edgeIt.hasNext()) { @@ -869,7 +880,6 @@ private static UF findOverlappingLabelsThresholdMedianEdg return uf; - } private static UF findOverlappingLabelsArgMaxNoHalo( @@ -956,13 +966,12 @@ private static UF findOverlappingLabelsArgMaxNoHalo( addOne(thisMap.get(thisLabel), thatLabel); addOne(thatMap.get(thatLabel), thisLabel); -// thatArgMax.forEachEntry((k, v) -> { -//// if (thatArgMax.get(v) == k) -// localUF.join(localUF.findRoot(v), localUF.findRoot(k)); -//// mapping.put(k, v); -// return true; -// }); - + // thatArgMax.forEachEntry((k, v) -> { + //// if (thatArgMax.get(v) == k) + // localUF.join(localUF.findRoot(v), localUF.findRoot(k)); + //// mapping.put(k, v); + // return true; + // }); } @@ -1005,7 +1014,6 @@ private static UF findOverlappingLabelsArgMaxNoHalo( return uf; - } private static class CropAffinitiesToDistanceTransform implements PairFunction, RandomAccessibleInterval>> { @@ -1020,6 +1028,7 @@ private CropAffinitiesToDistanceTransform( final Supplier n5in, final String affinities, final double[] weights) { + this.n5in = n5in; this.affinities = affinities; this.weights = weights; @@ -1030,11 +1039,13 @@ private static class ReplaceNaNWith> implements Converter< private final double replacement; private ReplaceNaNWith(double replacement) { + this.replacement = replacement; } @Override public void convert(T src, T tgt) { + final double t = src.getRealDouble(); tgt.setReal(Double.isNaN(t) ? replacement : t); } @@ -1042,6 +1053,7 @@ public void convert(T src, T tgt) { @Override public Tuple2, RandomAccessibleInterval>> call(final Interval interval) throws Exception { + final RandomAccessibleInterval affsImg = N5Utils.open(n5in.get(), affinities); final RandomAccessible affs = Converters.convert(Views.extendValue(affsImg, new FloatType(0.0f)), new ReplaceNaNWith<>(0.0), new FloatType()); @@ -1060,7 +1072,7 @@ public Tuple2, RandomAccess weights ); // TODO should we actually rewrite those? - double[] minMax = new double[] {Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY}; + double[] minMax = new double[]{Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY}; LoopBuilder .setImages(Views.interval(affs, withContext), distanceTransform) .forEachPixel((a, d) -> { @@ -1081,14 +1093,17 @@ public Tuple2, RandomAccess } private static T toMinMaxTuple(final Interval interval, BiFunction toTuple) { + return toTuple.apply(Intervals.minAsLongArray(interval), Intervals.maxAsLongArray(interval)); } private static void addOne(final TLongIntMap countMap, final long label) { + countMap.put(label, countMap.get(label) + 1); } private static TLongLongMap argMaxCounts(final TLongObjectMap counts) { + final TLongLongMap mapping = new TLongLongHashMap(); counts.forEachEntry((k, v) -> { mapping.put(k, argMaxCount(v)); @@ -1098,6 +1113,7 @@ private static TLongLongMap argMaxCounts(final TLongObjectMap count } private static long argMaxCount(final TLongIntMap counts) { + long maxCount = Long.MIN_VALUE; long argMaxCount = 0; for (final TLongIntIterator it = counts.iterator(); it.hasNext(); ) { @@ -1107,18 +1123,19 @@ private static long argMaxCount(final TLongIntMap counts) { maxCount = v; argMaxCount = it.key(); } - }; + } + ; return argMaxCount; } private static class CachedMapper implements LongUnaryOperator { - private long nextId; private final TLongLongMap cache = new TLongLongHashMap(); private CachedMapper(final long firstId) { + this.nextId = firstId; } @@ -1136,6 +1153,7 @@ public long applyAsLong(long l) { } private static T computeIfAbsent(final TLongObjectMap map, final long key, final LongFunction mappingFactory) { + final T value = map.get(key); if (value != null) return value; diff --git a/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatershedsOnDistanceTransformOfSampledFunctionSeedOnlyOnEdge.java b/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatershedsOnDistanceTransformOfSampledFunctionSeedOnlyOnEdge.java index d1aa63c..860a0e3 100644 --- a/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatershedsOnDistanceTransformOfSampledFunctionSeedOnlyOnEdge.java +++ b/src/main/java/org/janelia/saalfeldlab/label/spark/watersheds/SparkWatershedsOnDistanceTransformOfSampledFunctionSeedOnlyOnEdge.java @@ -150,7 +150,7 @@ private static class Args implements Serializable, Callable { int[] blockSize = {64, 64, 64}; @Expose - @CommandLine.Option(names = "--blocks-per-task", paramLabel = "BLOCKS_PER_TASK", description = "How many blocks to combine for watersheds/connected components (one value per dimension)", split=",") + @CommandLine.Option(names = "--blocks-per-task", paramLabel = "BLOCKS_PER_TASK", description = "How many blocks to combine for watersheds/connected components (one value per dimension)", split = ",") int[] blocksPerTask = {1, 1, 1}; @Expose @@ -170,8 +170,8 @@ private static class Args implements Serializable, Callable { Boolean relabel; @Expose - @CommandLine.Option(names = "--revert-array-attributes", paramLabel = "RELABEL", description = "Revert all array attributes (that are not dataset attributes)", defaultValue = "false") - Boolean revertArrayAttributes; + @CommandLine.Option(names = "--reverse-array-attributes", paramLabel = "RELABEL", description = "Reverse all array attributes (that are not dataset attributes)", defaultValue = "false") + Boolean reverseArrayAttributes; @CommandLine.Option(names = "--json-pretty-print", defaultValue = "true") transient Boolean prettyPrint; @@ -179,7 +179,7 @@ private static class Args implements Serializable, Callable { @CommandLine.Option(names = "--json-disable-html-escape", defaultValue = "true") transient Boolean disbaleHtmlEscape; - @CommandLine.Option(names = { "-h", "--help"}, usageHelp = true, description = "Display this help and exit") + @CommandLine.Option(names = {"-h", "--help"}, usageHelp = true, description = "Display this help and exit") private Boolean help; @Override @@ -221,12 +221,11 @@ public static void run(final String... argv) throws IOException { labelUtilitiesSparkAttributes.put(VERSION_KEY, Version.VERSION_STRING); final Map attributes = with(new HashMap<>(), LABEL_UTILITIES_SPARK_KEY, labelUtilitiesSparkAttributes); - String[] uint64Datasets = {args.merged, args.seededWatersheds, args.watershedSeeds, args.blockMerged}; String[] float64Datasets = {args.distanceTransform}; - final double[] resolution = reverted(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, RESOLUTION_KEY, double[].class)).orElse(ones(outputDims.length)), args.revertArrayAttributes); - final double[] offset = reverted(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, OFFSET_KEY, double[].class)).orElse(new double[outputDims.length]), args.revertArrayAttributes); + final double[] resolution = reversed(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, RESOLUTION_KEY, double[].class)).orElse(ones(outputDims.length)), args.reverseArrayAttributes); + final double[] offset = reversed(Optional.ofNullable(n5in.get().getAttribute(args.averagedAffinities, OFFSET_KEY, double[].class)).orElse(new double[outputDims.length]), args.reverseArrayAttributes); attributes.put(RESOLUTION_KEY, resolution); attributes.put(OFFSET_KEY, offset); @@ -259,7 +258,7 @@ public static void run(final String... argv) throws IOException { outputDims, DoubleStream.of(resolution).map(d -> d / DoubleStream.of(resolution).min().getAsDouble()).map(d -> args.weightDistanceTransform * d * d).toArray(), // TODO maybe pass these as parameters through CLI instead? args.seedDistance, -// new SerializableMergeWatershedsMinThresholdSupplier(args.threshold), + // new SerializableMergeWatershedsMinThresholdSupplier(args.threshold), new SerializableMergeWatershedsMedianThresholdSupplier(args.mergeFragmentThreshold), args.mergeFragmentThreshold, args.averagedAffinities, @@ -304,7 +303,7 @@ public static void run( final List, Integer>> idCounts = sc .parallelize(watershedBlocks) - .map(t -> (Interval) new FinalInterval(t._1(), t._2())) + .map(t -> (Interval)new FinalInterval(t._1(), t._2())) .mapToPair(new CropAffinitiesToDistanceTransform(n5in, averagedAffinities, distanceTransformWeights)) .mapToPair(t -> { final Interval block = t._1(); @@ -409,18 +408,17 @@ public static void run( N5Utils.saveBlock(labels, n5out.get(), merged, attributes.apply(DataType.UINT64), blockOffset); final boolean wasSuccessful = true; -// final IntervalView reloaded = Views.interval(N5Utils.open(n5out.get(), merged), block); -// final Cursor r = Views.flatIterable(reloaded).cursor(); -// final Cursor m = Views.flatIterable(labels).cursor(); -// boolean wasSuccessful = true; -// while(r.hasNext() && wasSuccessful) { -// wasSuccessful = r.next().valueEquals(m.next()); -// } + // final IntervalView reloaded = Views.interval(N5Utils.open(n5out.get(), merged), block); + // final Cursor r = Views.flatIterable(reloaded).cursor(); + // final Cursor m = Views.flatIterable(labels).cursor(); + // boolean wasSuccessful = true; + // while(r.hasNext() && wasSuccessful) { + // wasSuccessful = r.next().valueEquals(m.next()); + // } return new Tuple2<>(new Tuple2<>(Intervals.minAsLongArray(t._1()), Intervals.maxAsLongArray(t._1())), wasSuccessful ? ids.size() : -1); }) - .collect() - ; + .collect(); if (idCounts.stream().mapToInt(Tuple2::_2).anyMatch(c -> c < 0)) { // TODO log failed blocks. Right now, just throw exception @@ -441,9 +439,9 @@ public static void run( LOG.debug("Relabeling block ({} {}) starting at id {}", t._1()._1(), t._1()._2(), t._2()); final N5Writer n5 = n5out.get(); final Interval interval = new FinalInterval(t._1()._1(), t._1()._2()); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, t._2()); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, t._2()); relabel(n5, merged, merged, interval, t._2()); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, t._2()); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, t._2()); // TODO do halo relabeling @@ -457,7 +455,7 @@ public static void run( n5out.get().setAttribute(merged, "maxId", maxId); n5out.get().setAttribute(seededWatersheds, "maxId", maxId); -// final IntArrayUnionFind uf = findOverlappingLabelsArgMaxNoHalo(sc, n5out, merged, new IntArrayUnionFind((int) (maxId + 2)), outputDims, blockSize, blocksPerTask, 0); + // final IntArrayUnionFind uf = findOverlappingLabelsArgMaxNoHalo(sc, n5out, merged, new IntArrayUnionFind((int) (maxId + 2)), outputDims, blockSize, blocksPerTask, 0); final LongHashMapUnionFind uf = findOverlappingLabelsThresholdMedianEdgeAffinities( sc, n5out, @@ -482,10 +480,9 @@ public static void run( : (maxId); LOG.debug("Max label = {} min label = {} for block ({} {})", maxLabel, minLabel, idOffset._1()._1(), idOffset._1()._2()); - final long[] keys = new long[(int) (maxLabel - minLabel)]; + final long[] keys = new long[(int)(maxLabel - minLabel)]; final long[] vals = new long[keys.length]; - for (int i = 0; i < keys.length; ++i) { final long k = i + minLabel; final long root = uf.findRoot(k); @@ -505,9 +502,9 @@ public static void run( final Interval interval = new FinalInterval(t._1()._1(), t._1()._2()); final TLongLongMap mapping = new TLongLongHashMap(t._2()._1(), t._2()._2()); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, mapping); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, watershedSeeds) : watershedSeeds, interval, mapping); return relabel(n5out.get(), merged, blockMerged, interval, mapping); -// relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, mapping); + // relabel(n5, hasHalo ? String.format(croppedDatasetPattern, seededWatersheds) : seededWatersheds, interval, mapping); }) .collect(); if (returnCodes.stream().anyMatch(r -> !r)) @@ -523,6 +520,7 @@ private static boolean relabel( final String target, final Interval interval, final TLongLongMap mapping) throws IOException { + return SparkWatershedsOnDistanceTransformOfSampledFunctionSeedOnlyOnEdge.relabel(n5, source, target, interval, (src, tgt) -> { final long val = mapping.get(src.getIntegerLong()); if (val != 0) @@ -536,6 +534,7 @@ private static void relabel( final String target, final Interval interval, final long addIfNotZero) throws IOException { + final CachedMapper mapper = new CachedMapper(addIfNotZero); SparkWatershedsOnDistanceTransformOfSampledFunctionSeedOnlyOnEdge.relabel(n5, source, target, interval, (src, tgt) -> tgt.set(mapper.applyAsLong(src.getIntegerLong()))); } @@ -546,6 +545,7 @@ private static & NativeType> boolean relabel( final String target, final Interval interval, final BiConsumer idMapping) throws IOException { + final DatasetAttributes attributes = n5.getDatasetAttributes(source); final CellGrid grid = new CellGrid(attributes.getDimensions(), attributes.getBlockSize()); final RandomAccessibleInterval data = Views.interval(N5Utils.open(n5, source), interval); @@ -564,11 +564,11 @@ private static & NativeType> boolean relabel( final long[] blockPos = Intervals.minAsLongArray(interval); grid.getCellPosition(blockPos, blockPos); N5Utils.saveBlock(copy, n5, target, attributes, blockPos); -// final Cursor reloaded = Views.flatIterable(Views.interval(N5Utils.open(n5, target), interval)).cursor(); -// final Cursor c = Views.flatIterable(copy).cursor(); -// while (c.hasNext()) -// if (c.next().getIntegerLong() != reloaded.next().getIntegerLong()) -// return false; + // final Cursor reloaded = Views.flatIterable(Views.interval(N5Utils.open(n5, target), interval)).cursor(); + // final Cursor c = Views.flatIterable(copy).cursor(); + // while (c.hasNext()) + // if (c.next().getIntegerLong() != reloaded.next().getIntegerLong()) + // return false; return true; } @@ -577,6 +577,7 @@ private static void relabel( final String dataset, final long[] blockPos, final long addIfNonZero) throws IOException { + relabel(n5, dataset, blockPos, id -> id == 0 ? 0 : id + addIfNonZero); } @@ -585,14 +586,16 @@ private static void relabel( final String dataset, final long[] blockPos, final LongUnaryOperator idMapping) throws IOException { + final DatasetAttributes attributes = n5.getDatasetAttributes(dataset); - final LongArrayDataBlock block = ((LongArrayDataBlock) n5.readBlock(dataset, attributes, blockPos)); + final LongArrayDataBlock block = ((LongArrayDataBlock)n5.readBlock(dataset, attributes, blockPos)); final long[] data = block.getData(); for (int i = 0; i < data.length; ++i) { data[i] = idMapping.applyAsLong(data[i]); } n5.writeBlock(dataset, attributes, new LongArrayDataBlock(block.getSize(), data, block.getGridPosition())); } + private static void prepareOutputDatasets( final N5Writer n5, final Map datasets, @@ -607,12 +610,14 @@ private static void prepareOutputDataset( final String dataset, final DatasetAttributes attributes, final Map additionalAttributes) throws IOException { + n5.createDataset(dataset, attributes); for (Map.Entry entry : additionalAttributes.entrySet()) n5.setAttribute(dataset, entry.getKey(), entry.getValue()); } private static Map with(Map map, K key, V value) { + map.put(key, value); return map; } @@ -628,6 +633,7 @@ private static class N5WriterSupplier implements Supplier, Serializabl private final boolean serializeSpecialFloatingPointValues = true; private N5WriterSupplier(final String container, final boolean withPrettyPrinting, final boolean disableHtmlEscaping) { + this.container = container; this.withPrettyPrinting = withPrettyPrinting; this.disableHtmlEscaping = disableHtmlEscaping; @@ -636,43 +642,46 @@ private N5WriterSupplier(final String container, final boolean withPrettyPrintin @Override public N5Writer get() { - try { - return Files.isDirectory(Paths.get(container)) - ? new N5FSWriter(container, createaBuilder()) - : new N5HDF5Writer(container); - } catch (final IOException e) { - throw new RuntimeException(e); - } + return Files.isDirectory(Paths.get(container)) + ? new N5FSWriter(container, createaBuilder()) + : new N5HDF5Writer(container); } private GsonBuilder createaBuilder() { + return serializeSpecialFloatingPointValues(withPrettyPrinting(disableHtmlEscaping(new GsonBuilder()))); } private GsonBuilder serializeSpecialFloatingPointValues(final GsonBuilder builder) { + return with(builder, this.serializeSpecialFloatingPointValues, GsonBuilder::serializeSpecialFloatingPointValues); } private GsonBuilder withPrettyPrinting(final GsonBuilder builder) { + return with(builder, this.withPrettyPrinting, GsonBuilder::setPrettyPrinting); } private GsonBuilder disableHtmlEscaping(final GsonBuilder builder) { + return with(builder, this.disableHtmlEscaping, GsonBuilder::disableHtmlEscaping); } private static GsonBuilder with(final GsonBuilder builder, boolean applyAction, Function action) { + return applyAction ? action.apply(builder) : builder; } } private static double[] ones(final int length) { + double[] ones = new double[length]; Arrays.fill(ones, 1.0); return ones; } private static Interval addDimension(final Interval interval, final long m, final long M) { + long[] min = new long[interval.numDimensions() + 1]; long[] max = new long[interval.numDimensions() + 1]; for (int d = 0; d < interval.numDimensions(); ++d) { @@ -685,14 +694,17 @@ private static Interval addDimension(final Interval interval, final long m, fina } private static String toString(final Interval interval) { + return String.format("(%s %s)", Arrays.toString(Intervals.minAsLongArray(interval)), Arrays.toString(Intervals.maxAsLongArray(interval))); } - private static double[] reverted(final double[] array, final boolean revert) { - return revert ? reverted(array) : array; + private static double[] reversed(final double[] array, final boolean reverse) { + + return reverse ? reversed(array) : array; } - private static double[] reverted(final double[] array) { + private static double[] reversed(final double[] array) { + final double[] copy = new double[array.length]; for (int i = 0, k = copy.length - 1; i < copy.length; ++i, --k) { copy[i] = array[k]; @@ -705,6 +717,7 @@ private static > ArrayImg smooth( final Interval interval, final int channelDim, double sigma) { + final ArrayImg img = ArrayImgs.floats(Intervals.dimensionsAsLongArray(interval)); for (long channel = interval.min(channelDim); channel <= interval.max(channelDim); ++channel) { @@ -721,6 +734,7 @@ private static > void invalidateOutOfBlockAffinities( final T invalid, final long[]... offsets ) { + for (int index = 0; index < offsets.length; ++index) { final IntervalView slice = Views.hyperSlice(affs, affs.numDimensions() - 1, index); for (int d = 0; d < offsets[index].length; ++d) { @@ -771,7 +785,6 @@ private static UF findOverlappingLabelsThresholdMedianEdg final RandomAccessibleInterval thisBlockLabels = Views.interval(labels, minMax._1(), thisBlockMax); final RandomAccessibleInterval affinities = N5Utils.open(affinitiesContainer.get(), affinitiesDataset); - final TLongSet ignoreTheseSet = new TLongHashSet(ignoreThese); final TLongLongHashMap mapping = new TLongLongHashMap(); @@ -830,8 +843,7 @@ private static UF findOverlappingLabelsThresholdMedianEdg if (thisLabel < thatLabel) { e1 = thisLabel; e2 = thatLabel; - } - else { + } else { e1 = thatLabel; e2 = thisLabel; } @@ -852,7 +864,6 @@ private static UF findOverlappingLabelsThresholdMedianEdg LOG.info("Edge affinities: {}", affinitiesByEdge); - affinitiesByEdge.forEachEntry((k, v) -> { TLongObjectIterator edgeIt = v.iterator(); while (edgeIt.hasNext()) { @@ -895,7 +906,6 @@ private static UF findOverlappingLabelsThresholdMedianEdg return uf; - } private static UF findOverlappingLabelsArgMaxNoHalo( @@ -982,13 +992,12 @@ private static UF findOverlappingLabelsArgMaxNoHalo( addOne(thisMap.get(thisLabel), thatLabel); addOne(thatMap.get(thatLabel), thisLabel); -// thatArgMax.forEachEntry((k, v) -> { -//// if (thatArgMax.get(v) == k) -// localUF.join(localUF.findRoot(v), localUF.findRoot(k)); -//// mapping.put(k, v); -// return true; -// }); - + // thatArgMax.forEachEntry((k, v) -> { + //// if (thatArgMax.get(v) == k) + // localUF.join(localUF.findRoot(v), localUF.findRoot(k)); + //// mapping.put(k, v); + // return true; + // }); } @@ -1031,7 +1040,6 @@ private static UF findOverlappingLabelsArgMaxNoHalo( return uf; - } private static class CropAffinitiesToDistanceTransform implements PairFunction, RandomAccessibleInterval>> { @@ -1046,6 +1054,7 @@ private CropAffinitiesToDistanceTransform( final Supplier n5in, final String affinities, final double[] weights) { + this.n5in = n5in; this.affinities = affinities; this.weights = weights; @@ -1056,11 +1065,13 @@ private static class ReplaceNaNWith> implements Converter< private final double replacement; private ReplaceNaNWith(double replacement) { + this.replacement = replacement; } @Override public void convert(T src, T tgt) { + final double t = src.getRealDouble(); tgt.setReal(Double.isNaN(t) ? replacement : t); } @@ -1068,6 +1079,7 @@ public void convert(T src, T tgt) { @Override public Tuple2, RandomAccessibleInterval>> call(final Interval interval) throws Exception { + final RandomAccessibleInterval affsImg = N5Utils.open(n5in.get(), affinities); final RandomAccessible affs = Converters.convert(Views.extendValue(affsImg, new FloatType(0.0f)), new ReplaceNaNWith<>(0.0), new FloatType()); @@ -1086,7 +1098,7 @@ public Tuple2, RandomAccess weights ); // TODO should we actually rewrite those? - double[] minMax = new double[] {Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY}; + double[] minMax = new double[]{Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY}; LoopBuilder .setImages(Views.interval(affs, withContext), distanceTransform) .forEachPixel((a, d) -> { @@ -1107,14 +1119,17 @@ public Tuple2, RandomAccess } private static T toMinMaxTuple(final Interval interval, BiFunction toTuple) { + return toTuple.apply(Intervals.minAsLongArray(interval), Intervals.maxAsLongArray(interval)); } private static void addOne(final TLongIntMap countMap, final long label) { + countMap.put(label, countMap.get(label) + 1); } private static TLongLongMap argMaxCounts(final TLongObjectMap counts) { + final TLongLongMap mapping = new TLongLongHashMap(); counts.forEachEntry((k, v) -> { mapping.put(k, argMaxCount(v)); @@ -1124,6 +1139,7 @@ private static TLongLongMap argMaxCounts(final TLongObjectMap count } private static long argMaxCount(final TLongIntMap counts) { + long maxCount = Long.MIN_VALUE; long argMaxCount = 0; for (final TLongIntIterator it = counts.iterator(); it.hasNext(); ) { @@ -1133,18 +1149,19 @@ private static long argMaxCount(final TLongIntMap counts) { maxCount = v; argMaxCount = it.key(); } - }; + } + ; return argMaxCount; } private static class CachedMapper implements LongUnaryOperator { - private long nextId; private final TLongLongMap cache = new TLongLongHashMap(); private CachedMapper(final long firstId) { + this.nextId = firstId; } @@ -1162,6 +1179,7 @@ public long applyAsLong(long l) { } private static T computeIfAbsent(final TLongObjectMap map, final long key, final LongFunction mappingFactory) { + final T value = map.get(key); if (value != null) return value; diff --git a/src/test/java/org/janelia/saalfeldlab/label/spark/UniqueLabelsAndMappingTest.java b/src/test/java/org/janelia/saalfeldlab/label/spark/UniqueLabelsAndMappingTest.java index 1e643ed..7c10ab3 100644 --- a/src/test/java/org/janelia/saalfeldlab/label/spark/UniqueLabelsAndMappingTest.java +++ b/src/test/java/org/janelia/saalfeldlab/label/spark/UniqueLabelsAndMappingTest.java @@ -1,17 +1,13 @@ package org.janelia.saalfeldlab.label.spark; -import java.io.File; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - +import net.imglib2.AbstractInterval; +import net.imglib2.Interval; +import net.imglib2.algorithm.util.Grids; +import net.imglib2.img.array.ArrayImg; +import net.imglib2.img.array.ArrayImgs; +import net.imglib2.img.basictypeaccess.array.LongArray; +import net.imglib2.type.numeric.integer.UnsignedLongType; +import net.imglib2.util.Intervals; import org.apache.commons.io.FileUtils; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; @@ -34,17 +30,19 @@ import org.junit.Before; import org.junit.Test; -import net.imglib2.AbstractInterval; -import net.imglib2.Interval; -import net.imglib2.algorithm.util.Grids; -import net.imglib2.img.array.ArrayImg; -import net.imglib2.img.array.ArrayImgs; -import net.imglib2.img.basictypeaccess.array.LongArray; -import net.imglib2.type.numeric.integer.UnsignedLongType; -import net.imglib2.util.Intervals; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; -public class UniqueLabelsAndMappingTest -{ +public class UniqueLabelsAndMappingTest { private final String labelDataset = "labels"; @@ -54,9 +52,9 @@ public class UniqueLabelsAndMappingTest private final String tmpDir; - private final long[] dims = { 4, 3, 2 }; + private final long[] dims = {4, 3, 2}; - private final int[] blockSize = { 3, 2, 2 }; + private final int[] blockSize = {3, 2, 2}; private final long[] labels = { 1, 1, 1, 1, @@ -68,161 +66,142 @@ public class UniqueLabelsAndMappingTest 3, 3, 3, 2 }; - ArrayImg< UnsignedLongType, LongArray > labelImg = ArrayImgs.unsignedLongs( labels, dims ); + ArrayImg labelImg = ArrayImgs.unsignedLongs(labels, dims); + + Map> groundTruthMapping = new HashMap<>(); - Map< String, Set< ComparableFinalInterval > > groundTruthMapping = new HashMap<>(); { - groundTruthMapping.put( "1", new HashSet<>( Arrays.asList( - new ComparableFinalInterval( new long[] { 0, 0, 0 }, new long[] { 2, 1, 1 } ), - new ComparableFinalInterval( new long[] { 3, 0, 0 }, new long[] { 3, 1, 1 } ), - new ComparableFinalInterval( new long[] { 0, 2, 0 }, new long[] { 2, 2, 1 } ) ) ) ); + groundTruthMapping.put("1", new HashSet<>(Arrays.asList( + new ComparableFinalInterval(new long[]{0, 0, 0}, new long[]{2, 1, 1}), + new ComparableFinalInterval(new long[]{3, 0, 0}, new long[]{3, 1, 1}), + new ComparableFinalInterval(new long[]{0, 2, 0}, new long[]{2, 2, 1})))); - groundTruthMapping.put( "2", new HashSet<>( Arrays.asList( - new ComparableFinalInterval( new long[] { 0, 0, 0 }, new long[] { 2, 1, 1 } ), - new ComparableFinalInterval( new long[] { 3, 0, 0 }, new long[] { 3, 1, 1 } ), - new ComparableFinalInterval( new long[] { 0, 2, 0 }, new long[] { 2, 2, 1 } ), - new ComparableFinalInterval( new long[] { 3, 2, 0 }, new long[] { 3, 2, 1 } ) ) ) ); + groundTruthMapping.put("2", new HashSet<>(Arrays.asList( + new ComparableFinalInterval(new long[]{0, 0, 0}, new long[]{2, 1, 1}), + new ComparableFinalInterval(new long[]{3, 0, 0}, new long[]{3, 1, 1}), + new ComparableFinalInterval(new long[]{0, 2, 0}, new long[]{2, 2, 1}), + new ComparableFinalInterval(new long[]{3, 2, 0}, new long[]{3, 2, 1})))); - groundTruthMapping.put( "3", new HashSet<>( Arrays.asList( - new ComparableFinalInterval( new long[] { 0, 0, 0 }, new long[] { 2, 1, 1 } ), - new ComparableFinalInterval( new long[] { 0, 2, 0 }, new long[] { 2, 2, 1 } ) ) ) ); + groundTruthMapping.put("3", new HashSet<>(Arrays.asList( + new ComparableFinalInterval(new long[]{0, 0, 0}, new long[]{2, 1, 1}), + new ComparableFinalInterval(new long[]{0, 2, 0}, new long[]{2, 2, 1})))); - groundTruthMapping.put( "4", new HashSet<>( Arrays.asList( - new ComparableFinalInterval( new long[] { 3, 0, 0 }, new long[] { 3, 1, 1 } ) ) ) ); + groundTruthMapping.put("4", new HashSet<>(Arrays.asList( + new ComparableFinalInterval(new long[]{3, 0, 0}, new long[]{3, 1, 1})))); } - public UniqueLabelsAndMappingTest() throws IOException - { - this.tmpDir = Files.createTempDirectory( "unique-labels-test" ).toAbsolutePath().toString(); - this.labelToBlocksMappingDirectory = Paths.get( tmpDir, "label-to-block-mapping" ).toAbsolutePath().toString(); + public UniqueLabelsAndMappingTest() throws IOException { + + this.tmpDir = Files.createTempDirectory("unique-labels-test").toAbsolutePath().toString(); + this.labelToBlocksMappingDirectory = Paths.get(tmpDir, "label-to-block-mapping").toAbsolutePath().toString(); } @Before - public void setUp() throws IOException - { + public void setUp() throws IOException { - final N5FSWriter n5 = new N5FSWriter( this.tmpDir ); - n5.createDataset( labelDataset, new DatasetAttributes( dims, blockSize, DataType.UINT64, new RawCompression() ) ); - N5Utils.save( labelImg, n5, labelDataset, blockSize, new RawCompression() ); + final N5FSWriter n5 = new N5FSWriter(this.tmpDir); + n5.createDataset(labelDataset, new DatasetAttributes(dims, blockSize, DataType.UINT64, new RawCompression())); + N5Utils.save(labelImg, n5, labelDataset, blockSize, new RawCompression()); } @After - public void tearDown() throws IOException - { - FileUtils.deleteDirectory( new File( this.tmpDir ) ); + public void tearDown() throws IOException { + + FileUtils.deleteDirectory(new File(this.tmpDir)); } @Test - public void test() throws InvalidDataType, IOException, InvalidN5Container, InvalidDataset, InputSameAsOutput - { + public void test() throws InvalidDataType, IOException, InvalidN5Container, InvalidDataset, InputSameAsOutput { + final SparkConf conf = new SparkConf() - .setAppName( getClass().getName() ) - .setMaster( "local[*]" ); - try (JavaSparkContext sc = new JavaSparkContext( conf )) - { - ExtractUniqueLabelsPerBlock.extractUniqueLabels( sc, tmpDir, tmpDir, labelDataset, uniqueLabelDataset ); - LabelToBlockMapping.createMapping( sc, tmpDir, uniqueLabelDataset, labelToBlocksMappingDirectory ); + .setAppName(getClass().getName()) + .setMaster("local[*]"); + try (JavaSparkContext sc = new JavaSparkContext(conf)) { + ExtractUniqueLabelsPerBlock.extractUniqueLabels(sc, tmpDir, tmpDir, labelDataset, uniqueLabelDataset); + LabelToBlockMapping.createMapping(sc, tmpDir, uniqueLabelDataset, labelToBlocksMappingDirectory); } - final N5Reader n5 = new N5FSReader( tmpDir ); - final DatasetAttributes uniqueLabelAttributes = n5.getDatasetAttributes( uniqueLabelDataset ); + final N5Reader n5 = new N5FSReader(tmpDir); + final DatasetAttributes uniqueLabelAttributes = n5.getDatasetAttributes(uniqueLabelDataset); - Assert.assertArrayEquals( dims, uniqueLabelAttributes.getDimensions() ); - Assert.assertArrayEquals( blockSize, uniqueLabelAttributes.getBlockSize() ); + Assert.assertArrayEquals(dims, uniqueLabelAttributes.getDimensions()); + Assert.assertArrayEquals(blockSize, uniqueLabelAttributes.getBlockSize()); - final List< long[] > blocks = Grids.collectAllOffsets( dims, blockSize ); - for ( final long[] block : blocks ) - { + final List blocks = Grids.collectAllOffsets(dims, blockSize); + for (final long[] block : blocks) { final long[] blockPosition = block.clone(); - Arrays.setAll( blockPosition, d -> blockPosition[ d ] / blockSize[ d ] ); + Arrays.setAll(blockPosition, d -> blockPosition[d] / blockSize[d]); - final LongArrayDataBlock blockData = ( ( LongArrayDataBlock ) n5.readBlock( uniqueLabelDataset, uniqueLabelAttributes, blockPosition ) ); + final LongArrayDataBlock blockData = ((LongArrayDataBlock)n5.readBlock(uniqueLabelDataset, uniqueLabelAttributes, blockPosition)); final long[] sortedContents = blockData.getData().clone(); - Arrays.sort( sortedContents ); - - if ( Arrays.equals( blockPosition, new long[] { 0, 0, 0 } ) ) - { - Assert.assertArrayEquals( new long[] { 1, 2, 3 }, sortedContents ); - } - - else if ( Arrays.equals( blockPosition, new long[] { 1, 0, 0 } ) ) - { - Assert.assertArrayEquals( new long[] { 1, 2, 4 }, sortedContents ); - } - - else if ( Arrays.equals( blockPosition, new long[] { 0, 1, 0 } ) ) - { - Assert.assertArrayEquals( new long[] { 1, 2, 3 }, sortedContents ); - } - - else if ( Arrays.equals( blockPosition, new long[] { 1, 1, 0 } ) ) - { - Assert.assertArrayEquals( new long[] { 2 }, sortedContents ); - } - - else - { - Assert.fail( "Observed unexpected block position: " + Arrays.toString( blockPosition ) ); + Arrays.sort(sortedContents); + + if (Arrays.equals(blockPosition, new long[]{0, 0, 0})) { + Assert.assertArrayEquals(new long[]{1, 2, 3}, sortedContents); + } else if (Arrays.equals(blockPosition, new long[]{1, 0, 0})) { + Assert.assertArrayEquals(new long[]{1, 2, 4}, sortedContents); + } else if (Arrays.equals(blockPosition, new long[]{0, 1, 0})) { + Assert.assertArrayEquals(new long[]{1, 2, 3}, sortedContents); + } else if (Arrays.equals(blockPosition, new long[]{1, 1, 0})) { + Assert.assertArrayEquals(new long[]{2}, sortedContents); + } else { + Assert.fail("Observed unexpected block position: " + Arrays.toString(blockPosition)); } } - final String[] containedFiles = new File( labelToBlocksMappingDirectory ).list(); - Arrays.sort( containedFiles ); - Assert.assertEquals( 4, containedFiles.length ); - final Set< String > containedFilesSet = new HashSet<>( Arrays.asList( containedFiles ) ); - Assert.assertEquals( new HashSet<>( Arrays.asList( "1", "2", "3", "4" ) ), containedFilesSet ); - for ( final String file : containedFiles ) - { - final Set< ComparableFinalInterval > storedIntervals = new HashSet<>(); - final byte[] data = Files.readAllBytes( Paths.get( labelToBlocksMappingDirectory, file ) ); + final String[] containedFiles = new File(labelToBlocksMappingDirectory).list(); + Arrays.sort(containedFiles); + Assert.assertEquals(4, containedFiles.length); + final Set containedFilesSet = new HashSet<>(Arrays.asList(containedFiles)); + Assert.assertEquals(new HashSet<>(Arrays.asList("1", "2", "3", "4")), containedFilesSet); + for (final String file : containedFiles) { + final Set storedIntervals = new HashSet<>(); + final byte[] data = Files.readAllBytes(Paths.get(labelToBlocksMappingDirectory, file)); // three dimensions, 2 arrays, long elements - Assert.assertEquals( 0, data.length % ( 3 * 2 * Long.BYTES ) ); - final ByteBuffer bb = ByteBuffer.wrap( data ); - while ( bb.hasRemaining() ) - { - storedIntervals.add( new ComparableFinalInterval( - new long[] { bb.getLong(), bb.getLong(), bb.getLong() }, - new long[] { bb.getLong(), bb.getLong(), bb.getLong() } ) ); + Assert.assertEquals(0, data.length % (3 * 2 * Long.BYTES)); + final ByteBuffer bb = ByteBuffer.wrap(data); + while (bb.hasRemaining()) { + storedIntervals.add(new ComparableFinalInterval( + new long[]{bb.getLong(), bb.getLong(), bb.getLong()}, + new long[]{bb.getLong(), bb.getLong(), bb.getLong()})); } - Assert.assertEquals( groundTruthMapping.get( file ), storedIntervals ); + Assert.assertEquals(groundTruthMapping.get(file), storedIntervals); } } - private static final class ComparableFinalInterval extends AbstractInterval - { + private static final class ComparableFinalInterval extends AbstractInterval { - public ComparableFinalInterval( final long[] min, final long[] max ) - { - super( min, max ); + public ComparableFinalInterval(final long[] min, final long[] max) { + + super(min, max); } - public ComparableFinalInterval( final Interval interval ) - { - super( interval ); + public ComparableFinalInterval(final Interval interval) { + + super(interval); } @Override - public boolean equals( final Object other ) - { - if ( other instanceof Interval ) - { - final Interval that = ( Interval ) other; - return Arrays.equals( Intervals.minAsLongArray( that ), min ) && Arrays.equals( Intervals.maxAsLongArray( that ), max ); + public boolean equals(final Object other) { + + if (other instanceof Interval) { + final Interval that = (Interval)other; + return Arrays.equals(Intervals.minAsLongArray(that), min) && Arrays.equals(Intervals.maxAsLongArray(that), max); } return false; } @Override - public int hashCode() - { - return Arrays.hashCode( min ); + public int hashCode() { + + return Arrays.hashCode(min); } @Override - public String toString() - { - return "(" + Arrays.toString( min ) + " " + Arrays.toString( max ) + ")"; + public String toString() { + + return "(" + Arrays.toString(min) + " " + Arrays.toString(max) + ")"; } } diff --git a/src/test/java/org/janelia/saalfeldlab/label/spark/affinities/MakeEmptyMask.java b/src/test/java/org/janelia/saalfeldlab/label/spark/affinities/MakeEmptyMask.java index 4e7ad39..5f27a66 100644 --- a/src/test/java/org/janelia/saalfeldlab/label/spark/affinities/MakeEmptyMask.java +++ b/src/test/java/org/janelia/saalfeldlab/label/spark/affinities/MakeEmptyMask.java @@ -20,6 +20,7 @@ public class MakeEmptyMask { public static void main(String[] args) throws IOException, ExecutionException, InterruptedException { + final N5FSWriter container = new N5FSWriter("/groups/saalfeld/home/hanslovskyp/data/cremi/sample_A+_padded_20160601-bs=64.n5"); final String rawPath = "volumes/raw/data/s0"; final String maskPath = "volumes/masks/raw"; @@ -33,15 +34,14 @@ public static void main(String[] args) throws IOException, ExecutionException, I new GzipCompression()); container.createDataset(maskPath, maskAttributes); - container.setAttribute(maskPath, "value_range", new double[] {0.0, 1.0}); - container.setAttribute(maskPath, "resolution", new double[] {4, 4, 40}); + container.setAttribute(maskPath, "value_range", new double[]{0.0, 1.0}); + container.setAttribute(maskPath, "resolution", new double[]{4, 4, 40}); final long[] center = new long[maskAttributes.getNumDimensions()]; Arrays.setAll(center, d -> (raw.min(d) + raw.max(d)) / 2); final long[] radius = center.clone(); final double[] doubleCenterSquared = Arrays.stream(center).asDoubleStream().map(d -> d * d).toArray(); - final RandomAccessible mask = new FunctionRandomAccessible<>( center.length, (pos, t) -> { @@ -67,8 +67,6 @@ public static void main(String[] args) throws IOException, ExecutionException, I es.shutdown(); - - } } diff --git a/src/test/resources/logback.xml b/src/test/resources/logback.xml new file mode 100644 index 0000000..248ac8d --- /dev/null +++ b/src/test/resources/logback.xml @@ -0,0 +1,11 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} MDC=%X{user} - %msg%n + + + + + + +