Skip to content

Commit

Permalink
Add split and merge events support for xml TrackMate import
Browse files Browse the repository at this point in the history
  • Loading branch information
hadim committed Dec 2, 2015
1 parent 7449969 commit eff771b
Showing 1 changed file with 47 additions and 14 deletions.
61 changes: 47 additions & 14 deletions spindle_tracker/io/trackmate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import itertools
import xml.etree.cElementTree as et

import networkx as nx
import pandas as pd
import numpy as np

Expand Down Expand Up @@ -27,15 +29,14 @@ def trackmate_peak_import(trackmate_xml_path, get_tracks=False):
'ESTIMATED_DIAMETER': 'w',
'QUALITY': 'q',
'ID': 'spot_id',
# 'MEAN_INTENSITY': 'mean_intensity',
'MEAN_INTENSITY': 'mean_intensity',
'MEDIAN_INTENSITY': 'median_intensity',
'MIN_INTENSITY': 'min_intensity',
'MAX_INTENSITY': 'max_intensity',
'TOTAL_INTENSITY': 'total_intensity',
'STANDARD_DEVIATION': 'std_intensity',
'CONTRAST': 'contrast',
'SNR': 'snr',}

'SNR': 'snr'}

features = root.find('Model').find('FeatureDeclarations').find('SpotFeatures')
features = [c.get('feature') for c in features.getchildren()] + ['ID']
Expand Down Expand Up @@ -80,27 +81,59 @@ def trackmate_peak_import(trackmate_xml_path, get_tracks=False):
if get_tracks:
filtered_track_ids = [int(track.get('TRACK_ID')) for track in root.find('Model').find('FilteredTracks').findall('TrackID')]

trajs['label'] = np.nan
new_trajs = pd.DataFrame()
label_id = 0
trajs = trajs.set_index('spot_id')

tracks = root.find('Model').find('AllTracks')
for track in tracks.findall('Track'):
track_id = int(track.get("TRACK_ID"))

track_id = int(track.get("TRACK_ID"))
if track_id in filtered_track_ids:

spot_ids = []
for edge in track.findall('Edge'):
spot_ids.append(int(edge.get('SPOT_SOURCE_ID')))
spot_ids.append(int(edge.get('SPOT_TARGET_ID')))
spot_ids = [(edge.get('SPOT_SOURCE_ID'), edge.get('SPOT_TARGET_ID'), edge.get('EDGE_TIME')) for edge in track.findall('Edge')]
spot_ids = np.array(spot_ids).astype('float')
spot_ids = pd.DataFrame(spot_ids, columns=['source', 'target', 'time'])
spot_ids = spot_ids.sort_values(by='time')
spot_ids = spot_ids.set_index('time')

# Build graph
graph = nx.Graph()
for t, spot in spot_ids.iterrows():
graph.add_edge(int(spot['source']), int(spot['target']), attr_dict=dict(t=t))

# Find graph extremities by checking if number of neighbors is equal to 1
tracks_extremities = [node for node in graph.nodes() if len(graph.neighbors(node)) == 1]

paths = []
# Find all possible paths between extremities
for source, target in itertools.combinations(tracks_extremities, 2):

# Find all path between two nodes
for path in nx.all_simple_paths(graph, source=source, target=target):

# Now we need to check wether this path respect the time logic contraint
# edges can only go in one direction of the time

# Build times vector according to path
t = []
for i, node_srce in enumerate(path[:-1]):
node_trgt = path[i+1]
t.append(graph.edge[node_srce][node_trgt]['t'])

# Will be equal to 1 if going to one time direction
if len(np.unique(np.sign(np.diff(t)))) == 1:
paths.append(path)

spot_ids = np.unique(spot_ids)
trajs.loc[spot_ids, 'label'] = track_id
# Add each individual trajectory to a new DataFrame called new_trajs
for path in paths:
traj = trajs.loc[path].copy()
traj['label'] = label_id
label_id += 1

trajs = trajs.reset_index()
new_trajs = new_trajs.append(traj)

# Remove spot without labels
trajs = trajs.dropna(subset=['label'])
trajs = new_trajs

trajs.set_index(['t_stamp', 'label'], inplace=True)
trajs = trajs.sort_index()
Expand Down

0 comments on commit eff771b

Please sign in to comment.