From c1932a1dd6c140a8e5640cad9c9c5bd0c35eba31 Mon Sep 17 00:00:00 2001 From: IHappyPlant Date: Sun, 30 Oct 2022 15:35:43 +0300 Subject: [PATCH] * code refactor --- sources/python/base.py | 7 +++++ sources/python/metric_classifiers.py | 47 +++++++++++++++------------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/sources/python/base.py b/sources/python/base.py index baf8695..3db1bf5 100644 --- a/sources/python/base.py +++ b/sources/python/base.py @@ -72,3 +72,10 @@ def fit_predict(self, data, data_object): """ self.fit(data) return self.predict(data_object) + + @abc.abstractmethod + def predict_proba(self, data): + """ + :type data: DataObject + """ + pass diff --git a/sources/python/metric_classifiers.py b/sources/python/metric_classifiers.py index b4b5561..81f5685 100644 --- a/sources/python/metric_classifiers.py +++ b/sources/python/metric_classifiers.py @@ -33,9 +33,27 @@ def get_params(self): "dataset": self._dataset } + def _get_weights_by_class(self, classes_weights): + return { + cl: sum((cw["weight"] for cw in classes_weights + if cw["class"] == cl)) + for cl in set(cw["class"] for cw in classes_weights) + } + + @staticmethod + def _get_classes_weights(objects, weights): + return [{"class": o.classcode, "weight": w} + for o, w in zip(objects, weights)] + def fit(self, data, **kwargs): self._dataset = data + def predict(self, data): + weights_table = self.predict_proba(data) + max_weight = max(weights_table.values()) + return max(weights_table, key=weights_table.get) if max_weight > 0 \ + else None + class KWNN(MetricClassifier): """ @@ -68,19 +86,13 @@ def _sort_objects_by_dist(self, objects, obj): """ return sorted(objects, key=lambda x: self._metric(x, obj)) - def predict(self, data): + def predict_proba(self, data): sorted_objects = self._sort_objects_by_dist(self._dataset, data) n_objects = sorted_objects[:self._n] weights = self._weights_calculator.get_weights(n_objects) - cls_weights = [{ - "class": o.classcode, - "weight": w - } for o, w in zip(n_objects, weights)] - result_table = { - cl: sum((cw["weight"] for cw in cls_weights if cw["class"] == cl)) - for cl in set((obj.classcode for obj in n_objects)) - } - return max(result_table, key=result_table.get) + cls_weights = self._get_classes_weights(n_objects, weights) + result_table = self._get_weights_by_class(cls_weights) + return result_table class ParzenWindow(MetricClassifier): @@ -100,19 +112,12 @@ def _get_weights(self, distances): return self._weights_calculator.get_weights( np.array(list(distances)) / self._h) - def predict(self, data): + def predict_proba(self, data): distances = (self._metric(data, other) for other in self._dataset) weights = self._get_weights(distances) - cls_weights = [{ - "class": o.classcode, - "weight": w - } for o, w in zip(self._dataset, weights)] - res_table = { - cl: sum((cw["weight"] for cw in cls_weights if cw["class"] == cl)) - for cl in set((obj.classcode for obj in self._dataset)) - } - max_weight = max(res_table.values()) - return max(res_table, key=res_table.get) if max_weight > 0 else None + cls_weights = self._get_classes_weights(self._dataset, weights) + result_table = self._get_weights_by_class(cls_weights) + return result_table class PotentialFunctions(ParzenWindow):