From 4209a5f10a179d49f0705f59a0989b70f76015a7 Mon Sep 17 00:00:00 2001 From: Michael Innerberger Date: Sun, 26 Jan 2025 16:36:26 -0500 Subject: [PATCH] Use IntensityTile in matching and block worker --- .../AffineIntensityCorrectionBlockWorker.java | 116 ++++++++---------- .../solvers/intensity/IntensityMatcher.java | 15 ++- 2 files changed, 57 insertions(+), 74 deletions(-) diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/solvers/intensity/AffineIntensityCorrectionBlockWorker.java b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/solvers/intensity/AffineIntensityCorrectionBlockWorker.java index 0482f3b86..36226c155 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/solvers/intensity/AffineIntensityCorrectionBlockWorker.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/solvers/intensity/AffineIntensityCorrectionBlockWorker.java @@ -1,15 +1,11 @@ package org.janelia.render.client.newsolver.solvers.intensity; -import mpicbg.models.Affine1D; import mpicbg.models.AffineModel1D; -import mpicbg.models.ErrorStatistic; import mpicbg.models.IdentityModel; import mpicbg.models.InterpolatedAffineModel1D; import mpicbg.models.NoninvertibleModelException; import mpicbg.models.PointMatch; import mpicbg.models.Tile; -import mpicbg.models.TileConfiguration; -import mpicbg.models.TileUtil; import mpicbg.models.TranslationModel1D; import net.imglib2.util.ValuePair; @@ -76,14 +72,14 @@ public List, FIBSEMIntensityCorrectionParamet final List wrappedTiles = AdjustBlock.sortTileSpecs(blockData.rtsc()); - final HashMap>>> coefficientTiles = computeCoefficients(wrappedTiles); + final Map coefficientTiles = computeCoefficients(wrappedTiles); coefficientTiles.forEach((tileId, tiles) -> { final ArrayList models = new ArrayList<>(); - tiles.forEach(tile -> { - final AffineModel1D model = ((InterpolatedAffineModel1D) tile.getModel()).createAffineModel1D(); - models.add(model); - }); + for (int i = 0; i < tiles.nSubTiles(); i++) { + final InterpolatedAffineModel1D interpolatedModel = (InterpolatedAffineModel1D) tiles.getSubTile(i).getModel(); + models.add(interpolatedModel.createAffineModel1D()); + } blockData.getResults().recordModel(tileId, models); }); @@ -105,7 +101,7 @@ private void fetchResolvedTiles() blockData.getResults().init(rtsc); } - private HashMap>>> computeCoefficients(final List tiles) + private Map computeCoefficients(final List tiles) throws ExecutionException, InterruptedException { LOG.info("computeCoefficients: entry"); @@ -115,7 +111,7 @@ private HashMap>>> computeCoefficie ? ImageProcessorCache.DISABLED_CACHE : new ImageProcessorCache(parameters.maxNumberOfCachedPixels(), true, false); - final HashMap>>> coefficientTiles = splitIntoCoefficientTiles(tiles, imageProcessorCache); + final Map coefficientTiles = splitIntoCoefficientTiles(tiles, imageProcessorCache); if (tiles.size() > 1) { solveForGlobalCoefficients(coefficientTiles, ITERATIONS); @@ -126,7 +122,7 @@ private HashMap>>> computeCoefficie return coefficientTiles; } - private HashMap>>> splitIntoCoefficientTiles( + private HashMap splitIntoCoefficientTiles( final List tiles, final ImageProcessorCache imageProcessorCache ) throws InterruptedException, ExecutionException { @@ -139,8 +135,7 @@ private HashMap>>> splitIntoCoeffic LOG.info("splitIntoCoefficientTiles: entry, collecting pairs for {} patches with zDistance {}", tiles.size(), parameters.zDistance()); // generate coefficient tiles for all patches - final int nGridPoints = parameters.numCoefficients() * parameters.numCoefficients(); - final HashMap>>> coefficientTiles = generateCoefficientsTiles(tiles, nGridPoints); + final HashMap coefficientTiles = generateCoefficientsTiles(tiles); final List> patchPairs = findOverlappingPatches(tiles, parameters.zDistance()); @@ -187,24 +182,18 @@ private IntensityMatcher getIntensityMatcher( return new IntensityMatcher(filter, parameters, meshResolution, imageProcessorCache); } - private HashMap>>> generateCoefficientsTiles( - final Collection patches, - final int nGridPoints - ) { + private HashMap generateCoefficientsTiles(final Collection patches) { + final InterpolatedAffineModel1D, IdentityModel> modelTemplate = new InterpolatedAffineModel1D<>( new InterpolatedAffineModel1D<>( new AffineModel1D(), new TranslationModel1D(), parameters.lambdaTranslation()), new IdentityModel(), parameters.lambdaIdentity()); - final HashMap>>> coefficientTiles = new HashMap<>(); + final HashMap coefficientTiles = new HashMap<>(); for (final TileSpec p : patches) { - final ArrayList>> coefficientModels = new ArrayList<>(); - for (int i = 0; i < nGridPoints; ++i) { - final InterpolatedAffineModel1D model = modelTemplate.copy(); - coefficientModels.add(new Tile<>(model)); - } - coefficientTiles.put(p.getTileId(), coefficientModels); + final IntensityTile tile = new IntensityTile(modelTemplate::copy, parameters.numCoefficients(), 1); + coefficientTiles.put(p.getTileId(), tile); } return coefficientTiles; } @@ -237,34 +226,34 @@ private static ArrayList> findOverlappingPatches( @SuppressWarnings("SameParameterValue") private void solveForGlobalCoefficients( - final HashMap>>> coefficientTiles, + final Map coefficientTiles, final int iterations ) { - final Tile> equilibrationTile = new Tile<>(new IdentityModel()); + final IntensityTile equilibrationTile = new IntensityTile(IdentityModel::new, 1, 1); connectTilesWithinPatches(coefficientTiles, equilibrationTile); /* optimize */ - final TileConfiguration tc = new TileConfiguration(); - coefficientTiles.values().forEach(tc::addTiles); - - // anchor the equilibration tile - tc.addTile(equilibrationTile); - tc.fixTile(equilibrationTile); - - LOG.info("solveForGlobalCoefficients: optimizing {} tiles with {} threads", tc.getTiles().size(), numThreads); - try { - TileUtil.optimizeConcurrently(new ErrorStatistic(iterations + 1), 0.01f, iterations, iterations, 0.75f, tc, tc.getTiles(), tc.getFixedTiles(), numThreads); - } catch (final Exception e) { - throw new RuntimeException(e); + final List tiles = new ArrayList<>(coefficientTiles.values()); + final List fixedTiles = new ArrayList<>(); + + // anchor the equilibration tile if it is used, otherwise anchor a random tile (the first one) + if (blockData.solveTypeParameters().equilibrationWeight() > 0.0) { + tiles.add(equilibrationTile); + fixedTiles.add(equilibrationTile); + } else { + final IntensityTile firstTile = tiles.get(0); + fixedTiles.add(firstTile); } + LOG.info("solveForGlobalCoefficients: optimizing {} tiles with {} threads", tiles.size(), numThreads); + final IntensityTileOptimizer optimizer = new IntensityTileOptimizer(0.01, iterations, iterations, 0.75, numThreads); + optimizer.optimize(tiles, fixedTiles); + // TODO: this is not the right error measure, what is idToBlockErrorMap supposed to be exactly? - coefficientTiles.forEach((tileId, tiles) -> { - final Double error = tiles.stream().mapToDouble(t -> { - t.updateCost(); - return t.getDistance(); - }).average().orElse(Double.MAX_VALUE); + coefficientTiles.forEach((tileId, tile) -> { + tile.updateDistance(); + final double error = tile.getDistance(); final Map errorMap = new HashMap<>(); errorMap.put(tileId, error); blockData.getResults().recordAllErrors(tileId, errorMap); @@ -274,48 +263,39 @@ private void solveForGlobalCoefficients( } private void connectTilesWithinPatches( - final HashMap>>> coefficientTiles, - final Tile> equilibrationTile + final Map coefficientTiles, + final IntensityTile equilibrationTile ) { final Collection allTiles = blockData.rtsc().getTileSpecs(); final double equilibrationWeight = blockData.solveTypeParameters().equilibrationWeight(); final ResultContainer> results = blockData.getResults(); for (final TileSpec p : allTiles) { - final List> coefficientTile = coefficientTiles.get(p.getTileId()); + final IntensityTile coefficientTile = coefficientTiles.get(p.getTileId()); for (int i = 1; i < parameters.numCoefficients(); ++i) { for (int j = 0; j < parameters.numCoefficients(); ++j) { - final int left = getLinearIndex(i-1, j, parameters.numCoefficients()); - final int right = getLinearIndex(i, j, parameters.numCoefficients()); - final int top = getLinearIndex(j, i, parameters.numCoefficients()); - final int bot = getLinearIndex(j, i-1, parameters.numCoefficients()); + final Tile left = coefficientTile.getSubTile(i-1, j); + final Tile right = coefficientTile.getSubTile(i, j); + final Tile top = coefficientTile.getSubTile(j, i); + final Tile bot = coefficientTile.getSubTile(j, i-1); - identityConnect(coefficientTile.get(right), coefficientTile.get(left)); - identityConnect(coefficientTile.get(top), coefficientTile.get(bot)); + identityConnect(right, left); + identityConnect(top, bot); } } if (equilibrationWeight > 0.0) { final List averages = results.getAveragesFor(p.getTileId()); - for (int i = 0; i < parameters.numCoefficients(); i++) { - for (int j = 0; j < parameters.numCoefficients(); j++) { - final int idx = getLinearIndex(i, j, parameters.numCoefficients()); - equilibrateIntensity(coefficientTile.get(idx), - equilibrationTile, - averages.get(idx), - equilibrationWeight); - } + coefficientTile.connectTo(equilibrationTile); + for (int i = 0; i < coefficientTile.nSubTiles(); i++) { + equilibrateIntensity(coefficientTile.getSubTile(i), + equilibrationTile.getSubTile(0), + averages.get(i), + equilibrationWeight); } } } } - /** - * Get index of the (x,y) pixel in an n x n grid represented by a linear array - */ - private int getLinearIndex(final int x, final int y, final int n) { - return y * n + x; - } - private static void equilibrateIntensity(final Tile coefficientTile, final Tile equilibrationTile, final Double average, diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/solvers/intensity/IntensityMatcher.java b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/solvers/intensity/IntensityMatcher.java index ff5d5182a..8020d3b06 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/solvers/intensity/IntensityMatcher.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/solvers/intensity/IntensityMatcher.java @@ -2,7 +2,6 @@ import ij.process.ColorProcessor; import ij.process.FloatProcessor; -import mpicbg.models.Affine1D; import mpicbg.models.PointMatch; import mpicbg.models.Tile; import net.imglib2.util.Pair; @@ -55,7 +54,7 @@ public IntensityMatcher( this.imageProcessorCache = imageProcessorCache; } - public void match(final TileSpec p1, final TileSpec p2, final HashMap>>> coefficientTiles) { + public void match(final TileSpec p1, final TileSpec p2, final HashMap intensityTiles) { final StopWatch stopWatch = StopWatch.createAndStart(); @@ -115,24 +114,28 @@ public void match(final TileSpec p1, final TileSpec p2, final HashMap>> p1CoefficientTiles = coefficientTiles.get(p1.getTileId()); - final List>> p2CoefficientTiles = coefficientTiles.get(p2.getTileId()); + final IntensityTile p1IntensityTile = intensityTiles.get(p1.getTileId()); + final IntensityTile p2IntensityTile = intensityTiles.get(p2.getTileId()); int connectionCount = 0; for (int i = 0; i < nCoefficientTiles; ++i) { - final Tile t1 = p1CoefficientTiles.get(i); + final Tile t1 = p1IntensityTile.getSubTile(i); for (int j = 0; j < nCoefficientTiles; ++j) { final List matches = get(matrix, i, j, nCoefficientTiles); if (matches.isEmpty()) continue; - final Tile t2 = p2CoefficientTiles.get(j); + final Tile t2 = p2IntensityTile.getSubTile(j); t1.connect(t2, matches); connectionCount++; } } + if (connectionCount > 0) { + p1IntensityTile.connectTo(p2IntensityTile); + } + stopWatch.stop(); LOG.info("match: pair {} <-> {} has {} connections, matching took {}", p1.getTileId(), p2.getTileId(), connectionCount, stopWatch); }