diff --git a/src/agg/mean_aggl.cpp b/src/agg/mean_aggl.cpp index 89a1a42..7698061 100644 --- a/src/agg/mean_aggl.cpp +++ b/src/agg/mean_aggl.cpp @@ -127,7 +127,7 @@ struct agglomeration_size_heuristic_t struct agglomeration_semantic_heuristic_t { - aff_t aff_threshold = 0.5; + aff_t aff_threshold = 1.0; size_t total_signal_threshold = 100'000; double dominant_signal_ratio = 0.6; }; @@ -269,7 +269,12 @@ std::vector load_sem(const char * sem_filename, const std::vector()); + if (sem_counts[k][0] > 0 and v[0] > 0 and v[0] != sem_counts[k][0]) { + sem_counts[k][1] = 1; + } + if (v[0] > 0) { + sem_counts[k] = v; + } } return sem_counts; } @@ -624,20 +629,14 @@ std::pair sem_label(const sem_array_t & labels) bool sem_can_merge(const sem_array_t & labels1, const sem_array_t & labels2, const agglomeration_semantic_heuristic_t & sem_params) { - auto max_label1 = std::distance(labels1.begin(), std::max_element(labels1.begin(), labels1.end())); - auto max_label2 = std::distance(labels2.begin(), std::max_element(labels2.begin(), labels2.end())); - auto total_label1 = std::accumulate(labels1.begin(), labels1.end(), static_cast(0)); - auto total_label2 = std::accumulate(labels2.begin(), labels2.end(), static_cast(0)); - if (labels1[max_label1] < sem_params.dominant_signal_ratio * total_label1 || total_label1 < sem_params.total_signal_threshold) { //unsure about the semantic label - return true; + if (labels1[1] > 0 or labels2[1] > 0) { + return false; } - if (labels2[max_label2] < sem_params.dominant_signal_ratio * total_label2 || total_label2 < sem_params.total_signal_threshold) { //unsure about the semantic label + if (labels1[0] == 0 or labels2[0] == 0 or labels1[0] == labels2[0]) { return true; + } else { + return false; } - if (max_label1 == max_label2) { - return true; - } - return false; } template , class Plus = std::plus, @@ -752,9 +751,11 @@ inline agglomeration_output_t agglomerate_cc(agglomeration_data_t std::swap(seg_size[v0], seg_size[s]); if (!sem_counts.empty()) { - std::transform(sem_counts[v0].begin(), sem_counts[v0].end(), sem_counts[v1].begin(), sem_counts[v0].begin(), std::plus()); - sem_counts[v1] = sem_array_t(); - std::swap(sem_counts[v0], sem_counts[s]); + if (sem_counts[v0][0] > 0) { + std::swap(sem_counts[v0], sem_counts[s]); + } else { + std::swap(sem_counts[v1], sem_counts[s]); + } } output.merged_rg_vector.push_back(*(e.edge)); diff --git a/src/seg/SemExtractor.hpp b/src/seg/SemExtractor.hpp index 1e4ef14..7a519f7 100644 --- a/src/seg/SemExtractor.hpp +++ b/src/seg/SemExtractor.hpp @@ -6,7 +6,7 @@ #include #include -using sem_array_t = std::array; +using sem_array_t = std::array; template class SemExtractor @@ -17,8 +17,14 @@ class SemExtractor void collectVoxel(Coord c, Tseg segid) { auto sem_label = m_sem[c[0]][c[1]][c[2]]; - if (sem_label >= 0 and sem_map[sem_label] >= 0) { - m_labels[segid][sem_map[sem_label]] += 1; + if (m_labels[segid][1] > 0 or sem_label == 0) { + return; + } + + if (m_labels[segid][0] == 0) { + m_labels[segid][0] = sem_label; + } else if (m_labels[segid][0] != sem_label){ + m_labels[segid][1] = 1; } } @@ -38,10 +44,10 @@ class SemExtractor if (chunkMap.count(k) > 0) { svid = chunkMap.at(k); } - if (remapped_labels.count(svid) == 0) { + if (remapped_labels.count(svid) == 0 or remapped_labels[svid][0] == 0) { remapped_labels[svid] = v; - } else { - std::transform(remapped_labels[svid].begin(), remapped_labels[svid].end(), v.begin(), remapped_labels[svid].begin(), std::plus()); + } else if (v[0] > 0 and remapped_labels[svid][0] != v[0]) { + remapped_labels[svid][1] = 1; } } for (const auto & [k,v] : remapped_labels) { diff --git a/src/seg/Types.h b/src/seg/Types.h index 54ae489..5082dfe 100644 --- a/src/seg/Types.h +++ b/src/seg/Types.h @@ -22,7 +22,7 @@ using ContactRegionExt = MapContainer >; template using Edge = std::array >, 3>; -using semantic_t = uint8_t; +using semantic_t = uint64_t; template struct __attribute__((packed)) matching_entry_t