Skip to content

Commit

Permalink
Adding n_seeds and clusters members to clusterer (cms-patatrack#41)
Browse files Browse the repository at this point in the history
* Adding `n_seeds` and `clusters` members to clusterer

* Update version

* Add getters for the new attributes

* Update docstring

* Add check to `n_seeds` in equality operator
  • Loading branch information
sbaldu authored May 9, 2024
1 parent cbb194c commit bdecc95
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
38 changes: 35 additions & 3 deletions CLUEstering/CLUEstering.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ class cluster_properties:
----------
n_clusters : int
Number of clusters constructed.
n_seeds : int
Number of seeds found, which indicates the clusters excluding the group of outliers.
clusters : np.ndarray
Array containing the list of the clusters found.
cluster_ids : np.ndarray
Array containing the cluster_id of each point.
is_seed : np.ndarray
Expand All @@ -183,6 +187,8 @@ class cluster_properties:
"""

n_clusters : int
n_seeds : int
clusters : np.ndarray
cluster_ids : np.ndarray
is_seed : np.ndarray
cluster_points : np.ndarray
Expand All @@ -192,6 +198,8 @@ class cluster_properties:
def __eq__(self, other):
if self.n_clusters != other.n_clusters:
return False
if self.n_seeds != other.n_seeds:
return False
if not (self.cluster_ids == other.cluster_ids).all():
return False
if not (self.is_seed == other.is_seed).all():
Expand Down Expand Up @@ -667,13 +675,17 @@ def run_clue(self,
Modified attributes
-------------------
n_clusters : int
Number of clusters reconstructed.
n_seeds : int
Number of seeds found, which indicates the clusters excluding the group of outliers.
clusters : ndarray
Array containing the list of the clusters found.
cluster_ids : ndarray
Contains the cluster_id corresponding to every point.
is_seed : ndarray
For every point the value is 1 if the point is a seed or an
outlier and 0 if it isn't.
n_clusters : int
Number of clusters reconstructed.
cluster_points : ndarray of lists
Contains, for every cluster, the list of points associated to id.
points_per_cluster : ndarray
Expand Down Expand Up @@ -723,7 +735,9 @@ def run_clue(self,
finish = time.time_ns()
cluster_ids = np.array(cluster_id_is_seed[0])
is_seed = np.array(cluster_id_is_seed[1])
n_clusters = len(np.unique(cluster_ids))
clusters = np.unique(cluster_ids)
n_seeds = np.sum([1 for i in clusters if i > -1])
n_clusters = len(clusters)

cluster_points = [[] for _ in range(n_clusters)]
for i in range(self.clust_data.n_points):
Expand All @@ -735,6 +749,8 @@ def run_clue(self,
output_df = pd.DataFrame(data)

self.clust_prop = cluster_properties(n_clusters,
n_seeds,
clusters,
cluster_ids,
is_seed,
np.asarray(cluster_points, dtype=object),
Expand All @@ -755,6 +771,22 @@ def n_clusters(self) -> int:

return self.clust_prop.n_clusters

@property
def n_seeds(self) -> int:
'''
Returns the number of seeds found.
'''

return self.clust_prop.n_seeds

@property
def clusters(self) -> np.ndarray:
'''
Returns the list of clusters found.
'''

return self.clust_prop.clusters

@property
def cluster_ids(self) -> np.ndarray:
'''
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from pathlib import Path
from setuptools import setup

__version__ = "2.2.1"
__version__ = "2.2.2"

this_directory = Path(__file__).parent
long_description = (this_directory/'README.md').read_text()

Expand Down

0 comments on commit bdecc95

Please sign in to comment.