Skip to content

Commit

Permalink
feat: add gene_layers instead of gene_matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
macwiatrak committed Feb 2, 2025
1 parent 5fc2447 commit 7d33bd0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
18 changes: 12 additions & 6 deletions bactgraph/modeling/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,11 @@ def __init__(self, config: dict):
)

# self.linear = nn.Linear(config["output_dim"], 1)
self.bias = torch.nn.Parameter(torch.zeros(config["n_genes"])).unsqueeze(1)
# self.bias = torch.nn.Parameter(torch.zeros(config["n_genes"])).unsqueeze(1)
self.dropout = nn.Dropout(config["dropout"])
self.gene_matrix = nn.Parameter(nn.init.xavier_normal_(torch.empty(config["n_genes"], config["output_dim"])))
# self.gene_matrix = nn.Parameter(nn.init.xavier_normal_(torch.empty(config["n_genes"], config["output_dim"])))

self.gene_layers = nn.ModuleList([nn.Linear(config["output_dim"], 1) for _ in range(config["n_genes"])])

# Learning rate (default to 1e-3 if not specified)
self.lr = config.get("lr", 1e-3)
Expand All @@ -133,10 +135,14 @@ def forward(self, x_batch: torch.Tensor, edge_index_batch: torch.Tensor, gene_in
# logits = self.gat_module(x, edge_index).squeeze() # + self.bias.repeat(batch_size)
last_hidden_state = self.gat_module(x, edge_index)
last_hidden_state = group_by_label(self.dropout(last_hidden_state), gene_indices.view(-1))
logits = torch.einsum(
"bnm,bm->bn", last_hidden_state, self.gene_matrix.to(last_hidden_state.device)
) + self.bias.to(last_hidden_state.device)
return logits # F.softplus(logits)
# logits = torch.einsum(
# "bnm,bm->bn", last_hidden_state, self.gene_matrix.to(last_hidden_state.device)
# ) + self.bias.to(last_hidden_state.device)
logits = []
for idx, gene_lhs in enumerate(last_hidden_state):
logits.append(self.gene_layers[idx](gene_lhs))
logits = torch.stack(logits, dim=1).squeeze()
return F.softplus(logits)

def training_step(self, batch, batch_idx):
"""Training step."""
Expand Down
2 changes: 1 addition & 1 deletion bactgraph/modeling/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self):
monitor_metric: str = "val_r2"
early_stop_patience: int = 15
gradient_clip_val: float = 0.0
t_max: int = 10
# t_max: int = 10


def main(args):
Expand Down

0 comments on commit 7d33bd0

Please sign in to comment.