Skip to content

Commit

Permalink
ENH: Support axes arg for field mode (mne-tools#12655)
Browse files Browse the repository at this point in the history
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel McCloy <[email protected]>
  • Loading branch information
3 people authored Jun 10, 2024
1 parent 07f429e commit b167573
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 12 deletions.
2 changes: 2 additions & 0 deletions doc/changes/devel/12655.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added support for passing ``axes`` to :func:`mne.viz.plot_head_positions` when
``mode='field'``, by `Eric Larson`_.
19 changes: 15 additions & 4 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def plot_head_positions(
mode="traces",
cmap="viridis",
direction="z",
*,
show=True,
destination=None,
info=None,
Expand Down Expand Up @@ -169,9 +170,11 @@ def plot_head_positions(
.. versionadded:: 0.16
axes : array-like, shape (3, 2)
The matplotlib axes to use. Only used for ``mode == 'traces'``.
The matplotlib axes to use.
.. versionadded:: 0.16
.. versionchanged:: 1.8
Added support for making use of this argument when ``mode="field"``.
Returns
-------
Expand All @@ -193,7 +196,9 @@ def plot_head_positions(

if not isinstance(pos, (list, tuple)):
pos = [pos]
pos = list(pos) # make our own mutable copy
for ii, p in enumerate(pos):
_validate_type(p, np.ndarray, f"pos[{ii}]")
p = np.array(p, float)
if p.ndim != 2 or p.shape[1] != 10:
raise ValueError(
Expand Down Expand Up @@ -315,9 +320,15 @@ def plot_head_positions(
from mpl_toolkits.mplot3d import Axes3D # noqa: F401, analysis:ignore
from mpl_toolkits.mplot3d.art3d import Line3DCollection

fig, ax = plt.subplots(
1, subplot_kw=dict(projection="3d"), layout="constrained"
)
_validate_type(axes, (Axes3D, None), "ax", extra="when mode='field'")
if axes is None:
_, ax = plt.subplots(
1, subplot_kw=dict(projection="3d"), layout="constrained"
)
else:
ax = axes
fig = ax.get_figure()
del axes

# First plot the trajectory as a colormap:
# http://matplotlib.org/examples/pylab_examples/multicolored_line.html
Expand Down
21 changes: 14 additions & 7 deletions mne/viz/tests/test_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,20 @@ def test_plot_head_positions():
pos = np.random.RandomState(0).randn(4, 10)
pos[:, 0] = np.arange(len(pos))
destination = (0.0, 0.0, 0.04)
with _record_warnings(): # old MPL will cause a warning
plot_head_positions(pos)
plot_head_positions(pos, mode="field", info=info, destination=destination)
plot_head_positions([pos, pos]) # list support
pytest.raises(ValueError, plot_head_positions, ["pos"])
pytest.raises(ValueError, plot_head_positions, pos[:, :9])
pytest.raises(ValueError, plot_head_positions, pos, "foo")
plot_head_positions(pos)
plot_head_positions(pos, mode="field", info=info, destination=destination)
plot_head_positions([pos, pos]) # list support
fig, ax = plt.subplots()
with pytest.raises(TypeError, match="instance of Axes3D"):
plot_head_positions(pos, mode="field", info=info, axes=ax)
fig, ax = plt.subplots(subplot_kw=dict(projection="3d"))
plot_head_positions(pos, mode="field", info=info, axes=ax)
with pytest.raises(TypeError, match="must be an instance of ndarray"):
plot_head_positions(["foo"])
with pytest.raises(ValueError, match="must be dim"):
plot_head_positions(pos[:, :9])
with pytest.raises(ValueError, match="Allowed values"):
plot_head_positions(pos, "foo")
with pytest.raises(ValueError, match="shape"):
plot_head_positions(pos, axes=1.0)

Expand Down
2 changes: 1 addition & 1 deletion tools/install_pre_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ python -m pip install $STD_ARGS vtk
python -c "import vtk"

echo "PyVista"
python -m pip install $STD_ARGS "git+https://github.com/adeak/pyvista.git@fix_numpy_2" # pyvista/pyvista
python -m pip install $STD_ARGS "git+https://github.com/pyvista/pyvista"

echo "picard"
python -m pip install $STD_ARGS git+https://github.com/pierreablin/picard
Expand Down

0 comments on commit b167573

Please sign in to comment.