diff --git a/bin/recompress-raw-mime.py b/bin/recompress-raw-mime.py index bd8ce62b1..b5b40290d 100755 --- a/bin/recompress-raw-mime.py +++ b/bin/recompress-raw-mime.py @@ -30,6 +30,7 @@ DEFAULT_RECOMPRESS_BATCH_SIZE = 100 DEFAULT_BATCH_SIZE = 1000 MAX_RECOMPRESS_BATCH_BYTES = 100 * 1024 * 1024 # 100 MB +MAX_RECOMPRESS_PARALLEL_BYTES = 500 * 1024 * 1024 # 500 MB class Resolution(enum.Enum): @@ -39,7 +40,7 @@ class Resolution(enum.Enum): # https://stackoverflow.com/questions/73395864/how-do-i-wait-when-all-threadpoolexecutor-threads-are-busy -class AvailableThreadPoolExecutor(ThreadPoolExecutor): +class RecompressThreadPoolExecutor(ThreadPoolExecutor): """ThreadPoolExecutor that keeps track of the number of available workers. Refs: @@ -50,7 +51,7 @@ def __init__( self, max_workers=None, thread_name_prefix="", initializer=None, initargs=() ): super().__init__(max_workers, thread_name_prefix, initializer, initargs) - self._running_worker_futures: set[Future] = set() + self._running_worker_futures: dict[Future, int] = {} @property def available_workers(self) -> int: @@ -69,16 +70,21 @@ def wait_for_available_worker(self, timeout: "float | None" = None) -> None: start_time = time.monotonic() while True: - if self.available_workers > 0: + if ( + self.available_workers > 0 + and sum(self._running_worker_futures.values()) + < MAX_RECOMPRESS_PARALLEL_BYTES + ): return if timeout is not None and time.monotonic() - start_time > timeout: raise TimeoutError time.sleep(0.1) def submit(self, fn, /, *args, **kwargs): + size = sum(args[0].values()) f = super().submit(fn, *args, **kwargs) - self._running_worker_futures.add(f) - f.add_done_callback(self._running_worker_futures.remove) + self._running_worker_futures[f] = size + f.add_done_callback(self._running_worker_futures.pop) return f @@ -174,14 +180,14 @@ def overwrite_parallel(compressed_raw_mime_by_sha256: "dict[str, bytes]") -> Non def recompress_batch( - recompress_sha256s: "set[str]", *, dry_run=True, compression_level: int = 3 + recompress_sha256s: "dict[str, int]", *, dry_run=True, compression_level: int = 3 ) -> None: if not recompress_sha256s: return data_by_sha256 = { data_sha256: data - for data_sha256, data in download_parallel(recompress_sha256s) + for data_sha256, data in download_parallel(set(recompress_sha256s)) if data is not None and not data.startswith(blockstore.ZSTD_MAGIC_NUMBER_PREFIX) } @@ -306,7 +312,7 @@ def shutdown(signum, frame): assert batch_size > 0 assert recompress_batch_size > 0 - recompress_executor = AvailableThreadPoolExecutor( + recompress_executor = RecompressThreadPoolExecutor( max_workers=recompress_executor_workers ) @@ -322,7 +328,7 @@ def shutdown(signum, frame): max_size, ) - recompress_sha256s = set() + recompress_sha256s: dict[str, int] = {} recompress_bytes = 0 max_id = None @@ -358,7 +364,7 @@ def shutdown(signum, frame): print(*print_arguments) if resolution is Resolution.RECOMPRESS: - recompress_sha256s.add(message.data_sha256) + recompress_sha256s[message.data_sha256] = message.size recompress_bytes += message.size if (