-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* Allow semantic search with taxonomies #69 * rdf querying for taxonomy CSO * Update src/dossier_search/concept_search/faiss_search.py Co-authored-by: Wojciech Kusa <[email protected]> * rdf search extension for CCS taxonomy * fix missing requirement * fix specific exceptions in get_id * fix(taxonomy): semantic search works on M1 chip - lexical search on M1 is disabled * fix(taxonomy): fix settings import * fix(taxonomy): faiss indexing works on cpu * style(taxonomy): style with black * build(taxonomy): update requirements * fix(taxonomy): sort imports, catch more specific exceptions * docs(taxonomy): update readme and change default M1_CHIP to False * fix path parameter optional Co-authored-by: Wojciech Kusa <[email protected]> Co-authored-by: Wojciech Kusa <[email protected]> Co-authored-by: Ayah Soufan <[email protected]>
- Loading branch information
1 parent
61dc4ec
commit 993d98c
Showing
7 changed files
with
496 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from os.path import exists | ||
|
||
import faiss | ||
import numpy as np | ||
import torch | ||
from dossier_search.settings import M1_CHIP | ||
from transformers import AutoTokenizer, AutoModel | ||
|
||
if M1_CHIP: | ||
# solves problems with MKL library on M1 macbook | ||
# FIXME: this should be replaced by a proper requirements for M1 | ||
import os | ||
|
||
os.environ["KMP_DUPLICATE_LIB_OK"] = "True" | ||
|
||
|
||
class SemanticSearch: | ||
def __init__(self, data: list, tax_name: str): | ||
model_name = "allenai/scibert_scivocab_cased" | ||
self.model = AutoModel.from_pretrained( | ||
model_name, output_hidden_states=True | ||
).eval() | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=16) | ||
|
||
self.n_dimensions = self.model.pooler.dense.out_features | ||
self.data = data | ||
|
||
index_path = "../../data/processed/Taxonomy{}index.bin".format(tax_name) | ||
|
||
if exists(index_path): | ||
self.taxonomy_index = faiss.read_index(index_path) | ||
else: | ||
self.taxonomy_index = self.create_faiss_index() | ||
try: | ||
faiss.write_index( | ||
faiss.index_gpu_to_cpu(self.taxonomy_index), index_path | ||
) | ||
except AttributeError: | ||
faiss.write_index(self.taxonomy_index, index_path) | ||
|
||
try: | ||
res = faiss.StandardGpuResources() | ||
self.taxonomy_index = faiss.index_cpu_to_gpu(res, 0, self.taxonomy_index) | ||
except AttributeError: | ||
pass | ||
|
||
def embedding(self, word: str): | ||
""" | ||
embedds the word input as the sum of the last 4 hidden | ||
states of Bert embeddings | ||
""" | ||
model = self.model | ||
tokenizer = self.tokenizer | ||
word = [word] if word.__class__ != list else word | ||
marked_text = ["[CLS] " + i + " [SEP]" for i in word] | ||
|
||
padding = "max_length" | ||
tokenized_text = [ | ||
tokenizer.tokenize(i, padding=padding, truncation=True) for i in marked_text | ||
] | ||
tokenized_text_ = [ | ||
tokenizer.tokenize(i, padding=False, truncation=True) for i in marked_text | ||
] | ||
wordpiece_vectors = [len(i) - 1 for i in tokenized_text_] | ||
|
||
indexed_tokens = [tokenizer.convert_tokens_to_ids(i) for i in tokenized_text] | ||
tokens_tensor = torch.tensor([indexed_tokens]).squeeze(0) | ||
|
||
with torch.no_grad(): | ||
outputs = model(tokens_tensor) | ||
hidden_states = outputs["hidden_states"] | ||
|
||
# combine the layers to make a single tensor. | ||
token_embeddings = torch.stack(hidden_states, dim=0) | ||
token_embeddings = torch.squeeze(token_embeddings, dim=1) | ||
|
||
try: | ||
token_embeddings = token_embeddings.permute(1, 2, 0, 3) | ||
except RuntimeError: | ||
# RuntimeError: number of dims don't match in permute | ||
token_embeddings = token_embeddings.unsqueeze(0).permute(0, 2, 1, 3) | ||
|
||
out = [] | ||
for item, wordpiece_vector in zip(token_embeddings, wordpiece_vectors): | ||
token_vecs_sum = [] | ||
for token in item: | ||
sum_vec = torch.sum(token[-4:], dim=0) | ||
token_vecs_sum.append(sum_vec) | ||
out_i = torch.stack(token_vecs_sum[1:wordpiece_vector], dim=-1) | ||
out_i = out_i.mean(dim=-1) | ||
out.append(out_i) | ||
return torch.stack(out) | ||
|
||
def create_faiss_index(self): | ||
""" | ||
create a faiss index of type FlatL2 from data vectors of n_dimensions | ||
""" | ||
|
||
fastIndex = faiss.IndexFlatL2(self.n_dimensions) | ||
|
||
try: | ||
# copy the index to GPU | ||
res = faiss.StandardGpuResources() | ||
fastIndex = faiss.index_cpu_to_gpu(res, 0, fastIndex) | ||
except AttributeError: | ||
pass | ||
|
||
# index data | ||
embeddings = self.embedding(self.data).detach().numpy() | ||
fastIndex.add(np.stack(embeddings).astype("float32")) | ||
return fastIndex | ||
|
||
def do_faiss_lookup(self, text): | ||
""" | ||
the semantic search of word in the indexed collection | ||
""" | ||
vector = self.embedding(text).detach().numpy().astype("float32") | ||
|
||
score, index = self.taxonomy_index.search(vector, 1) | ||
|
||
return self.data[index[0][0]], score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from os.path import exists, join | ||
|
||
import pyterrier as pt | ||
|
||
if not pt.started(): | ||
pt.init() | ||
|
||
|
||
class LexicalSearch: | ||
def __init__(self, data: list, tax_name: str): | ||
self.data = data | ||
self.path = "../../data/processed/taxonomy{}iter_index".format(tax_name) | ||
self.indexref = self.get_index() | ||
|
||
def collection_iter(self, collection): | ||
for idx, text in enumerate(collection): | ||
yield {"docno": idx, "text": text} | ||
|
||
def get_index(self): | ||
if exists(join(self.path, "data.properties")): | ||
indexref = pt.IndexFactory.of(join(self.path, "data.properties")) | ||
else: | ||
iter_indexer = pt.IterDictIndexer(self.path, blocks=True, overwrite=True) | ||
|
||
doc_iter = self.collection_iter(self.data) | ||
indexref = iter_indexer.index(doc_iter) | ||
|
||
return indexref | ||
|
||
def lexical_search(self, query): | ||
retr_controls = { | ||
"wmodel": "BM25", | ||
"string.use_utf": "true", | ||
"end": 0, | ||
} | ||
|
||
retr = pt.BatchRetrieve(self.indexref, controls=retr_controls) | ||
|
||
res = retr.search(query) | ||
|
||
if len(res) > 0 and res.score.values[0] > 7: | ||
return self.data[int(res.docno.values[0])] | ||
else: | ||
return -100 |
Oops, something went wrong.