diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..bdb0cabc --- /dev/null +++ b/.gitattributes @@ -0,0 +1,17 @@ +# Auto detect text files and perform LF normalization +* text=auto + +# Custom for Visual Studio +*.cs diff=csharp + +# Standard to msysgit +*.doc diff=astextplain +*.DOC diff=astextplain +*.docx diff=astextplain +*.DOCX diff=astextplain +*.dot diff=astextplain +*.DOT diff=astextplain +*.pdf diff=astextplain +*.PDF diff=astextplain +*.rtf diff=astextplain +*.RTF diff=astextplain diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..1d98e997 --- /dev/null +++ b/.gitignore @@ -0,0 +1,48 @@ +# Windows image file caches +Thumbs.db +ehthumbs.db + +# Folder config file +Desktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Compiled source +*.mexa64 +*.mexw64 +*.asv + +# Windows Installer files +*.cab +*.msi +*.msm +*.msp + +# Windows shortcuts +*.lnk + +# ========================= +# Operating System Files +# ========================= + +# OSX +# ========================= + +.DS_Store +.AppleDouble +.LSOverride + +# Thumbnails +._* + +# Files that might appear on external disk +.Spotlight-V100 +.Trashes + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk diff --git a/CUDA/mexGPUall.m b/CUDA/mexGPUall.m new file mode 100644 index 00000000..37b88a5a --- /dev/null +++ b/CUDA/mexGPUall.m @@ -0,0 +1,15 @@ +% mexGPUall. For these to complete succesfully, you need to configure the +% Matlab GPU library first (see README files for platform-specific +% information) + mexcuda -largeArrayDims mexMPmuFEAT.cu + mexcuda -largeArrayDims mexMPregMU.cu + mexcuda -largeArrayDims mexWtW2.cu + +% mex -largeArrayDims mexMPmuFEAT.cu +% mex -largeArrayDims mexMPregMU.cu +% mex -largeArrayDims mexWtW2.cu + +% If you get uninterpretable errors like "An unexpected error occurred during CUDA execution", and if you are using Pascal GPUs +% (GTX 10xx series), it might be necessary to upgrade to Matlab 2017a / CUDA 8.0. + + diff --git a/CUDA/mexMPmuFEAT.cu b/CUDA/mexMPmuFEAT.cu new file mode 100644 index 00000000..19375805 --- /dev/null +++ b/CUDA/mexMPmuFEAT.cu @@ -0,0 +1,426 @@ +/* + * Example of how to use the mxGPUArray API in a MEX file. This example shows + * how to write a MEX function that takes a gpuArray input and returns a + * gpuArray output, e.g. B=mexFunction(A). + * + * Copyright 2012 The MathWorks, Inc. + */ +#include +#include +#include +#include +#include +#include "mex.h" +#include "gpu/mxGPUArray.h" +#include +#include +#include +using namespace std; + +const int Nthreads = 1024, NchanMax = 128, block = 32, NrankMax = 3; +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void Conv1D(const double *Params, const float *data, const float *W, float *conv_sig){ + volatile __shared__ float sW[81*NrankMax], sdata[(Nthreads+81)*NrankMax]; + float x; + int tid, tid0, bid, i, nid, Nrank, NT, Nfilt, nt0; + + tid = threadIdx.x; + bid = blockIdx.x; + Nfilt = (int) Params[1]; + NT = (int) Params[0]; + Nrank = (int) Params[6]; + nt0 = (int) Params[9]; + + if(tid0){ + for (i=0; i Cbest + 1e-6){ + Cbest = Cf; + xb = Ci - lam[i] * mu[i]; // /(lam[i] + 1); + ibest = i; + } + } + if (Cbest > Th*Th){ + err[tid0] = Cbest; + xbest[tid0] = xb; + ftype[tid0] = ibest; + } + } +} +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void cleanup_spikes(const double *Params, const float *xbest, const float *err, + const int *ftype, int *st, int *id, float *x, float *C, int *counter){ + int lockout, indx, maxFR, NTOT, tid, bid, NT, tid0, j; + volatile __shared__ float sdata[Nthreads+2*81+1]; + bool flag=0; + float err0; + + lockout = (int) Params[9] - 1; + tid = threadIdx.x; + bid = blockIdx.x; + + NT = (int) Params[0]; + maxFR = (int) Params[3]; + tid0 = bid * Nthreads; + + + if(tid01e-10){ + flag = 0; + for(j=-lockout;j<=lockout;j++) + if(sdata[tid+lockout+j]>err0){ + flag = 1; + break; + } + if(flag==0){ + indx = atomicAdd(&counter[0], 1); + if (indx=0 & tcurr>>(d_Params, d_data, d_W, d_dout); + for(int k=0;k<(int) Params[4];k++){ + cudaMemset(d_err, 0, NT * sizeof(float)); + cudaMemset(d_ftype, 0, NT * sizeof(int)); + cudaMemset(d_xbest, 0, NT * sizeof(float)); + + // compute the best filter + bestFilter<<>>( d_Params, + d_dout, d_mu, d_lam, d_nu, d_xbest, d_err, d_ftype); + + // ignore peaks that are smaller than another nearby peak + cleanup_spikes<<>>(d_Params, + d_xbest, d_err, d_ftype, d_st, d_id, d_x, d_C, d_counter); + + // add new spikes to 2nd counter + cudaMemcpy(counter, d_counter, sizeof(int), cudaMemcpyDeviceToHost); + if (counter[0]>maxFR){ + counter[0] = maxFR; + cudaMemcpy(d_counter, counter, sizeof(int), cudaMemcpyHostToDevice); + } + + // extract template features before subtraction + extractFEAT<<>>(d_Params, d_st, d_id, + d_x, d_counter, d_dout, d_WtW, d_lam, d_mu,d_feat); + + // subtract the detected spikes + subSpikes<<>>(d_Params, d_st, d_id, + d_x, d_counter, d_dout, d_WtW); + + // update 1st counter from 2nd counter + cudaMemcpy(d_counter+1, d_counter, sizeof(int), cudaMemcpyDeviceToDevice); + + if(counter[0]==maxFR) + break; + } + +// extractFEAT<<>>(d_Params, d_st, d_id, d_x, d_counter, d_dout, d_WtW, d_feat); + + float *x, *C, *feat; + int *st, *id; + int minSize; + if (counter[0] +#include +#include +#include +#include +#include "mex.h" +#include "gpu/mxGPUArray.h" +#include +#include +#include +using namespace std; + +const int Nthreads = 1024, NchanMax = 128, NrankMax = 3; +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void Conv1D(const double *Params, const float *data, const float *W, float *conv_sig){ + __shared__ float sW[81*NrankMax], sdata[(Nthreads+81)*NrankMax]; + float x; + int tid, nt0, tid0, bid, i, nid, Nrank, NT, Nfilt; + + tid = threadIdx.x; + bid = blockIdx.x; + Nfilt = (int) Params[1]; + NT = (int) Params[0]; + Nrank = (int) Params[6]; + nt0 = (int) Params[9]; + + if(tid0){ + for (i=0; i Cbest){ + Cbest = Cf; + xb = Ci - mu[i] * lam[i]; /// (lam[i] + 1); + ibest = i; + } + } + if (Cbest > Th*Th){ + err[tid0] = Cbest; + xbest[tid0] = xb; + ftype[tid0] = ibest; + } + } +} + +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void cleanup_spikes(const double *Params, const float *xbest, + const float *err, const int *ftype, const bool *UtU, int *st, int *id, float *x, + float *C, int *counter, float *nsp){ + int lockout, curr_token, indx, maxFR, Nfilt, NTOT, tid, bid, NT, tid0, j; + volatile __shared__ float sdata[Nthreads+2*81+1]; + volatile __shared__ int id_sh[Nthreads+2*81+1]; + bool flag=0; + float err0; + + lockout = (int) Params[9] - 1; + tid = threadIdx.x; + bid = blockIdx.x; + + NT = (int) Params[0]; + Nfilt = (int) Params[1]; + maxFR = (int) Params[3]; + tid0 = bid * Nthreads; + + if(tid01e-10){ + flag = 0; + for(j=-lockout;j<=lockout;j++) + if(sdata[tid+lockout+j]>err0) + if (UtU[curr_token*Nfilt + id_sh[tid+lockout+j]]){ + flag = 1; + break; + } + if(flag==0){ + indx = atomicAdd(&counter[0], 1); + if (indx>>(d_Params, d_data, d_W, d_dout); + bestFilter<<>>(d_Params, d_dout, d_mu, d_lam, d_nu, + d_xbest, d_err, d_ftype); + cleanup_spikes<<>>(d_Params, d_xbest, d_err, + d_ftype, d_UtU, d_st, d_id, d_x, d_C, d_counter, d_nsp); + + dim3 block(nt0, 1024/nt0); + average_snips<<>>( d_Params, d_st, d_id, d_x, d_counter, d_dataraw, d_dWU); + + cudaMemcpy(counter, d_counter, sizeof(int), cudaMemcpyDeviceToHost); + + plhs[0] = mxGPUCreateMxArrayOnGPU(dWU); + + float *x, *C; + int *st, *id; + int minSize; + if (counter[0] +#include +#include +#include +#include +#include "mex.h" +#include "gpu/mxGPUArray.h" +#include +#include +#include +using namespace std; + +const int Nthreads = 1024, nblock = 32; +////////////////////////////////////////////////////////////////////////////////////////// + +__global__ void crossFilter(const double *Params, const float *W1, const float *W2, + const float *UtU, float *WtW){ + __shared__ float shW1[nblock*81], shW2[nblock*81]; + + float x; + int nt0, tidx, tidy , bidx, bidy, i, NT, Nfilt, t; + + tidx = threadIdx.x; + tidy = threadIdx.y; + bidx = blockIdx.x; + bidy = blockIdx.y; + + Nfilt = (int) Params[1]; + nt0 = (int) Params[9]; + + while(tidx>>(d_Params, d_W1, d_W2, d_UtU, d_WtW); + + plhs[0] = mxGPUCreateMxArrayOnGPU(WtW); + + cudaFree(d_Params); + mxGPUDestroyGPUArray(WtW); + mxGPUDestroyGPUArray(W1); + mxGPUDestroyGPUArray(W2); + mxGPUDestroyGPUArray(UtU); + +} diff --git a/CUDA/mex_CUDA_glnxa64.xml b/CUDA/mex_CUDA_glnxa64.xml new file mode 100644 index 00000000..ebb0d2fb --- /dev/null +++ b/CUDA/mex_CUDA_glnxa64.xml @@ -0,0 +1,75 @@ + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/CUDA/mex_CUDA_win64.xml b/CUDA/mex_CUDA_win64.xml new file mode 100644 index 00000000..b1a9c726 --- /dev/null +++ b/CUDA/mex_CUDA_win64.xml @@ -0,0 +1,198 @@ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/Docs/phy_installation_with_templates.txt b/Docs/phy_installation_with_templates.txt new file mode 100644 index 00000000..e177d107 --- /dev/null +++ b/Docs/phy_installation_with_templates.txt @@ -0,0 +1,16 @@ +Follow the instructions at https://github.com/kwikteam/phy to install phy with the template-gui extension. + +Then run + +phy template-gui params.py (--debug, if you get an error) + +where params.py is a text file that KiloSort should produce. In case of errors, check that params.py contains information like this: + +dat_path = '20160330_ap_CAR.dat' % filename +n_channels_dat = 32 % number of channels +dtype = 'int16' +offset = 0 +sample_rate = 30000. +hp_filtered = False + + diff --git a/Docs/readme_mac.txt b/Docs/readme_mac.txt new file mode 100644 index 00000000..547b558a --- /dev/null +++ b/Docs/readme_mac.txt @@ -0,0 +1,52 @@ +***these instructions have not been tested in a long time, and Matlab appears to have simplified this process quite a bit on other platforms. +Try running mexGPUall right after installing everything. +If that fails, the rest of the instructions below might help. Particularly #7 appears required on Mac. *** + +Assuming gpu functions have been correctly compiled (see below), a "master_file.m" is available that you should copy to a local path and change for each of your experiments. +The logic is that the git folder might be updated, and when that happens all extraneous files in that folder will be deleted and any changes you made reverted. + +Mac (instructions from Dan Denman) +(w/ NVIDIA GeForce GTX 970; OS X 10.11.3, Xcode 7.2.1, CUDA 7.5.21) + +1. installed MATLAB R2015b +2. installed parallel computing toolbox license +3. modified ~.bashrc to include: +export CUDA_HOME=/usr/local/cuda +export PATH=/Developer/NVIDIA/CUDA­7.5/bin:$PATH +export DYLD_LIBRARY_PATH=/Developer/NVIDIA/CUDA­7.5/lib:$DYLD_LIBRARY_PATH +4. tried to run mexGPUall.m, MATLAB couldn’t find a compiler +updated Xcode to 7.2.XXX +5. did a bunch more stuff trying to specify MATLAB compiler, including adding this to ~.bashrc. not sure any of this is actually necessary. + +export CUDA_PATH=/usr/local/cuda +export MW_NVCC_PATH=/Developer/NVIDIA/CUDA­7.5/bin/nvcc:$MW_NVCC_PATH + +6. eventually, i found this thread, which led me to switch the ‘mex’ commands in mexGPUall.m to ‘mexcuda’ +I also had to modify nvcc_clang++.xml, which in my MATLAB is here: +/Applications/MATLAB_R2015b.app/toolbox/distcomp/gpu/extern/src/mex/maci64/nvcc_clang++.xml +changed Version = “7.0” to Version = “7.5” for my CUDA install +add some paths so it could find my Xcode install. +I also did this to nvcc_clang++_dynamic.xml. if you ever want more details I can provide. + +7. after that, mexcuda at least found nvcc. but, errors thrown during mexGPUall.m, like this one: + +/Users/administrator/Documents/MATLAB/KiloSort­master/mexWtW.cu:93:32: note: insert an +explicit cast to silence this issue +const mwSize dimsu[] = {Nfilt, Nfilt, ((2 * nt0) ­ 1)}; +^~~~~ +static_cast( ) + +Right, so i inserted static_cast( ) wherever necessary, which was in these places: + +file : line #s + +mexWtW.cu: 93 +mexMPreg.cu: 206,233 +mexMPsub.cu: 282,287,292 +mexMPmuFEAT.cu: 346,356 +mexMPregMU.cu: 231,236,264 +mexWtW2.cu: 97 + +8. ran mexGPUall.m + +...and MEX completed successfully! diff --git a/Docs/readme_win_linux.txt b/Docs/readme_win_linux.txt new file mode 100644 index 00000000..5af9591f --- /dev/null +++ b/Docs/readme_win_linux.txt @@ -0,0 +1,48 @@ +Assuming gpu functions have been correctly compiled (see below), a "master_file.m" is available that you should copy to a local path and change for each of your experiments. The logic is that the git folder might be updated, and when that happens all extraneous files in that folder will be deleted and any changes you made reverted. + +The following are instructions for setting up mex compilation of CUDA files with direct Matlab inputs. Note you need Matlab with the parallel computing toolbox. Instructions below are for Linux and Windows. Mac instructions are in a separate file (I haven't tested them though). If successful, you should be able to run mexGPUall. + +You should be able to run the code on a CPU without compiling any files, but it will be much, much slower than even on 250$ GPUs. For up to 32 channels, the CPU code might be fast enough. + +Windows + +Install Visual Studio Community (2012 or 2013) +Install CUDA 8.0 in Matlab R2017a (and 7.5 in R2016 etc; if in doubt, best to update to latest Matlab and CUDA). If you get a warning of not finding the GPU at the beginning of installation, this might be fine, unless you cannot run mexGPUall, in which case you should try a different combination of Nvidia/CUDA drivers. + +Try to run mexGPUall. If mexcuda gives an error, try the following: copy mex_CUDA_win64.xml (or nvcc_msvc120.xml, or a similarly named file, compatible with your Visual Studio installation version; 11 for 2012 and 12 for 2013) from here +matlabroot/toolbox/distcomp/gpu/extern/src/mex/win64 and into the KiloSort folder (or somewhere in your path). The included file with KiloSort will NOT be compatible with your local environment (unless you start changing paths inside it). + +If your video card is also driving your display, you need to disable the Windows watchdog that kills any process occupying the GPU for too long. +start regedit, +navigate to HKEY_LOCAL_MACHINE\System\CurrentControlSet\Control\GraphicsDrivers +create new DWORD key called TdrLevel, set value to 0, +restart PC. + +If you find that gpuDevice(1); takes more than 2 minutes each time you +run then you dont have the space to store its compiled files +GTX 1080 seems to have this issue. +In order to fix this set an environment variable "CUDA_CACHE_MAXSIZE " on the machine to +some high value like 1GB. By default "CUDA_CACHE_MAXSIZE" is 32MB. +In Windows you can do this in properties > advanced system settings > environment variables. +In order to set the cache to 1GB use CUDA_CACHE_MAXSIZE 1073741824. + + + +Linux + +UPDATE: for recent video cards/drivers, please see Carl Schoonover's instructions here https://groups.google.com/forum/#!topic/phy-users/g0FSHRI0Nao. + +Install CUDA (should ask for a compatible recent version of gcc, will install Nvidia drivers if necessary). + +Try to run mexGPUall. If mexcuda gives you an error, try something along the following lines + +Append to /home/marius/.bashrc then logout/login: +export CUDA_HOME=/usr/local/cuda-7.5 +export LD_LIBRARY_PATH=${CUDA_HOME}/lib64 +export PATH=${CUDA_HOME}/bin:${PATH} + +Copy mex_CUDA_glnxa64.xml (or a similarly named file, compatible with your Visual Studio installation) from here +matlabroot/toolbox/distcomp/gpu/extern/src/mex/ +and into the KiloSort folder (or somewhere in your path). The included file with KiloSort will NOT be compatible with your local environment (unless you start changing paths inside it). + + diff --git a/FwBwAddon/assembleBigDATA.m b/FwBwAddon/assembleBigDATA.m new file mode 100644 index 00000000..46084bdf --- /dev/null +++ b/FwBwAddon/assembleBigDATA.m @@ -0,0 +1,83 @@ +% assemble dataset + +probeName = {'K1', 'K2', 'K3', 'ZNP1', 'ZNP2', 'ZNP3', 'ZNP4', 'ZO'}; + +rootAlignments = '\\zserver.cortexlab.net\Data\Subjects\'; + +nt0 = 3e4; +NT = 2135; + +spks = []; +clu = []; +st = []; + +iMouse = 2; + +mname = {'Waksman', 'Krebs', 'Robbins'}; +datexp = {'2017-06-10', '2017-06-05', '2017-06-13'}; +rootZ = '\\zserver\Data\Subjects'; + +Nmax = 0; +for j = 1:length(probeName) + frootAlign = fullfile(rootAlignments, mname{iMouse}, datexp{iMouse}, 'alignments'); + + ops.dir_rez = 'G:\DATA\Spikes\'; + fname = sprintf('correct_ephys_%s_to_ephys_K1.npy', probeName{j}); + if j>1 + boff = readNPY(fullfile(frootAlign, fname)); + else + boff = [1 0]'; + end + + % save the result file + fname = fullfile(ops.dir_rez, sprintf('rez_%s_%s_%s.mat', mname{iMouse}, ... + datexp{iMouse}, probeName{j})); + load(fname); + + NN = rez.ops.Nfilt; + + t0 = ceil(rez.ops.trange(1) * ops.fs); + + nSpikes = numel(rez.st); + + st0 = t0 + rez.st; + spks(j).st = [st0(:)/ops.fs ones(nSpikes,1)] * boff; + + spks(j).clu = rez.clu(:); + + rez.W(7*374, 1) = 0; + spks(j).W = reshape(gather(rez.W), 7, 374, []); + [~, spks(j).Wheights] = max(sq(sum(spks(j).W.^2,1)), [], 1); + spks(j).wPCA = rez.wPCA; + + ycoords = rez.ycoords(rez.connected>0); + + spks(j).Wheights = ycoords(spks(j).Wheights); + + clu = cat(1, clu, Nmax + rez.clu(:)); + st = cat(1, st, spks(j).st(:)); + + Nmax = Nmax + max(rez.clu); + end + +save(fullfile('G:\DATA\Spikes\', sprintf('spks%s.mat', mname{iMouse})), 'spks') +%% + +S = sparse(max(1, ceil(st - rez.ops.trange(1))), clu, ones(1, numel(clu))); + +Sall = gpuArray(single(full(S))); +Sall = Sall(15:end-15, :); + + +%% +Slow = my_conv2(Sall,500,1); +rat = min(Slow, [], 1) ./max(Slow, [],1); +S0 = Sall(:, rat>.5); + +[U S V] = svdecon(S0 - mean(S0,1)); + +plot(U(:,3)) + +%% + +plot(U(:,4)) \ No newline at end of file diff --git a/FwBwAddon/configFileBench384.m b/FwBwAddon/configFileBench384.m new file mode 100644 index 00000000..bb1b4e0b --- /dev/null +++ b/FwBwAddon/configFileBench384.m @@ -0,0 +1,75 @@ + +ops.GPU = 1; % whether to run this code on an Nvidia GPU (much faster, mexGPUall first) +ops.parfor = 1; % whether to use parfor to accelerate some parts of the algorithm +ops.verbose = 1; % whether to print command line progress +ops.showfigures = 1; % whether to plot figures during optimization + +ops.datatype = 'bin'; % binary ('dat', 'bin') or 'openEphys' +ops.fbinary = 'F:\DATA\Spikes\Diego\smallFile.dat'; % will be created for 'openEphys' +ops.fproc = 'F:\DATA\Spikes\Diego\temp_wh.dat'; % residual from RAM of preprocessed data +ops.root = 'F:\DATA\Spikes\Diego\'; % 'openEphys' only: where raw files are + +ops.fs = 30000; % sampling rate +% ops.NchanTOT = 385; % total number of channels +% ops.Nchan = 344; % number of active channels +ops.Nfilt = 512*3; % number of filters to use (2-4 times more than Nchan, should be a multiple of 32) +ops.nNeighPC = 12; % visualization only (Phy): number of channnels to mask the PCs, leave empty to skip (12) +ops.nNeigh = 16; % visualization only (Phy): number of neighboring templates to retain projections of (16) + +% options for channel whitening +ops.whitening = 'full'; % type of whitening (default 'full', for 'noSpikes' set options for spike detection below) +ops.nSkipCov = 10; % compute whitening matrix from every N-th batch (1) +ops.whiteningRange = 32; % how many channels to whiten together (Inf for whole probe whitening, should be fine if Nchan<=128) + +% define the channel map as a filename (string) or simply an array +% ops.chanMap = 'F:\DATA\Spikes\Diego\chanMap.mat'; % 'F:\DATA\Spikes\forPRBimecP3opt3.mat'; % make this file using createChannelMapFile.m +ops.chanMap = 'G:\DATA\Spikes\neuropixPhase3A_kilosortChanMap.mat'; +ops.criterionNoiseChannels = 0.2; % fraction of "noise" templates allowed to span all channel groups (see createChannelMapFile for more info). +% ops.chanMap = 1:ops.Nchan; % treated as linear probe if a chanMap file + +% other options for controlling the model and optimization +ops.Nrank = 3; % matrix rank of spike template model (3) +ops.nfullpasses = 6; % number of complete passes through data during optimization (6) +ops.maxFR = 20000; % maximum number of spikes to extract per batch (20000) +ops.fshigh = 150; % frequency for high pass filtering +ops.ntbuff = 64; % samples of symmetrical buffer for whitening and spike detection +ops.scaleproc = 200; % int16 scaling of whitened data +ops.NT = 64*1024+ ops.ntbuff;% this is the batch size (try decreasing if out of memory) +% for GPU should be multiple of 32 + ntbuff + +% the following options can improve/deteriorate results. +% when multiple values are provided for an option, the first two are beginning and ending anneal values, +% the third is the value used in the final pass. +ops.Th = [4 12 12]; % threshold for detecting spikes on template-filtered data ([4 10 10]) +ops.lam = [5 20 20]; % large means amplitudes are forced around the mean ([5 20 20]) +ops.nannealpasses = 4; % should be less than nfullpasses (4) +ops.momentum = 1./[20 400]; % start with high momentum and anneal (1./[20 400]) +ops.shuffle_clusters = 1; % allow merges and splits during optimization (1) +ops.mergeT = .1; % upper threshold for merging (.1) +ops.splitT = .1; % lower threshold for splitting (.1) + +% options for initializing spikes from data +ops.initialize = 'no'; %'fromData' or 'no' +ops.spkTh = -6; % spike threshold in standard deviations (-6) +ops.loc_range = [5 4]; % ranges to detect peaks; plus/minus in time and channel ([3 1]) +ops.long_range = [30 6]; % ranges to detect isolated peaks ([30 6]) +ops.maskMaxChannels = 5; % how many channels to mask up/down ([5]) +ops.crit = .65; % upper criterion for discarding spike repeates (0.65) +ops.nFiltMax = 10000; % maximum "unique" spikes to consider (10000) + +% load predefined principal components (visualization only (Phy): used for features) +dd = load('PCspikes2.mat'); % you might want to recompute this from your own data +ops.wPCA = dd.Wi(:,1:7); % PCs + +% options for posthoc merges (under construction) +ops.fracse = 0.1; % binning step along discriminant axis for posthoc merges (in units of sd) +ops.epu = Inf; + +ops.ForceMaxRAMforDat = 0e9; % maximum RAM the algorithm will try to use; on Windows it will autodetect. + +ops.Drift.chSmooth = 5; +ops.Drift.tSmooth = 20; +ops.doDriftCorrection = 1; +ops.sigDrift = 15; +ops.sigShift = 20; +ops.initialize = 'fromDriftCorrection'; \ No newline at end of file diff --git a/FwBwAddon/driftTopPCs.m b/FwBwAddon/driftTopPCs.m new file mode 100644 index 00000000..33e831f3 --- /dev/null +++ b/FwBwAddon/driftTopPCs.m @@ -0,0 +1,35 @@ + +S = sparse(ceil(rez.st/3e4), rez.clu, ones(1, numel(rez.clu))); +S(1, rez.ops.Nfilt) = 0; + +S = gpuArray(single(full(S))); + +% Shigh = S - mean(S,1); +% Shigh = S - my_conv2(S,500,1); + +Slow = my_conv2(S,500,1); + +rat = min(Slow, [], 1) ./max(Slow, [],1); + +S0 = S(:, rat>.5); + +% [U Sv V] = svdecon(zscore(Shigh, 1, 1)); +% [U Sv V] = svdecon(S0 - mean(S0, 1)); +[U Sv V] = svdecon(zscore(S0,1,1)); +%% +clf +plot(U(:,4)) +%% +imagesc(S0, [0 20]) + +%% + +W = reshape(rez.W, 7, 374, []); + +Wrec = reshape(rez.wPCA * W(:,:), 61, 374, []); + + +imagesc(Wrec(:,:, 244)') + +%% +1 \ No newline at end of file diff --git a/FwBwAddon/extractPCfeatures0.m b/FwBwAddon/extractPCfeatures0.m new file mode 100644 index 00000000..4adb5938 --- /dev/null +++ b/FwBwAddon/extractPCfeatures0.m @@ -0,0 +1,85 @@ +function [uprojDAT, call, muall, tall] = extractPCfeatures0(rez, wPCA) + +ops = rez.ops; +nPC = size(wPCA,2); + +call = zeros(1, 5e6); +muall = zeros(1, 5e6); +tall = zeros(1, 5e6); + +i0 = 0; + +% indices of channels relative to center +nCH = 8; +dc = [-nCH:nCH]; +% dc = -12:15; +%% +Nbatch = rez.temp.Nbatch; + + +NT = ops.NT; +batchstart = 0:NT:NT*Nbatch; + +uprojDAT = zeros(size(wPCA,2) * numel(dc), 1e7, 'single'); + +fid = fopen(ops.fproc, 'r'); +for ibatch = 1:Nbatch + offset = 2 * ops.Nchan*batchstart(ibatch); + fseek(fid, offset, 'bof'); + dat = fread(fid, [NT ops.Nchan], '*int16'); + + if ibatch==1 + ioffset = 0; + else + ioffset = ops.ntbuff; + end + toff = -ioffset + (NT-ops.ntbuff)*(ibatch-1); + + % move data to GPU and scale it + if ops.GPU + dataRAW = gpuArray(dat); + else + dataRAW = dat; + end + dataRAW = single(dataRAW); + dataRAW = dataRAW / ops.scaleproc; + + % find isolated spikes + [row, col, mu] = isolated_peaks_new(dataRAW, ops); + [~, isort] = sort(row); + row = row(isort); + col = col(isort); + mu = mu(isort); + + muall(1,i0 + (1:numel(col))) = gather(mu); + tall(1,i0 + (1:numel(row))) = gather(row-ops.nt0) + toff; + call(1,i0 + (1:numel(col))) = gather(col); + + % get clips from upsampled data + clips = get_SpikeSample(dataRAW, row, col, dc); + + % compute center of mass of each spike and add to height estimate + uS = reshape(wPCA' * clips(:, :), nPC , numel(dc), []); +% uS = permute(uS, [2 1 3]); + uS = reshape(uS, nPC*numel(dc), []); + + if i0+numel(row)>size(uprojDAT,2) + nleft = size(uprojDAT,2) - i0; + uprojDAT(:, i0 + (1:nleft)) = gather_try(uS(:, 1:nleft)); + i0 = i0 + nleft; + break; + end + + uprojDAT(:, i0 + (1:numel(row))) = gather_try(uS); + + i0 = i0 + numel(row); +end + +call(i0+1:end) = []; +tall(i0+1:end) = []; +muall(i0+1:end) = []; +uprojDAT(:, i0+1:end) = []; +fclose(fid); + + + diff --git a/FwBwAddon/extractPCfromSnippets.m b/FwBwAddon/extractPCfromSnippets.m new file mode 100644 index 00000000..252e28b5 --- /dev/null +++ b/FwBwAddon/extractPCfromSnippets.m @@ -0,0 +1,44 @@ +function wPCA = extractPCfromSnippets(rez, nPCs) + +ops = rez.ops; + +% Nchan = ops.Nchan; + +Nbatch = rez.temp.Nbatch; + +NT = ops.NT; +batchstart = 0:NT:NT*Nbatch; + +% extract the PCA projections +CC = zeros(61); +fid = fopen(ops.fproc, 'r'); + +for ibatch = 1:100:Nbatch + offset = 2 * ops.Nchan*batchstart(ibatch); + fseek(fid, offset, 'bof'); + dat = fread(fid, [NT ops.Nchan], '*int16'); + + % move data to GPU and scale it + if ops.GPU + dataRAW = gpuArray(dat); + else + dataRAW = dat; + end + dataRAW = single(dataRAW); + dataRAW = dataRAW / ops.scaleproc; + + + % find isolated spikes + [row, col, mu] = isolated_peaks_new(dataRAW, ops); + + clips = get_SpikeSample(dataRAW, row, col, 0); + + c = sq(clips(:, :)); + CC = CC + gather(c * c')/1e3; + +end +fclose(fid); + +[U Sv V] = svdecon(CC); + +wPCA = U(:, 1:nPCs); diff --git a/FwBwAddon/fwbwClustering.m b/FwBwAddon/fwbwClustering.m new file mode 100644 index 00000000..3b508617 --- /dev/null +++ b/FwBwAddon/fwbwClustering.m @@ -0,0 +1,165 @@ +function rez = fwbwClustering(rez) + +nPCs = rez.ops.nPCs; +Nfilt = rez.ops.Nfilt; +nfullpasses = rez.ops.nfullpasses; + +% extract PC projections here +tic +wPCA = extractPCfromSnippets(rez, nPCs); +fprintf('Obtained 7 PC waveforms in %2.2f seconds \n', toc) + +tic +[uprojDAT, call, muall, tsall] = extractPCfeatures0(rez, wPCA); +fprintf('Extracted %d spikes with %d features in %2.2f seconds \n', ... + size(uprojDAT,2), size(uprojDAT,1), toc) + +tic +nch = 8; +nCHmax = max(call) + nch; + +ioff = nPCs * gpuArray(int32(call - 8)); + +% split into batches for GPU processing +nS = size(uprojDAT,2); % number of spikes +nSamples = 1e4; % number of spikes per batch +[iBatch, nBatches] = makeBatches(nS, nSamples); +for k = 1:nBatches + [~, isort] = sort(call(iBatch{k})); + iBatch{k} = iBatch{k}(isort); +end + +% iBatch = iBatch(randperm(numel(iBatch))); + +[W, Wheights] = initializeW(Nfilt, nCHmax, nPCs); +iorig = 1:Nfilt; + +pm = exp(-1/200); % momentum term +Nnearest = 32; +Params = [nSamples size(uprojDAT,1) Nfilt pm size(W,1) 0 Nnearest]; + +irounds = [1:nBatches nBatches:-1:1]; +Boffset = 0; %ceil(nBatches/2); +niter = nfullpasses * 2*nBatches + Boffset; +ops.lam = [5 30]; +lami = exp(linspace(log(ops.lam(1)), log(ops.lam(2)), niter)); + +flag_resort = 0; +flag_initialized = 0; + +mu = gpuArray.ones(Nfilt, 1, 'single'); +dWU = gpuArray.zeros(size(W), 'single'); +M = NaN * gpuArray.ones(nS,1, 'single'); +iList = gpuArray(int32(ones(Nnearest, Nfilt))); +cfall = zeros(nS, Nnearest, 'single'); +nspfilt = zeros(1, Nfilt, 'single'); +nspmax = zeros(1, Nfilt, 'single'); +Cost = zeros(niter,1); +icl = zeros(nS, 1); + +tauD = 2; + +for i = 1:niter + k = irounds(rem(i-1 + Boffset, 2*nBatches)+1); + + uproj = gpuArray(uprojDAT(:, iBatch{k})); + Params(1) = size(uproj,2); + + % boolean variable: should we compute spike x filter + iW = abs(call(iBatch{k})' - Wheights) < 10; + ioffsub = ioff(iBatch{k}); + + % get iclust and update W + [dWU, iclust, cmax, cf, nsp] = mexClustering(Params, uproj, W, ioffsub, ... + iW, dWU, mu, iList-1); + +% nspfilt(iorig) = exp(-1/tauD) * nspfilt(iorig) + (1 - exp(-1/tauD)) * single(nsp'); +% nspmax = max(nspfilt, nspmax); + +% coefs(iBatch{k}, :) = cmax; + + M(iBatch{k}) = max(cmax, [], 2); + + icl(iBatch{k}) = iorig(iclust+1); + + % update W + if i==100 + flag_resort = 1; + flag_initialized = 1; + end + + if flag_initialized + mu = sum(dWU.^2,1).^.5; + W = dWU./(1e-5 + mu); + Params(6) = lami(i); + if ~flag_resort + cfall(iBatch{k}, :) = cf; + end + end + + if flag_resort + W = reshape(W, nPCs, nCHmax, Nfilt); + nW = sq(sum(W(1, :, :).^2,1)); + W = reshape(W, nCHmax * nPCs, Nfilt); + + [~, Wheights] = max(nW,[], 1); + [Wheights, isort] = sort(Wheights); + iorig = iorig(isort); + W = W(:, isort); + dWU = dWU(:, isort); + end + + + if i==niter-nBatches + flag_resort = 0; + cc = W' * W; + [~, isort] = sort(cc, 1, 'descend'); + iList = int32(gpuArray(isort(1:Nnearest, :))); + end + +% if rem(i,100)==1 +% p = p+1; +% Cost(p) = gather(nanmean(M.^2)); +% plot(Cost(1:p)) +% drawnow +% end +end + + +iresort(iorig) = 1:Nfilt; + +rez.cfall = cfall - cfall(:,1); +rez.st = tsall; +rez.clu = iresort(icl); +rez.wPCA = wPCA; +rez.W = W; +rez.iList = iList; +rez.call = call; + +fprintf('Optimization complete in %2.2f seconds \n', toc) + +%% +% S = sparse(ceil(((tsall -tsall(1))+1)/3e4), icl, ones(1, numel(icl))); +% S(1, Nfilt) = 0; +% +% S = gpuArray(single(full(S))); +% +% % Shigh = S - mean(S,1); +% % Shigh = S - my_conv2(S,500,1); +% +% Slow = my_conv2(S,500,1); +% +% rat = min(Slow, [], 1) ./max(Slow, [],1); +% +% S0 = S(:, rat>.5); +% +% % [U Sv V] = svdecon(zscore(Shigh, 1, 1)); +% [U Sv V] = svdecon(S0 - mean(S0, 1)); +% +% clf +% plot(U(:,1:4)) +% %% +% clear iresort +% iresort(iorig) = 1:Nfilt; +% imagesc(S(:, iresort), [0 20]) +% diff --git a/FwBwAddon/get_SpikeSample.m b/FwBwAddon/get_SpikeSample.m new file mode 100644 index 00000000..54b4aede --- /dev/null +++ b/FwBwAddon/get_SpikeSample.m @@ -0,0 +1,20 @@ +function clips = get_SpikeSample(dataRAW, row, col, dc) + +[nT, nChan] = size(dataRAW); + +% times around the peak to consider +dt = -21 + [1:61]; + +% temporal indices +indsT = repmat(row', numel(dt), 1) + repmat(dt', 1, numel(row)); +indsC = repmat(col', numel(dc), 1) + repmat(dc', 1, numel(col)); + +indsC(indsC<1) = 1; +indsC(indsC>nChan) = nChan; + +indsT = permute(indsT, [1 3 2]); +indsC = permute(indsC, [3 1 2]); +ix = indsT + (indsC-1) * nT; + +% extract only spatial indices within the col index +clips = reshape(dataRAW(ix), numel(dt), numel(dc), numel(row)); diff --git a/FwBwAddon/initializeW.m b/FwBwAddon/initializeW.m new file mode 100644 index 00000000..a9418ccc --- /dev/null +++ b/FwBwAddon/initializeW.m @@ -0,0 +1,14 @@ +function [W, Wheights] = initializeW(Nfilt, nCHmax, nPCs) +W = gpuArray.zeros(7, nCHmax, Nfilt, 'single'); +ic = round(linspace(1, nCHmax, Nfilt)); +W(7*(ic-1)+1 + 7 * nCHmax * [0:1:Nfilt-1]) = -100; +W = reshape(W, nCHmax * nPCs, Nfilt); +W = W + 1 * gpuArray.randn(nCHmax * nPCs, Nfilt, 'single'); +W = normc(W); +W = reshape(W, nPCs, nCHmax, Nfilt); +nW = sq(sum(W(1, :, :).^2,1)); +[~, Wheights] = max(nW,[], 1); +[~, isort] = sort(Wheights); +W = reshape(W, nCHmax * nPCs, Nfilt); +W = W(:, isort); +end \ No newline at end of file diff --git a/FwBwAddon/isolated_peaks.m b/FwBwAddon/isolated_peaks.m new file mode 100644 index 00000000..f4d0b87d --- /dev/null +++ b/FwBwAddon/isolated_peaks.m @@ -0,0 +1,24 @@ +function [row, col, mu] = isolated_peaks(S1, ops) +loc_range = ops.loc_range; +long_range = ops.long_range; +Th = ops.spkTh; +nt0 = ops.nt0; + +% loc_range = [3 1]; +% long_range = [30 6]; +smin = my_min(S1, loc_range, [1 2]); +peaks = single(S1nS) = []; diff --git a/FwBwAddon/masterFwBw.m b/FwBwAddon/masterFwBw.m new file mode 100644 index 00000000..f7095a67 --- /dev/null +++ b/FwBwAddon/masterFwBw.m @@ -0,0 +1,45 @@ +addpath(genpath('D:\CODE\GitHub\KiloSort')) % path to kilosort folder +addpath('D:\CODE\GitHub\FwBwAddon') +addpath('D:\CODE\GitHub\npy-matlab') + +% mexcuda('D:\CODE\GitHub\FwBwAddon\mexClustering.cu'); + +pathToYourConfigFile = 'D:\CODE\GitHub\FwBwAddon'; % take from Github folder and put it somewhere else (together with the master_file) +run(fullfile(pathToYourConfigFile, 'configFileBench384.m')) + +% common options for every probe +ops.chanMap = 'D:\CODE\GitHub\FwBwAddon\neuropixPhase3A_kilosortChanMap_385.mat'; +% ops.trange = [3300 Inf]; % TIME RANGE IN SECONDS TO PROCESS +% ops.trange = [3400 Inf]; % TIME RANGE IN SECONDS TO PROCESS +ops.trange = [3750 Inf]; % TIME RANGE IN SECONDS TO PROCESS + +ops.Nfilt = 512; % how many clusters to use +ops.nfullpasses = 6; % how many forward backward passes to do +ops.nPCs = 7; % how many PCs to project the spikes into + +ops.useRAM = 0; % whether to use RAM for all data, or no data +ops.spkTh = -4; % spike threshold +ops.nSkipCov = 5; % how many batches to skip when computing whitening matrix + +probeName = {'K1', 'K2', 'K3', 'ZNP1', 'ZNP2', 'ZNP3', 'ZNP4', 'ZO'}; +mname = 'Krebs'; %'Waksman'; %'Krebs'; %'Robbins'; +datexp = '2017-06-05'; %'2017-06-10'; %'2017-06-13'; +rootZ = '\\zserver\Data\Subjects'; + +for j = 1:length(probeName) + fname = sprintf('%s_%s_%s_g0_t0.imec.ap_CAR.bin', mname, datexp, probeName{j}); + ops.fbinary = fullfile(rootZ, mname, datexp, sprintf('ephys_%s', probeName{j}), fname); + ops.fproc = 'G:\DATA\Spikes\temp_wh.dat'; % residual from RAM of preprocessed data + ops.dir_rez = 'G:\DATA\Spikes\'; + + % preprocess data + rez = preprocessDataSub(ops); + + + % cluster the threshold crossings + rez = fwbwClustering(rez); + + % save the result file + fname = fullfile(ops.dir_rez, sprintf('rez_%s_%s_%s.mat', mname, datexp, probeName{j})); + save(fname, 'rez', '-v7.3'); +end diff --git a/FwBwAddon/mexClustering.cu b/FwBwAddon/mexClustering.cu new file mode 100644 index 00000000..0ae87fe1 --- /dev/null +++ b/FwBwAddon/mexClustering.cu @@ -0,0 +1,260 @@ +/* + * Example of how to use the mxGPUArray API in a MEX file. This example shows + * how to write a MEX function that takes a gpuArray input and returns a + * gpuArray output, e.g. B=mexFunction(A). + * + * Copyright 2012 The MathWorks, Inc. + */ +#include +#include +#include +#include +#include +#include "mex.h" +#include "gpu/mxGPUArray.h" +#include +#include +#include +using namespace std; + + +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void computeCost(const double *Params, const float *uproj, const float *mu, const float *W, + const int *ioff, const bool *iW, float *cmax){ + + int tid, bid, Nspikes, Nfeatures, NfeatW, Nthreads, k; + float xsum = 0.0f, Ci, lam; + + Nspikes = (int) Params[0]; + Nfeatures = (int) Params[1]; + NfeatW = (int) Params[4]; + Nthreads = blockDim.x; + lam = (float) Params[5]; + + tid = threadIdx.x; + bid = blockIdx.x; + + while(tid max_running){ + id[tind] = ind; + max_running = cmax[tind + ind*Nspikes]; + } + + tind += Nblocks*Nthreads; + } +} +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void average_snips(const double *Params, const int *ioff, const int *id, const float *uproj, + const float *cmax, const int *iList, float *cf, float *WU){ + + int tid, bid, ind, Nspikes, Nfeatures, NfeatW, Nnearest, t; + float xsum = 0.0f, pm; + + Nspikes = (int) Params[0]; + Nfeatures = (int) Params[1]; + pm = (float) Params[3]; + NfeatW = (int) Params[4]; + Nnearest = (int) Params[6]; + + tid = threadIdx.x; + bid = blockIdx.x; + + for(ind=0; ind>>(d_Params, d_uproj, d_mu, d_W, d_ioff, + d_iW, d_cmax); + + // loop through cmax to find best template + bestFilter<<<40, 256>>>(d_Params, d_iW, d_cmax, d_id); + + // average all spikes for same template + average_snips<<>>(d_Params, d_ioff, d_id, d_uproj, + d_cmax, d_iList, d_cf, d_dWU); + + count_spikes<<<7, 256>>>(d_Params, d_id, d_nsp); + + // this is how dWU is pushed out + plhs[0] = mxGPUCreateMxArrayOnGPU(dWU); + + // this is how id is pushed out + int *id; + const mwSize dimst[] = {Nspikes,1}; + plhs[1] = mxCreateNumericArray(2, dimst, mxINT32_CLASS, mxREAL); + id = (int*) mxGetData(plhs[1]); + cudaMemcpy(id, d_id, Nspikes * sizeof(int), cudaMemcpyDeviceToHost); + + float *cmax; + const mwSize dimst2[] = {Nspikes,Nfilters}; + plhs[2] = mxCreateNumericArray(2, dimst2, mxSINGLE_CLASS, mxREAL); + cmax = (float*) mxGetData(plhs[2]); + cudaMemcpy(cmax, d_cmax, Nspikes * Nfilters* sizeof(float), cudaMemcpyDeviceToHost); + + float *cf; + const mwSize dimst3[] = {Nspikes,Nnearest}; + plhs[3] = mxCreateNumericArray(2, dimst3, mxSINGLE_CLASS, mxREAL); + cf = (float*) mxGetData(plhs[3]); + cudaMemcpy(cf, d_cf, Nspikes * Nnearest* sizeof(float), cudaMemcpyDeviceToHost); + + int *nsp; + const mwSize dimst4[] = {Nfilters,1}; + plhs[4] = mxCreateNumericArray(2, dimst4, mxINT32_CLASS, mxREAL); + nsp = (int*) mxGetData(plhs[4]); + cudaMemcpy(nsp, d_nsp, Nfilters * sizeof(int), cudaMemcpyDeviceToHost); + + //we are done, clear everything from the GPU + cudaFree(d_Params); + cudaFree(d_cmax); + cudaFree(d_id); + cudaFree(d_cf); + + //do this for the constant variables + mxGPUDestroyGPUArray(uproj); + mxGPUDestroyGPUArray(dWU); + mxGPUDestroyGPUArray(W); + mxGPUDestroyGPUArray(ioff); + mxGPUDestroyGPUArray(iW); + mxGPUDestroyGPUArray(mu); + mxGPUDestroyGPUArray(iList); + + +} diff --git a/FwBwAddon/neuropixPhase3A_kilosortChanMap_385.mat b/FwBwAddon/neuropixPhase3A_kilosortChanMap_385.mat new file mode 100644 index 00000000..b7248b71 Binary files /dev/null and b/FwBwAddon/neuropixPhase3A_kilosortChanMap_385.mat differ diff --git a/FwBwAddon/preprocessDataSub.m b/FwBwAddon/preprocessDataSub.m new file mode 100644 index 00000000..10ff3a84 --- /dev/null +++ b/FwBwAddon/preprocessDataSub.m @@ -0,0 +1,210 @@ +function [rez, DATA] = preprocessDataSub(ops) +tic; +ops.nt0 = getOr(ops, {'nt0'}, 61); + +if ~isempty(ops.chanMap) + if ischar(ops.chanMap) + load(ops.chanMap); + try + chanMapConn = chanMap(connected>1e-6); + xc = xcoords(connected>1e-6); + yc = ycoords(connected>1e-6); + catch + chanMapConn = 1+chanNums(connected>1e-6); + xc = zeros(numel(chanMapConn), 1); + yc = [1:1:numel(chanMapConn)]'; + end + ops.Nchan = getOr(ops, 'Nchan', sum(connected>1e-6)); + ops.NchanTOT = getOr(ops, 'NchanTOT', numel(connected)); + if exist('fs', 'var') + ops.fs = getOr(ops, 'fs', fs); + end + else + chanMap = ops.chanMap; + chanMapConn = ops.chanMap; + xc = zeros(numel(chanMapConn), 1); + yc = [1:1:numel(chanMapConn)]'; + connected = true(numel(chanMap), 1); + + ops.Nchan = numel(connected); + ops.NchanTOT = numel(connected); + end +else + chanMap = 1:ops.Nchan; + connected = true(numel(chanMap), 1); + + chanMapConn = 1:ops.Nchan; + xc = zeros(numel(chanMapConn), 1); + yc = [1:1:numel(chanMapConn)]'; +end +if exist('kcoords', 'var') + kcoords = kcoords(connected); +else + kcoords = ones(ops.Nchan, 1); +end +NchanTOT = ops.NchanTOT; +NT = ops.NT ; + +rez.ops = ops; +rez.xc = xc; +rez.yc = yc; +if exist('xcoords', 'var') + rez.xcoords = xcoords; + rez.ycoords = ycoords; +else + rez.xcoords = xc; + rez.ycoords = yc; +end +rez.connected = connected; +rez.ops.chanMap = chanMap; +rez.ops.kcoords = kcoords; + +d = dir(ops.fbinary); +nTimepoints = floor(d.bytes/NchanTOT/2); + +rez.ops.tstart = ceil(ops.trange(1) * ops.fs); +rez.ops.tend = min(nTimepoints, ceil(ops.trange(2) * ops.fs)); + +rez.ops.sampsToRead = rez.ops.tend-rez.ops.tstart; + +NTbuff = NT + 4*ops.ntbuff; +Nbatch = ceil(rez.ops.sampsToRead /(NT-ops.ntbuff)); + +% by how many bytes to offset all the batches +twind = rez.ops.tstart * NchanTOT*2; + +%% load data into patches, filter, compute covariance +if isfield(ops,'fslow')&&ops.fslow1e4 + break; + end +end +% +nS = nS(1:ncurr); +uBase = uBase(1:ncurr, :); + +[~, itsort] = sort(nS, 'descend'); + +%% initialize U +% compute covariance matrix +sigDrift = 15; +sigShift = 20; +chanDists= bsxfun(@minus, rez.yc, rez.yc').^2 + bsxfun(@minus, rez.xc, rez.xc').^2; +iCovChans = my_inv(exp(-chanDists/(2*sigDrift^2)), 1e-6); + +Nfilt = ops.Nfilt; +lam = ops.lam(1) * ones(Nfilt, 1, 'single'); + +U = gpuArray(uBase(itsort(1:Nfilt), :))'; +mu = sum(U.^2,1)'.^.5; +U = normc(U); +% +deltay = zeros(ops.Nchan, numel(indBatch)); + +for i = 1:10 + % resample spatial masks up and down + Uup = shift_data(reshape(U, ops.Nchan, []), sigShift, rez.yc, rez.xc, iCovChans, sigDrift, rez.Wrot); + Uup = reshape(Uup, size(U)); + Udown = shift_data(reshape(U, ops.Nchan, []), -sigShift, rez.yc, rez.xc, iCovChans, sigDrift, rez.Wrot); + Udown = reshape(Udown, size(U)); + + mu = repmat(mu, 3, 1); + lam = repmat(lam, 3, 1); + + dWU = zeros(Nfilt, nProj, 'single'); + nToT = gpuArray.zeros(Nfilt, 1, 'single'); + Cost = gpuArray(single(0)); + % + for ibatch = 1:numel(indBatch) + % find clusters + clips = gpuArray(uproj(indBatch{ibatch}, :))'; + nSpikes = size(clips,2); + clips = reshape(clips, ops.Nchan, []); + + % resample clips by the delta y + clips = shift_data(clips, deltay(:, ibatch), rez.yc, rez.xc, iCovChans, sigDrift, rez.Wrot); + clips = reshape(clips, size(U,1), [])'; + + ci = clips * [Udown U Uup]; + + ci = bsxfun(@plus, ci, (mu .* lam)'); + cf = bsxfun(@rdivide, ci.^2, 1 + lam'); + cf = bsxfun(@minus, cf, (mu.^2.*lam)'); + + cf = reshape(cf,[], Nfilt, 3); + [Mmax, imax] = max(cf, [], 2); + [~, i3max] = max(Mmax, [], 3); +% keyboard; + + % determine added correction to delta y + + % determine cluster assignment for this iteration + [max_cf, id] = max(cf(:,:,2), [], 2); + id = gather_try(id); + L = gpuArray.zeros(Nfilt, nSpikes, 'single'); + L(id' + [0:Nfilt:(Nfilt*nSpikes-1)]) = 1; + dWU = dWU + L * clips; + nToT = nToT + sum(L, 2); + Cost = Cost + mean(max_cf); + end + dWU = bsxfun(@rdivide, dWU, nToT); + + U = dWU'; + mu = sum(U.^2,1)'.^.5; + U = normc(U); + Cost = Cost/size(inds,2); + + % smooth out corrections + + % disp(Cost) + + % plot(sort(log(1+nToT))) + % drawnow +end +%% +Nchan = ops.Nchan; +Nfilt = ops.Nfilt; +wPCA = ops.wPCA(:,1:3); +Urec = reshape(U, Nchan, size(wPCA,2), Nfilt); + +nt0 = 61; +Urec= permute(Urec, [2 1 3]); +Wrec = reshape(wPCA * Urec(:,:), nt0, Nchan, Nfilt); + +Wrec = gather_try(Wrec); +Nrank = 3; +W = zeros(nt0, Nfilt, Nrank, 'single'); +U = zeros(Nchan, Nfilt, Nrank, 'single'); +for j = 1:Nfilt + [w sv u] = svd(Wrec(:,:,j)); + w = w * sv; + + Sv = diag(sv); + W(:,j,:) = w(:, 1:Nrank)/sum(Sv(1:ops.Nrank).^2).^.5; + U(:,j,:) = u(:, 1:Nrank); +end + +Uinit = U; +Winit = W; +mu = gather_try(single(mu)); +muinit = mu; + +WUinit = zeros(nt0, Nchan, Nfilt); +for j = 1:Nfilt + WUinit(:,:,j) = muinit(j) * Wrec(:,:,j); +end +WUinit = single(WUinit); +%% + + diff --git a/driftCorrection/collectRawClips.m b/driftCorrection/collectRawClips.m new file mode 100644 index 00000000..86271451 --- /dev/null +++ b/driftCorrection/collectRawClips.m @@ -0,0 +1,253 @@ +function [rez, uproj, indBatch] = collectRawClips(ops) +tic; +uproj = []; + + +if ~isempty(ops.chanMap) + if ischar(ops.chanMap) + load(ops.chanMap); + try + chanMapConn = chanMap(connected>1e-6); + xc = xcoords(connected>1e-6); + yc = ycoords(connected>1e-6); + catch + chanMapConn = 1+chanNums(connected>1e-6); + xc = zeros(numel(chanMapConn), 1); + yc = [1:1:numel(chanMapConn)]'; + end + else + chanMap = ops.chanMap; + chanMapConn = ops.chanMap; + xc = zeros(numel(chanMapConn), 1); + yc = [1:1:numel(chanMapConn)]'; + connected = true(numel(chanMap), 1); + end +else + chanMap = 1:ops.Nchan; + connected = true(numel(chanMap), 1); + + chanMapConn = 1:ops.Nchan; + xc = zeros(numel(chanMapConn), 1); + yc = [1:1:numel(chanMapConn)]'; +end +if ~exist('kcoords', 'var') + kcoords = ones(ops.Nchan, 1); +end +NchanTOT = ops.NchanTOT; +NT = ops.NT ; + +rez.xc = xc; +rez.yc = yc; +rez.connected = connected; +rez.ops = ops; +rez.ops.chanMap = chanMap; +rez.ops.kcoords = kcoords; + +d = dir(ops.fbinary); +ops.sampsToRead = floor(d.bytes/NchanTOT/2); + +if ispc + dmem = memory; + memfree = dmem.MemAvailableAllArrays/8; + memallocated = min(ops.ForceMaxRAMforDat, dmem.MemAvailableAllArrays) - memfree; + memallocated = max(0, memallocated); +else + memallocated = ops.ForceMaxRAMforDat; +end +nint16s = memallocated/2; + +NTbuff = NT + 4*ops.ntbuff; +Nbatch = ceil(d.bytes/2/NchanTOT /(NT-ops.ntbuff)); +Nbatch_buff = floor(4/5 * nint16s/ops.Nchan /(NT-ops.ntbuff)); % factor of 4/5 for storing PCs of spikes +Nbatch_buff = min(Nbatch_buff, Nbatch); + +%% load data into patches, filter, compute covariance +if isfield(ops,'fslow')&&ops.fslowsize(uproj,1) + uproj(1e6 + size(uproj,1), 1) = 0; + end + + uproj(i0 + (1:numel(row)), :) = gather_try(uS); + indBatch{ibatch} = i0 + (1:numel(row)); + i0 = i0 + numel(row); + end +end + +if strcmp(ops.initialize, 'fromData') + uproj(i0+1:end, :) = []; +end +Wrot = gather_try(Wrot); +rez.Wrot = Wrot; + +fclose(fid); +if ops.verbose + fprintf('Time %3.2f. Collected clips... \n', toc); +end + + +rez.temp.Nbatch = Nbatch; +rez.temp.Nbatch_buff = Nbatch_buff; + diff --git a/driftCorrection/get_Uupdown.m b/driftCorrection/get_Uupdown.m new file mode 100644 index 00000000..2a4905c0 --- /dev/null +++ b/driftCorrection/get_Uupdown.m @@ -0,0 +1,14 @@ + +function [Uup, Udown] = get_Uupdown(U, space_lag, CovChans, Wrot, ops) + +Uwh = (Wrot * (CovChans/Wrot)) * U; + + +Udown = reshape(U, 2, ops.Nchan/2, [], size(U,2)); +Udown(:,2:end,:,:) = Udown(:,1:end-1,:,:); +Udown = reshape(Udown, size(U)); +Udown = normc(Udown); +Uup = reshape(U, 2, ops.Nchan/2, [], size(U,2)); +Uup(:,1:end-1,:,:) = Uup(:,2:end,:,:); +Uup = reshape(Uup, size(U)); +Uup = normc(Uup); \ No newline at end of file diff --git a/driftCorrection/shift_data.m b/driftCorrection/shift_data.m new file mode 100644 index 00000000..5b467b3c --- /dev/null +++ b/driftCorrection/shift_data.m @@ -0,0 +1,7 @@ +function data = shift_data(data, dy, ycoords, xcoords, iCovChans, sigDrift, Wrot) + +shiftM = shift_matrix(dy, ycoords, xcoords, iCovChans, sigDrift, Wrot); +data = shiftM * data; + + + diff --git a/driftCorrection/shift_matrix.m b/driftCorrection/shift_matrix.m new file mode 100644 index 00000000..70313ed2 --- /dev/null +++ b/driftCorrection/shift_matrix.m @@ -0,0 +1,11 @@ +function shiftM = shift_matrix(dy, ycoords, xcoords, iCovChans, sig, Wrot) + + +yminusy = bsxfun(@minus, ycoords - dy, ycoords').^2 + ... + bsxfun(@minus, xcoords , xcoords').^2; + +newSamp = exp(- yminusy/(2*sig^2)); + +shiftM = Wrot * ((newSamp * iCovChans)/Wrot); + + diff --git a/eMouse/benchmark_simulation.m b/eMouse/benchmark_simulation.m new file mode 100644 index 00000000..711900d6 --- /dev/null +++ b/eMouse/benchmark_simulation.m @@ -0,0 +1,57 @@ +function benchmark_simulation(rez, GTfilepath) + +load(GTfilepath) + +try + testClu = 1 + rez.st3(:,5) ; % if the auto merges were performed + flag = 1; +catch + testClu = rez.st3(:,2) ;% no attempt to merge clusters + flag = 0; +end + +testRes = rez.st3(:,1) ; + +[allScores, allFPrates, allMissRates, allMerges] = ... + compareClustering2(gtClu, gtRes, testClu, testRes, []); + +% +clid = unique(gtClu); +clear gtimes +for k = 1:length(clid) + gtimes{k} = double(gtRes(gtClu==clid(k))); +end +%% + +figure + +plot(sort(cellfun(@(x) x(1), allFPrates)), '-*b', 'Linewidth', 2) +hold all +plot(sort(cellfun(@(x) x(1), allMissRates)), '-*r', 'Linewidth', 2) +plot(sort(cellfun(@(x) x(end), allFPrates)), 'b', 'Linewidth', 2) +plot(sort(cellfun(@(x) x(end), allMissRates)), 'r', 'Linewidth', 2) +ylim([0 1]) +box off + +finalScores = cellfun(@(x) x(end), allScores); +fprintf('%d / %d good cells, score > 0.8 (pre-merge) \n', sum(cellfun(@(x) x(1), allScores)>.8), numel(allScores)) +fprintf('%d / %d good cells, score > 0.8 (post-merge) \n', sum(cellfun(@(x) x(end), allScores)>.8), numel(allScores)) + +nMerges = cellfun(@(x) numel(x)-1, allMerges); +fprintf('Mean merges per good cell %2.2f \n', mean(nMerges(finalScores>.8))) + +% disp(cellfun(@(x) x(end), allScores)) + +xlabel('ground truth cluster') +ylabel('fractional error') + +legend('false positives (initial)', 'miss rates (initial)', 'false positives (best)', 'miss rates (best)') +legend boxoff +set(gca, 'Fontsize', 20) +set(gcf, 'Color', 'w') + +if flag==1 + title('After Kilosort AUTO merges') +else + title('Before Kilosort AUTO merges') +end diff --git a/eMouse/compareClustering2.m b/eMouse/compareClustering2.m new file mode 100644 index 00000000..938b28f0 --- /dev/null +++ b/eMouse/compareClustering2.m @@ -0,0 +1,141 @@ + + +function [allScores, allFPs, allMisses, allMerges] = compareClustering2(cluGT, resGT, cluTest, resTest, datFilename) +% function compareClustering(cluGT, resGT, cluTest, resTest[, datFilename]) +% - clu and res variables are length nSpikes, for ground truth (GT) and for +% the clustering to be evaluated (Test). + + +if nargin<5 + datFilename = []; +end + +GTcluIDs = unique(cluGT); +testCluIDs = unique(cluTest); +jitter = 12; + +nSp = zeros(max(testCluIDs), 1); +for j = 1:max(testCluIDs); + nSp(j) = max(1, sum(cluTest==j)); +end +nSp0 = nSp; + +for cGT = 1:length(GTcluIDs) +% fprintf(1,'ground truth cluster ID = %d (%d spikes)\n', GTcluIDs(cGT), sum(cluGT==GTcluIDs(cGT))); + + rGT = int32(resGT(cluGT==GTcluIDs(cGT))); + +% S = sparse(numel(rGT), max(testCluIDs)); + S = spalloc(numel(rGT), max(testCluIDs), numel(rGT) * 10); + % find the initial best match + mergeIDs = []; + scores = []; + falsePos = []; + missRate = []; + + igt = 1; + + nSp = nSp0; + nrGT = numel(rGT); + flag = false; + for j = 1:numel(cluTest) + while (resTest(j) > rGT(igt) + jitter) + % the curent spikes is now too large compared to GT, advance the GT + igt = igt + 1; + if igt>nrGT + flag = true; + break; + end + end + if flag + break; + end + + if resTest(j)>rGT(igt)-jitter + % we found a match, add a tick to the right cluster +% numMatch(cluTest(j)) = numMatch(cluTest(j)) + 1; + S(igt, cluTest(j)) = 1; + end + end + numMatch = sum(S,1)'; + misses = (nrGT-numMatch)/nrGT; % missed these spikes, as a proportion of the total true spikes + fps = (nSp-numMatch)./nSp; % number of comparison spikes not near a GT spike, as a proportion of the number of guesses + % + % for cTest = 1:length(testCluIDs) +% rTest = int32(resTest(cluTest==testCluIDs(cTest))); +% +% [miss, fp] = compareSpikeTimes(rTest, rGT); +% misses(cTest) = miss; +% fps(cTest) = fp; +% +% end +% + sc = 1-(fps+misses); + best = find(sc==max(sc),1); + mergeIDs(end+1) = best; + scores(end+1) = sc(best); + falsePos(end+1) = fps(best); + missRate(end+1) = misses(best); + +% fprintf(1, ' found initial best %d: score %.2f (%d spikes, %.2f FP, %.2f miss)\n', ... +% mergeIDs(1), scores(1), sum(cluTest==mergeIDs(1)), fps(best), misses(best)); + + S0 = S(:, best); + nSp = nSp + nSp0(best); + while scores(end)>0 && (length(scores)==1 || ( scores(end)>(scores(end-1) + 1*0.01) && scores(end)<=0.99 )) + % find the best match + S = bsxfun(@max, S, S0); + + numMatch = sum(S,1)'; + misses = (nrGT-numMatch)/nrGT; % missed these spikes, as a proportion of the total true spikes + fps = (nSp-numMatch)./nSp; % number of comparison spikes not near a GT spike, as a proportion of the number of guesses + + sc = 1-(fps+misses); + best = find(sc==max(sc),1); + mergeIDs(end+1) = best; + scores(end+1) = sc(best); + falsePos(end+1) = fps(best); + missRate(end+1) = misses(best); + +% fprintf(1, ' best merge with %d: score %.2f (%d/%d new/total spikes, %.2f FP, %.2f miss)\n', ... +% mergeIDs(end), scores(end), nSp0(best), nSp(best), fps(best), misses(best)); + + S0 = S(:, best); + nSp = nSp + nSp0(best); + + end + + if length(scores)==1 || scores(end)>(scores(end-1)+0.01) + % the last merge did help, so include it + allMerges{cGT} = mergeIDs(1:end); + allScores{cGT} = scores(1:end); + allFPs{cGT} = falsePos(1:end); + allMisses{cGT} = missRate(1:end); + else + % the last merge actually didn't help (or didn't help enough), so + % exclude it + allMerges{cGT} = mergeIDs(1:end-1); + allScores{cGT} = scores(1:end-1); + allFPs{cGT} = falsePos(1:end-1); + allMisses{cGT} = missRate(1:end-1); + end + +end + +initScore = zeros(1, length(GTcluIDs)); +finalScore = zeros(1, length(GTcluIDs)); +numMerges = zeros(1, length(GTcluIDs)); +fprintf(1, '\n\n--Results Summary--\n') +for cGT = 1:length(GTcluIDs) +% +% fprintf(1,'ground truth cluster ID = %d (%d spikes)\n', GTcluIDs(cGT), sum(cluGT==GTcluIDs(cGT))); +% fprintf(1,' initial score: %.2f\n', allScores{cGT}(1)); +% fprintf(1,' best score: %.2f (after %d merges)\n', allScores{cGT}(end), length(allScores{cGT})-1); +% + initScore(cGT) = allScores{cGT}(1); + finalScore(cGT) = allScores{cGT}(end); + numMerges(cGT) = length(allScores{cGT})-1; +end + +fprintf(1, 'median initial score: %.2f; median best score: %.2f\n', median(initScore), median(finalScore)); +fprintf(1, 'total merges required: %d\n', sum(numMerges)); diff --git a/eMouse/config_eMouse.m b/eMouse/config_eMouse.m new file mode 100644 index 00000000..16b15ecf --- /dev/null +++ b/eMouse/config_eMouse.m @@ -0,0 +1,65 @@ +clear ops +ops.GPU = useGPU; % whether to run this code on an Nvidia GPU (much faster, mexGPUall first) +ops.parfor = 0; % whether to use parfor to accelerate some parts of the algorithm +ops.verbose = 1; % whether to print command line progress +ops.showfigures = 1; % whether to plot figures during optimization + +ops.datatype = 'dat'; % binary ('dat', 'bin') or 'openEphys' +ops.fbinary = fullfile(fpath, 'sim_binary.dat'); % will be created for 'openEphys' +ops.fproc = fullfile(fpath, 'temp_wh.dat'); % residual from RAM of preprocessed data +ops.root = fpath; % 'openEphys' only: where raw files are +% define the channel map as a filename (string) or simply an array +ops.chanMap = fullfile(fpath, 'chanMap.mat'); % make this file using createChannelMapFile.m +% ops.chanMap = 1:ops.Nchan; % treated as linear probe if unavailable chanMap file + +ops.Nfilt = 64; % number of clusters to use (2-4 times more than Nchan, should be a multiple of 32) +ops.nNeighPC = 12; % visualization only (Phy): number of channnels to mask the PCs, leave empty to skip (12) +ops.nNeigh = 16; % visualization only (Phy): number of neighboring templates to retain projections of (16) + +% options for channel whitening +ops.whitening = 'full'; % type of whitening (default 'full', for 'noSpikes' set options for spike detection below) +ops.nSkipCov = 1; % compute whitening matrix from every N-th batch (1) +ops.whiteningRange = 32; % how many channels to whiten together (Inf for whole probe whitening, should be fine if Nchan<=32) + +ops.criterionNoiseChannels = 0.2; % fraction of "noise" templates allowed to span all channel groups (see createChannelMapFile for more info). + +% other options for controlling the model and optimization +ops.Nrank = 3; % matrix rank of spike template model (3) +ops.nfullpasses = 6; % number of complete passes through data during optimization (6) +ops.maxFR = 20000; % maximum number of spikes to extract per batch (20000) +ops.fshigh = 200; % frequency for high pass filtering +% ops.fslow = 2000; % frequency for low pass filtering (optional) +ops.ntbuff = 64; % samples of symmetrical buffer for whitening and spike detection +ops.scaleproc = 200; % int16 scaling of whitened data +ops.NT = 128*1024+ ops.ntbuff;% this is the batch size (try decreasing if out of memory) +% for GPU should be multiple of 32 + ntbuff + +% the following options can improve/deteriorate results. +% when multiple values are provided for an option, the first two are beginning and ending anneal values, +% the third is the value used in the final pass. +ops.Th = [4 10 10]; % threshold for detecting spikes on template-filtered data ([6 12 12]) +ops.lam = [5 5 5]; % large means amplitudes are forced around the mean ([10 30 30]) +ops.nannealpasses = 4; % should be less than nfullpasses (4) +ops.momentum = 1./[20 400]; % start with high momentum and anneal (1./[20 1000]) +ops.shuffle_clusters = 1; % allow merges and splits during optimization (1) +ops.mergeT = .1; % upper threshold for merging (.1) +ops.splitT = .1; % lower threshold for splitting (.1) + +% options for initializing spikes from data +ops.initialize = 'fromData'; %'fromData' or 'no' +ops.spkTh = -6; % spike threshold in standard deviations (4) +ops.loc_range = [3 1]; % ranges to detect peaks; plus/minus in time and channel ([3 1]) +ops.long_range = [30 6]; % ranges to detect isolated peaks ([30 6]) +ops.maskMaxChannels = 5; % how many channels to mask up/down ([5]) +ops.crit = .65; % upper criterion for discarding spike repeates (0.65) +ops.nFiltMax = 10000; % maximum "unique" spikes to consider (10000) + +% load predefined principal components (visualization only (Phy): used for features) +dd = load('PCspikes2.mat'); % you might want to recompute this from your own data +ops.wPCA = dd.Wi(:,1:7); % PCs + +% options for posthoc merges (under construction) +ops.fracse = 0.1; % binning step along discriminant axis for posthoc merges (in units of sd) +ops.epu = Inf; + +ops.ForceMaxRAMforDat = 20e9; % maximum RAM the algorithm will try to use; on Windows it will autodetect. diff --git a/eMouse/make_eMouseChannelMap.m b/eMouse/make_eMouseChannelMap.m new file mode 100644 index 00000000..74b006e8 --- /dev/null +++ b/eMouse/make_eMouseChannelMap.m @@ -0,0 +1,47 @@ +function make_eMouseChannelMap(fpath) +% create a channel Map file for simulated data (eMouse) + +% here I know a priori what order my channels are in. So I just manually +% make a list of channel indices (and give +% an index to dead channels too). chanMap(1) is the row in the raw binary +% file for the first channel. chanMap(1:2) = [33 34] in my case, which happen to +% be dead channels. + +chanMap = [33 34 8 10 12 14 16 18 20 22 24 26 28 30 32 ... + 7 9 11 13 15 17 19 21 23 25 27 29 31 1 2 3 4 5 6]; + +% the first thing Kilosort does is reorder the data with data = data(chanMap, :). +% Now we declare which channels are "connected" in this normal ordering, +% meaning not dead or used for non-ephys data + +connected = true(34, 1); connected(1:2) = 0; + +% now we define the horizontal (x) and vertical (y) coordinates of these +% 34 channels. For dead or nonephys channels the values won't matter. Again +% I will take this information from the specifications of the probe. These +% are in um here, but the absolute scaling doesn't really matter in the +% algorithm. + +xcoords = 20 * [NaN NaN 1 0 0 1 0 1 0 1 0 1 0 1 0 1 0 1 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0]; +ycoords = 20 * [NaN NaN 7 8 9 9 10 10 11 11 12 12 13 13 14 14 15 15 16 ... + 17 17 18 18 19 19 20 20 21 21 22 22 23 23 24]; + +% Often, multi-shank probes or tetrodes will be organized into groups of +% channels that cannot possibly share spikes with the rest of the probe. This helps +% the algorithm discard noisy templates shared across groups. In +% this case, we set kcoords to indicate which group the channel belongs to. +% In our case all channels are on the same shank in a single group so we +% assign them all to group 1. + +kcoords = [NaN NaN 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]; + +% at this point in Kilosort we do data = data(connected, :), ycoords = +% ycoords(connected), xcoords = xcoords(connected) and kcoords = +% kcoords(connected) and no more channel map information is needed (in particular +% no "adjacency graphs" like in KlustaKwik). +% Now we can save our channel map for the eMouse. + +% would be good to also save the sampling frequency here +fs = 25000; + +save(fullfile(fpath, 'chanMap.mat'), 'chanMap', 'connected', 'xcoords', 'ycoords', 'kcoords', 'fs') \ No newline at end of file diff --git a/eMouse/make_eMouseData.m b/eMouse/make_eMouseData.m new file mode 100644 index 00000000..c9ce4bdd --- /dev/null +++ b/eMouse/make_eMouseData.m @@ -0,0 +1,105 @@ +function make_eMouseData(fpath, useGPU) +% this script makes binary file of simulated eMouse recording + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% you can play with the parameters just below here to achieve a signal more similar to your own data!!! +mu_mean = 15; % mean of mean spike amplitudes. 15 should contain enough clustering errors to be instructive (in Phy). 20 is good quality data, 10 will miss half the neurons, +nn = 30; % number of simulated neurons (30) +t_record = 1000; % duration in seconds of simulation. longer is better (and slower!) (1000) +fr_bounds = [1 10]; % min and max of firing rates ([1 10]) +tsmooth = 3; % gaussian smooth the noise with sig = this many samples (increase to make it harder) (3) +chsmooth = 1; % smooth the noise across channels too, with this sig (increase to make it harder) (1) +amp_std = .25; % standard deviation of single spike amplitude variability (increase to make it harder, technically std of gamma random variable of mean 1) (.25) +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +rng('default'); +rng(101); % set the seed of the random number generator + +dat = load('simulation_parameters'); % these are inside Kilosort repo +fs = dat.fs; % sampling rate +wav = dat.waves; % mean waveforms for all neurons +wav = wav(:,:, randperm(size(wav,3), nn)); + +Nchan = numel(dat.xc) + 2; % we add two fake dead channels +NN = size(wav,3); % number of neurons + +chanMap = [33 34 8 10 12 14 16 18 20 22 24 26 28 30 32 ... + 7 9 11 13 15 17 19 21 23 25 27 29 31 1 2 3 4 5 6]; % this is the fake channel map I made + +invChanMap(chanMap) = [1:34]; % invert the channel map here + +mu = mu_mean * (1 + (rand(NN,1) - 0.5)); % create variability in mean amplitude +fr = fr_bounds(1) + (fr_bounds(2)-fr_bounds(1)) * rand(NN,1); % create variability in firing rates + +% totfr = sum(fr); % total firing rate + +spk_times = []; +clu = []; +for j = 1:length(fr) + dspks = int64(geornd(1/(fs/fr(j)), ceil(2*fr(j)*t_record),1)); + dspks(dspks0 + enoise(1:buff, :) = enoise_old(NT-buff + [1:buff], :); + end + + dat = enoise; + dat = my_conv2(dat, [tsmooth chsmooth], [1 2]); + dat = zscore(dat, 1, 1); + dat = gather_try(dat); + + if t_all>0 + dat(1:buff/2, :) = dat_old(NT-buff/2 + [1:buff/2], :); + end + + dat(:, [1 2]) = 0; % these are the "dead" channels + + % now we add spikes on non-dead channels. + ibatch = (spk_times >= t_all*fs) & (spk_times < t_all*fs+NT-buff); + ts = spk_times(ibatch) - t_all*fs; + ids = clu(ibatch); + am = amps(ibatch); + + for i = 1:length(ts) + dat(ts(i) + int64([1:82]), 2 + [1:32]) = dat(ts(i) + int64([1:82]), 2 + [1:32]) +... + mu(ids(i)) * am(i) * wav(:,:,ids(i)); + end + + dat_old = dat; + dat = int16(200 * dat); + fwrite(fidW, dat(1:(NT-buff),invChanMap)', 'int16'); + t_all = t_all + (NT-buff)/fs; + + enoise_old = enoise; +end + +fclose(fidW); % all done + +gtRes = spk_times + 42; % add back the time of the peak for the templates (answer to life and everything) +gtClu = clu; + +save(fullfile(fpath, 'eMouseGroundTruth'), 'gtRes', 'gtClu') \ No newline at end of file diff --git a/eMouse/master_eMouse.m b/eMouse/master_eMouse.m new file mode 100644 index 00000000..b99809b8 --- /dev/null +++ b/eMouse/master_eMouse.m @@ -0,0 +1,64 @@ +useGPU = 1; % do you have a GPU? Kilosorting 1000sec of 32chan simulated data takes 55 seconds on gtx 1080 + M2 SSD. + +fpath = 'F:\DATA\Spikes\eMouse\'; % where on disk do you want the simulation? ideally and SSD... +if ~exist(fpath, 'dir'); mkdir(fpath); end + +% This part adds paths +addpath(genpath('D:\CODE\GitHub\KiloSort')) % path to kilosort folder +addpath(genpath('D:\CODE\GitHub\npy-matlab')) % path to npy-matlab scripts +pathToYourConfigFile = 'D:\CODE\GitHub\KiloSort\eMouse'; % for this example it's ok to leave this path inside the repo, but for your own config file you *must* put it somewhere else! + +% Run the configuration file, it builds the structure of options (ops) +run(fullfile(pathToYourConfigFile, 'config_eMouse.m')) + +% This part makes the channel map for this simulation +make_eMouseChannelMap(fpath); + +% This part simulates and saves data. There are many options you can change inside this +% function, if you want to vary the SNR or firing rates, or number of cells etc. +% You can vary these to make the simulated data look more like your data. +% Currently it is set to relatively low SNR for illustration purposes in Phy. +make_eMouseData(fpath, useGPU); +% +% This part runs the normal Kilosort processing on the simulated data +[rez, DATA, uproj] = preprocessData(ops); % preprocess data and extract spikes for initialization +rez = fitTemplates(rez, DATA, uproj); % fit templates iteratively +rez = fullMPMU(rez, DATA);% extract final spike times (overlapping extraction) + +% This runs the benchmark script. It will report both 1) results for the +% clusters as provided by Kilosort (pre-merge), and 2) results after doing the best +% possible merges (post-merge). This last step is supposed to +% mimic what a user would do in Phy, and is the best achievable score +% without doing splits. +benchmark_simulation(rez, fullfile(fpath, 'eMouseGroundTruth.mat')); + +% save python results file for Phy +rezToPhy(rez, fpath); + +fprintf('Kilosort took %2.2f seconds vs 72.77 seconds on GTX 1080 + M2 SSD \n', toc) + +% now fire up Phy and check these results. There should still be manual +% work to be done (mostly merges, some refinements of contaminated clusters). +%% AUTO MERGES +% after spending quite some time with Phy checking on the results and understanding the merge and split functions, +% come back here and run Kilosort's automated merging strategy. This block +% will overwrite the previous results and python files. Load the results in +% Phy again: there should be no merges left to do (with the default simulation), but perhaps a few splits +% / cleanup. On realistic data (i.e. not this simulation) there will be drift also, which will usually +% mean there are merges left to do even after this step. +% Kilosort's AUTO merges should not be confused with the "best" merges done inside the +% benchmark (those are using the real ground truth!!!) + +rez = merge_posthoc2(rez); +benchmark_simulation(rez, fullfile(fpath, 'eMouseGroundTruth.mat')); + +% save python results file for Phy +rezToPhy(rez, fpath); + +%% save and clean up +% save matlab results file for future use (although you should really only be using the manually validated spike_clusters.npy file) +save(fullfile(fpath, 'rez.mat'), 'rez', '-v7.3'); + +% remove temporary file +delete(ops.fproc); +%% diff --git a/eMouse/readme_eMouse.txt b/eMouse/readme_eMouse.txt new file mode 100644 index 00000000..3e05a643 --- /dev/null +++ b/eMouse/readme_eMouse.txt @@ -0,0 +1,3 @@ +We made some scripts to generate artificial data with similar statistics to real recordings. Run master_eMouse to verify that Kilosort has been installed correctly, as well as to understand how the various options are being passed in and used, and what you should be seeing in Phy after running Kilosort on your own data. This example has been made intentionally hard, so that the are still some clustering errors which you can visualize (and correct for!) in Phy. + +The scripts also measure the accuracy of the algorithm at two different stages (before and after the AUTO merges). The accuracy is measured both in terms of the clusters found by Kilosort, and in terms of the best achievable accuracy after merging those clusters optimally (with knowledge of the ground truth). The last set of results is supposed to mimic the kind of results you would get if you were only doing merges in Phy, and you knew exactly what merges are best. \ No newline at end of file diff --git a/eMouse/simulation_parameters.mat b/eMouse/simulation_parameters.mat new file mode 100644 index 00000000..a2f0b4df Binary files /dev/null and b/eMouse/simulation_parameters.mat differ diff --git a/finalPass/cpuMPmuFEAT.m b/finalPass/cpuMPmuFEAT.m new file mode 100644 index 00000000..09d638e5 --- /dev/null +++ b/finalPass/cpuMPmuFEAT.m @@ -0,0 +1,79 @@ +function [sts, ids, xs, Costs, cprojall] = cpuMPmuFEAT(Params,data,fW,WtW, mu, lam1, nu, ops) + +nt0 = ops.nt0; + +WtW = permute(WtW, [1 3 2]); + +NT = Params(1); +nFilt = Params(2); +Th = Params(3); + +fdata = fft(data, [], 1); +proj = real(ifft(fdata .* fW(:,:), [], 1)); +if ops.Nrank > 1 + proj = sum(reshape(proj, NT, nFilt, ops.Nrank),3); +end +trange = int32([-(nt0-1):(nt0-1)]); + +xs = zeros(Params(4), 1, 'single'); +ids = zeros(Params(4), 1, 'int32'); +sts = zeros(Params(4), 1, 'int32'); +Costs = zeros(Params(4), 1, 'single'); +cprojall = zeros(Params(4), nFilt, 'single'); + +i0 = 0; +for k = 1:30 + Ci = bsxfun(@plus, proj, (mu.*lam1)'); + Ci = bsxfun(@rdivide, Ci.^2, 1 + lam1'); + Ci = bsxfun(@minus, Ci, (lam1 .* mu.^2)'); + + [mX, id] = max(Ci,[], 2); + + maX = -my_min(-mX, 31, 1); + id = int32(id); + + st = find((maX < mX + 1e-3) & mX > Th*Th); + st(st>NT-nt0 | st4 + spikeClusters = uint32(1+rez.st3(:,5)); +end +amplitudes = rez.st3(:,3); + +Nchan = rez.ops.Nchan; + +% try +% load(rez.ops.chanMap); +% catch +% chanMap0ind = [0:Nchan-1]'; +% connected = ones(Nchan, 1); +% xcoords = ones(Nchan, 1); +% ycoords = (1:Nchan)'; +% end +% chanMap0 = chanMap(connected>1e-6); + +connected = rez.connected(:); +xcoords = rez.xcoords(:); +ycoords = rez.ycoords(:); +chanMap = rez.ops.chanMap(:); +chanMap0ind = chanMap - 1; + +nt0 = size(rez.W,1); +U = rez.U; +W = rez.W; + +% for i = 1:length(chanMap0) +% chanMap0(i) = chanMap0(i) - sum(chanMap0(i) > chanMap(connected<1e-6)); +% end +% [~, invchanMap0] = sort(chanMap0); + +templates = zeros(Nchan, nt0, rez.ops.Nfilt, 'single'); +for iNN = 1:rez.ops.Nfilt + templates(:,:,iNN) = squeeze(U(:,iNN,:)) * squeeze(W(:,iNN,:))'; +end +templates = permute(templates, [3 2 1]); % now it's nTemplates x nSamples x nChannels +templatesInds = repmat([0:size(templates,3)-1], size(templates,1), 1); % we include all channels so this is trivial + +templateFeatures = rez.cProj; +templateFeatureInds = uint32(rez.iNeigh); +pcFeatures = rez.cProjPC; +pcFeatureInds = uint32(rez.iNeighPC); + +if ~isempty(savePath) + + writeNPY(spikeTimes, fullfile(savePath, 'spike_times.npy')); + writeNPY(uint32(spikeTemplates-1), fullfile(savePath, 'spike_templates.npy')); % -1 for zero indexing + if size(rez.st3,2)>4 + writeNPY(uint32(spikeClusters-1), fullfile(savePath, 'spike_clusters.npy')); % -1 for zero indexing + else + writeNPY(uint32(spikeTemplates-1), fullfile(savePath, 'spike_clusters.npy')); % -1 for zero indexing + end + writeNPY(amplitudes, fullfile(savePath, 'amplitudes.npy')); + writeNPY(templates, fullfile(savePath, 'templates.npy')); + writeNPY(templatesInds, fullfile(savePath, 'templates_ind.npy')); + +% Fs = rez.ops.fs; + conn = logical(connected); + chanMap0ind = int32(chanMap0ind); + + writeNPY(chanMap0ind(conn), fullfile(savePath, 'channel_map.npy')); + %writeNPY(connected, fullfile(savePath, 'connected.npy')); +% writeNPY(Fs, fullfile(savePath, 'Fs.npy')); + writeNPY([xcoords(conn) ycoords(conn)], fullfile(savePath, 'channel_positions.npy')); + + writeNPY(templateFeatures, fullfile(savePath, 'template_features.npy')); + writeNPY(templateFeatureInds'-1, fullfile(savePath, 'template_feature_ind.npy'));% -1 for zero indexing + writeNPY(pcFeatures, fullfile(savePath, 'pc_features.npy')); + writeNPY(pcFeatureInds'-1, fullfile(savePath, 'pc_feature_ind.npy'));% -1 for zero indexing + + whiteningMatrix = rez.Wrot/200; + whiteningMatrixInv = whiteningMatrix^-1; + writeNPY(whiteningMatrix, fullfile(savePath, 'whitening_mat.npy')); + writeNPY(whiteningMatrixInv, fullfile(savePath, 'whitening_mat_inv.npy')); + + if isfield(rez, 'simScore') + similarTemplates = rez.simScore; + writeNPY(similarTemplates, fullfile(savePath, 'similar_templates.npy')); + end + + %make params file + if ~exist(fullfile(savePath,'params.py'),'file') + fid = fopen(fullfile(savePath,'params.py'), 'w'); + + [~, fname, ext] = fileparts(rez.ops.fbinary); + + fprintf(fid,['dat_path = ''',fname ext '''\n']); + fprintf(fid,'n_channels_dat = %i\n',rez.ops.NchanTOT); + fprintf(fid,'dtype = ''int16''\n'); + fprintf(fid,'offset = 0\n'); + if mod(rez.ops.fs,1) + fprintf(fid,'sample_rate = %i\n',rez.ops.fs); + else + fprintf(fid,'sample_rate = %i.\n',rez.ops.fs); + end + fprintf(fid,'hp_filtered = False'); + fclose(fid); + end +end diff --git a/fitTemplates.m b/fitTemplates.m new file mode 100644 index 00000000..a44e7451 --- /dev/null +++ b/fitTemplates.m @@ -0,0 +1,243 @@ +function rez = fitTemplates(rez, DATA, uproj) + +nt0 = rez.ops.nt0; +rez.ops.nt0min = ceil(20 * nt0/61); + +ops = rez.ops; + +rng('default'); +rng(1); + +Nbatch = rez.temp.Nbatch; +Nbatch_buff = rez.temp.Nbatch_buff; + +Nfilt = ops.Nfilt; %256+128; + +ntbuff = ops.ntbuff; +NT = ops.NT; + +Nrank = ops.Nrank; +Th = ops.Th; +maxFR = ops.maxFR; + +Nchan = ops.Nchan; + +batchstart = 0:NT:NT*(Nbatch-Nbatch_buff); + +delta = NaN * ones(Nbatch, 1); +iperm = randperm(Nbatch); + +switch ops.initialize + case 'fromData' + WUinit = optimizePeaks(ops,uproj);%does a scaled kmeans + dWU = WUinit(:,:,1:Nfilt); + % dWU = alignWU(dWU); + otherwise + initialize_waves0; + ipck = randperm(size(Winit,2), Nfilt); + W = []; + U = []; + for i = 1:Nrank + W = cat(3, W, Winit(:, ipck)/Nrank); + U = cat(3, U, Uinit(:, ipck)); + end + W = alignW(W, ops); + + dWU = zeros(nt0, Nchan, Nfilt, 'single'); + for k = 1:Nfilt + wu = squeeze(W(:,k,:)) * squeeze(U(:,k,:))'; + newnorm = sum(wu(:).^2).^.5; + W(:,k,:) = W(:,k,:)/newnorm; + + dWU(:,:,k) = 10 * wu; + end + WUinit = dWU; +end +[W, U, mu, UtU, nu] = decompose_dWU(ops, dWU, Nrank, rez.ops.kcoords); +W0 = W; +W0(NT, 1) = 0; +fW = fft(W0, [], 1); +fW = conj(fW); + +nspikes = zeros(Nfilt, Nbatch); +lam = ones(Nfilt, 1, 'single'); + +freqUpdate = 100 * 4; +iUpdate = 1:freqUpdate:Nbatch; + + +dbins = zeros(100, Nfilt); +dsum = 0; +miniorder = repmat(iperm, 1, ops.nfullpasses); +% miniorder = repmat([1:Nbatch Nbatch:-1:1], 1, ops.nfullpasses/2); + +i = 1; % first iteration + +epu = ops.epu; + + +%% +% pmi = exp(-1./exp(linspace(log(ops.momentum(1)), log(ops.momentum(2)), Nbatch*ops.nannealpasses))); +pmi = exp(-1./linspace(1/ops.momentum(1), 1/ops.momentum(2), Nbatch*ops.nannealpasses)); +% pmi = exp(-linspace(ops.momentum(1), ops.momentum(2), Nbatch*ops.nannealpasses)); + +% pmi = linspace(ops.momentum(1), ops.momentum(2), Nbatch*ops.nannealpasses); +Thi = linspace(ops.Th(1), ops.Th(2), Nbatch*ops.nannealpasses); +if ops.lam(1)==0 + lami = linspace(ops.lam(1), ops.lam(2), Nbatch*ops.nannealpasses); +else + lami = exp(linspace(log(ops.lam(1)), log(ops.lam(2)), Nbatch*ops.nannealpasses)); +end + +if Nbatch_buff1 && ismember(rem(i,Nbatch), iUpdate) %&& i>Nbatch + dWU = gather_try(dWU); + + % break bimodal clusters and remove low variance clusters + if ops.shuffle_clusters &&... + i>Nbatch && rem(rem(i,Nbatch), 4*400)==1 % iNfilt; + j = Nfilt -9; + end + plot(log(1+NSP(j + [0:1:9])), mu(j+ [0:1:9]), 'o'); + xlabel('log of number of spikes') + ylabel('amplitude of template') + hold all + end + axis tight; + title(sprintf('%d ', nswitch)); + subplot(2,2,2) + plot(W(:,:,1)) + title('timecourses of top PC') + + subplot(2,2,3) + imagesc(U(:,:,1)) + title('spatial mask of top PC') + + drawnow + end + % break if last iteration reached + if i>Nbatch * ops.nfullpasses; break; end + + % record the error function for this iteration + rez.errall(ceil(i/freqUpdate)) = nanmean(delta); + + end + + % select batch and load from RAM or disk + ibatch = miniorder(i); + if ibatch>Nbatch_buff + offset = 2 * ops.Nchan*batchstart(ibatch-Nbatch_buff); + fseek(fid, offset, 'bof'); + dat = fread(fid, [NT ops.Nchan], '*int16'); + else + dat = DATA(:,:,ibatch); + end + + % move data to GPU and scale it + if ops.GPU + dataRAW = gpuArray(dat); + else + dataRAW = dat; + end + dataRAW = single(dataRAW); + dataRAW = dataRAW / ops.scaleproc; + + % project data in low-dim space + data = dataRAW * U(:,:); + + if ops.GPU + % run GPU code to get spike times and coefficients + [dWU, ~, id, x,Cost, nsp] = ... + mexMPregMU(Params,dataRAW,W,data,UtU,mu, lam .* (20./mu).^2, dWU, nu); + else + [dWU, ~, id, x,Cost, nsp] = ... + mexMPregMUcpu(Params,dataRAW,fW,data,UtU,mu, lam .* (20./mu).^2, dWU, nu, ops); + end + + dbins = .9975 * dbins; % this is a hard-coded forgetting factor, needs to become an option + if ~isempty(id) + % compute numbers of spikes + nsp = gather_try(nsp(:)); + nspikes(:, ibatch) = nsp; + + % bin the amplitudes of the spikes + xround = min(max(1, int32(x)), 100); + + dbins(xround + id * size(dbins,1)) = dbins(xround + id * size(dbins,1)) + 1; + + % estimate cost function at this time step + delta(ibatch) = sum(Cost)/1e3; + end + + % update status + if ops.verbose && rem(i,20)==1 + nsort = sort(round(sum(nspikes,2)), 'descend'); + fprintf(repmat('\b', 1, numel(msg))); + msg = sprintf('Time %2.2f, batch %d/%d, mu %2.2f, neg-err %2.6f, NTOT %d, n100 %d, n200 %d, n300 %d, n400 %d\n', ... + toc, i,Nbatch* ops.nfullpasses,nanmean(mu(:)), nanmean(delta), round(sum(nsort)), ... + nsort(min(size(W,2), 100)), nsort(min(size(W,2), 200)), ... + nsort(min(size(W,2), 300)), nsort(min(size(W,2), 400))); + fprintf(msg); + end + + % increase iteration counter + i = i+1; +end + +% close the data file if it has been used +if Nbatch_buff100); + cr = mWtW .* (vld * vld'); + cr(isnan(cr)) = 0; + [~, iNgsort] = sort(cr, 1, 'descend'); + + % save full similarity score + rez.simScore = cr; + maskTT = zeros(Nfilt, 'single'); + rez.iNeigh = iNgsort(1:nNeigh, :); + for i = 1:Nfilt + maskTT(rez.iNeigh(:,i),i) = 1; + end +end +if ~isempty(ops.nNeighPC) + nNeighPC = ops.nNeighPC; + load PCspikes + ixt = round(linspace(1, size(Wi,1), ops.nt0)); + Wi = Wi(ixt, 1:3); + rez.cProjPC = zeros(5e6, 3*nNeighPC, 'single'); + + % sort best channels + [~, iNch] = sort(abs(U(:,:,1)), 1, 'descend'); + maskPC = zeros(Nchan, Nfilt, 'single'); + rez.iNeighPC = iNch(1:nNeighPC, :); + for i = 1:Nfilt + maskPC(rez.iNeighPC(:,i),i) = 1; + end + maskPC = repmat(maskPC, 3, 1); +end + +irun = 0; +i1nt0 = int32([1:nt0])'; +%% +LAM = lam .* (20./mu).^2; + +NT = ops.NT; +batchstart = 0:NT:NT*(Nbatch-Nbatch_buff); + +for ibatch = 1:Nbatch + if ibatch>Nbatch_buff + offset = 2 * ops.Nchan*batchstart(ibatch-Nbatch_buff); % - ioffset; + fseek(fid, offset, 'bof'); + dat = fread(fid, [NT ops.Nchan], '*int16'); + else + dat = DATA(:,:,ibatch); + end + if ops.GPU + dataRAW = gpuArray(dat); + else + dataRAW = dat; + end + dataRAW = single(dataRAW); + dataRAW = dataRAW / ops.scaleproc; + + % project data in low-dim space + if ops.GPU + data = gpuArray.zeros(NT, Nfilt, Nrank, 'single'); + else + data = zeros(NT, Nfilt, Nrank, 'single'); + end + for irank = 1:Nrank + data(:,:,irank) = dataRAW * U(:,:,irank); + end + data = reshape(data, NT, Nfilt*Nrank); + + if ops.GPU + [st, id, x, errC, PCproj] ... + = mexMPmuFEAT(Params,data,W,WtW, mu, lam .* (20./mu).^2, nu); + else + [st, id, x, errC, PCproj]= cpuMPmuFEAT(Params,data,fW,WtW, mu, lam .* (20./mu).^2, nu, ops); + end + + if ~isempty(st) + if ~isempty(ops.nNeighPC) + % PCA coefficients + inds = repmat(st', nt0, 1) + repmat(i1nt0, 1, numel(st)); + try datSp = dataRAW(inds(:), :); + catch + datSp = dataRAW(inds(:), :); + end + datSp = reshape(datSp, [size(inds) Nchan]); + coefs = reshape(Wi' * reshape(datSp, nt0, []), size(Wi,2), numel(st), Nchan); + coefs = reshape(permute(coefs, [3 1 2]), [], numel(st)); + coefs = coefs .* maskPC(:, id+1); + iCoefs = reshape(find(maskPC(:, id+1)>0), 3*nNeighPC, []); + rez.cProjPC(irun + (1:numel(st)), :) = gather_try(coefs(iCoefs)'); + end + if ~isempty(ops.nNeigh) + % template coefficients + % transform coefficients + PCproj = bsxfun(@rdivide, ... + bsxfun(@plus, PCproj, LAM.*mu), sqrt(1+LAM)); + + PCproj = maskTT(:, id+1) .* PCproj; + iPP = reshape(find(maskTT(:, id+1)>0), nNeigh, []); + rez.cProj(irun + (1:numel(st)), :) = PCproj(iPP)'; + end + % increment number of spikes + irun = irun + numel(st); + + if ibatch==1; + ioffset = 0; + else + ioffset = ops.ntbuff; + end + st = st - ioffset; + + % nspikes2(1:size(W,2)+1, ibatch) = histc(id, 0:1:size(W,2)); + STT = cat(2, ops.nt0min + double(st) +(NT-ops.ntbuff)*(ibatch-1), ... + double(id)+1, double(x), ibatch*ones(numel(x),1)); + st3 = cat(1, st3, STT); + end + if rem(ibatch,100)==1 +% nsort = sort(sum(nspikes2,2), 'descend'); + fprintf(repmat('\b', 1, numel(msg))); + msg = sprintf('Time %2.2f, batch %d/%d, NTOT %d\n', ... + toc, ibatch,Nbatch, size(st3,1)); + fprintf(msg); + + end +end +%% +[~, isort] = sort(st3(:,1), 'ascend'); +st3 = st3(isort,:); + +rez.st3 = st3; +if ~isempty(ops.nNeighPC) + % re-sort coefficients for projections + rez.cProjPC(irun+1:end, :) = []; + rez.cProjPC = reshape(rez.cProjPC, size(rez.cProjPC,1), [], 3); + rez.cProjPC = rez.cProjPC(isort, :,:); + for ik = 1:Nfilt + iSp = rez.st3(:,2)==ik; + OneToN = 1:nNeighPC; + [~, isortNeigh] = sort(rez.iNeighPC(:,ik), 'ascend'); + OneToN(isortNeigh) = OneToN; + rez.cProjPC(iSp, :,:) = rez.cProjPC(iSp, OneToN, :); + end + + rez.cProjPC = permute(rez.cProjPC, [1 3 2]); +end +if ~isempty(ops.nNeigh) + rez.cProj(irun+1:end, :) = []; + rez.cProj = rez.cProj(isort, :); + + % re-index the template coefficients + for ik = 1:Nfilt + iSp = rez.st3(:,2)==ik; + OneToN = 1:nNeigh; + [~, isortNeigh] = sort(rez.iNeigh(:,ik), 'ascend'); + OneToN(isortNeigh) = OneToN; + rez.cProj(iSp, :) = rez.cProj(iSp, OneToN); + end +end + + +%% +% rez.ops = ops; +rez.W = W; +rez.U = U; +rez.mu = mu; + +rez.t2p = []; +for i = 1:Nfilt + wav0 = W(:,i,1); + wav0 = my_conv(wav0', .5)'; + [~, itrough] = min(wav0); + [~, t2p] = max(wav0(itrough:end)); + rez.t2p(i,1) = t2p; + rez.t2p(i,2) = itrough; +end + +rez.nbins = histc(rez.st3(:,2), .5:1:Nfilt+1); + +[~, rez.ypos] = max(rez.U(:,:,1), [], 1); +if Nbatch_buffcrit); +else + iNonMatch = 1:size(uS,2); + nS = []; +end \ No newline at end of file diff --git a/initialize/optimizePeaks.m b/initialize/optimizePeaks.m new file mode 100644 index 00000000..a10da20e --- /dev/null +++ b/initialize/optimizePeaks.m @@ -0,0 +1,152 @@ +% addpath('C:\CODE\GitHub\KiloSort\preDetect') +function WUinit=optimizePeaks(ops,uproj) + +nt0 = ops.nt0; + +nProj = size(uproj,2); +nSpikesPerBatch = 4000; +inds = 1:nSpikesPerBatch * floor(size(uproj,1)/nSpikesPerBatch); +inds = reshape(inds, nSpikesPerBatch, []); +% Nbatch = size(inds,2); +iperm = randperm(size(inds,2)); +miniorder = repmat(iperm, 1, ops.nfullpasses); +% miniorder = repmat([1:Nbatch Nbatch:-1:1], 1, ops.nfullpasses/2); + +if ~exist('spikes_merged') + uBase = zeros(1e4, nProj); + nS = zeros(1e4, 1); + ncurr = 1; + + for ibatch = 1:size(inds,2) + % merge in with existing templates + uS = uproj(inds(:,ibatch), :); + [nSnew, iNonMatch] = merge_spikes0(uBase(1:ncurr,:), nS(1:ncurr), uS, ops.crit); + nS(1:ncurr) = nSnew; + % + % reduce non-matches + [uNew, nSadd] = reduce_clusters0(uS(iNonMatch,:), ops.crit); + + % add new spikes to list + uBase(ncurr + [1:size(uNew,1)], :) = uNew; + nS(ncurr + [1:size(uNew,1)]) = nSadd; + + ncurr = ncurr + size(uNew,1); + + if ncurr>1e4 + break; + end + end + % + nS = nS(1:ncurr); + uBase = uBase(1:ncurr, :); + spikes_merged = 1; +end +[~, itsort] = sort(nS, 'descend'); + +%% initialize U +Nfilt = ops.Nfilt; +lam = ops.lam(1) * ones(Nfilt, 1, 'single'); + +ind_filt = itsort(rem([1:Nfilt]-1, numel(itsort)) + 1); +if ops.GPU + U = gpuArray(uBase(ind_filt, :))'; +else + U = uBase(ind_filt, :)'; +end +U = U + .001 * randn(size(U)); +mu = sum(U.^2,1)'.^.5; +U = normc(U); +% + +for i = 1:10 + + idT = zeros(size(inds)); + dWU = zeros(Nfilt, nProj, 'single'); + if ops.GPU + nToT = gpuArray.zeros(Nfilt, 1, 'single'); + Cost = gpuArray(single(0)); + else + nToT = zeros(Nfilt, 1, 'single'); + Cost = single(0); + end + + for ibatch = 1:size(inds,2) + % find clusters + if ops.GPU + clips = reshape(gpuArray(uproj(inds(:,ibatch), :)), nSpikesPerBatch, nProj); + else + clips = reshape(uproj(inds(:,ibatch), :), nSpikesPerBatch, nProj); + end + ci = clips * U; + + ci = bsxfun(@plus, ci, (mu .* lam)'); + cf = bsxfun(@rdivide, ci.^2, 1 + lam'); + cf = bsxfun(@minus, cf, (mu.^2.*lam)'); + + [max_cf, id] = max(cf, [], 2); + + id = gather_try(id); + % x = ci([1:nSpikesPerBatch] + nSpikesPerBatch * (id-1)')' - mu(id) .* lam(id); + idT(:,ibatch) = id; + + if ops.GPU + L = gpuArray.zeros(Nfilt, nSpikesPerBatch, 'single'); + else + L = zeros(Nfilt, nSpikesPerBatch, 'single'); + end + L(id' + [0:Nfilt:(Nfilt*nSpikesPerBatch-1)]) = 1; + dWU = dWU + L * clips; + + nToT = nToT + sum(L, 2); + Cost = Cost + mean(max_cf); + end + dWU = bsxfun(@rdivide, dWU, nToT); + + U = dWU'; + mu = sum(U.^2,1)'.^.5; + U = normc(U); + Cost = Cost/size(inds,2); + +% disp(Cost) + +% plot(sort(log(1+nToT))) +% drawnow +end +%% +Nchan = ops.Nchan; +Nfilt = ops.Nfilt; +Nrank = ops.Nrank; +wPCA = ops.wPCA(:,1:3); +Urec = reshape(U, Nchan, size(wPCA,2), Nfilt); + +Urec= permute(Urec, [2 1 3]); +Wrec = reshape(wPCA * Urec(:,:), nt0, Nchan, Nfilt); + +Wrec = gather_try(Wrec); + +W = zeros(nt0, Nfilt, Nrank, 'single'); +U = zeros(Nchan, Nfilt, Nrank, 'single'); + +Wrec(isnan(Wrec(:))) = 0; +for j = 1:Nfilt + [w sv u] = svd(Wrec(:,:,j)); + w = w * sv; + + Sv = diag(sv); + W(:,j,:) = w(:, 1:Nrank)/sum(Sv(1:ops.Nrank).^2).^.5; + U(:,j,:) = u(:, 1:Nrank); +end + +Uinit = U; +Winit = W; +mu = gather_try(single(mu)); +muinit = mu; + +WUinit = zeros(nt0, Nchan, Nfilt); +for j = 1:Nfilt + WUinit(:,:,j) = muinit(j) * Wrec(:,:,j); +end +WUinit = single(WUinit); +%% + + diff --git a/initialize/reduce_clusters0.m b/initialize/reduce_clusters0.m new file mode 100644 index 00000000..1081d028 --- /dev/null +++ b/initialize/reduce_clusters0.m @@ -0,0 +1,31 @@ +function [uNew, nSnew]= reduce_clusters0(uS, crit) + + cdot = uS * uS'; + +% compute norms of each spike +newNorms = sum(uS.^2, 2)'; + +% compute sum of pairs of norms +cNorms = 1e-10 + repmat(newNorms', 1, numel(newNorms)) +... + repmat(newNorms, numel(newNorms), 1); + +% compute normalized distance between spikes +cdot = 1 - 2*cdot./cNorms; +cdot = cdot + diag(Inf * diag(cdot)); + +[cmin, newind] = min(single(cdot>crit),[],1); +% if someone else votes you in, your votee doesn't count +% newind(ismember(1:nN, newind)) = []; +newind = unique(newind(cmin<.5)); +if ~isempty(newind) + newind = cat(2, newind, find(cmin>.5)); +else + newind = find(cmin>.5); +end + + +uNew = uS(newind, :); + +nNew = size(uNew,1); + +nSnew = merge_spikes0(uNew, zeros(nNew, 1), uS, crit); \ No newline at end of file diff --git a/licence.txt b/licence.txt new file mode 100644 index 00000000..fe3deb27 --- /dev/null +++ b/licence.txt @@ -0,0 +1,339 @@ + GNU GENERAL PUBLIC LICENSE + Version 2, June 1991 + + Copyright (C) 1989, 1991 Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The licenses for most software are designed to take away your +freedom to share and change it. By contrast, the GNU General Public +License is intended to guarantee your freedom to share and change free +software--to make sure the software is free for all its users. This +General Public License applies to most of the Free Software +Foundation's software and to any other program whose authors commit to +using it. (Some other Free Software Foundation software is covered by +the GNU Lesser General Public License instead.) You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +this service if you wish), that you receive source code or can get it +if you want it, that you can change the software or use pieces of it +in new free programs; and that you know you can do these things. + + To protect your rights, we need to make restrictions that forbid +anyone to deny you these rights or to ask you to surrender the rights. +These restrictions translate to certain responsibilities for you if you +distribute copies of the software, or if you modify it. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must give the recipients all the rights that +you have. You must make sure that they, too, receive or can get the +source code. And you must show them these terms so they know their +rights. + + We protect your rights with two steps: (1) copyright the software, and +(2) offer you this license which gives you legal permission to copy, +distribute and/or modify the software. + + Also, for each author's protection and ours, we want to make certain +that everyone understands that there is no warranty for this free +software. If the software is modified by someone else and passed on, we +want its recipients to know that what they have is not the original, so +that any problems introduced by others will not reflect on the original +authors' reputations. + + Finally, any free program is threatened constantly by software +patents. We wish to avoid the danger that redistributors of a free +program will individually obtain patent licenses, in effect making the +program proprietary. To prevent this, we have made it clear that any +patent must be licensed for everyone's free use or not licensed at all. + + The precise terms and conditions for copying, distribution and +modification follow. + + GNU GENERAL PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. This License applies to any program or other work which contains +a notice placed by the copyright holder saying it may be distributed +under the terms of this General Public License. The "Program", below, +refers to any such program or work, and a "work based on the Program" +means either the Program or any derivative work under copyright law: +that is to say, a work containing the Program or a portion of it, +either verbatim or with modifications and/or translated into another +language. (Hereinafter, translation is included without limitation in +the term "modification".) Each licensee is addressed as "you". + +Activities other than copying, distribution and modification are not +covered by this License; they are outside its scope. The act of +running the Program is not restricted, and the output from the Program +is covered only if its contents constitute a work based on the +Program (independent of having been made by running the Program). +Whether that is true depends on what the Program does. + + 1. You may copy and distribute verbatim copies of the Program's +source code as you receive it, in any medium, provided that you +conspicuously and appropriately publish on each copy an appropriate +copyright notice and disclaimer of warranty; keep intact all the +notices that refer to this License and to the absence of any warranty; +and give any other recipients of the Program a copy of this License +along with the Program. + +You may charge a fee for the physical act of transferring a copy, and +you may at your option offer warranty protection in exchange for a fee. + + 2. You may modify your copy or copies of the Program or any portion +of it, thus forming a work based on the Program, and copy and +distribute such modifications or work under the terms of Section 1 +above, provided that you also meet all of these conditions: + + a) You must cause the modified files to carry prominent notices + stating that you changed the files and the date of any change. + + b) You must cause any work that you distribute or publish, that in + whole or in part contains or is derived from the Program or any + part thereof, to be licensed as a whole at no charge to all third + parties under the terms of this License. + + c) If the modified program normally reads commands interactively + when run, you must cause it, when started running for such + interactive use in the most ordinary way, to print or display an + announcement including an appropriate copyright notice and a + notice that there is no warranty (or else, saying that you provide + a warranty) and that users may redistribute the program under + these conditions, and telling the user how to view a copy of this + License. (Exception: if the Program itself is interactive but + does not normally print such an announcement, your work based on + the Program is not required to print an announcement.) + +These requirements apply to the modified work as a whole. If +identifiable sections of that work are not derived from the Program, +and can be reasonably considered independent and separate works in +themselves, then this License, and its terms, do not apply to those +sections when you distribute them as separate works. But when you +distribute the same sections as part of a whole which is a work based +on the Program, the distribution of the whole must be on the terms of +this License, whose permissions for other licensees extend to the +entire whole, and thus to each and every part regardless of who wrote it. + +Thus, it is not the intent of this section to claim rights or contest +your rights to work written entirely by you; rather, the intent is to +exercise the right to control the distribution of derivative or +collective works based on the Program. + +In addition, mere aggregation of another work not based on the Program +with the Program (or with a work based on the Program) on a volume of +a storage or distribution medium does not bring the other work under +the scope of this License. + + 3. You may copy and distribute the Program (or a work based on it, +under Section 2) in object code or executable form under the terms of +Sections 1 and 2 above provided that you also do one of the following: + + a) Accompany it with the complete corresponding machine-readable + source code, which must be distributed under the terms of Sections + 1 and 2 above on a medium customarily used for software interchange; or, + + b) Accompany it with a written offer, valid for at least three + years, to give any third party, for a charge no more than your + cost of physically performing source distribution, a complete + machine-readable copy of the corresponding source code, to be + distributed under the terms of Sections 1 and 2 above on a medium + customarily used for software interchange; or, + + c) Accompany it with the information you received as to the offer + to distribute corresponding source code. (This alternative is + allowed only for noncommercial distribution and only if you + received the program in object code or executable form with such + an offer, in accord with Subsection b above.) + +The source code for a work means the preferred form of the work for +making modifications to it. For an executable work, complete source +code means all the source code for all modules it contains, plus any +associated interface definition files, plus the scripts used to +control compilation and installation of the executable. However, as a +special exception, the source code distributed need not include +anything that is normally distributed (in either source or binary +form) with the major components (compiler, kernel, and so on) of the +operating system on which the executable runs, unless that component +itself accompanies the executable. + +If distribution of executable or object code is made by offering +access to copy from a designated place, then offering equivalent +access to copy the source code from the same place counts as +distribution of the source code, even though third parties are not +compelled to copy the source along with the object code. + + 4. You may not copy, modify, sublicense, or distribute the Program +except as expressly provided under this License. Any attempt +otherwise to copy, modify, sublicense or distribute the Program is +void, and will automatically terminate your rights under this License. +However, parties who have received copies, or rights, from you under +this License will not have their licenses terminated so long as such +parties remain in full compliance. + + 5. You are not required to accept this License, since you have not +signed it. However, nothing else grants you permission to modify or +distribute the Program or its derivative works. These actions are +prohibited by law if you do not accept this License. Therefore, by +modifying or distributing the Program (or any work based on the +Program), you indicate your acceptance of this License to do so, and +all its terms and conditions for copying, distributing or modifying +the Program or works based on it. + + 6. Each time you redistribute the Program (or any work based on the +Program), the recipient automatically receives a license from the +original licensor to copy, distribute or modify the Program subject to +these terms and conditions. You may not impose any further +restrictions on the recipients' exercise of the rights granted herein. +You are not responsible for enforcing compliance by third parties to +this License. + + 7. If, as a consequence of a court judgment or allegation of patent +infringement or for any other reason (not limited to patent issues), +conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot +distribute so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you +may not distribute the Program at all. For example, if a patent +license would not permit royalty-free redistribution of the Program by +all those who receive copies directly or indirectly through you, then +the only way you could satisfy both it and this License would be to +refrain entirely from distribution of the Program. + +If any portion of this section is held invalid or unenforceable under +any particular circumstance, the balance of the section is intended to +apply and the section as a whole is intended to apply in other +circumstances. + +It is not the purpose of this section to induce you to infringe any +patents or other property right claims or to contest validity of any +such claims; this section has the sole purpose of protecting the +integrity of the free software distribution system, which is +implemented by public license practices. Many people have made +generous contributions to the wide range of software distributed +through that system in reliance on consistent application of that +system; it is up to the author/donor to decide if he or she is willing +to distribute software through any other system and a licensee cannot +impose that choice. + +This section is intended to make thoroughly clear what is believed to +be a consequence of the rest of this License. + + 8. If the distribution and/or use of the Program is restricted in +certain countries either by patents or by copyrighted interfaces, the +original copyright holder who places the Program under this License +may add an explicit geographical distribution limitation excluding +those countries, so that distribution is permitted only in or among +countries not thus excluded. In such case, this License incorporates +the limitation as if written in the body of this License. + + 9. The Free Software Foundation may publish revised and/or new versions +of the General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + +Each version is given a distinguishing version number. If the Program +specifies a version number of this License which applies to it and "any +later version", you have the option of following the terms and conditions +either of that version or of any later version published by the Free +Software Foundation. If the Program does not specify a version number of +this License, you may choose any version ever published by the Free Software +Foundation. + + 10. If you wish to incorporate parts of the Program into other free +programs whose distribution conditions are different, write to the author +to ask for permission. For software which is copyrighted by the Free +Software Foundation, write to the Free Software Foundation; we sometimes +make exceptions for this. Our decision will be guided by the two goals +of preserving the free status of all derivatives of our free software and +of promoting the sharing and reuse of software generally. + + NO WARRANTY + + 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY +FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN +OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES +PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED +OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS +TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE +PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, +REPAIR OR CORRECTION. + + 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR +REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, +INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING +OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED +TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY +YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER +PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE +POSSIBILITY OF SUCH DAMAGES. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +convey the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +Also add information on how to contact you by electronic and paper mail. + +If the program is interactive, make it output a short notice like this +when it starts in an interactive mode: + + Gnomovision version 69, Copyright (C) year name of author + Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, the commands you use may +be called something other than `show w' and `show c'; they could even be +mouse-clicks or menu items--whatever suits your program. + +You should also get your employer (if you work as a programmer) or your +school, if any, to sign a "copyright disclaimer" for the program, if +necessary. Here is a sample; alter the names: + + Yoyodyne, Inc., hereby disclaims all copyright interest in the program + `Gnomovision' (which makes passes at compilers) written by James Hacker. + + , 1 April 1989 + Ty Coon, President of Vice + +This General Public License does not permit incorporating your program into +proprietary programs. If your program is a subroutine library, you may +consider it more useful to permit linking proprietary applications with the +library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License.s \ No newline at end of file diff --git a/mainLoop/alignW.m b/mainLoop/alignW.m new file mode 100644 index 00000000..ec053510 --- /dev/null +++ b/mainLoop/alignW.m @@ -0,0 +1,20 @@ +function W = alignW(W, ops) + +[nt0 , Nfilt] = size(W); + + +[~, imax] = min(W, [], 1); +dmax = -(imax - ops.nt0min); +% dmax = min(1, abs(dmax)) .* sign(dmax); + +for i = 1:Nfilt + if dmax(i)>0 + W((dmax(i) + 1):nt0, i) = W(1:nt0-dmax(i), i); + else + W(1:nt0+dmax(i), i) = W((1-dmax(i)):nt0, i); + end +end + + + + diff --git a/mainLoop/alignWU.m b/mainLoop/alignWU.m new file mode 100644 index 00000000..8f6057e6 --- /dev/null +++ b/mainLoop/alignWU.m @@ -0,0 +1,31 @@ +function WU = alignWU(WU, ops) + +[nt0 , Nchan, Nfilt] = size(WU); +[~, imin] = min(reshape(WU, nt0*Nchan, Nfilt), [], 1); + +iMinChan = ceil(imin/nt0); + + +% imin = rem(imin-1, nt0) + 1; + +% [~, imax] = min(W, [], 1); +% dmax = -(imin - 20); +% dmax = min(1, abs(dmax)) .* sign(dmax); + +dmax = zeros(Nfilt, 1); +for i = 1:Nfilt + wu = WU(:,iMinChan(i),i); +% [~, imin] = min(diff(wu, 1)); + [~, imin] = min(wu); + dmax(i) = - (imin- ops.nt0min); + + if dmax(i)>0 + WU((dmax(i) + 1):nt0, :,i) = WU(1:nt0-dmax(i),:, i); + else + WU(1:nt0+dmax(i),:, i) = WU((1-dmax(i)):nt0,:, i); + end +end + + + + diff --git a/mainLoop/decompose_dWU.m b/mainLoop/decompose_dWU.m new file mode 100644 index 00000000..35aaa17e --- /dev/null +++ b/mainLoop/decompose_dWU.m @@ -0,0 +1,38 @@ +function [W, U, mu, UtU, nu] = decompose_dWU(ops, dWU, Nrank, kcoords) + +[nt0 Nchan Nfilt] = size(dWU); + +W = zeros(nt0, Nrank, Nfilt, 'single'); +U = zeros(Nchan, Nrank, Nfilt, 'single'); +mu = zeros(Nfilt, 1, 'single'); +% dmax = zeros(Nfilt, 1); + +dWU(isnan(dWU)) = 0; +if ops.parfor + parfor k = 1:Nfilt + [W(:,:,k), U(:,:,k), mu(k)] = get_svds(dWU(:,:,k), Nrank); + end +else + for k = 1:Nfilt + [W(:,:,k), U(:,:,k), mu(k)] = get_svds(dWU(:,:,k), Nrank); + end +end +U = permute(U, [1 3 2]); +W = permute(W, [1 3 2]); + +U(isnan(U)) = 0; + +if numel(unique(kcoords))>1 + U = zeroOutKcoords(U, kcoords, ops.criterionNoiseChannels); +end + +UtU = abs(U(:,:,1)' * U(:,:,1)) > .1; + + +Wdiff = cat(1, W, zeros(2, Nfilt, Nrank)) - cat(1, zeros(2, Nfilt, Nrank), W); +nu = sum(sum(Wdiff.^2,1),3); +nu = nu(:); + + + +% mu = min(mu, 200); \ No newline at end of file diff --git a/mainLoop/get_svds.m b/mainLoop/get_svds.m new file mode 100644 index 00000000..9062eae4 --- /dev/null +++ b/mainLoop/get_svds.m @@ -0,0 +1,27 @@ +function [W, U, mu] = get_svds(dWU, Nrank) + +[Wall, Sv, Uall] = svd(gather_try(dWU), 0); +[~, imax] = max(abs(Wall(:,1))); +Uall(:,1) = -Uall(:,1) * sign(Wall(imax,1)); +Wall(:,1) = -Wall(:,1) * sign(Wall(imax,1)); + +% [~, imin] = min(diff(Wall(:,1), 1)); +% [~, imin] = min(Wall(:,1)); +% dmax(k) = - (imin- 20); + +% if dmax(k)>0 +% dWU((dmax(k) + 1):nt0, :,k) = dWU(1:nt0-dmax(k),:, k); +% Wall((dmax(k) + 1):nt0, :) = Wall(1:nt0-dmax(k),:); +% else +% dWU(1:nt0+dmax(k),:, k) = dWU((1-dmax(k)):nt0,:, k); +% Wall(1:nt0+dmax(k),:) = Wall((1-dmax(k)):nt0,:); +% end + +Wall = Wall * Sv; + +Sv = diag(Sv); +mu = sum(Sv(1:Nrank).^2).^.5; +Wall = Wall/mu; + +W = Wall(:,1:Nrank); +U = Uall(:,1:Nrank); \ No newline at end of file diff --git a/mainLoop/initialize_waves0.m b/mainLoop/initialize_waves0.m new file mode 100644 index 00000000..4fe3c52d --- /dev/null +++ b/mainLoop/initialize_waves0.m @@ -0,0 +1,77 @@ +clear W + +tps = [1 5 10 25 40 50 61]; +tps = round(tps * nt0/61); +vs = [0 0 0 -2 1 0 0]; +fs= interp1(tps, vs, 1:nt0, 'linear', 'extrap'); +W(:,1,1) = my_conv(fs, 2); + +tps = [1 5 10 25 40 50 61]; +tps = round(tps * nt0/61); +vs = [0 0 0 -2 1 0 0]; +fs= interp1(tps, vs, 1:nt0, 'linear', 'extrap'); +W(:,1,2) = my_conv(fs, 2); + +tps = [1 5 10 15 25 50 61]; +tps = round(tps * nt0/61); +vs = [0 0 0 -2 1 0 0]; +fs= interp1(tps, vs, 1:nt0, 'linear', 'extrap'); +W(:,1,3) = my_conv(fs, 2); + +tps = [1 5 10 20 30 50 61]; +tps = round(tps * nt0/61); +vs = [0 0 0 -2 1 0 0]; +fs= interp1(tps, vs, 1:nt0, 'linear', 'extrap'); +W(:,1,4) = my_conv(fs, 2); + +tps = [1 5 10 25 40 50 61]; +tps = round(tps * nt0/61); +vs = [0 0 0 -2 0 0 0]; +fs= interp1(tps, vs, 1:nt0, 'linear', 'extrap'); +W(:,1,5) = my_conv(fs, 2); + +tps = [1 5 10 25 40 50 61]; +tps = round(tps * nt0/61); +vs = [0 0 0 -2 0 0 0]; +fs= interp1(tps, vs, 1:nt0, 'linear', 'extrap'); +W(:,1,6) = my_conv(fs, 2); + +tps = [1 5 10 15 25 50 61]; +tps = round(tps * nt0/61); +vs = [0 0 0 -2 0 0 0]; +fs= interp1(tps, vs, 1:nt0, 'linear', 'extrap'); +W(:,1,7) = my_conv(fs, 2); + +tps = [1 5 10 20 30 50 61]; +tps = round(tps * nt0/61); +vs = [0 0 0 -2 0 0 0]; +fs= interp1(tps, vs, 1:nt0, 'linear', 'extrap'); +W(:,1,8) = my_conv(fs, 2); + + +W = (single(W)); +W = squeeze(W); + + +W = W(:, repmat([1 2 3 4 1 2 3 4], ceil(Nfilt/8), 1)) + .1 * my_conv(randn(nt0, 8*ceil(Nfilt/8), 'single')', 5)'; +U = repmat(eye(Nchan, Nchan), 1, ceil(Nfilt/Nchan)); +U = U(:, 1:Nfilt); +U = my_conv(single(U)', 2)'; +U = U .* (1 + .05 * randn(size(U))); +U(abs(U)<.01) = 0; +U = single(U); + + +if 1<0 + W = randn(nt0, Nfilt, 'single'); + U = randn(Nchan, Nfilt, 'single'); + + W = single(my_conv(W', 10)'); + U = single(my_conv(U', 10)'); + + W = normc(W); + U = normc(U); +end + +Uinit = normc(U); +Winit = normc(W); diff --git a/mainLoop/merge_spikes_in.m b/mainLoop/merge_spikes_in.m new file mode 100644 index 00000000..35eb006e --- /dev/null +++ b/mainLoop/merge_spikes_in.m @@ -0,0 +1,35 @@ +function [nS, iNonMatch] = merge_spikes_in(uBase, nS, uS, crit) + +if ~isempty(uBase) + cdot = uBase(:,:,1)' * uS(:,:,1); + for j = 2:size(uBase,3) + cdot = cdot + uBase(:,:,j)' * uS(:,:,j); + end + + baseNorms = sum(sum(uBase.^2, 3),1); + newNorms = sum(sum(uS.^2, 3),1); + + cNorms = 1e-10 + repmat(baseNorms', 1, numel(newNorms)) + repmat(newNorms, numel(baseNorms), 1); + + cdot = 1 - 2*cdot./cNorms; + + [cdotmin, imin] = min(cdot, [], 1); + + iMatch = cdotmincrit); +else + iNonMatch = 1:size(uS,2); + nS = []; +end \ No newline at end of file diff --git a/mainLoop/mexMPregMUcpu.m b/mainLoop/mexMPregMUcpu.m new file mode 100644 index 00000000..25de67c2 --- /dev/null +++ b/mainLoop/mexMPregMUcpu.m @@ -0,0 +1,49 @@ +function [dWU, st, id, x,Cost, nsp] = ... + mexMPregMUcpu(Params,dataRAW,fW,data,UtU,mu, lam , dWU, nu, ops) + +nt0 = ops.nt0; +NT = Params(1); +nFilt = Params(2); +Th = Params(3); + +pm = Params(8); +fdata = fft(data, [], 1); +proj = real(ifft(fdata .* fW(:,:), [], 1)); + +if ops.Nrank > 1 + proj = sum(reshape(proj, NT, nFilt, ops.Nrank),3); +end +Ci = bsxfun(@plus, proj, (mu.*lam)'); +Ci = bsxfun(@rdivide, Ci.^2, 1 + lam'); +Ci = bsxfun(@minus, Ci, (lam .* mu.^2)'); + +[mX, id] = max(Ci,[], 2); + +maX = -my_min(-mX, 31, 1); +id = int32(id); + +st = find((maX < mX + 1e-3) & mX > Th*Th); +st(st>NT-nt0) = []; + +id = id(st); +x = []; +Cost = []; +nsp = []; + +if ~isempty(id) + inds = bsxfun(@plus, st', [1:nt0]'); + dspk = reshape(dataRAW(inds, :), nt0, numel(st), ops.Nchan); + dspk = permute(dspk, [1 3 2]); + + x = zeros(size(id)); + Cost = zeros(size(id)); + nsp = zeros(nFilt,1); + for j = 1:size(dspk,3) + dWU(:,:,id(j)) = pm * dWU(:,:,id(j)) + (1-pm) * dspk(:,:,j); + x(j) = proj(st(j), id(j)); + Cost(j) = maX(st(j)); + nsp(id(j)) = nsp(id(j)) + 1; + end + + id = id - 1; +end \ No newline at end of file diff --git a/mainLoop/reduce_clusters.m b/mainLoop/reduce_clusters.m new file mode 100644 index 00000000..19afcc6a --- /dev/null +++ b/mainLoop/reduce_clusters.m @@ -0,0 +1,35 @@ +function [uNew, nSnew]= reduce_clusters(uS, crit) + + +cdot = uS(:,:,1)' * uS(:,:,1); +for j = 2:size(uS,3) + cdot = cdot + uS(:,:,j)' * uS(:,:,j); +end + +% compute norms of each spike +newNorms = sum(sum(uS.^2, 3),1); + +% compute sum of pairs of norms +cNorms = 1e-10 + repmat(newNorms', 1, numel(newNorms)) +... + repmat(newNorms, numel(newNorms), 1); + +% compute normalized distance between spikes +cdot = 1 - 2*cdot./cNorms; +cdot = cdot + diag(Inf * diag(cdot)); + +[cmin, newind] = min(single(cdot>crit),[],1); +% if someone else votes you in, your votee doesn't count +% newind(ismember(1:nN, newind)) = []; +newind = unique(newind(cmin<.5)); +if ~isempty(newind) + newind = cat(2, newind, find(cmin>.5)); +else + newind = find(cmin>.5); +end + + +uNew = uS(:,newind, :); + +nNew = size(uNew,2); + +nSnew = merge_spikes_in(uNew, zeros(nNew, 1), uS, crit); \ No newline at end of file diff --git a/mainLoop/update_params.m b/mainLoop/update_params.m new file mode 100644 index 00000000..1b5607a9 --- /dev/null +++ b/mainLoop/update_params.m @@ -0,0 +1,52 @@ +function [W, U, mu, UtU] = update_params(mu, W, U, dWUtot, nspikes) + +[Nchan, Nfilt, Nrank] = size(U); + +dWUtotCPU = gather_try(dWUtot); +ntot = sum(nspikes,2); + +for k = 1:Nfilt + if ntot(k)>5 + + [Uall, Sv, Vall] = svd(gather_try(dWUtotCPU(:,:,k)), 0); + Sv = diag(Sv); + sumSv2 = sum(Sv(1:Nrank).^2).^.5; + for irank = 1:Nrank + [~, imax] = max(abs(Uall(:,irank)), [], 1); + W(:,k,irank) = - Uall(:,irank) * sign(Uall(imax,irank)) * Sv(irank)/sumSv2; + U(:,k,irank) = - Vall(:,irank) * sign(Uall(imax,irank)); + end + mmax = max(abs(U(:,k,1))); + Usize = squeeze(abs(U(:,k,:))); + Usize = Usize .* repmat(Sv(1:Nrank)'/Sv(1), Nchan, 1); + ibad = max(Usize, [], 2) < .1 * mmax; + + U(ibad,k,:) = 0; + end +end + +% mu = zeros(Nfilt,1, 'single'); +for k = 1:Nfilt + if ntot(k)>5 + wu = squeeze(W(:,k,:)) * squeeze(U(:,k,:))'; + mu(k) = sum(sum(wu.*squeeze(dWUtotCPU(:,:,k)))); + end +end + +for k = 1:Nfilt + if ntot(k)>5 + wu = squeeze(W(:,k,:)) * squeeze(U(:,k,:))'; + newnorm = sum(wu(:).^2).^.5; + W(:,k,:) = W(:,k,:)/newnorm; + end +end + +% compute adjacency matrix UtU +U(isnan(U)) = 0; +U0 = gpuArray(U); +utu = gpuArray.zeros(Nfilt, 'single'); +for irank = 1:Nrank + utu = utu + (U0(:,:,irank)' * U0(:,:,irank)); +end + +UtU = logical(utu); diff --git a/mainLoop/zeroOutKcoords.m b/mainLoop/zeroOutKcoords.m new file mode 100644 index 00000000..868c3415 --- /dev/null +++ b/mainLoop/zeroOutKcoords.m @@ -0,0 +1,39 @@ +function U = zeroOutKcoords(U, kcoords, criterionNoiseChannels) + +[M, imax] = max(abs(U(:,:,1)), [], 1); + +% determine over how many channel groups each template exists +aU = sum(U.^2,3).^.5; +ngroups = max(kcoords(:)); + +aUgroups = zeros(ngroups, size(U,2)); +for j = 1:ngroups + aUgroups(j, :) = mean(aU(kcoords==j,:), 1); +end + +% the "effective" number of channel groups is defined below. +% for cases when X channel groups have equal non-zero weights, this number +% equals X +nEffective = sum(aUgroups,1).^2./sum(aUgroups.^2, 1); + +[nEffSort, isort] = sort(nEffective, 'descend'); + +if criterionNoiseChannels<1 + % if this criterion is less than 1, it will be treated as a fraction + % of the total number of clusters + nNoise = ceil(criterionNoiseChannels * size(U,2)); + ThLocal = nEffSort(nNoise); +else + % if this criterion is larger than 1, it will be treated as the + % effective number of channel groups at which to set the threshold + ThLocal = criterionNoiseChannels; +end + + +for i = 1:size(U,2) + if ThLocal > nEffective(i) + U(kcoords~=kcoords(imax(i)),i,:) = 0; + U(:,i,:) = normc(squeeze(U(:,i,:))); + end +end + diff --git a/master_file_example_MOVEME.m b/master_file_example_MOVEME.m new file mode 100644 index 00000000..1db07922 --- /dev/null +++ b/master_file_example_MOVEME.m @@ -0,0 +1,34 @@ +% default options are in parenthesis after the comment + +addpath(genpath('D:\CODE\GitHub\KiloSort')) % path to kilosort folder +addpath(genpath('D:\CODE\GitHub\npy-matlab')) % path to npy-matlab scripts + +pathToYourConfigFile = 'D:\CODE\Kilosort\configFiles'; % take from Github folder and put it somewhere else (together with the master_file) +run(fullfile(pathToYourConfigFile, 'StandardConfig_MOVEME.m')) + +tic; % start timer +% +if ops.GPU + gpuDevice(1); % initialize GPU (will erase any existing GPU arrays) +end + +if strcmp(ops.datatype , 'openEphys') + ops = convertOpenEphysToRawBInary(ops); % convert data, only for OpenEphys +end +% +[rez, DATA, uproj] = preprocessData(ops); % preprocess data and extract spikes for initialization +rez = fitTemplates(rez, DATA, uproj); % fit templates iteratively +rez = fullMPMU(rez, DATA);% extract final spike times (overlapping extraction) + +% AutoMerge. rez2Phy will use for clusters the new 5th column of st3 if you run this) +% rez = merge_posthoc2(rez); + +% save matlab results file +save(fullfile(ops.root, 'rez.mat'), 'rez', '-v7.3'); + +% save python results file for Phy +rezToPhy(rez, ops.root); + +% remove temporary file +delete(ops.fproc); +%% diff --git a/merge_posthoc2.m b/merge_posthoc2.m new file mode 100644 index 00000000..feee9eff --- /dev/null +++ b/merge_posthoc2.m @@ -0,0 +1,122 @@ +function rez = merge_posthoc2(rez) +%fracse = 0.1; +mu = rez.mu; +fracse = rez.ops.fracse; + +ops = rez.ops; +LAM = ops.lam(3) * (20./mu).^2; +Nfilt = rez.ops.Nfilt; + +Cmerge = Inf *ones(Nfilt); +tfi = rez.iNeigh; +tf = rez.cProj; +clusterIDs = rez.st3(:,2); +% +nSpikes = size(rez.st3,1); + +fmax = zeros(nSpikes,1, 'single'); +pairs = {}; +for testID = 1:Nfilt + spikesTest = clusterIDs==testID; +% tfnew = bsxfun(@plus, tf(spikesTest, :), LAM(tfi(:, testID))'.*mu(tfi(:, testID))'); +% tf(spikesTest, :) = bsxfun(@rdivide, tfnew, sqrt(1+LAM(tfi(:, testID)))'); + + pp = tfi(:, testID); + pp(pp==testID) = []; + pairs{testID} = pp; + [~, isame] = min( abs(tfi(:, testID)-testID)); + fmax(spikesTest, 1) = tf(spikesTest, isame); +end + + +% +inewclust = 0; +clear iMegaC +picked = zeros(Nfilt, 1); +% tic +while 1 + [maxseed, iseed] = max(rez.nbins(1:Nfilt) .* (1-picked), [], 1); +% [maxseed, iseed] = max(mu(1:Nfilt) .* (1-picked), [], 1); + if maxseed<500 + break; + end + picked(iseed) = 1; + % 21, 69, + % +% iseed = 410; + + run_list = [iseed]; + pair_list = pairs{iseed}; + strun = find(clusterIDs==iseed); + + + while ~isempty(pair_list) + % +% picked_pairs = rez.nbins(pair_list); + + [mmax, ipair] = max(rez.nbins(pair_list)); + + + if mmax<100 + break; + end + + ipair = pair_list(ipair); + + % + imm = ismember(tfi(:, ipair), run_list); + if sum(imm) + % + new_spikes = find(clusterIDs==ipair); + f1new = max(tf(new_spikes, imm), [], 2); + + f2new = fmax(new_spikes); + + f1old = fmax(strun); + f2old = NaN * ones(numel(f1old), 1, 'single'); + i0 = 0; + for j = 1:length(run_list) + ifeat = find(tfi(:, run_list(j))==ipair); + if ~isempty(ifeat) + f2old(i0 + (1:rez.nbins(run_list(j))),1) = ... + tf(clusterIDs==run_list(j), ifeat); + i0 = i0 + rez.nbins(run_list(j)); + end + end + + f1old(isnan(f2old))=[]; + f2old(isnan(f2old))=[]; + mo = merging_score(f1old - f2old, f1new-f2new, ops.fracse); + + + if mo<3 + strun = cat(1, strun, new_spikes); + run_list(end+1) = ipair; + picked(ipair) = 1; + if mmax>300 + pair_list = unique(cat(1, pair_list, pairs{ipair})); + pair_list(ismember(pair_list, run_list)) = []; + end + end + end + pair_list(pair_list==ipair) = []; + end + + inewclust = inewclust + 1; + + iMegaC{inewclust} = run_list; +% [sum(picked) run_list] +end + +% toc +% + +iMega = zeros(Nfilt, 1); +for i = 1:length(iMegaC) + iMega(iMegaC{i}) = iMegaC{i}(1); +end +rez.iMega = iMega; +rez.iMegaC = iMegaC; + + +rez.st3(:,5) = iMega(rez.st3(:,2)); \ No newline at end of file diff --git a/mergesplits/distance_betwxt.m b/mergesplits/distance_betwxt.m new file mode 100644 index 00000000..e6f6ad4e --- /dev/null +++ b/mergesplits/distance_betwxt.m @@ -0,0 +1,21 @@ +function [d2d, iY, drez] = distance_betwxt(dWU) +[nt0, Nchan, Nfilt] = size(dWU); + +dWU = reshape(dWU, nt0*Nchan, Nfilt); +d2d = dWU' * dWU; + +mu = sum(dWU.^2,1).^.5; +mu = mu'; + +muall2 = repmat(mu.^2, 1, Nfilt); +d2d = 1 - 2 * d2d./(1e-30 + muall2+ muall2'); + +d2d = 1- triu(1 - d2d, 1); + +[dMin, iY] = min(d2d, [], 1); + +drez = dMin; + + +end + diff --git a/mergesplits/merging_score.m b/mergesplits/merging_score.m new file mode 100644 index 00000000..5062e501 --- /dev/null +++ b/mergesplits/merging_score.m @@ -0,0 +1,24 @@ +function steps = merging_score(fold, fnew, fracse) + + +troughToPeakRatio = 3; + +l1 = min(fnew); +l2 = max(fold); + +se = (std(fold) + std(fnew))/2; +se25 = fracse * se; +b2 = [0:se25:-l1]; +b1 = [0:se25:l2]; + +hs1 = my_conv(histc(fold, b1), 1); +hs2 = my_conv(histc(-fnew, b2), 1); + +mmax = min(max(hs1), max(hs2)); + +m1 = ceil(mean(fold)/se25); +m2 = -ceil(mean(fnew)/se25); + +steps = sum(hs1(1:m1)splitT); + +mu = sum(sum(dWUtot.^2,1),2).^.5; +mu = mu(:); +freeInd = find(nSpikes<200 | mu'<10 | isnan(mu')); + +for k = 1:nmerged + % merge the two clusters + iMerged = iY(isort(k)); + wt = [nSpikes(iMerged); nSpikes(isort(k))]; + wt = wt/sum(wt); +% mu(iMerged) = [mu(iMerged) mu(isort(k))] * wt; + + dWUtot(:,:,iMerged) = dWUtot(:,:,iMerged) * wt(1) + dWUtot(:,:,isort(k)) * wt(2); + dWUtot(:,:,isort(k)) = 1e-10; + + nspikes(iMerged, :) = nspikes(iMerged, :) + nspikes(isort(k), :); + nspikes(isort(k), :) = 0; +end + + +for k = 1:min(nmerged+numel(freeInd), nsplit) + if k<=numel(freeInd) + inew= freeInd(k); + else + inew = isort(k - numel(freeInd)); + end + + mu0 = mu(iY1(k)); + + % split the bimodal cluster, overwrite merged cluster + mu(inew) = mu1(k); + mu(iY1(k)) = mu2(k); + + dbins(:, inew) = u1(:, k) /Nbatch; + dbins(:, iY1(k)) = u2(:, k) /Nbatch; + + nspikes(inew, :) = nspikes(iY1(k), :)/2; + nspikes(iY1(k), :) = nspikes(iY1(k), :)/2; + dWUtot(:,:,inew) = mu1(k)/mu0 * dWUtot(:,:,iY1(k)); %/npm(iY1(k)); + dWUtot(:,:,iY1(k)) = mu2(k)/mu0 * dWUtot(:,:,iY1(k)); %/npm(iY1(k)); +end + +d2d = pairwise_dists(dWUtot, WUinit); +dmatch = min(d2d, [], 1); + +[~, inovel] = sort(dmatch, 'descend'); +% inovel = find(dmatch(1:1000)>.4); +% inovel = inovel(randperm(numel(inovel))); + +i0 = 0; + +for k = 1+min(nmerged+numel(freeInd), nsplit):nmerged+numel(freeInd) + % add new clusters + i0 = i0 + 1; + if i0>numel(inovel) + break; + end + if k<=numel(freeInd) + inew= freeInd(k); + else + inew = isort(k - numel(freeInd)); + end + + dbins(:, inew) = 1; + + nspikes(inew, :) = 1/8; + + + dWUtot(:,:,inew) = WUinit(:,:,inovel(i0)); %ratio * mu1(k)/mu0 * dWUtot(:,:,iY1(k)); + +end + +nswitch = [min(nmerged, nsplit) i0]; %min(nmerged+numel(freeInd), nsplit); + diff --git a/mergesplits/split_clust.m b/mergesplits/split_clust.m new file mode 100644 index 00000000..843b8fc4 --- /dev/null +++ b/mergesplits/split_clust.m @@ -0,0 +1,78 @@ +function [score, iY, mu1, mu2, u1, u2] = split_clust(uu, nhist) + +nhist = nhist(:); + +nspikes = sum(uu, 1); + +uc = zeros(size(uu)); +for i = 1:size(uu,2) + uc(:,i) = my_conv(uu(:,i)', max(.5, min(4, 2000/nspikes(i))))'; %.5 +% uc(:,i) = my_conv2(uu(:,i), max(.25, min(4, 2000/nspikes(i))), 1); +end +% +uc = uc ./repmat(sum(uc,1),size(uc,1), 1); +ucum = cumsum(uc, 1); +% +dd = diff(uc, 1); + +iY = zeros(1000,1); +mu1 = zeros(1000,1); +mu2 = zeros(1000,1); +var1 = zeros(1000,1); +var2 = zeros(1000,1); +u1 = zeros(size(uu,1), 1000); +u2 = zeros(size(uu,1), 1000); + +maxM = max(uc, [], 1); + +inew = 0; + +Nfilt = size(uu,2); +mu0 = sum(repmat(nhist(1:100, 1), 1, Nfilt) .* uc, 1); +var0 = sum((repmat(nhist(1:100), 1, Nfilt) - repmat(mu0, 100, 1)).^2 .* uc, 1); + +for i = 1:Nfilt + ix = find(dd(1:end-1, i)<0 & dd(2:end, i)>0); + + ix = ix(ucum(ix, i)>.1 & ucum(ix, i)<.8 & uc(ix,i)<.8 * maxM(i)); %.9 not .95 + if nspikes(i) > 500 && numel(ix)>0 + ix = ix(1); + + inew = inew + 1; + + normuc = sum(uc(1:ix, i)); + mu1(inew) = sum(nhist(1:ix) .* uc(1:ix, i)) /normuc; + mu2(inew) = sum(nhist(1+ix:100) .* uc(1+ix:100, i))/(1-normuc); + + var1(inew) = sum((nhist(1:ix)-mu1(inew)).^2 .* uc(1:ix, i)) /normuc; + var2(inew) = sum((nhist(1+ix:100)-mu2(inew)).^2 .* uc(1+ix:100, i))/(1-normuc); + + u1(1:ix,inew) = uu(1:ix, i); + u2(1+ix:100,inew) = uu(1+ix:100, i); + + iY(inew) = i; + end + +end + +mu1 = mu1(1:inew); +mu2 = mu2(1:inew); +var1 = var1(1:inew); +var2 = var2(1:inew); +u1 = u1(:,1:inew); +u2 = u2(:,1:inew); + +n1 = sum(u1,1)'; +n2 = sum(u2,1)'; +iY = iY(1:inew); + +score = 1 - (n1.*var1 + n2.*var2)./((n1+n2).*var0(iY)'); +% score = ((n1+n2).*var0(iY)' - (n1.*var1 + n2.*var2))./var0(iY)'; +[~, isort] = sort(score, 'descend'); + +iY = iY(isort); +mu1 = mu1(isort); +mu2 = mu2(isort); +u1 = u1(:,isort); +u2 = u2(:,isort); +score = score(isort); diff --git a/preProcess/convertOpenEphysToRawBInary.m b/preProcess/convertOpenEphysToRawBInary.m new file mode 100644 index 00000000..e7d1d90f --- /dev/null +++ b/preProcess/convertOpenEphysToRawBInary.m @@ -0,0 +1,71 @@ +function ops = convertOpenEphysToRawBInary(ops) + +fname = fullfile(ops.root, sprintf('%s.dat', ops.fbinary)); +fidout = fopen(fname, 'w'); +% +clear fs +for j = 1:ops.Nchan + fs{j} = dir(fullfile(ops.root, sprintf('*CH%d_*.continuous', j) )); +end +nblocks = cellfun(@(x) numel(x), fs); +if numel(unique(nblocks))>1 + error('different number of blocks for different channels!') +end +% +nBlocks = unique(nblocks); +nSamples = 1024; % fixed to 1024 for now! + +fid = cell(ops.Nchan, 1); + +tic +for k = 1:nBlocks + for j = 1:ops.Nchan + fid{j} = fopen(fullfile(ops.root, fs{j}(k).name)); + % discard header information + fseek(fid{j}, 1024, 0); + end + % + nsamps = 0; + flag = 1; + while 1 + samples = zeros(nSamples * 1000, ops.Nchan, 'int16'); + for j = 1:ops.Nchan + collectSamps = zeros(nSamples * 1000, 1, 'int16'); + + rawData = fread(fid{j}, 1000 * (nSamples + 6), '1030*int16', 10, 'b'); + + nbatches = ceil(numel(rawData)/(nSamples+6)); + for s = 1:nbatches + rawSamps = rawData((s-1) * (nSamples + 6) +6+ [1:nSamples]); + collectSamps((s-1)*nSamples + [1:nSamples]) = rawSamps; + end + samples(:,j) = collectSamps; + end + + if nbatches<1000 + flag = 0; + end + if flag==0 + samples = samples(1:s*nSamples, :); + end + + samples = samples'; + fwrite(fidout, samples, 'int16'); + + nsamps = nsamps + size(samples,2); + + if flag==0 + break; + end + end + ops.nSamplesBlocks(k) = nsamps; + + for j = 1:ops.Nchan + fclose(fid{j}); + end + +end + +fclose(fidout); + +toc \ No newline at end of file diff --git a/preProcess/get_PCproj.m b/preProcess/get_PCproj.m new file mode 100644 index 00000000..fb0fef3f --- /dev/null +++ b/preProcess/get_PCproj.m @@ -0,0 +1,18 @@ +function Us = get_PCproj(S1, row, col, wPCA, maskMaxChans) + +[nT, nChan] = size(S1); +dt = -21 + [1:size(wPCA,1)]; +inds = repmat(row', numel(dt), 1) + repmat(dt', 1, numel(row)); + +clips = reshape(S1(inds, :), numel(dt), numel(row), nChan); + + +mask = repmat([1:nChan], [numel(row) 1]) - repmat(col, 1, nChan); +Mask(1,:,:) = abs(mask)= nt1-1 + imin = imin+1; + end + for i = [imin:icurrent-1 icurrent+1:imax-1] + if (Mask(nt0 + (st0(i) - st0(icurrent)), id0(icurrent), id0(i))) + isol(icurrent) = false; + break; + end + end + + icurrent = icurrent + 1; +end diff --git a/preProcess/isolated_peaks.m b/preProcess/isolated_peaks.m new file mode 100644 index 00000000..15ad9948 --- /dev/null +++ b/preProcess/isolated_peaks.m @@ -0,0 +1,12 @@ +function [row, col, mu] = isolated_peaks(S1, loc_range, long_range, Th) +% loc_range = [3 1]; +% long_range = [30 6]; +smin = my_min(S1, loc_range, [1 2]); +peaks = single(S11e-6); + xc = xcoords(connected>1e-6); + yc = ycoords(connected>1e-6); + catch + chanMapConn = 1+chanNums(connected>1e-6); + xc = zeros(numel(chanMapConn), 1); + yc = [1:1:numel(chanMapConn)]'; + end + ops.Nchan = getOr(ops, 'Nchan', sum(connected>1e-6)); + ops.NchanTOT = getOr(ops, 'NchanTOT', numel(connected)); + if exist('fs', 'var') + ops.fs = getOr(ops, 'fs', fs); + end + else + chanMap = ops.chanMap; + chanMapConn = ops.chanMap; + xc = zeros(numel(chanMapConn), 1); + yc = [1:1:numel(chanMapConn)]'; + connected = true(numel(chanMap), 1); + + ops.Nchan = numel(connected); + ops.NchanTOT = numel(connected); + end +else + chanMap = 1:ops.Nchan; + connected = true(numel(chanMap), 1); + + chanMapConn = 1:ops.Nchan; + xc = zeros(numel(chanMapConn), 1); + yc = [1:1:numel(chanMapConn)]'; +end +if exist('kcoords', 'var') + kcoords = kcoords(connected); +else + kcoords = ones(ops.Nchan, 1); +end +NchanTOT = ops.NchanTOT; +NT = ops.NT ; + +rez.ops = ops; +rez.xc = xc; +rez.yc = yc; +if exist('xcoords') + rez.xcoords = xcoords; + rez.ycoords = ycoords; +else + rez.xcoords = xc; + rez.ycoords = yc; +end +rez.connected = connected; +rez.ops.chanMap = chanMap; +rez.ops.kcoords = kcoords; + +d = dir(ops.fbinary); +ops.sampsToRead = floor(d.bytes/NchanTOT/2); + +if ispc + dmem = memory; + memfree = dmem.MemAvailableAllArrays/8; + memallocated = min(ops.ForceMaxRAMforDat, dmem.MemAvailableAllArrays) - memfree; + memallocated = max(0, memallocated); +else + memallocated = ops.ForceMaxRAMforDat; +end +nint16s = memallocated/2; + +NTbuff = NT + 4*ops.ntbuff; +Nbatch = ceil(d.bytes/2/NchanTOT /(NT-ops.ntbuff)); +Nbatch_buff = floor(4/5 * nint16s/rez.ops.Nchan /(NT-ops.ntbuff)); % factor of 4/5 for storing PCs of spikes +Nbatch_buff = min(Nbatch_buff, Nbatch); + +%% load data into patches, filter, compute covariance +if isfield(ops,'fslow')&&ops.fslowsize(uproj,1) + uproj(1e6 + size(uproj,1), 1) = 0; + end + + uproj(i0 + (1:numel(row)), :) = gather_try(uS); + i0 = i0 + numel(row); + end + + if ibatch<=Nbatch_buff + DATA(:,:,ibatch) = gather_try(datr); + else + datcpu = gather_try(int16(datr)); + fwrite(fidW, datcpu, 'int16'); + end + +end + +if strcmp(ops.initialize, 'fromData') + uproj(i0+1:end, :) = []; +end +Wrot = gather_try(Wrot); +rez.Wrot = Wrot; + +fclose(fidW); +fclose(fid); +if ops.verbose + fprintf('Time %3.2f. Whitened data written to disk... \n', toc); + fprintf('Time %3.2f. Preprocessing complete!\n', toc); +end + + +rez.temp.Nbatch = Nbatch; +rez.temp.Nbatch_buff = Nbatch_buff; + diff --git a/tests/gather_mean_spikes.m b/tests/gather_mean_spikes.m new file mode 100644 index 00000000..27153774 --- /dev/null +++ b/tests/gather_mean_spikes.m @@ -0,0 +1,81 @@ +tic +if ~isempty(ops.chanMap) + load(ops.chanMap); + chanMapConn = chanMap(connected>1e-6); +else + chanMapConn = 1:ops.Nchan; +end +batch_path = fullfile(root, 'batches'); +if ~exist(batch_path, 'dir') + mkdir(batch_path); +end +NchanTOT = ops.NchanTOT; + +d = dir(fullfile(root, fname)); +ops.sampsToRead = floor(d.bytes/NchanTOT/2); + + +NT = 128*1024+ ops.ntbuff; +NTbuff = NT + 4*ops.ntbuff; +Nbatch = ceil(d.bytes/2/NchanTOT /(NT-ops.ntbuff)); + +% load data into patches, filter, compute covariance, write back to +% disk + +fprintf('Time %3.0fs. Loading raw data... \n', toc); +fid = fopen(fullfile(root, fname), 'r'); +ibatch = 0; +Nchan = ops.Nchan; + +Nchans = ops.Nchan; +ts = [1:1:nt0]'; + +clear stimes +for iNN = 1:size(rez.W,2) + stimes{iNN} = rez.st3(rez.st3(:,2)==iNN,1); +end +%stimes = gtimes; + +Wraw = zeros(nt0, Nchans, numel(stimes)); +for ibatch = 1:Nbatch + if ibatch>Nbatch_buff + offset = 2 * ops.Nchan*batchstart(ibatch-Nbatch_buff); % - ioffset; + fseek(fid, offset, 'bof'); + dat = fread(fid, [NT ops.Nchan], '*int16'); + else + dat = DATA(:,:,ibatch); + end + dataRAW = gpuArray(dat); + dataRAW = single(dataRAW); + dataRAW = dataRAW / ops.scaleproc; + + + if ibatch==1; ioffset = 0; + else ioffset = ops.ntbuff; + end + + for iNN = 1:numel(stimes) + st = stimes{iNN} + ioffset - (NT-ops.ntbuff)*(ibatch-1) - 20; + st(st<0) = []; + st(st>NT-ops.ntbuff) = []; + + if ~isempty(st) + inds = repmat(st', nt0, 1) + repmat(ts, 1, numel(st)); + + Wraw(:,:,iNN) = Wraw(:,:,iNN) + ... + gather_try(squeeze(sum(reshape(dataRAW(inds, :), nt0, numel(st), Nchans),2))); + end + end + +end + +for iNN = 1:numel(stimes) + Wraw(:,:,iNN) = Wraw(:,:,iNN)/numel(stimes{iNN}); +end +fprintf('Time %3.2f. Mean waveforms computed... \n', toc); + + + + + + diff --git a/tests/gather_raw_mean_spikes.m b/tests/gather_raw_mean_spikes.m new file mode 100644 index 00000000..6b4eefd1 --- /dev/null +++ b/tests/gather_raw_mean_spikes.m @@ -0,0 +1,92 @@ +tic +nt0 = ops.nt0; + +if ~isempty(ops.chanMap) + load(ops.chanMap); + chanMapConn = chanMap(connected>1e-6); +else + chanMapConn = 1:ops.Nchan; +end +batch_path = fullfile(root, 'batches'); +if ~exist(batch_path, 'dir') + mkdir(batch_path); +end +NchanTOT = ops.NchanTOT; + +d = dir(fullfile(root, fname)); +ops.sampsToRead = floor(d.bytes/NchanTOT/2); + + +NT = 128*1024+ ops.ntbuff; +NTbuff = NT + 4*ops.ntbuff; +Nbatch = ceil(d.bytes/2/NchanTOT /(NT-ops.ntbuff)); + +% load data into patches, filter, compute covariance, write back to +% disk + +fprintf('Time %3.0fs. Loading raw data... \n', toc); +fid = fopen(fullfile(root, fname), 'r'); +ibatch = 0; +Nchan = ops.Nchan; + +Nchans = ops.Nchan; +ts = [1:1:nt0]'; + +clear stimes +% for iNN = 1:size(rez.W,2) +% stimes{iNN} = rez.st3pos(rez.st3pos(:,2)==iNN,1); +% end +stimes = gtimes; + +Wraw = zeros(nt0, Nchans, numel(stimes)); + +while 1 + ibatch = ibatch + 1; + + offset = max(0, 2*NchanTOT*((NT - ops.ntbuff) * (ibatch-1) - 2*ops.ntbuff)); + if ibatch==1 + ioffset = 0; + else + ioffset = ops.ntbuff; + end + fseek(fid, offset, 'bof'); + buff = fread(fid, [NchanTOT NTbuff], '*int16'); + + if isempty(buff) + break; + end + nsampcurr = size(buff,2); + if nsampcurrNT-ops.ntbuff) = []; + + if ~isempty(st) + inds = repmat(st', nt0, 1) + repmat(ts, 1, numel(st)); + + Wraw(:,:,iNN) = Wraw(:,:,iNN) + ... + squeeze(sum(reshape(buff(inds, :), nt0, numel(st), Nchans),2)); + end + end + +end +fclose(fid); + +for iNN = 1:numel(stimes) + Wraw(:,:,iNN) = Wraw(:,:,iNN)/numel(stimes{iNN}); +end +fprintf('Time %3.2f. Mean waveforms computed... \n', toc); + + + + + + diff --git a/tests/plotPCcoefs.m b/tests/plotPCcoefs.m new file mode 100644 index 00000000..475b8e32 --- /dev/null +++ b/tests/plotPCcoefs.m @@ -0,0 +1,4 @@ +clustID = clustID+1; +iSp = find(st3(:,2)==clustID); + +plot(rez.cProjPC(iSp+1,1,1), rez.cProjPC(iSp+1,1,2), '.') \ No newline at end of file diff --git a/tests/plot_final_waveforms.m b/tests/plot_final_waveforms.m new file mode 100644 index 00000000..1fa6c07b --- /dev/null +++ b/tests/plot_final_waveforms.m @@ -0,0 +1,79 @@ +%% +isV1pyr = rez.nbins(1:Nfilt)> 1000 & rez.ypos'>0 & rez.t2p(:,1)>10; +isV1pv = rez.nbins(1:Nfilt)> 1000 & rez.ypos'>0 & rez.t2p(:,1)<=10; +W0 = alignW(W(:,:,1)); + +figure(1) +hist(rez.t2p(rez.nbins>1000,1), 1:1:nt0) +xlabel('trough to peak of templates with >1000 spikes') +set(gcf, 'Color', 'w') +ylabel('number of templates') +% export_fig('fig1.pdf') + +figure(2) +which_cells = find(isV1pyr); +[~, isort] = sort(rez.ypos(which_cells), 'ascend'); +which_cells = which_cells(isort); +subplot(1,2,1) +imagesc(U(:, which_cells,1)) +title('RS templates: spatial profile') +colormap('gray') +subplot(1,2,2) +plot(W0(:, which_cells,1)) +axis tight +title('temporal profile') +set(gcf, 'Color', 'w') +% export_fig('fig2.pdf') + +figure(3) +which_cells = find(isV1pv); +[~, isort] = sort(rez.ypos(which_cells), 'ascend'); +which_cells = which_cells(isort); +subplot(1,2,1) +imagesc(U(:, which_cells,1)) +colormap('gray') + +title('FS templates: spatial profile') +subplot(1,2,2) +plot(W0(:, which_cells,1)) +axis tight +title('temporal profile') +set(gcf, 'Color', 'w') + +% export_fig('fig3.pdf') +%% bad iNN 121, 478 +iscell = rez.nbins(1:Nfilt)> 300; + +[~, isortmu] = sort(mu, 'descend'); + +for i = 1:Nfilt + + ts = -.5:25:2000; + iNN = isortmu(i); + uu = histc(diff(st3pos(st3pos(:,2)==iNN, 1)), ts); + subplot(2,2,1) + bar([-fliplr(ts) ts(2:end)]', [flipud(uu); uu(2:end)]) + axis tight + title(mu(iNN)) + + subplot(2,2,2) + hist(rez.st3pos(rez.st3pos(:,2)==iNN,3), 100) + title(mu(iNN)) + axis tight + + [mWmax, iNN2] = max(mWtW(iNN, 1:Nfilt)); + uu = histc(diff(st3pos(st3pos(:,2)==iNN2, 1)), ts); + subplot(2,2,4) + bar([-fliplr(ts) ts(2:end)]', [flipud(uu); uu(2:end)]) + axis tight + title(mu(iNN2)) + + uu = histc(diff(st3pos(st3pos(:,2)==iNN2 | st3pos(:,2)==iNN, 1)), ts); + subplot(2,2,3) + bar([-fliplr(ts) ts(2:end)]', [flipud(uu); uu(2:end)]) + axis tight + title(mWmax) + + pause +end + diff --git a/tests/plot_waveforms.m b/tests/plot_waveforms.m new file mode 100644 index 00000000..0f9e2f34 --- /dev/null +++ b/tests/plot_waveforms.m @@ -0,0 +1,13 @@ +function plot_waveforms(W) + +W = W - repmat(mean(W,1), size(W,1), 1); + +uu = max(abs(W(:))); + +for i = 1:size(W,2) + plot(i*uu/24 + W(:,i), 'k') + hold all +end +hold off + +axis tight \ No newline at end of file diff --git a/tests/plot_waveforms2.m b/tests/plot_waveforms2.m new file mode 100644 index 00000000..3e54b94f --- /dev/null +++ b/tests/plot_waveforms2.m @@ -0,0 +1,92 @@ +function plot_waveforms2(rez,clusts,s,Wraw) +close all; +ops=rez.ops; + +if nargin<3 + s=[]; + Wraw=[]; +end + +if isempty(clusts) + clusts=1:ops.Nfilt; +end + +for k =clusts + ik = find(rez.st3(:,2)==k); + if length(ik)>5 + figure(k);clf; + colormap('parula') + subplot(1,3,1) + wa = squeeze(rez.U(:,k,:)) * squeeze(rez.W(:,k,:))'; + wa = 200 * wa; + mW=wa(:,21:end); + t=(1:size(mW,2))/ops.fs*1e3; + imagesc(t,(1:size(mW,1)),mW) + title('mean waveform') + xlabel('time (ms)') + axis tight + subplot(2,3,2) + m = sort(rez.st3(ik,3), 'descend');%amplitudes + plot(m,'linewidth',2) + xlabel('sorted spikes') + ylabel('amplitudes') + title('CDF of Amplitude') + axis tight + if ~isempty(s) + subplot(2,3,5) + plot(Wraw(:,:,k),'linewidth',2) + xlabel('time sample') + ylabel('voltage (\muV)') + title('mean waveform') + axis tight + subplot(2,3,3) + Mmax=max(abs(wa), [], 2);%get the max amplitude in each channel + [~, ichan] = max(Mmax); + imagesc(squeeze(s{k,2}(:,:,ichan))') + ylabel('spike index') + xlabel('time sample') + title('Waveform Stability') + axis tight + subplot(2,3,6) + amp=max(squeeze(s{k,2}(:,:,ichan))); + plot(s{k,1}/ops.fs,amp,'.'); + maxAmp(k)=mean(amp); + else + subplot(2,3,5) + plot(t,wa(:,21:end)') + xlabel('time (ms)') + ylabel('voltage (\muV)') + title('mean waveform') + axis tight + tSpike=rez.st3(ik,1)/ops.fs; + subplot(2,3,6) + maxAmp(k)=mean(rez.st3(ik,3)); + plot(tSpike,rez.st3(ik,3),'.'); + axis tight; + xlabel('time (s)') + ylabel('amplitude (\muV)') + title('Spike Raster') + yy=ylim; + ylim([0,yy(2)]) + subplot(2,3,3) + width=.03;%in s + binSize=.0003;%in s + newFs=1/binSize; + % resample at a lower Fs in tune with bins + Sinds=round(tSpike*newFs); + tSpike2=zeros(1,max(Sinds));%this is more memory intensive but faster than a for loop + tSpike2(Sinds)=1; + [c,lags]=xcorr(tSpike2,newFs*width); + c(ceil(end/2))=NaN; + c=c/max(c); + plot(lags*binSize*1e3,c) + xlabel('time (ms)') + ylabel('corr') + title('AutoCorrelogram') + axis tight; + yy=ylim; + ylim([0,yy(2)]); + end + suptitle(sprintf('cluster #%d', k)) + end +end diff --git a/tests/plots_of_template_coefficients.m b/tests/plots_of_template_coefficients.m new file mode 100644 index 00000000..3a5acc49 --- /dev/null +++ b/tests/plots_of_template_coefficients.m @@ -0,0 +1,67 @@ +testID = 156; +mu = rez.mu; +clusterIDs = rez.st3(:,2); +% tfi = rez.iNeigh; +% tf = rez.cProj; + +spikesTest = find(clusterIDs==testID); + +simIDs = rez.iNeigh(:,testID); +% +figure(2) +clf +figure(1) +clf + +LAM = ops.lam(3) * (20./mu).^2; + +nSP = ceil(sqrt(length(simIDs))); +for s = 1:length(simIDs) + simS_T = find(rez.iNeigh(:,simIDs(s))==testID); + spikesSim = find(clusterIDs==simIDs(s)); + + if simIDs(s)~=testID && numel(spikesSim)>20 && ~isempty(simS_T) + figure(2) + subplot(4, 8, s); + + plot(rez.cProj(spikesTest,1), rez.cProj(spikesTest,s), '.') + hold on; + plot(rez.cProj(spikesSim,simS_T), rez.cProj(spikesSim,1), '.') + + title(sprintf('%d vs %d', testID, simIDs(s))) + axis tight + + figure(1) + subplot(nSP/2, 2*nSP, s); + + ft1 = [rez.cProj(spikesTest,1); rez.cProj(spikesSim,simS_T)]; + ft2 = [rez.cProj(spikesTest,s); rez.cProj(spikesSim,1)]; + + ft1 = (ft1 + LAM(testID) * mu(testID)) / sqrt(1 + LAM(testID)); + ft2 = (ft2 + LAM(simIDs(s)) * mu(simIDs(s))) / sqrt(1 + LAM(simIDs(s))); + + df = ft1 - ft2; + l1 = min(df(:)); + l2 = max(df(:)); +% bins = linspace(l1, l2, 100); + df1 = df(1:numel(spikesTest)); + df2 = df(1+numel(spikesTest):end); + + se = (std(df1) + std(df2))/2; + se25 = se/10; + b2 = [0:se25:-l1]; + b1 = [0:se25:l2]; + + hs1 = my_conv(histc(df1, b1), 1); + hs2 = my_conv(histc(-df2, b2), 1); + + + mlow = min(max(hs1(1), hs1(2)), max(hs2(1), hs2(2))); + plot(b1, hs1, 'Linewidth', 2) + hold on + plot(-b2, hs2, 'Linewidth', 2) + hold off + axis tight + title({sprintf('%d vs %d', testID, simIDs(s)), sprintf('%2.2f and %2.2f', max(hs2)/mlow, max(hs1)/mlow)}) + end +end \ No newline at end of file diff --git a/tests/testCodeFromPhy.m b/tests/testCodeFromPhy.m new file mode 100644 index 00000000..a7697f25 --- /dev/null +++ b/tests/testCodeFromPhy.m @@ -0,0 +1,41 @@ +addpath('D:\DATA\Spikes\EvaluationCode') +% datFilename = '20150601_all.dat'; +%% +dd = load('C:\DATA\Spikes\set7\20151102_1_ks_results.mat'); +root = fullfile('C:\DATA\Spikes', sprintf('set%d', idset)); + +fname = 'spike_clustersSorted.npy'; + +gtCluMan = readNPY(fullfile(root, fname)); + +totM = zeros(1000, 1); +for i = 1:1000 + totM(i) = numel(unique(dd.rez.st3(gtCluMan==i, 2))); + nsp(i) = sum(gtCluMan==i); +end +%% +igt = ismember(gtCluMan, find(totM==1 & nsp>1000)); + +% igt = ismember(gtCluMan, 563); + +gtRes = double(dd.rez.st3(igt, 1)); +gtClu = double(gtCluMan(igt, 1)); +%% +testRes = st3(:,1); +testClu = st3(:,2); + +[allScores, allFPrates, allMissRates, allMerges] = ... + compareClustering(gtClu, gtRes, testClu, testRes, []); +% +clid = unique(gtClu); +clear gtimes +for k = 1:length(clid) + gtimes{k} = double(gtRes(gtClu==clid(k))); +end + + +%% +bestPostMerge = []; +for j = 1:length(allScores) + bestPostMerge(j) = allScores{j}(end); +end \ No newline at end of file diff --git a/utils/gather_try.m b/utils/gather_try.m new file mode 100644 index 00000000..d399fba5 --- /dev/null +++ b/utils/gather_try.m @@ -0,0 +1,6 @@ +function x = gather_try(x) + +try + x = gather(x); +catch +end \ No newline at end of file diff --git a/utils/getOr.m b/utils/getOr.m new file mode 100644 index 00000000..539214ad --- /dev/null +++ b/utils/getOr.m @@ -0,0 +1,26 @@ +function v = getOr(s, field, default) +%getOr Returns the structure field or a default if either don't exist +% v = getOr(s, field, [default]) returns the 'field' of the structure 's' +% or 'default' if the structure is empty of the field does not exist. If +% default is not specified it defaults to []. 'field' can also be a cell +% array of fields, in which case it will check for all of them and return +% the value of the first field that exists, if any (otherwise the default +% value). + +if nargin < 3 + default = []; +end + +fieldExists = isfield(s, field); +if any(fieldExists) + if iscellstr(field) + v = s.(field{find(fieldExists, 1)}); + else + v = s.(field); + end +else + v = default; +end + +end + diff --git a/utils/my_conv.m b/utils/my_conv.m new file mode 100644 index 00000000..02a7bf0d --- /dev/null +++ b/utils/my_conv.m @@ -0,0 +1,20 @@ +function Smooth = my_conv(S1, sig, varargin) + +NN = size(S1,1); +NT = size(S1,2); + +dt = -4*sig:1:4*sig; +gaus = exp( - dt.^2/(2*sig^2)); +gaus = gaus'/sum(gaus); + +% Norms = conv(ones(NT,1), gaus, 'same'); +%Smooth = zeros(NN, NT); +%for n = 1:NN +% Smooth(n,:) = (conv(S1(n,:)', gaus, 'same')./Norms)'; +%end + +Smooth = filter(gaus, 1, [S1' ones(NT,1); zeros(ceil(4*sig), NN+1)]); +Smooth = Smooth(1+ceil(4*sig):end, :); +Smooth = Smooth(:,1:NN) ./ (Smooth(:, NN+1) * ones(1,NN)); + +Smooth = Smooth'; \ No newline at end of file diff --git a/utils/my_conv2.m b/utils/my_conv2.m new file mode 100644 index 00000000..f2a7ac7b --- /dev/null +++ b/utils/my_conv2.m @@ -0,0 +1,55 @@ +function S1 = my_conv2(S1, sig, varargin) +% takes an extra argument which specifies which dimension to filter on +% extra argument can be a vector with all dimensions that need to be +% smoothed, in which case sig can also be a vector of different smoothing +% constants + +if sig>.25 + idims = 2; + if ~isempty(varargin) + idims = varargin{1}; + end + if numel(idims)>1 && numel(sig)>1 + sigall = sig; + else + sigall = repmat(sig, numel(idims), 1); + end + + for i = 1:length(idims) + sig = sigall(i); + + idim = idims(i); + Nd = ndims(S1); + + S1 = permute(S1, [idim 1:idim-1 idim+1:Nd]); + + dsnew = size(S1); + + S1 = reshape(S1, size(S1,1), []); + dsnew2 = size(S1); + + % NN = size(S1,1); + % NT = size(S1,2); + + tmax = ceil(4*sig); + dt = -tmax:1:tmax; + gaus = exp( - dt.^2/(2*sig^2)); + gaus = gaus'/sum(gaus); + + % Norms = conv(ones(NT,1), gaus, 'same'); + % Smooth = zeros(NN, NT); + % for n = 1:NN + % Smooth(n,:) = (conv(S1(n,:)', gaus, 'same')./Norms)'; + % end + + cNorm = filter(gaus, 1, cat(1, ones(dsnew2(1), 1), zeros(tmax,1))); + cNorm = cNorm(1+tmax:end, :); + S1 = filter(gaus, 1, cat(1, S1, zeros([tmax, dsnew2(2)]))); + S1(1:tmax, :) = []; + S1 = reshape(S1, dsnew); + + S1 = bsxfun(@rdivide, S1, cNorm); + + S1 = permute(S1, [2:idim 1 idim+1:Nd]); + end +end \ No newline at end of file diff --git a/utils/my_inv.m b/utils/my_inv.m new file mode 100644 index 00000000..4a11d4ae --- /dev/null +++ b/utils/my_inv.m @@ -0,0 +1,7 @@ +function Minv = my_inv(M, eps) + +[U, Sv, V] = svd(M); + +Sv = max(Sv, eps); + +Minv = U * diag(1./diag(Sv)) * V'; \ No newline at end of file diff --git a/utils/my_min.m b/utils/my_min.m new file mode 100644 index 00000000..677ec560 --- /dev/null +++ b/utils/my_min.m @@ -0,0 +1,39 @@ +function S1 = my_min(S1, sig, varargin) +% takes an extra argument which specifies which dimension to filter on +% extra argument can be a vector with all dimensions that need to be +% smoothed, in which case sig can also be a vector of different smoothing +% constants + +idims = 2; +if ~isempty(varargin) + idims = varargin{1}; +end +if numel(idims)>1 && numel(sig)>1 + sigall = sig; +else + sigall = repmat(sig, numel(idims), 1); +end + +for i = 1:length(idims) + sig = sigall(i); + + idim = idims(i); + Nd = ndims(S1); + + S1 = permute(S1, [idim 1:idim-1 idim+1:Nd]); + + dsnew = size(S1); + + S1 = reshape(S1, size(S1,1), []); + dsnew2 = size(S1); + + S1 = cat(1, Inf*ones([sig, dsnew2(2)]),S1, Inf*ones([sig, dsnew2(2)])); + Smax = S1(1:dsnew2(1), :); + for j = 1:2*sig + Smax = min(Smax, S1(j + (1:dsnew2(1)), :)); + end + + S1 = reshape(Smax, dsnew); + + S1 = permute(S1, [2:idim 1 idim+1:Nd]); +end \ No newline at end of file diff --git a/utils/my_sum.m b/utils/my_sum.m new file mode 100644 index 00000000..abef711c --- /dev/null +++ b/utils/my_sum.m @@ -0,0 +1,39 @@ +function S1 = my_sum(S1, sig, varargin) +% takes an extra argument which specifies which dimension to filter on +% extra argument can be a vector with all dimensions that need to be +% smoothed, in which case sig can also be a vector of different smoothing +% constants + +idims = 2; +if ~isempty(varargin) + idims = varargin{1}; +end +if numel(idims)>1 && numel(sig)>1 + sigall = sig; +else + sigall = repmat(sig, numel(idims), 1); +end + +for i = 1:length(idims) + sig = sigall(i); + + idim = idims(i); + Nd = ndims(S1); + + S1 = permute(S1, [idim 1:idim-1 idim+1:Nd]); + + dsnew = size(S1); + + S1 = reshape(S1, size(S1,1), []); + dsnew2 = size(S1); + + S1 = cat(1, 0*ones([sig, dsnew2(2)]),S1, 0*ones([sig, dsnew2(2)])); + Smax = S1(1:dsnew2(1), :); + for j = 1:2*sig + Smax = Smax + S1(j + (1:dsnew2(1)), :); + end + + S1 = reshape(Smax, dsnew); + + S1 = permute(S1, [2:idim 1 idim+1:Nd]); +end \ No newline at end of file diff --git a/utils/normc.m b/utils/normc.m new file mode 100644 index 00000000..b7e68cf1 --- /dev/null +++ b/utils/normc.m @@ -0,0 +1,3 @@ +function v = normc(v) + +v = v./repmat(sum(v.^2, 1), size(v,1),1).^.5; \ No newline at end of file diff --git a/utils/sq.m b/utils/sq.m new file mode 100644 index 00000000..58c40bee --- /dev/null +++ b/utils/sq.m @@ -0,0 +1,4 @@ +function x = sq(x) + + +x = squeeze(x); \ No newline at end of file