Skip to content

Commit

Permalink
refinements
Browse files Browse the repository at this point in the history
  • Loading branch information
marius10p committed May 5, 2020
1 parent 742ea27 commit 947abbb
Show file tree
Hide file tree
Showing 12 changed files with 461 additions and 88 deletions.
Binary file added configFiles/NP2_kilosortChanMap.mat
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
46 changes: 24 additions & 22 deletions mainLoop/runTemplates.m
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,28 @@
Nchan = rez.ops.Nchan;
nt0 = rez.ops.nt0;

nKeep = min(Nchan*3,20); % how many PCs to keep
rez.W_a = zeros(nt0 * Nrank, nKeep, Nfilt, 'single');
rez.W_b = zeros(Nbatches, nKeep, Nfilt, 'single');
rez.U_a = zeros(Nchan* Nrank, nKeep, Nfilt, 'single');
rez.U_b = zeros(Nbatches, nKeep, Nfilt, 'single');
for j = 1:Nfilt
% do this for every template separately
WA = reshape(rez.WA(:, j, :, :), [], Nbatches);
WA = gpuArray(WA); % svd on the GPU was faster for this, but the Python randomized CPU version might be faster still
[A, B, C] = svdecon(WA);
% W_a times W_b results in a reconstruction of the time components
rez.W_a(:,:,j) = gather(A(:, 1:nKeep) * B(1:nKeep, 1:nKeep));
rez.W_b(:,:,j) = gather(C(:, 1:nKeep));
if 0
nKeep = min(Nchan*3,20); % how many PCs to keep
rez.W_a = zeros(nt0 * Nrank, nKeep, Nfilt, 'single');
rez.W_b = zeros(Nbatches, nKeep, Nfilt, 'single');
rez.U_a = zeros(Nchan* Nrank, nKeep, Nfilt, 'single');
rez.U_b = zeros(Nbatches, nKeep, Nfilt, 'single');
for j = 1:Nfilt
% do this for every template separately
WA = reshape(rez.WA(:, j, :, :), [], Nbatches);
WA = gpuArray(WA); % svd on the GPU was faster for this, but the Python randomized CPU version might be faster still
[A, B, C] = svdecon(WA);
% W_a times W_b results in a reconstruction of the time components
rez.W_a(:,:,j) = gather(A(:, 1:nKeep) * B(1:nKeep, 1:nKeep));
rez.W_b(:,:,j) = gather(C(:, 1:nKeep));

UA = reshape(rez.UA(:, j, :, :), [], Nbatches);
UA = gpuArray(UA);
[A, B, C] = svdecon(UA);
% U_a times U_b results in a reconstruction of the time components
rez.U_a(:,:,j) = gather(A(:, 1:nKeep) * B(1:nKeep, 1:nKeep));
rez.U_b(:,:,j) = gather(C(:, 1:nKeep));
end

UA = reshape(rez.UA(:, j, :, :), [], Nbatches);
UA = gpuArray(UA);
[A, B, C] = svdecon(UA);
% U_a times U_b results in a reconstruction of the time components
rez.U_a(:,:,j) = gather(A(:, 1:nKeep) * B(1:nKeep, 1:nKeep));
rez.U_b(:,:,j) = gather(C(:, 1:nKeep));
end

fprintf('Finished compressing time-varying templates \n')
fprintf('Finished compressing time-varying templates \n')
end
97 changes: 97 additions & 0 deletions preProcess/benchmarks_manipulator.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
% benchmark manipulator
addpath(genpath('D:\GitHub\KiloSort2')) % path to kilosort folder
addpath('D:\GitHub\npy-matlab')

pathToYourConfigFile = 'D:\GitHub\KiloSort2\configFiles'; % take from Github folder and put it somewhere else (together with the master_file)
run(fullfile(pathToYourConfigFile, 'configFile384.m'))
rootH = 'H:\';
ops.fproc = fullfile(rootH, 'temp_wh.dat'); % proc file on a fast SSD

root0 = 'F:\Spikes\manipulator';
chanmaps = {'neuropixPhase3B1_kilosortChanMap_all.mat', 'NP2_kilosortChanMap.mat'};
for j = 1:2
chanmaps{j} = fullfile(root0, chanmaps{j});
end

addpath('D:\Drive\CODE\KS2')
%%
dt = 1;
root0 = 'F:\Spikes\manipulator';
dr = [];
cc = [];
cc2 = {};
%%
for iprobe = 2
for isess = [3]
for ishift = [1 2 3]
rootZ = fullfile(root0, sprintf('drift%d', isess), sprintf('p%d', iprobe));
if ishift<3
fname = fullfile(rootZ, sprintf('rez2_%d.mat', ishift-1));
load(fname)
dr{iprobe, isess} = rez.row_shifts;
else
fname = fullfile(rootZ, 'rez_datashift.mat');
load(fname)
end

bin = ops.fs * dt;
st = rez.st3(:,1);
clu = rez.st3(:,2);
nspikes = length(st);
S = sparse(clu, ceil(st/bin), ones(nspikes,1));
igood = get_good_units(rez);
S = S(igood, :);
nbins = size(S,2);

root_drift = fullfile(root0, sprintf('drift%d', isess));
dy = readNPY(fullfile(root_drift, 'manip.positions.npy'));
ts = readNPY(fullfile(root_drift, sprintf('manip.timestamps_p%d.npy', iprobe)));
dy_samps = interp1(ts, dy, dt/2 + [0:dt:dt*(nbins-1)]);
dy_samps(isnan(dy_samps)) = 0;

ix = ceil(ts(2)/dt):ceil(ts(end)/dt);
ix2 = [1:ceil(ts(2)/dt) ceil(ts(end)/dt):nbins];

ix2(length(ix)+1:end) = [];

cc{iprobe, isess, ishift} = corr(dy_samps(ix)', S(:, ix)');
cc2{iprobe, isess, ishift} = corr(dy_samps(ix(1:length(ix2)))', S(:, ix2)');
end
end
end
%%
iprobe = 2;
isess = 3;
ishift = 1;
icc1 = cc{iprobe, isess, ishift};
icc2 = cc2{iprobe, isess, ishift};

cq = quantile(icc2, [.025, .975]);
NN = length(icc1);
nbad = sum(icc1<cq(1)) + sum(icc1>cq(2));
disp([NN-nbad, nbad, nbad/NN])
%%
csd = cellfun(@(x) mean(abs(x)), cc);
c2sd = cellfun(@(x) mean(abs(x)), cc2);
%%
sq(mean(csd, 2))
sq(mean(c2sd, 2))

%%
iprobe = 2;
isess = 3;
ishift = 1;
rootZ = fullfile(root0, sprintf('drift%d', isess), sprintf('p%d', iprobe));
fname = fullfile(rootZ, sprintf('rez2_%d.mat', ishift-1));
load(fname)
%%

imagesc(rez.ccb, [-5, 5])
%%
cellfun(@(x) numel(x), cc)



%%


68 changes: 5 additions & 63 deletions preProcess/clusterSingleBatches.m
Original file line number Diff line number Diff line change
Expand Up @@ -100,69 +100,18 @@
fprintf('time %2.2f, pre clustered %d / %d batches \n', toc, ibatch, nBatches)
end
end
%%
% find Z offsets
% anothr one of these Params variables transporting parameters to the C++ code
Params = [1 NrankPC Nfilt 0 size(W,1) 0 NchanNear Nchan];

if isfield(ops, 'midpoint')
splits = [0, ceil(ops.midpoint/ops.NT), nBatches];
else
splits = [0, nBatches];
end
for k = 1:length(splits)-1
ib = splits(k)+1:splits(k+1);
Params(1) = size(Ws,3) * length(ib); % the total number of templates is the number of templates per batch times the number of batches
[iminy{k}, ww{k}, Ns{k}] = find_integer_shifts(Params, Whs(:, ib),Ws(:,:,:,ib),...
mus(:, ib), ns(:,ib), iC, Nchan, Nfilt);
end

if isfield(ops, 'midpoint')
iChan = 1:Nchan;
iUp = mod(iChan + 2-1, Nchan)+1;
iDown = mod(iChan - 2-1, Nchan)+1;
iMap = [iUp(iUp); iUp; iChan; iDown; iDown(iDown)];


mu1 = 1e-5 + sq(sum(sum(ww{1}.^2, 1),2)).^.5;
mu2 = 1e-5 + sq(sum(sum(ww{2}.^2, 1),2)).^.5;

CC = gpuArray.zeros(Nfilt, Nfilt, 5);
for k = 1:5
W0 = ww{1};
for t = 1:Nfilt
W0(:,iMap(k,:),t) = W0(:,:,t);
end
X1 = reshape(W0, [nPCs * Nchan, Nfilt]);
X2 = reshape(ww{2}, [nPCs * Nchan, Nfilt]);
CC(:,:, k) = X1' * X2 ./ (mu1 * mu2');
% CC(:,:, k) = 2 * X1' * X2 - mu1.^2 - mu2'.^2;
end

csum = sq(mean(max(CC.* Ns{2}' , [], 1), 2));
% csum = sq(mean(max(CC , [], 1), 2));
[cmax, imax] = max(csum);
imin = cat(2, iminy{1}, iminy{2} + (imax-3));
imin = min(5, max(1, imin));

disp(imax)
else
imin = iminy{1};
end

%%
Params(1) = size(Ws,3) * size(Ws,4); % the total number of templates is the number of templates per batch times the number of batches

Whs2 = mod(Whs + 2 * int32(imin-3) - 1, Nchan) + 1;
rez.row_shifts = imin - 3;

tic

% initialize dissimilarity matrix
ccb = gpuArray.zeros(nBatches, 'single');

for ibatch = 1:nBatches
% for every batch, compute in parallel its dissimilarity to ALL other batches
Wh0 = single(Whs2(:, ibatch)); % this one is the primary batch
Wh0 = single(Whs(:, ibatch)); % this one is the primary batch
mu = mus(:, ibatch);

% embed the templates from the primary batch back into a full, sparse representation
Expand All @@ -175,7 +124,7 @@
iMatch = sq(min(abs(single(iC) - reshape(Wh0, 1, 1, [])), [], 1))<.1;

% compute dissimilarities for iMatch = 1
[iclust, ds] = mexDistances2(Params, Ws, W, iMatch, iC-1, Whs2-1, mus, mu);
[iclust, ds] = mexDistances2(Params, Ws, W, iMatch, iC-1, Whs-1, mus, mu);

% ds are squared Euclidian distances
ds = reshape(ds, Nfilt, []); % this should just be an Nfilt-long vector
Expand Down Expand Up @@ -215,13 +164,6 @@
rez.iorig = gather(iorig);
rez.ccbsort = gather(ccbsort);

% rez.iorig = randperm(nBatches);

fprintf('time %2.2f, Re-ordered %d batches. \n', toc, nBatches)
%%
nup = 0;
for ibatch = 1:nBatches
if abs(imin(ibatch) - 3) > 0
shift_batch_on_disk(rez, ibatch, imin(ibatch) - 3);
nup = nup + 1;
end
end
fprintf('time %2.2f, Shifted up/down %d batches. \n', toc, nup)
Loading

0 comments on commit 947abbb

Please sign in to comment.