diff --git a/flexibleSubsetSelection/objective.py b/flexibleSubsetSelection/objective.py index ac027b5..e2b70a8 100644 --- a/flexibleSubsetSelection/objective.py +++ b/flexibleSubsetSelection/objective.py @@ -1,10 +1,12 @@ # --- Imports ------------------------------------------------------------------ # Standard library +from collections import Counter from typing import Callable # Third party import ot +from ott.geometry import costs, pointcloud import numpy as np from numpy.typing import ArrayLike @@ -87,8 +89,8 @@ def earthMoversDistance(subset: np.ndarray, dataset: np.ndarray) -> float: """ return ot.emd2([], [], ot.dist(subset, dataset)) -def sinkhornDistance(distances: np.ndarray, reg: float = 0.1, - verbose: bool = False) -> float: +def sinkhornDistance(distances: np.ndarray, datasetLength, subsetLength, + reg: float = 0.1, verbose: bool = False) -> float: """ Computes the Sinkhorn distance using the POT library. @@ -100,13 +102,10 @@ def sinkhornDistance(distances: np.ndarray, reg: float = 0.1, Returns: float: Sinkhorn distance. """ - n, m = subset.shape[0], dataset.shape[0] - distanceMatrix = compute_distance_matrix(subset, dataset) - print("Computed Distance Matrix") - - return ot.sinkhorn2(np.ones(n) / n, - np.ones(m) / m, - distanceMatrix, + print(distances.shape) + return ot.sinkhorn2(np.ones(datasetLength) / datasetLength, + np.ones(subsetLength) / subsetLength, + distances, reg, stopThr=1e-05, verbose=verbose) @@ -153,4 +152,15 @@ def emdCategorical(subset, dataset, features, categorical, categories): dataset_data = dataset.loc[dataset[categorical] == category, features].values emd_loss = ot.emd2([], [], ot.dist(subset_data, dataset_data)) emd_losses.append(emd_loss) - return emd_losses \ No newline at end of file + return emd_losses + +def entropy(array: np.ndarray) -> float: + counts = Counter(map(tuple, array)) + total = sum(counts.values()) + probabilities = np.array(list(counts.values()))/total + return np.sum(probabilities * np.log(probabilities)) + +def sinkhorn(subset, fullData, solveFunction): + geometry = pointcloud.PointCloud(fullData, subset) + sinkhornOutput = solveFunction(geometry) + return sinkhornOutput.reg_ot_cost \ No newline at end of file diff --git a/flexibleSubsetSelection/plot.py b/flexibleSubsetSelection/plot.py index 3741718..5833409 100644 --- a/flexibleSubsetSelection/plot.py +++ b/flexibleSubsetSelection/plot.py @@ -219,18 +219,21 @@ def scatter(ax: Axes, color: Color, dataset: (Dataset | None) = None, if not hasattr(ax, "zaxis"): raise ValueError("3D data is specified but axis is not 3D.") initializePane3D(ax, color["grey"]) + data = [] + colors = [] if dataset is not None: - ax.scatter(dataset.data[features[0]], - dataset.data[features[1]], - dataset.data[features[2]], - color = color["green"], - **parameters) - if subset is not None: - ax.scatter(subset.data[features[0]], - subset.data[features[1]], - subset.data[features[2]], - color = color["darkGreen"], - **parameters) + data.append(dataset.data) + colors.extend([color["green"]] * len(dataset.data)) + if subset is not None: + data.append(subset.data) + colors.extend([color["darkGreen"]] * len(subset.data)) + + data = np.concatenate(data, axis=0) + ax.scatter(data[:, features[0]], + data[:, features[1]], + data[:, features[2]], + c=colors, + **parameters) else: if dataset is not None: sns.scatterplot(data = dataset.data,