-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathmulti_classifier.py
47 lines (33 loc) · 1.44 KB
/
multi_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
from classifier import Classifier
def start_and_end_slot(sub_dir_name) -> (int, int):
start_slot = int(sub_dir_name.split("_")[1])
end_slot = int(sub_dir_name.split("_")[3])
return (start_slot, end_slot)
# List of classifiers
#
# [(start_slot, end_slot, classifier)]
class MultiClassifier:
def __init__(self, data_dir):
classifiers = []
for sub_dir_name in os.listdir(data_dir):
sub_dir_path = os.path.join(data_dir, sub_dir_name)
start_slot, end_slot = start_and_end_slot(sub_dir_name)
print(f"loading classifier for range {start_slot}..={end_slot}")
classifier = Classifier(sub_dir_path)
classifiers.append((start_slot, end_slot, classifier))
self.classifiers = sorted(classifiers, key=lambda x: x[0])
def classify(self, block_reward):
slot = int(block_reward["meta"]["slot"])
for i, (start_slot, end_slot, classifier) in enumerate(self.classifiers):
# Allow the last classifier to be used for slots beyond its end slot
if start_slot <= slot and (
slot <= end_slot or i + 1 == len(self.classifiers)
):
return classifier.classify(block_reward)
raise Exception(f"no classifier known for slot {slot}")
def scores(self):
return [
(start, end, classifier.score)
for (start, end, classifier) in self.classifiers
]