Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor threading in DownloadManager #35

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 30 additions & 80 deletions biobricks/dvc_fetcher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import biobricks.checks
import biobricks.config
from biobricks.logger import logger


from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
import requests, threading, time, shutil, os
import signal
Expand All @@ -13,91 +12,42 @@ def signal_handler(signum, frame, interrupt_event):
interrupt_event.set()
logger.info("Interrupt signal received. Attempting to terminate downloads gracefully...")

class PositionManager:
def __init__(self):
self.available_positions = []
self.lock = threading.Lock()
self.max_position = 0

def get_position(self):
with self.lock:
if self.available_positions:
return self.available_positions.pop(0)
else:
self.max_position += 1
return self.max_position

def release_position(self, position):
with self.lock:
self.available_positions.append(position)
self.available_positions.sort()

class DownloadThread(threading.Thread):

def __init__(self, url, total_progress_bar, path, headers, position_manager, semaphore, interrupt_event):
super(DownloadThread, self).__init__()
self.url = url
self.total_progress_bar = total_progress_bar
self.path = path
self.headers = headers
self.position_manager = position_manager
self.semaphore = semaphore
self.interrupt_event = interrupt_event

def run(self):
position = self.position_manager.get_position()
self.path.parent.mkdir(parents=True, exist_ok=True)
try:
response = requests.get(self.url, stream=True, headers=self.headers)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
block_size = 1024

with tqdm(total=total_size, unit='iB', unit_scale=True, disable=False, desc=str(self.path), position=position, leave=False) as progress:
with open(self.path, 'wb') as file:
for data in response.iter_content(chunk_size=block_size):
if self.interrupt_event.is_set(): # Check if the thread should stop
logger.info(f"Stopping download of {self.url}")
return # Exit the thread gracefully
if data:
file.write(data)
progress.update(len(data))
self.total_progress_bar.update(len(data))
finally:
self.semaphore.release() # Release the semaphore when the thread is done
self.position_manager.release_position(position)

@dataclass
class DownloadManager:
headers: dict
skip_existing: bool = False
progress_bar : tqdm = None
active_threads : int = 0
interrupt_event : threading.Event = threading.Event()

def download_files(self, urls, paths, total_size, max_threads=4):
def exec_task(self, url, path):
response = requests.get(url, stream=True, headers=self.headers)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
block_size = 1024
with tqdm(total=total_size, unit='iB', unit_scale=True, disable=False, desc=str(path), leave=False) as progress:
with open(path, 'wb') as file:
for data in response.iter_content(chunk_size=block_size):
if self.interrupt_event.is_set(): # Check if the thread should stop
logger.info(f"Stopping download of {self.url}")
return # Exit the thread gracefully
if data:
file.write(data)
progress.update(len(data))
self.total_progress_bar.update(len(data))

def download_exec(self, urls, paths, max_threads=4):
signal.signal(signal.SIGINT, lambda signum, frame: signal_handler(signum, frame, self.interrupt_event))

self.progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, position=0, desc="Overall Progress")
position_manager = PositionManager()
semaphore = threading.Semaphore(max_threads)
threads = []

for url, path in zip(urls, paths):
semaphore.acquire() # Block until a semaphore permit is available
if self.interrupt_event.is_set():
logger.info("Download process interrupted. Waiting for ongoing downloads to complete...")
semaphore.release()
break
thread = DownloadThread(url, self.progress_bar, path, {'BBToken': biobricks.config.token()}, position_manager, semaphore, self.interrupt_event)
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

self.progress_bar.close()
print(f"\n{len(paths)} files downloaded successfully!")

with ThreadPoolExecutor(max_threads) as exec:
exec_args = dict(zip(urls, paths))
futures = {exec.submit(self.exec_task, url, path) for url, path in exec_args}
for future in as_completed(futures):
try:
data = future.result()
except Exception as e:
logger.warning("Exception occurred while downloading brick.")


class DVCFetcher:

Expand Down Expand Up @@ -176,8 +126,8 @@ def fetch_outs(self, brick, prefixes=['brick/', 'data/']) -> tuple[list[dict], i

# download files
cache_paths = [self._remote_url_to_cache_path(url) for url in urls]
downloader = DownloadManager()
downloader.download_files(urls, cache_paths, total_size)
downloader = DownloadManager(headers = {'BBToken': biobricks.config.token()})
downloader.download_exec(urls, cache_paths, 4)

# build a symlink between each cache_path and its corresponding path
brick_paths = [brick.path() / path for path in paths]
Expand Down
Loading