Skip to content

Commit

Permalink
V0 migration model. Renaming - surgery required
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Nov 29, 2023
1 parent 84da68d commit ff29a57
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 587 deletions.
12 changes: 6 additions & 6 deletions src/spyglass/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@
from .prepopulate import populate_from_yaml, prepopulate_default

from spyglass.linearization.v0 import ( # isort:skip
IntervalLinearizationSelection,
IntervalLinearizedPosition,
LinearizationParameters,
LinearizationParams,
LinearizedSelection,
LinearizedV0,
TrackGraph,
)

Expand All @@ -89,8 +89,6 @@
"ElectrodeGroup",
"FirFilterParameters",
"Institution",
"IntervalLinearizationSelection",
"IntervalLinearizedPosition",
"IntervalList",
"IntervalPositionInfo",
"IntervalPositionInfoSelection",
Expand All @@ -101,7 +99,9 @@
"Lab",
"LabMember",
"LabTeam",
"LinearizationParameters",
"LinearizationParams",
"LinearizedV0",
"LinearizedSelection",
"Nwbfile",
"NwbfileKachery",
"PositionInfoParameters",
Expand Down
305 changes: 153 additions & 152 deletions src/spyglass/common/common_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,157 +496,6 @@ def _data_to_df(data, prefix="head_", add_frame_ind=False):
return df


# ------------------------------ Migrated Tables ------------------------------

from spyglass.linearization.v0 import main as linV0 # noqa: E402

(
LinearizationParameters,
TrackGraph,
IntervalLinearizationSelection,
IntervalLinearizedPosition,
) = deprecated_factory(
[
("LinearizationParameters", linV0.LinearizationParameters),
("TrackGraph", linV0.TrackGraph),
(
"IntervalLinearizationSelection",
linV0.IntervalLinearizationSelection,
),
(
"IntervalLinearizedPosition",
linV0.IntervalLinearizedPosition,
),
],
old_module=__name__,
)


class NodePicker:
"""Interactive creation of track graph by looking at video frames."""

def __init__(
self, ax=None, video_filename=None, node_color="#1f78b4", node_size=100
):
if ax is None:
ax = plt.gca()
self.ax = ax
self.canvas = ax.get_figure().canvas
self.cid = None
self._nodes = []
self.node_color = node_color
self._nodes_plot = ax.scatter(
[], [], zorder=5, s=node_size, color=node_color
)
self.edges = [[]]
self.video_filename = video_filename

if video_filename is not None:
self.video = cv2.VideoCapture(video_filename)
frame = self.get_video_frame()
ax.imshow(frame, picker=True)
ax.set_title(
"Left click to place node.\nRight click to remove node."
"\nShift+Left click to clear nodes."
"\nCntrl+Left click two nodes to place an edge"
)

self.connect()

@property
def node_positions(self):
return np.asarray(self._nodes)

def connect(self):
if self.cid is None:
self.cid = self.canvas.mpl_connect(
"button_press_event", self.click_event
)

def disconnect(self):
if self.cid is not None:
self.canvas.mpl_disconnect(self.cid)
self.cid = None

def click_event(self, event):
if not event.inaxes:
return
if (event.key not in ["control", "shift"]) & (
event.button == 1
): # left click
self._nodes.append((event.xdata, event.ydata))
if (event.key not in ["control", "shift"]) & (
event.button == 3
): # right click
self.remove_point((event.xdata, event.ydata))
if (event.key == "shift") & (event.button == 1):
self.clear()
if (event.key == "control") & (event.button == 1):
point = (event.xdata, event.ydata)
distance_to_nodes = np.linalg.norm(
self.node_positions - point, axis=1
)
closest_node_ind = np.argmin(distance_to_nodes)
if len(self.edges[-1]) < 2:
self.edges[-1].append(closest_node_ind)
else:
self.edges.append([closest_node_ind])

self.redraw()

def redraw(self):
# Draw Node Circles
if len(self.node_positions) > 0:
self._nodes_plot.set_offsets(self.node_positions)
else:
self._nodes_plot.set_offsets([])

# Draw Node Numbers
self.ax.texts = []
for ind, (x, y) in enumerate(self.node_positions):
self.ax.text(
x,
y,
ind,
zorder=6,
fontsize=12,
horizontalalignment="center",
verticalalignment="center",
clip_on=True,
bbox=None,
transform=self.ax.transData,
)
# Draw Edges
self.ax.lines = [] # clears the existing lines
for edge in self.edges:
if len(edge) > 1:
x1, y1 = self.node_positions[edge[0]]
x2, y2 = self.node_positions[edge[1]]
self.ax.plot(
[x1, x2], [y1, y2], color=self.node_color, linewidth=2
)

self.canvas.draw_idle()

def remove_point(self, point):
if len(self._nodes) > 0:
distance_to_nodes = np.linalg.norm(
self.node_positions - point, axis=1
)
closest_node_ind = np.argmin(distance_to_nodes)
self._nodes.pop(closest_node_ind)

def clear(self):
self._nodes = []
self.edges = [[]]
self.redraw()

def get_video_frame(self):
is_grabbed, frame = self.video.read()
if is_grabbed:
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)


@schema
class PositionVideo(dj.Computed):
"""Creates a video of the computed head position and orientation as well as
Expand All @@ -656,7 +505,6 @@ class PositionVideo(dj.Computed):

definition = """
-> IntervalPositionInfo
---
"""

def make(self, key):
Expand Down Expand Up @@ -887,6 +735,159 @@ def make_video(
cv2.destroyAllWindows()


# ------------------------------ Migrated Tables ------------------------------

from spyglass.linearization.v0 import main as linV0 # noqa: E402

(
LinearizationParameters,
TrackGraph,
IntervalLinearizationSelection,
IntervalLinearizedPosition,
) = deprecated_factory(
[
("LinearizationParameters", linV0.LinearizationParams),
("TrackGraph", linV0.TrackGraph),
(
"IntervalLinearizationSelection",
linV0.LinearizedSelection,
),
(
"IntervalLinearizedPosition",
linV0.LinearizedV0,
),
],
old_module=__name__,
)

# ------------------------ Helper classes and functions ------------------------


class NodePicker:
"""Interactive creation of track graph by looking at video frames."""

def __init__(
self, ax=None, video_filename=None, node_color="#1f78b4", node_size=100
):
if ax is None:
ax = plt.gca()
self.ax = ax
self.canvas = ax.get_figure().canvas
self.cid = None
self._nodes = []
self.node_color = node_color
self._nodes_plot = ax.scatter(
[], [], zorder=5, s=node_size, color=node_color
)
self.edges = [[]]
self.video_filename = video_filename

if video_filename is not None:
self.video = cv2.VideoCapture(video_filename)
frame = self.get_video_frame()
ax.imshow(frame, picker=True)
ax.set_title(
"Left click to place node.\nRight click to remove node."
"\nShift+Left click to clear nodes."
"\nCntrl+Left click two nodes to place an edge"
)

self.connect()

@property
def node_positions(self):
return np.asarray(self._nodes)

def connect(self):
if self.cid is None:
self.cid = self.canvas.mpl_connect(
"button_press_event", self.click_event
)

def disconnect(self):
if self.cid is not None:
self.canvas.mpl_disconnect(self.cid)
self.cid = None

def click_event(self, event):
if not event.inaxes:
return
if (event.key not in ["control", "shift"]) & (
event.button == 1
): # left click
self._nodes.append((event.xdata, event.ydata))
if (event.key not in ["control", "shift"]) & (
event.button == 3
): # right click
self.remove_point((event.xdata, event.ydata))
if (event.key == "shift") & (event.button == 1):
self.clear()
if (event.key == "control") & (event.button == 1):
point = (event.xdata, event.ydata)
distance_to_nodes = np.linalg.norm(
self.node_positions - point, axis=1
)
closest_node_ind = np.argmin(distance_to_nodes)
if len(self.edges[-1]) < 2:
self.edges[-1].append(closest_node_ind)
else:
self.edges.append([closest_node_ind])

self.redraw()

def redraw(self):
# Draw Node Circles
if len(self.node_positions) > 0:
self._nodes_plot.set_offsets(self.node_positions)
else:
self._nodes_plot.set_offsets([])

# Draw Node Numbers
self.ax.texts = []
for ind, (x, y) in enumerate(self.node_positions):
self.ax.text(
x,
y,
ind,
zorder=6,
fontsize=12,
horizontalalignment="center",
verticalalignment="center",
clip_on=True,
bbox=None,
transform=self.ax.transData,
)
# Draw Edges
self.ax.lines = [] # clears the existing lines
for edge in self.edges:
if len(edge) > 1:
x1, y1 = self.node_positions[edge[0]]
x2, y2 = self.node_positions[edge[1]]
self.ax.plot(
[x1, x2], [y1, y2], color=self.node_color, linewidth=2
)

self.canvas.draw_idle()

def remove_point(self, point):
if len(self._nodes) > 0:
distance_to_nodes = np.linalg.norm(
self.node_positions - point, axis=1
)
closest_node_ind = np.argmin(distance_to_nodes)
self._nodes.pop(closest_node_ind)

def clear(self):
self._nodes = []
self.edges = [[]]
self.redraw()

def get_video_frame(self):
is_grabbed, frame = self.video.read()
if is_grabbed:
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)


def _fix_col_names(spatial_df):
"""Renames columns in spatial dataframe according to previous norm
Expand Down
8 changes: 8 additions & 0 deletions src/spyglass/linearization/merge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datajoint as dj

from spyglass.linearization.v0.main import LinearizedV0 # noqa F401
from spyglass.linearization.v1.main import LinearizedV1 # noqa F401

from ..utils.dj_merge_tables import _Merge
Expand All @@ -15,6 +16,13 @@ class LinearizedOutput(_Merge):
source: varchar(32)
"""

class LinearizedV0(dj.Part): # noqa F811
definition = """
-> master
---
-> LinearizedV0
"""

class LinearizedV1(dj.Part): # noqa F811
definition = """
-> master
Expand Down
15 changes: 11 additions & 4 deletions src/spyglass/linearization/v0/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from .main import (
IntervalLinearizationSelection,
IntervalLinearizedPosition,
LinearizationParameters,
from spyglass.linearization.v0.main import (
LinearizationParams,
LinearizedSelection,
LinearizedV0,
TrackGraph,
)

__all__ = [
"LinearizationParams",
"Linearized0",
"LinearizedSelection",
"TrackGraph",
]
Loading

0 comments on commit ff29a57

Please sign in to comment.