Skip to content

Commit

Permalink
Incorporate PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
evanderiel committed Feb 19, 2024
1 parent cdf1f2c commit db47d63
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 34 deletions.
5 changes: 3 additions & 2 deletions aana/configs/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,9 @@
name="imagegen",
path="/generate_image",
summary="Generates an image from a text prompt",
outputs=[EndpointOutput(name="image", output="stablediffusion2-image")],
outputs=[
EndpointOutput(name="image_path", output="image_path_stablediffusion2")
],
)
],

}
2 changes: 1 addition & 1 deletion aana/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@
"inputs": [{"name": "prompt", "key": "prompt", "path": "prompt"}],
"outputs": [
{
"name": "stablediffusion2-image",
"name": "image_stablediffusion2",
"key": "image",
"path": "stablediffusion2-image",
}
Expand Down
20 changes: 13 additions & 7 deletions aana/deployments/stablediffusion2_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@


class StableDiffusion2Output(TypedDict):
"""Output class."""
"""Output class for the StableDiffusion2 deployment."""

data: Any
image: Any


class StableDiffusion2Config(BaseModel):
Expand All @@ -26,8 +26,6 @@ class StableDiffusion2Config(BaseModel):

model: str
dtype: Dtype = Field(default=Dtype.AUTO)
batch_size: int = Field(default=1)
num_processing_threads: int = Field(default=1)


@serve.deployment
Expand All @@ -39,9 +37,9 @@ async def apply_config(self, config: dict[str, Any]):
The method is called when the deployment is created or updated.
It loads the model and processor from HuggingFace.
It loads the model and scheduler from HuggingFace.
The configuration should conform to the HFBlip2Config schema.
The configuration should conform to the StableDiffusion2Confgi schema.
"""
config_obj = StableDiffusion2Config(**config)

Expand All @@ -59,11 +57,19 @@ async def apply_config(self, config: dict[str, Any]):
scheduler=EulerDiscreteScheduler.from_pretrained(
self.model_id, subfolder="scheduler"
),
device_map=self.device,
)

self.model.to(self.device)

async def generate(self, prompt: Prompt) -> StableDiffusion2Output:
"""Generates output."""
"""Runs the model on a given prompt and returns the first output.
Arguments:
prompt (Prompt): the prompt to the model.
Returns:
StableDiffusion2Output: a dictionary with one key containing the result
"""
image = self.model(str(prompt)).images[0]
return {"image": image}
8 changes: 8 additions & 0 deletions aana/models/core/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pathlib import Path
from typing import TypedDict


class PathResult(TypedDict):
"""Represents a path result describing a file on disk."""

path: Path
8 changes: 8 additions & 0 deletions aana/utils/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pathlib import Path
from aana.models.core.file import PathResult
from aana.models.core.image import Image


def save_image(image: Image, full_path: Path) -> PathResult
image.save_from_content(full_path)
return {"path": full_path}
72 changes: 48 additions & 24 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# Aana SDK Getting Started

This is the living version of this document. A [previous draft](https://docs.google.com/document/d/1z1y7Gxo1RwL_9MyTzRVvwA6tsOaVz8JKzlOA9Bxq2FU/edit#heading=h.70hndo88ymuh) lived on Google Docs.


## Code overview

aana/ - top level source code directory for the project
Expand Down Expand Up @@ -62,9 +58,9 @@ This is the living version of this document. A [previous draft](https://docs.goo

## Adding a New Model

A ray deployment is a standardized interface for any kind of functionality that needs to manage state (primarily, but not limited to, an AI model that needs to be fetched and loaded onto a GPU). New deployments should inherit from aana.deployments.base_deployment.BaseDeployment, and be in the aana/deployments folder. Additionally, a deployment config for the deployment will have to be added to the aana/configs/deployments.py file. Here's a simple example using Stable Diffusion 2:
A ray deployment is a standardized interface for any kind of functionality that needs to manage state (primarily, but not limited to, an AI model that needs to be fetched and loaded onto a GPU). New deployments should inherit from aana.deployments.base_deployment.BaseDeployment, and be in the aana/deployments folder. Additionally, a deployment config for the deployment will have to be added to the `[aana/configs/deployments.py](aana/configs/deployments.py)` file. Here's a simple example using Stable Diffusion 2:

aana/configs/deployments.py:
[aana/configs/deployments.py](aana/configs/deployments.py):
```python
deployments = {
"stablediffusion2_deployment": StableDiffusion2Deployment.options(
Expand All @@ -77,7 +73,7 @@ deployments = {
),
}
```
aana/deployments/stablediffusion2_deployment.py:
[aana/deployments/stablediffusion2_deployment.py](aana/deployments/stablediffusion2_deployment.py):
```python
from typing import TYPE_CHECKING, Any, TypedDict

Expand Down Expand Up @@ -168,8 +164,10 @@ For this simple pipeline, we only need three nodes:

Step 3 is actually only necessary because the SDK doesn't return binary files yet. At some point we expect this to be supported, so then this would only need two steps to run.

Here are nodes we need (`aana/configs/pipeline.py`):
Here are nodes we need for `[aana/configs/pipeline.py](aana/configs/pipeline.py)`:
```python
from aana.models.pydantic.prompt import Prompt

nodes = [
{
"name": "prompt",
Expand Down Expand Up @@ -242,7 +240,7 @@ endpoints = {
}
```

Finally, we need to write the save_image function we referenced above (`aana/utils/image.py`)
Finally, we need to write the save_image function we referenced above to `[aana/utils/image.py](aana/utils/image.py)`

```python
from pathlib import Path
Expand All @@ -262,6 +260,37 @@ def save_image(image: Image) -> dict[str, Path]:
return {"path": full_path}
```

Now we have everything to run the SDK and send a request

## Running the SDK

Poetry handles keeping all dependencies out of the way as building the SDK into a module, so you can run the SDK as follows:

```bash
poetry run aana [--host <host>] [--port <port>] --target <target>
```

In our case, our target is called `stablediffusion2` and we're fine with the default host and port (0.0.0.0 and 8000, respectively) so we can just say

```bash
poetry run aana --target stablediffusion2
```

There will be lots of messages about various components starting up, but if everything has been configured correctly, it will ed with `Deployed Serve app successfully.` That is the cue to begin sending requests.

Documentation is also automatically generated and available at http://{port}:{host}/docs and http://{port}:{host}/redoc, depending on which format you prefer.


## Sending a request to the SDK

The SDK currently accepts only one kind of request: an HTTP POST using form data, with the `body` form element a string JSON encoding of the input data. For example, using `curl`:

```bash
curl -X POST 0.0.0.0:8000/generate_image -F body='{"prompt": "a dog"}'
```

Since JSON input typically uses `"`, `{}` and `[]` characters, which are specially interpreted by the shell, as well as spaces, it is easiest to wrap the whole JSON in single quotes, `'{"like": "this"}'`.

# More complicated - Blip2 video captioning

Here's a more complicated example for video captioning.
Expand All @@ -283,7 +312,7 @@ deployments = {
).dict(),
),
```
aana/deployments/hfblip2_deployment.py:
[aana/deployments/hfblip2_deployment.py](aana/deployments/hfblip2_deployment.py):

```python
class HFBlip2Config(BaseModel): # BaseModel makes sure it's serializeable
Expand Down Expand Up @@ -435,7 +464,7 @@ A typical workflow with one inference stage might be:

![](diagram.png)

Here's an example of these for a video processing pipeline(aana/config/pipeline.py):
Here's an example of these for a video processing pipeline ([aana/config/pipeline.py](aana/config/pipeline.py)):
```python
# Input
{
Expand Down Expand Up @@ -588,7 +617,7 @@ Here's an example of these for a video processing pipeline(aana/config/pipeline.

## Adding endpoints

Now we're almost done. The last stage is to add a run target with endpoints that refer to the node inputs and outputs (aana/config/endpoints.py).
Now we're almost done. The last stage is to add a run target with endpoints that refer to the node inputs and outputs ([aana/config/endpoints.py](aana/config/endpoints.py)).

```python
endpoints = {
Expand Down Expand Up @@ -618,7 +647,7 @@ endpoints = {

Okay, can you figure out what we might have forgotten?

We created a postprocessing step to save the video captions to the DB, but since we didn't include its output in the endpoint definition ("captions_ids"), that step **won't** run. It will get the output of the captions model, determine that no more steps are needed, and return it to the user. A working endpoint that wanted to include `EndpointOutput(name="caption_ids", output="caption_ids")` in the list of outputs.
We created a postprocessing step to save the video captions to the DB, but since we didn't include its output in the endpoint definition ("captions_ids"), that step **won't** run. It will get the output of the captions model, determine that no more steps are needed, and return it to the user. A working endpoint that did everything we wanted would have to include `EndpointOutput(name="caption_ids", output="caption_ids")` in the list of outputs.


## Saving to a DB
Expand All @@ -628,7 +657,7 @@ Aana SDK is designed to have two databases, a structured database layer with res

### Saving to datastore

You will need to add database entity models as a subclass of `aana.models.db.base.BaseEntity` to a class file in `aana/models/db/`. Additionally, to avoid import issues, you will need to import that model inside `aana/models/db/__init__.py`.
You will need to add database entity models as a subclass of `aana.models.db.base.BaseEntity` to a class file in `[aana/models/db/](aana/models/db/)`. Additionally, to avoid import issues, you will need to import that model inside `[aana/models/db/__init__.py](aana/models/db/__init__.py)`.

Once you have defined your model, you will need to create an alembic migration to create the necessary table and modify other tables if necessary. Do this just by running

Expand All @@ -643,7 +672,7 @@ The app will automatically run the migration when you start up, so the rest is t

We wrap access to the datastore in a repository class. There is a generic BaseRepository that provides the following methods: create, create_multiple, read (by id), delete (by id). If you want to fetch by another parameter (for example by a parent object id). Update logic is TODO since the semantics of in-place updates with the SqlAlchemy ORM is a bit complex.

(aana/repository/datastore/caption_repository.py)
([aana/repository/datastore/caption_repository.py](aana/repository/datastore/caption_repository.py))

```python
class CaptionsRepository(BaseRepository[CaptionEntity]):
Expand All @@ -656,7 +685,7 @@ class CaptionsRepository(BaseRepository[CaptionEntity]):

That's it!

To make a repository work with the pipeline, it's easiest to wrap the repository actions in helper functions, like so (aana/utils/db.py):
To make a repository work with the pipeline, it's easiest to wrap the repository actions in helper functions, like so ([aana/utils/db.py](aana/utils/db.py)):

```python
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -688,23 +717,18 @@ TODO


### Running the SDK

So, you have created a new deployment for your model, defined the pipeline nodes for it, and created the endpoints to call it. How do you run the SDK?

The "normal" way to run the SDK right now is to call the SDK as a module with following syntax:

As before, you can run the SDK the usual way:
```bash
poetry run python aana --host 0.0.0.0 --port 8000 --target blip2
```

Host and port are optional; the defaults are `0.0.0.0` and `8000`, respectively, but you must give a deployment target or the SDK doesn't know what to run. Once the SDK has initialized the pipeline, downloaded any remote resources necessary, and loaded the model weights, it will print "Deployed Serve app successfully." and that is the cue that it is ready to serve traffic and perform tasks, including inference. Documentation is also automatically generated and available at http://{port}:{host}/docs and http://{port}:{host}/redoc, depending on which format you prefer.

Once the SDK has initialized the pipeline, downloaded any remote resources necessary, and loaded the model weights, it will print "Deployed Serve app successfully." and that is the cue that it is ready to serve traffic and perform tasks, including inference.

## Tests

Write unit tests for every freestand function or task function you add. Write unit tests for deployment methods or static functions that transform data without sending it to the inference engine (and refactor the deployment code so that as much functionality as possible is modularized so that it may be tested).

Additionally, please write some tests for the `tests/deployments` folder that will load your deployment and programmatically run some inputs through it to validate that the deployment itself works as expected. Note, however, that due to the size and complexity of loading deployments that this might fail even if the logic is correct, if for example the user is running on a machine a GPU that is too small for the model.
Additionally, please write some tests for the `[tests/deployments](tests/deployments)` folder that will load your deployment and programmatically run some inputs through it to validate that the deployment itself works as expected. Note, however, that due to the size and complexity of loading deployments that this might fail even if the logic is correct, if for example the user is running on a machine a GPU that is too small for the model.

You can tell PyTest to skip tests under certain conditions. For example, if there is a deployment that doesn't make sense to run without a GPU, you can use the decorator `@pytest.mark.skipif(not is_gpu_available(), reason="GPU is not available")` to tell pytest not to run it if there's no GPU available. The function `is_gpu_available()` is defined in aana.tests.utils.py.

Expand Down

0 comments on commit db47d63

Please sign in to comment.