diff --git a/projects/fal/src/fal/toolkit/file/providers/fal.py b/projects/fal/src/fal/toolkit/file/providers/fal.py index e6105a58..d55836ff 100644 --- a/projects/fal/src/fal/toolkit/file/providers/fal.py +++ b/projects/fal/src/fal/toolkit/file/providers/fal.py @@ -205,12 +205,12 @@ class MultipartUpload: def __init__( self, - file_path: str | Path, + file_name: str, chunk_size: int | None = None, content_type: str | None = None, max_concurrency: int | None = None, ) -> None: - self.file_path = file_path + self.file_name = file_name self.chunk_size = chunk_size or self.MULTIPART_CHUNK_SIZE self.content_type = content_type or "application/octet-stream" self.max_concurrency = max_concurrency or self.MULTIPART_MAX_CONCURRENCY @@ -230,7 +230,7 @@ def create(self): }, data=json.dumps( { - "file_name": os.path.basename(self.file_path), + "file_name": self.file_name, "content_type": self.content_type, } ).encode(), @@ -244,47 +244,29 @@ def create(self): f"Error initiating upload. Status {exc.status}: {exc.reason}" ) - def _upload_part(self, url: str, part_number: int) -> dict: - with open(self.file_path, "rb") as f: - start = (part_number - 1) * self.chunk_size - f.seek(start) - data = f.read(self.chunk_size) - req = Request( - url, - method="PUT", - headers={"Content-Type": self.content_type}, - data=data, - ) + def upload_part(self, part_number: int, data: bytes) -> None: + url = f"{self._upload_url}&part_number={part_number}" + + req = Request( + url, + method="PUT", + headers={"Content-Type": self.content_type}, + data=data, + ) - try: - with urlopen(req) as resp: - return { + try: + with urlopen(req) as resp: + self._parts.append( + { "part_number": part_number, "etag": resp.headers["ETag"], } - except HTTPError as exc: - raise FileUploadException( - f"Error uploading part {part_number} to {url}. " - f"Status {exc.status}: {exc.reason}" ) - - def upload(self) -> None: - import concurrent.futures - - parts = math.ceil(os.path.getsize(self.file_path) / self.chunk_size) - with concurrent.futures.ThreadPoolExecutor( - max_workers=self.max_concurrency - ) as executor: - futures = [] - for part_number in range(1, parts + 1): - upload_url = f"{self._upload_url}&part_number={part_number}" - futures.append( - executor.submit(self._upload_part, upload_url, part_number) - ) - - for future in concurrent.futures.as_completed(futures): - entry = future.result() - self._parts.append(entry) + except HTTPError as exc: + raise FileUploadException( + f"Error uploading part {part_number} to {url}. " + f"Status {exc.status}: {exc.reason}" + ) def complete(self): url = self._upload_url @@ -307,6 +289,82 @@ def complete(self): return self._file_url + @classmethod + def save( + cls, + file: FileData, + chunk_size: int | None = None, + max_concurrency: int | None = None, + ): + import concurrent.futures + + multipart = cls( + file.file_name, + chunk_size=chunk_size, + content_type=file.content_type, + max_concurrency=max_concurrency, + ) + multipart.create() + + parts = math.ceil(len(file.data) / multipart.chunk_size) + with concurrent.futures.ThreadPoolExecutor( + max_workers=multipart.max_concurrency + ) as executor: + futures = [] + for part_number in range(1, parts + 1): + start = (part_number - 1) * multipart.chunk_size + data = file.data[start : start + multipart.chunk_size] + futures.append( + executor.submit(multipart.upload_part, part_number, data) + ) + + for future in concurrent.futures.as_completed(futures): + future.result() + + return multipart.complete() + + @classmethod + def save_file( + cls, + file_path: str | Path, + chunk_size: int | None = None, + content_type: str | None = None, + max_concurrency: int | None = None, + ) -> str: + import concurrent.futures + + file_name = os.path.basename(file_path) + size = os.path.getsize(file_path) + + multipart = cls( + file_name, + chunk_size=chunk_size, + content_type=content_type, + max_concurrency=max_concurrency, + ) + multipart.create() + + parts = math.ceil(size / multipart.chunk_size) + with concurrent.futures.ThreadPoolExecutor( + max_workers=multipart.max_concurrency + ) as executor: + futures = [] + for part_number in range(1, parts + 1): + + def _upload_part(pn: int) -> None: + with open(file_path, "rb") as f: + start = (pn - 1) * multipart.chunk_size + f.seek(start) + data = f.read(multipart.chunk_size) + multipart.upload_part(pn, data) + + futures.append(executor.submit(_upload_part, part_number)) + + for future in concurrent.futures.as_completed(futures): + future.result() + + return multipart.complete() + class InternalMultipartUploadV3: MULTIPART_THRESHOLD = 100 * 1024 * 1024 @@ -315,12 +373,12 @@ class InternalMultipartUploadV3: def __init__( self, - file_path: str | Path, + file_name: str, chunk_size: int | None = None, content_type: str | None = None, max_concurrency: int | None = None, ) -> None: - self.file_path = file_path + self.file_name = file_name self.chunk_size = chunk_size or self.MULTIPART_CHUNK_SIZE self.content_type = content_type or "application/octet-stream" self.max_concurrency = max_concurrency or self.MULTIPART_MAX_CONCURRENCY @@ -359,7 +417,7 @@ def create(self): **self.auth_headers, "Accept": "application/json", "Content-Type": self.content_type, - "X-Fal-File-Name": os.path.basename(self.file_path), + "X-Fal-File-Name": self.file_name, }, ) with urlopen(req) as response: @@ -373,52 +431,32 @@ def create(self): ) @retry(max_retries=5, base_delay=1, backoff_type="exponential", jitter=True) - def _upload_part(self, url: str, part_number: int) -> dict: - with open(self.file_path, "rb") as f: - start = (part_number - 1) * self.chunk_size - f.seek(start) - data = f.read(self.chunk_size) - req = Request( - url, - method="PUT", - headers={ - **self.auth_headers, - "Content-Type": self.content_type, - }, - data=data, - ) + def upload_part(self, part_number: int, data: bytes) -> None: + url = f"{self.access_url}/multipart/{self.upload_id}/{part_number}" - try: - with urlopen(req) as resp: - return { + req = Request( + url, + method="PUT", + headers={ + **self.auth_headers, + "Content-Type": self.content_type, + }, + data=data, + ) + + try: + with urlopen(req) as resp: + self._parts.append( + { "partNumber": part_number, "etag": resp.headers["ETag"], } - except HTTPError as exc: - raise FileUploadException( - f"Error uploading part {part_number} to {url}. " - f"Status {exc.status}: {exc.reason}" - ) - - def upload(self) -> None: - import concurrent.futures - - parts = math.ceil(os.path.getsize(self.file_path) / self.chunk_size) - with concurrent.futures.ThreadPoolExecutor( - max_workers=self.max_concurrency - ) as executor: - futures = [] - for part_number in range(1, parts + 1): - upload_url = ( - f"{self.access_url}/multipart/{self.upload_id}/{part_number}" - ) - futures.append( - executor.submit(self._upload_part, upload_url, part_number) ) - - for future in concurrent.futures.as_completed(futures): - entry = future.result() - self._parts.append(entry) + except HTTPError as exc: + raise FileUploadException( + f"Error uploading part {part_number} to {url}. " + f"Status {exc.status}: {exc.reason}" + ) def complete(self) -> str: url = f"{self.access_url}/multipart/{self.upload_id}/complete" @@ -442,13 +480,106 @@ def complete(self) -> str: return self.access_url + @classmethod + def save( + cls, + file: FileData, + chunk_size: int | None = None, + max_concurrency: int | None = None, + ): + import concurrent.futures + + multipart = cls( + file.file_name, + chunk_size=chunk_size, + content_type=file.content_type, + max_concurrency=max_concurrency, + ) + multipart.create() + + parts = math.ceil(len(file.data) / multipart.chunk_size) + with concurrent.futures.ThreadPoolExecutor( + max_workers=multipart.max_concurrency + ) as executor: + futures = [] + for part_number in range(1, parts + 1): + start = (part_number - 1) * multipart.chunk_size + data = file.data[start : start + multipart.chunk_size] + futures.append( + executor.submit(multipart.upload_part, part_number, data) + ) + + for future in concurrent.futures.as_completed(futures): + future.result() + + return multipart.complete() + + @classmethod + def save_file( + cls, + file_path: str | Path, + chunk_size: int | None = None, + content_type: str | None = None, + max_concurrency: int | None = None, + ) -> str: + import concurrent.futures + + file_name = os.path.basename(file_path) + size = os.path.getsize(file_path) + + multipart = cls( + file_name, + chunk_size=chunk_size, + content_type=content_type, + max_concurrency=max_concurrency, + ) + multipart.create() + + parts = math.ceil(size / multipart.chunk_size) + with concurrent.futures.ThreadPoolExecutor( + max_workers=multipart.max_concurrency + ) as executor: + futures = [] + for part_number in range(1, parts + 1): + + def _upload_part(pn: int) -> None: + with open(file_path, "rb") as f: + start = (pn - 1) * multipart.chunk_size + f.seek(start) + data = f.read(multipart.chunk_size) + multipart.upload_part(pn, data) + + futures.append(executor.submit(_upload_part, part_number)) + + for future in concurrent.futures.as_completed(futures): + future.result() + + return multipart.complete() + @dataclass class FalFileRepositoryV2(FalFileRepositoryBase): @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True) def save( - self, file: FileData, object_lifecycle_preference: dict[str, str] | None = None + self, + file: FileData, + multipart: bool | None = None, + multipart_threshold: int | None = None, + multipart_chunk_size: int | None = None, + multipart_max_concurrency: int | None = None, + object_lifecycle_preference: dict[str, str] | None = None, ) -> str: + if multipart is None: + threshold = multipart_threshold or MultipartUpload.MULTIPART_THRESHOLD + multipart = len(file.data) > threshold + + if multipart: + return MultipartUpload.save( + file, + chunk_size=multipart_chunk_size, + max_concurrency=multipart_max_concurrency, + ) + token = fal_v2_token_manager.get_token() headers = { "Authorization": f"{token.token_type} {token.token}", @@ -475,23 +606,6 @@ def save( f"Error initiating upload. Status {e.status}: {e.reason}" ) - def _save_multipart( - self, - file_path: str | Path, - chunk_size: int | None = None, - content_type: str | None = None, - max_concurrency: int | None = None, - ) -> str: - multipart = MultipartUpload( - file_path, - chunk_size=chunk_size, - content_type=content_type, - max_concurrency=max_concurrency, - ) - multipart.create() - multipart.upload() - return multipart.complete() - def save_file( self, file_path: str | Path, @@ -507,7 +621,7 @@ def save_file( multipart = os.path.getsize(file_path) > threshold if multipart: - url = self._save_multipart( + url = MultipartUpload.save_file( file_path, chunk_size=multipart_chunk_size, content_type=content_type, @@ -608,8 +722,27 @@ def _object_lifecycle_headers( @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True) def save( - self, file: FileData, object_lifecycle_preference: dict[str, str] | None + self, + file: FileData, + multipart: bool | None = None, + multipart_threshold: int | None = None, + multipart_chunk_size: int | None = None, + multipart_max_concurrency: int | None = None, + object_lifecycle_preference: dict[str, str] | None = None, ) -> str: + if multipart is None: + threshold = ( + multipart_threshold or InternalMultipartUploadV3.MULTIPART_THRESHOLD + ) + multipart = len(file.data) > threshold + + if multipart: + return InternalMultipartUploadV3.save( + file, + chunk_size=multipart_chunk_size, + max_concurrency=multipart_max_concurrency, + ) + headers = { **self.auth_headers, "Accept": "application/json", @@ -640,23 +773,6 @@ def auth_headers(self) -> dict[str, str]: "User-Agent": "fal/0.1.0", } - def _save_multipart( - self, - file_path: str | Path, - chunk_size: int | None = None, - content_type: str | None = None, - max_concurrency: int | None = None, - ) -> str: - multipart = InternalMultipartUploadV3( - file_path, - chunk_size=chunk_size, - content_type=content_type, - max_concurrency=max_concurrency, - ) - multipart.create() - multipart.upload() - return multipart.complete() - def save_file( self, file_path: str | Path, @@ -672,7 +788,7 @@ def save_file( multipart = os.path.getsize(file_path) > threshold if multipart: - url = self._save_multipart( + url = MultipartUpload.save_file( file_path, chunk_size=multipart_chunk_size, content_type=content_type,