diff --git a/README.md b/README.md index 3d6fb37..2a8f496 100644 --- a/README.md +++ b/README.md @@ -28,9 +28,11 @@ def batched_euclidean_distance(x: Tensor, y: Tensor) -> Tensor: a = torch.rand((10000, 800)) b = torch.rand((12000, 800)) batched_euclidean_distance(a, b) -a = a.cuda() -b = b.cuda() -batched_euclidean_distance(a, b) # Cuda device is synchronized if function arguments are on device. + +if torch.cuda.is_available(): + a = a.cuda() + b = b.cuda() + batched_euclidean_distance(a, b) # Cuda device is synchronized if function arguments are on device. ``` Prints: ``` @@ -49,8 +51,10 @@ batched_euclidean_distance(CudaTensor[10000, 800], CudaTensor[12000, 800]) -> to * `show_kwargs` (`bool`): If `True`, displays the keyword arguments according to `display_level`. Default: `False`. * `display_level` (`int`): The level of verbosity used when printing function arguments ad keyword arguments. If `0`, prints the type of the parameters. If `1`, prints values for all primitive types, shapes for arrays, tensors, dataframes and length for sequences. Otherwise, prints values for all parameters. Default: `1`. * `sep` (`str`): The separator used when printing function arguments and keyword arguments. Default: `', '`. - * `file_path` (`str`): If not `None`, writes the measurement at the end of the given file path. For thread safe file writing configure use `logger_name` instead. Can't be used in conjunction with `logger_name`. If both `file_path` and `logger_name` are `None`, writes to stdout. Default: `None`. - * `logger_name` (`str`): If not `None`, uses the given logger to print the measurement. Can't be used in conjunction with `file_path`. If both `file_path` and `logger_name` are `None`, writes to stdout. Default: `None`. See [Using a logger](#using-a-logger). + * `stdout` (`bool`): If `True`, writes the elapsed time to stdout. Default: `True`. + * `file_path` (`str`): If not `None`, writes the measurement at the end of the given file path. For thread safe file writing configure use `logger_name` instead. Default: `None`. + * `logger_name` (`str`): If not `None`, uses the given logger to print the measurement. Can't be used in conjunction with `file_path`. Default: `None`. See [Using a logger](#using-a-logger). + * `return_time` (`bool`): If `True`, returns the elapsed time in addition to the wrapped function's return value. Default: `False`. * `out` (`dict`): If not `None`, stores the elapsed time in nanoseconds in the given dict using the function name as key. If the key already exists, adds the time to the existing value. Default: `None`. See [Storing the elapsed time in a dict](#storing-the-elapsed-time-in-a-dict). 2. `nested_timed` is similar to `timed`, however it is designed to work nicely with multiple timed functions that call each other, displaying both the total execution time and the difference after subtracting other timed functions on the same call stack. See [Nested timing decorator](#nested-timing-decorator). @@ -72,7 +76,27 @@ def fibonacci(n: int) -> int: fibonacci(10000) -# fibonacci() -> total time: 2114100ns +# fibonacci() -> total time: 1114100ns +``` + +Getting both the function's return value and the elapsed time. +```py +from timed_decorator.simple_timed import timed + + +@timed(return_time=True) +def fibonacci(n: int) -> int: + assert n > 0 + a, b = 0, 1 + for _ in range(n): + a, b = b, a + b + return a + + +value, elapsed = fibonacci(10000) +print(f'10000th fibonacci number has {len(str(value))} digits. Calculating it took {elapsed}ns.') +# fibonacci() -> total time: 1001200ns +# 10000th fibonacci number has 2090 digits. Calculating it took 1001200ns. ``` Set `collect_gc=False` to disable pre-collection of garbage. @@ -91,7 +115,7 @@ def fibonacci(n: int) -> int: fibonacci(10000) -# fibonacci() -> total time: 2062400ns +# fibonacci() -> total time: 1062400ns ``` Using seconds instead of nanoseconds. @@ -114,7 +138,7 @@ def recursive_fibonacci(n: int) -> int: call_recursive_fibonacci(30) -# call_recursive_fibonacci() -> total time: 0.098s +# call_recursive_fibonacci() -> total time: 0.045s ``` Displaying function parameters: @@ -305,7 +329,7 @@ logging.basicConfig() logging.root.setLevel(logging.NOTSET) -@timed(logger_name='TEST_LOGGER') +@timed(logger_name='TEST_LOGGER', stdout=False) def fn(): sleep(1) @@ -333,7 +357,7 @@ logging.root.setLevel(logging.NOTSET) logging.getLogger('TEST_LOGGER').addHandler(log_handler) -@timed(logger_name='TEST_LOGGER') +@timed(logger_name='TEST_LOGGER', stdout=False) def fn(): sleep(1) @@ -357,7 +381,7 @@ from timed_decorator.simple_timed import timed ns = {} -@timed(out=ns) +@timed(out=ns, stdout=False) def fn(): sleep(1) @@ -369,8 +393,6 @@ print(ns) ``` Prints ``` -fn() -> total time: 1000767300ns {'fn': 1000767300} -fn() -> total time: 1000238800ns {'fn': 2001006100} ``` diff --git a/pyproject.toml b/pyproject.toml index fb36986..976f64c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "timed-decorator" -version = "1.2.2" +version = "1.3.0" #requires-python = ">=3.10" requires-python = ">=3.7" description = "A timing decorator for python functions." @@ -15,6 +15,8 @@ maintainers = [ classifiers = [ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", "Programming Language :: Python :: 3", ] diff --git a/tests/test_usage.py b/tests/test_usage.py index 3a7e2a5..a1b502b 100644 --- a/tests/test_usage.py +++ b/tests/test_usage.py @@ -50,22 +50,22 @@ def sleeping_fn(x): @nested_timed(collect_gc=False, use_seconds=True, precision=3) def other_fn(): - sleep(0.5) - sleeping_fn(0.5) + sleep(0.1) + sleeping_fn(0.1) - sleep(1) - sleeping_fn(1) + sleep(0.1) + sleeping_fn(0.1) other_fn() - sleeping_fn(1) + sleeping_fn(0.1) nested_fn() def test_file_usage(self): filename = 'file.txt' - @timed(file_path=filename) + @timed(file_path=filename, stdout=False) def fn(): - sleep(1) + sleep(0.5) try: fn() @@ -86,9 +86,9 @@ def test_logger_usage(self): logging.root.setLevel(logging.NOTSET) logging.getLogger(logger_name).addHandler(log_handler) - @timed(logger_name=logger_name) + @timed(logger_name=logger_name, stdout=False) def fn(): - sleep(1) + sleep(0.5) fn() fn() @@ -101,14 +101,24 @@ def fn(): def test_ns_output(self): ns = {} - @timed(out=ns) + @timed(out=ns, stdout=False) def fn(): - sleep(1) + sleep(0.5) fn() self.assertIsInstance(ns[fn.__name__], int) - self.assertGreater(ns[fn.__name__], 1**9) + self.assertGreater(ns[fn.__name__], 1**9 / 2) + + def test_return_time(self): + @timed(return_time=True, stdout=False) + def fn(): + sleep(0.5) + + _, elapsed = fn() + + self.assertIsInstance(elapsed, int) + self.assertGreater(elapsed, 1**9 / 2) if __name__ == '__main__': diff --git a/timed_decorator/nested_timed.py b/timed_decorator/nested_timed.py index ca238ab..666956e 100644 --- a/timed_decorator/nested_timed.py +++ b/timed_decorator/nested_timed.py @@ -18,8 +18,10 @@ def nested_timed(collect_gc: bool = True, show_kwargs: bool = False, display_level: int = 1, sep: str = ', ', + stdout: bool = True, file_path: Union[str, None] = None, logger_name: Union[str, None] = None, + return_time: bool = False, out: dict = None): """ A nested timing decorator that measures the time elapsed during the function call and accounts for other decorators @@ -40,20 +42,20 @@ def nested_timed(collect_gc: bool = True, prints the type of the parameters. If `1`, prints values for all primitive types, shapes for arrays, tensors, dataframes and length for sequences. Otherwise, prints values for all parameters. Default: `1`. sep (str): The separator used when printing function arguments and keyword arguments. Default: `', '`. + stdout (bool): If `True`, writes the elapsed time to stdout. Default: `True`. file_path (str): If not `None`, writes the measurement at the end of the given file path. For thread safe - file writing configure use `logger_name` instead. Can't be used in conjunction with `logger_name`. If both - `file_path` and `logger_name` are `None`, writes to stdout. Default: `None`. + file writing configure use `logger_name` instead. Default: `None`. logger_name (str): If not `None`, uses the given logger to print the measurement. Can't be used in conjunction - with `file_path`. If both `file_path` and `logger_name` are `None`, writes to stdout. Default: `None`. + with `file_path`. Default: `None`. + return_time (bool): If `True`, returns the elapsed time in addition to the wrapped function's return value. + Default: `False`. out (dict): If not `None`, stores the elapsed time in nanoseconds in the given dict using the function name as key. If the key already exists, adds the time to the existing value. Default: `None`. """ - assert file_path is None or logger_name is None - gc_collect = collect if collect_gc else nop time_formatter = TimeFormatter(use_seconds, precision) input_formatter = InputFormatter(show_args, show_kwargs, display_level, sep) - logger = Logger(file_path, logger_name) + logger = Logger(stdout, file_path, logger_name) ns_out = write_mutable if out is not None else nop def decorator(fn): @@ -97,6 +99,8 @@ def wrap(*args, **kwargs): logger('\t' * nested_level + f'{input_formatter(fn.__name__, *args, **kwargs)} ' f'-> total time: {time_formatter(elapsed)}, ' f'own time: {time_formatter(own_time)}') + if return_time: + return ret, elapsed return ret return wrap diff --git a/timed_decorator/simple_timed.py b/timed_decorator/simple_timed.py index a674f60..d026252 100644 --- a/timed_decorator/simple_timed.py +++ b/timed_decorator/simple_timed.py @@ -15,8 +15,10 @@ def timed(collect_gc: bool = True, show_kwargs: bool = False, display_level: int = 1, sep: str = ', ', + stdout: bool = True, file_path: Union[str, None] = None, logger_name: Union[str, None] = None, + return_time: bool = False, out: dict = None): """ A simple timing decorator that measures the time elapsed during the function call and prints it. @@ -36,20 +38,20 @@ def timed(collect_gc: bool = True, prints the type of the parameters. If `1`, prints values for all primitive types, shapes for arrays, tensors, dataframes and length for sequences. Otherwise, prints values for all parameters. Default: `1`. sep (str): The separator used when printing function arguments and keyword arguments. Default: `', '`. + stdout (bool): If `True`, writes the elapsed time to stdout. Default: `True`. file_path (str): If not `None`, writes the measurement at the end of the given file path. For thread safe - file writing configure use `logger_name` instead. Can't be used in conjunction with `logger_name`. If both - `file_path` and `logger_name` are `None`, writes to stdout. Default: `None`. + file writing configure use `logger_name` instead. Default: `None`. logger_name (str): If not `None`, uses the given logger to print the measurement. Can't be used in conjunction - with `file_path`. If both `file_path` and `logger_name` are `None`, writes to stdout. Default: `None`. + with `file_path`. Default: `None`. + return_time (bool): If `True`, returns the elapsed time in addition to the wrapped function's return value. + Default: `False`. out (dict): If not `None`, stores the elapsed time in nanoseconds in the given dict using the function name as key. If the key already exists, adds the time to the existing value. Default: `None`. """ - assert file_path is None or logger_name is None - gc_collect = collect if collect_gc else nop time_formatter = TimeFormatter(use_seconds, precision) input_formatter = InputFormatter(show_args, show_kwargs, display_level, sep) - logger = Logger(file_path, logger_name) + logger = Logger(stdout, file_path, logger_name) ns_out = write_mutable if out is not None else nop def decorator(fn): @@ -71,6 +73,8 @@ def wrap(*args, **kwargs): elapsed = end - start ns_out(out, fn.__name__, elapsed) logger(f'{input_formatter(fn.__name__, *args, **kwargs)} -> total time: {time_formatter(elapsed)}') + if return_time: + return ret, elapsed return ret return wrap diff --git a/timed_decorator/utils.py b/timed_decorator/utils.py index 8be0db5..2e2d956 100644 --- a/timed_decorator/utils.py +++ b/timed_decorator/utils.py @@ -50,20 +50,22 @@ def __call__(self, nanoseconds): class Logger: - def __init__(self, file_path: Union[str, None], logger_name: Union[str, None]): - assert file_path is None or logger_name is None - + def __init__(self, stdout: bool, file_path: Union[str, None], logger_name: Union[str, None]): + self.stdout = stdout self.file_path = file_path self.logger_name = logger_name def __call__(self, string: str): + if self.stdout: + print(string) + if self.file_path is not None: with open(self.file_path, 'a') as f: f.write(string + '\n') - elif self.logger_name is not None: + + if self.logger_name is not None: logging.getLogger(self.logger_name).info(string) - else: - print(string) + class InputFormatter: