From 39c418aa65e702702fd4e46fc36f56835d20d4b1 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 29 Nov 2023 13:21:52 -0500 Subject: [PATCH 01/28] chore: hide stacktrace from interrupted ConnectedComponents refer to: https://github.com/imglib/imglib2-algorithm/issues/98 --- .../paintera/control/tools/paint/SamTool.kt | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt index 6a6f8ad9f..54cd119bd 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt @@ -48,6 +48,7 @@ import net.imglib2.type.volatiles.VolatileUnsignedLongType import net.imglib2.util.Intervals import paintera.net.imglib2.view.BundleView import net.imglib2.view.Views +import org.apache.commons.io.output.NullPrintStream import org.apache.http.HttpException import org.apache.http.client.methods.HttpPost import org.apache.http.entity.ContentType @@ -552,6 +553,15 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty = ArrayImgs.unsignedLongs(*predictionMask.dimensionsAsLongArray()) + /* FIXME: This is annoying, but I don't see a better way around it at the moment. + * `labelAllConnectedComponents` can be interrupted, but doing so causes an + * internal method to `printStackTrace()` on the error. So even when + * It's intentionally and interrupted and handeled, the consol still logs the + * stacktrace to stderr. We temporarily wrap stderr to swalleow it. + * When [https://github.com/imglib/imglib2-algorithm/issues/98] is resolved, + * hopefully this will be as well */ + val stdErr = System.err + System.setErr(NullPrintStream()) try { ConnectedComponents.labelAllConnectedComponents( filter, @@ -559,9 +569,12 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty Date: Wed, 10 Jan 2024 16:42:13 -0500 Subject: [PATCH 02/28] build: saalfx snapshot dep --- pom.xml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index f056b9375..a24998885 100644 --- a/pom.xml +++ b/pom.xml @@ -53,7 +53,7 @@ true ${javadoc.skip} - 1.0.0 + 1.1.0-SNAPSHOT 3.0.7 1.4.0 @@ -809,8 +809,8 @@ maven-compiler-plugin ${maven-compiler-plugin.version} - ${maven.compiler.source} - ${maven.compiler.target} + 10 + 10 From dfc2e38d4463d3510ce6f01b01ebf4c6d930460e Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 10 Jan 2024 16:42:17 -0500 Subject: [PATCH 03/28] docs --- src/main/java/bdv/fx/viewer/ViewerPanelFX.java | 7 +++++++ .../saalfeldlab/paintera/control/paint/ViewerMask.kt | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/src/main/java/bdv/fx/viewer/ViewerPanelFX.java b/src/main/java/bdv/fx/viewer/ViewerPanelFX.java index f338c0c1b..5c742c5d3 100644 --- a/src/main/java/bdv/fx/viewer/ViewerPanelFX.java +++ b/src/main/java/bdv/fx/viewer/ViewerPanelFX.java @@ -507,6 +507,13 @@ public ReadOnlyDoubleProperty getMouseYProperty() { return mouseTracker.getMouseYProperty(); } + /** + * + * Creates a ObservablePosition which refers to either the location on ViewerPanelFX under the mouse, OR the center + * if the mouse is not on the ViewerPanelFX. + * + * @return the observable position + */ public ObservablePosition createMousePositionOrCenterBinding() { var xBinding = Bindings.createDoubleBinding( diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/paint/ViewerMask.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/paint/ViewerMask.kt index 03f97d2d1..646e1ec8c 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/paint/ViewerMask.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/paint/ViewerMask.kt @@ -462,6 +462,14 @@ class ViewerMask private constructor( } ?: paintera.baseView.orthogonalViews().requestRepaint() } + /** + * Returns the screen interval in the ViewerMask space, based on the given width and height. + * + * @param width The width of the interval. Defaults to the width of the viewer. + * @param height The height of the interval. Defaults to the height of the viewer. + * + * @return The screen interval. + */ @JvmOverloads fun getScreenInterval(width: Long = viewer.width.toLong(), height: Long = viewer.height.toLong()): Interval { val (x: Long, y: Long) = displayPointToInitialMaskPoint(0, 0) From bb239698b729cdaebb7e2fd101cda660b6aa598d Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 10 Jan 2024 16:42:21 -0500 Subject: [PATCH 04/28] feat!: improve drag action logic --- .../paintera/viewer3d/Scene3DHandler.java | 40 ++++++++++--------- .../fx/actions/PainteraActionSet.kt | 5 ++- .../fx/actions/PainteraDragActionSet.kt | 3 +- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/src/main/java/org/janelia/saalfeldlab/paintera/viewer3d/Scene3DHandler.java b/src/main/java/org/janelia/saalfeldlab/paintera/viewer3d/Scene3DHandler.java index dbd01bb31..d87b62641 100644 --- a/src/main/java/org/janelia/saalfeldlab/paintera/viewer3d/Scene3DHandler.java +++ b/src/main/java/org/janelia/saalfeldlab/paintera/viewer3d/Scene3DHandler.java @@ -163,8 +163,8 @@ public TranslateXY(String name) { getDragDetectedAction().setFilter(true); getDragAction().setFilter(true); getDragReleaseAction().setFilter(true); - /* don't update XY*/ - setUpdateXY(false); + /* each drag is relative to previous */ + setRelative(true); onDrag(this::drag); } @@ -172,18 +172,15 @@ private void drag(MouseEvent event) { synchronized (affine) { LOG.trace("drag - translate"); + final Affine target = affine.clone(); final double dX = event.getX() - getStartX(); final double dY = event.getY() - getStartY(); - LOG.trace("dx " + dX + " dy: " + dY); - final Affine target = affine.clone(); target.prependTranslation(2 * dX / viewer.getHeight(), 2 * dY / viewer.getHeight()); + LOG.trace("target: {}", target); InvokeOnJavaFXApplicationThread.invoke(() -> setAffine(target)); - - setStartX(getStartX() + dX); - setStartY(getStartY() + dY); } } } @@ -208,24 +205,14 @@ public Rotate3DView(String name) { super(name); LOG.trace(name); verify(MouseEvent::isPrimaryButtonDown); - setUpdateXY(false); + setRelative(false); onDragDetected(this::dragDetected); onDrag(this::drag); } private void dragDetected(MouseEvent event) { - factor = NORMAL_FACTOR; - - if (event.isShiftDown()) { - if (event.isControlDown()) { - factor = SLOW_FACTOR; - } else { - factor = FAST_FACTOR; - } - } - - speed = baseSpeed * factor; + updateSpeed(event); synchronized (affine) { affineDragStart.setToTransform(affine); @@ -234,6 +221,8 @@ private void dragDetected(MouseEvent event) { private void drag(MouseEvent event) { + updateSpeed(event); + synchronized (affine) { LOG.trace("drag - rotate"); final Affine target = new Affine(affineDragStart); @@ -249,6 +238,19 @@ private void drag(MouseEvent event) { InvokeOnJavaFXApplicationThread.invoke(() -> setAffine(target)); } } + + //TODO Caleb: Use same speed control mechanism as NavigationControlMode uses (to be toggleable) + private void updateSpeed(MouseEvent event) { + factor = NORMAL_FACTOR; + + if (event.isControlDown()) { + factor = SLOW_FACTOR; + } else if (event.isShiftDown()) { + factor = FAST_FACTOR; + } + + speed = baseSpeed * factor; + } } public void resetAffine() { diff --git a/src/main/kotlin/org/janelia/saalfeldlab/fx/actions/PainteraActionSet.kt b/src/main/kotlin/org/janelia/saalfeldlab/fx/actions/PainteraActionSet.kt index bd405e190..94e469ce6 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/fx/actions/PainteraActionSet.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/fx/actions/PainteraActionSet.kt @@ -47,14 +47,15 @@ fun painteraDragActionSet( actionType: ActionType? = null, ignoreDisable: Boolean = false, filter: Boolean = true, + consumeMouseClicked: Boolean = false, apply: (DragActionSet.() -> Unit)? ): DragActionSet { - return DragActionSet(name, { paintera.keyTracker }, filter).apply { + return DragActionSet(name, { paintera.keyTracker }, filter, consumeMouseClicked).apply { verifyPermission(actionType) if (!ignoreDisable) { verifyPainteraNotDisabled() } - apply?.let { it() } + apply?.invoke(this) } } diff --git a/src/main/kotlin/org/janelia/saalfeldlab/fx/actions/PainteraDragActionSet.kt b/src/main/kotlin/org/janelia/saalfeldlab/fx/actions/PainteraDragActionSet.kt index d17d05a5e..801be2886 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/fx/actions/PainteraDragActionSet.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/fx/actions/PainteraDragActionSet.kt @@ -8,8 +8,9 @@ open class PainteraDragActionSet @JvmOverloads constructor( val actionType: ActionType?, name: String, filter: Boolean = true, + consumeMouseClicked: Boolean = false, apply: (DragActionSet.() -> Unit)? = null -) : DragActionSet(name, { paintera.keyTracker }, filter, apply) { +) : DragActionSet(name, { paintera.keyTracker }, filter, consumeMouseClicked, apply) { override fun preInvokeCheck(action: Action, event: E): Boolean { val actionAllowed = actionType?.let { paintera.baseView.allowedActionsProperty().get().isAllowed(actionType) } ?: true From ccbef773953972841c4fb375be10cb801260acd5 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 10 Jan 2024 16:42:25 -0500 Subject: [PATCH 05/28] feat!: provide transform supplier to enable changes --- .../RandomAccessibleIntervalDataSource.java | 29 ++++++++++--------- .../state/RandomAccessibleIntervalBackend.kt | 2 +- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/main/java/org/janelia/saalfeldlab/paintera/data/RandomAccessibleIntervalDataSource.java b/src/main/java/org/janelia/saalfeldlab/paintera/data/RandomAccessibleIntervalDataSource.java index c5c966c6c..946a5a227 100644 --- a/src/main/java/org/janelia/saalfeldlab/paintera/data/RandomAccessibleIntervalDataSource.java +++ b/src/main/java/org/janelia/saalfeldlab/paintera/data/RandomAccessibleIntervalDataSource.java @@ -1,6 +1,7 @@ package org.janelia.saalfeldlab.paintera.data; import bdv.viewer.Interpolation; +import javafx.scene.transform.Affine; import mpicbg.spim.data.sequence.VoxelDimensions; import net.imglib2.RandomAccessible; import net.imglib2.RandomAccessibleInterval; @@ -21,6 +22,7 @@ import java.util.Objects; import java.util.function.Function; import java.util.function.Predicate; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -28,7 +30,7 @@ public class RandomAccessibleIntervalDataSource, T extends Typ private static final Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); - private final AffineTransform3D[] mipmapTransforms; + private final Supplier getMipmapTransforms; private final RandomAccessibleInterval[] sources; @@ -78,7 +80,7 @@ public RandomAccessibleIntervalDataSource( this( dataWithInvalidate.data, dataWithInvalidate.viewData, - dataWithInvalidate.transforms, + () -> dataWithInvalidate.transforms, dataWithInvalidate.invalidate, dataInterpolation, interpolation, @@ -92,14 +94,14 @@ public RandomAccessibleIntervalDataSource( final Function>> interpolation, final String name) { - this(data.getA(), data.getB(), data.getC(), invalidate, dataInterpolation, interpolation, name); + this(data.getA(), data.getB(), data::getC, invalidate, dataInterpolation, interpolation, name); } @SuppressWarnings("unchecked") public RandomAccessibleIntervalDataSource( final RandomAccessibleInterval dataSource, final RandomAccessibleInterval source, - final AffineTransform3D mipmapTransform, + final Supplier mipmapTransform, final Invalidate invalidate, final Function>> dataInterpolation, final Function>> interpolation, @@ -108,7 +110,7 @@ public RandomAccessibleIntervalDataSource( this( new RandomAccessibleInterval[]{dataSource}, new RandomAccessibleInterval[]{source}, - new AffineTransform3D[]{mipmapTransform}, + () -> new AffineTransform3D[]{mipmapTransform.get()}, invalidate, dataInterpolation, interpolation, @@ -118,7 +120,7 @@ public RandomAccessibleIntervalDataSource( public RandomAccessibleIntervalDataSource( final RandomAccessibleInterval[] dataSources, final RandomAccessibleInterval[] sources, - final AffineTransform3D[] mipmapTransforms, + final Supplier getMipmapTransforms, final Invalidate invalidate, final Function>> dataInterpolation, final Function>> interpolation, @@ -127,7 +129,7 @@ public RandomAccessibleIntervalDataSource( this( dataSources, sources, - mipmapTransforms, + getMipmapTransforms, invalidate, dataInterpolation, interpolation, @@ -140,7 +142,7 @@ public RandomAccessibleIntervalDataSource( public RandomAccessibleIntervalDataSource( final RandomAccessibleInterval[] dataSources, final RandomAccessibleInterval[] sources, - final AffineTransform3D[] mipmapTransforms, + final Supplier getMipmapTransforms, final Invalidate invalidate, final Function>> dataInterpolation, final Function>> interpolation, @@ -149,7 +151,7 @@ public RandomAccessibleIntervalDataSource( final String name) { super(); - this.mipmapTransforms = mipmapTransforms; + this.getMipmapTransforms = getMipmapTransforms; this.dataSources = dataSources; this.sources = sources; this.invalidate = invalidate; @@ -201,8 +203,9 @@ public RealRandomAccessible getInterpolatedSource(final int t, final int leve @Override public void getSourceTransform(final int t, final int level, final AffineTransform3D transform) { - LOG.trace("Requesting mipmap transform for level {} at time {}: {}", level, t, mipmapTransforms[level]); - transform.set(mipmapTransforms[level]); + final AffineTransform3D[] transforms = getMipmapTransforms.get(); + LOG.trace("Requesting mipmap transform for level {} at time {}: {}", level, t, transforms[level]); + transform.set(transforms[level]); } @Override @@ -227,7 +230,7 @@ public VoxelDimensions getVoxelDimensions() { @Override public int getNumMipmapLevels() { - return mipmapTransforms.length; + return getMipmapTransforms.get().length; } @Override @@ -258,7 +261,7 @@ public RandomAccessibleIntervalDataSource copy() { return new RandomAccessibleIntervalDataSource<>( dataSources, sources, - mipmapTransforms, + getMipmapTransforms, invalidate, dataInterpolation, interpolation, diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/state/RandomAccessibleIntervalBackend.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/state/RandomAccessibleIntervalBackend.kt index 8c7efbbe8..18ac0be66 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/state/RandomAccessibleIntervalBackend.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/state/RandomAccessibleIntervalBackend.kt @@ -123,7 +123,7 @@ abstract class RandomAccessibleIntervalBackend( return RandomAccessibleIntervalDataSource( dataSources.toTypedArray(), volatileSources.toTypedArray(), - transforms, + { transforms }, NO_OP_INVALIDATE, { NearestNeighborInterpolatorFactory() }, { NearestNeighborInterpolatorFactory() }, From edf462cbe1d719a8476bf02666513257d8e90bd0 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 10 Jan 2024 16:42:32 -0500 Subject: [PATCH 06/28] feat: global view interval 1 pixel in viewerspace --- .../janelia/saalfeldlab/paintera/control/paint/ViewerMask.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/paint/ViewerMask.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/paint/ViewerMask.kt index 646e1ec8c..ec56c7ced 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/paint/ViewerMask.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/paint/ViewerMask.kt @@ -499,7 +499,7 @@ class ViewerMask private constructor( @JvmStatic fun ViewerPanelFX.getGlobalViewerInterval(): RealInterval { val zeroGlobal = doubleArrayOf(0.0, 0.0, 0.0).also { displayToGlobalCoordinates(it) } - val sizeGlobal = doubleArrayOf(width, height, 0.0).also { displayToGlobalCoordinates(it) } + val sizeGlobal = doubleArrayOf(width, height, 1.0).also { displayToGlobalCoordinates(it) } return FinalRealInterval(zeroGlobal, sizeGlobal) } From 7bf93b6cf41bc50528dd7ad19010f7c893042c28 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 10 Jan 2024 16:42:35 -0500 Subject: [PATCH 07/28] feat: more convenience methods --- .../saalfeldlab/util/Imglib2Extensions.kt | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/util/Imglib2Extensions.kt b/src/main/kotlin/org/janelia/saalfeldlab/util/Imglib2Extensions.kt index 95aa4bd30..ef435576e 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/util/Imglib2Extensions.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/util/Imglib2Extensions.kt @@ -10,15 +10,12 @@ import net.imglib2.realtransform.RealViews import net.imglib2.type.BooleanType import net.imglib2.type.Type import net.imglib2.type.numeric.IntegerType -import net.imglib2.type.numeric.NumericType import net.imglib2.type.numeric.RealType import net.imglib2.util.Intervals -import net.imglib2.view.ExtendedRandomAccessibleInterval import net.imglib2.view.IntervalView import net.imglib2.view.RandomAccessibleOnRealRandomAccessible import net.imglib2.view.Views import org.janelia.saalfeldlab.paintera.util.IntervalHelpers.Companion.smallestContainingInterval -import tmp.net.imglib2.outofbounds.OutOfBoundsConstantValueFactory import kotlin.math.floor import kotlin.math.roundToLong @@ -40,6 +37,7 @@ fun > F.interpolate(interpolatorFactory: Interpolator fun RandomAccessible.interpolateNearestNeighbor(): RealRandomAccessible = interpolate(NearestNeighborInterpolatorFactory()) fun RandomAccessibleInterval.interpolateNearestNeighbor(): RealRandomAccessibleRealInterval = interpolate(NearestNeighborInterpolatorFactory()).realInterval(this) fun RandomAccessibleInterval.forEach(loop: (T) -> Unit) = Views.iterable(this).forEach(loop) +fun RandomAccessibleInterval.asIterable() = Views.iterable(this) operator fun RandomAccessible.get(vararg pos: Long): T = getAt(*pos) operator fun RandomAccessible.get(vararg pos: Int): T = getAt(*pos) operator fun RandomAccessible.get(pos: Localizable): T = getAt(pos) @@ -83,8 +81,13 @@ fun > RealRandomAccessible.convertWith(other: RealRandomAcc return Converters.convert(this, other, converter, type) } -/* RealPoint Extensions */ +fun RandomAccessibleInterval.addDimension() = Views.addDimension(this) + +fun RandomAccessibleInterval.addDimension(minOfNewDim : Long, maxOfNewDim : Long) = Views.addDimension(this, minOfNewDim, maxOfNewDim) +fun RealRandomAccessible.addDimension() = RealViews.addDimension(this) + +/* RealPoint Extensions */ fun RealPoint.floor(): Point { val pointVals = LongArray(this.numDimensions()) for (i in 0 until this.numDimensions()) { @@ -118,6 +121,16 @@ fun RealPoint.toPoint(): Point { return Point(*pointVals) } +fun RealPoint.scale(vararg scales: Double, inplace : Boolean = false): RealPoint { + assert(scales.isNotEmpty()) + val scaledPoint = if (inplace) this else RealPoint(numDimensions()) + for (i in 0 until scaledPoint.numDimensions()) { + val scale = if (scales.size > 1) scales[i] else scales[0] + scaledPoint.setPosition(this.getDoublePosition(i) * scale, i) + } + return scaledPoint +} + inline fun RealPoint.get(i: Int): T { return when (T::class) { Double::class -> getDoublePosition(i) From 87fb5af2c97cfbe66cc72c924b155498dc80f897 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 10 Jan 2024 16:42:39 -0500 Subject: [PATCH 08/28] feat: support modifier only code --- .../saalfeldlab/paintera/BindingKeys.kt | 22 +++++++++---------- .../config/input/KeyAndMouseConfigNode.kt | 3 ++- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/BindingKeys.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/BindingKeys.kt index c6607c9c8..d1194bd78 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/BindingKeys.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/BindingKeys.kt @@ -3,11 +3,13 @@ package org.janelia.saalfeldlab.paintera import javafx.scene.input.KeyCode import javafx.scene.input.KeyCode.* import javafx.scene.input.KeyCodeCombination +import javafx.scene.input.KeyCombination import javafx.scene.input.KeyCombination.* import org.janelia.saalfeldlab.fx.actions.NamedKeyCombination infix fun String.byKeyCombo(keyCode: KeyCode) = NamedKeyCombination(this, KeyCodeCombination(keyCode)) -infix fun String.byKeyCombo(combo: KeyCodeCombination) = NamedKeyCombination(this, combo) +infix fun String.byKeyCombo(modifier: Modifier) = NamedKeyCombination(this, NamedKeyCombination.OnlyModifierKeyCombination(modifier)) +infix fun String.byKeyCombo(combo: KeyCombination) = NamedKeyCombination(this, combo) operator fun ArrayList.plus(keyCode: KeyCode) = KeyCodeCombination(keyCode, *this.toTypedArray()) operator fun ArrayList.plus(modifier: Modifier) = this.apply { add(modifier) } @@ -105,7 +107,8 @@ object LabelSourceStateKeys { const val REFRESH_MESHES = "refresh meshes" const val CANCEL = "cancel" const val TOGGLE_NON_SELECTED_LABELS_VISIBILITY = "toggle non-selected labels visibility" - const val SEGMENT_ANYTHING_MODE = "Segment Anything Mode" + const val ENTER_SEGMENT_ANYTHING_MODE = "segment anything: enter mode" + const val EXIT_SEGMENT_ANYTHING_MODE = "segment anything: exit mode" private val namedComboMap = NamedKeyCombination.CombinationMap( SELECT_ALL byKeyCombo A + CONTROL_DOWN, @@ -127,7 +130,8 @@ object LabelSourceStateKeys { REFRESH_MESHES byKeyCombo R, CANCEL byKeyCombo ESCAPE, TOGGLE_NON_SELECTED_LABELS_VISIBILITY byKeyCombo V + SHIFT_DOWN, - SEGMENT_ANYTHING_MODE byKeyCombo A + ENTER_SEGMENT_ANYTHING_MODE byKeyCombo A, + EXIT_SEGMENT_ANYTHING_MODE byKeyCombo ESCAPE ) fun namedCombinationsCopy() = namedComboMap.deepCopy @@ -149,12 +153,10 @@ object NavigationKeys { const val SET_ROTATION_AXIS_Y = "set rotation axis y" const val SET_ROTATION_AXIS_Z = "set rotation axis z" const val KEY_ROTATE_LEFT = "rotate left" - const val KEY_ROTATE_LEFT_FAST = "rotate left fast" - const val KEY_ROTATE_LEFT_SLOW = "rotate left slow" const val KEY_ROTATE_RIGHT = "rotate right" - const val KEY_ROTATE_RIGHT_FAST = "rotate right fast" - const val KEY_ROTATE_RIGHT_SLOW = "rotate right slow" const val REMOVE_ROTATION = "remove rotation" + const val KEY_MODIFIER_FAST = "fast-modifier" + const val KEY_MODIFIER_SLOW = "slow-modifier" private val namedComboMap = NamedKeyCombination.CombinationMap( BUTTON_TRANSLATE_ALONG_NORMAL_BACKWARD byKeyCombo COMMA, @@ -171,11 +173,9 @@ object NavigationKeys { SET_ROTATION_AXIS_Y byKeyCombo Y, SET_ROTATION_AXIS_Z byKeyCombo Z, KEY_ROTATE_LEFT byKeyCombo LEFT, - KEY_ROTATE_LEFT_FAST byKeyCombo LEFT + SHIFT_DOWN, - KEY_ROTATE_LEFT_SLOW byKeyCombo LEFT + CONTROL_DOWN, KEY_ROTATE_RIGHT byKeyCombo RIGHT, - KEY_ROTATE_RIGHT_FAST byKeyCombo RIGHT + SHIFT_DOWN, - KEY_ROTATE_RIGHT_SLOW byKeyCombo RIGHT + CONTROL_DOWN, + KEY_MODIFIER_FAST byKeyCombo SHIFT_DOWN, + KEY_MODIFIER_SLOW byKeyCombo CONTROL_DOWN, REMOVE_ROTATION byKeyCombo Z + SHIFT_DOWN ) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/config/input/KeyAndMouseConfigNode.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/config/input/KeyAndMouseConfigNode.kt index 9e3a68d15..3e39b99fb 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/config/input/KeyAndMouseConfigNode.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/config/input/KeyAndMouseConfigNode.kt @@ -12,6 +12,7 @@ import javafx.scene.Node import javafx.scene.control.* import javafx.scene.control.cell.PropertyValueFactory import javafx.scene.input.KeyCodeCombination +import javafx.scene.input.KeyCombination import javafx.scene.layout.GridPane import javafx.scene.layout.HBox import javafx.scene.layout.Priority @@ -199,7 +200,7 @@ class KeyAndMouseConfigNode( val nameColumn = TableColumn("Name").apply { cellValueFactory = Callback { SimpleStringProperty(it.value) } } - val bindingColumn = TableColumn("Binding").apply { + val bindingColumn = TableColumn("Binding").apply { cellValueFactory = Callback { bindings[it.value]?.primaryCombinationProperty() } } From 7c5029eaa0f1eb5a4bbc8b194a705088167ed1c6 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 10 Jan 2024 16:42:43 -0500 Subject: [PATCH 09/28] refactor!: source from RAI methods --- .../org/janelia/saalfeldlab/paintera/PainteraBaseView.java | 4 ++-- .../janelia/saalfeldlab/paintera/PainteraBaseViewTest.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/janelia/saalfeldlab/paintera/PainteraBaseView.java b/src/main/java/org/janelia/saalfeldlab/paintera/PainteraBaseView.java index 615eae50b..36ab0ed07 100644 --- a/src/main/java/org/janelia/saalfeldlab/paintera/PainteraBaseView.java +++ b/src/main/java/org/janelia/saalfeldlab/paintera/PainteraBaseView.java @@ -372,7 +372,7 @@ public void addGenericState(final SourceState state) { * @param Viewer type of {@code state} * @return the {@link ConnectomicsRawState} that was built from the inputs and added to the viewer */ - public & NativeType, T extends AbstractVolatileNativeRealType> ConnectomicsRawState addConnectomicsRawSource( + public & NativeType, T extends AbstractVolatileNativeRealType> ConnectomicsRawState addMultiscaleConnectomicsRawSource( final RandomAccessibleInterval[] data, final double[][] resolution, final double[][] offset, @@ -413,7 +413,7 @@ public & NativeType, T extends AbstractVolatileNativeR final double max, final String name) { - return addConnectomicsRawSource( + return addMultiscaleConnectomicsRawSource( new RandomAccessibleInterval[]{data}, new double[][]{resolution}, new double[][]{offset}, diff --git a/src/test/java/org/janelia/saalfeldlab/paintera/PainteraBaseViewTest.java b/src/test/java/org/janelia/saalfeldlab/paintera/PainteraBaseViewTest.java index 579ce9612..b33881267 100644 --- a/src/test/java/org/janelia/saalfeldlab/paintera/PainteraBaseViewTest.java +++ b/src/test/java/org/janelia/saalfeldlab/paintera/PainteraBaseViewTest.java @@ -104,7 +104,7 @@ public void testAddMultiScaleConnectomicsRawSource() { }); final PainteraBaseView viewer = Paintera.getPaintera().getBaseView(); - var raw = viewer.addConnectomicsRawSource( + var raw = viewer.addMultiscaleConnectomicsRawSource( multiscale.images, multiscale.resolutions, multiscale.translations, From 940d37bd3ff4d8f626d08d7789d0b4f0adb1517a Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 10 Jan 2024 16:42:54 -0500 Subject: [PATCH 10/28] refactor: cleanup --- .../paintera/control/modes/ShapeInterpolationMode.kt | 2 +- .../paintera/control/tools/paint/PaintBrushTool.kt | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/ShapeInterpolationMode.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/ShapeInterpolationMode.kt index a45f91512..e2dc3971b 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/ShapeInterpolationMode.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/ShapeInterpolationMode.kt @@ -314,7 +314,7 @@ class ShapeInterpolationMode>(val controller: ShapeInterpolat activeViewerProperty.get()?.viewer()?.let { viewer -> painteraMidiActionSet("midi paint tool switch actions", device, viewer, PaintActionType.Paint) { val toggleToolActionMap = mutableMapOf() - activeToolProperty.addListener { obs, old, new -> + activeToolProperty.addListener { _, old, new -> toggleToolActionMap[old]?.updateControlSilently(MCUButtonControl.TOGGLE_OFF) toggleToolActionMap[new]?.updateControlSilently(MCUButtonControl.TOGGLE_ON) } diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/PaintBrushTool.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/PaintBrushTool.kt index 9117451fb..9d14e2430 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/PaintBrushTool.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/PaintBrushTool.kt @@ -40,6 +40,9 @@ import org.janelia.saalfeldlab.paintera.paintera import org.janelia.saalfeldlab.paintera.state.SourceState import java.lang.Double.min +internal const val CHANGE_BRUSH_DEPTH = "change brush depth" +internal const val START_BACKGROUND_ERASE = "start background erase" + open class PaintBrushTool(activeSourceStateProperty: SimpleObjectProperty?>, mode: ToolMode? = null) : PaintTool(activeSourceStateProperty, mode) { @@ -261,10 +264,10 @@ open class PaintBrushTool(activeSourceStateProperty: SimpleObjectProperty Date: Wed, 10 Jan 2024 16:42:59 -0500 Subject: [PATCH 11/28] refactor: simplify constructors --- .../config/input/KeyAndMouseBindings.kt | 21 ++++--------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/config/input/KeyAndMouseBindings.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/config/input/KeyAndMouseBindings.kt index f3098dce0..9e45f6414 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/config/input/KeyAndMouseBindings.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/config/input/KeyAndMouseBindings.kt @@ -3,20 +3,7 @@ package org.janelia.saalfeldlab.paintera.config.input import org.janelia.saalfeldlab.fx.actions.NamedKeyCombination import org.janelia.saalfeldlab.fx.actions.NamedMouseCombination -class KeyAndMouseBindings( - val keyCombinations: NamedKeyCombination.CombinationMap, - val mouseCombinations: NamedMouseCombination.CombinationMap -) { - - constructor() : this(NamedMouseCombination.CombinationMap()) - - constructor(defaultKeyCombinations: NamedKeyCombination.CombinationMap) : this( - defaultKeyCombinations, - NamedMouseCombination.CombinationMap() - ) - - constructor(defaultMouseCombinations: NamedMouseCombination.CombinationMap) : this( - NamedKeyCombination.CombinationMap(), - defaultMouseCombinations - ) -} +class KeyAndMouseBindings @JvmOverloads constructor( + val keyCombinations: NamedKeyCombination.CombinationMap = NamedKeyCombination.CombinationMap(), + val mouseCombinations: NamedMouseCombination.CombinationMap = NamedMouseCombination.CombinationMap() +) \ No newline at end of file From 089a70d4186ea5ec143ab393b35c11e89a340112 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 10 Jan 2024 16:43:10 -0500 Subject: [PATCH 12/28] feat!: speed modifier refactor --- .../navigation/ButtonRotationSpeedConfig.java | 1 + .../control/navigation/KeyRotate.java | 110 ---------- .../paintera/control/navigation/Zoom.java | 31 +-- .../control/modes/NavigationControlMode.kt | 188 +++++++----------- 4 files changed, 92 insertions(+), 238 deletions(-) delete mode 100644 src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/KeyRotate.java diff --git a/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/ButtonRotationSpeedConfig.java b/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/ButtonRotationSpeedConfig.java index be2b6e9d7..d142b7647 100644 --- a/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/ButtonRotationSpeedConfig.java +++ b/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/ButtonRotationSpeedConfig.java @@ -2,6 +2,7 @@ import javafx.beans.property.SimpleDoubleProperty; +/* TODO: Using this for all speed modifier, not just rotation; should rename, but will require serialization changes. */ public class ButtonRotationSpeedConfig { private static final double DEFAULT_SLOW = 0.5; diff --git a/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/KeyRotate.java b/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/KeyRotate.java deleted file mode 100644 index 384389deb..000000000 --- a/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/KeyRotate.java +++ /dev/null @@ -1,110 +0,0 @@ -package org.janelia.saalfeldlab.paintera.control.navigation; - -import javafx.beans.binding.DoubleExpression; -import javafx.beans.binding.ObjectExpression; -import net.imglib2.realtransform.AffineTransform3D; -import org.janelia.saalfeldlab.paintera.state.GlobalTransformManager; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.lang.invoke.MethodHandles; -import java.util.function.Consumer; - -//TOOD Caleb: Consider refactoring to use existing `Rotate` logic -public class KeyRotate { - - private static final Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); - - public enum Axis { - X(0), - Y(1), - Z(2); - - private final int axis; - - Axis(final int axis) { - - this.axis = axis; - } - } - - private final ObjectExpression axis; - - private final DoubleExpression step; - - private final AffineTransformWithListeners displayTransformUpdater; - - private final AffineTransformWithListeners globalToViewerTransformUpdater; - - private final AffineTransform3D globalTransform = new AffineTransform3D(); - private final Consumer submit; - - private final GlobalTransformManager manager; - - private final TranslationController.TransformTracker globalTransformTracker; - private final TranslationController.TransformTracker globalToViewerTransformTracker; - - private final TranslationController.TransformTracker displayTransformTracker; - - private final AffineTransform3D globalToViewerTransform = new AffineTransform3D(); - - private final AffineTransform3D displayTransform = new AffineTransform3D(); - - public KeyRotate( - final ObjectExpression axis, - final DoubleExpression step, - final AffineTransformWithListeners displayTransformUpdater, - final AffineTransformWithListeners globalToViewerTransformUpdater, - final GlobalTransformManager globalTransformManager, - final Consumer submit) { - - super(); - this.axis = axis; - this.step = step; - this.displayTransformUpdater = displayTransformUpdater; - this.globalToViewerTransformUpdater = globalToViewerTransformUpdater; - this.globalTransformTracker = new TranslationController.TransformTracker(globalTransform, globalTransformManager); - this.globalToViewerTransformTracker = new TranslationController.TransformTracker(globalToViewerTransform, globalTransformManager); - this.displayTransformTracker = new TranslationController.TransformTracker(displayTransform, globalTransformManager); - this.submit = submit; - this.manager = globalTransformManager; - listenOnTransformChanges(); - } - - public void listenOnTransformChanges() { - - this.manager.addListener(this.globalTransformTracker); - this.displayTransformUpdater.addListener(this.displayTransformTracker); - this.globalToViewerTransformUpdater.addListener(this.globalToViewerTransformTracker); - } - - public void stopListeningOnTransformChanges() { - - this.manager.removeListener(this.globalTransformTracker); - this.displayTransformUpdater.removeListener(this.displayTransformTracker); - this.globalToViewerTransformUpdater.removeListener(this.globalToViewerTransformTracker); - } - - public void rotate(final double x, final double y) { - - final AffineTransform3D concatenated = displayTransform.copy() - .concatenate(globalToViewerTransform) - .concatenate(globalTransform); - - concatenated.set(concatenated.get(0, 3) - x, 0, 3); - concatenated.set(concatenated.get(1, 3) - y, 1, 3); - LOG.debug("Rotating {} around axis={} by angle={}", concatenated, axis, step); - concatenated.rotate(axis.get().axis, step.get()); - concatenated.set(concatenated.get(0, 3) + x, 0, 3); - concatenated.set(concatenated.get(1, 3) + y, 1, 3); - - submit.accept( - displayTransform.copy() - .concatenate(globalToViewerTransform) - .inverse() - .concatenate(concatenated) - ); - - } - -} diff --git a/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/Zoom.java b/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/Zoom.java index d1737b5b7..24c85c7d7 100644 --- a/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/Zoom.java +++ b/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/Zoom.java @@ -1,34 +1,39 @@ package org.janelia.saalfeldlab.paintera.control.navigation; -import javafx.beans.binding.DoubleExpression; +import javafx.util.Duration; import net.imglib2.realtransform.AffineTransform3D; import org.janelia.saalfeldlab.paintera.state.GlobalTransformManager; public class Zoom { - private final DoubleExpression speed; - private final AffineTransform3D global = new AffineTransform3D(); private final AffineTransform3D concatenated; private final GlobalTransformManager manager; + private boolean busy = false; + public Zoom( - final DoubleExpression speed, final GlobalTransformManager manager, final AffineTransform3D concatenated) { - this.speed = speed; this.manager = manager; this.concatenated = concatenated; this.manager.addListener(global::set); } - public void zoomCenteredAt(final double delta, final double x, final double y) { + public void zoomCenteredAt(final double scaleFactor, final double x, final double y) { + if (Math.abs(1 - scaleFactor) > .2) { + zoomCenteredAt(scaleFactor, x, y, Duration.millis(100)); + } else { + zoomCenteredAt(scaleFactor, x, y, null); + } + } + public void zoomCenteredAt(final double scaleFactor, final double x, final double y, final Duration animate) { - if (delta == 0.0) { + if (scaleFactor == 0.0 || busy) { return; } @@ -40,17 +45,19 @@ public void zoomCenteredAt(final double delta, final double x, final double y) { concatenated.applyInverse(location, location); global.apply(location, location); - final double dScale = speed.get(); - final double scale = delta > 0 ? 1.0 / dScale : dScale; - for (int d = 0; d < location.length; ++d) { global.set(global.get(d, 3) - location[d], d, 3); } - global.scale(scale); + global.scale(scaleFactor); for (int d = 0; d < location.length; ++d) { global.set(global.get(d, 3) + location[d], d, 3); } - manager.setTransform(global); + if (animate != null) { + this.busy = true; + manager.setTransform(global, animate, () -> this.busy = false); + } else { + manager.setTransform(global); + } } } diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt index 329c17a59..01ba03e6b 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt @@ -1,9 +1,6 @@ package org.janelia.saalfeldlab.paintera.control.modes import de.jensd.fx.glyphs.fontawesome.FontAwesomeIconView -import javafx.beans.binding.BooleanExpression -import javafx.beans.binding.DoubleExpression -import javafx.beans.binding.ObjectExpression import javafx.beans.property.SimpleBooleanProperty import javafx.beans.property.SimpleDoubleProperty import javafx.beans.property.SimpleObjectProperty @@ -14,6 +11,7 @@ import javafx.scene.control.ButtonType import javafx.scene.control.Label import javafx.scene.input.KeyCode import javafx.scene.input.KeyEvent.KEY_PRESSED +import javafx.scene.input.KeyEvent.KEY_RELEASED import javafx.scene.input.ScrollEvent import javafx.scene.layout.GridPane import javafx.scene.layout.Priority @@ -24,10 +22,7 @@ import net.imglib2.realtransform.AffineTransform3D import org.janelia.saalfeldlab.control.VPotControl.DisplayType import org.janelia.saalfeldlab.fx.ObservablePosition import org.janelia.saalfeldlab.fx.actions.* -import org.janelia.saalfeldlab.fx.extensions.LazyForeignMap -import org.janelia.saalfeldlab.fx.extensions.LazyForeignValue -import org.janelia.saalfeldlab.fx.extensions.invoke -import org.janelia.saalfeldlab.fx.extensions.nullable +import org.janelia.saalfeldlab.fx.extensions.* import org.janelia.saalfeldlab.fx.midi.MidiActionSet import org.janelia.saalfeldlab.fx.midi.MidiButtonEvent import org.janelia.saalfeldlab.fx.midi.MidiPotentiometerEvent @@ -36,20 +31,23 @@ import org.janelia.saalfeldlab.fx.ui.SpatialField import org.janelia.saalfeldlab.fx.util.InvokeOnJavaFXApplicationThread import org.janelia.saalfeldlab.paintera.DeviceManager import org.janelia.saalfeldlab.paintera.NavigationKeys +import org.janelia.saalfeldlab.paintera.NavigationKeys.KEY_MODIFIER_FAST +import org.janelia.saalfeldlab.paintera.NavigationKeys.KEY_MODIFIER_SLOW +import org.janelia.saalfeldlab.paintera.NavigationKeys.KEY_ROTATE_LEFT +import org.janelia.saalfeldlab.paintera.NavigationKeys.KEY_ROTATE_RIGHT import org.janelia.saalfeldlab.paintera.config.input.KeyAndMouseBindings import org.janelia.saalfeldlab.paintera.control.ControlUtils import org.janelia.saalfeldlab.paintera.control.actions.AllowedActions import org.janelia.saalfeldlab.paintera.control.actions.NavigationActionType import org.janelia.saalfeldlab.paintera.control.navigation.* -import org.janelia.saalfeldlab.paintera.control.navigation.KeyRotate.Axis +import org.janelia.saalfeldlab.paintera.control.navigation.Rotate.Axis import org.janelia.saalfeldlab.paintera.control.tools.Tool import org.janelia.saalfeldlab.paintera.control.tools.ViewerTool import org.janelia.saalfeldlab.paintera.paintera import org.janelia.saalfeldlab.paintera.properties -import org.janelia.saalfeldlab.paintera.state.GlobalTransformManager import org.janelia.saalfeldlab.paintera.ui.PainteraAlerts -import java.util.function.Consumer import kotlin.math.absoluteValue +import kotlin.math.max import kotlin.math.sign /** @@ -94,14 +92,15 @@ object NavigationTool : ViewerTool() { AffineTransform3D().apply { globalTransformManager.addListener { set(it) } } } - private val zoomSpeed = SimpleDoubleProperty(1.05) - - private val rotationSpeed = SimpleDoubleProperty(1.0) val allowRotationsProperty = SimpleBooleanProperty(true) private val buttonRotationSpeedConfig = ButtonRotationSpeedConfig() + private val speedProperty = SimpleDoubleProperty(buttonRotationSpeedConfig.regular.value) + + private val speed: Double by speedProperty.nonnull() + override fun activate() { with(properties.navigationConfig) { @@ -145,7 +144,7 @@ object NavigationTool : ViewerTool() { } val zoomController by LazyForeignValue({ activeViewerAndTransforms }) { - Zoom(zoomSpeed, globalTransformManager, viewerTransform) + Zoom(globalTransformManager, viewerTransform) } val keyRotationAxis by LazyForeignValue({ activeViewerAndTransforms }) { SimpleObjectProperty(Axis.Z) @@ -162,13 +161,12 @@ object NavigationTool : ViewerTool() { override val actionSets by LazyForeignMap({ activeViewerAndTransforms }) { viewerAndTransforms -> viewerAndTransforms?.run { - - globalToViewerTransform().addListener { globalToViewerTransform.set(it) } val actionSets = mutableListOf() + actionSets += speedModifierActions() actionSets += translateAlongNormalActions(translationController!!) actionSets += translateInPlaneActions(translationController!!) actionSets += zoomActions(zoomController, targetPositionObservable!!) @@ -179,6 +177,28 @@ object NavigationTool : ViewerTool() { } ?: mutableListOf() } + private fun speedModifierActions() = painteraActionSet("speed-modifier", ignoreDisable = true) { + mapOf( + KEY_MODIFIER_FAST to buttonRotationSpeedConfig.fast, + KEY_MODIFIER_SLOW to buttonRotationSpeedConfig.slow + ).forEach { (keys, speed) -> + KEY_PRESSED(keyBindings, keys, keysExclusive = false) { + consume = false + onAction { + speedProperty.unbind() + speedProperty.bind(speed) + } + } + KEY_RELEASED(keyBindings, keys, keysExclusive = false) { + consume = false + onAction { + speedProperty.unbind() + speedProperty.bind(buttonRotationSpeedConfig.regular) + } + } + } + } + private fun translateAlongNormalActions(translationController: TranslationController): List { @@ -313,8 +333,11 @@ object NavigationTool : ViewerTool() { ).map { keys -> ScrollEvent.SCROLL { verifyEventNotNull() + verify("scroll size at least 1 pixel") { max(it!!.deltaX.absoluteValue, it.deltaY.absoluteValue) > 1.0 } keysDown(*keys) - onAction { zoomController.zoomCenteredAt(-ControlUtils.getBiggestScroll(it!!), it.x, it.y) } + onAction { + val scale = 1 + ControlUtils.getBiggestScroll(it!!) / 1_000 + zoomController.zoomCenteredAt(scale, it.x, it.y) } } } } @@ -327,10 +350,13 @@ object NavigationTool : ViewerTool() { 1.0 to NavigationKeys.BUTTON_ZOOM_OUT2, -1.0 to NavigationKeys.BUTTON_ZOOM_IN, -1.0 to NavigationKeys.BUTTON_ZOOM_IN2 - ).map { (delta, key) -> - KEY_PRESSED { - onAction { zoomController.zoomCenteredAt(delta, targetPositionObservable.x, targetPositionObservable.y) } - keyMatchesBinding(keyBindings, key) + ).map { (direction, key) -> + KEY_PRESSED(keyBindings, key, keysExclusive = false) { + onAction { + val delta = speed / 100 + val scale = 1 - direction * delta + zoomController.zoomCenteredAt(scale, targetPositionObservable.x, targetPositionObservable.y) + } } } } @@ -381,61 +407,34 @@ object NavigationTool : ViewerTool() { } val setRotationAxis = painteraActionSet("set rotation axis", NavigationActionType.Rotate) { - arrayOf( - Axis.X to NavigationKeys.SET_ROTATION_AXIS_X, - Axis.Y to NavigationKeys.SET_ROTATION_AXIS_Y, - Axis.Z to NavigationKeys.SET_ROTATION_AXIS_Z - ).map { (axis, key) -> - KEY_PRESSED { - onAction { keyRotationAxis.set(axis) } - keyMatchesBinding(keyBindings, key) - } - } + KEY_PRESSED(keyBindings, NavigationKeys.SET_ROTATION_AXIS_X) { onAction { keyRotationAxis.set(Axis.X) } } + KEY_PRESSED(keyBindings, NavigationKeys.SET_ROTATION_AXIS_Y) { onAction { keyRotationAxis.set(Axis.Y) } } + KEY_PRESSED(keyBindings, NavigationKeys.SET_ROTATION_AXIS_Z) { onAction { keyRotationAxis.set(Axis.Z) } } } - fun newDragRotationAction(name: String, speed: Double, keyDown: KeyCode? = null) = - baseRotationAction( - name, - allowRotationsProperty, - rotationSpeed.multiply(speed), - displayTransform, - globalToViewerTransform, - globalTransformManager - ) { globalTransformManager.transform = it }.apply { - keyDown?.let { - dragDetectedAction.keysDown(keyDown) - dragAction.keysDown(keyDown) - } ?: let { - dragDetectedAction.verifyNoKeysDown() - dragAction.verifyNoKeysDown() - } + val rotationController = Rotate(displayTransform, globalToViewerTransform, globalTransformManager) + + val mouseRotation = painteraDragActionSet("mousde-drag-rotate", NavigationActionType.Rotate) { + verify { it.isPrimaryButtonDown } + dragDetectedAction.verify { NavigationTool.allowRotationsProperty() } + onDragDetected { rotationController.initialize(targetPositionObservable.x, targetPositionObservable.y) } + onDrag { + rotationController.setSpeed(speed / buttonRotationSpeedConfig.regular.value) + rotationController.rotate3D(it.x, it.y, startX, startY) } + } + + + val keyRotation = painteraActionSet("key-rotate", NavigationActionType.Rotate) { + mapOf(-1 to KEY_ROTATE_LEFT, 1 to KEY_ROTATE_RIGHT).forEach { (direction, key) -> - val mouseRotation = newDragRotationAction("rotate", DEFAULT) - val fastMouseRotation = newDragRotationAction("rotate fast", FAST, KeyCode.SHIFT) - val slowMouseRotation = newDragRotationAction("rotate slow", SLOW, KeyCode.CONTROL) - - val rotationKeyActions = painteraActionSet("rotate", NavigationActionType.Rotate) { - mapOf( - buttonRotationSpeedConfig.regular.multiply(-1) to NavigationKeys.KEY_ROTATE_LEFT, - buttonRotationSpeedConfig.regular to NavigationKeys.KEY_ROTATE_RIGHT, - - buttonRotationSpeedConfig.fast.multiply(-1) to NavigationKeys.KEY_ROTATE_LEFT_FAST, - buttonRotationSpeedConfig.fast to NavigationKeys.KEY_ROTATE_RIGHT_FAST, - - buttonRotationSpeedConfig.slow.multiply(-1) to NavigationKeys.KEY_ROTATE_LEFT_SLOW, - buttonRotationSpeedConfig.slow to NavigationKeys.KEY_ROTATE_RIGHT_SLOW, - ).forEach { (speed, key) -> - addKeyRotationHandler( - key, keyBindings, - targetPositionObservable, - allowRotationsProperty, - keyRotationAxis, - speed.multiply(Math.PI / 180.0), - displayTransform, - globalToViewerTransform, - globalTransformManager - ) { globalTransformManager.transform = it } + KEY_PRESSED(keyBindings, key, keysExclusive = false) { + verify { allowRotationsProperty() } + onAction { + rotationController.setSpeed(direction * speed) + rotationController.rotateAroundAxis(targetPositionObservable.x, targetPositionObservable.y, keyRotationAxis.get()) + } + } } } @@ -445,10 +444,7 @@ object NavigationTool : ViewerTool() { rotationActions += removeRotationActions rotationActions += setRotationAxis rotationActions += mouseRotation - rotationActions += fastMouseRotation - rotationActions += slowMouseRotation - rotationActions += rotationKeyActions - midiRotationActions()?.let { rotationActions += it } + rotationActions += keyRotation return rotationActions.filterNotNull() } @@ -576,44 +572,4 @@ object NavigationTool : ViewerTool() { } } } - - private fun baseRotationAction( - name: String, - allowRotations: BooleanExpression, - speed: DoubleExpression, - displayTransform: AffineTransformWithListeners, - globalToViewerTransform: AffineTransformWithListeners, - manager: GlobalTransformManager, - submitTransform: Consumer - ): DragActionSet { - val rotate = Rotate(speed, displayTransform, globalToViewerTransform, manager, submitTransform) - - return painteraDragActionSet(name, NavigationActionType.Rotate) { - verify { it.isPrimaryButtonDown } - dragDetectedAction.verify { allowRotations() } - onDragDetected { rotate.initialize() } - onDrag { rotate.rotate(it.x, it.y, startX, startY) } - } - } - - private fun ActionSet.addKeyRotationHandler( - name: String, - keyBindings: NamedKeyCombination.CombinationMap, - targetPositionObservable: ObservablePosition, - allowRotations: BooleanExpression, - axis: ObjectExpression, - step: DoubleExpression, - displayTransformSupplier: AffineTransformWithListeners, - globalToViewerTransform: AffineTransformWithListeners, - globalTransformManager: GlobalTransformManager, - submitTransform: Consumer - ) { - val rotate = KeyRotate(axis, step, displayTransformSupplier, globalToViewerTransform, globalTransformManager, submitTransform) - - KEY_PRESSED { - verify { allowRotations() } - onAction { rotate.rotate(targetPositionObservable.x, targetPositionObservable.y) } - keyMatchesBinding(keyBindings, name) - } - } } From 1da568726b7d8318a23277926016f2f21247d78d Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 10 Jan 2024 16:43:19 -0500 Subject: [PATCH 13/28] refactor!: combine rotate logic, and use speed modifer changes --- .../paintera/control/navigation/Rotate.java | 94 +++++-- .../control/modes/SegmentAnythingMode.kt | 258 ++++++++++++++++++ .../control/tools/paint/SamPredictor.kt | 249 +++++++++++++++++ 3 files changed, 573 insertions(+), 28 deletions(-) create mode 100644 src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/SegmentAnythingMode.kt create mode 100644 src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamPredictor.kt diff --git a/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/Rotate.java b/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/Rotate.java index 4cab29f02..de36e3283 100644 --- a/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/Rotate.java +++ b/src/main/java/org/janelia/saalfeldlab/paintera/control/navigation/Rotate.java @@ -1,16 +1,19 @@ package org.janelia.saalfeldlab.paintera.control.navigation; -import javafx.beans.binding.DoubleExpression; +import javafx.util.Duration; import net.imglib2.realtransform.AffineTransform3D; import org.janelia.saalfeldlab.paintera.state.GlobalTransformManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import java.util.function.Consumer; +import java.lang.invoke.MethodHandles; public class Rotate { - private static final double ROTATION_STEP = Math.PI / 180; + private static final Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); - private final DoubleExpression speed; + private static final double DEFAULT_STEP = Math.PI / 180; + private double speed = 1.0; private final AffineTransform3D globalTransform = new AffineTransform3D(); @@ -18,8 +21,6 @@ public class Rotate { private final AffineTransformWithListeners globalToViewerTransformUpdater; - private final Consumer submitTransform; - private final GlobalTransformManager manager; private final AffineTransform3D affineDragStart = new AffineTransform3D(); @@ -33,16 +34,14 @@ public class Rotate { private final AffineTransform3D globalToViewerTransform = new AffineTransform3D(); private final AffineTransform3D displayTransform = new AffineTransform3D(); + private boolean busy = false; public Rotate( - final DoubleExpression speed, final AffineTransformWithListeners displayTransformUpdater, final AffineTransformWithListeners globalToViewerTransformUpdater, - final GlobalTransformManager globalTransformManager, - final Consumer submitTransform) { + final GlobalTransformManager globalTransformManager) { super(); - this.speed = speed; this.globalTransformTracker = new TranslationController.TransformTracker(globalTransform, globalTransformManager); this.globalToViewerTransformTracker = new TranslationController.TransformTracker(globalToViewerTransform, globalTransformManager); @@ -50,12 +49,17 @@ public Rotate( this.displayTransformUpdater = displayTransformUpdater; this.globalToViewerTransformUpdater = globalToViewerTransformUpdater; - this.submitTransform = submitTransform; this.manager = globalTransformManager; listenOnTransformChanges(); } - public void initialize() { + public void setSpeed(double speed) { + this.speed = speed; + } + + + + public void initialize(final double displayStartX, final double displayStartY) { synchronized (manager) { affineDragStart.set(globalTransform); @@ -76,19 +80,53 @@ public void stopListeningOnTransformChanges() { this.globalToViewerTransformUpdater.removeListener(this.globalToViewerTransformTracker); } - public void rotate(final double x, final double y, final double startX, final double startY) { + public enum Axis { + X, Y, Z + } + + public synchronized void rotateAroundAxis(final double x, final double y, final Axis axis) { + if (busy) return; + + final AffineTransform3D concatenated = displayTransform.copy() + .concatenate(globalToViewerTransform) + .concatenate(globalTransform); + + concatenated.set(concatenated.get(0, 3) - x, 0, 3); + concatenated.set(concatenated.get(1, 3) - y, 1, 3); + + final var step = speed * DEFAULT_STEP; + LOG.debug("Rotating {} around axis={} by angley={}", concatenated, axis, step); + concatenated.rotate(axis.ordinal(), step); + concatenated.set(concatenated.get(0, 3) + x, 0, 3); + concatenated.set(concatenated.get(1, 3) + y, 1, 3); + + final AffineTransform3D rotatedTransform = displayTransform.copy() + .concatenate(globalToViewerTransform) + .inverse() + .concatenate(concatenated); + + final double magnitude = Math.abs(speed); + if (magnitude > 1) { + busy = true; + manager.setTransform(rotatedTransform, Duration.millis(Math.log10(magnitude) * 100), () -> this.busy = false); + } else { + manager.setTransform(rotatedTransform); + } + } + + public void rotate3D(final double x, final double y, final double startX, final double startY) { - final AffineTransform3D affine = new AffineTransform3D(); + final AffineTransform3D rotated = new AffineTransform3D(); synchronized (manager) { - final double v = ROTATION_STEP * this.speed.get(); - affine.set(affineDragStart); - final double[] point = new double[]{x, y, 0}; - final double[] origin = new double[]{startX, startY, 0}; + final double v = DEFAULT_STEP * this.speed; + rotated.set(affineDragStart); + final double[] start = new double[]{startX, startY, 0}; + final double[] end = new double[]{x, y, 0}; - displayTransform.applyInverse(point, point); - displayTransform.applyInverse(origin, origin); + displayTransform.applyInverse(end, end); + displayTransform.applyInverse(start, start); - final double[] delta = new double[]{point[0] - origin[0], point[1] - origin[1], 0}; + final double[] delta = new double[]{end[0] - start[0], end[1] - start[1], 0}; // TODO do scaling separately. need to swap .get( 0, 0 ) and // .get( 1, 1 ) ? final double[] rotation = new double[]{ @@ -96,24 +134,24 @@ public void rotate(final double x, final double y, final double startX, final do -delta[0] * v * displayTransform.get(1, 1), 0}; - globalToViewerTransform.applyInverse(origin, origin); + globalToViewerTransform.applyInverse(start, start); globalToViewerTransform.applyInverse(rotation, rotation); // center shift - for (int d = 0; d < origin.length; ++d) { - affine.set(affine.get(d, 3) - origin[d], d, 3); + for (int d = 0; d < start.length; ++d) { + rotated.set(rotated.get(d, 3) - start[d], d, 3); } for (int d = 0; d < rotation.length; ++d) { - affine.rotate(d, rotation[d]); + rotated.rotate(d, rotation[d]); } // center un-shift - for (int d = 0; d < origin.length; ++d) { - affine.set(affine.get(d, 3) + origin[d], d, 3); + for (int d = 0; d < start.length; ++d) { + rotated.set(rotated.get(d, 3) + start[d], d, 3); } - submitTransform.accept(affine); + manager.setTransform(rotated); } } } diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/SegmentAnythingMode.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/SegmentAnythingMode.kt new file mode 100644 index 000000000..98fd7ab18 --- /dev/null +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/SegmentAnythingMode.kt @@ -0,0 +1,258 @@ +package org.janelia.saalfeldlab.paintera.control.modes + +import ai.onnxruntime.OnnxTensor +import bdv.util.Affine3DHelpers +import de.jensd.fx.glyphs.fontawesome.FontAwesomeIconView +import javafx.beans.value.ChangeListener +import javafx.collections.FXCollections +import javafx.collections.ObservableList +import javafx.scene.input.KeyEvent.KEY_PRESSED +import javafx.scene.input.KeyEvent.KEY_RELEASED +import javafx.scene.input.MouseEvent.MOUSE_PRESSED +import net.imglib2.Interval +import net.imglib2.realtransform.AffineTransform3D +import org.janelia.saalfeldlab.fx.actions.ActionSet +import org.janelia.saalfeldlab.fx.actions.ActionSet.Companion.installActionSet +import org.janelia.saalfeldlab.fx.actions.ActionSet.Companion.removeActionSet +import org.janelia.saalfeldlab.fx.actions.painteraActionSet +import org.janelia.saalfeldlab.fx.extensions.LazyForeignValue +import org.janelia.saalfeldlab.fx.extensions.addWithListener +import org.janelia.saalfeldlab.fx.ortho.OrthogonalViews +import org.janelia.saalfeldlab.paintera.LabelSourceStateKeys +import org.janelia.saalfeldlab.paintera.control.actions.AllowedActions +import org.janelia.saalfeldlab.paintera.control.actions.PaintActionType +import org.janelia.saalfeldlab.paintera.control.paint.ViewerMask +import org.janelia.saalfeldlab.paintera.control.tools.Tool +import org.janelia.saalfeldlab.paintera.control.tools.paint.* +import org.janelia.saalfeldlab.paintera.data.mask.MaskedSource +import org.janelia.saalfeldlab.paintera.paintera + +class SegmentAnythingMode(val previousMode: ControlMode) : AbstractToolMode() { + + override val defaultTool: Tool? by lazy { samTool } + + private val samTool: SamTool = object : SamTool(activeSourceStateProperty, this@SegmentAnythingMode) { + + private var lastEmbedding: OnnxTensor? = null + private var globalTransformAtEmbedding = AffineTransform3D() + + init { + activeViewerProperty.unbind() + activeViewerProperty.bind(mode!!.activeViewerProperty) + } + + override fun activate() { + maskedSource?.resetMasks(false) + providedEmbedding = if (Affine3DHelpers.equals(paintera.baseView.manager().transform, globalTransformAtEmbedding)) lastEmbedding else null + super.activate() + } + + override fun deactivate() { + super.deactivate() + lastEmbedding = getImageEmbeddingTask.get()!! + globalTransformAtEmbedding.set(paintera.baseView.manager().transform) + } + + override fun setCurrentLabelToSelection() { + currentLabelToPaint = statePaintContext!!.selectedIds.lastSelection + } + } + + private val paintBrushTool = object : PaintBrushTool(activeSourceStateProperty, this@SegmentAnythingMode) { + + override val actionSets: MutableList by LazyForeignValue({ activeViewerAndTransforms }) { + mutableListOf( + *getBrushActions().filterNot { it.name == CHANGE_BRUSH_DEPTH }.toTypedArray(), + *getPaintActions().filterNot { it.name == START_BACKGROUND_ERASE }.toTypedArray(), + segmentAnythingPaintBrushActions(), + *(midiBrushActions() ?: arrayOf()) + ) + } + + override fun activate() { + super.activate() + /* Don't allow painting with depth during shape interpolation */ + brushProperties?.brushDepth = 1.0 + paintClickOrDrag!!.provideMask(samTool.viewerMask!!) + } + + override fun deactivate() { + paintClickOrDrag?.release() + super.deactivate() + } + } + + private val fill2DTool = object : Fill2DTool(activeSourceStateProperty, this@SegmentAnythingMode) { + + + private val samPredictionOnFill = ChangeListener { _, _, new -> + new?.let { + switchTool(samTool) + samTool.requestPrediction() + } + } + + override fun activate() { + super.activate() + /* Don't allow filling with depth during shape interpolation */ + brushProperties?.brushDepth = 1.0 + fillLabel = { statePaintContext!!.selectedIds.lastSelection } + brushProperties?.brushDepth = 1.0 + fill2D.provideMask(samTool.viewerMask!!) + fill2D.maskIntervalProperty.addListener(samPredictionOnFill) + } + + override fun deactivate() { + fill2D.maskIntervalProperty.removeListener(samPredictionOnFill) + super.deactivate() + } + + override val actionSets: MutableList by LazyForeignValue({ activeViewerAndTransforms }) { + super.actionSets.also { it += segmentAnythingFloodFillActions(this) } + } + + } + + override val modeActions by lazy { modeActions() } + + override val allowedActions = AllowedActions.AllowedActionsBuilder() + .add(PaintActionType.Paint, PaintActionType.Erase, PaintActionType.SetBrushSize, PaintActionType.Fill) + .create() + + private val toolTriggerListener = ChangeListener { _, old, new -> + new?.viewer()?.apply { modeActions.forEach { installActionSet(it) } } + old?.viewer()?.apply { modeActions.forEach { removeActionSet(it) } } + } + + override val tools: ObservableList by lazy { FXCollections.observableArrayList(paintBrushTool, fill2DTool, samTool) } + + override fun enter() { + activeViewerProperty.addListener(toolTriggerListener) + super.enter() + /* unbind the activeViewerProperty, since we disabled other viewers during ShapeInterpolation mode*/ + activeViewerProperty.unbind() + /* Try to initialize the tool, if state is valid. If not, change back to previous mode. */ + activeViewerProperty.get()?.viewer()?.let { + disableUnfocusedViewers() + switchTool(samTool) + } ?: paintera.baseView.changeMode(previousMode) + } + + override fun exit() { + super.exit() + enableAllViewers() + activeViewerProperty.removeListener(toolTriggerListener) + } + + private fun modeActions(): List { + val keyCombinations = paintera.baseView.keyAndMouseBindings.getConfigFor(activeSourceStateProperty.value!!).keyCombinations + return mutableListOf( + painteraActionSet(LabelSourceStateKeys.EXIT_SEGMENT_ANYTHING_MODE) { + + verifyAll(KEY_PRESSED, "Sam Tool is Active ") { activeTool == samTool } + KEY_PRESSED { + graphic = { FontAwesomeIconView().apply { styleClass += listOf("toolbar-tool", "reject", "reject-segment-anything") } } + keyMatchesBinding(keyCombinations, LabelSourceStateKeys.EXIT_SEGMENT_ANYTHING_MODE) + onAction { + paintera.baseView.changeMode(previousMode) + } + } + }, + painteraActionSet("paint during segment anything", PaintActionType.Paint) { + KEY_PRESSED(*paintBrushTool.keyTrigger.toTypedArray()) { + name = "switch to paint tool" + val getViewerMask = { (activeSourceStateProperty.get()?.dataSource as? MaskedSource<*, *>)?.currentMask as? ViewerMask } + verify { getViewerMask() != null } + onAction { + switchTool(paintBrushTool) + } + } + + KEY_RELEASED(*paintBrushTool.keyTrigger.toTypedArray()) { + name = "switch back to segment anything tool from paint brush" + filter = true + verify { activeTool is PaintBrushTool } + onAction { switchTool(samTool) } + } + + KEY_PRESSED(*fill2DTool.keyTrigger.toTypedArray()) { + name = "switch to fill2d tool" + verify { activeSourceStateProperty.get()?.dataSource is MaskedSource<*, *> } + onAction { switchTool(fill2DTool) } + } + KEY_RELEASED(*fill2DTool.keyTrigger.toTypedArray()) { + name = "switch to segment anything tool from fill2d" + filter = true + verify { activeTool is Fill2DTool } + onAction { + switchTool(samTool) + } + } + } + ) + } + + /** + * Additional paint brush actions for Segment Anything + * + * @receiver the tool to add the actions to + * @return the additional action sets + */ + private fun PaintBrushTool.segmentAnythingPaintBrushActions(): ActionSet { + + return painteraActionSet("Segment Anything Paint Brush Actions", PaintActionType.SegmentAnything) { + MOUSE_PRESSED { + name = "provide SAM tool mask to paint brush" + filter = true + consume = false + verify { activeTool == this@segmentAnythingPaintBrushActions } + onAction { + /* On click, generate a new mask, */ + (activeSourceStateProperty.get()?.dataSource as? MaskedSource<*, *>)?.let { source -> + paintClickOrDrag!!.let { paintController -> + source.resetMasks(true) + paintController.provideMask(samTool.viewerMask!!) + } + } + } + } + } + } + + /** + * Additional fill actions for Segment Anything + * + * @param floodFillTool + * @return the additional ActionSet + * + * */ + private fun segmentAnythingFloodFillActions(floodFillTool: Fill2DTool): ActionSet { + return painteraActionSet("Segment Anything Fill 2D Actions", PaintActionType.SegmentAnything) { + MOUSE_PRESSED { + name = "provide SAM tool mask to fill 2d" + filter = true + consume = false + verify { activeTool == floodFillTool } + onAction { + /* On click, provide the mask, setup the task listener */ + (activeSourceStateProperty.get()?.dataSource as? MaskedSource<*, *>)?.let { source -> + source.resetMasks(true) + val mask = samTool.viewerMask!! + fill2DTool.run { + fillTaskProperty.addWithListener { obs, _, task -> + task?.let { + task.onCancelled(true) { _, _ -> + source.resetMasks(true) + mask.requestRepaint() + } + task.onEnd(true) { obs?.removeListener(this) } + } ?: obs?.removeListener(this) + } + fill2D.provideMask(mask) + } + } + } + } + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamPredictor.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamPredictor.kt new file mode 100644 index 000000000..acbe312e3 --- /dev/null +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamPredictor.kt @@ -0,0 +1,249 @@ +package org.janelia.saalfeldlab.paintera.control.tools.paint + +import ai.onnxruntime.OnnxTensor +import ai.onnxruntime.OnnxTensorLike +import ai.onnxruntime.OrtEnvironment +import ai.onnxruntime.OrtSession +import io.github.oshai.kotlinlogging.KotlinLogging +import net.imglib2.RandomAccessibleInterval +import net.imglib2.RealPoint +import net.imglib2.img.array.ArrayImgs +import net.imglib2.type.NativeType +import net.imglib2.type.numeric.integer.UnsignedLongType +import net.imglib2.type.numeric.real.FloatType +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.FloatBuffer + +private fun allocateDirectFloatBuffer(size: Int, order: ByteOrder = ByteOrder.nativeOrder()): FloatBuffer { + return ByteBuffer.allocateDirect(size * Float.SIZE_BYTES).order(order).asFloatBuffer() +} + +private val LOG = KotlinLogging.logger { } + +class SamPredictor( + private val environment: OrtEnvironment, + private val session: OrtSession, + var embedding: OnnxTensor, + val originalImgSize: Pair +) { + companion object { + + /** + * SAM Embedding converts image to max dim length of 1024. + * Instead of letting the service do it, we just do it ahead of time. + * Helps ensure the image we send is as small as possible. + */ + internal const val MAX_DIM_TARGET = 1024 + + const val LOW_RES_MASK_DIM = 256 + + const val IMAGE_EMBEDDINGS = "image_embeddings" + const val ORIG_IM_SIZE = "orig_im_size" + const val POINT_COORDS = "point_coords" + const val POINT_LABELS = "point_labels" + const val MASK_INPUT = "mask_input" + const val HAS_MASK_INPUT = "has_mask_input" + + /** + * Creates a prediction request with the given points. + * The points must be in the range [0, origImgSize) for each dimension, + * and have labels of either SparseLabel.OUT or SpareLabel.IN + * + * @param points The list of points. + * @return The created PredictionRequest object. + */ + fun points(points: List): PredictionRequest { + return SparsePrediction(points) + } + } + + /* TODO: Evaluate this is correct. I think we are supposed to introduce a half pixel offset somewhere...? + * */ + val imgEmbeddingScale = let { + val (xDim, yDim) = originalImgSize + if (xDim >= yDim) { + MAX_DIM_TARGET.toDouble() / xDim + } else { + MAX_DIM_TARGET.toDouble() / yDim + } + } + + + private val imgSizeBuffer = allocateDirectFloatBuffer(2).also { + originalImgSize.let { (width, height) -> + it.put(height.toFloat()) + it.put(width.toFloat()) + } + it.position(0) + } + private val imgSizeTensor = OnnxTensor.createTensor(environment, imgSizeBuffer, longArrayOf(2)) + + lateinit var lastPrediction: SamPrediction + lateinit var result: RandomAccessibleInterval> + + fun predict(vararg requests: PredictionRequest): SamPrediction { + /* Add the embedding and size */ + val params = mutableMapOf( + IMAGE_EMBEDDINGS to embedding, + ORIG_IM_SIZE to imgSizeTensor + ) + /* add the `no-mask` params. If a mask is present, they will be overwritten */ + params += MaskPrediction.noMaskParameters(environment) + + /* add the parameter maps */ + requests.map { it.mapParameters(this, environment) }.fold(params) { acc, map -> + acc += map + acc + } + /* run the prediction */ + synchronized(this) { + val predictionResult = session.run(params) + return SamPrediction(predictionResult, this).also { + lastPrediction = it + } + } + } + + + /** + * Converts coordinates in the original image space to embedded image coordinates. + * The image sent to be embedded is always scaled such that the longest dimension is 1024, + * while maintaining the aspect ratio. The coordinates are scaled to match the scaled image. + * + * + * @param coord The coordinate to be converted, within the bounds of the [SamPredictor.originalImgSize]. + * @return The converted coordinate within the bounds of the scaled image. + */ + private fun originalToEmbeddedImageCoord(coord: RealPoint): RealPoint { + return coord.positionAsDoubleArray() + .map { it * imgEmbeddingScale } + .toDoubleArray().let { + RealPoint.wrap(it) + } + } + + data class SamPrediction( + val masks: OnnxTensor? = null, + val iouPredictions: OnnxTensor, + val lowResMasks: OnnxTensor, + val predictor: SamPredictor + ) { + + var image: RandomAccessibleInterval = ArrayImgs.floats(lowResMasks.floatBuffer.array(), LOW_RES_MASK_DIM.toLong(), LOW_RES_MASK_DIM.toLong()) + + /* + * Binary segmentation mask of connected components, or null if not binarized. + * Interval is the smallest bounding box containing the segmentation. May be empty, or the entire original image size. + */ + var segmentation: RandomAccessibleInterval? = null + + constructor(result: OrtSession.Result, predictor: SamPredictor) : this( + result[MASKS].get() as OnnxTensor, + result[IOU_PREDICTIONS].get() as OnnxTensor, + result[LOW_RES_MASKS].get() as OnnxTensor, + predictor + ) + + companion object { + const val MASKS = "masks" + const val IOU_PREDICTIONS = "iou_predictions" + const val LOW_RES_MASKS = "low_res_masks" + } + } + + + enum class SparseLabel(val label: Float) { + OUT(0f), + IN(1f), + TOP_LEFT_BOX(2f), + BOTTOM_RIGHT_BOX(3f) + } + + + interface PredictionRequest { + + fun mapParameters(predictor: SamPredictor, environment: OrtEnvironment): Map + } + + /** + * Represents a prediction request for sparse data. + * Points must be x,y integers relative to the top left of the image used to generate the embedding. + * Points must be within (0 - [SamPredictor.originalImgSize]) for all dimensions. + * + * @property points A list of points. + */ + class SparsePrediction(val points: List) : PredictionRequest { + + override fun mapParameters(predictor: SamPredictor, environment: OrtEnvironment): Map { + + val numPoints = if (points.isEmpty()) 1 else points.size + val coordsBuffer = allocateDirectFloatBuffer(2 * numPoints) + val labelsBuffer = allocateDirectFloatBuffer(numPoints) + + points.ifEmpty { listOf(SamPoint(0.0, 0.0, SparseLabel.OUT)) }.forEach { + val (scaledX, scaledY) = it.centerScaledCoordinates(predictor.imgEmbeddingScale) + coordsBuffer.put(scaledX.toFloat()) + coordsBuffer.put(scaledY.toFloat()) + labelsBuffer.put(it.label.ordinal.toFloat()) + } + + coordsBuffer.position(0) + labelsBuffer.position(0) + + val onnxCoords = OnnxTensor.createTensor(environment, coordsBuffer, longArrayOf(1, numPoints.toLong(), 2)) + val onnxLabels = OnnxTensor.createTensor(environment, labelsBuffer, longArrayOf(1, numPoints.toLong())) + return mapOf( + POINT_COORDS to onnxCoords, + POINT_LABELS to onnxLabels + ) + } + } + + /** + * Represents a prediction request with a mask input. + * Mask Interval should be 256x256 + * + * @property mask A lowres mask dictating parts of the image that should be inside or outside the segmentation + */ + class MaskPrediction(val mask: RandomAccessibleInterval>) : PredictionRequest { + + companion object { + const val MASK_DIM = LOW_RES_MASK_DIM.toLong() + + /* 4D low-res mask, expected to be 256x256 */ + private val noMaskBuffer by lazy { + val maskDim = MASK_DIM.toInt() + allocateDirectFloatBuffer(1 * 1 * maskDim * maskDim) + } + private val maskShape = longArrayOf(1, 1, MASK_DIM, MASK_DIM) + + private val hasNoMaskInput by lazy { allocateDirectFloatBuffer(1) } + private val hasMaskFlagShape = longArrayOf(1) + + fun noMaskParameters(environment: OrtEnvironment): Map { + return mapOf( + MASK_INPUT to OnnxTensor.createTensor(environment, noMaskBuffer, maskShape), + HAS_MASK_INPUT to OnnxTensor.createTensor(environment, hasNoMaskInput, hasMaskFlagShape), + ) + } + } + + init { + assert(mask.dimension(0) == MASK_DIM) + assert(mask.dimension(1) == MASK_DIM) + } + + override fun mapParameters(predictor: SamPredictor, environment: OrtEnvironment): Map { + TODO("Not yet implemented") + } + } + + data class SamPoint(val x: Double, val y: Double, val label: SparseLabel) { + + fun centerScaledCoordinates(scale: Double): Pair { + return (x + .5) * scale to (y + .5) * scale + } + } + +} \ No newline at end of file From ab3c639ea5dca5aa300ff72f82548bfea782ab1c Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 10 Jan 2024 16:43:36 -0500 Subject: [PATCH 14/28] feat,refactor!: support scale by different factors per dimension --- .../paintera/data/mask/MaskedSource.java | 2 +- .../paintera/util/IntervalHelpers.kt | 103 +++++++++++++----- 2 files changed, 78 insertions(+), 27 deletions(-) diff --git a/src/main/java/org/janelia/saalfeldlab/paintera/data/mask/MaskedSource.java b/src/main/java/org/janelia/saalfeldlab/paintera/data/mask/MaskedSource.java index a6ced8582..fdf2a0d16 100644 --- a/src/main/java/org/janelia/saalfeldlab/paintera/data/mask/MaskedSource.java +++ b/src/main/java/org/janelia/saalfeldlab/paintera/data/mask/MaskedSource.java @@ -1187,7 +1187,7 @@ private static > Set downsample( /* Views.tiles doesn't preserver intervals, so zeroMin prior to tiling */ final var zeroMinTarget = Views.zeroMin(target); - final var sourceInterval = IntervalHelpers.scaleBy(target, steps, true); + final var sourceInterval = IntervalHelpers.scale(target, steps, true); final IntervalView zeroMinSource = Views.zeroMin(Views.interval(source, sourceInterval)); final var tiledSource = Views.tiles(zeroMinSource, steps); diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/util/IntervalHelpers.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/util/IntervalHelpers.kt index b96034fbb..21a0e2f89 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/util/IntervalHelpers.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/util/IntervalHelpers.kt @@ -1,13 +1,12 @@ package org.janelia.saalfeldlab.paintera.util -import net.imglib2.FinalInterval -import net.imglib2.FinalRealInterval -import net.imglib2.Interval -import net.imglib2.RealInterval +import net.imglib2.* import net.imglib2.algorithm.util.Grids import net.imglib2.realtransform.RealTransform import net.imglib2.util.Intervals -import java.util.Arrays +import org.janelia.saalfeldlab.paintera.util.IntervalHelpers.Companion.scale +import org.janelia.saalfeldlab.util.shape +import java.util.* import kotlin.math.max import kotlin.math.min @@ -59,7 +58,7 @@ class IntervalHelpers { @JvmStatic @JvmOverloads - fun Interval.scaleBy(scaleFactor: Int, scaleMin: Boolean = false): Interval { + fun Interval.scale(scaleFactor: Int, scaleMin: Boolean = false): Interval { val newMin = minAsLongArray().also { if (scaleMin) { it.forEachIndexed { idx, min -> it[idx] = min * scaleFactor } @@ -70,7 +69,8 @@ class IntervalHelpers { @JvmStatic @JvmOverloads - fun Interval.scaleBy(vararg scaleFactors: Int, scaleMin: Boolean = false): Interval { + fun Interval.scale(vararg scaleFactors: Int, scaleMin: Boolean = false): Interval { + /* FIXME: Look at the modified [RealInterval.scale] below for reference on how this should behave */ assert(scaleFactors.size == nDim) val newMin = minAsLongArray().also { if (scaleMin) { @@ -80,31 +80,45 @@ class IntervalHelpers { return FinalInterval(newMin, newMin.copyOf().apply { forEachIndexed { idx, min -> this[idx] = min - 1 + dimension(idx) * scaleFactors[idx] } }) } + /** + * Scale the real interval by the provided scale factors. + * + * If [scaleMin] is true, the size of the resulting interval will be the initial size * [scaleFactors] + * per dimension, and the min position of the resulting interval will also be scaled based on [scaleFactors]. + * + * If [scaleMin] is false, the size will still be the result of the size * [scaleFactors] per dimension, + * but the min position will be the same as the original. + * + * @param scaleFactors to scale the interval by + * @param scaleMin whether to scale the min position, or just the size. + * @return the scaled RealInterval + */ @JvmStatic @JvmOverloads - fun RealInterval.scaleBy(scaleFactor: Double, scaleMin: Boolean = false): RealInterval { - val newMin = minAsDoubleArray().also { - if (scaleMin) { - it.forEachIndexed { idx, min -> it[idx] = min * scaleFactor } - } + fun RealInterval.scale(vararg scaleFactors: Double, scaleMin: Boolean = true): RealInterval { + /* TODO: WRITE TESTS */ + var scales = when(scaleFactors.size) { + nDim -> scaleFactors + 1 -> DoubleArray(nDim) {scaleFactors[0]} + else -> throw IndexOutOfBoundsException("Provided scales of size ${scaleFactors.size} cannot be used to scale interval with nDim of $nDim") } - return FinalRealInterval( - newMin, - newMin.copyOf().apply { forEachIndexed { idx, min -> this[idx] = min - 1 + ((min - realMin(idx)) * scaleFactor) } }) - } - @JvmStatic - @JvmOverloads - fun RealInterval.scaleBy(vararg scaleFactors: Double, scaleMin: Boolean = false): RealInterval { - assert(scaleFactors.size == nDim) - val newMin = minAsDoubleArray().also { - if (scaleMin) { - it.forEachIndexed { idx, min -> it[idx] = min * scaleFactors[idx] } + val newMin = minAsDoubleArray() + val newMax = maxAsDoubleArray() + if (scaleMin) { + newMin.also { + it.forEachIndexed { idx, min -> it[idx] = min * scales[idx] } + } + newMax.also { + it.forEachIndexed { idx, max -> it[idx] = max * scales[idx] } + } + } else { + shape().forEachIndexed {idx, dimLen -> + val scaledDimLen = dimLen * scales[idx] + newMax[idx] = newMin[idx] + scaledDimLen - 1 } } - return FinalRealInterval( - newMin, - newMin.copyOf().apply { forEachIndexed { idx, min -> this[idx] = min - 1 + ((min - realMin(idx)) * scaleFactors[idx]) } }) + return FinalRealInterval( newMin, newMax) } @@ -128,3 +142,40 @@ class IntervalHelpers { } + +fun main() { + //TODO Caleb: Move to test + val zeroMin = Intervals.createMinMaxReal(0.0, 0.0, 99.0, 99.0) + zeroMin.shape().contentEquals(doubleArrayOf(100.0, 100.0)) + + + val zeroMinDoubleTrue = zeroMin.scale(2.0, scaleMin = true) + zeroMinDoubleTrue.shape().contentEquals(doubleArrayOf(200.0, 200.0)) + + val zeroMinDoubleFalse = zeroMin.scale(2.0, scaleMin = false) + zeroMinDoubleFalse.shape().contentEquals(doubleArrayOf(200.0, 200.0)) + + val zeroMinDoubleFalse5 = zeroMin.scale(5.0, scaleMin = false) + zeroMinDoubleFalse5.shape().contentEquals(doubleArrayOf(500.0, 500.0)) + + val zeroMinHalfTrue: RealInterval = zeroMin.scale(.5, scaleMin = true) + zeroMinHalfTrue.shape().contentEquals(doubleArrayOf(50.0, 50.0)) + + val zeroMinHalfFalse = zeroMin.scale(.5, scaleMin = false) + zeroMinHalfFalse.shape().contentEquals(doubleArrayOf(50.0, 50.0)) + + val min = Intervals.createMinMaxReal(50.0, 50.0, 99.0, 99.0) + zeroMinDoubleTrue.shape().contentEquals(doubleArrayOf(200.0, 200.0)) + + val minDoubleTrue = min.scale(2.0, scaleMin = true) + minDoubleTrue.shape().contentEquals(doubleArrayOf(100.0, 100.0)) + + val minDoubleFalse = min.scale(2.0, scaleMin = false) + minDoubleFalse.shape().contentEquals(doubleArrayOf(100.0, 100.0)) + + val minHalfTrue = min.scale(.5, scaleMin = true) + minHalfTrue.shape().contentEquals(doubleArrayOf(25.0, 25.0)) + + val minHalfFalse = min.scale(.5, scaleMin = false) + minHalfFalse.shape().contentEquals(doubleArrayOf(25.0, 25.0)) +} From 1156dd142c017268ff746420aa01feb2ae860a85 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 10 Jan 2024 16:43:39 -0500 Subject: [PATCH 15/28] feat!: sam refactor --- .../control/actions/PaintActionType.java | 3 +- .../paintera/control/tools/paint/SamTool.kt | 651 ++++++++++++------ src/main/resources/style/sam.css | 17 +- 3 files changed, 455 insertions(+), 216 deletions(-) diff --git a/src/main/java/org/janelia/saalfeldlab/paintera/control/actions/PaintActionType.java b/src/main/java/org/janelia/saalfeldlab/paintera/control/actions/PaintActionType.java index 2bb269e82..baaa3a57f 100644 --- a/src/main/java/org/janelia/saalfeldlab/paintera/control/actions/PaintActionType.java +++ b/src/main/java/org/janelia/saalfeldlab/paintera/control/actions/PaintActionType.java @@ -10,7 +10,8 @@ public enum PaintActionType implements ActionType { Intersect, SetBrushSize, SetBrushDepth, - ShapeInterpolation; + ShapeInterpolation, + SegmentAnything; public static EnumSet of(final PaintActionType first, final PaintActionType... rest) { diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt index 54cd119bd..816688209 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt @@ -1,6 +1,8 @@ package org.janelia.saalfeldlab.paintera.control.tools.paint -import ai.onnxruntime.* +import ai.onnxruntime.OnnxTensor +import ai.onnxruntime.OrtEnvironment +import ai.onnxruntime.OrtException import bdv.cache.SharedQueue import bdv.fx.viewer.ViewerPanelFX import bdv.fx.viewer.render.RenderUnit @@ -24,29 +26,31 @@ import javafx.scene.input.KeyCode import javafx.scene.input.KeyEvent.KEY_PRESSED import javafx.scene.input.KeyEvent.KEY_RELEASED import javafx.scene.input.MouseButton +import javafx.scene.input.MouseEvent import javafx.scene.input.MouseEvent.MOUSE_CLICKED import javafx.scene.input.MouseEvent.MOUSE_MOVED -import javafx.scene.input.ScrollEvent +import javafx.scene.input.ScrollEvent.SCROLL import javafx.scene.shape.Circle +import javafx.scene.shape.Rectangle +import net.imglib2.FinalInterval import net.imglib2.Interval -import net.imglib2.Point import net.imglib2.RandomAccessibleInterval -import net.imglib2.RealPoint import net.imglib2.algorithm.labeling.ConnectedComponents import net.imglib2.algorithm.labeling.ConnectedComponents.StructuringElement import net.imglib2.converter.Converters +import net.imglib2.histogram.Real1dBinMapper import net.imglib2.img.array.ArrayImgs +import net.imglib2.interpolation.randomaccess.NLinearInterpolatorFactory +import net.imglib2.iterator.IntervalIterator import net.imglib2.loops.LoopBuilder import net.imglib2.parallel.TaskExecutors -import net.imglib2.realtransform.AffineTransform3D -import net.imglib2.realtransform.Scale3D -import net.imglib2.realtransform.Translation3D +import net.imglib2.realtransform.* import net.imglib2.type.logic.BoolType import net.imglib2.type.numeric.integer.UnsignedLongType import net.imglib2.type.numeric.real.FloatType +import net.imglib2.type.volatiles.VolatileFloatType import net.imglib2.type.volatiles.VolatileUnsignedLongType import net.imglib2.util.Intervals -import paintera.net.imglib2.view.BundleView import net.imglib2.view.Views import org.apache.commons.io.output.NullPrintStream import org.apache.http.HttpException @@ -57,55 +61,56 @@ import org.apache.http.impl.client.HttpClients import org.apache.http.util.EntityUtils import org.janelia.saalfeldlab.fx.Tasks import org.janelia.saalfeldlab.fx.UtilityTask -import org.janelia.saalfeldlab.fx.actions.painteraActionSet -import org.janelia.saalfeldlab.fx.actions.painteraMidiActionSet -import org.janelia.saalfeldlab.fx.actions.verifyPainteraNotDisabled +import org.janelia.saalfeldlab.fx.actions.* +import org.janelia.saalfeldlab.fx.actions.ActionSet.Companion.installActionSet import org.janelia.saalfeldlab.fx.event.KeyTracker import org.janelia.saalfeldlab.fx.extensions.LazyForeignValue import org.janelia.saalfeldlab.fx.extensions.nonnull import org.janelia.saalfeldlab.fx.extensions.nullable -import org.janelia.saalfeldlab.fx.extensions.position import org.janelia.saalfeldlab.fx.midi.MidiButtonEvent +import org.janelia.saalfeldlab.fx.util.InvokeOnJavaFXApplicationThread import org.janelia.saalfeldlab.labels.Label import org.janelia.saalfeldlab.paintera.DeviceManager import org.janelia.saalfeldlab.paintera.PainteraBaseView +import org.janelia.saalfeldlab.paintera.composition.ARGBCompositeAlphaAdd import org.janelia.saalfeldlab.paintera.composition.CompositeProjectorPreMultiply import org.janelia.saalfeldlab.paintera.control.actions.PaintActionType import org.janelia.saalfeldlab.paintera.control.modes.PaintLabelMode -import org.janelia.saalfeldlab.paintera.control.modes.ShapeInterpolationMode import org.janelia.saalfeldlab.paintera.control.modes.ToolMode import org.janelia.saalfeldlab.paintera.control.paint.ViewerMask import org.janelia.saalfeldlab.paintera.control.paint.ViewerMask.Companion.createViewerMask +import org.janelia.saalfeldlab.paintera.control.paint.ViewerMask.Companion.getGlobalViewerInterval +import org.janelia.saalfeldlab.paintera.control.tools.paint.SamPredictor.Companion.LOW_RES_MASK_DIM +import org.janelia.saalfeldlab.paintera.control.tools.paint.SamPredictor.SamPoint import org.janelia.saalfeldlab.paintera.data.mask.MaskInfo import org.janelia.saalfeldlab.paintera.data.mask.MaskedSource import org.janelia.saalfeldlab.paintera.paintera import org.janelia.saalfeldlab.paintera.properties import org.janelia.saalfeldlab.paintera.state.SourceState +import org.janelia.saalfeldlab.paintera.state.predicate.threshold.Bounds import org.janelia.saalfeldlab.paintera.util.IntervalHelpers import org.janelia.saalfeldlab.paintera.util.IntervalHelpers.Companion.asRealInterval import org.janelia.saalfeldlab.paintera.util.IntervalHelpers.Companion.extendBy import org.janelia.saalfeldlab.paintera.util.IntervalHelpers.Companion.smallestContainingInterval +import org.janelia.saalfeldlab.paintera.util.algorithms.otsuThresholdPrediction import org.janelia.saalfeldlab.util.* import org.slf4j.LoggerFactory +import paintera.net.imglib2.view.BundleView import java.io.PipedInputStream import java.io.PipedOutputStream import java.lang.invoke.MethodHandles import java.nio.ByteBuffer -import java.nio.FloatBuffer +import java.nio.ByteOrder import java.nio.file.Files import java.nio.file.Paths import java.util.concurrent.Executors import java.util.concurrent.LinkedBlockingQueue import javax.imageio.ImageIO -import kotlin.collections.component1 -import kotlin.collections.component2 import kotlin.collections.set -import kotlin.math.absoluteValue -import kotlin.math.max -import kotlin.math.min -import kotlin.math.sign +import kotlin.math.* import kotlin.properties.Delegates + open class SamTool(activeSourceStateProperty: SimpleObjectProperty?>, mode: ToolMode? = null) : PaintTool(activeSourceStateProperty, mode) { override val graphic = { FontAwesomeIconView().also { it.styleClass += listOf("toolbar-tool", "sam-select") } } @@ -179,6 +184,11 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty(null) var lastPrediction by lastPredictionProperty.nullable() private set - private val includePoints = mutableListOf() + private val points = mutableListOf() - private val excludePoints = mutableListOf() + private var clearPoints = true; - private var threshold = 2.5 + + private var thresholdBounds = Bounds(-40.0, 30.0) + private var threshold = 0.0 set(value) { - field = value.coerceAtLeast(0.0) + field = value.coerceIn(thresholdBounds.min, thresholdBounds.max) } init { - setCursorWhenDoneApplying = ChangeListener { observable, _, isApplying -> + setCursorWhenDoneApplying = ChangeListener { observable, _, _ -> observable.removeListener(setCursorWhenDoneApplying) } } @@ -230,16 +242,18 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty() + private val predictionToOriginalImageScaleWithoutPadding: Double + get() = max(imgWidth, imgHeight).toDouble() / LOW_RES_MASK_DIM private var predictionImagePngInputStream = PipedInputStream() private var predictionImagePngOutputStream = PipedOutputStream(predictionImagePngInputStream) override fun activate() { super.activate() - if (mode !is ShapeInterpolationMode<*>) { + if (mode is PaintLabelMode) { PaintLabelMode.disableUnfocusedViewers() } controlMode = false - threshold = 5.0 + threshold = 0.0 setCurrentLabelToSelection() statePaintContext?.selectedIds?.apply { addListener(selectedIdListener) } setViewer = activeViewer @@ -256,11 +270,9 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty @@ -272,6 +284,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty) { + if (mode is PaintLabelMode) { PaintLabelMode.enableAllViewers() } super.deactivate() @@ -303,25 +316,136 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty imgHeight) { + lowResWidth = LOW_RES_MASK_DIM.toLong() + lowResHeight = ceil(imgHeight / predictionToOriginalImageScaleWithoutPadding).toLong() + } else { + lowResHeight = LOW_RES_MASK_DIM.toLong() + lowResWidth = ceil(imgWidth / predictionToOriginalImageScaleWithoutPadding).toLong() + } + + val highResPrediction = ArrayImgs.floats(currentPrediction!!.masks!!.floatBuffer.array(), imgWidth.toLong(), imgHeight.toLong()) + val lowResPrediction = currentPrediction!!.image + + val name: String + val (mask, maskRai) = if (toggle) { + toggle = false + name = "high res" + highResPrediction to highResPrediction.interval(Intervals.createMinSize(0, 0, imgWidth.toLong(), imgHeight.toLong())) + } else { + toggle = true + name = "low res" + lowResPrediction to lowResPrediction.interval(Intervals.createMinSize(0, 0, lowResWidth, lowResHeight)) + } + + val (max, mean, std) = maskRai.let { + var sum = 0.0 + var sumSquared = 0.0 + var max = Float.MIN_VALUE + it.forEach { float -> + val floatVal = float.get() + sum += floatVal + sumSquared += floatVal.pow(2) + if (max < floatVal) max = floatVal + } + val area = Intervals.numElements(it) + val mean = sum / area + val stddev = sqrt(sumSquared / area - mean.pow(2)) + doubleArrayOf(max.toDouble(), mean, stddev) + } + val min = (mean - std).toFloat() + val zeroMinValue = mask.convert(FloatType()) { input, output -> output.set(input.get() - min) } + val predictionSource = paintera.baseView.addConnectomicsRawSource( + zeroMinValue.let { + val prediction3D = Views.addDimension(it) + val interval3D = Intervals.createMinMax(*it.minAsLongArray(), 0, *it.maxAsLongArray(), 0) + prediction3D.interval(interval3D) + }, + doubleArrayOf(1.0, 1.0, 1.0), + doubleArrayOf(0.0, 0.0, 0.0), + 0.0, max - min, + "$name prediction" + ) + + val transform = object : AffineTransform3D() { + override fun set(value: Double, row: Int, column: Int) { + super.set(value, row, column) + predictionSource.backend.updateTransform(this) + setViewer!!.requestRepaint() + } + } + + + setViewer!!.getGlobalViewerInterval().also { + val width = it.realMax(0) - it.realMin(0) + val height = it.realMax(1) - it.realMin(1) + val depth = it.realMax(2) - it.realMin(2) + transform.set( + *AffineTransform3D() + .concatenate(Translation3D(it.realMin(0), it.realMin(1), it.realMin(2))) + .concatenate(Scale3D(width / maskRai.shape()[0], height / maskRai.shape()[1], depth)) + .concatenate(Translation3D(.5, .5, 0.0)) //half-pixel offset + .inverse() + .rowPackedCopy + ) + } + predictionSource.backend.updateTransform(transform) + +// Stage().apply { +// val makeFields: (Int) -> SpatialField = { idx -> +// SpatialField.doubleField(0.0, { true }, Region.USE_COMPUTED_SIZE, ObjectField.SubmitOn.ENTER_PRESSED, ObjectField.SubmitOn.FOCUS_LOST).also { +// it.showHeader = false +// if (idx == 3) { +// it.setValues(transform[0, idx], transform[1, idx], transform[2, idx]) +// it.x.valueProperty().addListener { _, _, new -> transform.set(new.toDouble(), 0, idx) } +// it.y.valueProperty().addListener { _, _, new -> transform.set(new.toDouble(), 1, idx) } +// it.z.valueProperty().addListener { _, _, new -> transform.set(new.toDouble(), 2, idx) } +// } else { +// it.setValues(transform[idx, 0], transform[idx, 1], transform[idx, 2]) +// it.x.valueProperty().addListener { _, _, new -> transform.set(new.toDouble(), idx, 0) } +// it.y.valueProperty().addListener { _, _, new -> transform.set(new.toDouble(), idx, 1) } +// it.z.valueProperty().addListener { _, _, new -> transform.set(new.toDouble(), idx, 2) } +// } +// } +// } +// val sf1 = makeFields(0) +// val sf2 = makeFields(1) +// val sf3 = makeFields(2) +// val sf4 = makeFields(3) +// scene = Scene(VBox(sf1.node, sf2.node, sf3.node, sf4.node), 450.0, 800.0) +// }.show() + + predictionSource.composite = ARGBCompositeAlphaAdd() + setViewer!!.requestRepaint() } } KEY_PRESSED(KeyCode.CONTROL) { @@ -331,14 +455,24 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty 1.0 } verify { controlMode } verifyEventNotNull() verifyPainteraNotDisabled() - onAction { - val delta = arrayOf(it!!.deltaX, it.deltaY).maxBy { it.absoluteValue } - threshold += (delta.sign * .1) - requestPrediction(includePoints, excludePoints, true) + onAction { scroll -> + /* ScrollEvent deltas are internally multiplied to correspond to some estimate of pixels-per-unit-scroll. + * For example, on the platform I'm using now, it's `40` for both x and y. But our threshold is NOT + * in pixel units, so we divide by the multiplier, and specify our own. */ + val delta = with(scroll!!) { + when { + deltaY.absoluteValue > deltaX.absoluteValue -> deltaY / multiplierY + else -> deltaX / multiplierX + } + } + val increment = (thresholdBounds.max - thresholdBounds.min) / 100.0 + threshold += delta * increment + requestPrediction(points, true) } } @@ -349,11 +483,9 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty - Platform.runLater { - viewer.children += Circle(5.0).apply { - translateX = it!!.x - viewer.width / 2 - translateY = it.y - viewer.height / 2 - styleClass += SamPointStyle.POINT - styleClass += SamPointStyle.INCLUDE - } - } - } - requestPrediction(includePoints, excludePoints) + if (clearPoints) + resetPredictionPoints() + clearPoints = false + val point = SamPoint(it!!.x * screenScale, it.y * screenScale, SamPredictor.SparseLabel.IN) + drawCircle(it, point, SamPointStyle.Include) + points += point + requestPrediction(points) } } - MOUSE_CLICKED(MouseButton.SECONDARY) { + MOUSE_CLICKED ( MouseButton.SECONDARY) { name = "exclude point" verifyEventNotNull() verifyPainteraNotDisabled() verify { controlMode } onAction { - excludePoints += it!!.position.toPoint() - setViewer?.let { viewer -> - Platform.runLater { - viewer.children += Circle(5.0).apply { - translateX = it!!.x - viewer.width / 2 - translateY = it.y - viewer.height / 2 - styleClass += SamPointStyle.POINT - styleClass += SamPointStyle.EXCLUDE - } - } - } - requestPrediction(includePoints, excludePoints) + if (clearPoints) + resetPredictionPoints() + clearPoints = false + val point = SamPoint(it!!.x * screenScale, it.y * screenScale, SamPredictor.SparseLabel.OUT) + drawCircle(it, point, SamPointStyle.Exclude) + points += point + requestPrediction(points) } } }, @@ -412,11 +534,82 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty + newBox.styleClass += SAM_BOX_OVERLAY_STYLE + setViewer?.let { viewer -> + val oldBox = boxOverlay + InvokeOnJavaFXApplicationThread { + oldBox?.let { viewer.children -= oldBox } + viewer.children += newBox + } + } + } + } + onDrag { mouse -> + val box = boxOverlay!! + setViewer?.let { viewer -> + InvokeOnJavaFXApplicationThread { + val (minX, maxX) = (if (startX < mouse.x) startX to mouse.x else mouse.x to startX) + val (minY, maxY) = (if (startY < mouse.y) startY to mouse.y else mouse.y to startY) + + box.width = maxX - minX + box.height = maxY - minY + box.translateX = maxX - (box.width + viewer.width) / 2 + box.translateY = maxY - (box.height + viewer.height) / 2 + + + points.removeIf { it.label > SamPredictor.SparseLabel.IN } + points += SamPoint(minX * screenScale, minY * screenScale, SamPredictor.SparseLabel.TOP_LEFT_BOX) + points += SamPoint(maxX * screenScale, maxY * screenScale, SamPredictor.SparseLabel.BOTTOM_RIGHT_BOX) + requestPrediction(points) + } + } + } } ) - private fun clearInsideOutsideCircles() = setViewer?.let { viewer -> - Platform.runLater { viewer.children.removeIf { child -> SamPointStyle.POINT in child.styleClass } } + private fun drawCircle(it: MouseEvent, point: SamPoint, samStyle: SamPointStyle) { + setViewer?.let { viewer -> + Platform.runLater { + viewer.children += Circle(5.0).apply { + translateX = it.x - viewer.width / 2 + translateY = it.y - viewer.height / 2 + styleClass += samStyle.styles + + /* If clicked again, remove it */ + painteraActionSet("remove-circle", ignoreDisable = true) { + MOUSE_CLICKED { + onAction { + points -= point + viewer.children -= this@apply + requestPrediction(points) + } + } + }.also { installActionSet(it) } + } + } + } + } + + private fun resetPredictionPoints() { + points.clear() + clearCircles() + clearBox() + clearPoints = true + } + + private fun clearCircles() = setViewer?.let { viewer -> + Platform.runLater { viewer.children.removeIf { SAM_POINT_STYLE in it.styleClass } } + } + + private fun clearBox() = setViewer?.let { viewer -> + Platform.runLater { viewer.children.removeIf { SAM_BOX_OVERLAY_STYLE in it.styleClass } } } private fun SamTaskInfo.submitPrediction() { @@ -447,19 +640,15 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty - private val predictionQueue = LinkedBlockingQueue(1) - - private data class PredictionRequest(val includePoints: List, val excludePoints: List, val refresh: Boolean = false) + private val predictionQueue = LinkedBlockingQueue>(1) - private fun requestPrediction(includePoints: List, excludePoints: List, refresh: Boolean = false) { + fun requestPrediction(points: List = emptyList(), refresh: Boolean = false) { if (predictionTask == null || predictionTask?.isCancelled == true) { startPredictionTask() } - val include = MutableList(includePoints.size) { includePoints[it] } - val exclude = MutableList(excludePoints.size) { excludePoints[it] } synchronized(predictionQueue) { predictionQueue.clear() - predictionQueue.put(PredictionRequest(include, exclude, refresh)) + predictionQueue.put(SamPredictor.points(listOf(*points.toTypedArray())) to refresh) } } @@ -482,11 +671,13 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty? = null + private var currentPrediction: SamPredictor.SamPrediction? = null + private fun startPredictionTask() { val maskSource = maskedSource ?: return val task = Tasks.createTask { task -> val session = createOrtSessionTask.get() val embedding = providedEmbedding ?: getImageEmbeddingTask.get() + val predictor = SamPredictor(ortEnv, session, embedding, imgWidth to imgHeight) while (!task.isCancelled) { - val (pointsIn, pointsOut, refresh) = predictionQueue.take() - val predictionMask = if (refresh && currentPredictionMask != null) currentPredictionMask!! else runPredictionWithRetry(pointsIn, pointsOut, session, embedding) - currentPredictionMask = predictionMask + val (predictionRequest, refresh) = predictionQueue.take() + + val lowResWidth: Long + val lowResHeight: Long + if (imgWidth > imgHeight) { + lowResWidth = LOW_RES_MASK_DIM.toLong() + lowResHeight = ceil(imgHeight / predictionToOriginalImageScaleWithoutPadding).toLong() + } else { + lowResHeight = LOW_RES_MASK_DIM.toLong() + lowResWidth = ceil(imgWidth / predictionToOriginalImageScaleWithoutPadding).toLong() + } + val lowResIntervalWithoutPadding = Intervals.createMinSize(0, 0, lowResWidth, lowResHeight) + + val newPredictionRequest = !refresh || currentPrediction == null + if (newPredictionRequest) { + currentPrediction = runPredictionWithRetry(predictor, predictionRequest) + } + if (!refresh) { + setBestEstimatedThreshold(lowResIntervalWithoutPadding) + } val paintMask = viewerMask!! - val predictionMaskInterval = RealPoint(imgWidth!!.toDouble(), imgHeight!!.toDouble()) - .scaledPoint(1.0 / screenScale) - .toPoint() - .let { scaledPoint -> - paintMask.getScreenInterval(scaledPoint[0], scaledPoint[1]) - } - val minPoint = longArrayOf(Long.MAX_VALUE, Long.MAX_VALUE, 0) - val maxPoint = longArrayOf(Long.MIN_VALUE, Long.MIN_VALUE, 0) + val minPointInLowResMask = longArrayOf(Long.MAX_VALUE, Long.MAX_VALUE) + val maxPointInLowResMask = longArrayOf(Long.MIN_VALUE, Long.MIN_VALUE) var noneAccepted = true - val filter = Converters.convert( - BundleView(predictionMask), - { sourceRa, output -> - val type = sourceRa.get() - val accept = type.get() >= threshold + val lowResFilter = Converters.convert( + BundleView(currentPrediction!!.image), + { predictionMaskRA, output -> + val predictionType = predictionMaskRA.get() + val predictionValue = predictionType.get() + val accept = predictionValue >= threshold output.set(accept) if (accept) { noneAccepted = false - val pos = sourceRa.positionAsLongArray() - minPoint[0] = min(minPoint[0], pos[0]) - minPoint[1] = min(minPoint[1], pos[1]) - - maxPoint[0] = max(maxPoint[0], pos[0]) - maxPoint[1] = max(maxPoint[1], pos[1]) + /* TODO Caleb: Check if `localize()` is worth optimizing here. + * Should be easy, as long as it's thread safe and/or not parallel. */ + val pos = predictionMaskRA.positionAsLongArray() + minPointInLowResMask[0] = min(minPointInLowResMask[0], pos[0]) + minPointInLowResMask[1] = min(minPointInLowResMask[1], pos[1]) + + maxPointInLowResMask[0] = max(maxPointInLowResMask[0], pos[0]) + maxPointInLowResMask[1] = max(maxPointInLowResMask[1], pos[1]) } }, BoolType() ) - val connectedComponents: RandomAccessibleInterval = ArrayImgs.unsignedLongs(*predictionMask.dimensionsAsLongArray()) + val lowResConnectedComponents: RandomAccessibleInterval = ArrayImgs.unsignedLongs(*lowResIntervalWithoutPadding.dimensionsAsLongArray()) /* FIXME: This is annoying, but I don't see a better way around it at the moment. * `labelAllConnectedComponents` can be interrupted, but doing so causes an * internal method to `printStackTrace()` on the error. So even when @@ -564,8 +764,8 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty + val originalImgToPredictionScale = 1 / predictionToOriginalImageScaleWithoutPadding + val (x, y) = highResPoint.centerScaledCoordinates(originalImgToPredictionScale) + x.toLong() to y.toLong() + } + ?.filter { (x, y) -> lowResFilter.getAt(x, y).get() } + ?.map { (x, y) -> lowResConnectedComponents.getAt(x, y).get() } + ?.toMutableSet() ?: mutableSetOf() - val componentsUnderPointsIn = pointsIn - .map { point -> point.scaledPoint(screenScale) } - .filter { point -> filter.getAt(*point.positionAsLongArray(), 0).get() } - .map { point -> connectedComponents.getAt(*point.positionAsLongArray(), 0).get() } - .toSet() - val selectedComponents = Converters.convertRAI( - connectedComponents, - { source, output -> output.set(source.get() in componentsUnderPointsIn) }, - BoolType() + + predictionPoints?.firstOrNull { it.label == SamPredictor.SparseLabel.TOP_LEFT_BOX }?.let { topLeft -> + predictionPoints.firstOrNull { it.label == SamPredictor.SparseLabel.BOTTOM_RIGHT_BOX }?.let { bottomRight -> + val originalImgToPredictionScale = 1 / predictionToOriginalImageScaleWithoutPadding + val minPos = topLeft.centerScaledCoordinates(originalImgToPredictionScale).let { + longArrayOf(it.first.toLong(), it.second.toLong()) + } + val maxPos = bottomRight.centerScaledCoordinates(originalImgToPredictionScale).let { + longArrayOf(it.first.toLong(), it.second.toLong()) + } + val intervalIter = IntervalIterator(FinalInterval(minPos, maxPos)) + val posInBox = LongArray(2) + while (intervalIter.hasNext()) { + intervalIter.fwd() + intervalIter.localize(posInBox) + if (lowResFilter.getAt(*posInBox).get()) { + acceptedComponents += lowResConnectedComponents.getAt(*posInBox).get() + } + } + } + } + + val lowResSelectedComponents = Converters.convertRAI( + lowResConnectedComponents, + { source, output -> + output.set(if (source.get() in acceptedComponents) 1.0f else 0.0f) + }, + FloatType() ) - val alignToMask = AffineTransform3D() - .concatenate(Translation3D(*predictionMaskInterval.minAsDoubleArray())) - .concatenate(Scale3D(screenScale, screenScale, 2.0).inverse()) - val maskAlignedSelectedComponents = selectedComponents - .extendValue(Label.INVALID) - .interpolateNearestNeighbor() - .affineReal(alignToMask) + val predictionToViewerScale = Scale2D(setViewer!!.width / lowResWidth, setViewer!!.height / lowResHeight) + val halfPixelOffset = Translation2D(.5, .5) + val translationToViewer = Translation2D(*paintMask.displayPointToInitialMaskPoint(0, 0).positionAsDoubleArray()) + val predictionToViewerTransform = AffineTransform2D().concatenate(translationToViewer).concatenate(predictionToViewerScale).concatenate(halfPixelOffset) + val maskAlignedSelectedComponents = lowResSelectedComponents + .extendValue(0.0) + .interpolate(NLinearInterpolatorFactory()) + .affineReal(predictionToViewerTransform) + .convert(UnsignedLongType(Label.INVALID)) { source, output -> output.set(if (source.get() in .9..1.1) currentLabelToPaint else Label.INVALID) } + .addDimension() .raster() .interval(paintMask.viewerImg) @@ -611,7 +846,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty val overlayVal = overlay.get() composite.set( - if (overlayVal) currentLabelToPaint else original.get() + if (overlayVal == currentLabelToPaint) currentLabelToPaint else original.get() ) }, UnsignedLongType(Label.INVALID) @@ -622,7 +857,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty var checkOriginal = false val overlayVal = overlay.get() - if (overlayVal) { + if (overlayVal == currentLabelToPaint) { composite.get().set(currentLabelToPaint) composite.isValid = true } else checkOriginal = true @@ -642,22 +877,55 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty, pointsOut: List, session: OrtSession, embedding: OnnxTensor): RandomAccessibleInterval { - return try { - runPrediction(pointsIn, pointsOut, session, embedding) - } catch (e: OrtException) { - LOG.trace(e.message) - runPredictionWithRetry(pointsIn, pointsOut, session, embedding) + private fun setBestEstimatedThreshold(lowResIntervalWithoutPadding: FinalInterval?) { + val binMapper = Real1dBinMapper(-40.0, 30.0, 256, false) + val histogram = LongArray(binMapper.binCount.toInt()) + + val histogramInterval = points.filter { it.label > SamPredictor.SparseLabel.IN }.let { + if (it.size == 2) { + val (x1, y1) = it[0].run { (x / predictionToOriginalImageScaleWithoutPadding).toLong() to (y / predictionToOriginalImageScaleWithoutPadding).toLong() } + val (x2, y2) = it[1].run { (x / predictionToOriginalImageScaleWithoutPadding).toLong() to (y / predictionToOriginalImageScaleWithoutPadding).toLong() } + FinalInterval(longArrayOf(x1, y1), longArrayOf(x2, y2)) + } else null + } ?: lowResIntervalWithoutPadding + LoopBuilder.setImages(Views.interval(currentPrediction!!.image, histogramInterval)) + .forEachPixel { + val binIdx = binMapper.map(it).toInt() + if (binIdx != -1) + histogram[binIdx]++ + } + + + val binVar = FloatType() + val minThreshold = histogram.indexOfFirst { it > 0 }.let { + if (it == -1) return@let thresholdBounds.min + binMapper.getLowerBound(it.toLong(), binVar) + binVar.get().toDouble() + } + val maxThreshold = histogram.indexOfLast { it > 0 }.let { + if (it == -1) return@let thresholdBounds.max + binMapper.getUpperBound(it.toLong(), binVar) + binVar.get().toDouble() } + val otsuIdx = otsuThresholdPrediction(histogram) + binMapper.getUpperBound(otsuIdx, binVar) + + thresholdBounds = Bounds(minThreshold, maxThreshold) + threshold = binVar.get().toDouble() + } + + private fun runPredictionWithRetry(predictor: SamPredictor, vararg predictionRequest: SamPredictor.PredictionRequest): SamPredictor.SamPrediction { /* FIXME: This is a bit hacky, but works for now until a better solution is found. * Some explenation. When running the SAM predictions, occasionally the following OrtException is thrown: * [E:onnxruntime:, sequential_executor.cc:494 ExecuteKernel] @@ -667,69 +935,31 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty, pointsOut: List, session: OrtSession, embedding: OnnxTensor): RandomAccessibleInterval { - val coordsArray = FloatArray(2 * (pointsIn.size + pointsOut.size)) - val labels = FloatArray(coordsArray.size / 2) - var idx = 0 - - mapOf(pointsIn to 1f, pointsOut to 0f).forEach { (points, label) -> - points.forEach { - val convertedCoord = convertCoordinate(RealPoint(it.scaledPoint(screenScale))) - labels[idx / 2] = label - coordsArray[idx++] = convertedCoord.getFloatPosition(0) - coordsArray[idx++] = convertedCoord.getFloatPosition(1) - } + return try { + predictor.predict(*predictionRequest) + } catch (e: OrtException) { + LOG.trace(e.message) + runPredictionWithRetry(predictor, *predictionRequest) } - - val coordsBuffer = FloatBuffer.wrap(coordsArray) - val onnxCoords = OnnxTensor.createTensor(ortEnv, coordsBuffer, longArrayOf(1, labels.size.toLong(), 2)) - - val labelsBuffer = FloatBuffer.wrap(labels.map { it }.toFloatArray()) - val onnxLabels = OnnxTensor.createTensor(ortEnv, labelsBuffer, longArrayOf(1, labels.size.toLong())) - - /* NOTE: This is (height, width) */ - val onnxImgSize = OnnxTensor.createTensor(ortEnv, FloatBuffer.wrap(floatArrayOf(imgHeight!!, imgWidth!!)), longArrayOf(2)) - - val maskInput = OnnxTensor.createTensor(ortEnv, ByteBuffer.allocateDirect(1 * 1 * 256 * 256 * 4).asFloatBuffer(), longArrayOf(1, 1, 256, 256)) - val hasMaskInput = OnnxTensor.createTensor(ortEnv, ByteBuffer.allocateDirect(4).asFloatBuffer(), longArrayOf(1)) - val mask = session.run( - mapOf( - "image_embeddings" to embedding, - "point_coords" to onnxCoords, - "point_labels" to onnxLabels, - "orig_im_size" to onnxImgSize, - "mask_input" to maskInput, - "has_mask_input" to hasMaskInput, - ) - ).get("masks").get() as OnnxTensor - val maskImg = ArrayImgs.floats(mask.floatBuffer.array(), imgWidth!!.toLong(), imgHeight!!.toLong()) - return Views.addDimension(maskImg, 0, 0) } - private var imgWidth: Float? = null - private var imgHeight: Float? = null - + /** + * Calculates the target screen scale factor based on the highest screen scale and the viewer's dimensions. + * The resulting scale factor will always be the smaller of either: + * 1. the highest explicitly specified factor, or + * 2. [SamPredictor.MAX_DIM_TARGET] / `max(width, height)` + * + * This means if the `scaleFactor * maxEdge` is less than [SamPredictor.MAX_DIM_TARGET] it will be used, + * but if the `scaleFactor * maxEdge` is still larger than [SamPredictor.MAX_DIM_TARGET], then a more + * aggressive scale factor will be returned. See [SamPredictor.MAX_DIM_TARGET] for more information. + * + * @return The calculated scale factor. + */ private fun calculateTargetScreenScaleFactor(): Double { - val currentScreenScale = setViewer!!.renderUnit.screenScalesProperty.get()!![0] + val highestScreenScale = setViewer!!.renderUnit.screenScalesProperty.get().max() val (width, height) = setViewer!!.width to setViewer!!.height - val maxEdge = max(width, height) * currentScreenScale - return min(currentScreenScale, 1024.0 / maxEdge) - } - - private fun convertCoordinate(coord: RealPoint): RealPoint { - val (height, width) = imgHeight!! to imgWidth!! - val x = coord.getFloatPosition(0) - val y = coord.getFloatPosition(1) - val target = 1024 - val scale = target * (1.0 / max(height, width)) - val (scaledWidth, scaledHeight) = ((width * scale) + 0.5).toInt() to ((height * scale) + 0.5).toInt() - val (scaledX, scaledY) = x * (scaledWidth / width) to y * (scaledHeight / height) - - coord.setPosition(floatArrayOf(scaledX, scaledY)) - - return coord + val maxEdge = max(ceil(width * highestScreenScale), ceil(height * highestScreenScale)) + return min(highestScreenScale, SamPredictor.MAX_DIM_TARGET / maxEdge) } private fun saveActiveViewerImageFromRenderer() { @@ -786,10 +1016,6 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty result.image?.let { img -> - imgWidth = img.width.toFloat() - imgHeight = img.height.toFloat() - - ImageIO.write(SwingFXUtils.fromFXImage(img, null), "png", predictionImagePngOutputStream) predictionImagePngOutputStream.close() } @@ -800,10 +1026,13 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty) { + Include(arrayOf(SAM_POINT_STYLE, "sam-include")), + Exclude(arrayOf(SAM_POINT_STYLE, "sam-exclude")), + Box(arrayOf(SAM_POINT_STYLE, "sam-box")) } private val LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()) @@ -833,7 +1062,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty, val maskInterval: Interval) + } -} +} \ No newline at end of file diff --git a/src/main/resources/style/sam.css b/src/main/resources/style/sam.css index 5145dd5d1..d8eac788b 100644 --- a/src/main/resources/style/sam.css +++ b/src/main/resources/style/sam.css @@ -3,14 +3,23 @@ .sam-point { -fx-stroke: white; -fx-stroke-width: 2.0; + -include-fill: green; + -exclude-fill: red; } -.sam-point.sam-include-point { - -fx-fill: -accept; +.sam-box-overlay { + -fx-stroke-type: inside; + -fx-stroke: #cec9c6;; + -fx-stroke-width: 2; + -fx-fill: transparent; } -.sam-point.sam-exclude-point { - -fx-fill: -reject; +.sam-point.sam-include { + -fx-fill: -include-fill; +} + +.sam-point.sam-exclude { + -fx-fill: -exclude-fill; } .glyph-icon.sam-select { From 167c461511d13890d7f0cfe07a8426ba76438f73 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Wed, 10 Jan 2024 16:44:16 -0500 Subject: [PATCH 16/28] feat: add OtsuThresholdingPrediction algorithm --- .../algorithms/OtsuThresholdPrediction.kt | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 src/main/kotlin/org/janelia/saalfeldlab/paintera/util/algorithms/OtsuThresholdPrediction.kt diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/util/algorithms/OtsuThresholdPrediction.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/util/algorithms/OtsuThresholdPrediction.kt new file mode 100644 index 000000000..4ee63c2df --- /dev/null +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/util/algorithms/OtsuThresholdPrediction.kt @@ -0,0 +1,49 @@ + +package org.janelia.saalfeldlab.paintera.util.algorithms + +import kotlin.math.pow + +/** + * Otsu threshold prediction, adapted for kotlin from: + * [ComputeOtsuThreshold.java](https://github.com/imagej/imagej-ops/blob/master/src/main/java/net/imagej/ops/threshold/otsu/ComputeOtsuThreshold.java) + * + * @param histogram array of bins, whose values are frequensies at that bin + * @return index of the bin to use for threshold + * + * @author Caleb Hulbert + * @author Barry DeZonia + * @author Gabriel Landini + */ +fun otsuThresholdPrediction(histogram: LongArray): Long { + + val (histogramIntensity, numPoints) = histogram.foldIndexed(LongArray(2)) { idx, acc, freq -> + acc[0] += idx * freq + acc[1] += freq + acc + } + + var intensitySumBelowThreshold: Long = 0 + var numPointsBelowThreshold = histogram[0] + + var interClassVariance: Double + var maxInterClassVariance = 0.0 + var predictedThresholdIdx = 0 + + for (i in 1 until histogram.size - 1) { + intensitySumBelowThreshold += i * histogram[i] + numPointsBelowThreshold += histogram[i] + + val denom = numPointsBelowThreshold.toDouble() * (numPoints - numPointsBelowThreshold) + + if (denom != 0.0) { + val num = ((numPointsBelowThreshold.toDouble() / numPoints) * histogramIntensity - intensitySumBelowThreshold).pow(2) + interClassVariance = num / denom + } else interClassVariance = 0.0 + + if (interClassVariance >= maxInterClassVariance) { + maxInterClassVariance = interClassVariance + predictedThresholdIdx = i + } + } + return predictedThresholdIdx.toLong() +} \ No newline at end of file From 3b90224cef853e1168c0cbd21129ddcbcebaea6d Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Tue, 16 Jan 2024 15:23:47 -0500 Subject: [PATCH 17/28] refactor!: rename SAM binding key, revert mode logic --- .../saalfeldlab/paintera/BindingKeys.kt | 6 +- .../paintera/control/modes/PaintLabelMode.kt | 4 +- .../control/modes/SegmentAnythingMode.kt | 258 ------------------ 3 files changed, 4 insertions(+), 264 deletions(-) delete mode 100644 src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/SegmentAnythingMode.kt diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/BindingKeys.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/BindingKeys.kt index d1194bd78..9cb020fca 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/BindingKeys.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/BindingKeys.kt @@ -107,8 +107,7 @@ object LabelSourceStateKeys { const val REFRESH_MESHES = "refresh meshes" const val CANCEL = "cancel" const val TOGGLE_NON_SELECTED_LABELS_VISIBILITY = "toggle non-selected labels visibility" - const val ENTER_SEGMENT_ANYTHING_MODE = "segment anything: enter mode" - const val EXIT_SEGMENT_ANYTHING_MODE = "segment anything: exit mode" + const val SEGMENT_ANYTHING = "Segment Anything Mode" private val namedComboMap = NamedKeyCombination.CombinationMap( SELECT_ALL byKeyCombo A + CONTROL_DOWN, @@ -130,8 +129,7 @@ object LabelSourceStateKeys { REFRESH_MESHES byKeyCombo R, CANCEL byKeyCombo ESCAPE, TOGGLE_NON_SELECTED_LABELS_VISIBILITY byKeyCombo V + SHIFT_DOWN, - ENTER_SEGMENT_ANYTHING_MODE byKeyCombo A, - EXIT_SEGMENT_ANYTHING_MODE byKeyCombo ESCAPE + SEGMENT_ANYTHING byKeyCombo A ) fun namedCombinationsCopy() = namedComboMap.deepCopy diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/PaintLabelMode.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/PaintLabelMode.kt index b4802d899..594698d69 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/PaintLabelMode.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/PaintLabelMode.kt @@ -27,7 +27,7 @@ import org.janelia.saalfeldlab.fx.util.InvokeOnJavaFXApplicationThread import org.janelia.saalfeldlab.paintera.DeviceManager import org.janelia.saalfeldlab.paintera.LabelSourceStateKeys import org.janelia.saalfeldlab.paintera.LabelSourceStateKeys.ENTER_SHAPE_INTERPOLATION_MODE -import org.janelia.saalfeldlab.paintera.LabelSourceStateKeys.SEGMENT_ANYTHING_MODE +import org.janelia.saalfeldlab.paintera.LabelSourceStateKeys.SEGMENT_ANYTHING import org.janelia.saalfeldlab.paintera.control.ShapeInterpolationController import org.janelia.saalfeldlab.paintera.control.actions.AllowedActions import org.janelia.saalfeldlab.paintera.control.actions.LabelActionType @@ -162,7 +162,7 @@ object PaintLabelMode : AbstractToolMode() { } } - private val activeSamTool = painteraActionSet(SEGMENT_ANYTHING_MODE, PaintActionType.Paint) { + private val activeSamTool = painteraActionSet(SEGMENT_ANYTHING, PaintActionType.Paint) { KEY_PRESSED(*samTool.keyTrigger.toTypedArray()) { verify { activeSourceStateProperty.get() is ConnectomicsLabelState<*, *> } verify { activeTool !is SamTool } diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/SegmentAnythingMode.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/SegmentAnythingMode.kt deleted file mode 100644 index 98fd7ab18..000000000 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/SegmentAnythingMode.kt +++ /dev/null @@ -1,258 +0,0 @@ -package org.janelia.saalfeldlab.paintera.control.modes - -import ai.onnxruntime.OnnxTensor -import bdv.util.Affine3DHelpers -import de.jensd.fx.glyphs.fontawesome.FontAwesomeIconView -import javafx.beans.value.ChangeListener -import javafx.collections.FXCollections -import javafx.collections.ObservableList -import javafx.scene.input.KeyEvent.KEY_PRESSED -import javafx.scene.input.KeyEvent.KEY_RELEASED -import javafx.scene.input.MouseEvent.MOUSE_PRESSED -import net.imglib2.Interval -import net.imglib2.realtransform.AffineTransform3D -import org.janelia.saalfeldlab.fx.actions.ActionSet -import org.janelia.saalfeldlab.fx.actions.ActionSet.Companion.installActionSet -import org.janelia.saalfeldlab.fx.actions.ActionSet.Companion.removeActionSet -import org.janelia.saalfeldlab.fx.actions.painteraActionSet -import org.janelia.saalfeldlab.fx.extensions.LazyForeignValue -import org.janelia.saalfeldlab.fx.extensions.addWithListener -import org.janelia.saalfeldlab.fx.ortho.OrthogonalViews -import org.janelia.saalfeldlab.paintera.LabelSourceStateKeys -import org.janelia.saalfeldlab.paintera.control.actions.AllowedActions -import org.janelia.saalfeldlab.paintera.control.actions.PaintActionType -import org.janelia.saalfeldlab.paintera.control.paint.ViewerMask -import org.janelia.saalfeldlab.paintera.control.tools.Tool -import org.janelia.saalfeldlab.paintera.control.tools.paint.* -import org.janelia.saalfeldlab.paintera.data.mask.MaskedSource -import org.janelia.saalfeldlab.paintera.paintera - -class SegmentAnythingMode(val previousMode: ControlMode) : AbstractToolMode() { - - override val defaultTool: Tool? by lazy { samTool } - - private val samTool: SamTool = object : SamTool(activeSourceStateProperty, this@SegmentAnythingMode) { - - private var lastEmbedding: OnnxTensor? = null - private var globalTransformAtEmbedding = AffineTransform3D() - - init { - activeViewerProperty.unbind() - activeViewerProperty.bind(mode!!.activeViewerProperty) - } - - override fun activate() { - maskedSource?.resetMasks(false) - providedEmbedding = if (Affine3DHelpers.equals(paintera.baseView.manager().transform, globalTransformAtEmbedding)) lastEmbedding else null - super.activate() - } - - override fun deactivate() { - super.deactivate() - lastEmbedding = getImageEmbeddingTask.get()!! - globalTransformAtEmbedding.set(paintera.baseView.manager().transform) - } - - override fun setCurrentLabelToSelection() { - currentLabelToPaint = statePaintContext!!.selectedIds.lastSelection - } - } - - private val paintBrushTool = object : PaintBrushTool(activeSourceStateProperty, this@SegmentAnythingMode) { - - override val actionSets: MutableList by LazyForeignValue({ activeViewerAndTransforms }) { - mutableListOf( - *getBrushActions().filterNot { it.name == CHANGE_BRUSH_DEPTH }.toTypedArray(), - *getPaintActions().filterNot { it.name == START_BACKGROUND_ERASE }.toTypedArray(), - segmentAnythingPaintBrushActions(), - *(midiBrushActions() ?: arrayOf()) - ) - } - - override fun activate() { - super.activate() - /* Don't allow painting with depth during shape interpolation */ - brushProperties?.brushDepth = 1.0 - paintClickOrDrag!!.provideMask(samTool.viewerMask!!) - } - - override fun deactivate() { - paintClickOrDrag?.release() - super.deactivate() - } - } - - private val fill2DTool = object : Fill2DTool(activeSourceStateProperty, this@SegmentAnythingMode) { - - - private val samPredictionOnFill = ChangeListener { _, _, new -> - new?.let { - switchTool(samTool) - samTool.requestPrediction() - } - } - - override fun activate() { - super.activate() - /* Don't allow filling with depth during shape interpolation */ - brushProperties?.brushDepth = 1.0 - fillLabel = { statePaintContext!!.selectedIds.lastSelection } - brushProperties?.brushDepth = 1.0 - fill2D.provideMask(samTool.viewerMask!!) - fill2D.maskIntervalProperty.addListener(samPredictionOnFill) - } - - override fun deactivate() { - fill2D.maskIntervalProperty.removeListener(samPredictionOnFill) - super.deactivate() - } - - override val actionSets: MutableList by LazyForeignValue({ activeViewerAndTransforms }) { - super.actionSets.also { it += segmentAnythingFloodFillActions(this) } - } - - } - - override val modeActions by lazy { modeActions() } - - override val allowedActions = AllowedActions.AllowedActionsBuilder() - .add(PaintActionType.Paint, PaintActionType.Erase, PaintActionType.SetBrushSize, PaintActionType.Fill) - .create() - - private val toolTriggerListener = ChangeListener { _, old, new -> - new?.viewer()?.apply { modeActions.forEach { installActionSet(it) } } - old?.viewer()?.apply { modeActions.forEach { removeActionSet(it) } } - } - - override val tools: ObservableList by lazy { FXCollections.observableArrayList(paintBrushTool, fill2DTool, samTool) } - - override fun enter() { - activeViewerProperty.addListener(toolTriggerListener) - super.enter() - /* unbind the activeViewerProperty, since we disabled other viewers during ShapeInterpolation mode*/ - activeViewerProperty.unbind() - /* Try to initialize the tool, if state is valid. If not, change back to previous mode. */ - activeViewerProperty.get()?.viewer()?.let { - disableUnfocusedViewers() - switchTool(samTool) - } ?: paintera.baseView.changeMode(previousMode) - } - - override fun exit() { - super.exit() - enableAllViewers() - activeViewerProperty.removeListener(toolTriggerListener) - } - - private fun modeActions(): List { - val keyCombinations = paintera.baseView.keyAndMouseBindings.getConfigFor(activeSourceStateProperty.value!!).keyCombinations - return mutableListOf( - painteraActionSet(LabelSourceStateKeys.EXIT_SEGMENT_ANYTHING_MODE) { - - verifyAll(KEY_PRESSED, "Sam Tool is Active ") { activeTool == samTool } - KEY_PRESSED { - graphic = { FontAwesomeIconView().apply { styleClass += listOf("toolbar-tool", "reject", "reject-segment-anything") } } - keyMatchesBinding(keyCombinations, LabelSourceStateKeys.EXIT_SEGMENT_ANYTHING_MODE) - onAction { - paintera.baseView.changeMode(previousMode) - } - } - }, - painteraActionSet("paint during segment anything", PaintActionType.Paint) { - KEY_PRESSED(*paintBrushTool.keyTrigger.toTypedArray()) { - name = "switch to paint tool" - val getViewerMask = { (activeSourceStateProperty.get()?.dataSource as? MaskedSource<*, *>)?.currentMask as? ViewerMask } - verify { getViewerMask() != null } - onAction { - switchTool(paintBrushTool) - } - } - - KEY_RELEASED(*paintBrushTool.keyTrigger.toTypedArray()) { - name = "switch back to segment anything tool from paint brush" - filter = true - verify { activeTool is PaintBrushTool } - onAction { switchTool(samTool) } - } - - KEY_PRESSED(*fill2DTool.keyTrigger.toTypedArray()) { - name = "switch to fill2d tool" - verify { activeSourceStateProperty.get()?.dataSource is MaskedSource<*, *> } - onAction { switchTool(fill2DTool) } - } - KEY_RELEASED(*fill2DTool.keyTrigger.toTypedArray()) { - name = "switch to segment anything tool from fill2d" - filter = true - verify { activeTool is Fill2DTool } - onAction { - switchTool(samTool) - } - } - } - ) - } - - /** - * Additional paint brush actions for Segment Anything - * - * @receiver the tool to add the actions to - * @return the additional action sets - */ - private fun PaintBrushTool.segmentAnythingPaintBrushActions(): ActionSet { - - return painteraActionSet("Segment Anything Paint Brush Actions", PaintActionType.SegmentAnything) { - MOUSE_PRESSED { - name = "provide SAM tool mask to paint brush" - filter = true - consume = false - verify { activeTool == this@segmentAnythingPaintBrushActions } - onAction { - /* On click, generate a new mask, */ - (activeSourceStateProperty.get()?.dataSource as? MaskedSource<*, *>)?.let { source -> - paintClickOrDrag!!.let { paintController -> - source.resetMasks(true) - paintController.provideMask(samTool.viewerMask!!) - } - } - } - } - } - } - - /** - * Additional fill actions for Segment Anything - * - * @param floodFillTool - * @return the additional ActionSet - * - * */ - private fun segmentAnythingFloodFillActions(floodFillTool: Fill2DTool): ActionSet { - return painteraActionSet("Segment Anything Fill 2D Actions", PaintActionType.SegmentAnything) { - MOUSE_PRESSED { - name = "provide SAM tool mask to fill 2d" - filter = true - consume = false - verify { activeTool == floodFillTool } - onAction { - /* On click, provide the mask, setup the task listener */ - (activeSourceStateProperty.get()?.dataSource as? MaskedSource<*, *>)?.let { source -> - source.resetMasks(true) - val mask = samTool.viewerMask!! - fill2DTool.run { - fillTaskProperty.addWithListener { obs, _, task -> - task?.let { - task.onCancelled(true) { _, _ -> - source.resetMasks(true) - mask.requestRepaint() - } - task.onEnd(true) { obs?.removeListener(this) } - } ?: obs?.removeListener(this) - } - fill2D.provideMask(mask) - } - } - } - } - } - } -} \ No newline at end of file From 078ea56b57f750f1d1e69be7a0a6bb42335f235b Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Tue, 16 Jan 2024 15:26:35 -0500 Subject: [PATCH 18/28] feat: expose high and low res masks from SamPredictor, prefer highres mask for SamTool --- .../control/tools/paint/SamPredictor.kt | 38 +++- .../paintera/control/tools/paint/SamTool.kt | 165 ++++++++---------- 2 files changed, 100 insertions(+), 103 deletions(-) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamPredictor.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamPredictor.kt index acbe312e3..0e8fb0efe 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamPredictor.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamPredictor.kt @@ -5,15 +5,19 @@ import ai.onnxruntime.OnnxTensorLike import ai.onnxruntime.OrtEnvironment import ai.onnxruntime.OrtSession import io.github.oshai.kotlinlogging.KotlinLogging +import net.imglib2.Interval import net.imglib2.RandomAccessibleInterval import net.imglib2.RealPoint import net.imglib2.img.array.ArrayImgs import net.imglib2.type.NativeType -import net.imglib2.type.numeric.integer.UnsignedLongType import net.imglib2.type.numeric.real.FloatType +import net.imglib2.util.Intervals +import org.janelia.saalfeldlab.util.interval import java.nio.ByteBuffer import java.nio.ByteOrder import java.nio.FloatBuffer +import kotlin.math.ceil +import kotlin.math.max private fun allocateDirectFloatBuffer(size: Int, order: ByteOrder = ByteOrder.nativeOrder()): FloatBuffer { return ByteBuffer.allocateDirect(size * Float.SIZE_BYTES).order(order).asFloatBuffer() @@ -124,19 +128,18 @@ class SamPredictor( } data class SamPrediction( - val masks: OnnxTensor? = null, + val masks: OnnxTensor, val iouPredictions: OnnxTensor, val lowResMasks: OnnxTensor, val predictor: SamPredictor ) { - var image: RandomAccessibleInterval = ArrayImgs.floats(lowResMasks.floatBuffer.array(), LOW_RES_MASK_DIM.toLong(), LOW_RES_MASK_DIM.toLong()) - - /* - * Binary segmentation mask of connected components, or null if not binarized. - * Interval is the smallest bounding box containing the segmentation. May be empty, or the entire original image size. - */ - var segmentation: RandomAccessibleInterval? = null + val image: RandomAccessibleInterval = ArrayImgs.floats(masks.floatBuffer.array(), predictor.originalImgSize.first.toLong(), predictor.originalImgSize.second.toLong()) + val lowResImage: RandomAccessibleInterval by lazy { + ArrayImgs.floats(lowResMasks.floatBuffer.array(), LOW_RES_MASK_DIM.toLong(), LOW_RES_MASK_DIM.toLong()).interval(lowResIntervalWithoutPadding) + } + val lowToHighResScale: Double + private val lowResIntervalWithoutPadding : Interval constructor(result: OrtSession.Result, predictor: SamPredictor) : this( result[MASKS].get() as OnnxTensor, @@ -145,6 +148,23 @@ class SamPredictor( predictor ) + init { + with(predictor) { + val lowResWidth: Long + val lowResHeight: Long + val (imgWidth, imgHeight) = originalImgSize + lowToHighResScale = max(imgWidth, imgHeight).toDouble() / LOW_RES_MASK_DIM + if (imgWidth > imgHeight) { + lowResWidth = LOW_RES_MASK_DIM.toLong() + lowResHeight = ceil(imgHeight / lowToHighResScale).toLong() + } else { + lowResHeight = LOW_RES_MASK_DIM.toLong() + lowResWidth = ceil(imgWidth / lowToHighResScale).toLong() + } + lowResIntervalWithoutPadding = Intervals.createMinSize(0, 0, lowResWidth, lowResHeight) + } + } + companion object { const val MASKS = "masks" const val IOU_PREDICTIONS = "iou_predictions" diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt index 816688209..d34ea4b9b 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt @@ -61,8 +61,11 @@ import org.apache.http.impl.client.HttpClients import org.apache.http.util.EntityUtils import org.janelia.saalfeldlab.fx.Tasks import org.janelia.saalfeldlab.fx.UtilityTask -import org.janelia.saalfeldlab.fx.actions.* import org.janelia.saalfeldlab.fx.actions.ActionSet.Companion.installActionSet +import org.janelia.saalfeldlab.fx.actions.painteraActionSet +import org.janelia.saalfeldlab.fx.actions.painteraDragActionSet +import org.janelia.saalfeldlab.fx.actions.painteraMidiActionSet +import org.janelia.saalfeldlab.fx.actions.verifyPainteraNotDisabled import org.janelia.saalfeldlab.fx.event.KeyTracker import org.janelia.saalfeldlab.fx.extensions.LazyForeignValue import org.janelia.saalfeldlab.fx.extensions.nonnull @@ -80,7 +83,6 @@ import org.janelia.saalfeldlab.paintera.control.modes.ToolMode import org.janelia.saalfeldlab.paintera.control.paint.ViewerMask import org.janelia.saalfeldlab.paintera.control.paint.ViewerMask.Companion.createViewerMask import org.janelia.saalfeldlab.paintera.control.paint.ViewerMask.Companion.getGlobalViewerInterval -import org.janelia.saalfeldlab.paintera.control.tools.paint.SamPredictor.Companion.LOW_RES_MASK_DIM import org.janelia.saalfeldlab.paintera.control.tools.paint.SamPredictor.SamPoint import org.janelia.saalfeldlab.paintera.data.mask.MaskInfo import org.janelia.saalfeldlab.paintera.data.mask.MaskedSource @@ -242,8 +244,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty() - private val predictionToOriginalImageScaleWithoutPadding: Double - get() = max(imgWidth, imgHeight).toDouble() / LOW_RES_MASK_DIM + private var predictionImagePngInputStream = PipedInputStream() private var predictionImagePngOutputStream = PipedOutputStream(predictionImagePngInputStream) @@ -343,28 +344,18 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty imgHeight) { - lowResWidth = LOW_RES_MASK_DIM.toLong() - lowResHeight = ceil(imgHeight / predictionToOriginalImageScaleWithoutPadding).toLong() - } else { - lowResHeight = LOW_RES_MASK_DIM.toLong() - lowResWidth = ceil(imgWidth / predictionToOriginalImageScaleWithoutPadding).toLong() - } - - val highResPrediction = ArrayImgs.floats(currentPrediction!!.masks!!.floatBuffer.array(), imgWidth.toLong(), imgHeight.toLong()) - val lowResPrediction = currentPrediction!!.image + val highResPrediction = currentPrediction!!.image + val lowResPrediction = currentPrediction!!.lowResImage val name: String - val (mask, maskRai) = if (toggle) { + val maskRai = if (toggle) { toggle = false name = "high res" - highResPrediction to highResPrediction.interval(Intervals.createMinSize(0, 0, imgWidth.toLong(), imgHeight.toLong())) + highResPrediction } else { toggle = true name = "low res" - lowResPrediction to lowResPrediction.interval(Intervals.createMinSize(0, 0, lowResWidth, lowResHeight)) + lowResPrediction } val (max, mean, std) = maskRai.let { @@ -383,7 +374,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty output.set(input.get() - min) } + val zeroMinValue = maskRai.convert(FloatType()) { input, output -> output.set(input.get() - min) } val predictionSource = paintera.baseView.addConnectomicsRawSource( zeroMinValue.let { val prediction3D = Views.addDimension(it) @@ -506,7 +497,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty imgHeight) { - lowResWidth = LOW_RES_MASK_DIM.toLong() - lowResHeight = ceil(imgHeight / predictionToOriginalImageScaleWithoutPadding).toLong() - } else { - lowResHeight = LOW_RES_MASK_DIM.toLong() - lowResWidth = ceil(imgWidth / predictionToOriginalImageScaleWithoutPadding).toLong() - } - val lowResIntervalWithoutPadding = Intervals.createMinSize(0, 0, lowResWidth, lowResHeight) - val newPredictionRequest = !refresh || currentPrediction == null if (newPredictionRequest) { currentPrediction = runPredictionWithRetry(predictor, predictionRequest) } + val prediction = currentPrediction!! + if (!refresh) { - setBestEstimatedThreshold(lowResIntervalWithoutPadding) + val thresholdPredictorInterval = if (points.all { it.label > SamPredictor.SparseLabel.IN }) intervalOfBox(prediction) else null + setBestEstimatedThreshold(thresholdPredictorInterval) } val paintMask = viewerMask!! - val minPointInLowResMask = longArrayOf(Long.MAX_VALUE, Long.MAX_VALUE) - val maxPointInLowResMask = longArrayOf(Long.MIN_VALUE, Long.MIN_VALUE) + val minPoint = longArrayOf(Long.MAX_VALUE, Long.MAX_VALUE) + val maxPoint = longArrayOf(Long.MIN_VALUE, Long.MIN_VALUE) + + val predictedImage = currentPrediction!!.image var noneAccepted = true - val lowResFilter = Converters.convert( - BundleView(currentPrediction!!.image), + val thresholdFilter = Converters.convert( + BundleView(predictedImage), { predictionMaskRA, output -> val predictionType = predictionMaskRA.get() val predictionValue = predictionType.get() @@ -739,20 +724,18 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty = ArrayImgs.unsignedLongs(*lowResIntervalWithoutPadding.dimensionsAsLongArray()) + val connectedComponents: RandomAccessibleInterval = ArrayImgs.unsignedLongs(*predictedImage.dimensionsAsLongArray()) /* FIXME: This is annoying, but I don't see a better way around it at the moment. * `labelAllConnectedComponents` can be interrupted, but doing so causes an * internal method to `printStackTrace()` on the error. So even when @@ -764,8 +747,8 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty - val originalImgToPredictionScale = 1 / predictionToOriginalImageScaleWithoutPadding - val (x, y) = highResPoint.centerScaledCoordinates(originalImgToPredictionScale) - x.toLong() to y.toLong() - } - ?.filter { (x, y) -> lowResFilter.getAt(x, y).get() } - ?.map { (x, y) -> lowResConnectedComponents.getAt(x, y).get() } + ?.map { it.x.toLong() to it.y.toLong() } + ?.filter { (x, y) -> thresholdFilter.getAt(x, y).get() } + ?.map { (x, y) -> connectedComponents.getAt(x, y).get() } ?.toMutableSet() ?: mutableSetOf() - predictionPoints?.firstOrNull { it.label == SamPredictor.SparseLabel.TOP_LEFT_BOX }?.let { topLeft -> predictionPoints.firstOrNull { it.label == SamPredictor.SparseLabel.BOTTOM_RIGHT_BOX }?.let { bottomRight -> - val originalImgToPredictionScale = 1 / predictionToOriginalImageScaleWithoutPadding - val minPos = topLeft.centerScaledCoordinates(originalImgToPredictionScale).let { - longArrayOf(it.first.toLong(), it.second.toLong()) - } - val maxPos = bottomRight.centerScaledCoordinates(originalImgToPredictionScale).let { - longArrayOf(it.first.toLong(), it.second.toLong()) - } - val intervalIter = IntervalIterator(FinalInterval(minPos, maxPos)) + + val minPos = longArrayOf(topLeft.x.toLong(), topLeft.y.toLong()) + val maxPos = longArrayOf(bottomRight.x.toLong(), bottomRight.y.toLong()) + val boxIterator = IntervalIterator(FinalInterval(minPos, maxPos)) val posInBox = LongArray(2) - while (intervalIter.hasNext()) { - intervalIter.fwd() - intervalIter.localize(posInBox) - if (lowResFilter.getAt(*posInBox).get()) { - acceptedComponents += lowResConnectedComponents.getAt(*posInBox).get() + while (boxIterator.hasNext()) { + boxIterator.fwd() + boxIterator.localize(posInBox) + if (thresholdFilter.getAt(*posInBox).get()) { + acceptedComponents += connectedComponents.getAt(*posInBox).get() } } } } - val lowResSelectedComponents = Converters.convertRAI( - lowResConnectedComponents, + val selectedComponents = Converters.convertRAI( + connectedComponents, { source, output -> output.set(if (source.get() in acceptedComponents) 1.0f else 0.0f) }, FloatType() ) - val predictionToViewerScale = Scale2D(setViewer!!.width / lowResWidth, setViewer!!.height / lowResHeight) + val (width, height) = predictedImage.dimensionsAsLongArray() + val predictionToViewerScale = Scale2D(setViewer!!.width / width, setViewer!!.height / height) val halfPixelOffset = Translation2D(.5, .5) val translationToViewer = Translation2D(*paintMask.displayPointToInitialMaskPoint(0, 0).positionAsDoubleArray()) val predictionToViewerTransform = AffineTransform2D().concatenate(translationToViewer).concatenate(predictionToViewerScale).concatenate(halfPixelOffset) - val maskAlignedSelectedComponents = lowResSelectedComponents + val maskAlignedSelectedComponents = selectedComponents .extendValue(0.0) .interpolate(NLinearInterpolatorFactory()) .affineReal(predictionToViewerTransform) - .convert(UnsignedLongType(Label.INVALID)) { source, output -> output.set(if (source.get() in .9..1.1) currentLabelToPaint else Label.INVALID) } + .convert(UnsignedLongType(Label.INVALID)) { source, output -> output.set(if (source.get() > .8) currentLabelToPaint else Label.INVALID) } .addDimension() .raster() .interval(paintMask.viewerImg) - val compositeMask = Converters.convertRAI( - originalBackingImage, maskAlignedSelectedComponents, - { original, overlay, composite -> + val compositeMask = originalBackingImage!! + .extendValue(Label.INVALID) + .convertWith(maskAlignedSelectedComponents, UnsignedLongType(Label.INVALID)) { original, overlay, composite -> val overlayVal = overlay.get() composite.set( if (overlayVal == currentLabelToPaint) currentLabelToPaint else original.get() ) - }, - UnsignedLongType(Label.INVALID) - ) + }.interval(maskAlignedSelectedComponents) + - val compositeVolatileMask = Converters.convertRAI( - originalVolatileBackingImage, maskAlignedSelectedComponents, - { original, overlay, composite -> + val compositeVolatileMask = originalVolatileBackingImage!! + .extendValue(VolatileUnsignedLongType(Label.INVALID)) + .convertWith(maskAlignedSelectedComponents, VolatileUnsignedLongType(Label.INVALID)) { original, overlay, composite -> var checkOriginal = false val overlayVal = overlay.get() if (overlayVal == currentLabelToPaint) { @@ -868,16 +842,14 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty(-40.0, 30.0, 256, false) val histogram = LongArray(binMapper.binCount.toInt()) - val histogramInterval = points.filter { it.label > SamPredictor.SparseLabel.IN }.let { - if (it.size == 2) { - val (x1, y1) = it[0].run { (x / predictionToOriginalImageScaleWithoutPadding).toLong() to (y / predictionToOriginalImageScaleWithoutPadding).toLong() } - val (x2, y2) = it[1].run { (x / predictionToOriginalImageScaleWithoutPadding).toLong() to (y / predictionToOriginalImageScaleWithoutPadding).toLong() } - FinalInterval(longArrayOf(x1, y1), longArrayOf(x2, y2)) - } else null - } ?: lowResIntervalWithoutPadding - LoopBuilder.setImages(Views.interval(currentPrediction!!.image, histogramInterval)) + val predictionRAI = interval?.let { currentPrediction!!.image.interval(it) } ?: currentPrediction!!.image + LoopBuilder.setImages(predictionRAI) .forEachPixel { val binIdx = binMapper.map(it).toInt() if (binIdx != -1) @@ -925,6 +891,17 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty SamPredictor.SparseLabel.IN }.let { + if (it.size == 2) { + val scale = if (lowRes) samPrediction.lowToHighResScale else 1.0 + val (x1, y1) = it[0].run { (x / scale).toLong() to (y / scale).toLong() } + val (x2, y2) = it[1].run { (x / scale).toLong() to (y / scale).toLong() } + FinalInterval(longArrayOf(x1, y1), longArrayOf(x2, y2)) + } else null + } + } + private fun runPredictionWithRetry(predictor: SamPredictor, vararg predictionRequest: SamPredictor.PredictionRequest): SamPredictor.SamPrediction { /* FIXME: This is a bit hacky, but works for now until a better solution is found. * Some explenation. When running the SAM predictions, occasionally the following OrtException is thrown: From a415db79a98f4a9b1a193a3ceeb4286fae6e1388 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Tue, 16 Jan 2024 15:27:07 -0500 Subject: [PATCH 19/28] refactor!: pull rotation controller out as field --- .../control/modes/NavigationControlMode.kt | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt index 01ba03e6b..3fb1d719a 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt @@ -127,7 +127,7 @@ object NavigationTool : ViewerTool() { override val name: String = "Navigation" override val keyTrigger = null /* This is typically the default, so no binding to actively switch to it. */ - //TODO Caleb: should standardize the `TransformTracker` and updater concept. Refer to Rotate/KeyRotate/TranslationController + //TODO Caleb: should standardize the `TransformTracker` and updater concept. Refer to Rotate/TranslationController val globalToViewerTransform by LazyForeignMap({ activeViewerAndTransforms }) { AffineTransform3D() } val viewerTransform by LazyForeignValue({ activeViewerAndTransforms }) { viewerAndTransforms -> @@ -146,9 +146,15 @@ object NavigationTool : ViewerTool() { val zoomController by LazyForeignValue({ activeViewerAndTransforms }) { Zoom(globalTransformManager, viewerTransform) } + val keyRotationAxis by LazyForeignValue({ activeViewerAndTransforms }) { SimpleObjectProperty(Axis.Z) } + + val rotationController by LazyForeignValue({ activeViewerAndTransforms }) { + Rotate(it!!.displayTransform(), it.globalToViewerTransform(), globalTransformManager) + } + val resetRotationController by LazyForeignValue({ activeViewerAndTransforms }) { RemoveRotation(viewerTransform, globalTransform, { globalTransformManager.setTransform(it, Duration(300.0)) @@ -337,7 +343,8 @@ object NavigationTool : ViewerTool() { keysDown(*keys) onAction { val scale = 1 + ControlUtils.getBiggestScroll(it!!) / 1_000 - zoomController.zoomCenteredAt(scale, it.x, it.y) } + zoomController.zoomCenteredAt(scale, it.x, it.y) + } } } } @@ -412,8 +419,6 @@ object NavigationTool : ViewerTool() { KEY_PRESSED(keyBindings, NavigationKeys.SET_ROTATION_AXIS_Z) { onAction { keyRotationAxis.set(Axis.Z) } } } - val rotationController = Rotate(displayTransform, globalToViewerTransform, globalTransformManager) - val mouseRotation = painteraDragActionSet("mousde-drag-rotate", NavigationActionType.Rotate) { verify { it.isPrimaryButtonDown } dragDetectedAction.verify { NavigationTool.allowRotationsProperty() } @@ -445,6 +450,7 @@ object NavigationTool : ViewerTool() { rotationActions += setRotationAxis rotationActions += mouseRotation rotationActions += keyRotation + midiRotationActions()?.let { rotationActions += it } return rotationActions.filterNotNull() } @@ -469,10 +475,6 @@ object NavigationTool : ViewerTool() { DeviceManager.xTouchMini?.let { device -> targetPositionObservable?.let { targetPosition -> val target = vat.viewer() - val submitTransform: (AffineTransform3D) -> Unit = { t -> globalTransformManager.transform = t } - val step = SimpleDoubleProperty(5 * Math.PI / 180.0) - val axisProperty = SimpleObjectProperty(keyRotationAxis.get()) - val rotate = KeyRotate(axisProperty, step, vat.displayTransform(), vat.globalToViewerTransform(), globalTransformManager, submitTransform) data class MidiRotationStruct(val handle: Int, val axis: Axis) listOf( @@ -487,11 +489,9 @@ object NavigationTool : ViewerTool() { verifyEventNotNull() verify { allowRotationsProperty() } onAction { - InvokeOnJavaFXApplicationThread { - axisProperty.set(axis) - step.set(step.value.absoluteValue * it!!.value.sign) - rotate.rotate(targetPosition.x, targetPosition.y) - } + val direction = it!!.value.sign + rotationController.setSpeed(direction * speed) + rotationController.rotateAroundAxis(targetPosition.x, targetPosition.y, axis) } } } From 056c8742797e23e4999e44a23961426809088be4 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Mon, 22 Jan 2024 14:41:21 -0500 Subject: [PATCH 20/28] fix(test): suppliers for transforms --- .../janelia/saalfeldlab/paintera/state/SourceInfoTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/java/org/janelia/saalfeldlab/paintera/state/SourceInfoTest.java b/src/test/java/org/janelia/saalfeldlab/paintera/state/SourceInfoTest.java index b1174cd03..97de5e75e 100644 --- a/src/test/java/org/janelia/saalfeldlab/paintera/state/SourceInfoTest.java +++ b/src/test/java/org/janelia/saalfeldlab/paintera/state/SourceInfoTest.java @@ -44,7 +44,7 @@ public void invalidateAll(long parallelismThreshold) { RandomAccessibleIntervalDataSource<>( rai, rai, - new AffineTransform3D(), + () -> new AffineTransform3D(), NO_OP_INVALIDATE, i -> new NearestNeighborInterpolatorFactory<>(), i -> new NearestNeighborInterpolatorFactory<>(), @@ -55,7 +55,7 @@ public void invalidateAll(long parallelismThreshold) { RandomAccessibleIntervalDataSource<>( rai, rai, - new AffineTransform3D(), + () -> new AffineTransform3D(), NO_OP_INVALIDATE, i -> new NearestNeighborInterpolatorFactory<>(), i -> new NearestNeighborInterpolatorFactory<>(), From aa4d0d6cb478ce445ef119eaa89d8bc53a3cf858 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Mon, 22 Jan 2024 14:41:38 -0500 Subject: [PATCH 21/28] perf: improve transform animation logic --- .../state/GlobalTransformManager.java | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/janelia/saalfeldlab/paintera/state/GlobalTransformManager.java b/src/main/java/org/janelia/saalfeldlab/paintera/state/GlobalTransformManager.java index e90abd771..9937943cd 100644 --- a/src/main/java/org/janelia/saalfeldlab/paintera/state/GlobalTransformManager.java +++ b/src/main/java/org/janelia/saalfeldlab/paintera/state/GlobalTransformManager.java @@ -54,10 +54,24 @@ public synchronized void setTransform(final AffineTransform3D affine, final Dura }); } + private Timeline animateSetTransform = null; + + /** + * Set the global transform to {@code affine} with an animation, over {@code duration} amount of time. When + * the animation is stopped, either due to it finishing, or being stopped early, {@code runAfterAnimation} will be triggered. + * + * The animation can be stopped early either by passing in a `duration` of `0` milliseconds or less, OR by setting `duration` to null. + * In both cases of the animation stopping early, the `runAfterrAnimation` will be triggered. + * + * @param affine to set the global transform to + * @param duration to animate the transform update over + * @param runAfterAnimation to run when the animation stops. This could either be when the global transform equals {@code affine} or if it was stopped early. + */ public synchronized void setTransform(final AffineTransform3D affine, final Duration duration, final Runnable runAfterAnimation) { - if (duration.toMillis() == 0.0) { + if (duration == null || duration.toMillis() == 0.0) { setTransform(affine); + runAfterAnimation.run(); return; } final Timeline timeline = new Timeline(60.0); @@ -75,10 +89,18 @@ public synchronized void setTransform(final AffineTransform3D affine, final Dura public synchronized void setTransform(final AffineTransform3D affine) { + resetTransformAnimation(); this.affine.set(affine); notifyListeners(); } + private void resetTransformAnimation() { + if (animateSetTransform != null) { + animateSetTransform.stop(); + animateSetTransform = null; + } + } + public void addListener(final TransformListener listener) { this.listeners.add(listener); From f6ccca2ea29d2ea9d02aedfba69334cb10e71022 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Mon, 22 Jan 2024 14:41:41 -0500 Subject: [PATCH 22/28] perf: rendering performance --- .../paintera/stream/AbstractHighlightingARGBStream.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/janelia/saalfeldlab/paintera/stream/AbstractHighlightingARGBStream.java b/src/main/java/org/janelia/saalfeldlab/paintera/stream/AbstractHighlightingARGBStream.java index fe30f850a..edd65a5c5 100644 --- a/src/main/java/org/janelia/saalfeldlab/paintera/stream/AbstractHighlightingARGBStream.java +++ b/src/main/java/org/janelia/saalfeldlab/paintera/stream/AbstractHighlightingARGBStream.java @@ -87,7 +87,7 @@ public AbstractHighlightingARGBStream( } protected TLongIntHashMap argbCache = new TLongIntHashMap( - Constants.DEFAULT_CAPACITY, + Constants.DEFAULT_CAPACITY * 10, Constants.DEFAULT_LOAD_FACTOR, Label.TRANSPARENT, 0 From ee28ad48e5c658502a4fed425dbc6d00880e894c Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Mon, 22 Jan 2024 14:41:47 -0500 Subject: [PATCH 23/28] style: unused variable --- .../janelia/saalfeldlab/paintera/BorderPaneWithStatusBars.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/BorderPaneWithStatusBars.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/BorderPaneWithStatusBars.kt index 4d2155aac..b6fac407d 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/BorderPaneWithStatusBars.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/BorderPaneWithStatusBars.kt @@ -142,7 +142,7 @@ class BorderPaneWithStatusBars(paintera: PainteraMainWindow) { LOG.debug("Init {}", BorderPaneWithStatusBars::class.java.name) initCrossHairs() toggleOnMenuBarConfigMode(menuBar) - paintera.baseView.activeModeProperty.addListener { _, old, new -> + paintera.baseView.activeModeProperty.addListener { _, _, new -> (new as? ToolMode)?.also { toolMode -> val toolBar = toolMode.createToolBar() toolBar.visibleProperty().bind(painteraProperties.toolBarConfig.isVisibleProperty) From 6743f2cf14d41b71e35d4ca25abcd5668f1fa7fa Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Mon, 22 Jan 2024 14:41:53 -0500 Subject: [PATCH 24/28] feat: expose applyPrediction for overriding --- .../control/modes/ShapeInterpolationMode.kt | 34 +++++-------------- .../paintera/control/tools/paint/SamTool.kt | 18 +++++----- 2 files changed, 17 insertions(+), 35 deletions(-) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/ShapeInterpolationMode.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/ShapeInterpolationMode.kt index e2dc3971b..db6d325ce 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/ShapeInterpolationMode.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/ShapeInterpolationMode.kt @@ -183,6 +183,14 @@ class ShapeInterpolationMode>(val controller: ShapeInterpolat globalTransformAtEmbedding.set(paintera.baseView.manager().transform) } + override fun applyPrediction() { + lastPrediction?.let { + super.applyPrediction() + controller.paint(it.maskInterval) + switchTool(shapeInterpolationTool) + } + } + override fun setCurrentLabelToSelection() { currentLabelToPaint = controller.interpolationId } @@ -535,35 +543,11 @@ class ShapeInterpolationMode>(val controller: ShapeInterpolat * */ private fun additionalSamActions(samTool: SamTool): ActionSet { return painteraActionSet("Shape Interpolation SAM Actions", PaintActionType.ShapeInterpolation) { - KEY_PRESSED(KeyCode.ENTER) { - name = "submit sam mask to shape interpolation controller" - verify { activeTool == samTool } - onAction { - samTool.lastPrediction?.let { prediction -> - controller.paint(prediction.maskInterval) - } - switchTool(shapeInterpolationTool) - } - } KEY_PRESSED(KeyCode.ESCAPE) { name = "toggle off sam tool, back to shapeinterpolation " filter = true verify { activeTool == samTool } - onAction { - switchTool(shapeInterpolationTool) - } - } - MOUSE_CLICKED(MouseButton.PRIMARY) { - name = "submit sam mask to shape interpolation controller" - verifyEventNotNull() - verify("Control cannot be down") { it?.isControlDown == false } - verify { activeTool == samTool } - onAction { - samTool.lastPrediction?.let { prediction -> - controller.paint(prediction.maskInterval) - } - switchTool(shapeInterpolationTool) - } + onAction { switchTool(shapeInterpolationTool) } } switchAndApplyShapeInterpolation() } diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt index d34ea4b9b..817f92c05 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/tools/paint/SamTool.kt @@ -321,20 +321,14 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty Platform.runLater { @@ -1008,8 +1007,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty) { Include(arrayOf(SAM_POINT_STYLE, "sam-include")), - Exclude(arrayOf(SAM_POINT_STYLE, "sam-exclude")), - Box(arrayOf(SAM_POINT_STYLE, "sam-box")) + Exclude(arrayOf(SAM_POINT_STYLE, "sam-exclude")) } private val LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()) From 86d2cf7543b9203699a51e76ce8d056151a37244 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Mon, 22 Jan 2024 14:42:00 -0500 Subject: [PATCH 25/28] fix: midi controller logic to javafx thread --- .../paintera/control/modes/NavigationControlMode.kt | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt index 3fb1d719a..b874c2cb2 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt @@ -489,9 +489,11 @@ object NavigationTool : ViewerTool() { verifyEventNotNull() verify { allowRotationsProperty() } onAction { - val direction = it!!.value.sign - rotationController.setSpeed(direction * speed) - rotationController.rotateAroundAxis(targetPosition.x, targetPosition.y, axis) + InvokeOnJavaFXApplicationThread { + val direction = it!!.value.sign + rotationController.setSpeed(direction * speed) + rotationController.rotateAroundAxis(targetPosition.x, targetPosition.y, axis) + } } } } From 85ea87acdc28bd87b95c4f0e3255eefe3f007de7 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Mon, 22 Jan 2024 14:42:04 -0500 Subject: [PATCH 26/28] feat!: initial work on automated slice segment with SAM --- .../control/ShapeInterpolationController.kt | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/ShapeInterpolationController.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/ShapeInterpolationController.kt index ffc68366e..3b258f84a 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/ShapeInterpolationController.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/ShapeInterpolationController.kt @@ -365,7 +365,7 @@ class ShapeInterpolationController>( } } - fun selectAndMoveToSlice(sliceInfo: SliceInfo) { + private fun selectAndMoveToSlice(sliceInfo: SliceInfo) { controllerState = ControllerState.Moving InvokeOnJavaFXApplicationThread { paintera().manager().setTransform(sliceInfo.globalTransform, Duration(300.0)) { @@ -606,16 +606,15 @@ class ShapeInterpolationController>( } private val globalToViewerTransform: AffineTransform3D get() = AffineTransform3D().also { activeViewer!!.state.getViewerTransform(it) } - fun getMask(): ViewerMask { + fun getMask(targetMipMapLevel: Int = currentBestMipMapLevel): ViewerMask { - val currentLevel = currentBestMipMapLevel /* If we have a mask, get it; else create a new one */ currentViewerMask = sliceAtCurrentDepth?.let { oldSlice -> val oldMask = oldSlice.mask if (oldMask.xScaleChange == 1.0) return@let oldMask - val maskInfo = MaskInfo(0, currentLevel) + val maskInfo = MaskInfo(0, targetMipMapLevel) val newMask = source.createViewerMask(maskInfo, activeViewer!!, paintDepth = null, setMask = false) val oldToNewMask = ViewerMask.maskToMaskTransformation(oldMask, newMask) @@ -642,7 +641,7 @@ class ShapeInterpolationController>( newMask.viewerImg.wrappedSource = oldInNew newMask.volatileViewerImg.wrappedSource = oldInNewVolatile - /* then we pop the `newMask` back in front, as a writable layer */ + /* then we push the `newMask` back in front, as a writable layer */ newMask.pushNewImageLayer(newImg to newVolatileImg) /* Replace old slice info */ @@ -659,7 +658,7 @@ class ShapeInterpolationController>( slicesAndInterpolants.add(currentDepth, newSlice) newMask } ?: let { - val maskInfo = MaskInfo(0, currentLevel) + val maskInfo = MaskInfo(0, targetMipMapLevel) source.createViewerMask(maskInfo, activeViewer!!, paintDepth = null, setMask = false) } currentViewerMask?.setViewerMaskOnSource() @@ -1047,7 +1046,7 @@ class ShapeInterpolationController>( class InterpolantInfo(val dataInterpolant: RealRandomAccessible) - class SliceInfo( + private class SliceInfo( var mask: ViewerMask, val globalTransform: AffineTransform3D, selectionInterval: RealInterval From 18948618cb253f929fbc18932c8e1b75a94d515d Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Mon, 22 Jan 2024 16:01:00 -0500 Subject: [PATCH 27/28] fix: midi zoom controls --- .../paintera/control/modes/NavigationControlMode.kt | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt index b874c2cb2..8dd618ba3 100644 --- a/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt +++ b/src/main/kotlin/org/janelia/saalfeldlab/paintera/control/modes/NavigationControlMode.kt @@ -385,7 +385,13 @@ object NavigationTool : ViewerTool() { setDisplayType(DisplayType.TRIM) verifyEventNotNull() onAction { - InvokeOnJavaFXApplicationThread { zoomController.zoomCenteredAt(-it!!.value.toDouble(), target.width / 2.0, target.height / 2.0) } + InvokeOnJavaFXApplicationThread { + val delta = speed / 100 + val potVal = it!!.value.toDouble() + val scale = 1 - delta * potVal + val (x,y) = targetPositionObservable!!.let { it.x to it.y } + zoomController.zoomCenteredAt(scale, x, y) + } } } } @@ -483,7 +489,7 @@ object NavigationTool : ViewerTool() { MidiRotationStruct(7, Axis.Z), ).map { (handle, axis) -> painteraMidiActionSet("rotate", device, target, NavigationActionType.Rotate) { - MidiPotentiometerEvent.POTENTIOMETER_RELATIVE(handle) { + MidiPotentiometerEvent.POTENTIOMETER_RELATIVE ( handle) { name = "midi_rotate_${axis.name.lowercase()}" setDisplayType(DisplayType.TRIM) verifyEventNotNull() From 3c2b4b85a666dc41cde5c50a130a80fb243d6d25 Mon Sep 17 00:00:00 2001 From: Caleb Hulbert Date: Mon, 22 Jan 2024 16:03:41 -0500 Subject: [PATCH 28/28] build: use saalfx-1.1.0 --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index a24998885..60b8429ab 100644 --- a/pom.xml +++ b/pom.xml @@ -10,7 +10,7 @@ org.janelia.saalfeldlab paintera - 1.0.2-SNAPSHOT + 1.1.0-SNAPSHOT Paintera New Era Painting and annotation tool @@ -53,7 +53,7 @@ true ${javadoc.skip} - 1.1.0-SNAPSHOT + 1.1.0 3.0.7 1.4.0