Skip to content

Commit

Permalink
Add x0 and sampling_freq arguments to plot
Browse files Browse the repository at this point in the history
- Bump to 0.12.0
  • Loading branch information
smathot committed Dec 4, 2023
1 parent 88114e9 commit c5c2730
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions time_series_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import random
from collections import namedtuple

__version__ = '0.11.2'
__version__ = '0.12.0'
DEFAULT_HUE_COLORMAP = 'Dark2'
DEFAULT_ANNOTATION_COLORMAP = 'brg'
DEEP_ORANGE = ['#bf360c', '#e64a19', '#ff5722', '#ff8a65', '#ffccbc']
Expand Down Expand Up @@ -327,7 +327,7 @@ def lmer_permutation_test(dm, formula, groups, re_formula=None, winlen=1,
def plot(dm, dv, hue_factor, results=None, linestyle_factor=None, hues=None,
linestyles=None, alpha_level=.05, annotate_intercept=False,
annotation_hues=None, annotation_linestyle=':', legend_kwargs=None,
annotation_legend_kwargs=None):
annotation_legend_kwargs=None, x0=0, sampling_freq=1):
"""Visualizes a time series, where the signal is plotted as a function of
sample number on the x-axis. One fixed effect is indicated by the hue
(color) of the lines. An optional second fixed effect is indicated by the
Expand Down Expand Up @@ -370,6 +370,10 @@ def plot(dm, dv, hue_factor, results=None, linestyle_factor=None, hues=None,
annotation_legend_kwargs: None or dict, optional
Optional keywords to be passed to `plt.legend()` for the annotation
legend.
x0: int, float
The starting value on the x-axis.
sampling_freq: int, float
The sampling frequency.
"""
cols = [dv]
if hue_factor is not None:
Expand All @@ -383,6 +387,9 @@ def plot(dm, dv, hue_factor, results=None, linestyle_factor=None, hues=None,
hues = _colors(hues, dm[hue_factor].count)
if linestyles is None:
linestyles = LINESTYLES
# Adjust x axis
x = np.linspace(x0, x0 + (dm[dv].depth - 1) / sampling_freq, dm[dv].depth)
plt.xlim(x.min(), x.max())
# Plot the annotations
annotation_elements = []
if results is not None:
Expand All @@ -397,16 +404,14 @@ def plot(dm, dv, hue_factor, results=None, linestyle_factor=None, hues=None,
if result.p >= alpha_level:
continue
hue = annotation_hues[i % len(annotation_hues)]
x_hit = x0 + np.mean(list(result.samples)) / sampling_freq
annotation_elements.append(
plt.axvline(np.mean(list(result.samples)),
plt.axvline(x_hit,
linestyle=annotation_linestyle,
color=hue,
label='{}: p = {:.4f}'.format(effect, result.p)))
i += 1
# Adjust x axis
plt.xlim(0, dm[dv].depth)
# Plot the traces
x = np.arange(0, dm[dv].depth)
for i1, (f1, dm1) in enumerate(ops.split(dm[hue_factor])):
hue = hues[i1 % len(hues)]
if linestyle_factor is None:
Expand All @@ -416,7 +421,7 @@ def plot(dm, dv, hue_factor, results=None, linestyle_factor=None, hues=None,
ymin = y - yerr
ymax = y + yerr
plt.fill_between(x, ymin, ymax, color=hue, alpha=.2)
plt.plot(y, color=hue, linestyle=linestyles[0])
plt.plot(x, y, color=hue, linestyle=linestyles[0])
else:
for i2, (f2, dm2) in enumerate(ops.split(dm1[linestyle_factor])):
linestyle = linestyles[i2 % len(linestyles)]
Expand All @@ -426,7 +431,7 @@ def plot(dm, dv, hue_factor, results=None, linestyle_factor=None, hues=None,
ymin = y - yerr
ymax = y + yerr
plt.fill_between(x, ymin, ymax, color=hue, alpha=.2)
plt.plot(y, color=hue, linestyle=linestyle)
plt.plot(x, y, color=hue, linestyle=linestyle)
# Implement legend
if annotation_elements:
if annotation_legend_kwargs is not None:
Expand Down

0 comments on commit c5c2730

Please sign in to comment.