Skip to content

Commit

Permalink
use pandas instead of Dictreader to read the csv, and automatically i…
Browse files Browse the repository at this point in the history
…nfer the datatypes
  • Loading branch information
AnniekStok committed Nov 29, 2024
1 parent e2a6f62 commit 8b51b03
Showing 1 changed file with 45 additions and 68 deletions.
113 changes: 45 additions & 68 deletions src/motile_plugin/import_export/load_tracks.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,11 @@
from csv import DictReader
from typing import Any

import networkx as nx
import numpy as np
import pandas as pd
from motile_toolbox.candidate_graph import NodeAttr

from motile_plugin.data_model import SolutionTracks


def convert_value(value: Any):
"""Converts the given value to float or int if possible, otherwise returns it as is"""

try:
# Try to convert to integer
int_value = int(value)
if str(int_value) == value:
return int_value
except ValueError:
pass

try:
# Try to convert to float
float_value = float(value)
if str(float_value) == value:
return float_value
except ValueError:
pass

# If conversion to int or float fails, return the original value
return value


def tracks_from_csv(
csvfile: str,
selected_columns: dict,
Expand All @@ -56,49 +31,51 @@ def tracks_from_csv(
Tracks: a tracks object
"""
graph = nx.DiGraph()
with open(csvfile) as f:
reader = DictReader(f)
for row in reader:
_id = row[selected_columns["id"]]
if (
selected_columns.get("z") is not None
and selected_columns.get("z") != "Select Column"
):
attrs = {
"pos": [
float(row[selected_columns["z"]]),
float(row[selected_columns["y"]]),
float(selected_columns["x"]),
],
"time": int(row[selected_columns["t"]]),
}
ndims = 4
scale = [1, *scale] # assumes 1 for time step
else:
attrs = {
"pos": [
float(row[selected_columns["y"]]),
float(row[selected_columns["x"]]),
],
"time": int(row[selected_columns["t"]]),
}
ndims = 3
scale = [1, scale[1], scale[2]]
if selected_columns["seg_id"] != "Select Column":
attrs["seg_id"] = int(row[selected_columns["seg_id"]])
for key in extra_columns:
if extra_columns[key] != "Select Column":
attrs[key] = convert_value(
row[extra_columns[key]]
) # try to convert strings to numerical values if possible
df = pd.read_csv(csvfile)

for _, row in df.iterrows():
_id = row["id"]

if (
selected_columns.get("z") is not None
and selected_columns.get("z") != "Select Column"
):
attrs = {
"pos": [
row[selected_columns["z"]],
row[selected_columns["y"]],
row[selected_columns["x"]],
],
"time": row[selected_columns["t"]],
}
ndims = 4
scale = [1, *scale] # assumes 1 for time step
else:
attrs = {
"pos": [
row[selected_columns["y"]],
row[selected_columns["x"]],
],
"time": row[selected_columns["t"]],
}
ndims = 3
scale = [1, scale[1], scale[2]]

if selected_columns["seg_id"] != "Select Column":
attrs["seg_id"] = row[selected_columns["seg_id"]]

for key in extra_columns:
if extra_columns[key] != "Select Column":
attrs[key] = row[extra_columns[key]]

graph.add_node(_id, **attrs)
parent_id = row[selected_columns["parent_id"]]
if not pd.isna(parent_id):
parent_id = parent_id
if parent_id != -1:
assert parent_id in graph.nodes, f"{_id} {parent_id}"
graph.add_edge(parent_id, _id)

graph.add_node(_id, **attrs)
parent_id = row[selected_columns["parent_id"]].strip()
if parent_id != "":
parent_id = parent_id
if parent_id != -1:
assert parent_id in graph.nodes, f"{_id} {parent_id}"
graph.add_edge(parent_id, _id)
if segmentation is not None:
segmentation = np.expand_dims(
segmentation, axis=1
Expand Down

0 comments on commit 8b51b03

Please sign in to comment.