Skip to content

Commit

Permalink
more optimizations, about 27% faster
Browse files Browse the repository at this point in the history
  • Loading branch information
greenaddress committed May 16, 2024
1 parent 7650813 commit 9efb473
Show file tree
Hide file tree
Showing 13 changed files with 299 additions and 771 deletions.
156 changes: 54 additions & 102 deletions src/bytewords.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,152 +65,104 @@ static const int16_t _lookup[] = {
242, -1, -1, -1
};

static bool decode_word(const string& word, size_t word_len, uint8_t& output) {
const size_t dim = 26;
static inline bool decode_word(const string& word, size_t word_len, uint8_t& output) {
constexpr size_t dim = 26;

// Sanity check
if(word.length() != word_len) {
if (word.length() != word_len) {
return false;
}

// If the coordinates generated by the first and last letters are out of bounds,
// or the lookup table contains -1 at the coordinates, then the word is not valid.
int x = tolower(word[0]) - 'a';
int y = tolower(word[word_len == 4 ? 3 : 1]) - 'a';
if(!(0 <= x && x < dim && 0 <= y && y < dim)) {

if (static_cast<unsigned>(x) >= dim || static_cast<unsigned>(y) >= dim) {
return false;
}

size_t offset = y * dim + x;
int16_t value = _lookup[offset];
if(value == -1) {
if (value == -1) {
return false;
}

// If we're decoding a full four-letter word, verify that the two middle letters are correct.
if(word_len == 4) {
if (word_len == 4) {
const char* byteword = bytewords + value * 4;
int c1 = tolower(word[1]);
int c2 = tolower(word[2]);
if(c1 != byteword[1] || c2 != byteword[2]) {
if (tolower(word[1]) != byteword[1] || tolower(word[2]) != byteword[2]) {
return false;
}
}

// Successful decode.
output = value;
output = static_cast<uint8_t>(value);
return true;
}

static string get_word(uint8_t index) {
const auto* p = &bytewords[index * 4];
return {p, p + 4};
}

static string get_minimal_word(uint8_t index) {
string word;
word.reserve(2);
const auto* p = &bytewords[index * 4];
word.push_back(*p);
word.push_back(*(p + 3));
return word;
}

static string encode(const ByteVector& buf, const string& separator) {
auto len = buf.size();
StringVector words;
words.reserve(len);
for(int i = 0; i < len; i++) {
auto byte = buf[i];
words.push_back(get_word(byte));
}
return join(words, separator);
}

static inline ByteVector crc32_bytes(const ByteVector &buf) {
uint32_t checksum = __builtin_bswap32(esp_crc32_le(0, buf.data(), buf.size()));
const uint32_t checksum = __builtin_bswap32(esp_crc32_le(0, buf.data(), buf.size()));
ByteVector result(sizeof(checksum));
std::memcpy(result.data(), &checksum, sizeof(checksum));
memcpy(result.data(), &checksum, sizeof(checksum));
return result;
}

static ByteVector add_crc(const ByteVector& buf) {
auto crc_buf = crc32_bytes(buf);
auto result = buf;
string Bytewords::encode(style style, const ByteVector& bytes) {
auto crc_buf = crc32_bytes(bytes);
ByteVector result = bytes;
append(result, crc_buf);
return result;
}

static string encode_with_separator(const ByteVector& buf, const string& separator) {
auto crc_buf = add_crc(buf);
return encode(crc_buf, separator);
}

static string encode_minimal(const ByteVector& buf) {
string result;
auto crc_buf = add_crc(buf);
auto len = crc_buf.size();
for(int i = 0; i < len; i++) {
auto byte = crc_buf[i];
result.append(get_minimal_word(byte));
if (style == minimal) {
std::string r;
r.reserve(result.size() * 2);
for (uint8_t byte : result) {
const char* p = &bytewords[byte * 4];
r.push_back(p[0]);
r.push_back(p[3]);
}
return r;
}
return result;
}

static ByteVector _decode(const string& s, char separator, size_t word_len) {
assert(style == standard || style == uri);

StringVector words;
if(word_len == 4) {
words = split(s, separator);
} else {
words = partition(s, 2);
words.reserve(result.size());
for (uint8_t byte : result) {
words.emplace_back(&bytewords[byte * 4], 4);
}
return join(words, style == standard ? " " : "-");
}

ByteVector Bytewords::decode(style style, const string& s) {
assert(style == standard || style == uri || style == minimal);
const size_t word_len = (style == minimal) ? 2 : 4;
const char separator = (style == standard) ? ' ' : (style == uri) ? '-' : 0;

StringVector words = (word_len == 4) ? split(s, separator) : partition(s, 2);

const size_t num_words = words.size();
if (num_words < 5) return ByteVector();

ByteVector buf;
buf.reserve(words.size());
for (const auto &word : words) {
buf.reserve(num_words);

for (const auto& word : words) {
uint8_t output;
if (!decode_word(word, word_len, output)) {
// Failed to decode word
return ByteVector();
}
buf.push_back(output);
}
if(buf.size() < 5) {
return ByteVector();
}
auto p = split(buf, buf.size() - 4);
auto body = p.first;
auto body_checksum = p.second;
auto checksum = crc32_bytes(body);
if(checksum != body_checksum) {
return ByteVector();
}

return body;
}
if (buf.size() < 5) return ByteVector();

string Bytewords::encode(style style, const ByteVector& bytes) {
switch(style) {
case standard:
return encode_with_separator(bytes, " ");
case uri:
return encode_with_separator(bytes, "-");
case minimal:
return encode_minimal(bytes);
}
assert(false);
return string();
}
const auto body_size = buf.size() - 4;
const ByteVector body(buf.begin(), buf.begin() + body_size);
const ByteVector body_checksum(buf.begin() + body_size, buf.end());

ByteVector Bytewords::decode(style style, const string& string) {
switch(style) {
case standard:
return _decode(string, ' ', 4);
case uri:
return _decode(string, '-', 4);
case minimal:
return _decode(string, 0, 2);
const auto checksum = crc32_bytes(body);

if (std::equal(body_checksum.begin(), body_checksum.end(), checksum.begin())) {
return body;
}
assert(false);
return ByteVector();
}


}
81 changes: 49 additions & 32 deletions src/fountain-decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ bool FountainDecoder::receive_part(FountainEncoder::Part& encoder_part) {
}

// Keep track of how many parts we've processed
processed_parts_count_ += 1;
++processed_parts_count_;

return true;
}
Expand All @@ -85,68 +85,81 @@ void FountainDecoder::process_queue_item() {
}

void FountainDecoder::reduce_mixed_by(const Part& p) {
// Reduce all the current mixed parts by the given part
PartVector reduced_parts;
for(auto i = _mixed_parts.begin(); i != _mixed_parts.end(); i++) {
reduced_parts.push_back(reduce_part_by_part(i->second, p));
reduced_parts.reserve(_mixed_parts.size());

for (auto it = _mixed_parts.begin(); it != _mixed_parts.end(); ++it) {
reduced_parts.push_back(reduce_part_by_part(it->second, p));
}

// Collect all the remaining mixed parts
PartDict new_mixed;
for(auto reduced_part: reduced_parts) {
// If this reduced part is now simple
if(reduced_part.is_simple()) {
// Add it to the queue

for (auto& reduced_part : reduced_parts) {
if (reduced_part.is_simple()) {
enqueue(reduced_part);
} else {
// Otherwise, add it to the list of current mixed parts
new_mixed.insert(pair(reduced_part.indexes(), reduced_part));
new_mixed.emplace(reduced_part.indexes(), move(reduced_part));
}
}
_mixed_parts = new_mixed;

_mixed_parts = move(new_mixed);
}


FountainDecoder::Part FountainDecoder::reduce_part_by_part(const Part& a, const Part& b) const {
// If the fragments mixed into `b` are a strict (proper) subset of those in `a`...
if(is_strict_subset(b.indexes(), a.indexes())) {
// The new fragments in the revised part are `a` - `b`.
if (is_strict_subset(b.indexes(), a.indexes())) {
auto new_indexes = set_difference(a.indexes(), b.indexes());
// The new data in the revised part are `a` XOR `b`
auto new_data = xor_with(a.data(), b.data());

ByteVector new_data = a.data();
const auto& s = b.data();
const size_t count = new_data.size();

for (size_t i = 0; i < count; ++i) {
new_data[i] ^= s[i];
}

return Part(new_indexes, new_data);
} else {
// `a` is not reducable by `b`, so return a
return a;
}
}

void FountainDecoder::process_simple_part(Part& p) {
// Don't process duplicate parts
auto fragment_index = p.index();
if(contains(received_part_indexes_, fragment_index)) return;
if (received_part_indexes_.find(fragment_index) != received_part_indexes_.end()) return;

// Record this part
_simple_parts.insert(pair(p.indexes(), p));
_simple_parts.emplace(p.indexes(), p);
received_part_indexes_.insert(fragment_index);

// If we've received all the parts
if(received_part_indexes_ == _expected_part_indexes) {
if (received_part_indexes_ == _expected_part_indexes) {
// Reassemble the message from its fragments
PartVector sorted_parts;
transform(_simple_parts.begin(), _simple_parts.end(), back_inserter(sorted_parts), [&](auto elem) { return elem.second; });
sorted_parts.reserve(_simple_parts.size());
for (const auto& elem : _simple_parts) {
sorted_parts.push_back(elem.second);
}

sort(sorted_parts.begin(), sorted_parts.end(),
[](const Part& a, const Part& b) -> bool {
return a.index() < b.index();
}
);

ByteVectorVector fragments;
transform(sorted_parts.begin(), sorted_parts.end(), back_inserter(fragments), [&](auto part) { return part.data(); });
fragments.reserve(sorted_parts.size());
for (const auto& part : sorted_parts) {
fragments.push_back(part.data());
}

auto message = join_fragments(fragments, *_expected_message_len);

// Verify the message checksum and note success or failure
auto checksum = esp_crc32_le(0, message.data(), message.size());
if(checksum == _expected_checksum) {
result_ = message;
if (checksum == _expected_checksum) {
result_ = move(message);
} else {
result_ = InvalidChecksum();
}
Expand All @@ -158,23 +171,27 @@ void FountainDecoder::process_simple_part(Part& p) {

void FountainDecoder::process_mixed_part(const Part& p) {
// Don't process duplicate parts
if(any_of(_mixed_parts.begin(), _mixed_parts.end(), [&](auto r) { return r.first == p.indexes(); })) {
if (any_of(_mixed_parts.begin(), _mixed_parts.end(), [&](const auto& r) { return r.first == p.indexes(); })) {
return;
}

// Reduce this part by all the others
auto p2 = accumulate(_simple_parts.begin(), _simple_parts.end(), p, [&](auto p, auto r) { return reduce_part_by_part(p, r.second); });
p2 = accumulate(_mixed_parts.begin(), _mixed_parts.end(), p2, [&](auto p, auto r) { return reduce_part_by_part(p, r.second); });
Part p2 = p;
for (const auto& r : _simple_parts) {
p2 = reduce_part_by_part(p2, r.second);
}
for (const auto& r : _mixed_parts) {
p2 = reduce_part_by_part(p2, r.second);
}

// If the part is now simple
if(p2.is_simple()) {
if (p2.is_simple()) {
// Add it to the queue
enqueue(p2);
} else {
// Reduce all the mixed parts by this one
reduce_mixed_by(p2);
// Record this new mixed part
_mixed_parts.insert(pair(p2.indexes(), p2));
_mixed_parts.emplace(p2.indexes(), p2);
}
}

Expand All @@ -185,7 +202,7 @@ bool FountainDecoder::validate_part(const FountainEncoder::Part& p) {
if(!_expected_part_indexes.has_value()) {
// Record the things that all the other parts we see will have to match to be valid.
_expected_part_indexes = PartIndexes();
for(size_t i = 0; i < p.seq_len(); i++) { _expected_part_indexes->insert(i); }
for(size_t i = 0; i < p.seq_len(); ++i) { _expected_part_indexes->insert(i); }
_expected_message_len = p.message_len();
_expected_checksum = p.checksum();
_expected_fragment_len = p.data().size();
Expand Down
Loading

0 comments on commit 9efb473

Please sign in to comment.