From 1cc0193561c2703e3cad525b0ce308a9641788a2 Mon Sep 17 00:00:00 2001 From: yhna Date: Sat, 12 Oct 2024 22:04:00 +0900 Subject: [PATCH 01/10] Add deprecate decorator --- src/accelerate/utils/__init__.py | 1 + src/accelerate/utils/deprecation.py | 87 +++++++++++++++++++++++++++++ src/accelerate/utils/modeling.py | 2 + tests/test_utils.py | 62 ++++++++++++++++++++ 4 files changed, 152 insertions(+) create mode 100644 src/accelerate/utils/deprecation.py diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 403953de622..d4f3d969788 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -240,6 +240,7 @@ from .megatron_lm import prepare_model_optimizer_scheduler as megatron_lm_prepare_model_optimizer_scheduler from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler +from .deprecation import deprecated from .memory import find_executable_batch_size, release_memory from .other import ( check_os_kernel, diff --git a/src/accelerate/utils/deprecation.py b/src/accelerate/utils/deprecation.py new file mode 100644 index 00000000000..d4321119f08 --- /dev/null +++ b/src/accelerate/utils/deprecation.py @@ -0,0 +1,87 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import textwrap +import warnings +from typing import Callable, TypeVar + +from typing_extensions import ParamSpec + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def deprecated(since: str, removed_in: str, instructions: str) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """Marks functions as deprecated. + + It will result in a warning when the function is called and a note in the docstring. + + Args: + since (`str`): + The version when the function was first deprecated. + removed_in (`str`): + The version when the function will be removed. + instructions (`str`): + The action users should take. + + Returns: + `Callable`: A decorator that will mark the function as deprecated. + """ + + def decorator(function: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(function) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + warnings.warn( + f"'{function.__module__}.{function.__name__}' " + f"is deprecated in version {since} and will be " + f"removed in {removed_in}. Please {instructions}.", + category=FutureWarning, + stacklevel=2, + ) + return function(*args, **kwargs) + + # Add a deprecation note to the docstring. + docstring = function.__doc__ or "" + + # Add a note to the docstring. + deprecation_note = textwrap.dedent( + f"""\ + .. deprecated:: {since} + Deprecated and will be removed in version {removed_in}. Please {instructions}. + """ + ) + + # Split docstring at first occurrence of newline + summary_and_body = docstring.split("\n\n", 1) + + if len(summary_and_body) > 1: + summary, body = summary_and_body + + # Dedent the body. We cannot do this with the presence of the summary because + # the body contains leading whitespaces when the summary does not. + body = textwrap.dedent(body) + + new_docstring_parts = [deprecation_note, "\n\n", summary, body] + else: + summary = summary_and_body[0] + + new_docstring_parts = [deprecation_note, "\n\n", summary] + + wrapper.__doc__ = "".join(new_docstring_parts) + + return wrapper + + return decorator diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 181694b633c..f7cc6a2b8a8 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -31,6 +31,7 @@ from ..state import AcceleratorState from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from .dataclasses import AutocastKwargs, CustomDtype, DistributedType +from .deprecation import deprecated from .imports import ( is_mlu_available, is_mps_available, @@ -471,6 +472,7 @@ class FindTiedParametersResult(list): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @deprecated(since="1.0.0rc0", removed_in="1.3.0", instructions="use another method instead") def values(self): warnings.warn( "The 'values' method of FindTiedParametersResult is deprecated and will be removed in Accelerate v1.3.0. ", diff --git a/tests/test_utils.py b/tests/test_utils.py index cabdb55a1a6..8dd033c2594 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,6 +14,7 @@ import os import pickle import tempfile +import textwrap import unittest import warnings from collections import UserDict, namedtuple @@ -54,6 +55,7 @@ save, send_to_device, ) +from accelerate.utils.deprecation import deprecated from accelerate.utils.operations import is_namedtuple @@ -413,3 +415,63 @@ def test_convert_dict_to_env_variables(self): with self.assertLogs("accelerate.utils.environment", level="WARNING"): valid_env_items = convert_dict_to_env_variables(env) assert valid_env_items == ["ACCELERATE_DEBUG_MODE=1\n", "OTHER_ENV=2\n"] + + def test_deprecated(self): + @deprecated("0.2.0", "0.3.0", "toy instruction") + def deprecated_demo(arg1: int, arg2: int) -> tuple: + """This is a long summary. This is a long summary. This is a long + summary. This is a long summary. + + Args: + arg1 (int): Long description with a line break. Long description + with a line break. + arg2 (int): short description. + + Returns: + Long description without a line break. Long description without + a line break. + """ + return arg1, arg2 + + with pytest.warns( + FutureWarning, match="deprecated in version 0.2.0 and will be removed in 0.3.0. Please toy instruction." + ): + self.assertEqual((1, 2), deprecated_demo(1, 2)) + + # Clean up docstring for comparison + expected_docstring = textwrap.dedent(""" + .. deprecated:: 0.2.0 + Deprecated and will be removed in version 0.3.0. Please toy instruction. + + This is a long summary. This is a long summary. This is a long + summary. This is a long summary. + + Args: + arg1 (int): Long description with a line break. Long description + with a line break. + arg2 (int): short description. + + Returns: + Long description without a line break. Long description without + a line break. + """) + # Remove all extra whitespace for comparison + expected_docstring = "".join(expected_docstring.split()) + actual_docstring = "".join(deprecated_demo.__doc__.split()) + + self.assertEqual(expected_docstring, actual_docstring) + + @deprecated("0.2.0", "0.3.0", "toy instruction") + def deprecated_demo1(): + """Short summary.""" + + expected_docstring1 = textwrap.dedent(""" + .. deprecated:: 0.2.0 + Deprecated and will be removed in version 0.3.0. Please toy instruction. + + Short summary. + """) + expected_docstring1 = "".join(expected_docstring1.split()) + actual_docstring1 = "".join(deprecated_demo1.__doc__.split()) + + self.assertEqual(expected_docstring1, actual_docstring1) From eaa0ab2e9978072070fad681f35f29017e3db4f1 Mon Sep 17 00:00:00 2001 From: yhna Date: Sat, 12 Oct 2024 22:12:55 +0900 Subject: [PATCH 02/10] Fix minor --- tests/test_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8dd033c2594..103ebd7abbc 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -438,7 +438,6 @@ def deprecated_demo(arg1: int, arg2: int) -> tuple: ): self.assertEqual((1, 2), deprecated_demo(1, 2)) - # Clean up docstring for comparison expected_docstring = textwrap.dedent(""" .. deprecated:: 0.2.0 Deprecated and will be removed in version 0.3.0. Please toy instruction. @@ -455,7 +454,7 @@ def deprecated_demo(arg1: int, arg2: int) -> tuple: Long description without a line break. Long description without a line break. """) - # Remove all extra whitespace for comparison + expected_docstring = "".join(expected_docstring.split()) actual_docstring = "".join(deprecated_demo.__doc__.split()) From 1c5b43e5c1499a28ed2f53b6fc726d7875bbcd69 Mon Sep 17 00:00:00 2001 From: yhna Date: Sat, 12 Oct 2024 22:17:55 +0900 Subject: [PATCH 03/10] Refactor test var name --- tests/test_utils.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 103ebd7abbc..b0b256d1e9c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -418,7 +418,7 @@ def test_convert_dict_to_env_variables(self): def test_deprecated(self): @deprecated("0.2.0", "0.3.0", "toy instruction") - def deprecated_demo(arg1: int, arg2: int) -> tuple: + def long_deprecated_demo(arg1: int, arg2: int) -> tuple: """This is a long summary. This is a long summary. This is a long summary. This is a long summary. @@ -436,9 +436,9 @@ def deprecated_demo(arg1: int, arg2: int) -> tuple: with pytest.warns( FutureWarning, match="deprecated in version 0.2.0 and will be removed in 0.3.0. Please toy instruction." ): - self.assertEqual((1, 2), deprecated_demo(1, 2)) + self.assertEqual((1, 2), long_deprecated_demo(1, 2)) - expected_docstring = textwrap.dedent(""" + long_expected_docstring = textwrap.dedent(""" .. deprecated:: 0.2.0 Deprecated and will be removed in version 0.3.0. Please toy instruction. @@ -455,22 +455,22 @@ def deprecated_demo(arg1: int, arg2: int) -> tuple: a line break. """) - expected_docstring = "".join(expected_docstring.split()) - actual_docstring = "".join(deprecated_demo.__doc__.split()) + long_expected_docstring = "".join(long_expected_docstring.split()) + long_actual_docstring = "".join(long_deprecated_demo.__doc__.split()) - self.assertEqual(expected_docstring, actual_docstring) + self.assertEqual(long_expected_docstring, long_actual_docstring) @deprecated("0.2.0", "0.3.0", "toy instruction") - def deprecated_demo1(): + def short_deprecated_demo(): """Short summary.""" - expected_docstring1 = textwrap.dedent(""" + short_expected_docstring = textwrap.dedent(""" .. deprecated:: 0.2.0 Deprecated and will be removed in version 0.3.0. Please toy instruction. Short summary. """) - expected_docstring1 = "".join(expected_docstring1.split()) - actual_docstring1 = "".join(deprecated_demo1.__doc__.split()) + short_expected_docstring = "".join(short_expected_docstring.split()) + short_actual_docstring = "".join(short_deprecated_demo.__doc__.split()) - self.assertEqual(expected_docstring1, actual_docstring1) + self.assertEqual(short_expected_docstring, short_actual_docstring) From 174425c82f23ee68024ae085255aed06b29b43fb Mon Sep 17 00:00:00 2001 From: yhna Date: Sat, 12 Oct 2024 22:24:27 +0900 Subject: [PATCH 04/10] Fix minor --- tests/test_utils.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index b0b256d1e9c..e4738064d35 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -423,13 +423,11 @@ def long_deprecated_demo(arg1: int, arg2: int) -> tuple: summary. This is a long summary. Args: - arg1 (int): Long description with a line break. Long description - with a line break. - arg2 (int): short description. + arg1 (int): Description. + arg2 (int): Description. Returns: - Long description without a line break. Long description without - a line break. + Description. """ return arg1, arg2 @@ -446,13 +444,11 @@ def long_deprecated_demo(arg1: int, arg2: int) -> tuple: summary. This is a long summary. Args: - arg1 (int): Long description with a line break. Long description - with a line break. - arg2 (int): short description. + arg1 (int): Description. + arg2 (int): Description. Returns: - Long description without a line break. Long description without - a line break. + Description. """) long_expected_docstring = "".join(long_expected_docstring.split()) From 558e11eac8889bc88275d5132c10bfc2e2639f95 Mon Sep 17 00:00:00 2001 From: yhna Date: Sat, 12 Oct 2024 22:28:05 +0900 Subject: [PATCH 05/10] Del cmt --- src/accelerate/utils/deprecation.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/accelerate/utils/deprecation.py b/src/accelerate/utils/deprecation.py index d4321119f08..102ffad937a 100644 --- a/src/accelerate/utils/deprecation.py +++ b/src/accelerate/utils/deprecation.py @@ -56,7 +56,6 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: # Add a deprecation note to the docstring. docstring = function.__doc__ or "" - # Add a note to the docstring. deprecation_note = textwrap.dedent( f"""\ .. deprecated:: {since} @@ -66,18 +65,12 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: # Split docstring at first occurrence of newline summary_and_body = docstring.split("\n\n", 1) - if len(summary_and_body) > 1: summary, body = summary_and_body - - # Dedent the body. We cannot do this with the presence of the summary because - # the body contains leading whitespaces when the summary does not. body = textwrap.dedent(body) - new_docstring_parts = [deprecation_note, "\n\n", summary, body] else: summary = summary_and_body[0] - new_docstring_parts = [deprecation_note, "\n\n", summary] wrapper.__doc__ = "".join(new_docstring_parts) From c0fb0abed8ee63a829596eb54030256dea6917ef Mon Sep 17 00:00:00 2001 From: yhna Date: Sat, 12 Oct 2024 22:30:44 +0900 Subject: [PATCH 06/10] Del prev warninigs --- src/accelerate/utils/modeling.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index f7cc6a2b8a8..8dd55160dab 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -474,10 +474,6 @@ def __init__(self, *args, **kwargs): @deprecated(since="1.0.0rc0", removed_in="1.3.0", instructions="use another method instead") def values(self): - warnings.warn( - "The 'values' method of FindTiedParametersResult is deprecated and will be removed in Accelerate v1.3.0. ", - FutureWarning, - ) return sum([x[1:] for x in self], []) From 835c9c9815cb5da8bb4997127071ee483adc4021 Mon Sep 17 00:00:00 2001 From: yhna940 Date: Thu, 31 Oct 2024 22:27:15 +0900 Subject: [PATCH 07/10] Apply style --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index ab006a5170c..811dade7a19 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -473,7 +473,7 @@ def short_deprecated_demo(): short_actual_docstring = "".join(short_deprecated_demo.__doc__.split()) self.assertEqual(short_expected_docstring, short_actual_docstring) - + def test_has_offloaded_params(self): model = RegressionModel() assert not has_offloaded_params(model) From a7a43aa2f2268885b5db25a1324b04b59921ec65 Mon Sep 17 00:00:00 2001 From: yhna Date: Sat, 2 Nov 2024 10:33:55 +0900 Subject: [PATCH 08/10] Rename instruction arg --- src/accelerate/utils/deprecation.py | 8 ++++---- src/accelerate/utils/modeling.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/accelerate/utils/deprecation.py b/src/accelerate/utils/deprecation.py index 102ffad937a..4dba318e44e 100644 --- a/src/accelerate/utils/deprecation.py +++ b/src/accelerate/utils/deprecation.py @@ -24,7 +24,7 @@ _P = ParamSpec("_P") -def deprecated(since: str, removed_in: str, instructions: str) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: +def deprecated(since: str, removed_in: str, instruction: str) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: """Marks functions as deprecated. It will result in a warning when the function is called and a note in the docstring. @@ -34,7 +34,7 @@ def deprecated(since: str, removed_in: str, instructions: str) -> Callable[[Call The version when the function was first deprecated. removed_in (`str`): The version when the function will be removed. - instructions (`str`): + instruction (`str`): The action users should take. Returns: @@ -47,7 +47,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: warnings.warn( f"'{function.__module__}.{function.__name__}' " f"is deprecated in version {since} and will be " - f"removed in {removed_in}. Please {instructions}.", + f"removed in {removed_in}. Please {instruction}.", category=FutureWarning, stacklevel=2, ) @@ -59,7 +59,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: deprecation_note = textwrap.dedent( f"""\ .. deprecated:: {since} - Deprecated and will be removed in version {removed_in}. Please {instructions}. + Deprecated and will be removed in version {removed_in}. Please {instruction}. """ ) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 6360d624632..eff3966b0a3 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -472,7 +472,7 @@ class FindTiedParametersResult(list): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - @deprecated(since="1.0.0rc0", removed_in="1.3.0", instructions="use another method instead") + @deprecated(since="1.0.0rc0", removed_in="1.3.0", instruction="use another method instead") def values(self): return sum([x[1:] for x in self], []) From 2cb46485f9bfda5c6108978fe786eeda3227ef0f Mon Sep 17 00:00:00 2001 From: yhna Date: Sat, 2 Nov 2024 10:36:52 +0900 Subject: [PATCH 09/10] Del pls msg --- src/accelerate/utils/deprecation.py | 4 ++-- tests/test_utils.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/accelerate/utils/deprecation.py b/src/accelerate/utils/deprecation.py index 4dba318e44e..9f70ef23b78 100644 --- a/src/accelerate/utils/deprecation.py +++ b/src/accelerate/utils/deprecation.py @@ -47,7 +47,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: warnings.warn( f"'{function.__module__}.{function.__name__}' " f"is deprecated in version {since} and will be " - f"removed in {removed_in}. Please {instruction}.", + f"removed in {removed_in}. {instruction}.", category=FutureWarning, stacklevel=2, ) @@ -59,7 +59,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: deprecation_note = textwrap.dedent( f"""\ .. deprecated:: {since} - Deprecated and will be removed in version {removed_in}. Please {instruction}. + Deprecated and will be removed in version {removed_in}. {instruction}. """ ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 811dade7a19..61bf477f335 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -435,13 +435,13 @@ def long_deprecated_demo(arg1: int, arg2: int) -> tuple: return arg1, arg2 with pytest.warns( - FutureWarning, match="deprecated in version 0.2.0 and will be removed in 0.3.0. Please toy instruction." + FutureWarning, match="deprecated in version 0.2.0 and will be removed in 0.3.0. toy instruction." ): self.assertEqual((1, 2), long_deprecated_demo(1, 2)) long_expected_docstring = textwrap.dedent(""" .. deprecated:: 0.2.0 - Deprecated and will be removed in version 0.3.0. Please toy instruction. + Deprecated and will be removed in version 0.3.0. toy instruction. This is a long summary. This is a long summary. This is a long summary. This is a long summary. @@ -465,7 +465,7 @@ def short_deprecated_demo(): short_expected_docstring = textwrap.dedent(""" .. deprecated:: 0.2.0 - Deprecated and will be removed in version 0.3.0. Please toy instruction. + Deprecated and will be removed in version 0.3.0. toy instruction. Short summary. """) From 89b1974c0539bbcf21d70abf91505021d8991de4 Mon Sep 17 00:00:00 2001 From: yhna Date: Sat, 2 Nov 2024 10:41:05 +0900 Subject: [PATCH 10/10] Add cls test case --- tests/test_utils.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index 61bf477f335..2d9f969837f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -474,6 +474,28 @@ def short_deprecated_demo(): self.assertEqual(short_expected_docstring, short_actual_docstring) + @deprecated("0.2.0", "0.3.0", "toy instruction") + class OldClass: + """Old class docstring.""" + + def method(self): + pass + + with pytest.warns( + FutureWarning, match="deprecated in version 0.2.0 and will be removed in 0.3.0. toy instruction." + ): + OldClass() + + class_expected_docstring = textwrap.dedent(""" + .. deprecated:: 0.2.0 + Deprecated and will be removed in version 0.3.0. toy instruction. + Old class docstring. + """) + class_expected_docstring = "".join(class_expected_docstring.split()) + class_actual_docstring = "".join(OldClass.__doc__.split()) + + self.assertEqual(class_expected_docstring, class_actual_docstring) + def test_has_offloaded_params(self): model = RegressionModel() assert not has_offloaded_params(model)