Skip to content

Commit

Permalink
fix: use global lifecycle object in save of media
Browse files Browse the repository at this point in the history
  • Loading branch information
chamini2 committed Dec 5, 2024
1 parent 7cb321e commit 1189fbf
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
23 changes: 13 additions & 10 deletions projects/fal/src/fal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from fal.api import RouteSignature
from fal.exceptions import FalServerlessException, RequestCancelledException
from fal.logging import get_logger
from fal.toolkit.file import get_lifecycle_preference
from fal.toolkit.file.providers.fal import GLOBAL_LIFECYCLE_PREFERENCE
from fal.toolkit.file import request_lifecycle_repference
from fal.toolkit.file.providers.fal import LIFECYCLE_PREFERENCE

REALTIME_APP_REQUIREMENTS = ["websockets", "msgpack"]
REQUEST_ID_KEY = "x-fal-request-id"
Expand Down Expand Up @@ -342,13 +342,11 @@ async def provide_hints_headers(request, call_next):
@app.middleware("http")
async def set_global_object_preference(request, call_next):
try:
preference_dict = get_lifecycle_preference(request) or {}
expiration_duration = preference_dict.get("expiration_duration_seconds")
if expiration_duration is not None:
GLOBAL_LIFECYCLE_PREFERENCE.expiration_duration_seconds = int(
expiration_duration
)

preference_dict = request_lifecycle_repference(request)
if preference_dict is not None:
# This will not work properly for apps with multiplexing enabled
# we may mix up the preferences between requests
LIFECYCLE_PREFERENCE.set(preference_dict)
except Exception:
from fastapi.logger import logger

Expand All @@ -357,7 +355,12 @@ async def set_global_object_preference(request, call_next):
self.__class__.__name__,
)

return await call_next(request)
try:
return await call_next(request)
finally:
# We may miss the global preference if there are operations
# being done in the background that go beyond the request
LIFECYCLE_PREFERENCE.set(None)

@app.middleware("http")
async def set_request_id(request, call_next):
Expand Down
11 changes: 8 additions & 3 deletions projects/fal/src/fal/toolkit/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pydantic import BaseModel, Field

from fal.toolkit.file.providers.fal import (
LIFECYCLE_PREFERENCE,
FalCDNFileRepository,
FalFileRepository,
FalFileRepositoryV2,
Expand Down Expand Up @@ -149,7 +150,9 @@ def from_bytes(

fdata = FileData(data, content_type, file_name)

object_lifecycle_preference = get_lifecycle_preference(request)
object_lifecycle_preference = (
request_lifecycle_repference(request) or LIFECYCLE_PREFERENCE.get()
)

try:
url = repo.save(fdata, object_lifecycle_preference, **save_kwargs)
Expand Down Expand Up @@ -203,7 +206,9 @@ def from_path(
fallback_save_kwargs = fallback_save_kwargs or {}

content_type = content_type or "application/octet-stream"
object_lifecycle_preference = get_lifecycle_preference(request)
object_lifecycle_preference = (
request_lifecycle_repference(request) or LIFECYCLE_PREFERENCE.get()
)

try:
url, data = repo.save_file(
Expand Down Expand Up @@ -288,7 +293,7 @@ def __del__(self):
shutil.rmtree(self.extract_dir)


def get_lifecycle_preference(request: Optional[Request]) -> dict[str, str] | None:
def request_lifecycle_repference(request: Optional[Request]) -> dict[str, str] | None:
import json

if request is None:
Expand Down
20 changes: 14 additions & 6 deletions projects/fal/src/fal/toolkit/file/providers/fal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Generic, TypeVar
from urllib.error import HTTPError
from urllib.parse import urlparse, urlunparse
from urllib.request import Request, urlopen
Expand Down Expand Up @@ -104,14 +105,21 @@ class FalV3TokenManager(FalV2TokenManager):
fal_v3_token_manager = FalV3TokenManager()


@dataclass
class ObjectLifecyclePreference:
expiration_duration_seconds: int
VariableType = TypeVar("VariableType")


class VariableReference(Generic[VariableType]):
def __init__(self, value: VariableType) -> None:
self.set(value)

def get(self) -> VariableType:
return self.value

def set(self, value: VariableType) -> None:
self.value = value


GLOBAL_LIFECYCLE_PREFERENCE = ObjectLifecyclePreference(
expiration_duration_seconds=86400
)
LIFECYCLE_PREFERENCE: VariableReference[dict[str, str] | None] = VariableReference(None)


@dataclass
Expand Down

0 comments on commit 1189fbf

Please sign in to comment.