diff --git a/src/prefect/results.py b/src/prefect/results.py index 1fe738a762dc..e9089f054032 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import inspect import os import socket import threading import uuid from functools import partial +from operator import methodcaller from pathlib import Path from typing import ( TYPE_CHECKING, @@ -11,10 +14,8 @@ Any, Callable, ClassVar, - Dict, Generic, Optional, - Tuple, TypeVar, Union, ) @@ -38,6 +39,7 @@ emit_result_read_event, emit_result_write_event, ) +from prefect._internal.compatibility.async_dispatch import async_dispatch from prefect.blocks.core import Block from prefect.exceptions import ( ConfigurationError, @@ -58,28 +60,29 @@ from prefect.utilities.asyncutils import sync_compatible if TYPE_CHECKING: + import logging + from prefect import Flow, Task from prefect.transactions import IsolationLevel ResultStorage = Union[WritableFileSystem, str] ResultSerializer = Union[Serializer, str] -LITERAL_TYPES = {type(None), bool, UUID} +LITERAL_TYPES: set[type] = {type(None), bool, UUID} -def DEFAULT_STORAGE_KEY_FN(): +def DEFAULT_STORAGE_KEY_FN() -> str: return uuid.uuid4().hex -logger = get_logger("results") +logger: "logging.Logger" = get_logger("results") P = ParamSpec("P") R = TypeVar("R") -_default_storages: Dict[Tuple[str, str], WritableFileSystem] = {} +_default_storages: dict[tuple[str, str], WritableFileSystem] = {} -@sync_compatible -async def get_default_result_storage() -> WritableFileSystem: +async def aget_default_result_storage() -> WritableFileSystem: """ Generate a default file system for result storage. """ @@ -93,7 +96,7 @@ async def get_default_result_storage() -> WritableFileSystem: return _default_storages[cache_key] if default_block is not None: - storage = await resolve_result_storage(default_block) + storage = await aresolve_result_storage(default_block) else: # Use the local file system storage = LocalFileSystem(basepath=str(basepath)) @@ -102,9 +105,34 @@ async def get_default_result_storage() -> WritableFileSystem: return storage -@sync_compatible -async def resolve_result_storage( - result_storage: Union[ResultStorage, UUID, Path], +@async_dispatch(aget_default_result_storage) +def get_default_result_storage() -> WritableFileSystem: + """ + Generate a default file system for result storage. + """ + settings = get_current_settings() + default_block = settings.results.default_storage_block + basepath = settings.results.local_storage_path + + cache_key = (str(default_block), str(basepath)) + + if cache_key in _default_storages: + return _default_storages[cache_key] + + if default_block is not None: + storage = resolve_result_storage(default_block, _sync=True) + if TYPE_CHECKING: + assert isinstance(storage, WritableFileSystem) + else: + # Use the local file system + storage = LocalFileSystem(basepath=str(basepath)) + + _default_storages[cache_key] = storage + return storage + + +async def aresolve_result_storage( + result_storage: ResultStorage | UUID | Path, ) -> WritableFileSystem: """ Resolve one of the valid `ResultStorage` input types into a saved block @@ -113,23 +141,23 @@ async def resolve_result_storage( from prefect.client.orchestration import get_client client = get_client() + storage_block: WritableFileSystem if isinstance(result_storage, Block): storage_block = result_storage - - if storage_block._block_document_id is not None: - # Avoid saving the block if it already has an identifier assigned - storage_block_id = storage_block._block_document_id - else: - storage_block_id = None elif isinstance(result_storage, Path): storage_block = LocalFileSystem(basepath=str(result_storage)) elif isinstance(result_storage, str): - storage_block = await Block.aload(result_storage, client=client) - storage_block_id = storage_block._block_document_id - assert storage_block_id is not None, "Loaded storage blocks must have ids" - elif isinstance(result_storage, UUID): + block = await Block.aload(result_storage, client=client) + if TYPE_CHECKING: + assert isinstance(block, WritableFileSystem) + storage_block = block + elif isinstance(result_storage, UUID): # pyright: ignore[reportUnnecessaryIsInstance] block_document = await client.read_block_document(result_storage) - storage_block = Block._from_block_document(block_document) + from_block_document = methodcaller("_from_block_document", block_document) + block = from_block_document(Block) + if TYPE_CHECKING: + assert isinstance(block, WritableFileSystem) + storage_block = block else: raise TypeError( "Result storage must be one of the following types: 'UUID', 'Block', " @@ -139,6 +167,42 @@ async def resolve_result_storage( return storage_block +@async_dispatch(aresolve_result_storage) +def resolve_result_storage( + result_storage: ResultStorage | UUID | Path, +) -> WritableFileSystem: + """ + Resolve one of the valid `ResultStorage` input types into a saved block + document id and an instance of the block. + """ + from prefect.client.orchestration import get_client + + client = get_client(sync_client=True) + storage_block: WritableFileSystem + if isinstance(result_storage, Block): + storage_block = result_storage + elif isinstance(result_storage, Path): + storage_block = LocalFileSystem(basepath=str(result_storage)) + elif isinstance(result_storage, str): + block = Block.load(result_storage, _sync=True) + if TYPE_CHECKING: + assert isinstance(block, WritableFileSystem) + storage_block = block + elif isinstance(result_storage, UUID): # pyright: ignore[reportUnnecessaryIsInstance] + block_document = client.read_block_document(result_storage) + from_block_document = methodcaller("_from_block_document", block_document) + block = from_block_document(Block) + if TYPE_CHECKING: + assert isinstance(block, WritableFileSystem) + storage_block = block + else: + raise TypeError( + "Result storage must be one of the following types: 'UUID', 'Block', " + f"'str'. Got unsupported type {type(result_storage).__name__!r}." + ) + return storage_block + + def resolve_serializer(serializer: ResultSerializer) -> Serializer: """ Resolve one of the valid `ResultSerializer` input types into a serializer @@ -146,7 +210,7 @@ def resolve_serializer(serializer: ResultSerializer) -> Serializer: """ if isinstance(serializer, Serializer): return serializer - elif isinstance(serializer, str): + elif isinstance(serializer, str): # pyright: ignore[reportUnnecessaryIsInstance] return Serializer(type=serializer) else: raise TypeError( @@ -163,11 +227,14 @@ async def get_or_create_default_task_scheduling_storage() -> ResultStorage: default_block = settings.tasks.scheduling.default_storage_block if default_block is not None: - return await Block.aload(default_block) + block = await Block.aload(default_block) + if TYPE_CHECKING: + assert isinstance(block, WritableFileSystem) + return block # otherwise, use the local file system basepath = settings.results.local_storage_path - return LocalFileSystem(basepath=basepath) + return LocalFileSystem(basepath=str(basepath)) def get_default_result_serializer() -> Serializer: @@ -225,7 +292,7 @@ def _format_user_supplied_storage_key(key: str) -> str: async def _call_explicitly_async_block_method( - block: Union[WritableFileSystem, NullFileSystem], + block: WritableFileSystem | NullFileSystem, method: str, args: tuple[Any, ...], kwargs: dict[str, Any], @@ -301,13 +368,13 @@ class ResultStore(BaseModel): cache: LRUCache[str, "ResultRecord[Any]"] = Field(default_factory=default_cache) @property - def result_storage_block_id(self) -> Optional[UUID]: + def result_storage_block_id(self) -> UUID | None: if self.result_storage is None: return None - return self.result_storage._block_document_id + return getattr(self.result_storage, "_block_document_id", None) @sync_compatible - async def update_for_flow(self, flow: "Flow") -> Self: + async def update_for_flow(self, flow: "Flow[..., Any]") -> Self: """ Create a new result store for a flow with updated settings. @@ -317,15 +384,16 @@ async def update_for_flow(self, flow: "Flow") -> Self: Returns: An updated result store. """ - update = {} + update: dict[str, Any] = {} + update["cache_result_in_memory"] = flow.cache_result_in_memory if flow.result_storage is not None: - update["result_storage"] = await resolve_result_storage(flow.result_storage) + update["result_storage"] = await aresolve_result_storage( + flow.result_storage + ) if flow.result_serializer is not None: update["serializer"] = resolve_serializer(flow.result_serializer) - if flow.cache_result_in_memory is not None: - update["cache_result_in_memory"] = flow.cache_result_in_memory if self.result_storage is None and update.get("result_storage") is None: - update["result_storage"] = await get_default_result_storage() + update["result_storage"] = await aget_default_result_storage() update["metadata_storage"] = NullFileSystem() return self.model_copy(update=update) @@ -342,13 +410,14 @@ async def update_for_task(self: Self, task: "Task[P, R]") -> Self: """ from prefect.transactions import get_transaction - update = {} + update: dict[str, Any] = {} + update["cache_result_in_memory"] = task.cache_result_in_memory if task.result_storage is not None: - update["result_storage"] = await resolve_result_storage(task.result_storage) + update["result_storage"] = await aresolve_result_storage( + task.result_storage + ) if task.result_serializer is not None: update["serializer"] = resolve_serializer(task.result_serializer) - if task.cache_result_in_memory is not None: - update["cache_result_in_memory"] = task.cache_result_in_memory if task.result_storage_key is not None: update["storage_key_fn"] = partial( _format_user_supplied_storage_key, task.result_storage_key @@ -360,18 +429,20 @@ async def update_for_task(self: Self, task: "Task[P, R]") -> Self: ): update["lock_manager"] = current_txn.store.lock_manager - if task.cache_policy is not None and task.cache_policy is not NotSet: + from prefect.cache_policies import CachePolicy + + if isinstance(task.cache_policy, CachePolicy): if task.cache_policy.key_storage is not None: storage = task.cache_policy.key_storage if isinstance(storage, str) and not len(storage.split("/")) == 2: storage = Path(storage) - update["metadata_storage"] = await resolve_result_storage(storage) + update["metadata_storage"] = await aresolve_result_storage(storage) # if the cache policy has a lock manager, it takes precedence over the parent transaction if task.cache_policy.lock_manager is not None: update["lock_manager"] = task.cache_policy.lock_manager if self.result_storage is None and update.get("result_storage") is None: - update["result_storage"] = await get_default_result_storage() + update["result_storage"] = await aget_default_result_storage() if ( isinstance(self.metadata_storage, NullFileSystem) and update.get("metadata_storage", NotSet) is NotSet @@ -424,7 +495,7 @@ async def _exists(self, key: str) -> bool: ) if content is None: return False - record = ResultRecord.deserialize(content) + record: ResultRecord[Any] = ResultRecord.deserialize(content) metadata = record.metadata except Exception: return False @@ -462,10 +533,10 @@ async def aexists(self, key: str) -> bool: return await self._exists(key=key, _sync=False) def _resolved_key_path(self, key: str) -> str: - if self.result_storage_block_id is None and hasattr( - self.result_storage, "_resolve_path" + if self.result_storage_block_id is None and ( + _resolve_path := getattr(self.result_storage, "_resolve_path", None) ): - return str(self.result_storage._resolve_path(key)) + return str(_resolve_path(key)) return key @sync_compatible @@ -490,12 +561,12 @@ async def _read(self, key: str, holder: str) -> "ResultRecord[Any]": resolved_key_path = self._resolved_key_path(key) if resolved_key_path in self.cache: - cached_result = self.cache[resolved_key_path] + cached_result: ResultRecord[Any] = self.cache[resolved_key_path] await emit_result_read_event(self, resolved_key_path, cached=True) return cached_result if self.result_storage is None: - self.result_storage = await get_default_result_storage() + self.result_storage = await aget_default_result_storage() if self.metadata_storage is not None: metadata_content = await _call_explicitly_async_block_method( @@ -539,7 +610,7 @@ async def _read(self, key: str, holder: str) -> "ResultRecord[Any]": def read( self, key: str, - holder: Optional[str] = None, + holder: str | None = None, ) -> "ResultRecord[Any]": """ Read a result record from storage. @@ -557,7 +628,7 @@ def read( async def aread( self, key: str, - holder: Optional[str] = None, + holder: str | None = None, ) -> "ResultRecord[Any]": """ Read a result record from storage. @@ -575,8 +646,8 @@ async def aread( def create_result_record( self, obj: Any, - key: Optional[str] = None, - expiration: Optional[DateTime] = None, + key: str | None = None, + expiration: DateTime | None = None, ) -> "ResultRecord[Any]": """ Create a result record. @@ -590,10 +661,12 @@ def create_result_record( if self.result_storage is None: self.result_storage = get_default_result_storage(_sync=True) + if TYPE_CHECKING: + assert isinstance(self.result_storage, WritableFileSystem) if self.result_storage_block_id is None: - if hasattr(self.result_storage, "_resolve_path"): - key = str(self.result_storage._resolve_path(key)) + if _resolve_path := getattr(self.result_storage, "_resolve_path", None): + key = str(_resolve_path(key)) return ResultRecord( result=obj, @@ -608,10 +681,10 @@ def create_result_record( def write( self, obj: Any, - key: Optional[str] = None, - expiration: Optional[DateTime] = None, - holder: Optional[str] = None, - ): + key: str | None = None, + expiration: DateTime | None = None, + holder: str | None = None, + ) -> None: """ Write a result to storage. @@ -635,10 +708,10 @@ def write( async def awrite( self, obj: Any, - key: Optional[str] = None, - expiration: Optional[DateTime] = None, - holder: Optional[str] = None, - ): + key: str | None = None, + expiration: DateTime | None = None, + holder: str | None = None, + ) -> None: """ Write a result to storage. @@ -657,7 +730,9 @@ async def awrite( ) @sync_compatible - async def _persist_result_record(self, result_record: "ResultRecord", holder: str): + async def _persist_result_record( + self, result_record: "ResultRecord[Any]", holder: str + ) -> None: """ Persist a result record to storage. @@ -672,8 +747,10 @@ async def _persist_result_record(self, result_record: "ResultRecord", holder: st key = result_record.metadata.storage_key if result_record.metadata.storage_block_id is None: basepath = ( - self.result_storage._resolve_path("") - if hasattr(self.result_storage, "_resolve_path") + _resolve_path("") + if ( + _resolve_path := getattr(self.result_storage, "_resolve_path", None) + ) else Path(".").resolve() ) base_key = str(Path(key).relative_to(basepath)) @@ -689,7 +766,7 @@ async def _persist_result_record(self, result_record: "ResultRecord", holder: st f"another holder." ) if self.result_storage is None: - self.result_storage = await get_default_result_storage() + self.result_storage = await aget_default_result_storage() # If metadata storage is configured, write result and metadata separately if self.metadata_storage is not None: @@ -719,8 +796,8 @@ async def _persist_result_record(self, result_record: "ResultRecord", holder: st self.cache[key] = result_record def persist_result_record( - self, result_record: "ResultRecord", holder: Optional[str] = None - ): + self, result_record: "ResultRecord[Any]", holder: str | None = None + ) -> None: """ Persist a result record to storage. @@ -733,8 +810,8 @@ def persist_result_record( ) async def apersist_result_record( - self, result_record: "ResultRecord", holder: Optional[str] = None - ): + self, result_record: "ResultRecord[Any]", holder: str | None = None + ) -> None: """ Persist a result record to storage. @@ -766,7 +843,7 @@ def supports_isolation_level(self, level: "IsolationLevel") -> bool: raise ValueError(f"Unsupported isolation level: {level}") def acquire_lock( - self, key: str, holder: Optional[str] = None, timeout: Optional[float] = None + self, key: str, holder: str | None = None, timeout: float | None = None ) -> bool: """ Acquire a lock for a result record. @@ -789,7 +866,7 @@ def acquire_lock( return self.lock_manager.acquire_lock(key, holder, timeout) async def aacquire_lock( - self, key: str, holder: Optional[str] = None, timeout: Optional[float] = None + self, key: str, holder: str | None = None, timeout: float | None = None ) -> bool: """ Acquire a lock for a result record. @@ -812,7 +889,7 @@ async def aacquire_lock( return await self.lock_manager.aacquire_lock(key, holder, timeout) - def release_lock(self, key: str, holder: Optional[str] = None): + def release_lock(self, key: str, holder: str | None = None) -> None: """ Release a lock for a result record. @@ -841,7 +918,7 @@ def is_locked(self, key: str) -> bool: ) return self.lock_manager.is_locked(key) - def is_lock_holder(self, key: str, holder: Optional[str] = None) -> bool: + def is_lock_holder(self, key: str, holder: str | None = None) -> bool: """ Check if the current holder is the lock holder for the result record. @@ -861,7 +938,7 @@ def is_lock_holder(self, key: str, holder: Optional[str] = None) -> bool: ) return self.lock_manager.is_lock_holder(key, holder) - def wait_for_lock(self, key: str, timeout: Optional[float] = None) -> bool: + def wait_for_lock(self, key: str, timeout: float | None = None) -> bool: """ Wait for the corresponding transaction record to become free. """ @@ -872,7 +949,7 @@ def wait_for_lock(self, key: str, timeout: Optional[float] = None) -> bool: ) return self.lock_manager.wait_for_lock(key, timeout) - async def await_for_lock(self, key: str, timeout: Optional[float] = None) -> bool: + async def await_for_lock(self, key: str, timeout: float | None = None) -> bool: """ Wait for the corresponding transaction record to become free. """ @@ -886,13 +963,14 @@ async def await_for_lock(self, key: str, timeout: Optional[float] = None) -> boo # TODO: These two methods need to find a new home @sync_compatible - async def store_parameters(self, identifier: UUID, parameters: Dict[str, Any]): + async def store_parameters(self, identifier: UUID, parameters: dict[str, Any]): record = ResultRecord( result=parameters, metadata=ResultRecordMetadata( serializer=self.serializer, storage_key=str(identifier) ), ) + await _call_explicitly_async_block_method( self.result_storage, "write_path", @@ -906,7 +984,7 @@ async def read_parameters(self, identifier: UUID) -> dict[str, Any]: raise ValueError( "Result store is not configured - must have a result storage block to read parameters" ) - record = ResultRecord.deserialize( + record: ResultRecord[Any] = ResultRecord.deserialize( await _call_explicitly_async_block_method( self.result_storage, "read_path", @@ -988,7 +1066,7 @@ class ResultRecord(BaseModel, Generic[R]): result: R @property - def expiration(self) -> Optional[DateTime]: + def expiration(self) -> DateTime | None: return self.metadata.expiration @property @@ -1012,7 +1090,7 @@ def serialize_result(self) -> bytes: and str(exc).startswith("cannot pickle") ): try: - from IPython import get_ipython + from IPython.core.getipython import get_ipython if get_ipython() is not None: extra_info = inspect.cleandoc( @@ -1041,7 +1119,7 @@ def serialize_result(self) -> bytes: @model_validator(mode="before") @classmethod - def coerce_old_format(cls, value: Any) -> Any: + def coerce_old_format(cls, value: dict[str, Any] | Any) -> dict[str, Any]: if isinstance(value, dict): if "data" in value: value["result"] = value.pop("data") @@ -1071,12 +1149,14 @@ async def _from_metadata(cls, metadata: ResultRecordMetadata) -> "ResultRecord[R if metadata.storage_block_id is None: storage_block = None else: - storage_block = await resolve_result_storage( - metadata.storage_block_id, _sync=False - ) + storage_block = await aresolve_result_storage(metadata.storage_block_id) store = ResultStore( result_storage=storage_block, serializer=metadata.serializer ) + if metadata.storage_key is None: + raise ValueError( + "storage_key is required to hydrate a result record from metadata" + ) result = await store.aread(metadata.storage_key) return result @@ -1101,7 +1181,7 @@ def serialize( @classmethod def deserialize( - cls, data: bytes, backup_serializer: Optional[Serializer] = None + cls, data: bytes, backup_serializer: Serializer | None = None ) -> "ResultRecord[R]": """ Deserialize a record from bytes. @@ -1151,7 +1231,7 @@ def deserialize_from_result_and_metadata( result=result_record_metadata.serializer.loads(result), ) - def __eq__(self, other): + def __eq__(self, other: Any | "ResultRecord[Any]") -> bool: if not isinstance(other, ResultRecord): return False return self.metadata == other.metadata and self.result == other.result