Skip to content

Commit

Permalink
refactor: introduce is_scalar_type instead
Browse files Browse the repository at this point in the history
  • Loading branch information
fubuloubu committed May 4, 2024
1 parent 2920b39 commit 89b588a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
18 changes: 9 additions & 9 deletions silverback/recorder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Iterator, get_args
from typing import Any, Iterator

from ape.logging import get_logger
from pydantic import BaseModel, Field
Expand All @@ -11,9 +11,9 @@
INT96_RANGE,
Datapoint,
ScalarDatapoint,
ScalarType,
SilverbackID,
UTCTimestamp,
is_scalar_type,
iso_format,
utc_now,
)
Expand All @@ -40,10 +40,10 @@ class TaskResult(BaseModel):

@classmethod
def _extract_custom_metrics(cls, result: Any, task_name: str) -> dict[str, Datapoint]:
if isinstance(result, Datapoint): # type: ignore[arg-type,misc]
if isinstance(result, Datapoint):
return {"result": result}

elif isinstance(result, get_args(ScalarType)):
elif is_scalar_type(result):
if isinstance(result, int) and not (INT96_RANGE[0] <= result <= INT96_RANGE[1]):
logger.warn("Result integer is out of range suitable for parquet. Ignoring.")
else:
Expand All @@ -56,22 +56,22 @@ def _extract_custom_metrics(cls, result: Any, task_name: str) -> dict[str, Datap
logger.warning(f"Cannot handle return type of '{task_name}': '{type(result)}'.")
return {}

converted_result = {}
converted_results = {}

for metric_name, metric_value in result.items():
if isinstance(metric_value, Datapoint): # type: ignore[arg-type,misc]
converted_result[metric_name] = metric_value
converted_results[metric_name] = metric_value

elif isinstance(metric_value, ScalarType): # type: ignore[arg-type,misc]
converted_result[metric_name] = ScalarDatapoint(data=metric_value)
elif is_scalar_type(metric_value):
converted_results[metric_name] = ScalarDatapoint(data=metric_value)

else:
logger.warning(
f"Cannot handle type of metric '{task_name}.{metric_name}':"
f" '{type(metric_value)}'."
)

return converted_result
return converted_results

@classmethod
def _extract_system_metrics(cls, labels: dict) -> dict:
Expand Down
15 changes: 9 additions & 6 deletions silverback/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from datetime import datetime, timezone
from decimal import Decimal
from enum import Enum # NOTE: `enum.StrEnum` only in Python 3.11+
from typing import Literal
from typing import Any, Literal

from pydantic import BaseModel, Field
from pydantic.functional_serializers import PlainSerializer
from typing_extensions import Annotated
from typing_extensions import Annotated, get_args


class TaskType(str, Enum):
Expand Down Expand Up @@ -47,14 +47,17 @@ class _BaseDatapoint(BaseModel):
INT96_RANGE = (-(2**95), 2**95 - 1)
Int96 = Annotated[int, Field(ge=INT96_RANGE[0], le=INT96_RANGE[1])]
# NOTE: only these types of data are implicitly converted e.g. `{"something": 1, "else": 0.001}`
PydanticScalarType = bool | Int96 | float | Decimal
# NOTE: Use this for isinstance() comparisons and the like
ScalarType = bool | int | float | Decimal
ScalarType = bool | Int96 | float | Decimal
SCALAR_TYPES = tuple(t.__origin__ if hasattr(t, "__origin__") else t for t in get_args(ScalarType))


def is_scalar_type(val: Any) -> bool:
return isinstance(val, SCALAR_TYPES)


class ScalarDatapoint(_BaseDatapoint):
type: Literal["scalar"] = "scalar"
data: PydanticScalarType
data: ScalarType


# NOTE: Other datapoint types must be explicitly used
Expand Down

0 comments on commit 89b588a

Please sign in to comment.