-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into pydantic_v2_migration
- Loading branch information
Showing
10 changed files
with
1,606 additions
and
346 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import Any, TypedDict | ||
|
||
import PIL | ||
import torch | ||
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline | ||
from pydantic import BaseModel, Field | ||
from ray import serve | ||
|
||
from aana.deployments.base_deployment import BaseDeployment | ||
from aana.models.core.dtype import Dtype | ||
from aana.models.pydantic.prompt import Prompt | ||
|
||
|
||
class StableDiffusion2Output(TypedDict): | ||
"""Output class for the StableDiffusion2 deployment.""" | ||
|
||
image: PIL.Image.Image | ||
|
||
|
||
class StableDiffusion2Config(BaseModel): | ||
"""The configuration for the Stable Diffusion 2 deployment. | ||
Attributes: | ||
model (str): the model ID on HuggingFace | ||
dtype (str): the data type (optional, default: "auto"), one of "auto", "float32", "float16" | ||
""" | ||
|
||
model: str | ||
dtype: Dtype = Field(default=Dtype.AUTO) | ||
|
||
|
||
@serve.deployment | ||
class StableDiffusion2Deployment(BaseDeployment): | ||
"""Stable Diffusion 2 deployment.""" | ||
|
||
async def apply_config(self, config: dict[str, Any]): | ||
"""Apply the configuration. | ||
The method is called when the deployment is created or updated. | ||
It loads the model and scheduler from HuggingFace. | ||
The configuration should conform to the StableDiffusion2Confgi schema. | ||
""" | ||
config_obj = StableDiffusion2Config(**config) | ||
|
||
# Load the model and processor from HuggingFace | ||
self.model_id = config_obj.model | ||
self.dtype = config_obj.dtype | ||
if self.dtype == Dtype.INT8: | ||
self.torch_dtype = Dtype.FLOAT16.to_torch() | ||
else: | ||
self.torch_dtype = self.dtype.to_torch() | ||
self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
self.model = StableDiffusionPipeline.from_pretrained( | ||
self.model_id, | ||
torch_dtype=self.torch_dtype, | ||
scheduler=EulerDiscreteScheduler.from_pretrained( | ||
self.model_id, subfolder="scheduler" | ||
), | ||
device_map="auto", | ||
) | ||
|
||
self.model.to(self.device) | ||
|
||
async def generate(self, prompt: Prompt) -> StableDiffusion2Output: | ||
"""Runs the model on a given prompt and returns the first output. | ||
Arguments: | ||
prompt (Prompt): the prompt to the model. | ||
Returns: | ||
StableDiffusion2Output: a dictionary with one key containing the result | ||
""" | ||
image = self.model(str(prompt)).images[0] | ||
return {"image": image} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from pathlib import Path | ||
from typing import TypedDict | ||
|
||
|
||
class PathResult(TypedDict): | ||
"""Represents a path result describing a file on disk.""" | ||
|
||
path: Path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from pathlib import Path | ||
from uuid import uuid4 | ||
|
||
import PIL.Image | ||
|
||
from aana.configs.settings import settings | ||
from aana.models.core.file import PathResult | ||
|
||
|
||
def save_image(image: PIL.Image.Image, full_path: Path | None = None) -> PathResult: | ||
"""Saves an image to the given full path, or randomely generates one if no path is supplied. | ||
Arguments: | ||
image (Image): the image to save | ||
full_path (Path|None): the path to save the image to. If None, will generate one randomly. | ||
Returns: | ||
PathResult: contains the path to the saved image. | ||
""" | ||
if not full_path: | ||
full_path = settings.image_dir / f"{uuid4()}.png" | ||
image.save(full_path) | ||
return {"path": full_path} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.