Skip to content

Commit

Permalink
feat: Add support for Timed Clusterization (#3654)
Browse files Browse the repository at this point in the history
This add possibility of doing clusterization using as well time information
CarloVarni authored Oct 4, 2024
1 parent 74244eb commit 6cea6b2
Showing 6 changed files with 401 additions and 84 deletions.
50 changes: 41 additions & 9 deletions Core/include/Acts/Clusterization/Clusterization.hpp
Original file line number Diff line number Diff line change
@@ -13,6 +13,26 @@

namespace Acts::Ccl {

template <typename Cell>
concept HasRetrievableColumnInfo = requires(Cell cell) {
{ getCellColumn(cell) } -> std::same_as<int>;
};

template <typename Cell>
concept HasRetrievableRowInfo = requires(Cell cell) {
{ getCellRow(cell) } -> std::same_as<int>;
};

template <typename Cell>
concept HasRetrievableLabelInfo = requires(Cell cell) {
{ getCellLabel(cell) } -> std::same_as<int&>;
};

template <typename Cell, typename Cluster>
concept CanAcceptCell = requires(Cell cell, Cluster cluster) {
{ clusterAddCell(cluster, cell) } -> std::same_as<void>;
};

using Label = int;
constexpr Label NO_LABEL = 0;

@@ -28,17 +48,21 @@ enum class ConnectResult {

// Default connection type for 2-D grids: 4- or 8-cell connectivity
template <typename Cell>
requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
Acts::Ccl::HasRetrievableRowInfo<Cell>)
struct Connect2D {
bool conn8;
Connect2D() : conn8{true} {}
bool conn8{true};
Connect2D() = default;
explicit Connect2D(bool commonCorner) : conn8{commonCorner} {}
ConnectResult operator()(const Cell& ref, const Cell& iter) const;
virtual ConnectResult operator()(const Cell& ref, const Cell& iter) const;
virtual ~Connect2D() = default;
};

// Default connection type for 1-D grids: 2-cell connectivity
template <typename Cell>
template <Acts::Ccl::HasRetrievableColumnInfo Cell>
struct Connect1D {
ConnectResult operator()(const Cell& ref, const Cell& iter) const;
virtual ConnectResult operator()(const Cell& ref, const Cell& iter) const;
virtual ~Connect1D() = default;
};

// Default connection type based on GridDim
@@ -49,13 +73,16 @@ struct DefaultConnect {
};

template <typename Cell>
struct DefaultConnect<Cell, 2> : public Connect2D<Cell> {
explicit DefaultConnect(bool commonCorner) : Connect2D<Cell>(commonCorner) {}
DefaultConnect() : DefaultConnect(true) {}
struct DefaultConnect<Cell, 1> : public Connect1D<Cell> {
~DefaultConnect() override = default;
};

template <typename Cell>
struct DefaultConnect<Cell, 1> : public Connect1D<Cell> {};
struct DefaultConnect<Cell, 2> : public Connect2D<Cell> {
explicit DefaultConnect(bool commonCorner) : Connect2D<Cell>(commonCorner) {}
DefaultConnect() = default;
~DefaultConnect() override = default;
};

/// @brief labelClusters
///
@@ -70,6 +97,8 @@ struct DefaultConnect<Cell, 1> : public Connect1D<Cell> {};
template <typename CellCollection, std::size_t GridDim = 2,
typename Connect =
DefaultConnect<typename CellCollection::value_type, GridDim>>
requires(
Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type>)
void labelClusters(CellCollection& cells, Connect connect = Connect());

/// @brief mergeClusters
@@ -82,6 +111,9 @@ void labelClusters(CellCollection& cells, Connect connect = Connect());
/// @return nothing
template <typename CellCollection, typename ClusterCollection,
std::size_t GridDim>
requires(GridDim == 1 || GridDim == 2) &&
Acts::Ccl::HasRetrievableLabelInfo<
typename CellCollection::value_type>
ClusterCollection mergeClusters(CellCollection& /*cells*/);

/// @brief createClusters
100 changes: 25 additions & 75 deletions Core/include/Acts/Clusterization/Clusterization.ipp
Original file line number Diff line number Diff line change
@@ -14,86 +14,34 @@

namespace Acts::Ccl::internal {

// Machinery for validating generic Cell/Cluster types at compile-time

template <typename, std::size_t, typename T = void>
struct cellTypeHasRequiredFunctions : std::false_type {};

template <typename T>
struct cellTypeHasRequiredFunctions<
T, 2,
std::void_t<decltype(getCellRow(std::declval<T>())),
decltype(getCellColumn(std::declval<T>())),
decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
};

template <typename T>
struct cellTypeHasRequiredFunctions<
T, 1,
std::void_t<decltype(getCellColumn(std::declval<T>())),
decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
};

template <typename, typename, typename T = void>
struct clusterTypeHasRequiredFunctions : std::false_type {};

template <typename T, typename U>
struct clusterTypeHasRequiredFunctions<
T, U,
std::void_t<decltype(clusterAddCell(std::declval<T>(), std::declval<U>()))>>
: std::true_type {};

template <std::size_t GridDim>
constexpr void staticCheckGridDim() {
static_assert(
GridDim == 1 || GridDim == 2,
"mergeClusters is only defined for grid dimensions of 1 or 2. ");
}

template <typename T, std::size_t GridDim>
constexpr void staticCheckCellType() {
constexpr bool hasFns = cellTypeHasRequiredFunctions<T, GridDim>();
static_assert(hasFns,
"Cell type should have the following functions: "
"'int getCellRow(const Cell&)', "
"'int getCellColumn(const Cell&)', "
"'Label& getCellLabel(Cell&)'");
}

template <typename T, typename U>
constexpr void staticCheckClusterType() {
constexpr bool hasFns = clusterTypeHasRequiredFunctions<T, U>();
static_assert(hasFns,
"Cluster type should have the following function: "
"'void clusterAddCell(Cluster&, const Cell&)'");
}

template <typename Cell, std::size_t GridDim>
struct Compare {
static_assert(GridDim != 1 && GridDim != 2,
"Only grid dimensions of 1 or 2 are supported");
};

// Comparator function object for cells, column-wise ordering
// Specialization for 2-D grid
template <typename Cell>
struct Compare<Cell, 2> {
// Specialization for 1-D grids
template <Acts::Ccl::HasRetrievableColumnInfo Cell>
struct Compare<Cell, 1> {
bool operator()(const Cell& c0, const Cell& c1) const {
int row0 = getCellRow(c0);
int row1 = getCellRow(c1);
int col0 = getCellColumn(c0);
int col1 = getCellColumn(c1);
return (col0 == col1) ? row0 < row1 : col0 < col1;
return col0 < col1;
}
};

// Specialization for 1-D grids
// Specialization for 2-D grid
template <typename Cell>
struct Compare<Cell, 1> {
requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
Acts::Ccl::HasRetrievableRowInfo<Cell>)
struct Compare<Cell, 2> {
bool operator()(const Cell& c0, const Cell& c1) const {
int row0 = getCellRow(c0);
int row1 = getCellRow(c1);
int col0 = getCellColumn(c0);
int col1 = getCellColumn(c1);
return col0 < col1;
return (col0 == col1) ? row0 < row1 : col0 < col1;
}
};

@@ -184,6 +132,10 @@ Connections<GridDim> getConnections(typename std::vector<Cell>::iterator it,
}

template <typename CellCollection, typename ClusterCollection>
requires(
Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type> &&
Acts::Ccl::CanAcceptCell<typename CellCollection::value_type,
typename ClusterCollection::value_type>)
ClusterCollection mergeClustersImpl(CellCollection& cells) {
using Cluster = typename ClusterCollection::value_type;

@@ -215,6 +167,8 @@ ClusterCollection mergeClustersImpl(CellCollection& cells) {
namespace Acts::Ccl {

template <typename Cell>
requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
Acts::Ccl::HasRetrievableRowInfo<Cell>)
ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
const Cell& iter) const {
int deltaRow = std::abs(getCellRow(ref) - getCellRow(iter));
@@ -237,7 +191,7 @@ ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
return ConnectResult::eNoConn;
}

template <typename Cell>
template <Acts::Ccl::HasRetrievableColumnInfo Cell>
ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
const Cell& iter) const {
int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
@@ -267,17 +221,19 @@ void recordEquivalences(const internal::Connections<GridDim> seen,
}

template <typename CellCollection, std::size_t GridDim, typename Connect>
requires(
Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type>)
void labelClusters(CellCollection& cells, Connect connect) {
using Cell = typename CellCollection::value_type;
internal::staticCheckCellType<Cell, GridDim>();

internal::DisjointSets ds{};

// Sort cells by position to enable in-order scan
std::ranges::sort(cells, internal::Compare<Cell, GridDim>());

// First pass: Allocate labels and record equivalences
for (auto it = cells.begin(); it != cells.end(); ++it) {
for (auto it = std::ranges::begin(cells); it != std::ranges::end(cells);
++it) {
const internal::Connections<GridDim> seen =
internal::getConnections<Cell, Connect, GridDim>(it, cells, connect);
if (seen.nconn == 0) {
@@ -299,13 +255,11 @@ void labelClusters(CellCollection& cells, Connect connect) {

template <typename CellCollection, typename ClusterCollection,
std::size_t GridDim = 2>
requires(GridDim == 1 || GridDim == 2) &&
Acts::Ccl::HasRetrievableLabelInfo<
typename CellCollection::value_type>
ClusterCollection mergeClusters(CellCollection& cells) {
using Cell = typename CellCollection::value_type;
using Cluster = typename ClusterCollection::value_type;
internal::staticCheckGridDim<GridDim>();
internal::staticCheckCellType<Cell, GridDim>();
internal::staticCheckClusterType<Cluster&, const Cell&>();

if constexpr (GridDim > 1) {
// Sort the cells by their cluster label, only needed if more than
// one spatial dimension
@@ -318,10 +272,6 @@ ClusterCollection mergeClusters(CellCollection& cells) {
template <typename CellCollection, typename ClusterCollection,
std::size_t GridDim, typename Connect>
ClusterCollection createClusters(CellCollection& cells, Connect connect) {
using Cell = typename CellCollection::value_type;
using Cluster = typename ClusterCollection::value_type;
internal::staticCheckCellType<Cell, GridDim>();
internal::staticCheckClusterType<Cluster&, const Cell&>();
labelClusters<CellCollection, GridDim, Connect>(cells, connect);
return mergeClusters<CellCollection, ClusterCollection, GridDim>(cells);
}
38 changes: 38 additions & 0 deletions Core/include/Acts/Clusterization/TimedClusterization.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Clusterization/Clusterization.hpp"
#include "Acts/Definitions/Algebra.hpp"

#include <limits>

namespace Acts::Ccl {

template <typename Cell>
concept HasRetrievableTimeInfo = requires(Cell cell) {
{ getCellTime(cell) } -> std::same_as<Acts::ActsScalar>;
};

template <Acts::Ccl::HasRetrievableTimeInfo Cell, std::size_t N>
struct TimedConnect : public Acts::Ccl::DefaultConnect<Cell, N> {
Acts::ActsScalar timeTolerance{std::numeric_limits<Acts::ActsScalar>::max()};

TimedConnect() = default;
TimedConnect(Acts::ActsScalar time);
TimedConnect(Acts::ActsScalar time, bool commonCorner)
requires(N == 2);
~TimedConnect() override = default;

ConnectResult operator()(const Cell& ref, const Cell& iter) const override;
};

} // namespace Acts::Ccl

#include "Acts/Clusterization/TimedClusterization.ipp"
36 changes: 36 additions & 0 deletions Core/include/Acts/Clusterization/TimedClusterization.ipp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

namespace Acts::Ccl {

template <Acts::Ccl::HasRetrievableTimeInfo Cell, std::size_t N>
TimedConnect<Cell, N>::TimedConnect(Acts::ActsScalar time)
: timeTolerance(time) {}

template <Acts::Ccl::HasRetrievableTimeInfo Cell, std::size_t N>
TimedConnect<Cell, N>::TimedConnect(Acts::ActsScalar time, bool commonCorner)
requires(N == 2)
: Acts::Ccl::DefaultConnect<Cell, N>(commonCorner), timeTolerance(time) {}

template <Acts::Ccl::HasRetrievableTimeInfo Cell, std::size_t N>
Acts::Ccl::ConnectResult TimedConnect<Cell, N>::operator()(
const Cell& ref, const Cell& iter) const {
Acts::Ccl::ConnectResult spaceCompatibility =
Acts::Ccl::DefaultConnect<Cell, N>::operator()(ref, iter);
if (spaceCompatibility != Acts::Ccl::ConnectResult::eConn) {
return spaceCompatibility;
}

if (std::abs(getCellTime(ref) - getCellTime(iter)) < timeTolerance) {
return Acts::Ccl::ConnectResult::eConn;
}

return Acts::Ccl::ConnectResult::eNoConn;
}

} // namespace Acts::Ccl
1 change: 1 addition & 0 deletions Tests/UnitTests/Core/Clusterization/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_unittest(Clusterization1D ClusterizationTests1D.cpp)
add_unittest(Clusterization2D ClusterizationTests2D.cpp)
add_unittest(TimedClusterization TimedClusterizationTests.cpp)
260 changes: 260 additions & 0 deletions Tests/UnitTests/Core/Clusterization/TimedClusterizationTests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

#include <boost/test/unit_test.hpp>

#include "Acts/Clusterization/TimedClusterization.hpp"

namespace Acts::Test {

// Define objects
using Identifier = std::size_t;
struct Cell {
Cell(Identifier identifier, int c, int r, double t)
: id(identifier), column(c), row(r), time(t) {}

Identifier id{};
int column{0};
int row{0};
int label{-1};
double time{0.};
};

struct Cluster {
std::vector<Identifier> ids{};
};

using CellCollection = std::vector<Acts::Test::Cell>;
using ClusterCollection = std::vector<Acts::Test::Cluster>;

// Define functions
static inline int getCellRow(const Cell& cell) {
return cell.row;
}

static inline int getCellColumn(const Cell& cell) {
return cell.column;
}

static inline int& getCellLabel(Cell& cell) {
return cell.label;
}

static inline double getCellTime(const Cell& cell) {
return cell.time;
}

static void clusterAddCell(Cluster& cl, const Cell& cell) {
cl.ids.push_back(cell.id);
}

BOOST_AUTO_TEST_CASE(TimedGrid_1D_withtime) {
// 1x10 matrix
/*
X X X Y O X Y Y X X
*/
// 6 + 3 cells -> 3 + 2 clusters in total

std::vector<Cell> cells;
// X
cells.emplace_back(0ul, 0, -1, 0);
cells.emplace_back(1ul, 1, -1, 0);
cells.emplace_back(2ul, 2, -1, 0);
cells.emplace_back(3ul, 5, -1, 0);
cells.emplace_back(4ul, 8, -1, 0);
cells.emplace_back(5ul, 9, -1, 0);
// Y
cells.emplace_back(6ul, 3, 0, 1);
cells.emplace_back(7ul, 6, 1, 1);
cells.emplace_back(8ul, 7, 1, 1);

std::vector<std::vector<Identifier>> expectedResults;
expectedResults.push_back({0ul, 1ul, 2ul});
expectedResults.push_back({6ul});
expectedResults.push_back({3ul});
expectedResults.push_back({7ul, 8ul});
expectedResults.push_back({4ul, 5ul});

ClusterCollection clusters =
Acts::Ccl::createClusters<CellCollection, ClusterCollection, 1>(
cells, Acts::Ccl::TimedConnect<Cell, 1>(0.5));

BOOST_CHECK_EQUAL(5ul, clusters.size());

for (std::size_t i(0); i < clusters.size(); ++i) {
std::vector<Identifier>& timedIds = clusters[i].ids;
const std::vector<Identifier>& expected = expectedResults[i];
std::sort(timedIds.begin(), timedIds.end());
BOOST_CHECK_EQUAL(timedIds.size(), expected.size());

for (std::size_t j(0); j < timedIds.size(); ++j) {
BOOST_CHECK_EQUAL(timedIds[j], expected[j]);
}
}
}

BOOST_AUTO_TEST_CASE(TimedGrid_2D_notime) {
// 4x4 matrix
/*
X O O X
O O O X
X X O O
X O O X
*/
// 7 cells -> 4 clusters in total

std::vector<Cell> cells;
cells.emplace_back(0ul, 0, 0, 0);
cells.emplace_back(1ul, 3, 0, 0);
cells.emplace_back(2ul, 3, 1, 0);
cells.emplace_back(3ul, 0, 2, 0);
cells.emplace_back(4ul, 1, 2, 0);
cells.emplace_back(5ul, 0, 3, 0);
cells.emplace_back(6ul, 3, 3, 0);

std::vector<std::vector<Identifier>> expectedResults;
expectedResults.push_back({0ul});
expectedResults.push_back({3ul, 4ul, 5ul});
expectedResults.push_back({1ul, 2ul});
expectedResults.push_back({6ul});

ClusterCollection clusters =
Acts::Ccl::createClusters<CellCollection, ClusterCollection, 2>(
cells,
Acts::Ccl::TimedConnect<Cell, 2>(std::numeric_limits<double>::max()));

BOOST_CHECK_EQUAL(4ul, clusters.size());

// Compare against default connect (only space)
ClusterCollection defaultClusters =
Acts::Ccl::createClusters<CellCollection, ClusterCollection, 2>(
cells, Acts::Ccl::DefaultConnect<Cell, 2>());

BOOST_CHECK_EQUAL(4ul, defaultClusters.size());
BOOST_CHECK_EQUAL(defaultClusters.size(), expectedResults.size());

std::vector<std::size_t> sizes{1, 3, 2, 1};
for (std::size_t i(0); i < clusters.size(); ++i) {
std::vector<Identifier>& timedIds = clusters[i].ids;
std::vector<Identifier>& defaultIds = defaultClusters[i].ids;
const std::vector<Identifier>& expected = expectedResults[i];
BOOST_CHECK_EQUAL(timedIds.size(), defaultIds.size());
BOOST_CHECK_EQUAL(timedIds.size(), sizes[i]);
BOOST_CHECK_EQUAL(timedIds.size(), expected.size());

std::sort(timedIds.begin(), timedIds.end());
std::sort(defaultIds.begin(), defaultIds.end());
for (std::size_t j(0); j < timedIds.size(); ++j) {
BOOST_CHECK_EQUAL(timedIds[j], defaultIds[j]);
BOOST_CHECK_EQUAL(timedIds[j], expected[j]);
}
}
}

BOOST_AUTO_TEST_CASE(TimedGrid_2D_withtime) {
// 4x4 matrix
/*
X Y O X
O Y Y X
X X Z Z
X O O X
*/
// 7 + 3 + 2 cells -> 4 + 1 + 1 clusters in total

std::vector<Cell> cells;
// X
cells.emplace_back(0ul, 0, 0, 0);
cells.emplace_back(1ul, 3, 0, 0);
cells.emplace_back(2ul, 3, 1, 0);
cells.emplace_back(3ul, 0, 2, 0);
cells.emplace_back(4ul, 1, 2, 0);
cells.emplace_back(5ul, 0, 3, 0);
cells.emplace_back(6ul, 3, 3, 0);
// Y
cells.emplace_back(7ul, 1, 0, 1);
cells.emplace_back(8ul, 1, 1, 1);
cells.emplace_back(9ul, 2, 1, 1);
// Z
cells.emplace_back(10ul, 2, 2, 2);
cells.emplace_back(11ul, 3, 2, 2);

std::vector<std::vector<Identifier>> expectedResults;
expectedResults.push_back({0ul});
expectedResults.push_back({3ul, 4ul, 5ul});
expectedResults.push_back({7ul, 8ul, 9ul});
expectedResults.push_back({10ul, 11ul});
expectedResults.push_back({1ul, 2ul});
expectedResults.push_back({6ul});

ClusterCollection clusters =
Acts::Ccl::createClusters<CellCollection, ClusterCollection, 2>(
cells, Acts::Ccl::TimedConnect<Cell, 2>(0.5));

BOOST_CHECK_EQUAL(6ul, clusters.size());

std::vector<std::size_t> sizes{1, 3, 3, 2, 2, 1};
for (std::size_t i(0); i < clusters.size(); ++i) {
std::vector<Identifier>& timedIds = clusters[i].ids;
BOOST_CHECK_EQUAL(timedIds.size(), sizes[i]);
std::sort(timedIds.begin(), timedIds.end());

const std::vector<Identifier>& expected = expectedResults[i];
BOOST_CHECK_EQUAL(timedIds.size(), expected.size());

for (std::size_t j(0); j < timedIds.size(); ++j) {
BOOST_CHECK_EQUAL(timedIds[j], expected[j]);
}
}
}

BOOST_AUTO_TEST_CASE(TimedGrid_2D_noTollerance) {
// 4x4 matrix
/*
X O O X
O O O X
X X O O
X O O X
*/
// 7 cells -> 7 clusters in total
// since time requirement will never be satisfied

std::vector<Cell> cells;
cells.emplace_back(0ul, 0, 0, 0);
cells.emplace_back(1ul, 3, 0, 0);
cells.emplace_back(2ul, 3, 1, 0);
cells.emplace_back(3ul, 0, 2, 0);
cells.emplace_back(4ul, 1, 2, 0);
cells.emplace_back(5ul, 0, 3, 0);
cells.emplace_back(6ul, 3, 3, 0);

std::vector<std::vector<Identifier>> expectedResults;
expectedResults.push_back({0ul});
expectedResults.push_back({3ul});
expectedResults.push_back({5ul});
expectedResults.push_back({4ul});
expectedResults.push_back({1ul});
expectedResults.push_back({2ul});
expectedResults.push_back({6ul});

ClusterCollection clusters =
Acts::Ccl::createClusters<CellCollection, ClusterCollection, 2>(
cells, Acts::Ccl::TimedConnect<Cell, 2>(0.));

BOOST_CHECK_EQUAL(7ul, clusters.size());

for (std::size_t i(0); i < clusters.size(); ++i) {
std::vector<Identifier>& timedIds = clusters[i].ids;
const std::vector<Identifier>& expected = expectedResults[i];

BOOST_CHECK_EQUAL(timedIds.size(), 1);
BOOST_CHECK_EQUAL(timedIds.size(), expected.size());
BOOST_CHECK_EQUAL(timedIds[0], expected[0]);
}
}

} // namespace Acts::Test

0 comments on commit 6cea6b2

Please sign in to comment.