-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
69 changed files
with
3,550 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
function logbeta1 = HMMbackwardSFC(v,phghm,SFC) | ||
%HMMBACKWARDSFC Backward Pass (beta method) for the Switching Autoregressive HMM | ||
% logbeta=HMMbackwardSFC(v,phghm,a,sigma2,Tskip) | ||
% | ||
% Inputs: | ||
% v : observations | ||
% phghm : state (switch) transition matrix p(h(t)|h(t-1)) | ||
% Tskip : the number of timesteps to skip before a switch update is allowed | ||
% | ||
% Outputs: | ||
% logbeta: log backward messages log p(v(t+1:T)|h(t),v(t-L+1:t)) | ||
% See also HMMforwardSFC.m and demoSFClearn.m | ||
T = size(v,2); %length of time series | ||
H = length(SFC.a); % # of states | ||
M = size(v,1); %# of regions | ||
% logbeta recursion | ||
logbeta1 = zeros(H,T); | ||
logbeta1(:,T) = zeros(H,1); | ||
for t=T:-1:2 | ||
phatvgh1 = zeros(H,1); | ||
for h = 1:H | ||
%for Wishart distribution | ||
b = SFC.b(:,:,h); | ||
a = SFC.a(h); | ||
lambdap = SFC.lambdap(h); | ||
mp = SFC.mp(:,h); | ||
term1 = -M/2 * log(2*pi); | ||
term2 = -0.5*log(det(0.5*b)); | ||
digamma_args = repmat(a + 1,1,M)-(1:M) ; | ||
term3 = 0.5*sum(digamma(0.5*digamma_args)); | ||
term4 = -0.5*a*(v(:,t)-mp)'*pinv(b)*(v(:,t)-mp); | ||
term5 = -0.5*M/lambdap; | ||
phatvgh1(h) = exp(term1 + term2 + term3 + term4 + term5)+eps; % Eq 44 | ||
end | ||
logbeta1(:,t-1)=logsumexp(repmat(logbeta1(:,t),1,H),repmat(phatvgh1,1,H).*phghm); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
function [logalpha1,loglik1] = HMMforwardSFC(v,phghm,ph1,SFC) | ||
%HMMFORWARDSFC Switching Autoregressive HMM with switches updated only every Tskip timesteps | ||
% [logalpha,loglik]=HMMforwardSFC(v,phghm,ph1,a,sigma2,Tskip) | ||
% | ||
% Inputs: | ||
% v : observations | ||
% phghm : state (switch) transition matrix p(h(t)|h(t-1)) | ||
% ph1 : prior state distribution | ||
% % Tskip : the number of timesteps to skip before a switch update is allowed | ||
% | ||
% Outputs: | ||
% logalpha : log forward messages | ||
% loglik : sequence log likelihood log p(v(1:T)) | ||
% See also HMMbackwardSFC.m and demoSFClearn.m | ||
T = size(v,2); %length of time series | ||
H = length(ph1); % # of states | ||
M = size(v,1); %# of regions | ||
logalpha1 = zeros(H,T); | ||
% logalpha recursion: | ||
for h = 1:H | ||
%for Normal - Wishart distribution | ||
b = SFC.b(:,:,h); | ||
a = SFC.a(h); | ||
lambdap = SFC.lambdap(h); | ||
mp = SFC.mp(:,h); | ||
term1 = -M/2 * log(2*pi); | ||
term2 = -0.5*log(det(0.5*b)); | ||
digamma_args = repmat(a + 1,1,M)-(1:M) ; | ||
term3 = 0.5*sum(digamma(0.5*digamma_args)); | ||
term4 = -0.5 *a* (v(:,1)-mp)'*pinv(b)*(v(:,1)-mp); | ||
term5 = -0.5*M/lambdap; | ||
logalpha1(h,1) = term1 + term2 + term3 + term4 + term5 + log(ph1(h)); % Eq 44 | ||
end | ||
for t = 2:T | ||
phatvgh1 = zeros(H,1); | ||
for h = 1:H | ||
%for Normal- Wishart distribution | ||
b = SFC.b(:,:,h); | ||
a = SFC.a(h); | ||
lambdap = SFC.lambdap(h); | ||
mp = SFC.mp(:,h); | ||
term1 = -M/2 * log(2*pi); | ||
term2 = -0.5*log(det(0.5*b)); | ||
digamma_args = repmat(a + 1,1,M)-(1:M) ; | ||
term3 = 0.5*sum(digamma(0.5*digamma_args)); | ||
term4 = -0.5*a*(v(:,t)-mp)'*pinv(b)*(v(:,t)-mp); | ||
term5 = -0.5*M/lambdap; | ||
phatvgh1(h) = exp(term1 + term2 + term3 + term4 + term5) + eps; % Eq 44 | ||
end | ||
logalpha1(:,t)=logsumexp(repmat(logalpha1(:,t-1),1,H),repmat(phatvgh1',H,1).*phghm'); | ||
end | ||
loglik1 = logsumexp(logalpha1(:,T),ones(H,1)); % log likelihood |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
function [phtgV1T,phthtpgV1T1]=HMMsmoothSFC(logalpha,logbeta,SFC,phghm,v) | ||
%HMMSMOOTHSFC Switching Autoregressive HMM smoothing | ||
% [phtgV1T,phthtpgV1T]=HMMsmoothSFC(logalpha,logbeta,a,sigma2,phghm,v,Tskip) | ||
% return the smoothed pointwise posterior p(h(t)|v(1:T)) and pairwise smoothed posterior p(h(t),h(t+1)|v(1:T)). | ||
% | ||
% Inputs: | ||
% logalpha : log alpha messages (see HMMforwardSFC.m) | ||
% logbeta : log beta messages (see HMMbackwardSFC.m) | ||
% % phghm : state (switch) transition matrix p(h(t)|h(t-1)) | ||
% v : observations | ||
% Tskip : the number of timesteps to skip before a switch update is allowed | ||
% | ||
% Outputs: | ||
% phtgV1T : smoothed posterior p(h(t)|v(1:T)) | ||
% phthtpgV1T : smoothed posterior p(h(t),h(t+1)|v(1:T)) | ||
% See also HMMforwardSFC.m, HMMbackwardSFC.m, demoSFClearn.m | ||
T = size(v,2); %length of time series | ||
H = length(SFC.a); % # of states | ||
M = size(v,1); %# of regions | ||
% smoothed posteriors: pointwise marginals: | ||
for t=1:T | ||
logphtgV1T(:,t)=logalpha(:,t)+logbeta(:,t); % alpha-beta approach | ||
phtgV1T(:,t)=condexp(logphtgV1T(:,t)); | ||
end | ||
% smoothed posteriors: pairwise marginals p(h(t),h(t+1)|v(1:T)): | ||
for t=2:T | ||
atmp=condexp(logalpha(:,t-1)); | ||
btmp=condexp(logbeta(:,t)); | ||
phatvgh1 = zeros(H,1); | ||
for h = 1:H | ||
%for Wishart distribution | ||
b = SFC.b(:,:,h); | ||
a = SFC.a(h); | ||
lambdap = SFC.lambdap(h); | ||
mp = SFC.mp(:,h); | ||
term1 = -M/2 * log(2*pi); | ||
term2 = -0.5*log(det(0.5*b)); | ||
digamma_args = repmat(a + 1,1,M)-(1:M) ; | ||
term3 = 0.5*sum(digamma(0.5*digamma_args)); | ||
term4 = -0.5*a*(v(:,t)-mp)'*pinv(b)*(v(:,t)-mp); | ||
term5 = -0.5*M/lambdap; | ||
phatvgh1(h) = exp(term1 + term2 + term3 + term4 + term5) + eps; | ||
end | ||
phatvgh1=condexp(phatvgh1); | ||
phghmt=phghm; | ||
ctmp1 = repmat(atmp,1,H).*phghmt'.*repmat(phatvgh1'.*btmp',H,1)+eps; % two timestep potential | ||
phthtpgV1T1(:,:,t-1)=ctmp1./sum(sum(ctmp1)); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
function [maxstate logprob]=HMMviterbi(v,phghm,ph1,pvgh) | ||
%HMMVITERBI Viterbi most likely joint hidden state of a HMM | ||
% [maxstate logprob]=HMMviterbi(v,phghm,ph1,pvgh) | ||
% | ||
% Inputs: | ||
% v : visible (obervation) sequence being a vector v=[2 1 3 3 1 ...] | ||
% phghm : homogeneous transition distribution phghm(i,j)=p(h(t)=i|h(t-1)=j) | ||
% ph1 : initial distribution | ||
% pvgh : homogeneous emission disrtribution pvgh(i,j)=p(v(t)=i|h(t)=j) | ||
% | ||
% Outputs: | ||
% maxstate : most likely joint hidden (latent) state sequence | ||
% logprob : associated log probability of the most likely hidden sequence | ||
% See also demoHMMinference.m | ||
%import brml.* | ||
T=size(v,2); H=size(phghm,1); | ||
mu(:,T)=ones(H,1); | ||
for t=T:-1:2 | ||
tmp = repmat(pvgh(v(t),:)'.*mu(:,t),1,H).*phghm; | ||
mu(:,t-1)= condp(max(tmp)'); % normalise to avoid underflow | ||
end | ||
% backtrack | ||
[val hs(1)]=max(ph1.*pvgh(v(1),:)'.*mu(:,1)); | ||
for t=2:T | ||
tmp = pvgh(v(t),:)'.*phghm(:,hs(t-1)); | ||
[val hs(t)]=max(tmp.*mu(:,t)); | ||
end | ||
maxstate=hs; | ||
logprob=log(ph1(hs(1)))+log(pvgh(v(1),hs(1))); | ||
for t=2:T | ||
logprob=logprob+log(phghm(hs(t),hs(t-1)))+log(pvgh(v(t),hs(t))); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
|
||
function [outlist] = ReadList(inlist) | ||
|
||
inlist = strtrim(inlist); | ||
if iscell(inlist) | ||
[pathstr, name, ext] = fileparts(inlist{1}); | ||
if strcmp(ext,'.txt') | ||
inlist_file = inlist{1}; | ||
if ~exist(inlist_file,'file') | ||
error('Cannot find file: %s \n', inlist_file); | ||
else | ||
outlist = GetList(inlist_file); | ||
end | ||
else | ||
outlist = strtrim(inlist); | ||
end | ||
else | ||
[pathstr, name, ext] = fileparts(inlist); | ||
if strcmp(ext,'.txt') | ||
inlist_file = inlist; | ||
if ~exist(inlist_file,'file') | ||
error('Cannot find file: %s \n', inlist_file); | ||
else | ||
outlist = GetList(inlist_file); | ||
end | ||
else | ||
outlist = strtrim({inlist}); | ||
end | ||
end | ||
|
||
end | ||
|
||
function slist = GetList(filename) | ||
|
||
fid = fopen(filename); | ||
cnt = 1; | ||
|
||
while ~feof(fid) | ||
fstr = fgetl(fid); | ||
str = strtrim(fstr); | ||
if ~isempty(str) | ||
slist{cnt} = str; | ||
cnt = cnt + 1; | ||
end | ||
end | ||
|
||
fclose(fid); | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
resultfmt = '/home/sryali/VB-HMM/Results/stanford_dataset_grp2-%d.mat' | ||
Model = {} | ||
|
||
for ii=1:100 | ||
foo = load(sprintf(resultfmt,ii)); | ||
Model{ii} = foo.Model(:); | ||
end | ||
|
||
save('/home/sryali/VB-HMM/Results/stanford_dataset_grp2-All.mat','Model') | ||
%save('/home/ksupekar/stanford_dataset_grp2-All.mat','Model') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
resultfmt = '/home/sryali/VB-HMM/Results/WM_HCP_1_20-%d.mat' | ||
Model = {} | ||
|
||
for ii=1:100 | ||
foo = load(sprintf(resultfmt,ii)); | ||
Model{ii} = foo.Model(:); | ||
end | ||
|
||
save('/home/sryali/VB-HMM/Results/WM_HCP_1_20-All.mat','Model') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
clear all | ||
close all | ||
clc | ||
|
||
addpath(genpath('/mnt/musk2/home/fmri/fmrihome/scripts/GCCA_toolbox_sep21')) | ||
addpath(genpath('/home/fmri/fmrihome/SPM/spm8_scripts')); | ||
addpath(genpath('/mnt/mapricot/musk2/home/sryali/Work/toolboxes/VB_HMM/vbhmm')) | ||
addpath('/mnt/mapricot/musk2/home/sryali/Work/switchingFC/VB-HMM/Scripts/common_scripts') | ||
addpath(genpath('/home/tianwenc/Toolbox/BCT/BCT_04_05_2014')); | ||
data_dir = '/mnt/mandarin2/Public_Data/HCP/Stats/TS_for_TRSBN/Smith/'; | ||
prefix = 'NonInterestRegIC_eigen1_ts_rfMRI_REST1_LR_'; | ||
saveDir = '/mnt/mapricot/musk2/home/sryali/Work/switchingFC/VB-HMM/Results/HCP/'; | ||
%result_fname = 'Adults_9ROIs_25Clusters_lambda_10p3.mat'; | ||
result_fname = 'Adults_6ROIs_25Clusters_lambda_10p3_hbm.mat'; | ||
|
||
load([saveDir result_fname]) | ||
subjects = ReadList('/mnt/mandarin1/Public_Data/HCP/Data/subjectslist.txt'); | ||
%subjects = ReadList('/mnt/mandarin1/Public_Data/HCP/Data/subjectslist_common.txt'); | ||
roi_names = {'ACC','lAI','rAI','lDLPFC','lPPC','rDLPFC','rPPC','preCue','VMPFC'}; | ||
rois = [1,3,6,7,8,9]; | ||
%rois = 1:9; | ||
roi_names = roi_names(rois); | ||
data = []; | ||
%%%%%%%%%%% Get the Data %%%%%%%%%%%%%%% | ||
for subj = 1:length(subjects) | ||
load([data_dir prefix subjects{subj} '.mat']); | ||
X = roi_data.timeseries; | ||
X = X(rois,:); | ||
for k = 1:size(X,1) | ||
x = X(k,:); | ||
X(k,:) = (x-mean(x))/std(x); | ||
end | ||
data_subj(:,:,subj) = X; %for Vitterbi states | ||
data_states{subj} = X; | ||
end | ||
figure(1) | ||
for k = 1:length(Model) | ||
%subplot(length(Model),1,k) | ||
plot(Model{k}.F + 5*randn,'o-') | ||
hold on | ||
title('Log-Lower Bound') | ||
end | ||
for repetition = 1:length(Model) | ||
F = Model{repetition}.F; | ||
if repetition == 1 | ||
model = Model{repetition}; | ||
elseif (max(model.F) < max(F)) | ||
model = Model{repetition}; | ||
end | ||
end | ||
state_prob_smooth = []; | ||
Task = []; | ||
for n = 1:length(data_states) | ||
state_prob_smooth = [state_prob_smooth model.state_prob_smooth{n}]; | ||
% % Task = [Task task{n}]; | ||
end | ||
[est_states] = est_states_vitterbi(data_states,model); | ||
post_states = est_states; | ||
|
||
%[max_probs,post_states] = max(state_prob_smooth); | ||
|
||
figure(2) | ||
subplot(311) | ||
imagesc(state_prob_smooth) | ||
subplot(312) | ||
plot(post_states,'linewidth',2) | ||
subplot(313) | ||
plot(model.F,'o-','linewidth',2) | ||
|
||
K = size(model.Wa,1); | ||
counts_post = zeros(1,K); | ||
for k = 1:K | ||
counts_post(k) = length(find(post_states == k)); | ||
end | ||
counts_post = counts_post/sum(counts_post); | ||
[percent_dominant dominant_states] = sort(counts_post,'descend'); | ||
figure(3) | ||
subplot(211) | ||
bar(counts_post*100) | ||
[fractional_occupancy, mean_life,Counters] = summary_stats(post_states,K); | ||
subplot(212) | ||
bar(mean_life) | ||
%Estimates of Covariance | ||
for k = 1:K | ||
ap = model.ap(k); bp = model.bp(:,:,k); | ||
est_cov(:,:,k) = bp/ap; | ||
invD = inv(diag(sqrt(diag(est_cov(:,:,k))))); | ||
pearson_corr(:,:,k) = invD*est_cov(:,:,k)*invD; | ||
%Partial Correlation | ||
inv_est_cov(:,:,k) = inv(est_cov(:,:,k)); | ||
invD = inv(diag(sqrt(diag(inv_est_cov(:,:,k))))); | ||
partial_corr(:,:,k) = -invD*inv_est_cov(:,:,k)*invD; | ||
end | ||
%estimeted state transition matrix A | ||
Wa = model.Wa; | ||
Wa = Wa'; | ||
H = size(Wa,1); | ||
Aest = Wa./repmat(sum(Wa,1),H,1);% transition distribution p(h(t)|h(t-1)) | ||
%estimeted pi | ||
Wpi = model.Wpi; | ||
piest = Wpi./sum(Wpi); | ||
figure(4) | ||
for k = 1:15 | ||
subplot(3,5,k) | ||
cca_plotcausality(abs(pearson_corr(:,:,dominant_states(k))) > 0.2,roi_names,5); | ||
end | ||
for k = 1:6 | ||
[est_network1,clust_mtx1] = clusters_community_detection((pearson_corr(:,:,dominant_states(k)))); | ||
[est_network2,clust_mtx2] = clusters_community_detection((partial_corr(:,:,dominant_states(k)))); | ||
figure(5) | ||
subplot(2,3,k) | ||
cca_plotcausality(est_network1,roi_names,5); | ||
figure(6) | ||
subplot(2,3,k) | ||
cca_plotcausality(est_network2,roi_names,5); | ||
end | ||
static_corr = corr(cell2mat(data_states)'); | ||
invD = inv(diag(sqrt(diag(static_corr)))); | ||
inv_cov = inv(cov(cell2mat(data_states)')); | ||
invD = inv(diag(sqrt(diag(inv_cov)))); | ||
partial_corr = invD*inv_cov*invD; | ||
[est_network1,clust_mtx1] = clusters_community_detection((static_corr)); | ||
[est_network2,clust_mtx2] = clusters_community_detection((partial_corr)); | ||
figure(7) | ||
subplot(211) | ||
cca_plotcausality(est_network1,roi_names,5); | ||
subplot(212) | ||
cca_plotcausality(est_network2,roi_names,5); | ||
|
Oops, something went wrong.