diff --git a/navis/transforms/elastix.py b/navis/transforms/elastix.py index 743187cd..aa2beff8 100644 --- a/navis/transforms/elastix.py +++ b/navis/transforms/elastix.py @@ -20,6 +20,7 @@ import subprocess import shutil import tempfile +import platform import numpy as np import pandas as pd @@ -27,10 +28,10 @@ from .base import BaseTransform from ..utils import make_iterable -_search_path = [i for i in os.environ['PATH'].split(os.pathsep) if len(i) > 0] +_search_path = [i for i in os.environ["PATH"].split(os.pathsep) if len(i) > 0] -def find_elastixbin(tool: str = 'transformix') -> str: +def find_elastixbin(tool: str = "transformix") -> str: """Find directory with elastix binaries.""" for path in _search_path: path = pathlib.Path(path) @@ -45,7 +46,12 @@ def find_elastixbin(tool: str = 'transformix') -> str: raise -_elastixbin = find_elastixbin() +if platform.system() == "Windows": + # On Windows, we have to search for `transformix.exe` + # We can still invoke it as `transformix` via the command line though + _elastixbin = find_elastixbin("transformix.exe") +else: + _elastixbin = find_elastixbin("transformix") def setup_elastix(): @@ -53,7 +59,7 @@ def setup_elastix(): Briefly: elastix requires the `LD_LIBRARY_PATH` (Linux) or `LDY_LIBRARY_PATH` (OSX) environment variables to (also) point to the directory with the - elastix `lib` directory. For reasons unknown to me, these varibles do not + elastix `lib` directory. For reasons unknown to me, these variables do not make it into the Python session. Hence, we have to set them here explicitly. Above info is based on: https://github.com/jasper-tms/pytransformix @@ -64,18 +70,18 @@ def setup_elastix(): return # Check if this variable already exists - var = os.environ.get('LD_LIBRARY_PATH', os.environ.get('LDY_LIBRARY_PATH', '')) + var = os.environ.get("LD_LIBRARY_PATH", os.environ.get("LDY_LIBRARY_PATH", "")) # Get the actual path - path = (_elastixbin.parent / 'lib').absolute() + path = (_elastixbin.parent / "lib").absolute() if str(path) not in var: - var = f'{path}{os.pathsep}{var}' if var else str(path) + var = f"{path}{os.pathsep}{var}" if var else str(path) # Note that `LD_LIBRARY_PATH` works for both Linux and OSX - os.environ['LD_LIBRARY_PATH'] = var + os.environ["LD_LIBRARY_PATH"] = var # As per navis/issues/112 - os.environ['DYLD_LIBRARY_PATH'] = var + os.environ["DYLD_LIBRARY_PATH"] = var setup_elastix() @@ -83,37 +89,40 @@ def setup_elastix(): def requires_elastix(func): """Check if elastix is available.""" + @functools.wraps(func) def wrapper(*args, **kwargs): if not _elastixbin: - raise ValueError("Could not find elastix binaries. Please download " - "the releases page at https://github.com/SuperElastix/elastix, " - "unzip at a convenient location and add that " - "location to your PATH variable. Note that you " - "will also have to set a LD_LIBRARY_PATH (Linux) " - "or DYLD_LIBRARY_PATH (OSX) variable. See the " - "elastic manual (release page) for details.") + raise ValueError( + "Could not find elastix binaries. Please download " + "the releases page at https://github.com/SuperElastix/elastix, " + "unzip at a convenient location and add that " + "location to your PATH variable. Note that you " + "will also have to set a LD_LIBRARY_PATH (Linux) " + "or DYLD_LIBRARY_PATH (OSX) variable. See the " + "elastic manual (release page) for details." + ) return func(*args, **kwargs) + return wrapper @requires_elastix def elastix_version(as_string=False): """Get elastix version.""" - p = subprocess.run([_elastixbin / 'elastix', '--version'], - capture_output=True) + p = subprocess.run([_elastixbin / "elastix", "--version"], capture_output=True) if p.stderr: - raise BaseException(f'Error running elastix:\n{p.stderr.decode()}') + raise BaseException(f"Error running elastix:\n{p.stderr.decode()}") - version = p.stdout.decode('utf-8').rstrip() + version = p.stdout.decode("utf-8").rstrip() # Extract version from "elastix version: 5.0.1" - version = version.split(':')[-1] + version = version.split(":")[-1] if as_string: return version else: - return tuple(int(v) for v in version.split('.')) + return tuple(int(v) for v in version.split(".")) class ElastixTransform(BaseTransform): @@ -146,44 +155,48 @@ def __init__(self, file: str, copy_files=[]): self.file = pathlib.Path(file) self.copy_files = copy_files - def __eq__(self, other: 'ElastixTransform') -> bool: + def __eq__(self, other: "ElastixTransform") -> bool: """Implement equality comparison.""" if isinstance(other, ElastixTransform): if self.file == other.file: return True return False - def check_if_possible(self, on_error: str = 'raise'): + def check_if_possible(self, on_error: str = "raise"): """Check if this transform is possible.""" if not _elastixbin: - msg = 'Folder with elastix binaries not found. Make sure the ' \ - 'directory is in your PATH environment variable.' - if on_error == 'raise': + msg = ( + "Folder with elastix binaries not found. Make sure the " + "directory is in your PATH environment variable." + ) + if on_error == "raise": raise BaseException(msg) return msg if not self.file.is_file(): - msg = f'Transformation file {self.file} not found.' - if on_error == 'raise': + msg = f"Transformation file {self.file} not found." + if on_error == "raise": raise BaseException(msg) return msg - def copy(self) -> 'ElastixTransform': + def copy(self) -> "ElastixTransform": """Return copy.""" # Attributes not to copy no_copy = [] # Generate new empty transform x = self.__class__(self.file) # Override with this neuron's data - x.__dict__.update({k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy}) + x.__dict__.update( + {k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy} + ) return x def write_input_file(self, points, filepath): """Write a numpy array in format required by transformix.""" - with open(filepath, 'w') as f: - f.write('point\n{}\n'.format(len(points))) + with open(filepath, "w") as f: + f.write("point\n{}\n".format(len(points))) for x, y, z in points: - f.write(f'{x:f} {y:f} {z:f}\n') + f.write(f"{x:f} {y:f} {z:f}\n") def read_output_file(self, filepath) -> np.ndarray: """Load output file. @@ -200,10 +213,10 @@ def read_output_file(self, filepath) -> np.ndarray: """ points = [] - with open(filepath, 'r') as f: + with open(filepath, "r") as f: for line in f.readlines(): - output = line.split('OutputPoint = [ ')[1].split(' ]')[0] - points.append([float(i) for i in output.split(' ')]) + output = line.split("OutputPoint = [ ")[1].split(" ]")[0] + points.append([float(i) for i in output.split(" ")]) return np.array(points) def xform(self, points: np.ndarray, return_logs=False) -> np.ndarray: @@ -223,16 +236,20 @@ def xform(self, points: np.ndarray, return_logs=False) -> np.ndarray: Transformed points. """ - self.check_if_possible(on_error='raise') + self.check_if_possible(on_error="raise") if isinstance(points, pd.DataFrame): # Make sure x/y/z columns are present - if np.any([c not in points for c in ['x', 'y', 'z']]): - raise ValueError('points DataFrame must have x/y/z columns.') - points = points[['x', 'y', 'z']].values - elif not (isinstance(points, np.ndarray) and points.ndim == 2 and points.shape[1] == 3): - raise TypeError('`points` must be numpy array of shape (N, 3) or ' - 'pandas DataFrame with x/y/z columns') + if np.any([c not in points for c in ["x", "y", "z"]]): + raise ValueError("points DataFrame must have x/y/z columns.") + points = points[["x", "y", "z"]].values + elif not ( + isinstance(points, np.ndarray) and points.ndim == 2 and points.shape[1] == 3 + ): + raise TypeError( + "`points` must be numpy array of shape (N, 3) or " + "pandas DataFrame with x/y/z columns" + ) # Everything happens in a temporary directory with tempfile.TemporaryDirectory() as tempdir: @@ -244,13 +261,21 @@ def xform(self, points: np.ndarray, return_logs=False) -> np.ndarray: _ = pathlib.Path(shutil.copy(f, p)) # Write points to file - in_file = p / 'inputpoints.txt' + in_file = p / "inputpoints.txt" self.write_input_file(points, in_file) - out_file = p / 'outputpoints.txt' + out_file = p / "outputpoints.txt" # Prepare the command - command = [_elastixbin / 'transformix', '-out', str(p), '-tp', str(self.file), '-def', str(in_file)] + command = [ + _elastixbin / "transformix", + "-out", + str(p), + "-tp", + str(self.file), + "-def", + str(in_file), + ] # Keep track of current working directory cwd = os.getcwd() @@ -269,16 +294,18 @@ def xform(self, points: np.ndarray, return_logs=False) -> np.ndarray: os.chdir(cwd) if return_logs: - logfile = p / 'transformix.log' + logfile = p / "transformix.log" if not logfile.is_file(): - raise FileNotFoundError('No log file found.') + raise FileNotFoundError("No log file found.") with open(logfile) as f: logs = f.read() return logs if not out_file.is_file(): - raise FileNotFoundError('Elastix transform did not produce any ' - f'output:\n {proc.stdout.decode()}') + raise FileNotFoundError( + "Elastix transform did not produce any " + f"output:\n {proc.stdout.decode()}" + ) points_xf = self.read_output_file(out_file)