From efa76fec3961e306364de9cf855edf2ec26c12e5 Mon Sep 17 00:00:00 2001 From: rly Date: Thu, 11 Jul 2024 16:10:32 -0400 Subject: [PATCH] Fix probeinterface converter shape keys --- src/pynwb/ndx_extracellular_channels/io.py | 40 ++- .../test_example_usage_probeinterface.py | 241 ++++++++++++++---- 2 files changed, 227 insertions(+), 54 deletions(-) diff --git a/src/pynwb/ndx_extracellular_channels/io.py b/src/pynwb/ndx_extracellular_channels/io.py index 0044ee6..bb9365f 100644 --- a/src/pynwb/ndx_extracellular_channels/io.py +++ b/src/pynwb/ndx_extracellular_channels/io.py @@ -67,7 +67,24 @@ def from_probeinterface( def to_probeinterface(ndx_probe: ndx_extracellular_channels.Probe) -> probeinterface.Probe: """ - Construct a probeinterface.Probe from a ndx_extracellular_channels.Probe + Construct a probeinterface.Probe from a ndx_extracellular_channels.Probe. + + ndx_extracellular_channels.Probe.name -> probeinterface.Probe.name + ndx_extracellular_channels.Probe.identifier -> probeinterface.Probe.serial_number + ndx_extracellular_channels.Probe.probe_model.name -> probeinterface.Probe.model_name + ndx_extracellular_channels.Probe.probe_model.manufacturer -> probeinterface.Probe.manufacturer + ndx_extracellular_channels.Probe.probe_model.ndim -> probeinterface.Probe.ndim + ndx_extracellular_channels.Probe.probe_model.planar_contour_in_um -> probeinterface.Probe.probe_planar_contour + ndx_extracellular_channels.Probe.probe_model.contacts_table["relative_position_in_mm"] -> + probeinterface.Probe.contact_positions + ndx_extracellular_channels.Probe.probe_model.contacts_table["shape"] -> probeinterface.Probe.contact_shapes + ndx_extracellular_channels.Probe.probe_model.contacts_table["contact_id"] -> probeinterface.Probe.contact_ids + ndx_extracellular_channels.Probe.probe_model.contacts_table["device_channel"] -> + probeinterface.Probe.device_channel_indices + ndx_extracellular_channels.Probe.probe_model.contacts_table["shank_id"] -> probeinterface.Probe.shank_ids + ndx_extracellular_channels.Probe.probe_model.contacts_table["plane_axes"] -> probeinterface.Probe.contact_plane_axes + ndx_extracellular_channels.Probe.probe_model.contacts_table["radius_in_um"] -> probeinterface.Probe.contact_shapes["radius"] + Parameters ---------- @@ -89,12 +106,11 @@ def to_probeinterface(ndx_probe: ndx_extracellular_channels.Probe) -> probeinter shapes = [] contact_ids = None - shape_params = None shank_ids = None plane_axes = None device_channel_indices = None - possible_shape_keys = ["radius", "width", "height"] + possible_shape_keys = ["radius_in_um", "width_in_um", "height_in_um"] contacts_table = ndx_probe.probe_model.contacts_table positions.append(contacts_table["relative_position_in_mm"][:]) @@ -115,11 +131,6 @@ def to_probeinterface(ndx_probe: ndx_extracellular_channels.Probe) -> probeinter if shank_ids is None: shank_ids = [] shank_ids.append(contacts_table["shank_id"][:]) - for possible_shape_key in possible_shape_keys: - if possible_shape_key in contacts_table.colnames: - if shape_params is None: - shape_params = [] - shape_params.append([{possible_shape_key: val} for val in contacts_table[possible_shape_key][:]]) positions = [item for sublist in positions for item in sublist] shapes = [item for sublist in shapes for item in sublist] @@ -128,13 +139,22 @@ def to_probeinterface(ndx_probe: ndx_extracellular_channels.Probe) -> probeinter contact_ids = [item for sublist in contact_ids for item in sublist] if plane_axes is not None: plane_axes = [item for sublist in plane_axes for item in sublist] - if shape_params is not None: - shape_params = [item for sublist in shape_params for item in sublist] if shank_ids is not None: shank_ids = [item for sublist in shank_ids for item in sublist] if device_channel_indices is not None: device_channel_indices = [item for sublist in device_channel_indices for item in sublist] + # if there are multiple shape keys, e.g., radius, width, and height + # we need to create a list of dicts, one for each contact + shape_params = [dict() for _ in range(len(contacts_table))] + for i in range(len(contacts_table)): + for possible_shape_key in possible_shape_keys: + if possible_shape_key in contacts_table.colnames: + new_key = possible_shape_key.replace("_in_um", "") + shape_params[i][new_key] = contacts_table[possible_shape_key][i] + + print(shape_params) + probeinterface_probe = probeinterface.Probe( ndim=ndx_probe.probe_model.ndim, si_units="um", diff --git a/src/pynwb/tests/test_example_usage_probeinterface.py b/src/pynwb/tests/test_example_usage_probeinterface.py index d816ee7..e4b92d1 100644 --- a/src/pynwb/tests/test_example_usage_probeinterface.py +++ b/src/pynwb/tests/test_example_usage_probeinterface.py @@ -1,6 +1,7 @@ import datetime import ndx_extracellular_channels import numpy as np +import numpy.testing as npt import probeinterface import pynwb import uuid @@ -31,7 +32,7 @@ def test_from_probeinterface(): polygon = [(-20.0, -30.0), (20.0, -110.0), (60.0, -30.0), (60.0, 190.0), (-20.0, 190.0)] probe0.set_planar_contour(polygon) - probe1 = probeinterface.generate_dummy_probe(elec_shapes="circle") + probe1 = probeinterface.generate_dummy_probe(elec_shapes="circle") # no name set probe1.serial_number = "1000" probe1.model_name = "Dummy Neuropixels 1.0" probe1.manufacturer = "IMEC" @@ -45,6 +46,7 @@ def test_from_probeinterface(): probe2.move([500, -90]) probe3 = probeinterface.generate_dummy_probe(elec_shapes="circle") + probe3.name = "probe3" probe3.serial_number = "1002" probe3.model_name = "Dummy Neuropixels 3.0" probe3.manufacturer = "IMEC" @@ -61,7 +63,8 @@ def test_from_probeinterface(): ndx_probes.extend(model0) model1 = ndx_extracellular_channels.from_probeinterface(probe1, name="probe1") # override name of probe ndx_probes.extend(model1) - group_probes = ndx_extracellular_channels.from_probeinterface(probegroup, name=[None, "probe3"]) + # override name of probe3 + group_probes = ndx_extracellular_channels.from_probeinterface(probegroup, name=[None, "renamed_probe3"]) ndx_probes.extend(group_probes) nwbfile = pynwb.NWBFile( @@ -79,13 +82,13 @@ def test_from_probeinterface(): io.write(nwbfile) # read the file and check the content - with pynwb.NWBHDF5IO("test_probeinterface.nwb", "r", load_namespaces=True) as io: + with pynwb.NWBHDF5IO("test_probeinterface.nwb", "r") as io: nwbfile = io.read() assert set(nwbfile.devices.keys()) == { "probe0", "probe1", "probe2", - "probe3", + "renamed_probe3", "a1x32-edge-5mm-20-177_H32", "Dummy Neuropixels 1.0", "Dummy Neuropixels 2.0", @@ -96,7 +99,7 @@ def test_from_probeinterface(): assert isinstance(nwbfile.devices["probe0"], ndx_extracellular_channels.Probe) assert isinstance(nwbfile.devices["probe1"], ndx_extracellular_channels.Probe) assert isinstance(nwbfile.devices["probe2"], ndx_extracellular_channels.Probe) - assert isinstance(nwbfile.devices["probe3"], ndx_extracellular_channels.Probe) + assert isinstance(nwbfile.devices["renamed_probe3"], ndx_extracellular_channels.Probe) assert isinstance(nwbfile.devices["a1x32-edge-5mm-20-177_H32"], ndx_extracellular_channels.ProbeModel) assert isinstance(nwbfile.devices["Dummy Neuropixels 1.0"], ndx_extracellular_channels.ProbeModel) assert isinstance(nwbfile.devices["Dummy Neuropixels 2.0"], ndx_extracellular_channels.ProbeModel) @@ -107,23 +110,23 @@ def test_from_probeinterface(): assert nwbfile.devices["probe0"].probe_model.name == "a1x32-edge-5mm-20-177_H32" assert nwbfile.devices["probe0"].probe_model.manufacturer == "Neuronexus" assert nwbfile.devices["probe0"].probe_model.ndim == 2 - assert np.all(nwbfile.devices["probe0"].probe_model.planar_contour_in_um == polygon) - assert np.allclose(nwbfile.devices["probe0"].probe_model.contacts_table.relative_position_in_mm, positions) - assert np.all(nwbfile.devices["probe0"].probe_model.contacts_table["shape"].data[:] == "circle") - assert np.all(nwbfile.devices["probe0"].probe_model.contacts_table["radius_in_um"].data[:] == 5.0) + npt.assert_array_equal(nwbfile.devices["probe0"].probe_model.planar_contour_in_um, polygon) + npt.assert_allclose(nwbfile.devices["probe0"].probe_model.contacts_table.relative_position_in_mm, positions) + npt.assert_array_equal(nwbfile.devices["probe0"].probe_model.contacts_table["shape"].data[:], "circle") + npt.assert_array_equal(nwbfile.devices["probe0"].probe_model.contacts_table["radius_in_um"].data[:], 5.0) assert nwbfile.devices["probe1"].name == "probe1" assert nwbfile.devices["probe1"].identifier == "1000" assert nwbfile.devices["probe1"].probe_model.name == "Dummy Neuropixels 1.0" assert nwbfile.devices["probe1"].probe_model.manufacturer == "IMEC" assert nwbfile.devices["probe1"].probe_model.ndim == 2 - assert np.allclose(nwbfile.devices["probe1"].probe_model.planar_contour_in_um, probe1.probe_planar_contour) - assert np.allclose( + npt.assert_allclose(nwbfile.devices["probe1"].probe_model.planar_contour_in_um, probe1.probe_planar_contour) + npt.assert_allclose( nwbfile.devices["probe1"].probe_model.contacts_table.relative_position_in_mm, probe1.contact_positions ) - assert np.all(nwbfile.devices["probe1"].probe_model.contacts_table["shape"].data[:] == "circle") - assert np.all( - nwbfile.devices["probe1"].probe_model.contacts_table["radius_in_um"].data[:] == probe1.to_numpy()["radius"] + npt.assert_array_equal(nwbfile.devices["probe1"].probe_model.contacts_table["shape"].data[:], "circle") + npt.assert_array_equal( + nwbfile.devices["probe1"].probe_model.contacts_table["radius_in_um"].data[:], probe1.to_numpy()["radius"] ) assert nwbfile.devices["probe2"].name == "probe2" @@ -131,40 +134,190 @@ def test_from_probeinterface(): assert nwbfile.devices["probe2"].probe_model.name == "Dummy Neuropixels 2.0" assert nwbfile.devices["probe2"].probe_model.manufacturer == "IMEC" assert nwbfile.devices["probe2"].probe_model.ndim == 2 - assert np.allclose(nwbfile.devices["probe2"].probe_model.planar_contour_in_um, probe2.probe_planar_contour) - assert np.allclose( + npt.assert_allclose(nwbfile.devices["probe2"].probe_model.planar_contour_in_um, probe2.probe_planar_contour) + npt.assert_allclose( nwbfile.devices["probe2"].probe_model.contacts_table.relative_position_in_mm, probe2.contact_positions ) - assert np.all(nwbfile.devices["probe2"].probe_model.contacts_table["shape"].data[:] == "square") - assert np.all( - nwbfile.devices["probe2"].probe_model.contacts_table["width_in_um"].data[:] == probe2.to_numpy()["width"] + npt.assert_array_equal(nwbfile.devices["probe2"].probe_model.contacts_table["shape"].data[:], "square") + npt.assert_array_equal( + nwbfile.devices["probe2"].probe_model.contacts_table["width_in_um"].data[:], probe2.to_numpy()["width"] ) - assert nwbfile.devices["probe3"].name == "probe3" - assert nwbfile.devices["probe3"].identifier == "1002" - assert nwbfile.devices["probe3"].probe_model.name == "Dummy Neuropixels 3.0" - assert nwbfile.devices["probe3"].probe_model.manufacturer == "IMEC" - assert nwbfile.devices["probe3"].probe_model.ndim == 2 - assert np.allclose(nwbfile.devices["probe3"].probe_model.planar_contour_in_um, probe3.probe_planar_contour) - assert np.allclose( - nwbfile.devices["probe3"].probe_model.contacts_table.relative_position_in_mm, probe3.contact_positions + assert nwbfile.devices["renamed_probe3"].name == "renamed_probe3" + assert nwbfile.devices["renamed_probe3"].identifier == "1002" + assert nwbfile.devices["renamed_probe3"].probe_model.name == "Dummy Neuropixels 3.0" + assert nwbfile.devices["renamed_probe3"].probe_model.manufacturer == "IMEC" + assert nwbfile.devices["renamed_probe3"].probe_model.ndim == 2 + npt.assert_allclose( + nwbfile.devices["renamed_probe3"].probe_model.planar_contour_in_um, probe3.probe_planar_contour ) - assert np.all(nwbfile.devices["probe3"].probe_model.contacts_table["shape"].data[:] == "circle") - assert np.all( - nwbfile.devices["probe3"].probe_model.contacts_table["radius_in_um"].data[:] == probe3.to_numpy()["radius"] + npt.assert_allclose( + nwbfile.devices["renamed_probe3"].probe_model.contacts_table.relative_position_in_mm, + probe3.contact_positions, ) + npt.assert_array_equal(nwbfile.devices["renamed_probe3"].probe_model.contacts_table["shape"].data[:], "circle") + npt.assert_array_equal( + nwbfile.devices["renamed_probe3"].probe_model.contacts_table["radius_in_um"].data[:], + probe3.to_numpy()["radius"] + ) + + +def test_to_probeinterface(): + + # create a NWB file with a few probes + nwbfile = pynwb.NWBFile( + session_description="A description of my session", + identifier=str(uuid.uuid4()), + session_start_time=datetime.datetime.now(datetime.timezone.utc), + ) + + # create a probe model + probe_model0 = ndx_extracellular_channels.ProbeModel( + name="a1x32-edge-5mm-20-177_H32", + model="a1x32-edge-5mm-20-177_H32", + manufacturer="Neuronexus", + ndim=2, + planar_contour_in_um=[(-20.0, -30.0), (20.0, -110.0), (60.0, -30.0), (60.0, 190.0), (-20.0, 190.0)], + contacts_table=ndx_extracellular_channels.ContactsTable( + name="contacts_table", + description="a table with electrode contacts", + columns=[ + pynwb.core.VectorData( + name="relative_position_in_mm", + description="the relative position of the contact in mm", + data=[ + (0.0, 0.0), + (0.0, 20.0), + (0.0, 40.0), + (0.0, 60.0), + (0.0, 80.0), + (0.0, 100.0), + (0.0, 120.0), + (0.0, 140.0), + (20.0, 0.0), + (20.0, 20.0), + (20.0, 40.0), + (20.0, 60.0), + (20.0, 80.0), + (20.0, 100.0), + (20.0, 120.0), + (20.0, 140.0), + (40.0, 0.0), + (40.0, 20.0), + (40.0, 40.0), + (40.0, 60.0), + (40.0, 80.0), + (40.0, 100.0), + (40.0, 120.0), + (40.0, 140.0), + ], + ), + pynwb.core.VectorData( + name="shape", + description="the shape of the contact", + data=["circle"] * 24, + ), + pynwb.core.VectorData( + name="radius_in_um", + description="the radius of the contact in um", + data=[5.0] * 24, + ), + ], + ), + ) + + # create a probe + probe0 = ndx_extracellular_channels.Probe( + name="probe0", + identifier="0123", + probe_model=probe_model0, + ) - # for device in nwbfile.devices.values(): - # print("-------------------") - # print(device) - # if isinstance(device, ndx_extracellular_channels.ProbeModel): - # print(device.name) - # print(device.manufacturer) - # print(device.ndim) - # print(device.planar_contour_in_um) - # print(device.contacts_table.to_dataframe()) - # if isinstance(device, ndx_extracellular_channels.Probe): - # pi_probe = ndx_extracellular_channels.to_probeinterface(device) - # print(pi_probe) - - # TODO add more tests for other probeinterface IO functions + pi_probe0 = ndx_extracellular_channels.to_probeinterface(probe0) + assert pi_probe0.ndim == 2 + assert pi_probe0.si_units == "um" + assert pi_probe0.name == "probe0" + assert pi_probe0.serial_number == "0123" + assert pi_probe0.model_name == "a1x32-edge-5mm-20-177_H32" + assert pi_probe0.manufacturer == "Neuronexus" + npt.assert_array_equal(pi_probe0.contact_positions, probe_model0.contacts_table.relative_position_in_mm) + npt.assert_array_equal(pi_probe0.contact_shapes, "circle") + npt.assert_array_equal(pi_probe0.to_numpy()["radius"], 5.0) + + ct2 = ndx_extracellular_channels.ContactsTable( + description="Test contacts table", + ) + + # for testing, mix and match different shapes. np.nan means the radius/width/height does not apply + ct2.add_row( + relative_position_in_mm=[10.0, 10.0], + shape="circle", + contact_id="C1", + shank_id="shank0", + plane_axes=[[0.0, 1.0], [1.0, 0.0]], # TODO make realistic + radius_in_um=10.0, + width_in_um=np.nan, + height_in_um=np.nan, + device_channel=1, + ) + ct2.add_row( + relative_position_in_mm=[20.0, 10.0], + shape="square", + contact_id="C2", + shank_id="shank0", + plane_axes=[[0.0, 1.0], [1.0, 0.0]], # TODO make realistic + radius_in_um=np.nan, + width_in_um=10.0, + height_in_um=10.0, + device_channel=2, + ) + probe_model1 = ndx_extracellular_channels.ProbeModel( + name="Neuropixels 1.0", + description="A neuropixels probe", + model="Neuropixels 1.0", + manufacturer="IMEC", + planar_contour_in_um=[[-10.0, -10.0], [10.0, -10.0], [10.0, 10.0], [-10.0, 10.0]], + contacts_table=ct2, + ) + + # create a probe + probe1 = ndx_extracellular_channels.Probe( + name="probe1", + identifier="7890", + probe_model=probe_model1, + ) + + pi_probe1 = ndx_extracellular_channels.to_probeinterface(probe1) + assert pi_probe1.ndim == 2 + assert pi_probe1.si_units == "um" + assert pi_probe1.name == "probe1" + assert pi_probe1.serial_number == "7890" + assert pi_probe1.model_name == "Neuropixels 1.0" + assert pi_probe1.manufacturer == "IMEC" + npt.assert_array_equal(pi_probe1.contact_positions, probe_model1.contacts_table.relative_position_in_mm) + npt.assert_array_equal(pi_probe1.contact_shapes, ["circle", "square"]) + npt.assert_array_equal(pi_probe1.to_numpy()["radius"], [10.0, np.nan]) + npt.assert_array_equal(pi_probe1.to_numpy()["width"], [np.nan, 10.0]) + npt.assert_array_equal(pi_probe1.to_numpy()["height"], [np.nan, 10.0]) + + # add Probe as NWB Devices + nwbfile.add_device(probe_model0) + nwbfile.add_device(probe0) + + with pynwb.NWBHDF5IO("test_probeinterface.nwb", "w") as io: + io.write(nwbfile) + + # read the file and test whether the read probe can be converted back to probeinterface correctly + with pynwb.NWBHDF5IO("test_probeinterface.nwb", "r") as io: + nwbfile = io.read() + read_probe = nwbfile.devices["probe0"] + pi_probe = ndx_extracellular_channels.to_probeinterface(read_probe) + assert pi_probe.ndim == 2 + assert pi_probe.si_units == "um" + assert pi_probe.name == "probe0" + assert pi_probe.serial_number == "0123" + assert pi_probe.model_name == "a1x32-edge-5mm-20-177_H32" + assert pi_probe.manufacturer == "Neuronexus" + npt.assert_array_equal(pi_probe.contact_positions, probe_model0.contacts_table.relative_position_in_mm) + npt.assert_array_equal(pi_probe.to_numpy()["radius"], 5.0) + npt.assert_array_equal(pi_probe.contact_shapes, "circle")