Skip to content

Commit

Permalink
refactor(assets): Add offline assets package generation
Browse files Browse the repository at this point in the history
  • Loading branch information
awwaawwa committed Mar 2, 2025
1 parent e65e4ea commit c0c60bf
Showing 1 changed file with 128 additions and 2 deletions.
130 changes: 128 additions & 2 deletions babeldoc/assets/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import logging
import threading
import zipfile
from pathlib import Path

import httpx
Expand Down Expand Up @@ -180,7 +181,7 @@ async def get_fastest_upstream(client: httpx.AsyncClient | None = None):

async def get_doclayout_onnx_model_path_async(client: httpx.AsyncClient | None = None):
onnx_path = get_cache_file_path(
"doclayout_yolo_docstructbench_imgsz1024.onnx", "model"
"doclayout_yolo_docstructbench_imgsz1024.onnx", "models"
)
if verify_file(onnx_path, DOCLAYOUT_YOLO_DOCSTRUCTBENCH_IMGSZ1024ONNX_SHA3_256):
return onnx_path
Expand Down Expand Up @@ -260,6 +261,16 @@ def get_font_family(lang_code: str):


async def download_all_fonts_async(client: httpx.AsyncClient | None = None):
for font_file_name in EMBEDDING_FONT_METADATA:
if not verify_file(
get_cache_file_path(font_file_name, "fonts"),
EMBEDDING_FONT_METADATA[font_file_name]["sha3_256"],
):
break
else:
logger.debug("All fonts are already downloaded")
return

fastest_upstream, font_metadata = await get_fastest_upstream_for_font(client)
if fastest_upstream is None:
logger.error("Failed to get fastest upstream")
Expand Down Expand Up @@ -288,10 +299,125 @@ def warmup():
run_coro(async_warmup())


def generate_all_assets_file_list():
result = {}
result["fonts"] = []
result["models"] = []
for font_file_name in EMBEDDING_FONT_METADATA:
result["fonts"].append(
{
"name": font_file_name,
"sha3_256": EMBEDDING_FONT_METADATA[font_file_name]["sha3_256"],
}
)
result["models"].append(
{
"name": "doclayout_yolo_docstructbench_imgsz1024.onnx",
"sha3_256": DOCLAYOUT_YOLO_DOCSTRUCTBENCH_IMGSZ1024ONNX_SHA3_256,
}
)
return result


async def generate_offline_assets_package_async(output_path: Path | None = None):
await async_warmup()
file_list = generate_all_assets_file_list()
offline_assets_tag = get_offline_assets_tag(file_list)
if output_path is None:
output_path = get_cache_file_path(
f"offline_assets_{offline_assets_tag}.zip", "assets"
)
with zipfile.ZipFile(
output_path, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9
) as zipf:
for file_type, file_descs in file_list.items():
zipf.mkdir(file_type)
for file_desc in file_descs:
file_name = file_desc["name"]
sha3_256 = file_desc["sha3_256"]
file_path = get_cache_file_path(file_name, file_type)
if not verify_file(file_path, sha3_256):
logger.error(f"File {file_path} is corrupted")
exit(1)

with file_path.open("rb") as f:
zipf.writestr(f"{file_type}/{file_name}", f.read())
logger.info(f"Offline assets package generated at {output_path}")


async def restore_offline_assets_package_async(input_path: Path | None = None):
file_list = generate_all_assets_file_list()
offline_assets_tag = get_offline_assets_tag(file_list)
if input_path is None:
input_path = get_cache_file_path(
f"offline_assets_{offline_assets_tag}.zip", "assets"
)
else:
import re

offline_assets_tag_from_input_path = re.match(
r"offline_assets_(.*)\.zip", input_path.name
).group(1)
if offline_assets_tag != offline_assets_tag_from_input_path:
logger.critical(
f"Offline assets tag mismatch: {offline_assets_tag} != {offline_assets_tag_from_input_path}"
)
exit(1)
nothing_changed = True
with zipfile.ZipFile(input_path, "r") as zipf:
for file_type, file_descs in file_list.items():
for file_desc in file_descs:
file_name = file_desc["name"]
file_path = get_cache_file_path(file_name, file_type)

if verify_file(file_path, file_desc["sha3_256"]):
continue
nothing_changed = False
with zipf.open(f"{file_type}/{file_name}", "r") as f:
with file_path.open("wb") as f2:
f2.write(f.read())
if not verify_file(file_path, file_desc["sha3_256"]):
logger.critical(
"Offline assets package is corrupted, please delete it and try again"
)
exit(1)
if not nothing_changed:
logger.info(f"Offline assets package restored from {input_path}")


def get_offline_assets_tag(file_list: dict | None = None):
if file_list is None:
file_list = generate_all_assets_file_list()
import orjson

# noinspection PyTypeChecker
offline_assets_tag = hashlib.sha3_256(
orjson.dumps(
file_list,
option=orjson.OPT_APPEND_NEWLINE
| orjson.OPT_INDENT_2
| orjson.OPT_SORT_KEYS,
)
).hexdigest()
return offline_assets_tag


def generate_offline_assets_package(output_path: Path | None = None):
return run_coro(generate_offline_assets_package_async(output_path))


def restore_offline_assets_package(input_path: Path):
return run_coro(restore_offline_assets_package_async(input_path))


if __name__ == "__main__":
from rich.logging import RichHandler

logging.basicConfig(level=logging.DEBUG, handlers=[RichHandler()])
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
warmup()
# warmup()
# generate_offline_assets_package()
# restore_offline_assets_package(Path(
# '/Users/aw/.cache/babeldoc/assets/offline_assets_33971e4940e90ba0c35baacda44bbe83b214f4703a7bdb8b837de97d0383508c.zip'))
# warmup()

0 comments on commit c0c60bf

Please sign in to comment.