From e6117b691e4c95afd4c63974db2e687278551824 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Wed, 16 Oct 2024 22:32:36 +0300 Subject: [PATCH] feat(fal_client): use cdn v3 --- projects/fal_client/src/fal_client/client.py | 173 +++++++++++++++++-- 1 file changed, 155 insertions(+), 18 deletions(-) diff --git a/projects/fal_client/src/fal_client/client.py b/projects/fal_client/src/fal_client/client.py index 11ea2f4f..b3cb1fc6 100644 --- a/projects/fal_client/src/fal_client/client.py +++ b/projects/fal_client/src/fal_client/client.py @@ -6,6 +6,8 @@ import asyncio import time import base64 +import threading +from datetime import datetime, timezone from dataclasses import dataclass, field from functools import cached_property from typing import Any, AsyncIterator, Iterator, TYPE_CHECKING, Optional, Literal @@ -23,10 +25,96 @@ RUN_URL_FORMAT = f"https://{FAL_RUN_HOST}/" QUEUE_URL_FORMAT = f"https://queue.{FAL_RUN_HOST}/" REALTIME_URL_FORMAT = f"wss://{FAL_RUN_HOST}/" -CDN_URL = "https://fal.media" +REST_URL = "https://rest.alpha.fal.ai" +CDN_URL = "https://v3.fal.media" USER_AGENT = "fal-client/0.2.2 (python)" +@dataclass +class CDNToken: + token: str + token_type: str + base_upload_url: str + expires_at: datetime + + def is_expired(self) -> bool: + return datetime.now(timezone.utc) >= self.expires_at + + +class CDNTokenManager: + def __init__(self, key: str) -> None: + self._key = key + self._token: CDNToken = CDNToken( + token="", + token_type="", + base_upload_url="", + expires_at=datetime.min.replace(tzinfo=timezone.utc), + ) + self._lock: threading.Lock = threading.Lock() + self._url = f"{REST_URL}/storage/auth/token?storage_type=fal-cdn-v3" + self._headers = { + "Authorization": f"Key {self._key}", + "Accept": "application/json", + "Content-Type": "application/json", + } + + def _refresh_token(self) -> CDNToken: + with httpx.Client() as client: + response = client.post(self._url, headers=self._headers, data=b"{}") + response.raise_for_status() + data = response.json() + + return CDNToken( + token=data["token"], + token_type=data["token_type"], + base_upload_url=data["base_url"], + expires_at=datetime.fromisoformat(data["expires_at"]), + ) + + def get_token(self) -> CDNToken: + with self._lock: + if self._token.is_expired(): + self._token = self._refresh_token() + return self._token + + +class AsyncCDNTokenManager: + def __init__(self, key: str) -> None: + self._key = key + self._token: CDNToken = CDNToken( + token="", + token_type="", + base_upload_url="", + expires_at=datetime.min.replace(tzinfo=timezone.utc), + ) + self._lock: threading.Lock = threading.Lock() + self._url = f"{REST_URL}/storage/auth/token?storage_type=fal-cdn-v3" + self._headers = { + "Authorization": f"Key {self._key}", + "Accept": "application/json", + "Content-Type": "application/json", + } + + async def _refresh_token(self) -> None: + async with httpx.AsyncClient() as client: + response = await client.post(self._url, headers=self._headers, data=b"{}") + response.raise_for_status() + data = response.json() + + return CDNToken( + token=data["token"], + token_type=data["token_type"], + base_upload_url=data["base_url"], + expires_at=datetime.fromisoformat(data["expires_at"]), + ) + + async def get_token(self) -> CDNToken: + with self._lock: + if self._token.is_expired(): + self._token = await self._refresh_token() + return self._token + + class FalClientError(Exception): pass @@ -265,13 +353,18 @@ class AsyncClient: key: str | None = field(default=None, repr=False) default_timeout: float = 120.0 - @cached_property - def _client(self) -> httpx.AsyncClient: + def _get_key(self) -> str: if self.key is None: - key = fetch_credentials() - else: - key = self.key + return fetch_credentials() + return self.key + @cached_property + def _token_manager(self) -> AsyncCDNTokenManager: + return AsyncCDNTokenManager(self._get_key()) + + @cached_property + def _client(self) -> httpx.AsyncClient: + key = self._get_key() return httpx.AsyncClient( headers={ "Authorization": f"Key {key}", @@ -280,6 +373,16 @@ def _client(self) -> httpx.AsyncClient: timeout=self.default_timeout, ) + async def _get_cdn_client(self) -> httpx.AsyncClient: + token = await self._token_manager.get_token() + return httpx.AsyncClient( + headers={ + "Authorization": f"{token.token_type} {token.token}", + "User-Agent": USER_AGENT, + }, + timeout=self.default_timeout, + ) + async def run( self, application: str, @@ -425,14 +528,22 @@ async def stream( async for event in events.aiter_sse(): yield event.json() - async def upload(self, data: str | bytes, content_type: str) -> str: + async def upload( + self, data: str | bytes, content_type: str, file_name: str | None = None + ) -> str: """Upload the given data blob to the CDN and return the access URL. The content type should be specified as the second argument. Use upload_file or upload_image for convenience.""" - response = await self._client.post( + client = await self._get_cdn_client() + + headers = {"Content-Type": content_type} + if file_name is not None: + headers["X-Fal-File-Name"] = file_name + + response = await client.post( CDN_URL + "/files/upload", data=data, - headers={"Content-Type": content_type}, + headers=headers, ) _raise_for_status(response) @@ -446,7 +557,9 @@ async def upload_file(self, path: os.PathLike) -> str: mime_type = "application/octet-stream" with open(path, "rb") as file: - return await self.upload(file.read(), mime_type) + return await self.upload( + file.read(), mime_type, file_name=os.path.basename(path) + ) async def upload_image(self, image: Image.Image, format: str = "jpeg") -> str: """Upload a pillow image object to the CDN and return the access URL.""" @@ -461,12 +574,14 @@ class SyncClient: key: str | None = field(default=None, repr=False) default_timeout: float = 120.0 + def _get_key(self) -> str: + if self.key is None: + return fetch_credentials() + return self.key + @cached_property def _client(self) -> httpx.Client: - if self.key is None: - key = fetch_credentials() - else: - key = self.key + key = self._get_key() return httpx.Client( headers={ "Authorization": f"Key {key}", @@ -475,6 +590,20 @@ def _client(self) -> httpx.Client: timeout=self.default_timeout, ) + @cached_property + def _token_manager(self) -> CDNTokenManager: + return CDNTokenManager(self._get_key()) + + def _get_cdn_client(self) -> httpx.Client: + token = self._token_manager.get_token() + return httpx.Client( + headers={ + "Authorization": f"{token.token_type} {token.token}", + "User-Agent": USER_AGENT, + }, + timeout=self.default_timeout, + ) + def run( self, application: str, @@ -617,14 +746,22 @@ def stream( for event in events.iter_sse(): yield event.json() - def upload(self, data: str | bytes, content_type: str) -> str: + def upload( + self, data: str | bytes, content_type: str, file_name: str | None = None + ) -> str: """Upload the given data blob to the CDN and return the access URL. The content type should be specified as the second argument. Use upload_file or upload_image for convenience.""" - response = self._client.post( + client = self._get_cdn_client() + + headers = {"Content-Type": content_type} + if file_name is not None: + headers["X-Fal-File-Name"] = file_name + + response = client.post( CDN_URL + "/files/upload", data=data, - headers={"Content-Type": content_type}, + headers=headers, ) _raise_for_status(response) @@ -638,7 +775,7 @@ def upload_file(self, path: os.PathLike) -> str: mime_type = "application/octet-stream" with open(path, "rb") as file: - return self.upload(file.read(), mime_type) + return self.upload(file.read(), mime_type, file_name=os.path.basename(path)) def upload_image(self, image: Image.Image, format: str = "jpeg") -> str: """Upload a pillow image object to the CDN and return the access URL."""