Skip to content

Commit

Permalink
Add wrapper for deprecating functions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 616781556
  • Loading branch information
mtthss authored and ChexDev committed Mar 18, 2024
1 parent d2a7137 commit 5a9e7f5
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 1 deletion.
4 changes: 4 additions & 0 deletions chex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
from chex._src.variants import params_product
from chex._src.variants import TestCase
from chex._src.variants import variants
from chex._src.warnings import warn_deprecated_function
from chex._src.warnings import warn_keyword_args_only_in_future
from chex._src.warnings import warn_only_n_pos_args_in_future

Expand Down Expand Up @@ -188,6 +189,9 @@
"Shape",
"TestCase",
"variants",
"warn_deprecated_function",
"warn_keyword_args_only_in_future",
"warn_only_n_pos_args_in_future",
"with_jittable_assertions",
)

Expand Down
27 changes: 27 additions & 0 deletions chex/_src/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,30 @@ def wrapper(*args, **kwargs):
warn_keyword_args_only_in_future = functools.partial(
warn_only_n_pos_args_in_future, n=0
)


def warn_deprecated_function(fun, replacement):
"""A decorator to mark a function definition as deprecated.
Example usage:
@warn_deprecated_function(fun, replacement='g')
def f(a, b):
return a + b
Args:
fun: the deprecated function.
replacement: the name of the function to be used instead.
Returns:
the wrapped function.
"""

@functools.wraps(fun)
def new_fun(*args, **kwargs):
warnings.warn(
f'The function {fun.__name__} is deprecated, '
f'please use {replacement} instead.',
category=DeprecationWarning,
stacklevel=2)
return fun(*args, **kwargs)
return new_fun
11 changes: 10 additions & 1 deletion chex/_src/warnings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def f(a, b, c):
return a + b + c


@functools.partial(warnings.warn_deprecated_function, replacement='h')
def g(a, b, c):
return a + b + c


class WarningsTest(absltest.TestCase):

def test_warn_only_n_pos_args_in_future(self):
Expand All @@ -34,6 +39,10 @@ def test_warn_only_n_pos_args_in_future(self):
with self.assertWarns(Warning):
f(1, 2, c=3)

def test_warn_deprecated_function(self):
with self.assertWarns(Warning):
g(1, 2, 3)


if __name__ == "__main__":
if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ Warnings

.. currentmodule:: chex

.. autofunction:: warn_deprecated_function
.. autofunction:: warn_keyword_args_only_in_future
.. autofunction:: warn_only_n_pos_args_in_future

Expand Down

0 comments on commit 5a9e7f5

Please sign in to comment.