Skip to content

Commit

Permalink
Allow semantic search with taxonomies #69 (#80)
Browse files Browse the repository at this point in the history
* Allow semantic search with taxonomies #69

* rdf querying for taxonomy CSO

* Update src/dossier_search/concept_search/faiss_search.py

Co-authored-by: Wojciech Kusa <[email protected]>

* rdf search extension for CCS taxonomy

* fix missing requirement

* fix specific exceptions in get_id

* fix(taxonomy): semantic search works on M1 chip

- lexical search on M1 is disabled

* fix(taxonomy): fix settings import

* fix(taxonomy): faiss indexing works on cpu

* style(taxonomy): style with black

* build(taxonomy): update requirements

* fix(taxonomy): sort imports, catch more specific exceptions

* docs(taxonomy): update readme and change default M1_CHIP to False

* fix path parameter optional

Co-authored-by: Wojciech Kusa <[email protected]>
Co-authored-by: Wojciech Kusa <[email protected]>
Co-authored-by: Ayah Soufan <[email protected]>
  • Loading branch information
4 people authored May 17, 2022
1 parent 61dc4ec commit 993d98c
Show file tree
Hide file tree
Showing 7 changed files with 496 additions and 30 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,12 @@ Server should be available at http://127.0.0.1:8000/
(cruise-literature)$ python manage.py runserver YOUR_IP:YOUR_PORT
```

## 3. Troubleshooting

### 3.1 M1 Macbook

If you are using a laptop with the M1 chip please change the following line in the [settings.py](src/dossier_search/dossier_search/settings.py) file:

```python
M1_CHIP = True
```
Empty file added data/processed/.gitkeep
Empty file.
9 changes: 8 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,11 @@ wikipedia~=1.4.0
pandas==1.4.1
xmltodict==0.12.0
fuzzywuzzy==0.18.0
requests~=2.27.1
requests~=2.27.1
torch==1.10.2
transformers==4.14.1
python-terrier==0.8.1
rdflib==6.1.1
faiss-cpu==1.7.2
numpy~=1.22.3
faiss-gpu==1.7.2
121 changes: 121 additions & 0 deletions src/dossier_search/concept_search/faiss_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from os.path import exists

import faiss
import numpy as np
import torch
from dossier_search.settings import M1_CHIP
from transformers import AutoTokenizer, AutoModel

if M1_CHIP:
# solves problems with MKL library on M1 macbook
# FIXME: this should be replaced by a proper requirements for M1
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"


class SemanticSearch:
def __init__(self, data: list, tax_name: str):
model_name = "allenai/scibert_scivocab_cased"
self.model = AutoModel.from_pretrained(
model_name, output_hidden_states=True
).eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=16)

self.n_dimensions = self.model.pooler.dense.out_features
self.data = data

index_path = "../../data/processed/Taxonomy{}index.bin".format(tax_name)

if exists(index_path):
self.taxonomy_index = faiss.read_index(index_path)
else:
self.taxonomy_index = self.create_faiss_index()
try:
faiss.write_index(
faiss.index_gpu_to_cpu(self.taxonomy_index), index_path
)
except AttributeError:
faiss.write_index(self.taxonomy_index, index_path)

try:
res = faiss.StandardGpuResources()
self.taxonomy_index = faiss.index_cpu_to_gpu(res, 0, self.taxonomy_index)
except AttributeError:
pass

def embedding(self, word: str):
"""
embedds the word input as the sum of the last 4 hidden
states of Bert embeddings
"""
model = self.model
tokenizer = self.tokenizer
word = [word] if word.__class__ != list else word
marked_text = ["[CLS] " + i + " [SEP]" for i in word]

padding = "max_length"
tokenized_text = [
tokenizer.tokenize(i, padding=padding, truncation=True) for i in marked_text
]
tokenized_text_ = [
tokenizer.tokenize(i, padding=False, truncation=True) for i in marked_text
]
wordpiece_vectors = [len(i) - 1 for i in tokenized_text_]

indexed_tokens = [tokenizer.convert_tokens_to_ids(i) for i in tokenized_text]
tokens_tensor = torch.tensor([indexed_tokens]).squeeze(0)

with torch.no_grad():
outputs = model(tokens_tensor)
hidden_states = outputs["hidden_states"]

# combine the layers to make a single tensor.
token_embeddings = torch.stack(hidden_states, dim=0)
token_embeddings = torch.squeeze(token_embeddings, dim=1)

try:
token_embeddings = token_embeddings.permute(1, 2, 0, 3)
except RuntimeError:
# RuntimeError: number of dims don't match in permute
token_embeddings = token_embeddings.unsqueeze(0).permute(0, 2, 1, 3)

out = []
for item, wordpiece_vector in zip(token_embeddings, wordpiece_vectors):
token_vecs_sum = []
for token in item:
sum_vec = torch.sum(token[-4:], dim=0)
token_vecs_sum.append(sum_vec)
out_i = torch.stack(token_vecs_sum[1:wordpiece_vector], dim=-1)
out_i = out_i.mean(dim=-1)
out.append(out_i)
return torch.stack(out)

def create_faiss_index(self):
"""
create a faiss index of type FlatL2 from data vectors of n_dimensions
"""

fastIndex = faiss.IndexFlatL2(self.n_dimensions)

try:
# copy the index to GPU
res = faiss.StandardGpuResources()
fastIndex = faiss.index_cpu_to_gpu(res, 0, fastIndex)
except AttributeError:
pass

# index data
embeddings = self.embedding(self.data).detach().numpy()
fastIndex.add(np.stack(embeddings).astype("float32"))
return fastIndex

def do_faiss_lookup(self, text):
"""
the semantic search of word in the indexed collection
"""
vector = self.embedding(text).detach().numpy().astype("float32")

score, index = self.taxonomy_index.search(vector, 1)

return self.data[index[0][0]], score
44 changes: 44 additions & 0 deletions src/dossier_search/concept_search/lexical_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from os.path import exists, join

import pyterrier as pt

if not pt.started():
pt.init()


class LexicalSearch:
def __init__(self, data: list, tax_name: str):
self.data = data
self.path = "../../data/processed/taxonomy{}iter_index".format(tax_name)
self.indexref = self.get_index()

def collection_iter(self, collection):
for idx, text in enumerate(collection):
yield {"docno": idx, "text": text}

def get_index(self):
if exists(join(self.path, "data.properties")):
indexref = pt.IndexFactory.of(join(self.path, "data.properties"))
else:
iter_indexer = pt.IterDictIndexer(self.path, blocks=True, overwrite=True)

doc_iter = self.collection_iter(self.data)
indexref = iter_indexer.index(doc_iter)

return indexref

def lexical_search(self, query):
retr_controls = {
"wmodel": "BM25",
"string.use_utf": "true",
"end": 0,
}

retr = pt.BatchRetrieve(self.indexref, controls=retr_controls)

res = retr.search(query)

if len(res) > 0 and res.score.values[0] > 7:
return self.data[int(res.docno.values[0])]
else:
return -100
Loading

0 comments on commit 993d98c

Please sign in to comment.