Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Development #105

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions rayTrain_Tune.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# This is a simple tuning script for the raytune library.
# no support for distributed training in this file.
"""
This is a tuning script for the raytune library.

support for DDP is done by RayTune
- This is done by increasing the number of workers in the ScalingConfig
- For example the following would distribute inference across 2 GPUs (num_workers*resources_per_worker['GPU']):
num_workers=2,
use_gpu=True,
resources_per_worker={"CPU": 2, "GPU": 1},
-
"""

import random
import os
Expand Down Expand Up @@ -82,7 +91,7 @@ def train_func(config):
search_space = {
## constants:
"epochs": 20,
"model": cfg.MODEL_OPT.GVPL,
"model": cfg.MODEL_OPT.DG,

"dataset": cfg.DATA_OPT.kiba,
"feature_opt": cfg.PRO_FEAT_OPT.nomsa,
Expand Down
77 changes: 0 additions & 77 deletions raytune.py

This file was deleted.

134 changes: 0 additions & 134 deletions raytune_DDP.py

This file was deleted.

121 changes: 119 additions & 2 deletions src/models/gvp_branch.py → src/models/branches.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,126 @@
from torch import nn
import torch
from torch_scatter import scatter_mean
import torch.nn as nn
import torch_geometric

from torch_geometric.nn import GCNConv, GATConv, global_max_pool as gmp, global_mean_pool as gep
from torch_geometric.utils import dropout_edge, dropout_node
from torch_geometric import data as geo_data
from torch_geometric.nn import summary

from transformers import AutoTokenizer, EsmModel
from transformers.utils import logging

from torch_scatter import scatter_mean
from src.models.utils import GVP, GVPConvLayer, LayerNorm

class ESMBranch(nn.Module):
def __init__(self, esm_head:str='facebook/esm2_t6_8M_UR50D',
num_feat=320, emb_dim=512, output_dim=128, dropout=0.2,
dropout_gnn=0.0, extra_fc_lyr=False):

super(ESMBranch, self).__init__()

# Protein graph:
self.conv1 = GCNConv(num_feat, emb_dim)
self.conv2 = GCNConv(emb_dim, emb_dim * 2)
self.conv3 = GCNConv(emb_dim * 2, emb_dim * 4)

if not extra_fc_lyr:
self.fc_g1 = nn.Linear(emb_dim * 4, 1024)
else:
self.fc_g1 = nn.Linear(emb_dim * 4, emb_dim * 2)
self.fc_g1b = nn.Linear(emb_dim * 2, 1024)

self.extra_fc_lyr = extra_fc_lyr
self.fc_g2 = nn.Linear(1024, output_dim)

# this will raise a warning since lm head is missing but that is okay since we are not using it:
# prev_v = logging.get_verbosity()
logging.set_verbosity(logging.CRITICAL)
self.esm_tok = AutoTokenizer.from_pretrained(esm_head)
self.esm_mdl = EsmModel.from_pretrained(esm_head)
self.esm_mdl.requires_grad_(False) # freeze weights

self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
# note that dropout for edge and nodes is handled by torch_geometric in forward pass
self.dropout_gnn = dropout_gnn

def forward(self, data):
#### ESM emb ####
# cls and sep tokens are added to the sequence by the tokenizer
seq_tok = self.esm_tok(data.pro_seq,
return_tensors='pt',
padding=True) # [B, L_max+2]
seq_tok['input_ids'] = seq_tok['input_ids'].to(data.x.device)
seq_tok['attention_mask'] = seq_tok['attention_mask'].to(data.x.device)

esm_emb = self.esm_mdl(**seq_tok).last_hidden_state # [B, L_max+2, emb_dim]

# removing <cls> token
esm_emb = esm_emb[:,1:,:] # [B, L_max+1, emb_dim]

# removing <sep> token by applying mask
L_max = esm_emb.shape[1] # L_max+1
mask = torch.arange(L_max)[None, :] < torch.tensor([len(seq) for seq in data.pro_seq])[:, None]
mask = mask.flatten(0,1) # [B*L_max+1]

# flatten from [B, L_max+1, emb_dim]
esm_emb = esm_emb.flatten(0,1) # to [B*L_max+1, emb_dim]
esm_emb = esm_emb[mask] # [B*L, emb_dim]

if self.esm_only:
target_x = esm_emb # [B*L, emb_dim]
else:
# append esm embeddings to protein input
target_x = torch.cat((esm_emb, data.x), axis=1)
# ->> [B*L, emb_dim+feat_dim]

#### Graph NN ####
ei = data.edge_index
# if edge_weight doesnt exist no error is thrown it just passes it as None
ew = data.edge_weight if (self.edge_weight is not None and
self.edge_weight != 'binary') else None

target_x = self.relu(target_x)
ei_drp, _, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)

# conv1
xt = self.conv1(target_x, ei_drp, ew)
xt = self.relu(xt)
ei_drp, _, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv2
xt = self.conv2(xt, ei_drp, ew)
xt = self.relu(xt)
ei_drp, _, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv3
xt = self.conv3(xt, ei_drp, ew)
xt = self.relu(xt)

# flatten/pool
xt = gep(xt, data.batch) # global pooling
xt = self.relu(xt)
xt = self.dropout(xt)

#### FC layers ####
xt = self.fc_g1(xt)
xt = self.relu(xt)
xt = self.dropout(xt)

if self.extra_fc_lyr:
xt = self.fc_g1b(xt)
xt = self.relu(xt)
xt = self.dropout(xt)

xt = self.fc_g2(xt)
xt = self.relu(xt)
xt = self.dropout(xt)
return xt


# Adapted from https://github.com/drorlab/gvp-pytorch/blob/82af6b22eaf8311c15733117b0071408d24ed877/gvp/models.py
class GVPBranchProt(nn.Module):
'''
Expand Down
Loading
Loading