Skip to content

Commit

Permalink
Plotting bug fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
cbbcbail committed Aug 2, 2024
1 parent 1b1cb0c commit e36de44
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 21 deletions.
30 changes: 20 additions & 10 deletions flexibleSubsetSelection/objective.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# --- Imports ------------------------------------------------------------------

# Standard library
from collections import Counter
from typing import Callable

# Third party
import ot
from ott.geometry import costs, pointcloud

import numpy as np
from numpy.typing import ArrayLike
Expand Down Expand Up @@ -87,8 +89,8 @@ def earthMoversDistance(subset: np.ndarray, dataset: np.ndarray) -> float:
"""
return ot.emd2([], [], ot.dist(subset, dataset))

def sinkhornDistance(distances: np.ndarray, reg: float = 0.1,
verbose: bool = False) -> float:
def sinkhornDistance(distances: np.ndarray, datasetLength, subsetLength,
reg: float = 0.1, verbose: bool = False) -> float:
"""
Computes the Sinkhorn distance using the POT library.
Expand All @@ -100,13 +102,10 @@ def sinkhornDistance(distances: np.ndarray, reg: float = 0.1,
Returns:
float: Sinkhorn distance.
"""
n, m = subset.shape[0], dataset.shape[0]
distanceMatrix = compute_distance_matrix(subset, dataset)
print("Computed Distance Matrix")

return ot.sinkhorn2(np.ones(n) / n,
np.ones(m) / m,
distanceMatrix,
print(distances.shape)
return ot.sinkhorn2(np.ones(datasetLength) / datasetLength,
np.ones(subsetLength) / subsetLength,
distances,
reg,
stopThr=1e-05,
verbose=verbose)
Expand Down Expand Up @@ -153,4 +152,15 @@ def emdCategorical(subset, dataset, features, categorical, categories):
dataset_data = dataset.loc[dataset[categorical] == category, features].values
emd_loss = ot.emd2([], [], ot.dist(subset_data, dataset_data))
emd_losses.append(emd_loss)
return emd_losses
return emd_losses

def entropy(array: np.ndarray) -> float:
counts = Counter(map(tuple, array))
total = sum(counts.values())
probabilities = np.array(list(counts.values()))/total
return np.sum(probabilities * np.log(probabilities))

def sinkhorn(subset, fullData, solveFunction):
geometry = pointcloud.PointCloud(fullData, subset)
sinkhornOutput = solveFunction(geometry)
return sinkhornOutput.reg_ot_cost
25 changes: 14 additions & 11 deletions flexibleSubsetSelection/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,18 +219,21 @@ def scatter(ax: Axes, color: Color, dataset: (Dataset | None) = None,
if not hasattr(ax, "zaxis"):
raise ValueError("3D data is specified but axis is not 3D.")
initializePane3D(ax, color["grey"])
data = []
colors = []
if dataset is not None:
ax.scatter(dataset.data[features[0]],
dataset.data[features[1]],
dataset.data[features[2]],
color = color["green"],
**parameters)
if subset is not None:
ax.scatter(subset.data[features[0]],
subset.data[features[1]],
subset.data[features[2]],
color = color["darkGreen"],
**parameters)
data.append(dataset.data)
colors.extend([color["green"]] * len(dataset.data))
if subset is not None:
data.append(subset.data)
colors.extend([color["darkGreen"]] * len(subset.data))

data = np.concatenate(data, axis=0)
ax.scatter(data[:, features[0]],
data[:, features[1]],
data[:, features[2]],
c=colors,
**parameters)
else:
if dataset is not None:
sns.scatterplot(data = dataset.data,
Expand Down

0 comments on commit e36de44

Please sign in to comment.