diff --git a/docker_test/docker_test.py b/docker_test/docker_test.py new file mode 100644 index 0000000..b289ecf --- /dev/null +++ b/docker_test/docker_test.py @@ -0,0 +1,34 @@ +import spikeextractors as se +import spikesorters as ss + + +rec, _ = se.example_datasets.toy_example(dumpable=True) + +output_folder = "ms4_test_docker" + +sorting = ss.run_klusta(rec, output_folder=output_folder, use_docker=True) + +print(f"KL found #{len(sorting.get_unit_ids())} units") + + +# output_folder = "kl_test_docker" +# +# sorting_KL = ssd.run_klusta(rec, output_folder=output_folder) +# +# print(f"KL found #{len(sorting_KL.get_unit_ids())} units") +# +# rec, _ = se.example_datasets.toy_example(dumpable=True) +# +# output_folder = "sc_test_docker" +# +# sorting_SC = ssd.run_spykingcircus(rec, output_folder=output_folder) +# +# print(f"SC found #{len(sorting_SC.get_unit_ids())} units") +# +# rec, _ = se.example_datasets.toy_example(dumpable=True) +# +# output_folder = "hs_test_docker" +# +# sorting_HS = ssd.run_herdingspikes(rec, output_folder=output_folder) +# +# print(f"HS found #{len(sorting_HS.get_unit_ids())} units") diff --git a/spikesorters/docker_tools.py b/spikesorters/docker_tools.py new file mode 100644 index 0000000..723ffee --- /dev/null +++ b/spikesorters/docker_tools.py @@ -0,0 +1,74 @@ +import spikeextractors as se +import time +import numpy as np +from pathlib import Path + +ss_folder = Path(__file__).parent + +try: + import hither2 as hither + import docker + + HAVE_DOCKER = True + + default_docker_images = { + "klusta": hither.DockerImageFromScript(name="klusta", dockerfile=str(ss_folder / "docker_images" / "v0.12" / "klusta" / "Dockerfile")), + "mountainsort4": hither.DockerImageFromScript(name="ms4", dockerfile=str(ss_folder / "docker_images" / "v0.12" / "mountainsort4" / "Dockerfile")), + "herdingspikes": hither.LocalDockerImage('spikeinterface/herdingspikes-si-0.12:0.3.7'), + "spykingcircus": hither.LocalDockerImage('spikeinterface/spyking-circus-si-0.12:1.0.7') + } + +except ImportError: + HAVE_DOCKER = False + + +def modify_input_folder(dump_dict, input_folder="/input"): + if "kwargs" in dump_dict.keys(): + dcopy_kwargs, folder_to_mount = modify_input_folder(dump_dict["kwargs"]) + dump_dict["kwargs"] = dcopy_kwargs + return dump_dict, folder_to_mount + else: + if "file_path" in dump_dict: + file_path = Path(dump_dict["file_path"]) + folder_to_mount = file_path.parent + file_relative = file_path.relative_to(folder_to_mount) + dump_dict["file_path"] = f"{input_folder}/{str(file_relative)}" + return dump_dict, folder_to_mount + elif "folder_path" in dump_dict: + folder_path = Path(dump_dict["folder_path"]) + folder_to_mount = folder_path.parent + folder_relative = folder_path.relative_to(folder_to_mount) + dump_dict["folder_path"] = f"{input_folder}/{str(folder_relative)}" + return dump_dict, folder_to_mount + elif "file_or_folder_path" in dump_dict: + file_or_folder_path = Path(dump_dict["file_or_folder_path"]) + folder_to_mount = file_or_folder_path.parent + file_or_folder_relative = file_or_folder_path.relative_to(folder_to_mount) + dump_dict["file_or_folder_path"] = f"{input_folder}/{str(file_or_folder_relative)}" + return dump_dict, folder_to_mount + else: + raise Exception + + +def return_local_data_folder(recording, input_folder='/input'): + """ + Modifies recording dictionary so that the file_path, folder_path, or file_or_folder path is relative to the + 'input_folder' + + Parameters + ---------- + recording: se.RecordingExtractor + input_folder: str + + Returns + ------- + dump_dict: dict + + """ + assert recording.is_dumpable + from copy import deepcopy + + d = recording.dump_to_dict() + dcopy = deepcopy(d) + + return modify_input_folder(dcopy, input_folder) \ No newline at end of file diff --git a/spikesorters/run_funtions/__init__.py b/spikesorters/run_funtions/__init__.py new file mode 100644 index 0000000..741e613 --- /dev/null +++ b/spikesorters/run_funtions/__init__.py @@ -0,0 +1 @@ +from .run_functions import _run_sorter_local, _run_sorter_hither \ No newline at end of file diff --git a/spikesorters/run_funtions/run_functions.py b/spikesorters/run_funtions/run_functions.py new file mode 100644 index 0000000..73539b4 --- /dev/null +++ b/spikesorters/run_funtions/run_functions.py @@ -0,0 +1,149 @@ +from ..docker_tools import HAVE_DOCKER +from ..sorterlist import sorter_dict, sorter_full_list + + +if HAVE_DOCKER: + # conditional definition of hither tools + import time + from pathlib import Path + import hither2 as hither + import spikeextractors as se + import numpy as np + import shutil + from ..docker_tools import modify_input_folder, default_docker_images + + class SpikeSortingDockerHook(hither.RuntimeHook): + def __init__(self): + super().__init__() + + def precontainer(self, context: hither.PreContainerContext): + # this gets run outside the container before the run, and we have a chance to mutate the kwargs, + # add bind mounts, and set the image + input_directory = context.kwargs['input_directory'] + output_directory = context.kwargs['output_directory'] + + print("Input:", input_directory) + print("Output:", output_directory) + context.add_bind_mount(hither.BindMount(source=input_directory, + target='/input', read_only=True)) + context.add_bind_mount(hither.BindMount(source=output_directory, + target='/output', read_only=False)) + context.image = default_docker_images[context.kwargs['sorter_name']] + context.kwargs['output_directory'] = '/output' + context.kwargs['input_directory'] = '/input' + + + @hither.function('run_sorter_docker_with_container', + '0.1.0', + image=True, + runtime_hooks=[SpikeSortingDockerHook()]) + def run_sorter_docker_with_container( + recording_dict, sorter_name, input_directory, output_directory, **kwargs + ): + recording = se.load_extractor_from_dict(recording_dict) + # run sorter + kwargs["output_folder"] = f"{output_directory}/working" + t_start = time.time() + # set output folder within the container + sorting = _run_sorter_local(sorter_name, recording, **kwargs) + t_stop = time.time() + print(f'{sorter_name} run time {np.round(t_stop - t_start)}s') + # save sorting to npz + se.NpzSortingExtractor.write_sorting(sorting, f"{output_directory}/sorting_docker.npz") + + def _run_sorter_hither(sorter_name, recording, output_folder=None, delete_output_folder=False, + grouping_property=None, parallel=False, verbose=False, raise_error=True, + n_jobs=-1, joblib_backend='loky', **params): + assert recording.is_dumpable, "Cannot run not dumpable recordings in docker" + if output_folder is None: + output_folder = sorter_name + '_output' + output_folder = Path(output_folder).absolute() + output_folder.mkdir(exist_ok=True, parents=True) + + with hither.Config(use_container=True, show_console=True): + dump_dict_container, input_directory = modify_input_folder(recording.dump_to_dict(), '/input') + print(dump_dict_container) + kwargs = dict(recording_dict=dump_dict_container, + sorter_name=sorter_name, + output_folder=str(output_folder), + delete_output_folder=False, + grouping_property=grouping_property, parallel=parallel, + verbose=verbose, raise_error=raise_error, n_jobs=n_jobs, + joblib_backend=joblib_backend) + + kwargs.update(params) + kwargs.update({'input_directory': str(input_directory), 'output_directory': str(output_folder)}) + sorting_job = hither.Job(run_sorter_docker_with_container, kwargs) + sorting_job.wait() + sorting = se.NpzSortingExtractor(output_folder / "sorting_docker.npz") + if delete_output_folder: + shutil.rmtree(output_folder) + return sorting +else: + def _run_sorter_hither(sorter_name, recording, output_folder=None, delete_output_folder=False, + grouping_property=None, parallel=False, verbose=False, raise_error=True, + n_jobs=-1, joblib_backend='loky', **params): + raise ImportError() + + +# generic launcher via function approach +def _run_sorter_local(sorter_name_or_class, recording, output_folder=None, delete_output_folder=False, + grouping_property=None, parallel=False, verbose=False, raise_error=True, n_jobs=-1, + joblib_backend='loky', **params): + """ + Generic function to run a sorter via function approach. + + Two usages with name or class: + + by name: + >>> sorting = run_sorter('tridesclous', recording) + + by class: + >>> sorting = run_sorter(TridesclousSorter, recording) + + Parameters + ---------- + sorter_name_or_class: str or SorterClass + The sorter to retrieve default parameters from + recording: RecordingExtractor + The recording extractor to be spike sorted + output_folder: str or Path + Path to output folder + delete_output_folder: bool + If True, output folder is deleted (default False) + grouping_property: str + Splits spike sorting by 'grouping_property' (e.g. 'groups') + parallel: bool + If True and spike sorting is by 'grouping_property', spike sorting jobs are launched in parallel + verbose: bool + If True, output is verbose + raise_error: bool + If True, an error is raised if spike sorting fails (default). If False, the process continues and the error is + logged in the log file. + n_jobs: int + Number of jobs when parallel=True (default=-1) + joblib_backend: str + joblib backend when parallel=True (default='loky') + **params: keyword args + Spike sorter specific arguments (they can be retrieved with 'get_default_params(sorter_name_or_class)' + + Returns + ------- + sortingextractor: SortingExtractor + The spike sorted data + + """ + if isinstance(sorter_name_or_class, str): + SorterClass = sorter_dict[sorter_name_or_class] + elif sorter_name_or_class in sorter_full_list: + SorterClass = sorter_name_or_class + else: + raise ValueError('Unknown sorter') + + sorter = SorterClass(recording=recording, output_folder=output_folder, grouping_property=grouping_property, + verbose=verbose, delete_output_folder=delete_output_folder) + sorter.set_params(**params) + sorter.run(raise_error=raise_error, parallel=parallel, n_jobs=n_jobs, joblib_backend=joblib_backend) + sortingextractor = sorter.get_result(raise_error=raise_error) + + return sortingextractor \ No newline at end of file diff --git a/spikesorters/sorterlist.py b/spikesorters/sorterlist.py index 5022507..e378b1d 100644 --- a/spikesorters/sorterlist.py +++ b/spikesorters/sorterlist.py @@ -13,6 +13,11 @@ from .yass import YassSorter from .combinato import CombinatoSorter + +from .run_funtions import _run_sorter_local, _run_sorter_hither +from .docker_tools import HAVE_DOCKER + + sorter_full_list = [ HDSortSorter, KlustaSorter, @@ -33,10 +38,9 @@ sorter_dict = {s.sorter_name: s for s in sorter_full_list} -# generic launcher via function approach def run_sorter(sorter_name_or_class, recording, output_folder=None, delete_output_folder=False, - grouping_property=None, parallel=False, verbose=False, raise_error=True, n_jobs=-1, joblib_backend='loky', - **params): + grouping_property=None, use_docker=False, parallel=False, verbose=False, raise_error=True, n_jobs=-1, + joblib_backend='loky', **params): """ Generic function to run a sorter via function approach. @@ -58,6 +62,8 @@ def run_sorter(sorter_name_or_class, recording, output_folder=None, delete_outpu Path to output folder delete_output_folder: bool If True, output folder is deleted (default False) + use_docker: bool + If True and docker backend is installed, spike sorting is run in a docker image grouping_property: str Splits spike sorting by 'grouping_property' (e.g. 'groups') parallel: bool @@ -80,20 +86,26 @@ def run_sorter(sorter_name_or_class, recording, output_folder=None, delete_outpu The spike sorted data """ - if isinstance(sorter_name_or_class, str): - SorterClass = sorter_dict[sorter_name_or_class] - elif sorter_name_or_class in sorter_full_list: - SorterClass = sorter_name_or_class + if use_docker: + assert HAVE_DOCKER, "To run in docker, install docker and hitheron your system and >>> pip install hither docker" + + # we need sorter name here + if isinstance(sorter_name_or_class, str): + sorter_name = sorter_name_or_class + elif sorter_name_or_class in sorter_full_list: + sorter_name = sorter_name_or_class.sorter_name + else: + raise ValueError('Unknown sorter') + sorting = _run_sorter_hither(sorter_name, recording, output_folder=output_folder, + delete_output_folder=delete_output_folder, grouping_property=grouping_property, + parallel=parallel, verbose=verbose, raise_error=raise_error, n_jobs=n_jobs, + joblib_backend=joblib_backend, **params) else: - raise (ValueError('Unknown sorter')) - - sorter = SorterClass(recording=recording, output_folder=output_folder, grouping_property=grouping_property, - verbose=verbose, delete_output_folder=delete_output_folder) - sorter.set_params(**params) - sorter.run(raise_error=raise_error, parallel=parallel, n_jobs=n_jobs, joblib_backend=joblib_backend) - sortingextractor = sorter.get_result(raise_error=raise_error) - - return sortingextractor + sorting = _run_sorter_local(sorter_name_or_class, recording, output_folder=output_folder, + delete_output_folder=delete_output_folder, grouping_property=grouping_property, + parallel=parallel, verbose=verbose, raise_error=raise_error, n_jobs=n_jobs, + joblib_backend=joblib_backend, **params) + return sorting def available_sorters(): @@ -110,6 +122,7 @@ def installed_sorters(): l = sorted([s.sorter_name for s in sorter_full_list if s.is_installed()]) return l + def print_sorter_versions(): """ Prints versions of all installed sorters.