From 8b51b038147cdfdce86d3c3235284916241fffcb Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Fri, 29 Nov 2024 15:50:33 +0100 Subject: [PATCH] use pandas instead of Dictreader to read the csv, and automatically infer the datatypes --- .../import_export/load_tracks.py | 113 +++++++----------- 1 file changed, 45 insertions(+), 68 deletions(-) diff --git a/src/motile_plugin/import_export/load_tracks.py b/src/motile_plugin/import_export/load_tracks.py index a509abc..07bd954 100644 --- a/src/motile_plugin/import_export/load_tracks.py +++ b/src/motile_plugin/import_export/load_tracks.py @@ -1,36 +1,11 @@ -from csv import DictReader -from typing import Any - import networkx as nx import numpy as np +import pandas as pd from motile_toolbox.candidate_graph import NodeAttr from motile_plugin.data_model import SolutionTracks -def convert_value(value: Any): - """Converts the given value to float or int if possible, otherwise returns it as is""" - - try: - # Try to convert to integer - int_value = int(value) - if str(int_value) == value: - return int_value - except ValueError: - pass - - try: - # Try to convert to float - float_value = float(value) - if str(float_value) == value: - return float_value - except ValueError: - pass - - # If conversion to int or float fails, return the original value - return value - - def tracks_from_csv( csvfile: str, selected_columns: dict, @@ -56,49 +31,51 @@ def tracks_from_csv( Tracks: a tracks object """ graph = nx.DiGraph() - with open(csvfile) as f: - reader = DictReader(f) - for row in reader: - _id = row[selected_columns["id"]] - if ( - selected_columns.get("z") is not None - and selected_columns.get("z") != "Select Column" - ): - attrs = { - "pos": [ - float(row[selected_columns["z"]]), - float(row[selected_columns["y"]]), - float(selected_columns["x"]), - ], - "time": int(row[selected_columns["t"]]), - } - ndims = 4 - scale = [1, *scale] # assumes 1 for time step - else: - attrs = { - "pos": [ - float(row[selected_columns["y"]]), - float(row[selected_columns["x"]]), - ], - "time": int(row[selected_columns["t"]]), - } - ndims = 3 - scale = [1, scale[1], scale[2]] - if selected_columns["seg_id"] != "Select Column": - attrs["seg_id"] = int(row[selected_columns["seg_id"]]) - for key in extra_columns: - if extra_columns[key] != "Select Column": - attrs[key] = convert_value( - row[extra_columns[key]] - ) # try to convert strings to numerical values if possible + df = pd.read_csv(csvfile) + + for _, row in df.iterrows(): + _id = row["id"] + + if ( + selected_columns.get("z") is not None + and selected_columns.get("z") != "Select Column" + ): + attrs = { + "pos": [ + row[selected_columns["z"]], + row[selected_columns["y"]], + row[selected_columns["x"]], + ], + "time": row[selected_columns["t"]], + } + ndims = 4 + scale = [1, *scale] # assumes 1 for time step + else: + attrs = { + "pos": [ + row[selected_columns["y"]], + row[selected_columns["x"]], + ], + "time": row[selected_columns["t"]], + } + ndims = 3 + scale = [1, scale[1], scale[2]] + + if selected_columns["seg_id"] != "Select Column": + attrs["seg_id"] = row[selected_columns["seg_id"]] + + for key in extra_columns: + if extra_columns[key] != "Select Column": + attrs[key] = row[extra_columns[key]] + + graph.add_node(_id, **attrs) + parent_id = row[selected_columns["parent_id"]] + if not pd.isna(parent_id): + parent_id = parent_id + if parent_id != -1: + assert parent_id in graph.nodes, f"{_id} {parent_id}" + graph.add_edge(parent_id, _id) - graph.add_node(_id, **attrs) - parent_id = row[selected_columns["parent_id"]].strip() - if parent_id != "": - parent_id = parent_id - if parent_id != -1: - assert parent_id in graph.nodes, f"{_id} {parent_id}" - graph.add_edge(parent_id, _id) if segmentation is not None: segmentation = np.expand_dims( segmentation, axis=1