Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: fal file #16

Closed
wants to merge 10 commits into from
2 changes: 1 addition & 1 deletion projects/fal/src/fal/toolkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from fal.toolkit.file.file import File
from fal.toolkit.file import File
badayvedat marked this conversation as resolved.
Show resolved Hide resolved
from fal.toolkit.image.image import Image, ImageSizeInput, get_image_size
from fal.toolkit.mainify import mainify
from fal.toolkit.utils import (
Expand Down
130 changes: 96 additions & 34 deletions projects/fal/src/fal/toolkit/file/file.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from __future__ import annotations

from io import BytesIO, FileIO, IOBase

from pathlib import Path
from typing import Callable
from typing import Any, Callable
from urllib.parse import urlparse

from fal.toolkit.file.providers.fal import FalFileRepository, InMemoryRepository
from fal.toolkit.file.providers.gcp import GoogleStorageRepository
from fal.toolkit.file.providers.r2 import R2Repository
from fal.toolkit.file.types import FileData, FileRepository, RepositoryId
from fal.toolkit.file.types import FileData, FileRepository, RepositoryId, RemoteFileIO
from fal.toolkit.mainify import mainify
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.typing import Optional

FileRepositoryFactory = Callable[[], FileRepository]

DEFAULT_REPOSITORY: FileRepository | RepositoryId = "fal"

BUILT_IN_REPOSITORIES: dict[RepositoryId, FileRepositoryFactory] = {
"fal": lambda: FalFileRepository(),
"in_memory": lambda: InMemoryRepository(),
Expand All @@ -20,21 +26,24 @@
}


@mainify
def get_builtin_repository(id: RepositoryId) -> FileRepository:
if id not in BUILT_IN_REPOSITORIES.keys():
raise ValueError(f'"{id}" is not a valid built-in file repository')
return BUILT_IN_REPOSITORIES[id]()


get_builtin_repository.__module__ = "__main__"

DEFAULT_REPOSITORY: FileRepository | RepositoryId = "fal"
@mainify
def get_repository(id: FileRepository | RepositoryId) -> FileRepository:
if isinstance(id, FileRepository):
return id
return get_builtin_repository(id)


@mainify
class File(BaseModel):
# public properties
_file_data: FileData = PrivateAttr()

url: str = Field(
description="The URL where the file can be downloaded from.",
examples=["https://url.to/generated/file/z9RV14K95DvU.png"],
Expand All @@ -47,31 +56,13 @@ class File(BaseModel):
description="The name of the file. It will be auto-generated if not provided.",
examples=["z9RV14K95DvU.png"],
)
file_size: int = Field(
description="The size of the file in bytes.", examples=[4404019]
file_size: Optional[int] = Field(
description="The size of the file in bytes, when available.", examples=[4404019]
)

def __init__(self, **kwargs):
if "file_data" in kwargs:
data = kwargs.pop("file_data")
repository = kwargs.pop("repository", None)

repo = (
repository
if isinstance(repository, FileRepository)
else get_builtin_repository(repository)
)
self._file_data = data

kwargs.update(
{
"url": repo.save(data),
"content_type": data.content_type,
"file_name": data.file_name,
"file_size": len(data.data),
}
)

self._file_data = kwargs.pop("file_data")
super().__init__(**kwargs)

@classmethod
Expand All @@ -82,11 +73,34 @@ def from_bytes(
file_name: str | None = None,
repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
) -> File:
return cls(
file_data=FileData(data, content_type, file_name),
return cls.from_fileobj(
BytesIO(data),
content_type=content_type,
file_name=file_name,
repository=repository,
)

@classmethod
def from_fileobj(
cls,
fileobj: IOBase,
content_type: str | None = None,
file_name: str | None = None,
repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
) -> File:
file_data = FileData(fileobj, content_type, file_name)

file_repository = get_repository(repository)
url = file_repository.save(file_data)

return cls(
file_data=file_data,
url=url,
content_type=file_data.content_type,
file_name=file_data.file_name,
file_size=file_data.file_size,
)

@classmethod
def from_path(
cls,
Expand All @@ -95,13 +109,61 @@ def from_path(
repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
) -> File:
file_path = Path(path)

if not file_path.exists():
raise FileNotFoundError(f"File {file_path} does not exist")
with open(file_path, "rb") as f:
data = f.read()
return File.from_bytes(
data, content_type, file_name=file_path.name, repository=repository

file_data = FileData(FileIO(file_path), content_type, file_path.name)

file_repository = get_repository(repository)
url = file_repository.save(file_data)

return cls(
file_data=file_data,
url=url,
content_type=file_data.content_type,
file_name=file_data.file_name,
file_size=file_data.file_size,
)

# Pydantic custom validator for input type conversion
@classmethod
def __get_validators__(cls):
yield cls.__convert_from_str

@classmethod
def __convert_from_str(cls, value: Any):
if isinstance(value, str):
url = urlparse(value)
if url.scheme not in ["http", "https", "data"]:
raise ValueError(f"value must be a valid URL")
return cls._from_url(url.geturl())
return value

@classmethod
def _from_url(
cls,
url: str,
) -> File:
remote_url = RemoteFileIO(url)
file_data = FileData(remote_url)

return cls(
file_data=file_data,
url=url,
content_type=file_data.content_type,
file_name=file_data.file_name,
file_size=file_data.file_size,
)

def as_bytes(self) -> bytes:
return self._file_data.data
content = self._file_data.as_bytes()
self.file_size = len(content)

if not content:
raise Exception("File is empty")
return content

def save(self, path: str | Path):
file_path = Path(path)
file_path.write_bytes(self.as_bytes())
12 changes: 10 additions & 2 deletions projects/fal/src/fal/toolkit/file/providers/fal.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def save(self, file: FileData) -> str:
upload_url = result["upload_url"]
self._upload_file(upload_url, file)

# Reset the file pointer to the beginning of the file
# so that it can be read again.
file.data.seek(0)
file.file_size = file.data.tell()

return result["file_url"]
except HTTPError as e:
raise FileUploadException(
Expand All @@ -58,7 +63,7 @@ def _upload_file(self, upload_url: str, file: FileData):
upload_url,
method="PUT",
data=file.data,
headers={"Content-Type": file.content_type},
headers={"Content-Type": file.content_type}, # type: ignore
)

with urlopen(req):
Expand All @@ -69,4 +74,7 @@ def _upload_file(self, upload_url: str, file: FileData):
@dataclass
class InMemoryRepository(FileRepository):
def save(self, file: FileData) -> str:
return f'data:{file.content_type};base64,{b64encode(file.data).decode("utf-8")}'
return (
f"data:{file.content_type};base64,"
f'{b64encode(file.as_bytes()).decode("utf-8")}' # type: ignore
)
11 changes: 9 additions & 2 deletions projects/fal/src/fal/toolkit/file/providers/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,17 @@ def bucket(self):
return self._bucket

def save(self, data: FileData) -> str:
destination_path = os.path.join(self.folder, data.file_name)
if not data.file_name:
raise ValueError("File name is required")

file_name = data.file_name
destination_path = os.path.join(self.folder, file_name)

gcp_blob = self.bucket.blob(destination_path)
gcp_blob.upload_from_string(data.data, content_type=data.content_type)

with data.data as file:
gcp_blob.upload_from_file(file, content_type=data.content_type)
Comment on lines +62 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work if filename == "" ?

Copy link
Contributor Author

@badayvedat badayvedat Jan 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! we should've raised an error in that case

data.file_size = file.tell()

if self.url_expiration is None:
return gcp_blob.public_url
Expand Down
10 changes: 8 additions & 2 deletions projects/fal/src/fal/toolkit/file/providers/r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,18 @@ def bucket(self):
return self._bucket

def save(self, data: FileData) -> str:
destination_path = os.path.join(self.key, data.file_name)
if not data.file_name:
raise ValueError("File name is required")

file_name = data.file_name

destination_path = os.path.join(self.key, file_name)

s3_object = self.bucket.Object(destination_path)
s3_object.upload_fileobj(
BytesIO(data.data), ExtraArgs={"ContentType": data.content_type}
BytesIO(data.data), ExtraArgs={"ContentType": data.content_type} # type: ignore
)
data.file_size = data.data.tell()

public_url = self._s3_client.generate_presigned_url(
ClientMethod="get_object",
Expand Down
73 changes: 68 additions & 5 deletions projects/fal/src/fal/toolkit/file/types.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,76 @@
from __future__ import annotations

from dataclasses import dataclass
from io import FileIO, IOBase
from mimetypes import guess_extension, guess_type
from os import remove as remove_file
from tempfile import mkdtemp
from typing import Literal
from uuid import uuid4

from fal.toolkit.mainify import mainify
from fal.toolkit.utils.download_utils import download_file


RepositoryId = Literal["fal", "in_memory", "gcp_storage", "r2"]


@mainify
class RemoteFileIO(IOBase):
url: str
_file: FileIO | None

def __init__(self, url: str):
self.url = url
self._file = None

def _ensure_file_is_downloaded(self):
temp_dir = mkdtemp(prefix="fal_file_", suffix="", dir=None)
file_path = download_file(self.url, temp_dir)

self._file = FileIO(file_path, mode="rb")

def read(self, size: int = -1) -> bytes:
if not self.is_downloaded():
self._ensure_file_is_downloaded()

return self._file.read(size) # type: ignore

def is_downloaded(self) -> bool:
return self._file is not None

def close(self) -> None:
if self._file:
self._file.close()
remove_file(str(self._file.name))
self._file = None


@mainify
class FileData:
data: bytes
content_type: str
file_name: str
data: IOBase
content_type: str | None = None
file_name: str | None = None
file_size: int | None = None

_cached_content: bytes | None = None

def __init__(
self, data: bytes, content_type: str | None = None, file_name: str | None = None
self,
data: IOBase,
content_type: str | None = None,
file_name: str | None = None,
):
self.data = data

if content_type is None and file_name is not None:
content_type, _ = guess_type(file_name or "")

# If the data is a remote file, try to guess the content type from the url
url = getattr(data, "url", None)
if url and content_type is None:
content_type, _ = guess_type(url)

# Ultimately fallback to a generic binary file mime type
self.content_type = content_type or "application/octet-stream"

Expand All @@ -30,8 +80,21 @@ def __init__(
else:
self.file_name = file_name

def as_bytes(self) -> bytes:
if self._cached_content:
return self._cached_content

RepositoryId = Literal["fal", "in_memory", "gcp_storage", "r2"]
content: bytes | str = self.data.read()
# For files that open in text mode, convert to bytes
if isinstance(content, str):
content = content.encode()

self._cached_content = content

if content:
self.file_size = len(content)

return content


@mainify
Expand Down
Loading
Loading