Skip to content

Commit

Permalink
Merge branch 'main' into pydantic_v2_migration
Browse files Browse the repository at this point in the history
  • Loading branch information
movchan74 committed Feb 26, 2024
2 parents cddfd71 + d9a62fa commit 55ce3bc
Show file tree
Hide file tree
Showing 10 changed files with 1,606 additions and 346 deletions.
13 changes: 13 additions & 0 deletions aana/configs/deployments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from aana.deployments.hf_blip2_deployment import HFBlip2Config, HFBlip2Deployment
from aana.deployments.stablediffusion2_deployment import (
StableDiffusion2Config,
StableDiffusion2Deployment,
)
from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment
from aana.deployments.whisper_deployment import (
WhisperComputeType,
Expand Down Expand Up @@ -46,4 +50,13 @@
compute_type=WhisperComputeType.FLOAT16,
).model_dump(),
),
"stablediffusion2_deployment": StableDiffusion2Deployment.options(
num_replicas=1,
max_concurrent_queries=1000,
ray_actor_options={"num_gpus": 1},
user_config=StableDiffusion2Config(
model="stabilityai/stable-diffusion-2",
dtype=Dtype.FLOAT16,
).dict(),
),
}
13 changes: 13 additions & 0 deletions aana/configs/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,17 @@
],
),
],
"stablediffusion2": [
Endpoint(
name="imagegen",
path="/generate_image",
summary="Generates an image from a text prompt",
outputs=[
EndpointOutput(
name="image_path_stablediffusion2",
output="image_path_stablediffusion2",
)
],
)
],
}
36 changes: 36 additions & 0 deletions aana/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
It is used to generate the pipeline and the API endpoints.
"""
import PIL.Image

from aana.models.pydantic.asr_output import (
AsrSegments,
Expand Down Expand Up @@ -721,6 +722,41 @@
}
],
},
{
"name": "stable-diffusion-2-imagegen",
"type": "ray_deployment",
"deployment_name": "stablediffusion2_deployment",
"method": "generate",
"inputs": [{"name": "prompt", "key": "prompt", "path": "prompt"}],
"outputs": [
{
"name": "image_stablediffusion2",
"key": "image",
"path": "stablediffusion2-image",
"data_model": PIL.Image.Image,
}
],
},
{
"name": "save_image_stablediffusion2",
"type": "function",
"function": "aana.utils.image.save_image",
"dict_output": True,
"inputs": [
{
"name": "image_stablediffusion2",
"key": "image",
"path": "stablediffusion2-image",
},
],
"outputs": [
{
"name": "image_path_stablediffusion2",
"key": "path",
"path": "image_path",
}
],
},
{
"name": "save_video",
"type": "function",
Expand Down
76 changes: 76 additions & 0 deletions aana/deployments/stablediffusion2_deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Any, TypedDict

import PIL
import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
from pydantic import BaseModel, Field
from ray import serve

from aana.deployments.base_deployment import BaseDeployment
from aana.models.core.dtype import Dtype
from aana.models.pydantic.prompt import Prompt


class StableDiffusion2Output(TypedDict):
"""Output class for the StableDiffusion2 deployment."""

image: PIL.Image.Image


class StableDiffusion2Config(BaseModel):
"""The configuration for the Stable Diffusion 2 deployment.
Attributes:
model (str): the model ID on HuggingFace
dtype (str): the data type (optional, default: "auto"), one of "auto", "float32", "float16"
"""

model: str
dtype: Dtype = Field(default=Dtype.AUTO)


@serve.deployment
class StableDiffusion2Deployment(BaseDeployment):
"""Stable Diffusion 2 deployment."""

async def apply_config(self, config: dict[str, Any]):
"""Apply the configuration.
The method is called when the deployment is created or updated.
It loads the model and scheduler from HuggingFace.
The configuration should conform to the StableDiffusion2Confgi schema.
"""
config_obj = StableDiffusion2Config(**config)

# Load the model and processor from HuggingFace
self.model_id = config_obj.model
self.dtype = config_obj.dtype
if self.dtype == Dtype.INT8:
self.torch_dtype = Dtype.FLOAT16.to_torch()
else:
self.torch_dtype = self.dtype.to_torch()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = StableDiffusionPipeline.from_pretrained(
self.model_id,
torch_dtype=self.torch_dtype,
scheduler=EulerDiscreteScheduler.from_pretrained(
self.model_id, subfolder="scheduler"
),
device_map="auto",
)

self.model.to(self.device)

async def generate(self, prompt: Prompt) -> StableDiffusion2Output:
"""Runs the model on a given prompt and returns the first output.
Arguments:
prompt (Prompt): the prompt to the model.
Returns:
StableDiffusion2Output: a dictionary with one key containing the result
"""
image = self.model(str(prompt)).images[0]
return {"image": image}
8 changes: 8 additions & 0 deletions aana/models/core/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pathlib import Path
from typing import TypedDict


class PathResult(TypedDict):
"""Represents a path result describing a file on disk."""

path: Path
23 changes: 23 additions & 0 deletions aana/utils/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path
from uuid import uuid4

import PIL.Image

from aana.configs.settings import settings
from aana.models.core.file import PathResult


def save_image(image: PIL.Image.Image, full_path: Path | None = None) -> PathResult:
"""Saves an image to the given full path, or randomely generates one if no path is supplied.
Arguments:
image (Image): the image to save
full_path (Path|None): the path to save the image to. If None, will generate one randomly.
Returns:
PathResult: contains the path to the saved image.
"""
if not full_path:
full_path = settings.image_dir / f"{uuid4()}.png"
image.save(full_path)
return {"path": full_path}
Binary file added docs/diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 55ce3bc

Please sign in to comment.