Skip to content

Commit

Permalink
mend
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 committed Jan 7, 2025
1 parent a684dd3 commit 22bbc07
Showing 1 changed file with 101 additions and 52 deletions.
153 changes: 101 additions & 52 deletions tdc/test/test_hf.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,115 @@
# -*- coding: utf-8 -*-
from huggingface_hub import create_repo
from huggingface_hub import HfApi, snapshot_download, hf_hub_download
import os

from __future__ import division
from __future__ import print_function
deeppurpose_repo = [
'hERG_Karim-Morgan',
'hERG_Karim-CNN',
'hERG_Karim-AttentiveFP',
'BBB_Martins-AttentiveFP',
'BBB_Martins-Morgan',
'BBB_Martins-CNN',
'CYP3A4_Veith-Morgan',
'CYP3A4_Veith-CNN',
'CYP3A4_Veith-AttentiveFP',
]

import os
import sys
model_hub = ["Geneformer", "scGPT"]

import unittest
import shutil
import pytest

# temporary solution for relative imports in case TDC is not installed
# if TDC is installed, no need to use the following line
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
# TODO: add verification for the generation other than simple integration
class tdc_hf_interface:
'''
Example use cases:
# initialize an interface object with HF repo name
tdc_hf_herg = tdc_hf_interface("hERG_Karim-Morgan")
# upload folder/files to this repo
tdc_hf_herg.upload('./Morgan_herg_karim_optimal')
# load deeppurpose model from this repo
dp_model = tdc_hf_herg.load_deeppurpose('./data')
dp_model.predict(XXX)
'''

def __init__(self, repo_name):
self.repo_id = "tdc/" + repo_name
try:
self.model_name = repo_name.split('-')[1]
except:
self.model_name = repo_name

class TestHF(unittest.TestCase):
def upload(self, folder_path):
create_repo(repo_id=self.repo_id)
api = HfApi()
api.upload_folder(folder_path=folder_path,
path_in_repo="model",
repo_id=self.repo_id,
repo_type="model")

def setUp(self):
print(os.getcwd())
pass
def file_download(self, save_path, filename):
model_ckpt = hf_hub_download(repo_id=self.repo_id,
filename=filename,
cache_dir=save_path)

@pytest.mark.skip(
reason="This test is skipped due to deeppurpose installation dependency"
)
@unittest.skip(reason="DeepPurpose")
def test_hf_load_predict(self):
from tdc.single_pred import Tox
data = Tox(name='herg_karim')
def repo_download(self, save_path):
snapshot_download(repo_id=self.repo_id, cache_dir=save_path)

from tdc import tdc_hf_interface
tdc_hf = tdc_hf_interface("hERG_Karim-CNN")
# load deeppurpose model from this repo
dp_model = tdc_hf.load_deeppurpose('./data')
tdc_hf.predict_deeppurpose(dp_model, ['CC(=O)NC1=CC=C(O)C=C1'])
def load(self):
if self.model_name not in model_hub:
raise Exception("this model is not in the TDC model hub GH repo.")
elif self.model_name == "Geneformer":
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained(
"ctheodoris/Geneformer")
return model
elif self.model_name == "scGPT":
from transformers import AutoModel
model = AutoModel.from_pretrained("tdc/scGPT")
return model
raise Exception("Not implemented yet!")

def test_hf_transformer(self):
from tdc import tdc_hf_interface
# from transformers import Pipeline
from transformers import BertForMaskedLM as BertModel
geneformer = tdc_hf_interface("Geneformer")
model = geneformer.load()
# assert isinstance(pipeline, Pipeline)
assert isinstance(model, BertModel), type(model)
def load_deeppurpose(self, save_path):
if self.repo_id[4:] in deeppurpose_repo:
save_path = save_path + '/' + self.repo_id[4:]
if not os.path.exists(save_path):
os.mkdir(save_path)
self.file_download(save_path, "model/model.pt")
self.file_download(save_path, "model/config.pkl")

# def test_hf_load_new_pytorch_standard(self):
# from tdc import tdc_hf_interface
# # from tdc.resource.dataloader import DataLoader
# # data = DataLoader(name="pinnacle_dti")
# tdc_hf = tdc_hf_interface("mli-PINNACLE")
# dp_model = tdc_hf.load()
# assert dp_model is not None
save_path = save_path + '/models--tdc--' + self.repo_id[
4:] + '/blobs/'
file_name1 = save_path + os.listdir(save_path)[0]
file_name2 = save_path + os.listdir(save_path)[1]

def tearDown(self):
try:
print(os.getcwd())
shutil.rmtree(os.path.join(os.getcwd(), "data"))
except:
pass
if os.path.getsize(file_name1) > os.path.getsize(file_name2):
model_file, config_file = file_name1, file_name2
else:
config_file, model_file = file_name1, file_name2

os.rename(model_file, save_path + 'model.pt')
os.rename(config_file, save_path + 'config.pkl')
try:
from DeepPurpose import CompoundPred
except:
raise ValueError(
"Please install DeepPurpose package following https://github.com/kexinhuang12345/DeepPurpose#installation"
)

net = CompoundPred.model_pretrained(path_dir=save_path)
return net
else:
raise ValueError("This repo does not host a DeepPurpose model!")

if __name__ == "__main__":
unittest.main()
def predict_deeppurpose(self, model, drugs):
try:
from DeepPurpose import utils
except:
raise ValueError(
"Please install DeepPurpose package following https://github.com/kexinhuang12345/DeepPurpose#installation"
)
if self.model_name == 'AttentiveFP':
self.model_name = 'DGL_' + self.model_name
X_pred = utils.data_process(X_drug=drugs,
y=[0] * len(drugs),
drug_encoding=self.model_name,
split_method='no_split')
y_pred = model.predict(X_pred)[0]
return y_pred

0 comments on commit 22bbc07

Please sign in to comment.