Skip to content

Commit

Permalink
partition sparse switches into packed segments
Browse files Browse the repository at this point in the history
Summary:
This change is motivated by agampe's analysis and idea to do more fine-trained splitting when beneficial.

This generalizes the splitting of switches:
```
1. Splitting sparse switches into packed segments and remaining sparse
   switches. We only do this if we can find a partitioning of a sparse
   switch of size N into M packed segments and a remaining sparse switch
   of size L such that
       M + log2(L) <= log2(N).
   This transformation is largely size neutral.
   Before the transformation, the interpreter would have O(log2(N)) and
   compiled code O(N). After the tranformation, the interpreter gets down
   to O(M + log2(L)) and compiled code to O(M + L). So while the runtime
   performance of the interpreter won't change much, compiled code will run
   must faster; if L gets close to 0, we get O(M) <= O(log2(N)).
```

Reviewed By: agampe

Differential Revision: D68985937

fbshipit-source-id: 72d67aef0172749a4a259b527c163e5acc9dd7c5
  • Loading branch information
Nikolai Tillmann authored and facebook-github-bot committed Feb 4, 2025
1 parent b228b1b commit c203380
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 88 deletions.
223 changes: 148 additions & 75 deletions opt/reduce-sparse-switches/ReduceSparseSwitchesPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "ScopedCFG.h"
#include "Show.h"
#include "SourceBlocks.h"
#include "StlUtil.h"
#include "Trace.h"
#include "Walkers.h"

Expand Down Expand Up @@ -50,13 +51,19 @@ bool is_sufficiently_sparse(cfg::Block* block) {
return ckeb->sufficiently_sparse();
}

static cfg::Block* split_sparse_switch_into_packed_and_sparse(
struct Segment {
int64_t first;
int64_t last;
Segment(int64_t first, int64_t last) : first(first), last(last) {}
size_t size() const { return last - first + 1; }
};

static void split_sparse_switch_into_packed_and_sparse(
cfg::ControlFlowGraph& cfg,
cfg::Block* block,
const IRList::iterator& switch_insn_it,
IRList::iterator switch_insn_it,
const std::vector<int32_t>& case_keys,
int64_t first,
int64_t last) {
const std::vector<Segment>& packed_segments) {
// We now rewrite the switch from
// /* sparse */ switch (selector) {
// case K_0: goto B_0;
Expand Down Expand Up @@ -87,37 +94,45 @@ static cfg::Block* split_sparse_switch_into_packed_and_sparse(
// }
// }

auto selector_reg = switch_insn_it->insn->src(0);
auto* goto_block = block->goes_to();
auto* secondary_switch_block = cfg.create_block();
auto succs_copy = block->succs();
auto first_case_key = case_keys[first];
auto last_case_key = case_keys[last];
std::vector<std::pair<int32_t, cfg::Block*>> secondary_switch_case_to_block;
secondary_switch_case_to_block.reserve(case_keys.size() - last + first - 1);
std::vector<cfg::Edge*> sparse_edges;
sparse_edges.reserve(case_keys.size() - last + first - 1);
for (auto* e : succs_copy) {
if (e->type() == cfg::EDGE_GOTO) {
cfg.set_edge_target(e, secondary_switch_block);
continue;
for (auto [first, last] : packed_segments) {
auto selector_reg = switch_insn_it->insn->src(0);
auto* goto_block = block->goes_to();
auto first_case_key = case_keys[first];
auto last_case_key = case_keys[last];
std::vector<std::pair<int32_t, cfg::Block*>> secondary_switch_case_to_block;
std::vector<cfg::Edge*> sparse_edges;
cfg::Edge* default_edge = nullptr;
for (auto* e : block->succs()) {
if (e->type() == cfg::EDGE_GOTO) {
default_edge = e;
continue;
}
always_assert(e->type() == cfg::EDGE_BRANCH);
auto case_key = *e->case_key();
if (case_key >= first_case_key && case_key <= last_case_key) {
continue;
}
secondary_switch_case_to_block.emplace_back(case_key, e->target());
sparse_edges.push_back(e);
}
always_assert(e->type() == cfg::EDGE_BRANCH);
auto case_key = *e->case_key();
if (case_key >= first_case_key && case_key <= last_case_key) {
continue;
if (secondary_switch_case_to_block.empty()) {
// The last switch is including all remaining case keys, so there's
// nothing to rewrite.
break;
}
secondary_switch_case_to_block.emplace_back(case_key, e->target());
sparse_edges.push_back(e);
auto* secondary_switch_block = cfg.create_block();
always_assert(default_edge != nullptr);
cfg.set_edge_target(default_edge, secondary_switch_block);
cfg.delete_edges(sparse_edges.begin(), sparse_edges.end());
cfg.create_branch(
secondary_switch_block,
(new IRInstruction(OPCODE_SWITCH))->set_src(0, selector_reg),
goto_block, secondary_switch_case_to_block);
always_assert(!is_sufficiently_sparse(block));

block = secondary_switch_block;
switch_insn_it = block->get_last_insn();
}
cfg.delete_edges(sparse_edges.begin(), sparse_edges.end());
cfg.create_branch(
secondary_switch_block,
(new IRInstruction(OPCODE_SWITCH))->set_src(0, selector_reg), goto_block,
secondary_switch_case_to_block);

always_assert(!is_sufficiently_sparse(block));
return secondary_switch_block;
}

static void multiplex_sparse_switch_into_packed_and_sparse(
Expand Down Expand Up @@ -265,13 +280,94 @@ void write_sparse_switches(DexStoresVector& stores,
});
}

bool partition(const std::vector<int32_t>& case_keys,
size_t min_switch_cases_per_segment,
std::vector<Segment>* packed_segments,
std::vector<int32_t>* sparse_case_keys) {
// We start by treating each case key as a separate segment. Then we'll
// iteratively marge adjacent segments that together are packed. Whenever we
// merge, we look to the left if there's an adjacent segment that is now
// mergable, and if not, we keep going to the right. Eventually, all mergable
// segments will have been merged. This algorithm is linear in the number of
// case keys (only the later packed segment sorting is obviously not).
packed_segments->reserve(case_keys.size());
auto back = [v = packed_segments]() -> Segment& { return v->back(); };
auto back2 = [v = packed_segments]() -> Segment& { return v->end()[-2]; };
for (size_t i = 0; i < case_keys.size(); ++i) {
packed_segments->emplace_back(i, i);
// Iteratively fuse last two segments until no longer possible.
while (packed_segments->size() >= 2 &&
!instruction_lowering::CaseKeysExtent{
case_keys[back2().first], case_keys[back().last],
(uint32_t)(back().last - back2().first + 1)}
.sufficiently_sparse()) {
// Fuse last two segments.
back2().last = back().last;
packed_segments->pop_back();
}
}

// We now scan all remaining segments and identify which ones are trivial
// segments, and move those over to the sparse case keys collection.
std20::erase_if(*packed_segments, [&](const auto& segment) {
if (segment.first != segment.last) {
return false;
}
sparse_case_keys->push_back(case_keys[segment.first]);
return true;
});

// Make it so that the largest packed segments come first to reduce average
// runtime cost, assuming that all case-keys get selected with the same
// frequency.
std::sort(packed_segments->begin(), packed_segments->end(),
[](const auto& a, const auto& b) {
if (a.size() != b.size()) {
return a.size() > b.size();
}
return a.first < b.first;
});

// We move unproductive (too small) packed segments over to the remaining
// sparse keys.
auto unpack_last_segment = [&]() {
const auto& segment = packed_segments->back();
for (int64_t i = segment.first; i <= segment.last; i++) {
sparse_case_keys->push_back(case_keys[i]);
}
packed_segments->pop_back();
};
while (!packed_segments->empty() &&
packed_segments->back().size() < sparse_case_keys->size()) {
unpack_last_segment();
}

double partitioned_log2_cost = packed_segments->size();
if (!sparse_case_keys->empty()) {
partitioned_log2_cost += std::log2(sparse_case_keys->size());
}
if (partitioned_log2_cost > std::log2(case_keys.size())) {
return false;
}

// Okay, so it's conceptually worthwhile doing. We still want to avoid too
// small packed segments for practical reason.
while (!packed_segments->empty() &&
packed_segments->back().size() < min_switch_cases_per_segment) {
unpack_last_segment();
}

return !packed_segments->empty();
}
} // namespace

// Find switches which can be split into packed and sparse switches, and apply
// the transformation.
ReduceSparseSwitchesPass::Stats
ReduceSparseSwitchesPass::splitting_transformation(size_t min_switch_cases,
cfg::ControlFlowGraph& cfg) {
ReduceSparseSwitchesPass::splitting_transformation(
size_t min_switch_cases,
size_t min_switch_cases_per_segment,
cfg::ControlFlowGraph& cfg) {
always_assert(min_switch_cases > 0);
ReduceSparseSwitchesPass::Stats stats;
for (auto* block : cfg.blocks()) {
Expand Down Expand Up @@ -303,44 +399,22 @@ ReduceSparseSwitchesPass::splitting_transformation(size_t min_switch_cases,
always_assert(case_keys.size() + 1 == block->succs().size());
always_assert(!case_keys.empty());
std::sort(case_keys.begin(), case_keys.end());
int64_t min_splitting_size = (case_keys.size() + 1) / 2;
always_assert(min_splitting_size > 0);
int64_t first = 0;
int64_t last = case_keys.size() - 1;
while (last - first + 1 >= min_splitting_size &&
instruction_lowering::CaseKeysExtent{
case_keys[first], case_keys[last], (uint32_t)(last - first + 1)}
.sufficiently_sparse()) {
// The way we defined min_splitting_size implies that the number of packed
// switch case keys we are looking for is at least half the size of the
// original switch. Thus, the middle case key must be part of the packed
// switch case keys. We use this fact to decide which extreme case key to
// eliminate. However, we need to do some extra gymnastics to deal with
// the case of an even switch size where there is no middle case key.
auto mid_case_key2 = (case_keys[(first + last) / 2] +
(int64_t)case_keys[(first + last + 1) / 2]);
if (mid_case_key2 - 2 * (int64_t)case_keys[first] >
2 * (int64_t)case_keys[last] - mid_case_key2) {
first++;
} else {
last--;
}
}
if (last - first + 1 < min_splitting_size) {

std::vector<Segment> packed_segments;
std::vector<int32_t> sparse_case_keys;

if (!partition(case_keys, min_switch_cases_per_segment, &packed_segments,
&sparse_case_keys)) {
continue;
}

auto* remaining_block = split_sparse_switch_into_packed_and_sparse(
cfg, block, last_insn_it, case_keys, first, last);
split_sparse_switch_into_packed_and_sparse(cfg, block, last_insn_it,
case_keys, packed_segments);

stats.splitting_transformations++;
if (is_sufficiently_sparse(remaining_block)) {
stats.splitting_transformations_packed_segments++;
stats.splitting_transformations_switch_cases_packed += last - first + 1;
} else {
stats.splitting_transformations_packed_segments += 2;
stats.splitting_transformations_switch_cases_packed += case_keys.size();
}
stats.splitting_transformations_packed_segments += packed_segments.size();
stats.splitting_transformations_switch_cases_packed +=
case_keys.size() - sparse_case_keys.size();
}
return stats;
}
Expand Down Expand Up @@ -433,6 +507,10 @@ void ReduceSparseSwitchesPass::bind_config() {
m_config.min_splitting_switch_cases,
m_config.min_splitting_switch_cases);

bind("min_splitting_switch_cases_per_segment",
m_config.min_splitting_switch_cases_per_segment,
m_config.min_splitting_switch_cases_per_segment);

bind("min_multiplexing_switch_cases",
m_config.min_multiplexing_switch_cases,
m_config.min_multiplexing_switch_cases);
Expand Down Expand Up @@ -464,14 +542,9 @@ void ReduceSparseSwitchesPass::run_pass(DexStoresVector& stores,
}

auto& cfg = code.cfg();
Stats local_stats;
size_t last_splitting_transformations = 0;
do {
last_splitting_transformations = local_stats.splitting_transformations;
local_stats +=
splitting_transformation(m_config.min_splitting_switch_cases, cfg);
} while (last_splitting_transformations !=
local_stats.splitting_transformations);
Stats local_stats = splitting_transformation(
m_config.min_splitting_switch_cases,
m_config.min_splitting_switch_cases_per_segment, cfg);

local_stats += multiplexing_transformation(
m_config.min_multiplexing_switch_cases, cfg);
Expand Down
31 changes: 21 additions & 10 deletions opt/reduce-sparse-switches/ReduceSparseSwitchesPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ class ReduceSparseSwitchesPass : public Pass {

struct Config {
// Starting at 10, the splitting transformation is always a code-size win
// when using 1 packed segment.
uint64_t min_splitting_switch_cases{10};

// To avoid excessive overhead.
uint64_t min_splitting_switch_cases_per_segment{3};

uint64_t min_multiplexing_switch_cases{64};

uint64_t write_sparse_switches{std::numeric_limits<uint64_t>::max()};
Expand Down Expand Up @@ -75,12 +79,17 @@ O(N) time to execute, where N is the number of switch cases.
This pass performs two transformations which are designed to improve
runtime performance:
1. Splitting sparse switches into a main packed switch and a secondary sparse
switch. This transformation is only performed if we find a packed
sub-sequence of case keys that contains at least half of all case keys,
so that we shave off at least one operation from the binary search the
interpreter needs to do over the remaining sparse case keys, making sure
we never degrade worst-case complexity. We run this to a fixed point.
1. Splitting sparse switches into packed segments and remaining sparse
switches. We only do this if we can find a partitioning of a sparse
switch of size N into M packed segments and a remaining sparse switch
of size L such that
M + log2(L) <= log2(N).
This transformation is largely size neutral.
Before the transformation, the interpreter would have O(log2(N)) and
compiled code O(N). After the tranformation, the interpreter gets down
to O(M + log2(L)) and compiled code to O(M + L). So while the runtime
performance of the interpreter won't change much, compiled code will run
must faster; if L gets close to 0, we get O(M) <= O(log2(N)).
2. Multiplexing sparse switches into a main packed switch with secondary sparse
switches for each main switch case. The basic idea is that we partition a
large number of sparse switch cases into several buckets of relatively small
Expand All @@ -93,10 +102,10 @@ runtime performance:
Given a switch with N case keys, we aim at partitioning it into
M = ~sqrt(N) buckets with ~sqrt(N) case keys in each bucket. (We don't
achieve that in practice, and there are rounding effects as well.)
In that case, before the transformation, the interpreter would have O(log N)
In that case, before the transformation, the interpreter would have O(log2(N))
and compiled code O(N). After the tranformation, the interpreter gets down
to O(log sqrt(N)) and compiled code to O(sqrt(N)).
(We could try to partition buckets even further, e.g. down to log(N), but
to O(log2(sqrt(N))) and compiled code to O(sqrt(N)).
(We could try to partition buckets even further, e.g. down to log2(N), but
that might result in an excessive size regression.)
)");
}
Expand All @@ -106,7 +115,9 @@ runtime performance:
void run_pass(DexStoresVector&, ConfigFiles&, PassManager&) override;

static ReduceSparseSwitchesPass::Stats splitting_transformation(
size_t min_switch_cases, cfg::ControlFlowGraph& cfg);
size_t min_switch_cases,
size_t min_switch_cases_per_segment,
cfg::ControlFlowGraph& cfg);

static ReduceSparseSwitchesPass::Stats multiplexing_transformation(
size_t min_switch_cases, cfg::ControlFlowGraph& cfg);
Expand Down
Loading

0 comments on commit c203380

Please sign in to comment.