Skip to content

Commit

Permalink
better rgfa support in call
Browse files Browse the repository at this point in the history
  • Loading branch information
glennhickey committed Oct 19, 2023
1 parent f29425d commit ea5d6e5
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 28 deletions.
124 changes: 101 additions & 23 deletions src/graph_caller.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "graph_caller.hpp"
#include "algorithms/expand_context.hpp"
#include "annotation.hpp"
#include "rgfa.hpp"

//#define debug

Expand All @@ -18,42 +19,52 @@ void GraphCaller::call_top_level_snarls(const HandleGraph& graph, RecurseType re

// Used to recurse on children of parents that can't be called
size_t thread_count = get_thread_count();
vector<vector<const Snarl*>> snarl_queue(thread_count);
vector<vector<pair<const Snarl*, int>>> snarl_queue(thread_count);

// Run the snarl caller on a snarl, and queue up the children if it fails
auto process_snarl = [&](const Snarl* snarl) {
auto process_snarl = [&](const Snarl* snarl, int ploidy_override) {

if (!snarl_manager.is_trivial(snarl, graph)) {

#ifdef debug
cerr << "GraphCaller running call_snarl on " << pb2json(*snarl) << endl;
#endif

bool was_called = call_snarl(*snarl);
if (recurse_type == RecurseAlways || (!was_called && recurse_type == RecurseOnFail)) {
const vector<const Snarl*>& children = snarl_manager.children_of(snarl);
vector<const Snarl*>& thread_queue = snarl_queue[omp_get_thread_num()];
thread_queue.insert(thread_queue.end(), children.begin(), children.end());
const vector<const Snarl*>& children = snarl_manager.children_of(snarl);
vector<int> child_ploidies(children.size(), -1);
bool was_called = call_snarl(*snarl, ploidy_override, &child_ploidies);
if (recurse_type == RecurseAlways || (!was_called && recurse_type == RecurseOnFail)) {
vector<pair<const Snarl*, int>>& thread_queue = snarl_queue[omp_get_thread_num()];
for (int64_t i = 0; i < children.size(); ++i) {
thread_queue.push_back(make_pair(children[i], child_ploidies[i]));
}
}
}
};

// Start with the top level snarls
snarl_manager.for_each_top_level_snarl_parallel(process_snarl);
// Queue them up since process_snarl is no longer a valid callback for the iterator snarl_manager.for_each_top_level_snarl()
vector<const Snarl*> top_level_snarls;
snarl_manager.for_each_top_level_snarl([&](const Snarl* snarl) {
top_level_snarls.push_back(snarl);
});
#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < top_level_snarls.size(); ++i) {
process_snarl(top_level_snarls[i], -1);
}

// Then recurse on any children the snarl caller failed to handle
while (!std::all_of(snarl_queue.begin(), snarl_queue.end(),
[](const vector<const Snarl*>& snarl_vec) {return snarl_vec.empty();})) {
vector<const Snarl*> cur_queue;
for (vector<const Snarl*>& thread_queue : snarl_queue) {
[](const vector<pair<const Snarl*, int>>& snarl_vec) {return snarl_vec.empty();})) {
vector<pair<const Snarl*, int>> cur_queue;
for (vector<pair<const Snarl*, int>>& thread_queue : snarl_queue) {
cur_queue.reserve(cur_queue.size() + thread_queue.size());
std::move(thread_queue.begin(), thread_queue.end(), std::back_inserter(cur_queue));
thread_queue.clear();
}

#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < cur_queue.size(); ++i) {
process_snarl(cur_queue[i]);
process_snarl(cur_queue[i].first, cur_queue[i].second);
}
}

Expand Down Expand Up @@ -175,8 +186,48 @@ vector<Chain> GraphCaller::break_chain(const HandleGraph& graph, const Chain& ch

return chain_frags;
}


void GraphCaller::resolve_child_ploidies(const HandleGraph& graph, const Snarl* snarl, const vector<SnarlTraversal>& travs,
const vector<int>& genotype, vector<int>& child_ploidies) {
const vector<const Snarl*>& children = snarl_manager.children_of(snarl);

child_ploidies.resize(children.size());

VCFOutputCaller::VCFOutputCaller(const string& sample_name) : sample_name(sample_name), translation(nullptr), include_nested(false)
if (!children.empty()) {
// index the nodes in the traversals
vector<unordered_set<handle_t>> trav_indexes(genotype.size());
for (int64_t i = 0; i < genotype.size(); ++i) {
const SnarlTraversal& trav = travs[genotype[i]];
unordered_set<handle_t>& trav_idx = trav_indexes[i];
if (i > 0 && genotype[i] == genotype[i-1]) {
trav_idx = trav_indexes[i-1];
} else {
for (int j = 0; j < trav.visit_size(); ++j) {
trav_idx.insert(graph.get_handle(trav.visit(j).node_id(), trav.visit(j).backward()));
}
}
}

// for every child snarl, count the number of traversals that spans it -- this will be its ploidy
for (int64_t i = 0; i < children.size(); ++i) {
const Snarl* child_snarl = children[i];
child_ploidies[i] = 0;
handle_t child_start = graph.get_handle(child_snarl->start().node_id(), child_snarl->start().backward());
handle_t child_end = graph.get_handle(child_snarl->end().node_id(), child_snarl->end().backward());
for (int64_t j = 0; j < trav_indexes.size(); ++j) {
unordered_set<handle_t>& trav_idx = trav_indexes[j];
if ((trav_idx.count(child_start) && trav_idx.count(child_end)) ||
(trav_idx.count(graph.flip(child_start)) && trav_idx.count(graph.flip(child_end)))) {
++child_ploidies[i];
}
}
}
}
}


VCFOutputCaller::VCFOutputCaller(const string& sample_name) : sample_name(sample_name), translation(nullptr), include_nested(false), write_full_names(false)
{
output_variants.resize(get_thread_count());
}
Expand Down Expand Up @@ -491,11 +542,11 @@ void VCFOutputCaller::emit_variant(const PathPositionHandleGraph& graph, SnarlCa

// resolve subpath naming
subrange_t subrange;
string basepath_name = Paths::strip_subrange(ref_path_name, &subrange);
string basepath_name = Paths::strip_subrange(RGFACover::revert_rgfa_path_name(ref_path_name), &subrange);
size_t basepath_offset = subrange == PathMetadata::NO_SUBRANGE ? 0 : subrange.first;
// in VCF we usually just want a contig
string contig_name = PathMetadata::parse_locus_name(basepath_name);
if (contig_name != PathMetadata::NO_LOCUS_NAME) {
if (contig_name != PathMetadata::NO_LOCUS_NAME && !this->write_full_names) {
basepath_name = contig_name;
}
// fill out the rest of the variant
Expand Down Expand Up @@ -766,6 +817,20 @@ void VCFOutputCaller::scan_snarl(const string& allele_string, function<void(cons
}
}

void VCFOutputCaller::toggle_full_names_from_paths(const vector<string>& ref_paths) {
set<string> samples;
set<int> haplotypes;
for (const string& path_name : ref_paths) {
samples.insert(PathMetadata::parse_sample_name(path_name));
haplotypes.insert(PathMetadata::parse_haplotype(path_name));
if (samples.size() > 1 || haplotypes.size() > 1) {
this->write_full_names = true;
return;
}
}
this->write_full_names = false;
}

GAFOutputCaller::GAFOutputCaller(AlignmentEmitter* emitter, const string& sample_name, const vector<string>& ref_paths,
size_t trav_padding) :
emitter(emitter),
Expand Down Expand Up @@ -1054,7 +1119,7 @@ VCFGenotyper::~VCFGenotyper() {

}

bool VCFGenotyper::call_snarl(const Snarl& snarl) {
bool VCFGenotyper::call_snarl(const Snarl& snarl, int ploidy_override, vector<int>* out_child_ploidies) {

// could be that our graph is a subgraph of the graph the snarls were computed from
// so bypass snarls we can't process
Expand Down Expand Up @@ -1296,6 +1361,8 @@ LegacyCaller::LegacyCaller(const PathPositionHandleGraph& graph,
// our graph is not in vg format. we will make graphs for each site as needed and work with those
traversal_finder = nullptr;
}

this->toggle_full_names_from_paths(ref_paths);
}

LegacyCaller::~LegacyCaller() {
Expand All @@ -1305,7 +1372,7 @@ LegacyCaller::~LegacyCaller() {
}
}

bool LegacyCaller::call_snarl(const Snarl& snarl) {
bool LegacyCaller::call_snarl(const Snarl& snarl, int ploidy_override, vector<int>* out_child_ploidies) {

// if we can't handle the snarl, then the GraphCaller framework will recurse on its children
if (!is_traversable(snarl)) {
Expand Down Expand Up @@ -1630,20 +1697,26 @@ FlowCaller::FlowCaller(const PathPositionHandleGraph& graph,
ref_ploidies[ref_paths[i]] = i < ref_path_ploidies.size() ? ref_path_ploidies[i] : 2;
}

this->toggle_full_names_from_paths(ref_paths);
}

FlowCaller::~FlowCaller() {

}

bool FlowCaller::call_snarl(const Snarl& managed_snarl) {
bool FlowCaller::call_snarl(const Snarl& managed_snarl, int ploidy_override, vector<int>* out_child_ploidies) {

// todo: In order to experiment with merging consecutive snarls to make longer traversals,
// I am experimenting with sending "fake" snarls through this code. So make a local
// copy to work on to do things like flip -- calling any snarl_manager code that
// wants a pointer will crash.
Snarl snarl = managed_snarl;

if (ploidy_override == 0) {
// returning true is a bit ironic, but we do *not* want to recurse so it's important we do
return true;
}

if (snarl.start().node_id() == snarl.end().node_id() ||
!graph.has_node(snarl.start().node_id()) || !graph.has_node(snarl.end().node_id())) {
// can't call one-node or out-of graph snarls.
Expand Down Expand Up @@ -1803,10 +1876,14 @@ bool FlowCaller::call_snarl(const Snarl& managed_snarl) {
// use our support caller to choose our genotype
vector<int> trav_genotype;
unique_ptr<SnarlCaller::CallInfo> trav_call_info;
int ploidy = ref_ploidies[ref_path_name];
int ploidy = ploidy_override == -1 ? ref_ploidies[ref_path_name] : ploidy_override;
std::tie(trav_genotype, trav_call_info) = snarl_caller.genotype(snarl, travs, ref_trav_idx, ploidy, ref_path_name,
make_pair(get<0>(ref_interval), get<1>(ref_interval)));

if (out_child_ploidies != nullptr) {
// derive ploidies for child snarls from the genotype traversals and save them
resolve_child_ploidies(graph, &managed_snarl, travs, trav_genotype, *out_child_ploidies);
}

assert(trav_genotype.empty() || trav_genotype.size() == ploidy);

if (!gaf_output) {
Expand Down Expand Up @@ -1866,14 +1943,15 @@ NestedFlowCaller::NestedFlowCaller(const PathPositionHandleGraph& graph,
ref_path_set.insert(ref_paths[i]);
ref_ploidies[ref_paths[i]] = i < ref_path_ploidies.size() ? ref_path_ploidies[i] : 2;
}


this->toggle_full_names_from_paths(ref_paths);
}

NestedFlowCaller::~NestedFlowCaller() {

}

bool NestedFlowCaller::call_snarl(const Snarl& managed_snarl) {
bool NestedFlowCaller::call_snarl(const Snarl& managed_snarl, int ploidy_override, vector<int>* out_child_ploidies) {

// remember the calls for each child snarl in this table
CallTable call_table;
Expand Down
22 changes: 17 additions & 5 deletions src/graph_caller.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,16 @@ class GraphCaller {
RecurseType recurise_type = RecurseOnFail);

/// Call a given snarl, and print the output to out_stream
virtual bool call_snarl(const Snarl& snarl) = 0;
virtual bool call_snarl(const Snarl& snarl, int ploidy_override = -1, vector<int>* out_child_ploidies = nullptr) = 0;

protected:

/// Break up a chain into bits that we want to call using size heuristics
vector<Chain> break_chain(const HandleGraph& graph, const Chain& chain, size_t max_edges, size_t max_trivial);

/// Compute the child ploidies baseds on the called genotype
void resolve_child_ploidies(const HandleGraph& graph, const Snarl* snarl, const vector<SnarlTraversal>& travs,
const vector<int>& genotype, vector<int>& child_ploidies);

protected:

Expand Down Expand Up @@ -135,6 +139,11 @@ class VCFOutputCaller {

// update the PS and LV tags in the output buffer (called in write_variants if include_nested is true)
void update_nesting_info_tags(const SnarlManager* snarl_manager);

// toggle write_full_names automatically depending on ref_paths
// (keeps backwards compatibility when it was always false whenever possible,
// but flips to full names when necessary)
void toggle_full_names_from_paths(const vector<string>& ref_paths);

/// output vcf
mutable vcflib::VariantCallFile output_vcf;
Expand All @@ -154,6 +163,9 @@ class VCFOutputCaller {

// need to write LV/PS info tags
bool include_nested;

// toggle writing full name or just contig name
bool write_full_names;
};

/**
Expand Down Expand Up @@ -221,7 +233,7 @@ class VCFGenotyper : public GraphCaller, public VCFOutputCaller, public GAFOutpu

virtual ~VCFGenotyper();

virtual bool call_snarl(const Snarl& snarl);
virtual bool call_snarl(const Snarl& snarl, int ploidy_override = -1, vector<int>* out_child_ploidies = nullptr);

virtual string vcf_header(const PathHandleGraph& graph, const vector<string>& contigs,
const vector<size_t>& contig_length_overrides = {}) const;
Expand Down Expand Up @@ -274,7 +286,7 @@ class LegacyCaller : public GraphCaller, public VCFOutputCaller {

virtual ~LegacyCaller();

virtual bool call_snarl(const Snarl& snarl);
virtual bool call_snarl(const Snarl& snarl, int ploidy_override = -1, vector<int>* out_child_ploidies = nullptr);

virtual string vcf_header(const PathHandleGraph& graph, const vector<string>& contigs,
const vector<size_t>& contig_length_overrides = {}) const;
Expand Down Expand Up @@ -375,7 +387,7 @@ class FlowCaller : public GraphCaller, public VCFOutputCaller, public GAFOutputC

virtual ~FlowCaller();

virtual bool call_snarl(const Snarl& snarl);
virtual bool call_snarl(const Snarl& snarl, int ploidy_override = -1, vector<int>* out_child_ploidies = nullptr);

virtual string vcf_header(const PathHandleGraph& graph, const vector<string>& contigs,
const vector<size_t>& contig_length_overrides = {}) const;
Expand Down Expand Up @@ -449,7 +461,7 @@ class NestedFlowCaller : public GraphCaller, public VCFOutputCaller, public GAFO

virtual ~NestedFlowCaller();

virtual bool call_snarl(const Snarl& snarl);
virtual bool call_snarl(const Snarl& snarl, int ploidy_override = -1, vector<int>* out_child_ploidies = nullptr);

virtual string vcf_header(const PathHandleGraph& graph, const vector<string>& contigs,
const vector<size_t>& contig_length_overrides = {}) const;
Expand Down

1 comment on commit ea5d6e5

@adamnovak
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vg CI tests complete for branch rgfa2. View the full report here.

16 tests passed, 0 tests failed and 0 tests skipped in 17334 seconds

Please sign in to comment.