diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 237c9402f..55ee0fa9a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,11 @@ `Unreleased `_ ---------- +- Bug Fixes: + + - Allow ``color`` to be passed as an extra kwarg to ``plot_1d`` and + ``plot_1d_to_axis``. Previously this caused a ``TypeError``. + `v1.2.0 `_ ------ diff --git a/euphonic/plot.py b/euphonic/plot.py index 0cd4173b5..25a3a0705 100644 --- a/euphonic/plot.py +++ b/euphonic/plot.py @@ -79,10 +79,11 @@ def plot_1d_to_axis(spectra: Union[Spectrum1D, Spectrum1DCollection], # Only add legend label to the first segment label = None color = p[-1].get_color() - + # Allow user kwargs to take priority + plot_kwargs = {**{'color': color, 'label': label}, **mplargs} p = ax.plot(spectrum.get_bin_centres().magnitude[x0:x1], spectrum.y_data.magnitude[x0:x1], - color=color, label=label, **mplargs) + **plot_kwargs) # Update legend if it exists, in case new labels have been added legend = ax.get_legend() diff --git a/tests_and_analysis/test/euphonic_test/test_plot.py b/tests_and_analysis/test/euphonic_test/test_plot.py index a36795ddd..c9872b4a4 100644 --- a/tests_and_analysis/test/euphonic_test/test_plot.py +++ b/tests_and_analysis/test/euphonic_test/test_plot.py @@ -186,6 +186,20 @@ def test_incorrect_length_labels_raises_value_error( with pytest.raises(ValueError): plot_1d_to_axis(spec, axes, labels=labels) + @pytest.mark.parametrize('kwargs', [ + ({'ls': '--'}), + ({'color': 'r', 'ms': '+'}) + ]) + def test_extra_plot_kwargs(self, mocker, axes, kwargs): + mock = mocker.patch('matplotlib.axes.Axes.plot', + return_value=None) + spec = Spectrum1D(np.array([0., 1., 2.])*ureg('meV'), + np.array([2., 3., 2.])*ureg('angstrom^-2')) + plot_1d_to_axis(spec, axes, **kwargs) + + expected_kwargs = {**{'color': None, 'label': None}, **kwargs} + assert mock.call_args[1] == expected_kwargs + class TestPlot1D: def teardown_method(self): @@ -292,6 +306,21 @@ def test_plot_with_incorrect_labels_raises_valueerror(self, band_segments): with pytest.raises(ValueError): fig = plot_1d(band_segments, labels=['Band A', 'Band B']) + @pytest.mark.parametrize('spec, kwargs', [ + (Spectrum1D(*spec1d_args), {'ls': '.-'}), + (Spectrum1D(*spec1d_args), {'ms': '*', 'color': 'g'}), + (Spectrum1DCollection(*spec1dcol_args), {'ms': '*', 'color': 'g'}), + (Spectrum1DCollection(*spec1dcol_args), + {'label': 'Line A', 'color': 'k'}) + ]) + def test_plot_kwargs(self, mocker, spec, kwargs): + mock = mocker.patch('matplotlib.axes.Axes.plot', + return_value=None) + plot_1d(spec, **kwargs) + + expected_kwargs = {**{'color': None, 'label': None}, **kwargs} + for mock_call_args in mock.call_args_list: + assert mock_call_args[1] == expected_kwargs class TestPlot2D: