Skip to content

Commit

Permalink
double alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
marius10p committed Apr 25, 2020
1 parent 66069e6 commit 742ea27
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 37 deletions.
Binary file not shown.
82 changes: 48 additions & 34 deletions preProcess/clusterSingleBatches.m
Original file line number Diff line number Diff line change
Expand Up @@ -100,42 +100,55 @@
fprintf('time %2.2f, pre clustered %d / %d batches \n', toc, ibatch, nBatches)
end
end

%%
% find Z offsets
% anothr one of these Params variables transporting parameters to the C++ code
Params = [1 NrankPC Nfilt 0 size(W,1) 0 NchanNear Nchan];
% splits = [0, 1108, nBatches];
ib = 1:nBatches;
Params(1) = size(Ws,3) * length(ib); % the total number of templates is the number of templates per batch times the number of batches
[imin, ww] = find_integer_shifts(Params, Whs(:, ib),Ws(:,:,:,ib),...
mus(:, ib), ns(:,ib), iC, Nchan, Nfilt);

%%
% iChan = 1:Nchan;
% iUp = mod(iChan + 2-1, Nchan)+1;
% iDown = mod(iChan - 2-1, Nchan)+1;
% iMap = [iUp(iUp); iUp; iChan; iDown; iDown(iDown)];
%
%
% mu1 = 1e-5 + sq(sum(sum(ww{1}.^2, 1),2));
% mu2 = 1e-5 + sq(sum(sum(ww{2}.^2, 1),2));
%
% CC = gpuArray.zeros(Nfilt, Nfilt, 5);
% for k = 1:5
% W0 = ww{1};
% for t = 1:Nfilt
% W0(:,iMap(k,:),t) = W0(:,:,t);
% end
% X1 = reshape(W0, [nPCs * Nchan, Nfilt]);
% X2 = reshape(ww{2}, [nPCs * Nchan, Nfilt]);
% CC(:,:, k) = X1' * X2 ./ (mu1 * mu2');
% end
%
% csum = sq(mean(max(CC, [], 1), 2));
% [cmax, imax] = max(csum);
% imin = cat(2, iminy{1}, iminy{2} + (imax-3));
% imin = min(5, max(1, imin));

if isfield(ops, 'midpoint')
splits = [0, ceil(ops.midpoint/ops.NT), nBatches];
else
splits = [0, nBatches];
end
for k = 1:length(splits)-1
ib = splits(k)+1:splits(k+1);
Params(1) = size(Ws,3) * length(ib); % the total number of templates is the number of templates per batch times the number of batches
[iminy{k}, ww{k}, Ns{k}] = find_integer_shifts(Params, Whs(:, ib),Ws(:,:,:,ib),...
mus(:, ib), ns(:,ib), iC, Nchan, Nfilt);
end

if isfield(ops, 'midpoint')
iChan = 1:Nchan;
iUp = mod(iChan + 2-1, Nchan)+1;
iDown = mod(iChan - 2-1, Nchan)+1;
iMap = [iUp(iUp); iUp; iChan; iDown; iDown(iDown)];


mu1 = 1e-5 + sq(sum(sum(ww{1}.^2, 1),2)).^.5;
mu2 = 1e-5 + sq(sum(sum(ww{2}.^2, 1),2)).^.5;

CC = gpuArray.zeros(Nfilt, Nfilt, 5);
for k = 1:5
W0 = ww{1};
for t = 1:Nfilt
W0(:,iMap(k,:),t) = W0(:,:,t);
end
X1 = reshape(W0, [nPCs * Nchan, Nfilt]);
X2 = reshape(ww{2}, [nPCs * Nchan, Nfilt]);
CC(:,:, k) = X1' * X2 ./ (mu1 * mu2');
% CC(:,:, k) = 2 * X1' * X2 - mu1.^2 - mu2'.^2;
end

csum = sq(mean(max(CC.* Ns{2}' , [], 1), 2));
% csum = sq(mean(max(CC , [], 1), 2));
[cmax, imax] = max(csum);
imin = cat(2, iminy{1}, iminy{2} + (imax-3));
imin = min(5, max(1, imin));

disp(imax)
else
imin = iminy{1};
end
%%
Params(1) = size(Ws,3) * size(Ws,4); % the total number of templates is the number of templates per batch times the number of batches

Expand All @@ -161,7 +174,6 @@
% pairs of templates that live on the same channels are potential "matches"
iMatch = sq(min(abs(single(iC) - reshape(Wh0, 1, 1, [])), [], 1))<.1;


% compute dissimilarities for iMatch = 1
[iclust, ds] = mexDistances2(Params, Ws, W, iMatch, iC-1, Whs2-1, mus, mu);

Expand Down Expand Up @@ -205,9 +217,11 @@

fprintf('time %2.2f, Re-ordered %d batches. \n', toc, nBatches)
%%
nup = 0;
for ibatch = 1:nBatches
if abs(imin(ibatch) - 3) > 0
shift_batch_on_disk(rez, ibatch, imin(ibatch) - 3);
nup = nup + 1;
end
end
fprintf('time %2.2f, Shifted up/down %d batches. \n', toc, nBatches)
fprintf('time %2.2f, Shifted up/down %d batches. \n', toc, nup)
10 changes: 7 additions & 3 deletions preProcess/find_integer_shifts.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [imin, W0] = find_integer_shifts(Params, Whs,Ws,mus, ns, iC, Nchan, Nfilt)
function [imin, W0, Ns] = find_integer_shifts(Params, Whs,Ws,mus, ns, iC, Nchan, Nfilt)

nBatches = size(Whs, 2);
nPCs = size(Ws, 1);
Expand Down Expand Up @@ -35,7 +35,7 @@
iMatch = sq(min(abs(single(iC) - reshape(iMap(k, Wh0), 1, 1, [])), [], 1))<.1;

% compute dissimilarities for iMatch = 1
[iclust, ds] = mexDistances2(Params, Ws, W, iMatch, iC-1, Whs-1, mus, mu0);
[iclust, ds] = mexDistances2(Params, Ws, W, iMatch, iC-1, Whs-1, mus, sq(mu0));

% ds are squared Euclidian distances
iclustall(:,:,k) = reshape(iclust, Nfilt, []) + 1;
Expand All @@ -53,6 +53,7 @@

W0 = gpuArray.zeros(nPCs , Nchan, Nfilt, 'single');
nn = 1e-4 * ones(Nfilt,1);
Ns = zeros(Nfilt,1);
for j = 1:length(irange)
ibatch = irange(j);
icl = iclustall(:, ibatch, medshift);
Expand All @@ -61,6 +62,7 @@
W0(:,iC(:, Whs(t,ibatch)), icl(t)) = W0(:,iC(:, Whs(t,ibatch)), icl(t)) + ...
mus(t, ibatch) * Ws(:,:,t,ibatch) ;
nn(icl(t)) = nn(icl(t)) + 1;
Ns(icl(t)) = Ns(icl(t)) + gather(ns(t, ibatch));
end
end
end
Expand All @@ -69,7 +71,9 @@
end
mu0 = 1e-3 + sum(sum(W0.^2,1),2).^.5;
W0 = W0 ./ mu0;
mu0 = sq(mu0);
Ns = Ns./nn;
% mu0 = sq(mu0);
end

W0 = W0 .* mu0;

0 comments on commit 742ea27

Please sign in to comment.