diff --git a/doc/changes/devel/12655.newfeature.rst b/doc/changes/devel/12655.newfeature.rst new file mode 100644 index 00000000000..b2711d2f2ed --- /dev/null +++ b/doc/changes/devel/12655.newfeature.rst @@ -0,0 +1,2 @@ +Added support for passing ``axes`` to :func:`mne.viz.plot_head_positions` when +``mode='field'``, by `Eric Larson`_. diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index d2097b219fc..0d389645f3f 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -130,6 +130,7 @@ def plot_head_positions( mode="traces", cmap="viridis", direction="z", + *, show=True, destination=None, info=None, @@ -169,9 +170,11 @@ def plot_head_positions( .. versionadded:: 0.16 axes : array-like, shape (3, 2) - The matplotlib axes to use. Only used for ``mode == 'traces'``. + The matplotlib axes to use. .. versionadded:: 0.16 + .. versionchanged:: 1.8 + Added support for making use of this argument when ``mode="field"``. Returns ------- @@ -193,7 +196,9 @@ def plot_head_positions( if not isinstance(pos, (list, tuple)): pos = [pos] + pos = list(pos) # make our own mutable copy for ii, p in enumerate(pos): + _validate_type(p, np.ndarray, f"pos[{ii}]") p = np.array(p, float) if p.ndim != 2 or p.shape[1] != 10: raise ValueError( @@ -315,9 +320,15 @@ def plot_head_positions( from mpl_toolkits.mplot3d import Axes3D # noqa: F401, analysis:ignore from mpl_toolkits.mplot3d.art3d import Line3DCollection - fig, ax = plt.subplots( - 1, subplot_kw=dict(projection="3d"), layout="constrained" - ) + _validate_type(axes, (Axes3D, None), "ax", extra="when mode='field'") + if axes is None: + _, ax = plt.subplots( + 1, subplot_kw=dict(projection="3d"), layout="constrained" + ) + else: + ax = axes + fig = ax.get_figure() + del axes # First plot the trajectory as a colormap: # http://matplotlib.org/examples/pylab_examples/multicolored_line.html diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 5109becb645..584e70432c4 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -104,13 +104,20 @@ def test_plot_head_positions(): pos = np.random.RandomState(0).randn(4, 10) pos[:, 0] = np.arange(len(pos)) destination = (0.0, 0.0, 0.04) - with _record_warnings(): # old MPL will cause a warning - plot_head_positions(pos) - plot_head_positions(pos, mode="field", info=info, destination=destination) - plot_head_positions([pos, pos]) # list support - pytest.raises(ValueError, plot_head_positions, ["pos"]) - pytest.raises(ValueError, plot_head_positions, pos[:, :9]) - pytest.raises(ValueError, plot_head_positions, pos, "foo") + plot_head_positions(pos) + plot_head_positions(pos, mode="field", info=info, destination=destination) + plot_head_positions([pos, pos]) # list support + fig, ax = plt.subplots() + with pytest.raises(TypeError, match="instance of Axes3D"): + plot_head_positions(pos, mode="field", info=info, axes=ax) + fig, ax = plt.subplots(subplot_kw=dict(projection="3d")) + plot_head_positions(pos, mode="field", info=info, axes=ax) + with pytest.raises(TypeError, match="must be an instance of ndarray"): + plot_head_positions(["foo"]) + with pytest.raises(ValueError, match="must be dim"): + plot_head_positions(pos[:, :9]) + with pytest.raises(ValueError, match="Allowed values"): + plot_head_positions(pos, "foo") with pytest.raises(ValueError, match="shape"): plot_head_positions(pos, axes=1.0) diff --git a/tools/install_pre_requirements.sh b/tools/install_pre_requirements.sh index 2f532b73220..a5a11dba3e6 100755 --- a/tools/install_pre_requirements.sh +++ b/tools/install_pre_requirements.sh @@ -46,7 +46,7 @@ python -m pip install $STD_ARGS vtk python -c "import vtk" echo "PyVista" -python -m pip install $STD_ARGS "git+https://github.com/adeak/pyvista.git@fix_numpy_2" # pyvista/pyvista +python -m pip install $STD_ARGS "git+https://github.com/pyvista/pyvista" echo "picard" python -m pip install $STD_ARGS git+https://github.com/pierreablin/picard