Skip to content

Commit

Permalink
Merge pull request #17 from IBM/pre_v.0.3.2
Browse files Browse the repository at this point in the history
Pre v.0.3.2
  • Loading branch information
RaulFD-creator authored Aug 14, 2024
2 parents a2f1369 + 2a23227 commit 6d7e374
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
20 changes: 16 additions & 4 deletions autopeptideml/utils/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
'prot_t5_xxl_uniref50': 1024,
'prot_t5_xl_half_uniref50-enc': 1024,
'prot_bert': 1024,
'ProstT5': 1024
'ProstT5': 1024,
'ankh-base': 768,
'ankh-large': 1536
}

SYNONYMS = {
Expand All @@ -30,7 +32,9 @@
'esm1b': 'esm1b_t33_650M_UR50S',
'esm2-150m': 'esm2_t30_150M_UR50D',
'esm2-35m': 'esm2_t12_35M_UR50D',
'esm2-8m': 'esm2_t6_8M_UR50D'
'esm2-8m': 'esm2_t6_8M_UR50D',
'ankh-base': 'ankh-base',
'ankh-large': 'ankh-large'
}


Expand Down Expand Up @@ -59,7 +63,13 @@ def compute_batch(self, batch: list, average_pooling: bool):
inputs = self.tokenizer(batch, add_special_tokens=True, padding="longest", return_tensors="pt")
inputs = inputs.to(self.device)
with torch.no_grad():
embd_rpr = self.model(**inputs).last_hidden_state
if self.lab == 'ElnaggarLab':
embd_rpr = self.model(input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
decoder_input_ids=inputs['input_ids']
).last_hidden_state
else:
embd_rpr = self.model(**input_ids)

output = []
for idx in range(len(batch)):
Expand Down Expand Up @@ -102,7 +112,9 @@ def _load_model(self, model: str):
if 'pro' in model.lower():
self.lab = 'Rostlab'
elif 'esm' in model.lower():
self.lab = 'facebook'
self.lab = 'facebook'
elif 'ankh' in model.lower():
self.lab = 'ElnaggarLab'
if 't5' in model.lower():
self.tokenizer = T5Tokenizer.from_pretrained(f'Rostlab/{model}', do_lower_case=False)
self.model = T5EncoderModel.from_pretrained(f"Rostlab/{model}")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@
name='autopeptideml',
packages=find_packages(exclude=['examples']),
url='https://ibm.github.io/AutoPeptideML/',
version='0.3.1',
version='0.3.2',
zip_safe=False,
)

0 comments on commit 6d7e374

Please sign in to comment.