Skip to content

Commit

Permalink
avoiding loading models and embeddings to the gpu by defoult (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
ospitia authored Jun 21, 2022
1 parent 0dac3d6 commit 13ad6e5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/dossier_search/concept_search/concept_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_words_score(


class ConceptRate:
def __init__(self, model_name: str = "allenai/scibert_scivocab_cased", mount_on_gpu: bool = True):
def __init__(self, model_name: str = "allenai/scibert_scivocab_cased", mount_on_gpu: bool = False):
self.model = AutoModel.from_pretrained(model_name, output_hidden_states=True).eval()
self.device = 'cuda' if mount_on_gpu and torch.cuda.is_available() else 'cpu'
self.model.to(self.device)
Expand Down
31 changes: 16 additions & 15 deletions src/dossier_search/concept_search/faiss_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@


class SemanticSearch:
def __init__(self, data: list, tax_name: str):
def __init__(self, data: list, tax_name: str, mount_gpu: bool = False):
model_name = "allenai/scibert_scivocab_cased"
self.model = AutoModel.from_pretrained(
model_name, output_hidden_states=True
).eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=16)

self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.mount_gpu = mount_gpu
self.device = "cuda:0" if torch.cuda.is_available() and mount_gpu else "cpu"

self.model = self.model.to(self.device)


self.n_dimensions = self.model.pooler.dense.out_features
self.data = data

Expand All @@ -43,11 +42,12 @@ def __init__(self, data: list, tax_name: str):
except AttributeError:
faiss.write_index(self.taxonomy_index, index_path)

try:
res = faiss.StandardGpuResources()
self.taxonomy_index = faiss.index_cpu_to_gpu(res, 0, self.taxonomy_index)
except AttributeError:
pass
if mount_gpu:
try:
res = faiss.StandardGpuResources()
self.taxonomy_index = faiss.index_cpu_to_gpu(res, 0, self.taxonomy_index)
except AttributeError:
pass

def embedding(self, word: str):
"""
Expand Down Expand Up @@ -103,12 +103,13 @@ def create_faiss_index(self):

fastIndex = faiss.IndexFlatL2(self.n_dimensions)

try:
# copy the index to GPU
res = faiss.StandardGpuResources()
fastIndex = faiss.index_cpu_to_gpu(res, 0, fastIndex)
except AttributeError:
pass
if self.mount_gpu:
try:
# copy the index to GPU
res = faiss.StandardGpuResources()
fastIndex = faiss.index_cpu_to_gpu(res, 0, fastIndex)
except AttributeError:
pass

# index data
embeddings = self.embedding(self.data).cpu().detach().numpy()
Expand Down

0 comments on commit 13ad6e5

Please sign in to comment.