Skip to content

Commit

Permalink
feat: Adding protein design
Browse files Browse the repository at this point in the history
  • Loading branch information
JaktensTid committed Dec 9, 2023
1 parent c601edf commit d262028
Show file tree
Hide file tree
Showing 16 changed files with 516 additions and 178 deletions.
5 changes: 5 additions & 0 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
"name": "DTI",
"type": "dti",
"task": "dti"
},
{
"name": "rfdiffusion",
"type": "protein_design",
"task": "protein_design"
}
]
}
28 changes: 28 additions & 0 deletions installation/install_rfdiffusion.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
protein_design_root=src/ai/custom_models/protein_design
rfdiff_root=src/ai/custom_models/protein_design/RFdiffusion

# Install RFdiffusion
#git clone https://github.com/RosettaCommons/RFdiffusion.git $protein_design_root/RFdiffusion

#wget -P $rfdiff_root/models http://files.ipd.uw.edu/pub/RFdiffusion/6f5902ac237024bdd0c176cb93063dc4/Base_ckpt.pt
#wget -P $rfdiff_root/models http://files.ipd.uw.edu/pub/RFdiffusion/e29311f6f1bf1af907f9ef9f44b8328b/Complex_base_ckpt.pt
#wget -P $rfdiff_root/models http://files.ipd.uw.edu/pub/RFdiffusion/60f09a193fb5e5ccdc4980417708dbab/Complex_Fold_base_ckpt.pt
#wget -P $rfdiff_root/models http://files.ipd.uw.edu/pub/RFdiffusion/74f51cfb8b440f50d70878e05361d8f0/InpaintSeq_ckpt.pt
#wget -P $rfdiff_root/models http://files.ipd.uw.edu/pub/RFdiffusion/76d00716416567174cdb7ca96e208296/InpaintSeq_Fold_ckpt.pt
#wget -P $rfdiff_root/models http://files.ipd.uw.edu/pub/RFdiffusion/5532d2e1f3a4738decd58b19d633b3c3/ActiveSite_ckpt.pt
#wget -P $rfdiff_root/models http://files.ipd.uw.edu/pub/RFdiffusion/12fc204edeae5b57713c5ad7dcb97d39/Base_epoch8_ckpt.pt

# Optional:
#wget -P $rfdiff_root/models http://files.ipd.uw.edu/pub/RFdiffusion/f572d396fae9206628714fb2ce00f72e/Complex_beta_ckpt.pt

# original structure prediction weights
#wget -P $rfdiff_root/models http://files.ipd.uw.edu/pub/RFdiffusion/1befcb9b28e2f778f53d47f18b7597fa/RF_structure_prediction_weights.pt

conda env create -f $rfdiff_root/env/SE3nv.yml

conda activate SE3nv
pip install --no-cache-dir -r $rfdiff_root/env/SE3Transformer/requirements.txt
python $rfdiff_root/env/SE3Transformer/setup.py install
pip install -e $rfdiff_root # install the rfdiffusion module from the root of the repository

conda run -n SE3nv pip install flask --upgrade
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ flask-cors
pandas
numpy
transformers

torch==1.13.1

flask_socketio
torchdrug==0.1.2
torchmetrics==1.2.0
Expand Down
25 changes: 25 additions & 0 deletions src/ai/custom_models/protein_design/rfdiffusion_microservice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from flask import Flask, request, jsonify

app = Flask(__name__)

@app.route('/rfdiffusion', methods=['POST'])
def rfdiffusion():
try:
# Assuming the incoming data is in JSON format
data = request.get_json()

# Accessing a key named 'message' from the JSON data
message = data.get('message', 'No message provided.')

# You can perform any processing with the received data here

response = {'status': 'success', 'message': f'Message received: {message}'}

return jsonify(response)

except Exception as e:
response = {'status': 'error', 'message': str(e)}
return jsonify(response), 500 # Internal Server Error

if __name__ == '__main__':
app.run(debug=True)
98 changes: 78 additions & 20 deletions src/ai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from src.ai.custom_models.drug_target.umol.src.make_msa_seq_feats_colab import process
from src.ai.custom_models.drug_target.umol.src.net.model import config
from src.ai.custom_models.drug_target.umol.src.predict_colab import predict
from src.ai.custom_models.drug_target.umol.src.relax.align_ligand_conformer_colab import read_pdb, generate_best_conformer, align_coords_transform, write_sdf

from src.ai.custom_models.drug_target.umol.src.relax.align_ligand_conformer_colab import read_pdb, \
generate_best_conformer, align_coords_transform, write_sdf

from Bio.PDB import PDBParser
import rdkit.Chem as Chem
Expand All @@ -47,6 +47,7 @@

dirname = os.path.dirname


class BaseModel:
def __init__(self, model_name: str, gpu: bool, model_task=""):
self.model_name = model_name
Expand Down Expand Up @@ -434,16 +435,16 @@ def post_process(chosen, rdkitMolFile, dataset, result_folder):

for protein_idx in range(len(protein_names)):
protein_name = protein_names[protein_idx]
logger.info(f"Making dti predictions for {protein_idx+1} out of {len(protein_names)} protein...")
logger.info(f"Making dti predictions for {protein_idx + 1} out of {len(protein_names)} protein...")
result_folder = os.path.join(experiment_folder, protein_name, 'result')
if not os.path.exists(result_folder):
os.mkdir(result_folder)
ligand2compounddict = self.protein2ligandcompdict[protein_name]
for ligand_idx in range(len(ligands_names)):
ligand_name = ligands_names[ligand_idx]
logger.info(f"Making dti predictions for {protein_idx+1} out of {len(protein_names)} protein...\n \
Currently predicting {ligand_idx+1} out of {len(ligands_names)} ligands...")
info = pd.read_csv(os.path.join(experiment_folder,protein_name, f"{ligand_name}_temp_info.csv"))
logger.info(f"Making dti predictions for {protein_idx + 1} out of {len(protein_names)} protein...\n \
Currently predicting {ligand_idx + 1} out of {len(ligands_names)} ligands...")
info = pd.read_csv(os.path.join(experiment_folder, protein_name, f"{ligand_name}_temp_info.csv"))
compound_dict, rdkitMolFile = ligand2compounddict[ligand_name]
dataset_path = f"{experiment_folder}/{protein_name}/{ligand_name}_dataset/"
os.system(f"rm -r {dataset_path}")
Expand Down Expand Up @@ -492,7 +493,6 @@ def _prepare_protein_data(self, experiment_folder, protein_files_paths):
# Add other protein data preparation steps here
self.submit_fasta_and_save_a3m(settings.FASTA_API, protein_files_paths, experiment_folder)


def submit_fasta_and_save_a3m(self, api_url, fasta_file_paths, experiment_folder):
for fasta_file_path in fasta_file_paths:
# Extract the base name to create the .a3m file name
Expand Down Expand Up @@ -536,10 +536,10 @@ def prepare_ligand_data(self, fasta_files, ligands, ligand_names, experiment_fol
if not os.path.exists(protein_folder):
os.makedirs(protein_folder)

MSA=os.path.join(experiment_folder, protein_name, protein_name+'.a3m')
PROCESSED_MSA=os.path.join(experiment_folder, protein_name, protein_name+'_processed.a3m')
MSA = os.path.join(experiment_folder, protein_name, protein_name + '.a3m')
PROCESSED_MSA = os.path.join(experiment_folder, protein_name, protein_name + '_processed.a3m')
process_a3m(MSA, get_sequence(fasta_file), PROCESSED_MSA)
MSA=PROCESSED_MSA
MSA = PROCESSED_MSA

# Process MSA features
feature_dict = process(fasta_file, [MSA]) # Assuming MSA is defined elsewhere
Expand All @@ -548,9 +548,12 @@ def prepare_ligand_data(self, fasta_files, ligands, ligand_names, experiment_fol
pickle.dump(feature_dict, f, protocol=4)
logger.info('Saved MSA features to', features_output_path)

atom_encoding = {'B':0, 'C':1, 'F':2, 'I':3, 'N':4, 'O':5, 'P':6, 'S':7,'Br':8, 'Cl':9, #Individual encoding
'As':10, 'Co':10, 'Fe':10, 'Mg':10, 'Pt':10, 'Rh':10, 'Ru':10, 'Se':10, 'Si':10, 'Te':10, 'V':10, 'Zn':10 #Joint (rare)
}
atom_encoding = {'B': 0, 'C': 1, 'F': 2, 'I': 3, 'N': 4, 'O': 5, 'P': 6, 'S': 7, 'Br': 8, 'Cl': 9,
# Individual encoding
'As': 10, 'Co': 10, 'Fe': 10, 'Mg': 10, 'Pt': 10, 'Rh': 10, 'Ru': 10, 'Se': 10, 'Si': 10,
'Te': 10, 'V': 10, 'Zn': 10
# Joint (rare)
}

# Process ligand features
atom_types, atoms, bond_types, bond_lengths, bond_mask = bonds_from_smiles(ligand, atom_encoding)
Expand All @@ -569,7 +572,8 @@ def prepare_ligand_data(self, fasta_files, ligands, ligand_names, experiment_fol

def load_model(self):
print("Loading uMol DTI params...")
self.model_params_path = os.path.join(dirname(os.path.abspath(__file__)), 'custom_models', 'drug_target', 'models_ckpt', 'params.pkl')
self.model_params_path = os.path.join(dirname(os.path.abspath(__file__)), 'custom_models', 'drug_target',
'models_ckpt', 'params.pkl')
url = 'https://huggingface.co/thomasshelby/uMol_params/resolve/main/params.pkl?download=true'
if not os.path.exists(self.model_params_path):
response = requests.get(url, stream=True)
Expand All @@ -596,17 +600,19 @@ def _raw_inference(self, protein_ids, ligands, ligands_names, experiment_folder,

with open(self.model_params_path, 'rb') as file:
PARAMS = pickle.load(file)

ID = ID.split('/')[-1]
# Predict
predict(config.CONFIG, MSA_FEATS, LIGAND_FEATS, ID, target_positions, PARAMS, num_recycles, outdir=result_folder)
predict(config.CONFIG, MSA_FEATS, LIGAND_FEATS, ID, target_positions, PARAMS, num_recycles,
outdir=result_folder)

# Process the prediction
RAW_PDB = os.path.join(result_folder, f'{ID}_pred_raw.pdb')

# Get a conformer
pred_ligand = read_pdb(RAW_PDB)
best_conf, best_conf_pos, best_conf_err, atoms, nonH_inds, mol, best_conf_id = generate_best_conformer(pred_ligand['chain_coords'], LIGAND)
best_conf, best_conf_pos, best_conf_err, atoms, nonH_inds, mol, best_conf_id = generate_best_conformer(
pred_ligand['chain_coords'], LIGAND)

# Align it to the prediction
aligned_conf_pos = align_coords_transform(pred_ligand['chain_coords'], best_conf_pos, nonH_inds)
Expand All @@ -619,20 +625,20 @@ def _raw_inference(self, protein_ids, ligands, ligands_names, experiment_folder,
protein_pdb_path = os.path.join(result_folder, f'{ID}_pred_protein.pdb')
ligand_plddt_path = os.path.join(result_folder, f'{LIGAND_NAME}_ligand_plddt.csv')

with open(RAW_PDB, 'r') as infile, open(protein_pdb_path, 'w') as protein_out, open(ligand_plddt_path, 'w') as ligand_out:
with open(RAW_PDB, 'r') as infile, open(protein_pdb_path, 'w') as protein_out, open(ligand_plddt_path,
'w') as ligand_out:
for line in infile:
if line.startswith('ATOM'):
protein_out.write(line)
elif line.startswith('HETATM'):
ligand_out.write(line[64:66] + '\n') # Extracting plDDT values


def predict(self, ligand_files_paths, protein_files, experiment_folder: str):
logger.info("Making dti predictions...")
ligands_names, ligands_smiles = read_sdf_files(ligand_files_paths)
protein_file_paths = [os.path.splitext(x)[0] for x in protein_files]

#TODO: ADD p2rank to identify target positions
# TODO: ADD p2rank to identify target positions
target_array = np.asarray([])
protein_names = [os.path.splitext(protein_file_path)[-2] for protein_file_path in protein_files]
self.prepare_data(ligands_names, ligands_smiles, protein_file_paths, protein_files, experiment_folder)
Expand All @@ -643,3 +649,55 @@ def predict(self, ligand_files_paths, protein_files, experiment_folder: str):
experiment_folder=experiment_folder,
target_positions=target_array,
num_recycles=3)


class RfDiffusionProteinDesign(BaseModel):
def __init__(self, model_name: str, gpu: bool, model_task=""):
super().__init__(model_name, gpu, model_task)

def load_model(self):
"""Load model and tokenizer here"""
pass

# Method to get raw model outputs
def _raw_inference(self,
pdb_content: str = None,
contig: str = '50',
symmetry: str = None,
timesteps: int = 50,
hotspots: str = ''):
result = requests.post(settings.RFDIFFUSION_API, json={
'pdb_content': pdb_content,
'contig': contig,
'symmetry': symmetry,
'timesteps': timesteps,
'hotspots': hotspots
})
return result.json()

# Method to return raw outputs in the desired format
def predict(self,
pdb_content: str = None,
contig: str = '50',
symmetry: str = None,
timesteps: int = 50,
hotspots: str = ''):
"""
Run raw inference on rfdiffusion
:param pdb_content: pdb content (optional). If not specified design an unconditional monomer
:param contig: contigs (around what and how to design a protein). See https://github.com/RosettaCommons/RFdiffusion
:param symmetry: To design a symmetrical protein,
Available symmetries:
- Cyclic symmetry (C_n) # call as c5
- Dihedral symmetry (D_n) # call as d5
- Tetrahedral symmetry # call as tetrahedral
- Octahedral symmetry # call as octahedral
- Icosahedral symmetry # call as icosahedral
Default None
:param timesteps: default 50 ( Desired iterations to generate structure. )
:param hotspots: A30, A33, A34
The model optionally readily learns that it should be making an interface which involving these hotspot residues. Input is ChainResidueNumber: A100 for residue 100 on chain A.
:return:
"""
result = self._raw_inference(pdb_content, contig, symmetry, timesteps, hotspots)
return result
14 changes: 14 additions & 0 deletions src/ai/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,18 @@ def create_model(model_metadata: Dict[str, str], use_gpu: bool = False) -> Union

return model

if model_type == "protein_design":

if 'task' in model_metadata:
model_task = model_metadata["task"]

model = SolubilityPrediction(model_name=model_name, gpu=use_gpu, model_task=model_task)
model.load_model()
if not model.embedding_model:
emb_model = ESM2EmbeddingGenerator(model_metadata['embedding_model'], gpu=use_gpu)
emb_model.load_model()
model.set_embedding_model(emb_model)

return model

raise UnknownModelException()
6 changes: 3 additions & 3 deletions src/server/api_handlers/drug_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def inference(self, request):
experiment_id = self.drug_discovery.run(ligand_files=ligand_files, protein_files=protein_files,
experiment_id=experiment_id)
self.experiments_loader.save_experiment_metadata(experiment_id, experiment_name=experiment_name)
data = self.experiments_loader.load_result(experiment_id)
data = self.experiments_loader.load_experiment(experiment_id)

return {'id': experiment_id, 'name': experiment_name, 'data': data}

Expand All @@ -39,7 +39,7 @@ def get_experiment(self, request):
if not self.experiments_loader.experiment_exists(experiment_id):
return {'id': experiment_id, 'name': experiment_name, 'data': {}}

data = self.experiments_loader.load_result(experiment_id)
data = self.experiments_loader.load_experiment(experiment_id)
return {'id': experiment_id, 'data': data}

def change_experiment_name(self, request: Request):
Expand All @@ -62,7 +62,7 @@ def download_combined_pdb(self, request):
experiment_id = j['experiment_id']
experiment_selected_index = j['selected_index']

data = self.experiments_loader.load_result(experiment_id)
data = self.experiments_loader.load_experiment(experiment_id)

combined_pdb = combine_sdf_pdb(data[experiment_selected_index])

Expand Down
4 changes: 2 additions & 2 deletions src/server/api_handlers/drug_target_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_experiments(self):
def get_experiment(self, request):
experiment_id = request.args.get('id')
experiment_name = request.args.get('name')
data = self.experiments_loader.load_result(experiment_id)
data = self.experiments_loader.load_experiment(experiment_id)
return {'id': experiment_id, 'data': data}

def change_experiment_name(self, request: Request):
Expand All @@ -33,7 +33,7 @@ def download_combined_pdb(self, request):
experiment_id = j['experiment_id']
experiment_selected_index = j['selected_index']

data = self.experiments_loader.load_result(experiment_id)
data = self.experiments_loader.load_experiment(experiment_id)

combined_pdb = combine_sdf_pdb(data[experiment_selected_index]['sdf'], data[experiment_selected_index]['pdb'])

Expand Down
Loading

0 comments on commit d262028

Please sign in to comment.