-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhelper.py
86 lines (67 loc) · 2.97 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from collections import defaultdict
from jina import DocumentArray, DocumentArrayMemmap, Document
from jina import Executor, requests
import torch
from sentence_transformers import SentenceTransformer
class SimpleIndexer(Executor):
"""Simple indexer class"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
print("{} IS self.workspace".format(self.workspace))
self._docs = DocumentArrayMemmap(self.workspace + '/indexer')
@requests(on='/index')
def index(self, docs: 'DocumentArray', **kwargs):
# Stores the index in attribute
if docs:
self._docs.extend(docs)
@requests(on='/search')
def search(self, docs: 'DocumentArray', **kwargs):
"""Append best matches to each document in docs"""
# Match query agains the index using cosine similarity
docs.match(
DocumentArray(self._docs),
metric='cosine',
normalization=(1, 0),
limit=10,
exclude_self=True,
)
print(docs.embeddings.shape)
for d in docs:
d.plot('match.svg')
match_similarity = defaultdict(float)
# For each match
# print("Type of d.matches is {} ".format(d.matches))
for m in d.matches:
# Get cosine similarity
# m.plot('m.svg')
# print("{} is the m.text".format(m.text))
# print("parent_id for m is {}".format(m.id))
match_similarity[m.parent_id] += m.scores['cosine'].value
sorted_similarities = sorted(
match_similarity.items(), key=lambda v: v[1], reverse=True
)
# print(match_similarity)
# print(sorted_similarities)
# Remove embedding as it is not needed anymore
d.pop('embedding')
class TextEncoder(Executor):
def __init__(self, parameters: dict = {'traversal_paths': 'r'}, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = SentenceTransformer(
# 'multi-qa-MiniLM-L6-cos-v1', device='cpu', cache_folder='.'
# 'all-MiniLM-L6-v2', device='cpu', cache_folder='.'
'all-MiniLM-L12-v2', device='cpu', cache_folder='.'
)
self.parameters = parameters
@requests(on=['/search', '/embed'])
def encode(self, docs: DocumentArray, **kwargs):
"""Wraps encoder from sentence-transformers package"""
print("BEHOLD!!!! I AM EMBEDDING !!!!")
traversal_paths = self.parameters.get('traversal_paths')
target = docs.traverse_flat(traversal_paths)
with torch.inference_mode():
# print(target.texts) #gave none when travesal_path set to c
embeddings = self.model.encode(target.texts)
# print(embeddings)
# print(embeddings.shape) # (1,384) for current model
target.embeddings = embeddings