diff --git a/spindle_tracker/io/trackmate.py b/spindle_tracker/io/trackmate.py index 7298641..e081ecc 100644 --- a/spindle_tracker/io/trackmate.py +++ b/spindle_tracker/io/trackmate.py @@ -1,5 +1,7 @@ +import itertools import xml.etree.cElementTree as et +import networkx as nx import pandas as pd import numpy as np @@ -27,15 +29,14 @@ def trackmate_peak_import(trackmate_xml_path, get_tracks=False): 'ESTIMATED_DIAMETER': 'w', 'QUALITY': 'q', 'ID': 'spot_id', - # 'MEAN_INTENSITY': 'mean_intensity', + 'MEAN_INTENSITY': 'mean_intensity', 'MEDIAN_INTENSITY': 'median_intensity', 'MIN_INTENSITY': 'min_intensity', 'MAX_INTENSITY': 'max_intensity', 'TOTAL_INTENSITY': 'total_intensity', 'STANDARD_DEVIATION': 'std_intensity', 'CONTRAST': 'contrast', - 'SNR': 'snr',} - + 'SNR': 'snr'} features = root.find('Model').find('FeatureDeclarations').find('SpotFeatures') features = [c.get('feature') for c in features.getchildren()] + ['ID'] @@ -80,27 +81,59 @@ def trackmate_peak_import(trackmate_xml_path, get_tracks=False): if get_tracks: filtered_track_ids = [int(track.get('TRACK_ID')) for track in root.find('Model').find('FilteredTracks').findall('TrackID')] - trajs['label'] = np.nan + new_trajs = pd.DataFrame() + label_id = 0 trajs = trajs.set_index('spot_id') tracks = root.find('Model').find('AllTracks') for track in tracks.findall('Track'): - track_id = int(track.get("TRACK_ID")) + track_id = int(track.get("TRACK_ID")) if track_id in filtered_track_ids: - spot_ids = [] - for edge in track.findall('Edge'): - spot_ids.append(int(edge.get('SPOT_SOURCE_ID'))) - spot_ids.append(int(edge.get('SPOT_TARGET_ID'))) + spot_ids = [(edge.get('SPOT_SOURCE_ID'), edge.get('SPOT_TARGET_ID'), edge.get('EDGE_TIME')) for edge in track.findall('Edge')] + spot_ids = np.array(spot_ids).astype('float') + spot_ids = pd.DataFrame(spot_ids, columns=['source', 'target', 'time']) + spot_ids = spot_ids.sort_values(by='time') + spot_ids = spot_ids.set_index('time') + + # Build graph + graph = nx.Graph() + for t, spot in spot_ids.iterrows(): + graph.add_edge(int(spot['source']), int(spot['target']), attr_dict=dict(t=t)) + + # Find graph extremities by checking if number of neighbors is equal to 1 + tracks_extremities = [node for node in graph.nodes() if len(graph.neighbors(node)) == 1] + + paths = [] + # Find all possible paths between extremities + for source, target in itertools.combinations(tracks_extremities, 2): + + # Find all path between two nodes + for path in nx.all_simple_paths(graph, source=source, target=target): + + # Now we need to check wether this path respect the time logic contraint + # edges can only go in one direction of the time + + # Build times vector according to path + t = [] + for i, node_srce in enumerate(path[:-1]): + node_trgt = path[i+1] + t.append(graph.edge[node_srce][node_trgt]['t']) + + # Will be equal to 1 if going to one time direction + if len(np.unique(np.sign(np.diff(t)))) == 1: + paths.append(path) - spot_ids = np.unique(spot_ids) - trajs.loc[spot_ids, 'label'] = track_id + # Add each individual trajectory to a new DataFrame called new_trajs + for path in paths: + traj = trajs.loc[path].copy() + traj['label'] = label_id + label_id += 1 - trajs = trajs.reset_index() + new_trajs = new_trajs.append(traj) - # Remove spot without labels - trajs = trajs.dropna(subset=['label']) + trajs = new_trajs trajs.set_index(['t_stamp', 'label'], inplace=True) trajs = trajs.sort_index()