Skip to content

Commit

Permalink
Simplify connected components impl in distributed
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Aug 13, 2019
1 parent ef09a10 commit 296e756
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 104 deletions.
123 changes: 25 additions & 98 deletions include/nifty/distributed/graph_tools.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -408,61 +408,23 @@ namespace distributed {
sets.make_set(node_id);
}

std::vector<NodeType> nodes;
graph.nodes(nodes);

// First pass:
// iterate over each node and create new label at node
// or assign representative of the neighbor node
NodeType currentLabel = 0;
for(const NodeType u : nodes){

if(ignoreLabel && (u == 0)) {
const auto & edges = graph.edges();
for(const auto & edge: edges) {
const uint64_t u = edge.first;
const uint64_t v = edge.second;
const uint64_t lU = labels(u);
const uint64_t lV = labels(v);
if(ignoreLabel && (lU == 0 || lV == 0)) {
continue;
}

// iterate over the nodes in the neighborhood
// and collect the nodes that are connected
const auto & nhood = graph.nodeAdjacency(u);
std::set<NodeType> ngbLabels;
const auto lU = labels(u);

for(auto nhIt = nhood.begin(); nhIt != nhood.end(); ++nhIt) {
const NodeType v = nhIt->first;
const auto lV = labels(v);

// nodes are connected if the edge has the value 0
// this is in accordance with cut edges being 1
if(lU == lV) {
ngbLabels.insert(v);
}
}

// check if we are connected to any of the neighbors
// and if the neighbor labels need to be merged
if(ngbLabels.size() == 0) {
// no connection -> make new label @ current node
out(u) = ++currentLabel;
} else if (ngbLabels.size() == 1) {
// only single label -> we assign its representative to the current node
const uint64_t ngb = *ngbLabels.begin();
out(u) = sets.find_set(ngb);
sets.link(u, ngb);
} else {
// multiple labels -> we merge them and assign representative to the current node
std::vector<NodeType> tmp_labels(ngbLabels.begin(), ngbLabels.end());
for(unsigned ii = 1; ii < tmp_labels.size(); ++ii) {
sets.link(tmp_labels[ii - 1], tmp_labels[ii]);
}
sets.link(u, tmp_labels[0]);
out(u) = sets.find_set(tmp_labels[0]);
if(lU == lV) {
sets.link(u, v);
}
}

// Second pass:
// Assign representative to each pixel
for(const NodeType u : nodes){
out(u) = sets.find_set(out(u));
// assign representative to each pixel
for(std::size_t u = 0; u < out.size(); ++u){
out(u) = sets.find_set(u);
}
}

Expand All @@ -471,9 +433,8 @@ namespace distributed {
template<class EDGES, class NODES>
void connectedComponents(const Graph & graph,
const xt::xexpression<EDGES> & edges_exp,
const bool ignoreLabel,
xt::xexpression<NODES> & labels_exp) {
const auto & edges = edges_exp.derived_cast();
const auto & edgeLabels = edges_exp.derived_cast();
auto & labels = labels_exp.derived_cast();

std::vector<NodeType> nodes;
Expand All @@ -490,56 +451,22 @@ namespace distributed {
sets.make_set(node_id);
}

// First pass:
// iterate over each node and create new label at node
// or assign representative of the neighbor node
NodeType currentLabel = 0;
for(const NodeType node : nodes){

if(ignoreLabel && (node == 0)) {
continue;
}
const auto & edges = graph.edges();
for(std::size_t edge_id = 0; edge_id < edges.size(); ++edge_id) {
const auto & edge = edges[edge_id];
const uint64_t u = edge.first;
const uint64_t v = edge.second;

// iterate over the nodes in the neighborhood
// and collect the nodes that are connected
const auto & nhood = graph.nodeAdjacency(node);
std::set<NodeType> ngbLabels;
for(auto nhIt = nhood.begin(); nhIt != nhood.end(); ++nhIt) {
const NodeType nhNode = nhIt->first;
const EdgeIndexType nhEdge = nhIt->second;

// nodes are connected if the edge has the value 0
// this is in accordance with cut edges being 1
if(!edges(nhEdge)) {
ngbLabels.insert(nhNode);
}
}

// check if we are connected to any of the neighbors
// and if the neighbor labels need to be merged
if(ngbLabels.size() == 0) {
// no connection -> make new label @ current pixel
labels(node) = ++currentLabel;
} else if (ngbLabels.size() == 1) {
// only single label -> we assign its representative to the current pixel
const uint64_t ngb = *ngbLabels.begin();
sets.link(node, ngb);
labels(node) = sets.find_set(ngb);
} else {
// multiple labels -> we merge them and assign representative to the current pixel
std::vector<NodeType> tmp_labels(ngbLabels.begin(), ngbLabels.end());
for(unsigned ii = 1; ii < tmp_labels.size(); ++ii) {
sets.link(tmp_labels[ii - 1], tmp_labels[ii]);
}
sets.link(node, tmp_labels[0]);
labels(node) = sets.find_set(tmp_labels[0]);
// nodes are connected if the edge has the value 0
// this is in accordance with cut edges being 1
if(!edgeLabels(edge_id)) {
sets.link(u, v);
}
}

// Second pass:
// Assign representative to each pixel
for(const NodeType node : nodes){
labels(node) = sets.find_set(labels(node));
// assign representative to each pixel
for(std::size_t u = 0; u < labels.size(); ++u){
labels(u) = sets.find_set(u);
}
}

Expand Down
4 changes: 2 additions & 2 deletions include/nifty/graph/components.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public:
ufd_(graph.nodeIdUpperBound()+1),
offset_(ufd_.numberOfElements() - graph_.numberOfNodes()),
needsReset_(false){

}

uint64_t build(){
Expand Down Expand Up @@ -54,7 +54,7 @@ public:
if(edgeLabels[edge] == 0){
const auto u = graph_.u(edge);
const auto v = graph_.v(edge);
ufd_.merge(u,v);
ufd_.merge(u,v);
}
}
needsReset_ = true;
Expand Down
7 changes: 3 additions & 4 deletions src/python/lib/distributed/graph_extraction.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,14 @@ namespace distributed {


module.def("connectedComponents", [](const Graph & graph,
const xt::pytensor<bool, 1> & edgeLabels,
const bool ignoreLabel){
const xt::pytensor<bool, 1> & edgeLabels){
xt::pytensor<NodeType, 1> labels = xt::zeros<NodeType>({graph.maxNodeId() + 1});
{
py::gil_scoped_release allowThreads;
connectedComponents(graph, edgeLabels, ignoreLabel, labels);
connectedComponents(graph, edgeLabels, labels);
}
return labels;
}, py::arg("graph"), py::arg("edgeLabels"), py::arg("ignoreLabel"));
}, py::arg("graph"), py::arg("edgeLabels"));


module.def("connectedComponentsFromNodes", [](const Graph & graph,
Expand Down

0 comments on commit 296e756

Please sign in to comment.