Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(fal_client): use cdn v3 #344

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 155 additions & 18 deletions projects/fal_client/src/fal_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import asyncio
import time
import base64
import threading
from datetime import datetime, timezone
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any, AsyncIterator, Iterator, TYPE_CHECKING, Optional, Literal
Expand All @@ -23,10 +25,96 @@
RUN_URL_FORMAT = f"https://{FAL_RUN_HOST}/"
QUEUE_URL_FORMAT = f"https://queue.{FAL_RUN_HOST}/"
REALTIME_URL_FORMAT = f"wss://{FAL_RUN_HOST}/"
CDN_URL = "https://fal.media"
REST_URL = "https://rest.alpha.fal.ai"
CDN_URL = "https://v3.fal.media"
USER_AGENT = "fal-client/0.2.2 (python)"


@dataclass
class CDNToken:
token: str
token_type: str
base_upload_url: str
expires_at: datetime

def is_expired(self) -> bool:
return datetime.now(timezone.utc) >= self.expires_at


class CDNTokenManager:
def __init__(self, key: str) -> None:
self._key = key
self._token: CDNToken = CDNToken(
token="",
token_type="",
base_upload_url="",
expires_at=datetime.min.replace(tzinfo=timezone.utc),
)
self._lock: threading.Lock = threading.Lock()
self._url = f"{REST_URL}/storage/auth/token?storage_type=fal-cdn-v3"
self._headers = {
"Authorization": f"Key {self._key}",
"Accept": "application/json",
"Content-Type": "application/json",
}

def _refresh_token(self) -> CDNToken:
with httpx.Client() as client:
response = client.post(self._url, headers=self._headers, data=b"{}")
response.raise_for_status()
data = response.json()

return CDNToken(
token=data["token"],
token_type=data["token_type"],
base_upload_url=data["base_url"],
expires_at=datetime.fromisoformat(data["expires_at"]),
)

def get_token(self) -> CDNToken:
with self._lock:
if self._token.is_expired():
self._token = self._refresh_token()
return self._token


class AsyncCDNTokenManager:
def __init__(self, key: str) -> None:
self._key = key
self._token: CDNToken = CDNToken(
token="",
token_type="",
base_upload_url="",
expires_at=datetime.min.replace(tzinfo=timezone.utc),
)
self._lock: threading.Lock = threading.Lock()
self._url = f"{REST_URL}/storage/auth/token?storage_type=fal-cdn-v3"
self._headers = {
"Authorization": f"Key {self._key}",
"Accept": "application/json",
"Content-Type": "application/json",
}

async def _refresh_token(self) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(self._url, headers=self._headers, data=b"{}")
response.raise_for_status()
data = response.json()

return CDNToken(
token=data["token"],
token_type=data["token_type"],
base_upload_url=data["base_url"],
expires_at=datetime.fromisoformat(data["expires_at"]),
)

async def get_token(self) -> CDNToken:
with self._lock:
if self._token.is_expired():
self._token = await self._refresh_token()
return self._token


class FalClientError(Exception):
pass

Expand Down Expand Up @@ -265,13 +353,18 @@ class AsyncClient:
key: str | None = field(default=None, repr=False)
default_timeout: float = 120.0

@cached_property
def _client(self) -> httpx.AsyncClient:
def _get_key(self) -> str:
if self.key is None:
key = fetch_credentials()
else:
key = self.key
return fetch_credentials()
return self.key

@cached_property
def _token_manager(self) -> AsyncCDNTokenManager:
return AsyncCDNTokenManager(self._get_key())

@cached_property
def _client(self) -> httpx.AsyncClient:
key = self._get_key()
return httpx.AsyncClient(
headers={
"Authorization": f"Key {key}",
Expand All @@ -280,6 +373,16 @@ def _client(self) -> httpx.AsyncClient:
timeout=self.default_timeout,
)

async def _get_cdn_client(self) -> httpx.AsyncClient:
token = await self._token_manager.get_token()
return httpx.AsyncClient(
headers={
"Authorization": f"{token.token_type} {token.token}",
"User-Agent": USER_AGENT,
},
timeout=self.default_timeout,
)

async def run(
self,
application: str,
Expand Down Expand Up @@ -425,14 +528,22 @@ async def stream(
async for event in events.aiter_sse():
yield event.json()

async def upload(self, data: str | bytes, content_type: str) -> str:
async def upload(
self, data: str | bytes, content_type: str, file_name: str | None = None
) -> str:
"""Upload the given data blob to the CDN and return the access URL. The content type should be specified
as the second argument. Use upload_file or upload_image for convenience."""

response = await self._client.post(
client = await self._get_cdn_client()

headers = {"Content-Type": content_type}
if file_name is not None:
headers["X-Fal-File-Name"] = file_name

response = await client.post(
CDN_URL + "/files/upload",
data=data,
headers={"Content-Type": content_type},
headers=headers,
)
_raise_for_status(response)

Expand All @@ -446,7 +557,9 @@ async def upload_file(self, path: os.PathLike) -> str:
mime_type = "application/octet-stream"

with open(path, "rb") as file:
return await self.upload(file.read(), mime_type)
return await self.upload(
file.read(), mime_type, file_name=os.path.basename(path)
)

async def upload_image(self, image: Image.Image, format: str = "jpeg") -> str:
"""Upload a pillow image object to the CDN and return the access URL."""
Expand All @@ -461,12 +574,14 @@ class SyncClient:
key: str | None = field(default=None, repr=False)
default_timeout: float = 120.0

def _get_key(self) -> str:
if self.key is None:
return fetch_credentials()
return self.key

@cached_property
def _client(self) -> httpx.Client:
if self.key is None:
key = fetch_credentials()
else:
key = self.key
key = self._get_key()
return httpx.Client(
headers={
"Authorization": f"Key {key}",
Expand All @@ -475,6 +590,20 @@ def _client(self) -> httpx.Client:
timeout=self.default_timeout,
)

@cached_property
def _token_manager(self) -> CDNTokenManager:
return CDNTokenManager(self._get_key())

def _get_cdn_client(self) -> httpx.Client:
token = self._token_manager.get_token()
return httpx.Client(
headers={
"Authorization": f"{token.token_type} {token.token}",
"User-Agent": USER_AGENT,
},
timeout=self.default_timeout,
)

def run(
self,
application: str,
Expand Down Expand Up @@ -617,14 +746,22 @@ def stream(
for event in events.iter_sse():
yield event.json()

def upload(self, data: str | bytes, content_type: str) -> str:
def upload(
self, data: str | bytes, content_type: str, file_name: str | None = None
) -> str:
"""Upload the given data blob to the CDN and return the access URL. The content type should be specified
as the second argument. Use upload_file or upload_image for convenience."""

response = self._client.post(
client = self._get_cdn_client()

headers = {"Content-Type": content_type}
if file_name is not None:
headers["X-Fal-File-Name"] = file_name

response = client.post(
CDN_URL + "/files/upload",
data=data,
headers={"Content-Type": content_type},
headers=headers,
)
_raise_for_status(response)

Expand All @@ -638,7 +775,7 @@ def upload_file(self, path: os.PathLike) -> str:
mime_type = "application/octet-stream"

with open(path, "rb") as file:
return self.upload(file.read(), mime_type)
return self.upload(file.read(), mime_type, file_name=os.path.basename(path))

def upload_image(self, image: Image.Image, format: str = "jpeg") -> str:
"""Upload a pillow image object to the CDN and return the access URL."""
Expand Down
Loading