diff --git a/biobricks/dvc_fetcher.py b/biobricks/dvc_fetcher.py index 5dfd128..e567a4c 100644 --- a/biobricks/dvc_fetcher.py +++ b/biobricks/dvc_fetcher.py @@ -1,8 +1,7 @@ import biobricks.checks import biobricks.config from biobricks.logger import logger - - +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass import requests, threading, time, shutil, os import signal @@ -13,91 +12,42 @@ def signal_handler(signum, frame, interrupt_event): interrupt_event.set() logger.info("Interrupt signal received. Attempting to terminate downloads gracefully...") -class PositionManager: - def __init__(self): - self.available_positions = [] - self.lock = threading.Lock() - self.max_position = 0 - - def get_position(self): - with self.lock: - if self.available_positions: - return self.available_positions.pop(0) - else: - self.max_position += 1 - return self.max_position - - def release_position(self, position): - with self.lock: - self.available_positions.append(position) - self.available_positions.sort() - -class DownloadThread(threading.Thread): - - def __init__(self, url, total_progress_bar, path, headers, position_manager, semaphore, interrupt_event): - super(DownloadThread, self).__init__() - self.url = url - self.total_progress_bar = total_progress_bar - self.path = path - self.headers = headers - self.position_manager = position_manager - self.semaphore = semaphore - self.interrupt_event = interrupt_event - - def run(self): - position = self.position_manager.get_position() - self.path.parent.mkdir(parents=True, exist_ok=True) - try: - response = requests.get(self.url, stream=True, headers=self.headers) - response.raise_for_status() - total_size = int(response.headers.get('content-length', 0)) - block_size = 1024 - - with tqdm(total=total_size, unit='iB', unit_scale=True, disable=False, desc=str(self.path), position=position, leave=False) as progress: - with open(self.path, 'wb') as file: - for data in response.iter_content(chunk_size=block_size): - if self.interrupt_event.is_set(): # Check if the thread should stop - logger.info(f"Stopping download of {self.url}") - return # Exit the thread gracefully - if data: - file.write(data) - progress.update(len(data)) - self.total_progress_bar.update(len(data)) - finally: - self.semaphore.release() # Release the semaphore when the thread is done - self.position_manager.release_position(position) @dataclass class DownloadManager: + headers: dict skip_existing: bool = False progress_bar : tqdm = None active_threads : int = 0 interrupt_event : threading.Event = threading.Event() - def download_files(self, urls, paths, total_size, max_threads=4): + def exec_task(self, url, path): + response = requests.get(url, stream=True, headers=self.headers) + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + block_size = 1024 + with tqdm(total=total_size, unit='iB', unit_scale=True, disable=False, desc=str(path), leave=False) as progress: + with open(path, 'wb') as file: + for data in response.iter_content(chunk_size=block_size): + if self.interrupt_event.is_set(): # Check if the thread should stop + logger.info(f"Stopping download of {self.url}") + return # Exit the thread gracefully + if data: + file.write(data) + progress.update(len(data)) + self.total_progress_bar.update(len(data)) + + def download_exec(self, urls, paths, max_threads=4): signal.signal(signal.SIGINT, lambda signum, frame: signal_handler(signum, frame, self.interrupt_event)) - - self.progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, position=0, desc="Overall Progress") - position_manager = PositionManager() - semaphore = threading.Semaphore(max_threads) - threads = [] - - for url, path in zip(urls, paths): - semaphore.acquire() # Block until a semaphore permit is available - if self.interrupt_event.is_set(): - logger.info("Download process interrupted. Waiting for ongoing downloads to complete...") - semaphore.release() - break - thread = DownloadThread(url, self.progress_bar, path, {'BBToken': biobricks.config.token()}, position_manager, semaphore, self.interrupt_event) - threads.append(thread) - thread.start() - - for thread in threads: - thread.join() - - self.progress_bar.close() - print(f"\n{len(paths)} files downloaded successfully!") - + with ThreadPoolExecutor(max_threads) as exec: + exec_args = dict(zip(urls, paths)) + futures = {exec.submit(self.exec_task, url, path) for url, path in exec_args} + for future in as_completed(futures): + try: + data = future.result() + except Exception as e: + logger.warning("Exception occurred while downloading brick.") + class DVCFetcher: @@ -176,8 +126,8 @@ def fetch_outs(self, brick, prefixes=['brick/', 'data/']) -> tuple[list[dict], i # download files cache_paths = [self._remote_url_to_cache_path(url) for url in urls] - downloader = DownloadManager() - downloader.download_files(urls, cache_paths, total_size) + downloader = DownloadManager(headers = {'BBToken': biobricks.config.token()}) + downloader.download_exec(urls, cache_paths, 4) # build a symlink between each cache_path and its corresponding path brick_paths = [brick.path() / path for path in paths]