Skip to content

Commit

Permalink
* code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
IHappyPlant committed Oct 30, 2022
1 parent cb1f12e commit c1932a1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 21 deletions.
7 changes: 7 additions & 0 deletions sources/python/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,10 @@ def fit_predict(self, data, data_object):
"""
self.fit(data)
return self.predict(data_object)

@abc.abstractmethod
def predict_proba(self, data):
"""
:type data: DataObject
"""
pass
47 changes: 26 additions & 21 deletions sources/python/metric_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,27 @@ def get_params(self):
"dataset": self._dataset
}

def _get_weights_by_class(self, classes_weights):
return {
cl: sum((cw["weight"] for cw in classes_weights
if cw["class"] == cl))
for cl in set(cw["class"] for cw in classes_weights)
}

@staticmethod
def _get_classes_weights(objects, weights):
return [{"class": o.classcode, "weight": w}
for o, w in zip(objects, weights)]

def fit(self, data, **kwargs):
self._dataset = data

def predict(self, data):
weights_table = self.predict_proba(data)
max_weight = max(weights_table.values())
return max(weights_table, key=weights_table.get) if max_weight > 0 \
else None


class KWNN(MetricClassifier):
"""
Expand Down Expand Up @@ -68,19 +86,13 @@ def _sort_objects_by_dist(self, objects, obj):
"""
return sorted(objects, key=lambda x: self._metric(x, obj))

def predict(self, data):
def predict_proba(self, data):
sorted_objects = self._sort_objects_by_dist(self._dataset, data)
n_objects = sorted_objects[:self._n]
weights = self._weights_calculator.get_weights(n_objects)
cls_weights = [{
"class": o.classcode,
"weight": w
} for o, w in zip(n_objects, weights)]
result_table = {
cl: sum((cw["weight"] for cw in cls_weights if cw["class"] == cl))
for cl in set((obj.classcode for obj in n_objects))
}
return max(result_table, key=result_table.get)
cls_weights = self._get_classes_weights(n_objects, weights)
result_table = self._get_weights_by_class(cls_weights)
return result_table


class ParzenWindow(MetricClassifier):
Expand All @@ -100,19 +112,12 @@ def _get_weights(self, distances):
return self._weights_calculator.get_weights(
np.array(list(distances)) / self._h)

def predict(self, data):
def predict_proba(self, data):
distances = (self._metric(data, other) for other in self._dataset)
weights = self._get_weights(distances)
cls_weights = [{
"class": o.classcode,
"weight": w
} for o, w in zip(self._dataset, weights)]
res_table = {
cl: sum((cw["weight"] for cw in cls_weights if cw["class"] == cl))
for cl in set((obj.classcode for obj in self._dataset))
}
max_weight = max(res_table.values())
return max(res_table, key=res_table.get) if max_weight > 0 else None
cls_weights = self._get_classes_weights(self._dataset, weights)
result_table = self._get_weights_by_class(cls_weights)
return result_table


class PotentialFunctions(ParzenWindow):
Expand Down

0 comments on commit c1932a1

Please sign in to comment.