From 7d33bd0d42c4e7c862d34abb98a1a1fc9a21cbbd Mon Sep 17 00:00:00 2001 From: macwiatrak Date: Sun, 2 Feb 2025 12:21:22 +0000 Subject: [PATCH] feat: add gene_layers instead of gene_matrix --- bactgraph/modeling/model.py | 18 ++++++++++++------ bactgraph/modeling/train.py | 2 +- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/bactgraph/modeling/model.py b/bactgraph/modeling/model.py index f63d160..57ee47e 100644 --- a/bactgraph/modeling/model.py +++ b/bactgraph/modeling/model.py @@ -118,9 +118,11 @@ def __init__(self, config: dict): ) # self.linear = nn.Linear(config["output_dim"], 1) - self.bias = torch.nn.Parameter(torch.zeros(config["n_genes"])).unsqueeze(1) + # self.bias = torch.nn.Parameter(torch.zeros(config["n_genes"])).unsqueeze(1) self.dropout = nn.Dropout(config["dropout"]) - self.gene_matrix = nn.Parameter(nn.init.xavier_normal_(torch.empty(config["n_genes"], config["output_dim"]))) + # self.gene_matrix = nn.Parameter(nn.init.xavier_normal_(torch.empty(config["n_genes"], config["output_dim"]))) + + self.gene_layers = nn.ModuleList([nn.Linear(config["output_dim"], 1) for _ in range(config["n_genes"])]) # Learning rate (default to 1e-3 if not specified) self.lr = config.get("lr", 1e-3) @@ -133,10 +135,14 @@ def forward(self, x_batch: torch.Tensor, edge_index_batch: torch.Tensor, gene_in # logits = self.gat_module(x, edge_index).squeeze() # + self.bias.repeat(batch_size) last_hidden_state = self.gat_module(x, edge_index) last_hidden_state = group_by_label(self.dropout(last_hidden_state), gene_indices.view(-1)) - logits = torch.einsum( - "bnm,bm->bn", last_hidden_state, self.gene_matrix.to(last_hidden_state.device) - ) + self.bias.to(last_hidden_state.device) - return logits # F.softplus(logits) + # logits = torch.einsum( + # "bnm,bm->bn", last_hidden_state, self.gene_matrix.to(last_hidden_state.device) + # ) + self.bias.to(last_hidden_state.device) + logits = [] + for idx, gene_lhs in enumerate(last_hidden_state): + logits.append(self.gene_layers[idx](gene_lhs)) + logits = torch.stack(logits, dim=1).squeeze() + return F.softplus(logits) def training_step(self, batch, batch_idx): """Training step.""" diff --git a/bactgraph/modeling/train.py b/bactgraph/modeling/train.py index dd4cf1a..7c07a32 100644 --- a/bactgraph/modeling/train.py +++ b/bactgraph/modeling/train.py @@ -73,7 +73,7 @@ def __init__(self): monitor_metric: str = "val_r2" early_stop_patience: int = 15 gradient_clip_val: float = 0.0 - t_max: int = 10 + # t_max: int = 10 def main(args):