diff --git a/projects/fal/src/fal/toolkit/__init__.py b/projects/fal/src/fal/toolkit/__init__.py index deb0a4dc..ef0c928b 100644 --- a/projects/fal/src/fal/toolkit/__init__.py +++ b/projects/fal/src/fal/toolkit/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from fal.toolkit.file.file import File +from fal.toolkit.file import File from fal.toolkit.image.image import Image, ImageSizeInput, get_image_size from fal.toolkit.mainify import mainify from fal.toolkit.utils import ( diff --git a/projects/fal/src/fal/toolkit/file/file.py b/projects/fal/src/fal/toolkit/file/file.py index 72bbba0f..ad5ef046 100644 --- a/projects/fal/src/fal/toolkit/file/file.py +++ b/projects/fal/src/fal/toolkit/file/file.py @@ -1,17 +1,23 @@ from __future__ import annotations +from io import BytesIO, FileIO, IOBase + from pathlib import Path -from typing import Callable +from typing import Any, Callable +from urllib.parse import urlparse from fal.toolkit.file.providers.fal import FalFileRepository, InMemoryRepository from fal.toolkit.file.providers.gcp import GoogleStorageRepository from fal.toolkit.file.providers.r2 import R2Repository -from fal.toolkit.file.types import FileData, FileRepository, RepositoryId +from fal.toolkit.file.types import FileData, FileRepository, RepositoryId, RemoteFileIO from fal.toolkit.mainify import mainify from pydantic import BaseModel, Field, PrivateAttr +from pydantic.typing import Optional FileRepositoryFactory = Callable[[], FileRepository] +DEFAULT_REPOSITORY: FileRepository | RepositoryId = "fal" + BUILT_IN_REPOSITORIES: dict[RepositoryId, FileRepositoryFactory] = { "fal": lambda: FalFileRepository(), "in_memory": lambda: InMemoryRepository(), @@ -20,21 +26,24 @@ } +@mainify def get_builtin_repository(id: RepositoryId) -> FileRepository: if id not in BUILT_IN_REPOSITORIES.keys(): raise ValueError(f'"{id}" is not a valid built-in file repository') return BUILT_IN_REPOSITORIES[id]() -get_builtin_repository.__module__ = "__main__" - -DEFAULT_REPOSITORY: FileRepository | RepositoryId = "fal" +@mainify +def get_repository(id: FileRepository | RepositoryId) -> FileRepository: + if isinstance(id, FileRepository): + return id + return get_builtin_repository(id) @mainify class File(BaseModel): - # public properties _file_data: FileData = PrivateAttr() + url: str = Field( description="The URL where the file can be downloaded from.", examples=["https://url.to/generated/file/z9RV14K95DvU.png"], @@ -47,31 +56,13 @@ class File(BaseModel): description="The name of the file. It will be auto-generated if not provided.", examples=["z9RV14K95DvU.png"], ) - file_size: int = Field( - description="The size of the file in bytes.", examples=[4404019] + file_size: Optional[int] = Field( + description="The size of the file in bytes, when available.", examples=[4404019] ) def __init__(self, **kwargs): if "file_data" in kwargs: - data = kwargs.pop("file_data") - repository = kwargs.pop("repository", None) - - repo = ( - repository - if isinstance(repository, FileRepository) - else get_builtin_repository(repository) - ) - self._file_data = data - - kwargs.update( - { - "url": repo.save(data), - "content_type": data.content_type, - "file_name": data.file_name, - "file_size": len(data.data), - } - ) - + self._file_data = kwargs.pop("file_data") super().__init__(**kwargs) @classmethod @@ -82,11 +73,34 @@ def from_bytes( file_name: str | None = None, repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY, ) -> File: - return cls( - file_data=FileData(data, content_type, file_name), + return cls.from_fileobj( + BytesIO(data), + content_type=content_type, + file_name=file_name, repository=repository, ) + @classmethod + def from_fileobj( + cls, + fileobj: IOBase, + content_type: str | None = None, + file_name: str | None = None, + repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY, + ) -> File: + file_data = FileData(fileobj, content_type, file_name) + + file_repository = get_repository(repository) + url = file_repository.save(file_data) + + return cls( + file_data=file_data, + url=url, + content_type=file_data.content_type, + file_name=file_data.file_name, + file_size=file_data.file_size, + ) + @classmethod def from_path( cls, @@ -95,13 +109,61 @@ def from_path( repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY, ) -> File: file_path = Path(path) + if not file_path.exists(): raise FileNotFoundError(f"File {file_path} does not exist") - with open(file_path, "rb") as f: - data = f.read() - return File.from_bytes( - data, content_type, file_name=file_path.name, repository=repository + + file_data = FileData(FileIO(file_path), content_type, file_path.name) + + file_repository = get_repository(repository) + url = file_repository.save(file_data) + + return cls( + file_data=file_data, + url=url, + content_type=file_data.content_type, + file_name=file_data.file_name, + file_size=file_data.file_size, + ) + + # Pydantic custom validator for input type conversion + @classmethod + def __get_validators__(cls): + yield cls.__convert_from_str + + @classmethod + def __convert_from_str(cls, value: Any): + if isinstance(value, str): + url = urlparse(value) + if url.scheme not in ["http", "https", "data"]: + raise ValueError(f"value must be a valid URL") + return cls._from_url(url.geturl()) + return value + + @classmethod + def _from_url( + cls, + url: str, + ) -> File: + remote_url = RemoteFileIO(url) + file_data = FileData(remote_url) + + return cls( + file_data=file_data, + url=url, + content_type=file_data.content_type, + file_name=file_data.file_name, + file_size=file_data.file_size, ) def as_bytes(self) -> bytes: - return self._file_data.data + content = self._file_data.as_bytes() + self.file_size = len(content) + + if not content: + raise Exception("File is empty") + return content + + def save(self, path: str | Path): + file_path = Path(path) + file_path.write_bytes(self.as_bytes()) diff --git a/projects/fal/src/fal/toolkit/file/providers/fal.py b/projects/fal/src/fal/toolkit/file/providers/fal.py index 65a0c661..db0ff7a1 100644 --- a/projects/fal/src/fal/toolkit/file/providers/fal.py +++ b/projects/fal/src/fal/toolkit/file/providers/fal.py @@ -47,6 +47,11 @@ def save(self, file: FileData) -> str: upload_url = result["upload_url"] self._upload_file(upload_url, file) + # Reset the file pointer to the beginning of the file + # so that it can be read again. + file.data.seek(0) + file.file_size = file.data.tell() + return result["file_url"] except HTTPError as e: raise FileUploadException( @@ -58,7 +63,7 @@ def _upload_file(self, upload_url: str, file: FileData): upload_url, method="PUT", data=file.data, - headers={"Content-Type": file.content_type}, + headers={"Content-Type": file.content_type}, # type: ignore ) with urlopen(req): @@ -69,4 +74,7 @@ def _upload_file(self, upload_url: str, file: FileData): @dataclass class InMemoryRepository(FileRepository): def save(self, file: FileData) -> str: - return f'data:{file.content_type};base64,{b64encode(file.data).decode("utf-8")}' + return ( + f"data:{file.content_type};base64," + f'{b64encode(file.as_bytes()).decode("utf-8")}' # type: ignore + ) diff --git a/projects/fal/src/fal/toolkit/file/providers/gcp.py b/projects/fal/src/fal/toolkit/file/providers/gcp.py index 7a0d1c65..14d925ed 100644 --- a/projects/fal/src/fal/toolkit/file/providers/gcp.py +++ b/projects/fal/src/fal/toolkit/file/providers/gcp.py @@ -51,10 +51,17 @@ def bucket(self): return self._bucket def save(self, data: FileData) -> str: - destination_path = os.path.join(self.folder, data.file_name) + if not data.file_name: + raise ValueError("File name is required") + + file_name = data.file_name + destination_path = os.path.join(self.folder, file_name) gcp_blob = self.bucket.blob(destination_path) - gcp_blob.upload_from_string(data.data, content_type=data.content_type) + + with data.data as file: + gcp_blob.upload_from_file(file, content_type=data.content_type) + data.file_size = file.tell() if self.url_expiration is None: return gcp_blob.public_url diff --git a/projects/fal/src/fal/toolkit/file/providers/r2.py b/projects/fal/src/fal/toolkit/file/providers/r2.py index dbd8c85c..427fdafd 100644 --- a/projects/fal/src/fal/toolkit/file/providers/r2.py +++ b/projects/fal/src/fal/toolkit/file/providers/r2.py @@ -68,12 +68,18 @@ def bucket(self): return self._bucket def save(self, data: FileData) -> str: - destination_path = os.path.join(self.key, data.file_name) + if not data.file_name: + raise ValueError("File name is required") + + file_name = data.file_name + + destination_path = os.path.join(self.key, file_name) s3_object = self.bucket.Object(destination_path) s3_object.upload_fileobj( - BytesIO(data.data), ExtraArgs={"ContentType": data.content_type} + BytesIO(data.data), ExtraArgs={"ContentType": data.content_type} # type: ignore ) + data.file_size = data.data.tell() public_url = self._s3_client.generate_presigned_url( ClientMethod="get_object", diff --git a/projects/fal/src/fal/toolkit/file/types.py b/projects/fal/src/fal/toolkit/file/types.py index 9c09a85c..29b4f1ed 100644 --- a/projects/fal/src/fal/toolkit/file/types.py +++ b/projects/fal/src/fal/toolkit/file/types.py @@ -1,26 +1,76 @@ from __future__ import annotations from dataclasses import dataclass +from io import FileIO, IOBase from mimetypes import guess_extension, guess_type +from os import remove as remove_file +from tempfile import mkdtemp from typing import Literal from uuid import uuid4 from fal.toolkit.mainify import mainify +from fal.toolkit.utils.download_utils import download_file + + +RepositoryId = Literal["fal", "in_memory", "gcp_storage", "r2"] + + +@mainify +class RemoteFileIO(IOBase): + url: str + _file: FileIO | None + + def __init__(self, url: str): + self.url = url + self._file = None + + def _ensure_file_is_downloaded(self): + temp_dir = mkdtemp(prefix="fal_file_", suffix="", dir=None) + file_path = download_file(self.url, temp_dir) + + self._file = FileIO(file_path, mode="rb") + + def read(self, size: int = -1) -> bytes: + if not self.is_downloaded(): + self._ensure_file_is_downloaded() + + return self._file.read(size) # type: ignore + + def is_downloaded(self) -> bool: + return self._file is not None + + def close(self) -> None: + if self._file: + self._file.close() + remove_file(str(self._file.name)) + self._file = None @mainify class FileData: - data: bytes - content_type: str - file_name: str + data: IOBase + content_type: str | None = None + file_name: str | None = None + file_size: int | None = None + + _cached_content: bytes | None = None def __init__( - self, data: bytes, content_type: str | None = None, file_name: str | None = None + self, + data: IOBase, + content_type: str | None = None, + file_name: str | None = None, ): self.data = data + if content_type is None and file_name is not None: content_type, _ = guess_type(file_name or "") + # If the data is a remote file, try to guess the content type from the url + url = getattr(data, "url", None) + if url and content_type is None: + content_type, _ = guess_type(url) + # Ultimately fallback to a generic binary file mime type self.content_type = content_type or "application/octet-stream" @@ -30,8 +80,21 @@ def __init__( else: self.file_name = file_name + def as_bytes(self) -> bytes: + if self._cached_content: + return self._cached_content -RepositoryId = Literal["fal", "in_memory", "gcp_storage", "r2"] + content: bytes | str = self.data.read() + # For files that open in text mode, convert to bytes + if isinstance(content, str): + content = content.encode() + + self._cached_content = content + + if content: + self.file_size = len(content) + + return content @mainify diff --git a/projects/fal/src/fal/toolkit/image/image.py b/projects/fal/src/fal/toolkit/image/image.py index a53d89b0..0b526d96 100644 --- a/projects/fal/src/fal/toolkit/image/image.py +++ b/projects/fal/src/fal/toolkit/image/image.py @@ -1,10 +1,11 @@ from __future__ import annotations import io +from io import BytesIO from typing import TYPE_CHECKING, Literal, Optional, Union from fal.toolkit.file.file import DEFAULT_REPOSITORY, File -from fal.toolkit.file.types import FileData, FileRepository, RepositoryId +from fal.toolkit.file.types import FileRepository, RepositoryId from fal.toolkit.mainify import mainify from pydantic import BaseModel, Field @@ -81,14 +82,8 @@ def from_bytes( # type: ignore[override] file_name: str | None = None, repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY, ) -> Image: - file_data = FileData( - data=data, content_type=f"image/{format}", file_name=file_name - ) - return cls( - file_data=file_data, - repository=repository, - width=size.width if size else None, - height=size.height if size else None, + return cls.from_fileobj( + BytesIO(data), content_type=f"image/{format}", file_name=file_name, repository=repository ) @classmethod diff --git a/projects/fal/src/fal/toolkit/utils/download_utils.py b/projects/fal/src/fal/toolkit/utils/download_utils.py index 319efef0..59dc9e55 100644 --- a/projects/fal/src/fal/toolkit/utils/download_utils.py +++ b/projects/fal/src/fal/toolkit/utils/download_utils.py @@ -71,8 +71,13 @@ def _get_remote_file_properties(url: str) -> tuple[str, int]: content_length = int(response.headers.get("Content-Length", -1)) if not file_name: - url_path = urlparse(url).path - file_name = Path(url_path).name or _hash_url(url) + parsed_url = urlparse(url) + + if parsed_url.scheme == "data": + file_name = _hash_url(url) + else: + url_path = parsed_url.path + file_name = Path(url_path).name or _hash_url(url) return file_name, content_length diff --git a/projects/fal/tests/integration_test.py b/projects/fal/tests/integration_test.py index db2f28ce..f6cd25f3 100644 --- a/projects/fal/tests/integration_test.py +++ b/projects/fal/tests/integration_test.py @@ -1,8 +1,9 @@ from __future__ import annotations from pathlib import Path +import tempfile from uuid import uuid4 - +from fal.toolkit import File import fal import pytest from fal import FalServerlessHost, FalServerlessKeyCredentials, local, sync_dir @@ -429,3 +430,112 @@ def clone_with_force(): assert ( first_repo_stat.st_mtime < third_repo_stat.st_mtime ), "The repository should be cloned again with force=True" + + +def fal_file_downloaded(file: File): + return file.file_size != None + + +def fal_file_url_matches(file: File, url: str): + return file.url == url + + +def fal_file_content_matches(file: File, content: str): + return file.as_bytes().decode() == content + + +def test_fal_file_from_path(isolated_client): + @isolated_client(requirements=["pydantic==1.10.12"]) + def fal_file_from_temp(content: str): + with tempfile.NamedTemporaryFile() as temp_file: + file_path = temp_file.name + + with open(file_path, "w") as fp: + fp.write(content) + + return File.from_path(file_path, repository="in_memory") + + file_content = "file-test" + file = fal_file_from_temp(file_content) + + assert fal_file_content_matches(file, file_content) + + +def test_fal_file_from_bytes(isolated_client): + @isolated_client(requirements=["pydantic==1.10.12"]) + def fal_file_from_bytes(content: str): + return File.from_bytes(content.encode(), repository="in_memory") + + file_content = "file-test" + file = fal_file_from_bytes(file_content) + + assert fal_file_content_matches(file, file_content) + + +def test_fal_file_from_fileobj(isolated_client): + @isolated_client(requirements=["pydantic==1.10.12"]) + def fal_file_from_fileobj(content: str): + temp_file = tempfile.NamedTemporaryFile(delete=False) + with open(temp_file.name, "w+") as fp: + fp.write(content) + + fp.seek(0) + return File.from_fileobj(fp, repository="in_memory") + + file_content = "file-test" + file = fal_file_from_fileobj(file_content) + + assert fal_file_content_matches(file, file_content) + + +def test_fal_file_save(isolated_client): + @isolated_client(requirements=["pydantic==1.10.12"]) + def fal_file_to_local_file(content: str): + file = File.from_bytes(content.encode(), repository="in_memory") + + with tempfile.NamedTemporaryFile() as temp_file: + file_name = temp_file.name + file.save(file_name) + + with open(file_name) as fp: + file_content = fp.read() + + return file_content + + file_content = "file-test" + saved_file_content = fal_file_to_local_file(file_content) + + assert file_content == saved_file_content + + +@pytest.mark.parametrize( + "file_url, expected_content", + [ + ( + "https://raw.githubusercontent.com/fal-ai/fal/fe0e2a1aa4b46a42a93bad0fbd9aca4aefcb4296/README.md", + "projects/fal/README.md", + ), + ("data:text/plain;charset=UTF-8,fal", "fal"), + ], +) +def test_fal_file_input(isolated_client, file_url: str, expected_content: str): + from pydantic import BaseModel, Field + + class TestInput(BaseModel): + file: File = Field() + + @isolated_client(requirements=["pydantic==1.10.12"]) + def init_file_on_fal(input: TestInput) -> File: + return input.file + + test_input = TestInput(file=file_url) + file = init_file_on_fal(test_input) + + # File is not downloaded until it is needed + assert not fal_file_downloaded(file) + + assert fal_file_url_matches(file, file_url) + + # File will be downloaded when content is accessed + assert fal_file_content_matches(file, expected_content) + assert fal_file_downloaded(file) diff --git a/projects/fal/tests/test_stability.py b/projects/fal/tests/test_stability.py index c023e530..de978da6 100644 --- a/projects/fal/tests/test_stability.py +++ b/projects/fal/tests/test_stability.py @@ -490,13 +490,12 @@ def test_fal_storage(isolated_client): ) assert file.as_bytes().decode().endswith("local") - @isolated_client(serve=True) + @isolated_client(requirements=["pydantic==1.10.12"]) def hello_file(): # Run in the isolated environment return File.from_bytes(b"Hello fal storage from isolated", repository="fal") - local_fn = hello_file.on(serve=False) - file = local_fn() + file = hello_file() assert file.url.startswith( "https://storage.googleapis.com/isolate-dev-smiling-shark_toolkit_bucket/" )