diff --git a/temporalio/activity.py b/temporalio/activity.py index 281cfcb8..4c5d917f 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -12,12 +12,14 @@ import asyncio import contextvars import dataclasses +import enum import inspect import logging import threading from contextlib import AbstractContextManager, contextmanager from dataclasses import dataclass from datetime import datetime, timedelta +import types from typing import ( Any, Callable, @@ -480,6 +482,8 @@ def base_logger(self) -> logging.Logger: """Logger that will have contextual activity details embedded.""" +temporal_activity_def_key = "__temporal_activity_definition" + @dataclass(frozen=True) class _Definition: name: Optional[str] @@ -492,7 +496,7 @@ class _Definition: @staticmethod def from_callable(fn: Callable) -> Optional[_Definition]: - defn = getattr(fn, "__temporal_activity_definition", None) + defn = getattr(fn, temporal_activity_def_key, None) if isinstance(defn, _Definition): # We have to replace the function with the given callable here # because the one passed in may be a method or some other partial @@ -519,7 +523,7 @@ def _apply_to_callable( no_thread_cancel_exception: bool = False, ) -> None: # Validate the activity - if hasattr(fn, "__temporal_activity_definition"): + if hasattr(fn, temporal_activity_def_key): raise ValueError("Function already contains activity definition") elif not callable(fn): raise TypeError("Activity is not callable") @@ -530,7 +534,7 @@ def _apply_to_callable( raise TypeError("Activity cannot have keyword-only arguments") setattr( fn, - "__temporal_activity_definition", + temporal_activity_def_key, _Definition( name=activity_name, fn=fn, @@ -557,3 +561,50 @@ def __post_init__(self) -> None: ) object.__setattr__(self, "arg_types", arg_types) object.__setattr__(self, "ret_type", ret_type) + + +def get_activities(module: types.ModuleType) -> list[Callable]: + activities = [] + for name, member in inspect.getmembers(module): + if inspect.isfunction(member) and hasattr(member, temporal_activity_def_key): + activities.append(getattr(module, name)) + return activities + + +class ActivitiesProvider: + class __MethodType(enum.Enum): + CLASS_METHOD = (True, False) + STATIC_METHOD = (False, True) + INSTANCE_METHOD = (False, False) + + @classmethod + def __get_activities( + cls, + instance: Union[ + type["ActivitiesProvider"], "ActivitiesProvider" + ], + ) -> list[Callable]: + throw_exception_for_instance_method = isinstance(instance, type) + activities = [] + for name, member in inspect.getmembers(cls): + is_method_or_fn = inspect.isfunction(member) or inspect.ismethod(member) + is_activity = hasattr(member, temporal_activity_def_key) + if not (is_method_or_fn and is_activity): + continue + is_classmethod = isinstance(inspect.getattr_static(cls, name), classmethod) + is_staticmethod = isinstance(inspect.getattr_static(cls, name), staticmethod) + method_type = cls.__MethodType((is_classmethod, is_staticmethod)) + if method_type is cls.__MethodType.INSTANCE_METHOD and throw_exception_for_instance_method: + raise ValueError( + f"Class {cls.__name__} method {name} is an activity, but it is an instance method. " + "Because of that, you cannot gather activities from the class, you must get them from " + "an instance using instance.get_activities_from_instance()") + activities.append(getattr(instance, name)) + return activities + + def get_activities_from_instance(self) -> list[Callable]: + return self.__get_activities(self) + + @classmethod + def get_activities_from_cls(cls) -> list[Callable]: + return cls.__get_activities(cls) \ No newline at end of file diff --git a/tests/testing/test_activity.py b/tests/testing/test_activity.py index 29b66c77..5d054658 100644 --- a/tests/testing/test_activity.py +++ b/tests/testing/test_activity.py @@ -1,4 +1,5 @@ import asyncio +import sys import threading import time from contextvars import copy_context @@ -110,3 +111,69 @@ async def assert_equals(a: str, b: str) -> None: assert type(expected_err) == type(actual_err) assert str(expected_err) == str(actual_err) + + +def test_get_activities_from_cls(): + class ClassAndStaticActivities(activity.ActivitiesProvider): + @classmethod + @activity.defn + async def class_method_activity(cls): + pass + + @staticmethod + @activity.defn + async def static_method_activity(): + pass + + assert ClassAndStaticActivities.get_activities_from_cls() == [ + ClassAndStaticActivities.class_method_activity, + ClassAndStaticActivities.static_method_activity, + ] + +class _AllActivityMethodTypes(activity.ActivitiesProvider): + @activity.defn + async def instance_method_activity(self): + pass + + @classmethod + @activity.defn + async def class_method_activity(cls): + pass + + @staticmethod + @activity.defn + async def static_method_activity(): + pass + +def test_get_activities_from_cls_error(): + try: + _AllActivityMethodTypes.get_activities_from_cls() + raise Exception("above call should have thrown value error") + except ValueError as ex: + assert str(ex) == (f"Class _AllActivityMethodTypes method instance_method_activity is an activity, but it is an instance method. " + "Because of that, you cannot gather activities from the class, you must get them from " + "an instance using instance.get_activities_from_instance()" + ) + +def test_get_activities_from_instance(): + inst = _AllActivityMethodTypes() + assert inst.get_activities_from_instance() == [ + inst.class_method_activity, + inst.instance_method_activity, + inst.static_method_activity, + ] + +@activity.defn +def _some_activity(): + pass + +@activity.defn +async def _some_async_activity(): + pass + +def test_get_activities(): + current_module = sys.modules[__name__] + assert activity.get_activities(current_module) == [ + _some_activity, + _some_async_activity + ]