Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feat: create and train GNN for gene expression prediction using gene regulatory network #1

Merged
merged 49 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
7c5e020
feat: add dependencies required for GNN
macwiatrak Jan 31, 2025
ef82063
feat: add the dataset.py for preprocessing the BactGraph dataset.py
macwiatrak Jan 31, 2025
f91ba68
feat: add placeholder train.py as well as init for model and trainer,…
macwiatrak Jan 31, 2025
fc74910
feat: add support for training the model including the data reader an…
macwiatrak Feb 1, 2025
ecd388f
feat: add support for training the model including the data reader an…
macwiatrak Feb 1, 2025
1348c13
debug: debug data processing
macwiatrak Feb 1, 2025
c3d3b2e
fix: add output_dim as an argument
macwiatrak Feb 1, 2025
b75af59
fix: n_unique genes
macwiatrak Feb 1, 2025
55a5903
fix: fix passing args to trainer
macwiatrak Feb 1, 2025
306e38e
fix: fix fetching protein embeddings in the dataset.py
macwiatrak Feb 1, 2025
e20ea14
fix: fix fetching protein embeddings in the dataset.py
macwiatrak Feb 1, 2025
1af18b6
fix: convert edge indices to long tensors
macwiatrak Feb 1, 2025
0e5357e
fix: fix adding bias to the output
macwiatrak Feb 1, 2025
89982db
fix: fix adding bias to the output
macwiatrak Feb 1, 2025
53695c1
fix: fix adding bias to the output
macwiatrak Feb 1, 2025
e40d0a9
debug: debug nan in loss
macwiatrak Feb 1, 2025
0e99651
debug: debug nan in loss
macwiatrak Feb 1, 2025
90ee5bc
debug: remove debugging print
macwiatrak Feb 1, 2025
f832fd4
feat: experience with removing bias
macwiatrak Feb 1, 2025
4eb351a
feat: group embeddings by label
macwiatrak Feb 1, 2025
f296ab8
feat: add gene matrix parameter
macwiatrak Feb 1, 2025
cf6a6ed
feat: add gene matrix parameter
macwiatrak Feb 1, 2025
520f1c7
feat: add gene matrix parameter
macwiatrak Feb 1, 2025
840824e
feat: add gene matrix parameter
macwiatrak Feb 1, 2025
75c699c
feat: add gene matrix parameter
macwiatrak Feb 1, 2025
d0961a8
feat: add gene matrix parameter
macwiatrak Feb 1, 2025
45aca53
debug: randomize the network
macwiatrak Feb 1, 2025
472cb4b
debug: revert randomize the network
macwiatrak Feb 2, 2025
ed06cc3
feat: add cosine annealing LR
macwiatrak Feb 2, 2025
3cbc1bf
feat: add t_max for cosine annealing LR
macwiatrak Feb 2, 2025
ece65d5
exp: remove cosine annealing LR
macwiatrak Feb 2, 2025
5fc2447
exp: don't use softplus
macwiatrak Feb 2, 2025
7d33bd0
feat: add gene_layers instead of gene_matrix
macwiatrak Feb 2, 2025
0183303
feat: only add biad, no gene layers
macwiatrak Feb 2, 2025
18c7300
feat: only add bias, no gene layers
macwiatrak Feb 2, 2025
f33bd93
feat: only add bias, no gene layers
macwiatrak Feb 2, 2025
f07330b
feat: try to revert back to previous version with gene matrix
macwiatrak Feb 2, 2025
fb07df1
feat: experiment with a single non-gene specific layer
macwiatrak Feb 2, 2025
bee8c49
feat: experiment with no gene specific layer
macwiatrak Feb 2, 2025
9a68d94
feat: do testing on val and test after training
macwiatrak Feb 2, 2025
fb87742
feat: change the train_size
macwiatrak Feb 2, 2025
1c224e2
feat: randomize the network experiment
macwiatrak Feb 2, 2025
3702143
feat: fully connected network experiment
macwiatrak Feb 2, 2025
954a359
exp: subset to half of genes
macwiatrak Feb 2, 2025
b12e6a6
exp: subset to quarter of genes
macwiatrak Feb 2, 2025
a5665d4
exp: run expn with gene matrix and randomizing the network
macwiatrak Feb 3, 2025
fc084bc
exp: revert randomizing the network
macwiatrak Feb 3, 2025
014a040
exp: flip network direction
macwiatrak Feb 3, 2025
3f6df87
refactor: revert to previous version
macwiatrak Feb 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added bactgraph/modeling/__init__.py
Empty file.
107 changes: 107 additions & 0 deletions bactgraph/modeling/data_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os
import random
from collections.abc import Callable
from typing import Any

import numpy as np
import pandas as pd
from torch.utils.data import DataLoader

from bactgraph.modeling.dataset import BactGraphDataset

BACTMAP_PROTEINS_FILE_NAME = "bactmap_proteins_prot_embeds.parquet"
NORMALISED_EXPRESSION_FILE_NAME = "norm_dat_pao1.tsv"
PERTURB_NETWORK_FILE_NAME = "llcb_perturb_hits_adj_matrix.tsv"


def preprocess_data_for_training(
input_dir: str,
transform_norm_expression_fn: Callable = np.log10,
train_size: float = 0.7,
test_size: float = 0.2,
batch_size: int = 32,
num_workers: int = 4,
random_seed: int = 42,
) -> dict[str, Any]:
"""Preprocess the data for training the BactGraph model."""
# read the data
protein_embeddings = pd.read_parquet(os.path.join(input_dir, BACTMAP_PROTEINS_FILE_NAME))
expression_df = pd.read_csv(os.path.join(input_dir, NORMALISED_EXPRESSION_FILE_NAME), sep="\t").set_index(
"feature_id"
)
perturb_network = pd.read_csv(os.path.join(input_dir, PERTURB_NETWORK_FILE_NAME), sep="\t").set_index("gene_id")

# keep only genes which are in all files
prot_emb_genes = set(protein_embeddings.columns.tolist())
expression_genes = set(expression_df.index.tolist())
perturb_network_genes = set(perturb_network.index.tolist() + perturb_network.columns.tolist())

genes_of_interest = list(prot_emb_genes.intersection(expression_genes).intersection(perturb_network_genes))
print(f"Total nr of genes available: {len(genes_of_interest)}")

# subset the genes of interest
protein_embeddings = protein_embeddings[genes_of_interest]
expression_df = expression_df[expression_df.index.isin(genes_of_interest)]
perturb_network = perturb_network[[g for g in genes_of_interest if g in perturb_network.columns]]
perturb_network = perturb_network[perturb_network.index.isin(genes_of_interest)]

# subset to the strains with expression data
strains_w_expression = expression_df.columns.tolist()
strains_w_prot_emb = protein_embeddings.index.tolist()
strains_of_interest = list(set(strains_w_expression).intersection(strains_w_prot_emb))
expression_df = expression_df[strains_of_interest]
protein_embeddings = protein_embeddings.loc[strains_of_interest]

# split the data
random.seed(random_seed)
random.shuffle(strains_of_interest)
train_size = int(len(strains_of_interest) * train_size)
test_size = int(len(strains_of_interest) * test_size)
train_strains = strains_of_interest[:train_size]
test_strains = strains_of_interest[train_size : train_size + test_size]
val_strains = strains_of_interest[train_size + test_size :]

gene2idx = {gene: idx for idx, gene in enumerate(protein_embeddings.columns)}

# create datasets
train_dataset = BactGraphDataset(
protein_embeddings=protein_embeddings.loc[train_strains],
expression_df=expression_df[train_strains],
gene2idx=gene2idx,
perturb_network=perturb_network,
transform_norm_expression_fn=transform_norm_expression_fn,
random_seed=random_seed,
)
val_dataset = BactGraphDataset(
protein_embeddings=protein_embeddings.loc[val_strains],
expression_df=expression_df[val_strains],
gene2idx=gene2idx,
perturb_network=perturb_network,
transform_norm_expression_fn=transform_norm_expression_fn,
random_seed=random_seed,
)
test_dataset = BactGraphDataset(
protein_embeddings=protein_embeddings.loc[test_strains],
expression_df=expression_df[test_strains],
gene2idx=gene2idx,
perturb_network=perturb_network,
transform_norm_expression_fn=transform_norm_expression_fn,
random_seed=random_seed,
)

# create dataloaders
train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers
)
val_dataloader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers
)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=num_workers)

return dict( # noqa
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=test_dataloader,
n_train_size=len(train_strains),
gene2idx=gene2idx,
)
92 changes: 92 additions & 0 deletions bactgraph/modeling/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from collections.abc import Callable

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

BACTMAP_PROTEINS_FILE_NAME = "bactmap_proteins_prot_embeds.parquet"
NORMALISED_EXPRESSION_FILE_NAME = "norm_dat_pao1.tsv"
PERTURB_NETWORK_FILE_NAME = "bactmap_proteins_prot_embeds.parquet"


def perturb_mtx_to_triples(df: pd.DataFrame, gene2idx: dict[str, int]) -> torch.Tensor:
"""Conver perturbation dataframe to triples with non-zero values for training."""
# 1. "Stack" the DataFrame so that rows become part of a MultiIndex
nonzero_stacked = df.stack() # This will convert the DataFrame into a Series

# 2. Filter out zero values
nonzero_stacked = nonzero_stacked[nonzero_stacked != 0]

# 3. Convert to a list of (index_name, column_name, value) tuples
triples = list(
zip(
nonzero_stacked.index.get_level_values(0), # index name
nonzero_stacked.index.get_level_values(1), # column name
nonzero_stacked.values,
strict=False, # value
)
)

triples = torch.tensor(
[
[gene2idx[gene1] for gene1, _, _ in triples],
[gene2idx[gene2] for _, gene2, _ in triples],
[val for _, _, val in triples],
],
dtype=torch.float32,
)
return triples


class BactGraphDataset(Dataset):
"""Dataset of gene networks in bacteria for BactGraph project."""

def __init__(
self,
protein_embeddings: pd.DataFrame,
expression_df: pd.DataFrame,
gene2idx: dict[str, int],
perturb_network: pd.DataFrame,
transform_norm_expression_fn: Callable = np.log10,
random_seed: int = 42,
):
self.protein_embeddings = protein_embeddings
self.expression_df = expression_df
self.gene2idx = gene2idx

# get triples
self.triples = perturb_mtx_to_triples(perturb_network, self.gene2idx)[:2, :]
# reverse the direction
# self.triples = self.triples[:2, :].flip(0)
# randomize the network experiment
# print("Randomizing the network experiment by randomly sampling edges.")
# torch.manual_seed(random_seed)
# self.triples = torch.randint(0, len(self.gene2idx), self.triples.shape)
# fully connected network
# self.triples = torch.stack(
# [torch.arange(len(self.gene2idx)), torch.arange(len(self.gene2idx)), torch.ones(len(self.gene2idx))],
# dim=0,
# )

# normalise the expression data
# revert previous log2 transformation (the data was provided like this)
self.expression_df = self.expression_df.apply(np.exp2)
# transform the data with the provided function
self.expression_df = self.expression_df.apply(transform_norm_expression_fn).fillna(-100.0)

self.strains = self.expression_df.columns.tolist()

def __len__(self):
return len(self.expression_df.columns)

def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# get the expression data for the idx-th strain
strain = self.strains[idx]
# get protein embeddings
prot_emb = torch.tensor(np.stack(self.protein_embeddings.loc[strain].values), dtype=torch.float32)
expr_values = torch.tensor(
[self.expression_df.loc[gene, strain] for gene in self.protein_embeddings.columns], dtype=torch.float32
)
gene_idx = torch.arange(len(self.protein_embeddings.columns), dtype=torch.long)
return prot_emb, self.triples, expr_values, gene_idx
Loading
Loading