Skip to content

Commit

Permalink
refactor: fal image
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Rochetti <[email protected]>
  • Loading branch information
badayvedat and drochetti committed Dec 26, 2023
1 parent 6347826 commit 8b9357a
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 36 deletions.
71 changes: 69 additions & 2 deletions projects/fal/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions projects/fal/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ types-python-dateutil = "^2.8.0"
# For 3.9 and earlier, importlib-metadata's newer versions are included in the standard library.
importlib-metadata = { version = ">=4.4", python = "<3.10" }
boto3 = "^1.33.8"
pillow = "^10.1.0"

[tool.poetry.group.dev.dependencies]
openapi-python-client = "^0.14.1"
Expand Down
2 changes: 1 addition & 1 deletion projects/fal/src/fal/toolkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from fal.toolkit.file import CompressedFile, File
from fal.toolkit.image.image import Image, ImageSizeInput, get_image_size
from fal.toolkit.image import Image, ImageSizeInput, get_image_size
from fal.toolkit.mainify import mainify
from fal.toolkit.utils import (
FAL_MODEL_WEIGHTS_DIR,
Expand Down
46 changes: 30 additions & 16 deletions projects/fal/src/fal/toolkit/image/image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

import io
from io import BytesIO
from typing import TYPE_CHECKING, Literal, Optional, Union

from fal.toolkit.file.file import DEFAULT_REPOSITORY, File
from fal.toolkit.file.types import FileData, FileRepository, RepositoryId
from fal.toolkit.file.types import FileRepository, RepositoryId
from fal.toolkit.mainify import mainify
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -44,7 +43,7 @@ class ImageSize(BaseModel):

ImageSizeInput = Union[ImageSize, ImageSizePreset]


@mainify
def get_image_size(source: ImageSizeInput) -> ImageSize:
if isinstance(source, ImageSize):
return source
Expand All @@ -54,8 +53,6 @@ def get_image_size(source: ImageSizeInput) -> ImageSize:
raise TypeError(f"Invalid value for ImageSize: {source}")


get_image_size.__module__ = "__main__"

ImageFormat = Literal["png", "jpeg", "jpg", "webp", "gif"]


Expand All @@ -77,21 +74,25 @@ class Image(File):
def from_bytes( # type: ignore[override]
cls,
data: bytes,
format: ImageFormat,
size: ImageSize | None = None,
format: ImageFormat | None = None,
file_name: str | None = None,
repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
) -> Image:
file_data = FileData(
data=BytesIO(data), content_type=f"image/{format}", file_name=file_name
)
return cls(
file_data=file_data,
from PIL import Image as PILImage

pil_image = PILImage.open(BytesIO(data))

return cls.from_pil(
pil_image=pil_image,
format=format,
file_name=file_name,
repository=repository,
width=size.width if size else None,
height=size.height if size else None,
)

@classmethod
def _from_url(cls, url: str):
return super()._from_url(url)

@classmethod
def from_pil(
cls,
Expand All @@ -101,9 +102,12 @@ def from_pil(
repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
) -> Image:
size = ImageSize(width=pil_image.width, height=pil_image.height)

if format is None:
format = pil_image.format or "png" # type: ignore[assignment]
assert format # for type checker

content_type = f"image/{format}"

saving_options = {}
if format == "png":
Expand All @@ -113,8 +117,18 @@ def from_pil(
# efficiently.
saving_options["compress_level"] = 1

with io.BytesIO() as f:
with BytesIO() as f:
pil_image.save(f, format=format, **saving_options)
raw_image = f.getvalue()

return cls.from_bytes(raw_image, format, size, file_name, repository)
return super().from_bytes(
data=raw_image,
repository=repository,
content_type=content_type,
file_name=file_name,
)

def to_pil(self) -> PILImage.Image:
from PIL import Image as PILImage
image_buffer = BytesIO(self.as_bytes())
return PILImage.open(image_buffer)
141 changes: 124 additions & 17 deletions projects/fal/tests/toolkit/image_test_requires_pil.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,139 @@
from __future__ import annotations

import base64
import io
from base64 import b64encode
from io import BytesIO

from fal.toolkit.image.image import Image
import pytest
from fal.toolkit import Image, mainify
from PIL import Image as PILImage
from pydantic import BaseModel, Field


# taken from chatgpt
def images_are_equal(img1: PILImage.Image, img2: PILImage.Image) -> bool:
if img1.size != img2.size:
return False
@mainify
def get_image(as_pil: bool = True):
pil_image = PILImage.new("RGB", (1, 1), (255, 255, 255))
if as_pil:
return pil_image

return pil_image_to_bytes(pil_image)


@mainify
def pil_image_to_bytes(image: PILImage.Image) -> bytes:
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
return image_bytes.getvalue()


def fal_image_downloaded(image: Image):
return image.file_size != None


def fal_image_url_matches(image: Image, url: str):
return image.url == url


def fal_image_content_matches(image: Image, content: bytes):
image1 = PILImage.open(BytesIO(image.as_bytes()))
image2 = PILImage.open(BytesIO(content))
return images_are_equal(image1, image2)


@mainify
def image_to_data_uri(image: PILImage.Image) -> str:
image_bytes = pil_image_to_bytes(image)
b64_encoded = b64encode(image_bytes).decode("utf-8")
return f"data:image/png;base64,{b64_encoded}"


def images_are_equal(img1: PILImage.Image, img2: PILImage.Image) -> bool:
pixels1 = list(img1.getdata())
pixels2 = list(img2.getdata())
return pixels1 == pixels2

for p1, p2 in zip(pixels1, pixels2):
if p1 != p2:
return False

return True
def assert_fal_images_equal(fal_image_1: Image, fal_image_2: Image):
assert (
fal_image_1.file_size == fal_image_2.file_size
), "Image file size should match"
assert (
fal_image_1.content_type == fal_image_2.content_type
), "Content type should match"
assert fal_image_1.url == fal_image_2.url, "URL should match"
assert fal_image_1.width == fal_image_2.width, "Width should match"
assert fal_image_1.height == fal_image_2.height, "Height should match"


def test_image_matches():
# 1x1 white png image
base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx5QAAAABJRU5ErkJggg=="
image_bytes = base64.b64decode(base64_image)
pil_image = PILImage.open(io.BytesIO(image_bytes))
image_file: Image = Image.from_pil(pil_image, format="png", repository="in_memory")
output_pil_image = PILImage.open(io.BytesIO(image_file.as_bytes()))
pil_image = get_image()

image_file = Image.from_pil(pil_image, repository="in_memory")
output_pil_image = PILImage.open(BytesIO(image_file.as_bytes()))

assert images_are_equal(output_pil_image, pil_image)


def test_fal_image_from_pil(isolated_client):
def fal_image_from_pil():
pil_image = get_image()
return Image.from_pil(pil_image, repository="in_memory")

@isolated_client(requirements=["pillow", "pydantic==1.10.12"])
def fal_image_from_bytes_remote():
return fal_image_from_pil()

local_image = fal_image_from_pil()
remote_image = fal_image_from_bytes_remote()

assert fal_image_content_matches(remote_image, get_image(as_pil=False))

assert_fal_images_equal(local_image, remote_image)


def test_fal_image_from_bytes(isolated_client):
image_bytes = get_image(as_pil=False)

def fal_image_from_bytes():
return Image.from_bytes(image_bytes, repository="in_memory")

@isolated_client(requirements=["pillow", "pydantic==1.10.12"])
def fal_image_from_bytes_remote():
return fal_image_from_bytes()

local_image = fal_image_from_bytes()
remote_image = fal_image_from_bytes_remote()

assert fal_image_content_matches(remote_image, image_bytes)
assert_fal_images_equal(local_image, remote_image)


@pytest.mark.parametrize(
"image_url",
[
"https://storage.googleapis.com/falserverless/model_tests/remove_background/elephant.jpg",
image_to_data_uri(get_image()),
],
)
def test_fal_image_input(isolated_client, image_url):
class TestInput(BaseModel):
image: Image = Field()

def test_input():
return TestInput(image=image_url).image

@isolated_client(requirements=["pillow", "pydantic==1.10.12"])
def test_input_remote():
return test_input()

local_input_image = test_input()
remote_input_image = test_input_remote()

# Image is not downloaded until it is needed
assert not fal_image_downloaded(local_input_image)
assert not fal_image_downloaded(remote_input_image)

assert fal_image_url_matches(local_input_image, image_url)

# Image will be downloaded when trying to access its content
assert_fal_images_equal(local_input_image, remote_input_image)
assert fal_image_downloaded(local_input_image)

0 comments on commit 8b9357a

Please sign in to comment.