From 89b588aac11f4086fa42f7c9d79ece78b648eb19 Mon Sep 17 00:00:00 2001 From: fubuloubu <3859395+fubuloubu@users.noreply.github.com> Date: Sat, 4 May 2024 12:19:08 -0400 Subject: [PATCH] refactor: introduce `is_scalar_type` instead --- silverback/recorder.py | 18 +++++++++--------- silverback/types.py | 15 +++++++++------ 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/silverback/recorder.py b/silverback/recorder.py index 60a16d11..e5e074da 100644 --- a/silverback/recorder.py +++ b/silverback/recorder.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Iterator, get_args +from typing import Any, Iterator from ape.logging import get_logger from pydantic import BaseModel, Field @@ -11,9 +11,9 @@ INT96_RANGE, Datapoint, ScalarDatapoint, - ScalarType, SilverbackID, UTCTimestamp, + is_scalar_type, iso_format, utc_now, ) @@ -40,10 +40,10 @@ class TaskResult(BaseModel): @classmethod def _extract_custom_metrics(cls, result: Any, task_name: str) -> dict[str, Datapoint]: - if isinstance(result, Datapoint): # type: ignore[arg-type,misc] + if isinstance(result, Datapoint): return {"result": result} - elif isinstance(result, get_args(ScalarType)): + elif is_scalar_type(result): if isinstance(result, int) and not (INT96_RANGE[0] <= result <= INT96_RANGE[1]): logger.warn("Result integer is out of range suitable for parquet. Ignoring.") else: @@ -56,14 +56,14 @@ def _extract_custom_metrics(cls, result: Any, task_name: str) -> dict[str, Datap logger.warning(f"Cannot handle return type of '{task_name}': '{type(result)}'.") return {} - converted_result = {} + converted_results = {} for metric_name, metric_value in result.items(): if isinstance(metric_value, Datapoint): # type: ignore[arg-type,misc] - converted_result[metric_name] = metric_value + converted_results[metric_name] = metric_value - elif isinstance(metric_value, ScalarType): # type: ignore[arg-type,misc] - converted_result[metric_name] = ScalarDatapoint(data=metric_value) + elif is_scalar_type(metric_value): + converted_results[metric_name] = ScalarDatapoint(data=metric_value) else: logger.warning( @@ -71,7 +71,7 @@ def _extract_custom_metrics(cls, result: Any, task_name: str) -> dict[str, Datap f" '{type(metric_value)}'." ) - return converted_result + return converted_results @classmethod def _extract_system_metrics(cls, labels: dict) -> dict: diff --git a/silverback/types.py b/silverback/types.py index 6d1f203d..6c4da544 100644 --- a/silverback/types.py +++ b/silverback/types.py @@ -1,11 +1,11 @@ from datetime import datetime, timezone from decimal import Decimal from enum import Enum # NOTE: `enum.StrEnum` only in Python 3.11+ -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel, Field from pydantic.functional_serializers import PlainSerializer -from typing_extensions import Annotated +from typing_extensions import Annotated, get_args class TaskType(str, Enum): @@ -47,14 +47,17 @@ class _BaseDatapoint(BaseModel): INT96_RANGE = (-(2**95), 2**95 - 1) Int96 = Annotated[int, Field(ge=INT96_RANGE[0], le=INT96_RANGE[1])] # NOTE: only these types of data are implicitly converted e.g. `{"something": 1, "else": 0.001}` -PydanticScalarType = bool | Int96 | float | Decimal -# NOTE: Use this for isinstance() comparisons and the like -ScalarType = bool | int | float | Decimal +ScalarType = bool | Int96 | float | Decimal +SCALAR_TYPES = tuple(t.__origin__ if hasattr(t, "__origin__") else t for t in get_args(ScalarType)) + + +def is_scalar_type(val: Any) -> bool: + return isinstance(val, SCALAR_TYPES) class ScalarDatapoint(_BaseDatapoint): type: Literal["scalar"] = "scalar" - data: PydanticScalarType + data: ScalarType # NOTE: Other datapoint types must be explicitly used