From 3a2fcc9c6f38e0f236e17ef16dca78c6be77d259 Mon Sep 17 00:00:00 2001 From: Ryan Harvey Date: Mon, 25 Nov 2024 09:24:50 -0500 Subject: [PATCH] testing --- neuro_py/raw/__init__.pyi | 4 +- neuro_py/raw/preprocessing.py | 152 +++++++++++++++++++++++++++++++++- 2 files changed, 150 insertions(+), 6 deletions(-) diff --git a/neuro_py/raw/__init__.pyi b/neuro_py/raw/__init__.pyi index fccd211..1c091e9 100644 --- a/neuro_py/raw/__init__.pyi +++ b/neuro_py/raw/__init__.pyi @@ -1,3 +1,3 @@ -__all__ = ["remove_artifacts"] +__all__ = ["remove_artifacts","downsample_binary"] -from .preprocessing import remove_artifacts +from .preprocessing import remove_artifacts, downsample_binary diff --git a/neuro_py/raw/preprocessing.py b/neuro_py/raw/preprocessing.py index 535d2d6..20273bf 100644 --- a/neuro_py/raw/preprocessing.py +++ b/neuro_py/raw/preprocessing.py @@ -1,9 +1,11 @@ import gc import os import warnings -from typing import List, Tuple, Optional +from typing import List, Optional, Tuple +import numba as nb import numpy as np +from scipy.signal import butter, firwin, sosfiltfilt def remove_artifacts( @@ -96,9 +98,7 @@ def remove_artifacts( data[start, ch], data[end, ch], end - start, - ).astype( - data.dtype - ) # Ensure consistent dtype + ).astype(data.dtype) # Ensure consistent dtype data[start:end, ch] = interpolated else: warnings.warn( @@ -172,3 +172,147 @@ def remove_artifacts( f.write(f"Zeroed intervals: {zero_intervals.tolist()}\n") except Exception as e: warnings.warn(f"Failed to create log file: {e}") + + +def downsample_binary( + filepath: str, + n_channels: int, + original_fs: int = 20000, + target_fs: int = 1250, + precision: str = "int16", + filter_order: int = 4, +) -> str: + """ + Optimized function to downsample raw binary data. + """ + if original_fs % target_fs != 0: + raise ValueError( + "Original sampling frequency must be an integer multiple of the target frequency." + ) + + downsample_factor = original_fs // target_fs + nyquist = target_fs / 2 + + # Design a stable low-pass filter + sos = butter(filter_order, nyquist / (original_fs / 2), btype="low", output="sos") + + downsampled_filepath = ( + os.path.splitext(filepath)[0] + ".lfp" + ) + + bytes_size = np.dtype(precision).itemsize + chunk_size = 10_000 # Adjust for optimal performance + with open(filepath, "rb") as infile, open(downsampled_filepath, "wb") as outfile: + infile.seek(0, 2) + n_samples = infile.tell() // (n_channels * bytes_size) + infile.seek(0, 0) + + for start_idx in range(0, n_samples, chunk_size): + end_idx = min(start_idx + chunk_size, n_samples) + n_chunk_samples = end_idx - start_idx + + # Load chunk + data = np.fromfile( + infile, dtype=precision, count=n_chunk_samples * n_channels + ) + data = data.reshape((n_chunk_samples, n_channels)) + + # Filter and downsample + filtered_data = sosfiltfilt(sos, data, axis=0) + downsampled_data = filtered_data[::downsample_factor, :] + + # Write to output file + downsampled_data.astype(precision).tofile(outfile) + + del data, filtered_data, downsampled_data + gc.collect() + + return downsampled_filepath + + + + + +@nb.jit(nopython=True, parallel=True, fastmath=True) +def filter_and_downsample(data, fir_coeffs, downsample_factor): + """ + JIT-compiled function to filter and downsample data. + """ + n_samples, n_channels = data.shape + n_output_samples = n_samples // downsample_factor + output = np.zeros((n_output_samples, n_channels), dtype=data.dtype) + + for ch in nb.prange(n_channels): + # Convolve with FIR filter (linear phase, symmetric) + filtered = np.convolve(data[:, ch], fir_coeffs, mode="valid") + # Downsample + output[:, ch] = filtered[::downsample_factor] + + return output + + +def downsample_binary_ultrafast( + filepath: str, + n_channels: int, + original_fs: int = 20000, + target_fs: int = 1250, + precision: str = "int16", + filter_order: int = 64, +) -> str: + """ + Ultrafast function to downsample raw binary data. + """ + if original_fs % target_fs != 0: + raise ValueError("Original sampling frequency must be an integer multiple of the target frequency.") + + downsample_factor = original_fs // target_fs + nyquist = target_fs / 2 + + # Design FIR filter + fir_coeffs = firwin(filter_order + 1, nyquist / (original_fs / 2), pass_zero="lowpass") + + # Output file + downsampled_filepath = os.path.splitext(filepath)[0] + ".lfp" + + # Memory-mapped I/O setup + bytes_size = np.dtype(precision).itemsize + chunk_size = 10_000_000 # Process 10M samples at a time for I/O efficiency + with open(filepath, "rb") as infile, open(downsampled_filepath, "wb") as outfile: + infile.seek(0, 2) + n_samples = infile.tell() // (n_channels * bytes_size) + infile.seek(0, 0) + + for start_idx in range(0, n_samples, chunk_size): + end_idx = min(start_idx + chunk_size, n_samples) + n_chunk_samples = end_idx - start_idx + + # Load chunk + data = np.fromfile(infile, dtype=precision, count=n_chunk_samples * n_channels) + data = data.reshape((n_chunk_samples, n_channels)) + + # Filter and downsample + downsampled_data = filter_and_downsample(data, fir_coeffs, downsample_factor) + + # Write to output file + downsampled_data.astype(precision).tofile(outfile) + + del data, downsampled_data + gc.collect() + + return downsampled_filepath + + +if __name__ == "__main__": + # time function + import time + + start = time.time() + downsample_binary_ultrafast( + filepath=r"U:\data\hpc_ctx_project\HP13\HP13_day1_20241030\HP13_probe_241030_111814\amplifier - Copy.dat", + n_channels=128, + original_fs=20000, + target_fs=1250, + precision="int16", + filter_order=4, + ) + print(f"Elapsed time: {time.time() - start:.2f} s")