From 5eb49910ed71b3a4cd3db48e1e5e5839f75426e0 Mon Sep 17 00:00:00 2001 From: TheRootOf3 Date: Fri, 22 Mar 2024 09:14:19 +0000 Subject: [PATCH] :art: Format Python code with psf/black --- .../evaluation_scripts/constants.py | 78 +++++++++---------- .../evaluation_scripts/moderation.py | 24 ++++-- 2 files changed, 55 insertions(+), 47 deletions(-) diff --git a/eval_harmfulness/evaluation_scripts/constants.py b/eval_harmfulness/evaluation_scripts/constants.py index a36c741..3115e7f 100644 --- a/eval_harmfulness/evaluation_scripts/constants.py +++ b/eval_harmfulness/evaluation_scripts/constants.py @@ -18,58 +18,58 @@ __all__ = [ - 'IGNORE_INDEX', - 'DEFAULT_BOS_TOKEN', - 'DEFAULT_EOS_TOKEN', - 'DEFAULT_PAD_TOKEN', - 'DEFAULT_UNK_TOKEN', - 'PROMPT_BEGIN', - 'PROMPT_USER', - 'PROMPT_ASSISTANT', - 'PROMPT_INPUT', - 'PROMPT_DICT', - 'ADAM_BETAS', - 'NUM_LABELS', - 'LABEL_NAMES', - 'ID2LABELS', + "IGNORE_INDEX", + "DEFAULT_BOS_TOKEN", + "DEFAULT_EOS_TOKEN", + "DEFAULT_PAD_TOKEN", + "DEFAULT_UNK_TOKEN", + "PROMPT_BEGIN", + "PROMPT_USER", + "PROMPT_ASSISTANT", + "PROMPT_INPUT", + "PROMPT_DICT", + "ADAM_BETAS", + "NUM_LABELS", + "LABEL_NAMES", + "ID2LABELS", ] IGNORE_INDEX: int = -100 -DEFAULT_BOS_TOKEN: str = '' -DEFAULT_EOS_TOKEN: str = '' -DEFAULT_PAD_TOKEN: str = '' -DEFAULT_UNK_TOKEN: str = '' +DEFAULT_BOS_TOKEN: str = "" +DEFAULT_EOS_TOKEN: str = "" +DEFAULT_PAD_TOKEN: str = "" +DEFAULT_UNK_TOKEN: str = "" -PROMPT_BEGIN: str = 'BEGINNING OF CONVERSATION: ' -PROMPT_USER: str = 'USER: {input} ' -PROMPT_ASSISTANT: str = 'ASSISTANT:' # should not have a space at the end +PROMPT_BEGIN: str = "BEGINNING OF CONVERSATION: " +PROMPT_USER: str = "USER: {input} " +PROMPT_ASSISTANT: str = "ASSISTANT:" # should not have a space at the end PROMPT_INPUT: str = PROMPT_BEGIN + PROMPT_USER + PROMPT_ASSISTANT PROMPT_DICT: dict[str, str] = { - 'prompt_begin': PROMPT_BEGIN, - 'prompt_user': PROMPT_USER, - 'prompt_assistant': PROMPT_ASSISTANT, - 'prompt_input': PROMPT_INPUT, + "prompt_begin": PROMPT_BEGIN, + "prompt_user": PROMPT_USER, + "prompt_assistant": PROMPT_ASSISTANT, + "prompt_input": PROMPT_INPUT, } ADAM_BETAS: tuple[float, float] = (0.9, 0.95) LABEL_NAMES: list[str] = [ - 'animal_abuse', - 'child_abuse', - 'controversial_topics,politics', - 'discrimination,stereotype,injustice', - 'drug_abuse,weapons,banned_substance', - 'financial_crime,property_crime,theft', - 'hate_speech,offensive_language', - 'misinformation_regarding_ethics,laws_and_safety', - 'non_violent_unethical_behavior', - 'privacy_violation', - 'self_harm', - 'sexually_explicit,adult_content', - 'terrorism,organized_crime', - 'violence,aiding_and_abetting,incitement', + "animal_abuse", + "child_abuse", + "controversial_topics,politics", + "discrimination,stereotype,injustice", + "drug_abuse,weapons,banned_substance", + "financial_crime,property_crime,theft", + "hate_speech,offensive_language", + "misinformation_regarding_ethics,laws_and_safety", + "non_violent_unethical_behavior", + "privacy_violation", + "self_harm", + "sexually_explicit,adult_content", + "terrorism,organized_crime", + "violence,aiding_and_abetting,incitement", ] NUM_LABELS: int = len(LABEL_NAMES) diff --git a/eval_harmfulness/evaluation_scripts/moderation.py b/eval_harmfulness/evaluation_scripts/moderation.py index 8086b25..950cf46 100644 --- a/eval_harmfulness/evaluation_scripts/moderation.py +++ b/eval_harmfulness/evaluation_scripts/moderation.py @@ -222,7 +222,8 @@ def predict( batch_size: int, return_bool: Literal[False], threshold: float, - ) -> list[dict[str, float]]: ... + ) -> list[dict[str, float]]: + ... @overload def predict( @@ -231,7 +232,8 @@ def predict( batch_size: int, return_bool: Literal[True], threshold: float, - ) -> list[dict[str, bool]]: ... + ) -> list[dict[str, bool]]: + ... @overload def predict( @@ -240,7 +242,8 @@ def predict( batch_size: int, return_bool: Literal[False], threshold: float, - ) -> dict[str, float]: ... + ) -> dict[str, float]: + ... @overload def predict( @@ -249,7 +252,8 @@ def predict( batch_size: int, return_bool: Literal[True], threshold: float, - ) -> dict[str, bool]: ... + ) -> dict[str, bool]: + ... @torch.inference_mode() def predict( @@ -322,7 +326,8 @@ def predict( # pylint: disable=arguments-differ batch_size: int, return_bool: Literal[False], threshold: float, - ) -> list[dict[str, float]]: ... + ) -> list[dict[str, float]]: + ... @overload def predict( # pylint: disable=arguments-differ @@ -332,7 +337,8 @@ def predict( # pylint: disable=arguments-differ batch_size: int, return_bool: Literal[True], threshold: float, - ) -> list[dict[str, bool]]: ... + ) -> list[dict[str, bool]]: + ... @overload def predict( # pylint: disable=arguments-differ @@ -342,7 +348,8 @@ def predict( # pylint: disable=arguments-differ batch_size: int, return_bool: Literal[False], threshold: float, - ) -> dict[str, float]: ... + ) -> dict[str, float]: + ... @overload def predict( # pylint: disable=arguments-differ @@ -352,7 +359,8 @@ def predict( # pylint: disable=arguments-differ batch_size: int, return_bool: Literal[True], threshold: float, - ) -> dict[str, bool]: ... + ) -> dict[str, bool]: + ... @torch.inference_mode() def predict( # pylint: disable=arguments-differ,arguments-renamed