Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fal.App for multiple endpoints #27

Merged
merged 3 commits into from
Jan 10, 2024
Merged

Conversation

isidentical
Copy link
Collaborator

@isidentical isidentical commented Jan 9, 2024

mainly for discussions on the idea, not the concrete implementation Let's see if we can get it into a decent state to start using it.

import fal
from fal.toolkit import Image, ImageSizeInput, get_image_size
from pydantic import BaseModel, Field


class InputModel(BaseModel):
    prompt: str
    seed: int = Field(default=42, ge=0, le=2**32 - 1)


class OutputModel(BaseModel):
    images: list[Image]
    seed: int


class Text2ImageInputModel(InputModel):
    image_size: ImageSizeInput = "square_hd"


class Image2ImageInputModel(InputModel):
    image_url: str
    strength: float = Field(default=0.5, ge=0.0, le=1.0)


class InpaintingInputModel(InputModel):
    image_url: str
    mask_url: str
    strength: float = Field(default=0.5, ge=0.0, le=1.0)


class StableDiffusion(fal.App, _scheduler="nomad"):
    machine_type = "GPU"
    requirements = [
        "diffusers==0.25.0",
        "transformers",
        "torch>=2.1",
        "accelerate",
    ]

    def setup(self):
        import torch
        from diffusers import (
            AutoPipelineForText2Image,
            AutoPipelineForImage2Image,
            AutoPipelineForInpainting,
        )

        self.pipeline_text2img = AutoPipelineForText2Image.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            torch_dtype=torch.float16,
            use_safetensors=True,
        ).to("cuda")
        self.pipeline_img2img = AutoPipelineForImage2Image.from_pipe(
            self.pipeline_text2img
        )
        self.pipeline_inpainting = AutoPipelineForInpainting.from_pipe(
            self.pipeline_text2img
        )

    def parse_image_url(self, url: str) -> object:
        from urllib.request import urlopen
        from PIL import Image

        with urlopen(url) as stream:
            return Image.open(stream)

    @fal.endpoint("/text-to-image")
    def text_to_image(self, input: Text2ImageInputModel) -> OutputModel:
        import torch

        image_size = get_image_size(input.image_size)
        result = self.pipeline_text2img(
            prompt=input.prompt,
            generator=torch.Generator("cuda").manual_seed(input.seed),
            width=image_size.width,
            height=image_size.height,
        )
        return OutputModel(
            images=[Image.from_pil(image) for image in result.images],
            seed=input.seed,
        )

    @fal.endpoint("/image-to-image")
    def image_to_image(self, input: Image2ImageInputModel) -> OutputModel:
        import torch

        result = self.pipeline_img2img(
            prompt=input.prompt,
            image=self.parse_image_url(input.image_url),
            generator=torch.Generator("cuda").manual_seed(input.seed),
            strength=input.strength,
        )
        return OutputModel(
            images=[Image.from_pil(image) for image in result.images],
            seed=input.seed,
        )

    @fal.endpoint("/inpainting")
    def inpainting(self, input: InpaintingInputModel) -> OutputModel:
        import torch

        result = self.pipeline_inpainting(
            prompt=input.prompt,
            image=self.parse_image_url(input.image_url),
            mask=self.parse_image_url(input.mask_url),
            generator=torch.Generator("cuda").manual_seed(input.seed),
            strength=input.strength,
        )
        return OutputModel(
            images=[Image.from_pil(image) for image in result.images],
            seed=input.seed,
        )


if __name__ == "__main__":
    # SDK usage, TBD
    app = StableDiffusion() # returns a shallow proxy object which
                            # when called, will create the stateful
                            # app if it is not already in the process
                            # and treat the calls as if they were coming
                            # in a local python process.
    result = app.text_to_image(...)
$ fal run t.py StableDiffusion

@mederka
Copy link
Contributor

mederka commented Jan 9, 2024

Is this how you deploy?

fal run t.py StableDiffusion

@isidentical
Copy link
Collaborator Author

no, that's how you start the local test server (compared to just calling it we have now), the deployment flow is the same:

❯ fal fn run t.py StableDiffusion
2024-01-09 21:39:08.976 [info     ] !!! HEADS UP !!! Scheduling a job with <isolate_controller.scheduler.nomad.scheduler.NomadJobScheduler object at 0x7f7fa1516ad0>
2024-01-09 21:39:11.370 [info     ] Access your exposed service at https://5d0299e5-479d-48a3-b26a-8fb38207ba17.gateway.alpha.fal.ai
2024-01-09 21:39:17.996 [stderr   ] 
2024-01-09 21:39:17.996 [stderr   ] Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]
2024-01-09 21:39:18.714 [stderr   ] 
...

❯ fal fn serve t.py StableDiffusion --alias lolz-diffusion       
Registered a new revision for function 'lolz-diffusion' (revision='4eaeab3c-1619-4ef5-97fa-5dcc8edd7488').
URL: https://47358913-lolz-diffusion.gateway.alpha.fal.ai

@mederka
Copy link
Contributor

mederka commented Jan 9, 2024

How come _scheduler is a parameter?

@isidentical
Copy link
Collaborator Author

How come _scheduler is a parameter?

There is a set of generic parameters (the same set of arguments we have at @fal.function) and a set of host-specific parameters (fal vs local). For the latter we need to either use a dict (host_options = {"a": "b"}) or pass them as parameters in the class, although the former can also be supported in the class class T(fal.App, requirements=[...], a=b).

Open for comments on this, we can have a single way to configure everything (but it might become super cluttered due to requirements) OR have two different segments which configure two different things (the current way).

@mederka
Copy link
Contributor

mederka commented Jan 9, 2024

So, when testing locally, I might need to test the fal.endpoints locally as well. How would we do this?

@isidentical
Copy link
Collaborator Author

Given the application above, when developing, you just edit your code and run fal fn run t.py StableDiffusion (which gives you an endpoint you can perform tests against). Once you are ready to deploy, you use the fal fn serve t.py StableDiffusion as if it was just a single @fal.function (w/same options etc.).

@isidentical isidentical force-pushed the multiple-endpoints branch 2 times, most recently from 1adae46 to 31f9c63 Compare January 10, 2024 15:50
@isidentical isidentical marked this pull request as ready for review January 10, 2024 15:56
@isidentical isidentical changed the title wip: feat: fal.App for multiple endpoints feat: fal.App for multiple endpoints Jan 10, 2024
def marker_fn(callable: EndpointT) -> EndpointT:
if hasattr(callable, "route_signature"):
raise ValueError(
f"Can't set multiple routes for the same function: {callable.__name__}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't handle it yet, but in the future I want to be able to support multiple endpoints for a single function by stacking @fal.endpoint. this is done to reserve that use case (instead of treating it as an override which we would need to break)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha ok 👍

Copy link
Contributor

@squat squat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great idea and work!

@isidentical isidentical merged commit 74afc08 into main Jan 10, 2024
4 checks passed
@isidentical isidentical deleted the multiple-endpoints branch January 10, 2024 20:19
Comment on lines +78 to +146
def _build_app(self) -> FastAPI:
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

_app = FastAPI()

_app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_headers=("*"),
allow_methods=("*"),
allow_origins=("*"),
)

routes: dict[RouteSignature, Callable[..., Any]] = {
signature: endpoint
for _, endpoint in inspect.getmembers(self, inspect.ismethod)
if (signature := getattr(endpoint, "route_signature", None))
}
if not routes:
raise ValueError("An application must have at least one route!")

for signature, endpoint in routes.items():
_app.add_api_route(
signature.path,
endpoint,
name=endpoint.__name__,
methods=["POST"],
)

return _app

def openapi(self) -> dict[str, Any]:
"""
Build the OpenAPI specification for the served function.
Attach needed metadata for a better integration to fal.
"""
app = self._build_app()
spec = app.openapi()
self._mark_order_openapi(spec)
return spec

def _mark_order_openapi(self, spec: dict[str, Any]):
"""
Add x-fal-order-* keys to the OpenAPI specification to help the rendering of UI.

NOTE: We rely on the fact that fastapi and Python dicts keep the order of properties.
"""

def mark_order(obj: dict[str, Any], key: str):
obj[f"x-fal-order-{key}"] = list(obj[key].keys())

mark_order(spec, "paths")

def order_schema_object(schema: dict[str, Any]):
"""
Mark the order of properties in the schema object.
They can have 'allOf', 'properties' or '$ref' key.
"""
if "allOf" in schema:
for sub_schema in schema["allOf"]:
order_schema_object(sub_schema)
if "properties" in schema:
mark_order(schema, "properties")

for key in spec["components"].get("schemas") or {}:
order_schema_object(spec["components"]["schemas"][key])

return spec
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this whole thing already existed for serve functions, can we merge them?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants