Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cost function documentation #95

Merged
merged 3 commits into from
Jun 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions whobpyt/optimization/cost_FC.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,53 @@


class CostsFC(AbstractLoss):
"""
Cost function for Fitting the Functional Connectivity (FC) matrix.
The cost function is the negative log-likelihood of the Pearson correlation between the simulated FC and empirical FC.

Attributes:
-----------
simKey: str
string key to reference to this const function. i.e., "CostsFC".

Methods:
--------
loss: function
calculates functional connectivity and uses it to calculate the loss
"""
def __init__(self, simKey):
"""
Parameters:
-----------
simKey: str
type of cost function to be used
"""
super(CostsFC, self).__init__()
self.simKey = simKey

def loss(self, sim, emp, model: torch.nn.Module = None, state_vals = None):
logits_series_tf = sim
labels_series_tf = emp

"""
Calculate the Pearson Correlation between the simFC and empFC.
From there, the probability and negative log-likelihood.
"""Function to calculate the cost function for Functional Connectivity (FC) fitting. It initially calculates the FC matrix using the data from the BOLD time series, makes that mean-zero, and then calculates the Pearson Correlation between the simulated FC and empirical FC. The FC matrix values are then transposed to the 0-1 range. We then use this FC matrix as a probability matrix and use it to get the cross-entropy-like loss using negative log likelihood.

Parameters
----------
logits_series_tf: tensor with node_size X datapoint
sim: torch.tensor with node_size X datapoint
simulated BOLD
labels_series_tf: tensor with node_size X datapoint
emp: torch.tensor with node_size X datapoint
empirical BOLD
model: torch.nn.Module
model to be used for the loss calculation
default: None
state_vals: list
list of state values
default: None

Returns
-------
losses_corr: torch.tensor
cost function value
"""

logits_series_tf = sim
labels_series_tf = emp
# get node_size() and TRs_per_window()
node_size = logits_series_tf.shape[0]
truncated_backprop_length = logits_series_tf.shape[1]
Expand All @@ -44,32 +72,30 @@ def loss(self, sim, emp, model: torch.nn.Module = None, state_vals = None):
cov_sim = torch.matmul(logits_series_tf_n, torch.transpose(logits_series_tf_n, 0, 1))
cov_def = torch.matmul(labels_series_tf_n, torch.transpose(labels_series_tf_n, 0, 1))

# fc for sim and empirical BOLDs
# Getting the FC matrix for the simulated and empirical BOLD signals
FC_sim_T = torch.matmul(torch.matmul(torch.diag(torch.reciprocal(torch.sqrt(
torch.diag(cov_sim)))), cov_sim),
torch.diag(torch.reciprocal(torch.sqrt(torch.diag(cov_sim)))))
FC_T = torch.matmul(torch.matmul(torch.diag(torch.reciprocal(torch.sqrt(torch.diag(cov_def)))), cov_def),
torch.diag(torch.reciprocal(torch.sqrt(torch.diag(cov_def)))))
torch.diag(torch.reciprocal(torch.sqrt(torch.diag(cov_sim))))) # SIMULATED FC
FC_T = torch.matmul(torch.matmul(torch.diag(torch.reciprocal(torch.sqrt(torch.diag(cov_def)))), cov_def), torch.diag(torch.reciprocal(torch.sqrt(torch.diag(cov_def))))) # EMPIRICAL FC

# mask for lower triangle without diagonal
# Masking out the upper triangle of the FC matrix and keeping the lower triangle
ones_tri = torch.tril(torch.ones_like(FC_T), -1)
zeros = torch.zeros_like(FC_T) # create a tensor all ones
mask = torch.greater(ones_tri, zeros) # boolean tensor, mask[i] = True iff x[i] > 1

# mask out fc to vector with elements of the lower triangle
FC_tri_v = torch.masked_select(FC_T, mask)
FC_sim_tri_v = torch.masked_select(FC_sim_T, mask)

# remove the mean across the elements
# Bring the FC mean to zero
FC_v = FC_tri_v - torch.mean(FC_tri_v)
FC_sim_v = FC_sim_tri_v - torch.mean(FC_sim_tri_v)

# corr_coef
# Calculate the correlation coefficient between the simulated FC and empirical FC
corr_FC = torch.sum(torch.multiply(FC_v, FC_sim_v)) \
* torch.reciprocal(torch.sqrt(torch.sum(torch.multiply(FC_v, FC_v)))) \
* torch.reciprocal(torch.sqrt(torch.sum(torch.multiply(FC_sim_v, FC_sim_v))))

# use surprise: corr to calculate probability and -log
# Bringing the corr-FC to the 0-1 range, and calculating the negative log-likelihood
losses_corr = -torch.log(0.5000 + 0.5 * corr_FC) # torch.mean((FC_v -FC_sim_v)**2)#
return losses_corr