Skip to content

Commit

Permalink
Merge pull request #314 from 4ms/wifi_compressed_flash
Browse files Browse the repository at this point in the history
Flash compressed wifi images
  • Loading branch information
danngreen authored Sep 26, 2024
2 parents e076e14 + 839c88c commit bcb671a
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 39 deletions.
33 changes: 25 additions & 8 deletions firmware/flashing/manifest_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,46 @@
import hashlib
import re
import shutil
import zlib

ManifestFormatVersion = 1


def process_file(dest_dir, filename, imagetype, *, name=None, address=None):
def process_file(dest_dir, filename, imagetype, *, name=None, address=None, compressed=False):

with open(filename, 'rb') as file:

entry = dict()

entry["filename"] = os.path.basename(filename)
entry["type"] = imagetype
entry["filesize"] = os.stat(filename).st_size
entry["md5"] = hashlib.md5(file.read()).hexdigest()
entry["address"] = int(address)

if name is not None:
entry["name"] = name

try:
shutil.copyfile(filename, os.path.join(dest_dir, os.path.basename(filename)))
except shutil.SameFileError:
pass
if not compressed:
entry["filename"] = os.path.basename(filename)
entry["filesize"] = os.stat(filename).st_size

try:
shutil.copyfile(filename, os.path.join(dest_dir, entry["filename"]))
except shutil.SameFileError:
pass
else:
entry["filename"] = os.path.basename(filename) + ".gz"
entry["uncompressed_size"] = os.stat(filename).st_size

with open(filename, "rb") as orig_file:
with open(os.path.join(dest_dir, entry["filename"]), "wb") as dest_file:

# use level 9 as in esptool, can probably be changed
compressed_data = zlib.compress(orig_file.read(), level=9)
dest_file.write(compressed_data)

entry["filesize"] = len(compressed_data)



return entry

Expand Down Expand Up @@ -97,7 +114,7 @@ def parse_file_version(version):
j["files"].append(process_file(destination_dir, args.wifi_app_file, "wifi", name="Wifi Application", address=0x10000))

if args.wifi_fs_file:
j["files"].append(process_file(destination_dir, args.wifi_fs_file, "wifi", name="Wifi Filesystem", address=0x200000))
j["files"].append(process_file(destination_dir, args.wifi_fs_file, "wifi", name="Wifi Filesystem", address=0x200000, compressed=True))

with open(args.out_file, "w+") as out_file:
data_json = json.dumps(j, indent=4)
Expand Down
2 changes: 1 addition & 1 deletion firmware/lib/esp-serial-flasher
1 change: 1 addition & 0 deletions firmware/src/calibrate/calibration_routine.hh
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ private:
if (storage.request_file_flash(IntercoreStorageMessage::FlashTarget::QSPI,
{(uint8_t *)(&cal_data), sizeof(cal_data)},
CalDataFlashOffset,
std::nullopt,
&bytes_written))
state = State::WritingCal;
}
Expand Down
1 change: 1 addition & 0 deletions firmware/src/core_intercom/intercore_message.hh
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ struct IntercoreStorageMessage {

uint32_t address;
uint32_t length;
std::optional<uint32_t> uncompressed_size;
StaticString<32> checksum;
uint32_t *bytes_processed;
enum FlashTarget : uint8_t { WIFI, QSPI };
Expand Down
3 changes: 3 additions & 0 deletions firmware/src/core_m4/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ include(${CMAKE_SOURCE_DIR}/cmake/common.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/arch_mp15xm4.cmake)
set_arch_flags()

include(${CMAKE_SOURCE_DIR}/cmake/log_levels.cmake)
enable_logging()

# set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)

set(M4DIR ${CMAKE_CURRENT_SOURCE_DIR}) # /firmware/src/core_m4
Expand Down
42 changes: 25 additions & 17 deletions firmware/src/fw_update/firmware_writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ std::optional<IntercoreStorageMessage> FirmwareWriter::handle_message(const Inte
auto buf = std::span<uint8_t>{(uint8_t *)message.buffer.data(), message.buffer.size()};

if (message.flashTarget == WIFI) {
return flashWifi(buf, message.address, *message.bytes_processed);
return flashWifi(buf, message.address, message.uncompressed_size, *message.bytes_processed);

} else if (message.flashTarget == QSPI) {
return flashQSPI(buf, message.address, *message.bytes_processed);
Expand All @@ -62,21 +62,20 @@ IntercoreStorageMessage FirmwareWriter::compareChecksumWifi(uint32_t address, ui

if (result == ESP_LOADER_SUCCESS) {

if (result == ESP_LOADER_SUCCESS) {
result = Flasher::verify(address, length, checksum);
if (result == ESP_LOADER_ERROR_INVALID_MD5) {
returnValue = {.message_type = ChecksumMismatch};
} else if (result == ESP_LOADER_SUCCESS) {
pr_dbg("-> Checksum matches\n");
returnValue = {.message_type = ChecksumMatch};
} else {
pr_dbg("-> Cannot get checksum\n");
returnValue = {.message_type = ChecksumFailed};
}
result = Flasher::verify(address, length, checksum);

if (result == ESP_LOADER_ERROR_INVALID_MD5) {
returnValue = {.message_type = ChecksumMismatch};

} else if (result == ESP_LOADER_SUCCESS) {
pr_dbg("-> Checksum matches\n");
returnValue = {.message_type = ChecksumMatch};

} else {
pr_err("Cannot write dummy byte to wifi flash\n");
pr_err("-> Cannot get checksum\n");
returnValue = {.message_type = ChecksumFailed};
}

} else {
pr_err("Cannot connect to wifi bootloader\n");
returnValue = {.message_type = ChecksumFailed};
Expand All @@ -87,7 +86,10 @@ IntercoreStorageMessage FirmwareWriter::compareChecksumWifi(uint32_t address, ui
return returnValue;
}

IntercoreStorageMessage FirmwareWriter::flashWifi(std::span<uint8_t> buffer, uint32_t address, uint32_t &bytesWritten) {
IntercoreStorageMessage FirmwareWriter::flashWifi(std::span<uint8_t> buffer,
uint32_t address,
std::optional<uint32_t> uncompressed_size,
uint32_t &bytesWritten) {
IntercoreStorageMessage returnValue;

// Stop wifi reception before long running operation
Expand All @@ -100,7 +102,7 @@ IntercoreStorageMessage FirmwareWriter::flashWifi(std::span<uint8_t> buffer, uin
if (result == ESP_LOADER_SUCCESS) {
const std::size_t BatchSize = 1024;

result = Flasher::flash_start(address, buffer.size(), BatchSize);
result = Flasher::flash_start(address, buffer.size(), BatchSize, uncompressed_size);

if (result == ESP_LOADER_SUCCESS) {
bool error_during_writes = false;
Expand All @@ -109,7 +111,7 @@ IntercoreStorageMessage FirmwareWriter::flashWifi(std::span<uint8_t> buffer, uin
auto to_read = std::min<std::size_t>(buffer.size() - bytesWritten, BatchSize);
auto thisBatch = buffer.subspan(bytesWritten, to_read);

result = Flasher::flash_process(thisBatch);
result = Flasher::flash_process(thisBatch, uncompressed_size.has_value());

if (result != ESP_LOADER_SUCCESS) {
error_during_writes = true;
Expand All @@ -120,7 +122,12 @@ IntercoreStorageMessage FirmwareWriter::flashWifi(std::span<uint8_t> buffer, uin
}

if (not error_during_writes) {
pr_trace("-> Flashing completed\n");
result = Flasher::flash_finish(uncompressed_size.has_value());
error_during_writes = result != ESP_LOADER_SUCCESS;
}

if (not error_during_writes) {
pr_dbg("-> Flashing completed\n");
returnValue = {.message_type = FlashingOk};
} else {
pr_trace("-> Flashing failed\n");
Expand Down Expand Up @@ -165,6 +172,7 @@ FirmwareWriter::compareChecksumQSPI(uint32_t address, uint32_t length, Checksum_
offset += bytesToRead;
bytesChecked = offset;
} else {
pr_trace("Flash checksum could not be read\n");
return {.message_type = ChecksumFailed};
}
}
Expand Down
2 changes: 1 addition & 1 deletion firmware/src/fw_update/firmware_writer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public:

private:
IntercoreStorageMessage compareChecksumWifi(uint32_t, uint32_t, Checksum_t);
IntercoreStorageMessage flashWifi(std::span<uint8_t>, uint32_t address, uint32_t &bytesWritten);
IntercoreStorageMessage flashWifi(std::span<uint8_t>, uint32_t address, std::optional<uint32_t> uncompressed_size, uint32_t &bytesWritten);
IntercoreStorageMessage compareChecksumQSPI(uint32_t, uint32_t, Checksum_t, uint32_t &bytesWritten);
IntercoreStorageMessage flashQSPI(std::span<uint8_t>, uint32_t address, uint32_t &bytesWritten);

Expand Down
8 changes: 8 additions & 0 deletions firmware/src/fw_update/manifest_parse.hh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ inline bool read(ryml::ConstNodeRef const &n, UpdateFile *updateFile) {
if (n.has_child("name")) {
n["name"] >> updateFile->name;
}

if (n.has_child("uncompressed_size"))
{
uint32_t val;
n["uncompressed_size"] >> val;
updateFile->uncompressed_size = val;
}

return true;
} else {
pr_err("Missing required fields\n");
Expand Down
1 change: 1 addition & 0 deletions firmware/src/fw_update/update_file.hh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct UpdateFile {
UpdateType type;
std::string filename;
uint32_t filesize = 0;
std::optional<uint32_t> uncompressed_size;
std::optional<StaticString<32>> md5;
uint32_t address;
std::string name;
Expand Down
13 changes: 10 additions & 3 deletions firmware/src/fw_update/updater_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,11 @@ FirmwareUpdaterProxy::Status FirmwareUpdaterProxy::process() {
current_file_name = thisFile.name;

auto result = file_storage.request_checksum_compare(
*target, *checksumValue, thisFile.address, thisFile.filesize, &sharedMem->bytes_processed);
*target,
*checksumValue,
thisFile.address,
thisFile.uncompressed_size.value_or(thisFile.filesize),
&sharedMem->bytes_processed);

if (not result) {
abortWithMessage("Cannot trigger comparing checksums");
Expand Down Expand Up @@ -206,8 +210,11 @@ FirmwareUpdaterProxy::Status FirmwareUpdaterProxy::process() {
current_file_name = thisFile.name;

if (auto target = GetTargetForUpdateType(thisFile.type); target) {
auto result = file_storage.request_file_flash(
*target, thisLoadedFile, thisFile.address, &sharedMem->bytes_processed);
auto result = file_storage.request_file_flash(*target,
thisLoadedFile,
thisFile.address,
thisFile.uncompressed_size,
&sharedMem->bytes_processed);

if (not result) {
abortWithMessage("Cannot trigger flashing");
Expand Down
2 changes: 2 additions & 0 deletions firmware/src/patch_file/file_storage_proxy.hh
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,13 @@ public:
[[nodiscard]] bool request_file_flash(IntercoreStorageMessage::FlashTarget target,
std::span<uint8_t> buffer,
uint32_t address,
std::optional<uint32_t> uncompressed_size,
uint32_t *bytes_processed) {
IntercoreStorageMessage message{
.message_type = StartFlashing,
.buffer = {(char *)buffer.data(), buffer.size()},
.address = address,
.uncompressed_size = uncompressed_size,
.bytes_processed = bytes_processed,
.flashTarget = target,
};
Expand Down
43 changes: 36 additions & 7 deletions firmware/src/wifi/flasher/flasher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "custom_port.h"
#include "esp_loader.h"
#include "esp_loader_io.h"
#include "drivers/stm32xx.h"

#include <console/pr_dbg.hh>

Expand All @@ -27,7 +28,7 @@ esp_loader_error_t init(uint32_t baudrate)
}
else
{
pr_dbg("Flasher: Connected to target\n");
pr_trace("Flasher: Connected to target\n");

err = esp_loader_change_transmission_rate(baudrate);
if (err == ESP_LOADER_ERROR_UNSUPPORTED_FUNC)
Expand All @@ -47,7 +48,7 @@ esp_loader_error_t init(uint32_t baudrate)
}
else
{
pr_dbg("Flasher: Transmission rate changed changed\n");
pr_trace("Flasher: Transmission rate changed\n");
}
}
}
Expand Down Expand Up @@ -119,7 +120,8 @@ esp_loader_error_t flash(uint32_t address, std::span<const uint8_t> buffer)

esp_loader_error_t verify(uint32_t address, uint32_t length, std::string_view checksum)
{
pr_dbg("Flasher: Getting checksum from %08x-%08x\n", address, address + length);
HAL_Delay(10);
pr_trace("Flasher: Getting checksum from %08x-%08x\n", address, address + length);

std::array<uint8_t,32> readValue;

Expand Down Expand Up @@ -178,14 +180,41 @@ esp_loader_error_t conditional_flash(uint32_t address, std::span<const uint8_t>
return result;
}

esp_loader_error_t flash_start(uint32_t address, uint32_t length, uint32_t batchSize)
esp_loader_error_t flash_start(uint32_t address, uint32_t length, uint32_t batchSize, std::optional<std::size_t> uncompressed_size)
{
return esp_loader_flash_start(address, length, batchSize);
if (not uncompressed_size)
{
return esp_loader_flash_start(address, length, batchSize);
}
else
{
return esp_loader_flash_defl_start(address, *uncompressed_size, length, batchSize);
}
}

esp_loader_error_t flash_process(std::span<uint8_t> buffer, bool compressed)
{
if (not compressed)
{
return esp_loader_flash_write(buffer.data(), buffer.size());
}
else
{
return esp_loader_flash_defl_write(buffer.data(), buffer.size());
}
}

esp_loader_error_t flash_process(std::span<uint8_t> buffer)
esp_loader_error_t flash_finish(bool compressed)
{
return esp_loader_flash_write(buffer.data(), buffer.size());
if (not compressed)
{
return esp_loader_flash_finish(false);
}
else
{
printf("Finishing defl flash\n");
return esp_loader_flash_defl_finish(false);
}
}

}
6 changes: 4 additions & 2 deletions firmware/src/wifi/flasher/flasher.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <cstdint>
#include <span>
#include <optional>
#include <string_view>

namespace Flasher
Expand All @@ -17,7 +18,8 @@ esp_loader_error_t flash(uint32_t address, std::span<const uint8_t>);
esp_loader_error_t verify(uint32_t address, uint32_t length, std::string_view);
esp_loader_error_t conditional_flash(uint32_t address, std::span<const uint8_t>, std::string_view);

esp_loader_error_t flash_start(uint32_t address, uint32_t length, uint32_t batchSize);
esp_loader_error_t flash_process(std::span<uint8_t>);
esp_loader_error_t flash_start(uint32_t address, uint32_t length, uint32_t batchSize, std::optional<std::size_t> uncompressed_size);
esp_loader_error_t flash_process(std::span<uint8_t>, bool compressed);
esp_loader_error_t flash_finish(bool compressed);

} // namespace Flasher

0 comments on commit bcb671a

Please sign in to comment.