Skip to content

Commit

Permalink
Code cleanup.
Browse files Browse the repository at this point in the history
Fixed lsl_viewer_v2.
  • Loading branch information
kowalej committed May 13, 2018
1 parent 530737a commit 104fb06
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 88 deletions.
2 changes: 1 addition & 1 deletion muselsl/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
LSL_BUFFER = 360

VIEW_SUBSAMPLE = 2
VIEW_BUFFER = 12
VIEW_BUFFER = 12
7 changes: 3 additions & 4 deletions muselsl/lsl_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@ def record(duration, filename=None, dejitter=False):

print("Start acquiring data")
inlet = StreamInlet(streams[0], max_chunklen=12)
eeg_time_correction = inlet.time_correction()
# eeg_time_correction = inlet.time_correction()

print("looking for a Markers stream...")
marker_streams = resolve_byprop('name', 'Markers', timeout=2)

if marker_streams:
inlet_marker = StreamInlet(marker_streams[0])
marker_time_correction = inlet_marker.time_correction()
# marker_time_correction = inlet_marker.time_correction()
else:
inlet_marker = False
print("Can't find Markers stream")

info = inlet.info()
description = info.desc()

freq = info.nominal_srate()
# freq = info.nominal_srate()
Nchan = info.channel_count()

ch = description.child('channels').first_child()
Expand Down Expand Up @@ -86,7 +86,6 @@ def record(duration, filename=None, dejitter=False):
for marker in markers:
# find index of markers
ix = np.argmin(np.abs(marker[1] - timestamps))
val = timestamps[ix]
for ii in range(n_markers):
data.loc[ix, 'Marker%d' % ii] = marker[0][ii]

Expand Down
14 changes: 7 additions & 7 deletions muselsl/lsl_viewer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter, lfilter_zi, firwin
from scipy.signal import lfilter, lfilter_zi, firwin
from time import sleep
from pylsl import StreamInlet, resolve_byprop
import seaborn as sns
Expand All @@ -13,7 +13,7 @@ def view(window, scale, refresh, figure):

figsize = np.int16(figure.split('x'))

print("looking for an EEG stream...")
print("Looking for an EEG stream...")
streams = resolve_byprop('type', 'EEG', timeout=2)

if len(streams) == 0:
Expand All @@ -37,7 +37,7 @@ def view(window, scale, refresh, figure):


class LSLViewer():
def __init__(self, stream, fig, axes, window, scale, dejitter=True):
def __init__(self, stream, fig, axes, window, scale, dejitter=True):
"""Init"""
self.stream = stream
self.window = window
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(self, stream, fig, axes, window, scale, dejitter=True):
sns.despine(left=True)

self.data = np.zeros((self.n_samples, self.n_chan))
self.times = np.arange(-self.window, 0, 1./self.sfreq)
self.times = np.arange(-self.window, 0, 1. / self.sfreq)
impedances = np.std(self.data, axis=0)
lines = []

Expand All @@ -93,12 +93,12 @@ def __init__(self, stream, fig, axes, window, scale, dejitter=True):
for ii in range(self.n_chan)]
axes.set_yticklabels(ticks_labels)

self.display_every = int(0.2 / (12/self.sfreq))
self.display_every = int(0.2 / (12 / self.sfreq))

# self.bf, self.af = butter(4, np.array([1, 40])/(self.sfreq/2.),
# 'bandpass')

self.bf = firwin(32, np.array([1, 40])/(self.sfreq/2.), width=0.05,
self.bf = firwin(32, np.array([1, 40]) / (self.sfreq / 2.), width=0.05,
pass_zero=False)
self.af = [1.0]

Expand All @@ -115,7 +115,7 @@ def update_plot(self):
if self.dejitter:
timestamps = np.float64(np.arange(len(timestamps)))
timestamps /= self.sfreq
timestamps += self.times[-1] + 1./self.sfreq
timestamps += self.times[-1] + 1. / self.sfreq
self.times = np.concatenate([self.times, timestamps])
self.n_samples = int(self.sfreq * self.window)
self.times = self.times[-self.n_samples:]
Expand Down
116 changes: 59 additions & 57 deletions muselsl/lsl_viewer_V2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@
varying vec2 v_position;
varying vec4 v_ab;
void main() {
float nrows = u_size.x;
float ncols = u_size.y;
float n_rows = u_size.x;
float n_cols = u_size.y;
// Compute the x coordinate from the time index.
float x = -1 + 2*a_index.z / (u_n-1);
vec2 position = vec2(x - (1 - 1 / u_scale.x), a_position);
// Find the affine transformation for the subplots.
vec2 a = vec2(1./ncols, 1./nrows)*.9;
vec2 b = vec2(-1 + 2*(a_index.x+.5) / ncols,
-1 + 2*(a_index.y+.5) / nrows);
vec2 a = vec2(1./n_cols, 1./n_rows)*.9;
vec2 b = vec2(-1 + 2*(a_index.x+.5) / n_cols,
-1 + 2*(a_index.y+.5) / n_rows);
// Apply the static subplot transformation + scaling.
gl_Position = vec4(a*u_scale*position+b, 0.0, 1.0);
v_color = vec4(a_color, 1.);
Expand Down Expand Up @@ -74,6 +74,7 @@
}
"""


def view():
print("looking for an EEG stream...")
streams = resolve_byprop('type', 'EEG', timeout=2)
Expand All @@ -83,68 +84,69 @@ def view():
print("Start acquiring data")

inlet = StreamInlet(streams[0], max_chunklen=12)
Canvas(inlet)
app.run()

info = inlet.info()
description = info.desc()

window = 10
sfreq = info.nominal_srate()
n_samples = int(sfreq * window)
n_chan = info.channel_count()
class Canvas(app.Canvas):
def __init__(self, lsl_inlet, scale=500, filt=True):
app.Canvas.__init__(self, title='EEG - Use your wheel to zoom!',
keys='interactive')

ch = description.child('channels').first_child()
ch_names = [ch.child_value('label')]
self.inlet = lsl_inlet
info = self.inlet.info()
description = info.desc()

for i in range(n_chan):
ch = ch.next_sibling()
ch_names.append(ch.child_value('label'))
window = 10
self.sfreq = info.nominal_srate()
n_samples = int(self.sfreq * window)
self.n_chans = info.channel_count()

# Number of cols and rows in the table.
nrows = n_chan
ncols = 1
ch = description.child('channels').first_child()
ch_names = [ch.child_value('label')]

# Number of signals.
m = nrows*ncols
for i in range(self.n_chans):
ch = ch.next_sibling()
ch_names.append(ch.child_value('label'))

# Number of samples per signal.
n = n_samples
# Number of cols and rows in the table.
n_rows = self.n_chans
n_cols = 1

# Various signal amplitudes.
amplitudes = np.zeros((m, n)).astype(np.float32)
gamma = np.ones((m, n)).astype(np.float32)
# Generate the signals as a (m, n) array.
y = amplitudes
# Number of signals.
m = n_rows * n_cols

color = color_palette("RdBu_r", nrows)
# Number of samples per signal.
n = n_samples

color = np.repeat(color, n, axis=0).astype(np.float32)
# Signal 2D index of each vertex (row and col) and x-index (sample index
# within each signal).
index = np.c_[np.repeat(np.repeat(np.arange(ncols), nrows), n),
np.repeat(np.tile(np.arange(nrows), ncols), n),
np.tile(np.arange(n), m)].astype(np.float32)
c = Canvas()
app.run()
# Various signal amplitudes.
amplitudes = np.zeros((m, n)).astype(np.float32)
# gamma = np.ones((m, n)).astype(np.float32)
# Generate the signals as a (m, n) array.
y = amplitudes

class Canvas(app.Canvas):
def __init__(self, scale=500, filt=True):
app.Canvas.__init__(self, title='EEG - Use your wheel to zoom!',
keys='interactive')
color = color_palette("RdBu_r", n_rows)

color = np.repeat(color, n, axis=0).astype(np.float32)
# Signal 2D index of each vertex (row and col) and x-index (sample index
# within each signal).
index = np.c_[np.repeat(np.repeat(np.arange(n_cols), n_rows), n),
np.repeat(np.tile(np.arange(n_rows), n_cols), n),
np.tile(np.arange(n), m)].astype(np.float32)

self.program = gloo.Program(VERT_SHADER, FRAG_SHADER)
self.program['a_position'] = y.reshape(-1, 1)
self.program['a_color'] = color
self.program['a_index'] = index
self.program['u_scale'] = (1., 1.)
self.program['u_size'] = (nrows, ncols)
self.program['u_size'] = (n_rows, n_cols)
self.program['u_n'] = n


# text
self.font_size = 48.
self.names = []
self.quality = []
for ii in range(n_chan):
for ii in range(self.n_chans):
text = visuals.TextVisual(ch_names[ii], bold=True, color='white')
self.names.append(text)
text = visuals.TextVisual('', bold=True, color='white')
Expand All @@ -157,14 +159,14 @@ def __init__(self, scale=500, filt=True):
self.filt = filt
self.af = [1.0]

self.data_f = np.zeros((n_samples, n_chan))
self.data = np.zeros((n_samples, n_chan))
self.data_f = np.zeros((n_samples, self.n_chans))
self.data = np.zeros((n_samples, self.n_chans))

self.bf = create_filter(self.data_f.T, sfreq, 3, 40.,
self.bf = create_filter(self.data_f.T, self.sfreq, 3, 40.,
method='fir', fir_design='firwin')

zi = lfilter_zi(self.bf, self.af)
self.filt_state = np.tile(zi, (n_chan, 1)).transpose()
self.filt_state = np.tile(zi, (self.n_chans, 1)).transpose()

self._timer = app.Timer('auto', connect=self.on_timer, start=True)
gloo.set_viewport(0, 0, *self.physical_size)
Expand All @@ -186,23 +188,23 @@ def on_key_press(self, event):
else:
dx = 0.05
scale_x, scale_y = self.program['u_scale']
scale_x_new, scale_y_new = (scale_x * math.exp(1.0*dx),
scale_y * math.exp(0.0*dx))
scale_x_new, scale_y_new = (scale_x * math.exp(1.0 * dx),
scale_y * math.exp(0.0 * dx))
self.program['u_scale'] = (max(1, scale_x_new), max(1, scale_y_new))
self.update()

def on_mouse_wheel(self, event):
dx = np.sign(event.delta[1]) * .05
scale_x, scale_y = self.program['u_scale']
scale_x_new, scale_y_new = (scale_x * math.exp(0.0*dx),
scale_y * math.exp(2.0*dx))
scale_x_new, scale_y_new = (scale_x * math.exp(0.0 * dx),
scale_y * math.exp(2.0 * dx))
self.program['u_scale'] = (max(1, scale_x_new), max(0.01, scale_y_new))
self.update()

def on_timer(self, event):
"""Add some data at the end of each signal (real-time signals)."""

samples, timestamps = inlet.pull_chunk(timeout=0.0,
samples, timestamps = self.inlet.pull_chunk(timeout=0.0,
max_samples=100)
if timestamps:
samples = np.array(samples)[:, ::-1]
Expand All @@ -219,9 +221,9 @@ def on_timer(self, event):
elif not self.filt:
plot_data = (self.data - self.data.mean(axis=0)) / self.scale

sd = np.std(plot_data[-int(sfreq):], axis=0)[::-1] * self.scale
sd = np.std(plot_data[-int(self.sfreq):], axis=0)[::-1] * self.scale
co = np.int32(np.tanh((sd - 30) / 15)*5 + 5)
for ii in range(n_chan):
for ii in range(self.n_chans):
self.quality[ii].text = '%.2f' % (sd[ii])
self.quality[ii].color = self.quality_colors[co[ii]]
self.quality[ii].font_size = 12 + co[ii]
Expand All @@ -239,11 +241,11 @@ def on_resize(self, event):

for ii, t in enumerate(self.names):
t.transforms.configure(canvas=self, viewport=vp)
t.pos = (self.size[0] * 0.025, ((ii + 0.5)/n_chan) * self.size[1])
t.pos = (self.size[0] * 0.025, ((ii + 0.5) / self.n_chans) * self.size[1])

for ii, t in enumerate(self.quality):
t.transforms.configure(canvas=self, viewport=vp)
t.pos = (self.size[0] * 0.975, ((ii + 0.5)/n_chan) * self.size[1])
t.pos = (self.size[0] * 0.975, ((ii + 0.5) / self.n_chans) * self.size[1])

def on_draw(self, event):
gloo.clear()
Expand Down
12 changes: 4 additions & 8 deletions muselsl/muse-lsl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
#!/usr/bin/python
import sys
import getopt
import argparse
import re
import os
import configparser


class Program:
Expand Down Expand Up @@ -77,8 +73,8 @@ def stream(self):
parser = argparse.ArgumentParser(
description='Start an LSL stream from Muse headset.')
parser.add_argument("-a", "--address",
dest="address", type=str, default=None,
help="device MAC address.")
dest="address", type=str, default=None,
help="device MAC address.")
parser.add_argument("-n", "--name",
dest="name", type=str, default=None,
help="name of the device.")
Expand Down Expand Up @@ -135,8 +131,8 @@ def viewlsl(self):
help="viewer version (1 or 2) - 1 is the default stable version, 2 is in development (and takes no arguments).")
args = parser.parse_args(sys.argv[2:])
if args.version == 2:
import lsl_viewer_V2
lsl_viewer_V2.view()
import lsl_viewer_v2
lsl_viewer_v2.view()
else:
import lsl_viewer
lsl_viewer.view(args.window, args.scale, args.refresh, args.figure)
Expand Down
10 changes: 4 additions & 6 deletions muselsl/muse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import bitstring
import pygatt
import numpy as np
from time import time, sleep
from time import time
from sys import platform
import subprocess

Expand Down Expand Up @@ -64,7 +64,7 @@ def connect(self, interface=None, backend='auto'):
self.interface = self.interface or 'hci0'
self.adapter = pygatt.GATTToolBackend(self.interface)
else:
self.adapter = pygatt.BGAPIBackend(serial_port=self.interface)
self.adapter = pygatt.BGAPIBackend(serial_port=self.interface)

self.adapter.start()
self.device = self.adapter.connect(self.address)
Expand All @@ -86,12 +86,11 @@ def connect(self, interface=None, backend='auto'):
self._subscribe_gyro()

return True

except (pygatt.exceptions.NotConnectedError, pygatt.exceptions.NotificationTimeout):
print('Connection to', self.address, 'failed')
return False


def _write_cmd(self, cmd):
"""Wrapper to write a command to the Muse device.
cmd -- list of bytes"""
Expand Down Expand Up @@ -212,7 +211,7 @@ def _init_timestamp_correction(self):
# initial params for the timestamp correction
# the time it started + the inverse of sampling rate
self.sample_index = 0
self.reg_params = np.array([self.time_func(), 1./256])
self.reg_params = np.array([self.time_func(), 1. / 256])

def _update_timestamp_correction(self, x, y):
"""Update regression for dejittering
Expand Down Expand Up @@ -317,7 +316,6 @@ def _handle_telemetry(self, handle, packet):
pattern = "uint:16,uint:16,uint:16,uint:16,uint:16" # The rest is 0 padding
data = bit_decoder.unpack(pattern)

packet_index = data[0]
battery = data[1] / 512
fuel_gauge = data[2] * 2.2
adc_volt = data[3]
Expand Down
Loading

0 comments on commit 104fb06

Please sign in to comment.