From 3c61b3e8808d9d80217df4c17b6ce313156c0105 Mon Sep 17 00:00:00 2001 From: Fabian Keller Date: Wed, 11 Dec 2024 15:20:27 +0100 Subject: [PATCH] ENH Typing: fix library interface (#2265) Improve typing (re-export) in __init__.py files. --- src/peft/__init__.py | 96 +++++++++++++++++++++++++++++++++--- src/peft/tuners/__init__.py | 61 ++++++++++++++++++++--- src/peft/utils/__init__.py | 48 +++++++++++++++--- src/peft/utils/peft_types.py | 6 +-- 4 files changed, 186 insertions(+), 25 deletions(-) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index d01b3b0b7c..6e4e124c18 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -1,8 +1,3 @@ -# flake8: noqa -# There's no way to ignore "F401 '...' imported but unused" warnings in this -# module, but to preserve other warnings. So, don't check this module at all. - -# coding=utf-8 # Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,7 +14,7 @@ __version__ = "0.14.1.dev0" -from .auto import ( +from .auto import ( # noqa: I001 AutoPeftModel, AutoPeftModelForCausalLM, AutoPeftModelForSequenceClassification, @@ -91,7 +86,6 @@ XLoraModel, HRAConfig, HRAModel, - VBLoRAConfig, get_eva_state_dict, initialize_lora_eva_weights, CPTEmbedding, @@ -113,3 +107,91 @@ cast_mixed_precision_params, ) from .config import PeftConfig, PromptLearningConfig + +__all__ = [ + "MODEL_TYPE_TO_PEFT_MODEL_MAPPING", + "PEFT_TYPE_TO_CONFIG_MAPPING", + "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", + "AdaLoraConfig", + "AdaLoraModel", + "AdaptionPromptConfig", + "AdaptionPromptModel", + "AutoPeftModel", + "AutoPeftModelForCausalLM", + "AutoPeftModelForFeatureExtraction", + "AutoPeftModelForQuestionAnswering", + "AutoPeftModelForSeq2SeqLM", + "AutoPeftModelForSequenceClassification", + "AutoPeftModelForTokenClassification", + "BOFTConfig", + "BOFTModel", + "BoneConfig", + "BoneModel", + "CPTConfig", + "CPTEmbedding", + "EvaConfig", + "FourierFTConfig", + "FourierFTModel", + "HRAConfig", + "HRAModel", + "IA3Config", + "IA3Model", + "LNTuningConfig", + "LNTuningModel", + "LoHaConfig", + "LoHaModel", + "LoKrConfig", + "LoKrModel", + "LoftQConfig", + "LoraConfig", + "LoraModel", + "LoraRuntimeConfig", + "MultitaskPromptTuningConfig", + "MultitaskPromptTuningInit", + "OFTConfig", + "OFTModel", + "PeftConfig", + "PeftMixedModel", + "PeftModel", + "PeftModelForCausalLM", + "PeftModelForFeatureExtraction", + "PeftModelForQuestionAnswering", + "PeftModelForSeq2SeqLM", + "PeftModelForSequenceClassification", + "PeftModelForTokenClassification", + "PeftType", + "PolyConfig", + "PolyModel", + "PrefixEncoder", + "PrefixTuningConfig", + "PromptEmbedding", + "PromptEncoder", + "PromptEncoderConfig", + "PromptEncoderReparameterizationType", + "PromptLearningConfig", + "PromptTuningConfig", + "PromptTuningInit", + "TaskType", + "VBLoRAConfig", + "VBLoRAConfig", + "VBLoRAModel", + "VeraConfig", + "VeraModel", + "XLoraConfig", + "XLoraModel", + "bloom_model_postprocess_past_key_value", + "cast_mixed_precision_params", + "get_eva_state_dict", + "get_layer_status", + "get_model_status", + "get_peft_config", + "get_peft_model", + "get_peft_model_state_dict", + "initialize_lora_eva_weights", + "inject_adapter_in_model", + "load_peft_weights", + "prepare_model_for_kbit_training", + "replace_lora_weights_loftq", + "set_peft_model_state_dict", + "shift_tokens_right", +] diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index 605d601614..bc3d08a0db 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -1,8 +1,3 @@ -# flake8: noqa -# There's no way to ignore "F401 '...' imported but unused" warnings in this -# module, but to preserve other warnings. So, don't check this module at all - -# coding=utf-8 # Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel +from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel # noqa: I001 from .lora import ( LoraConfig, LoraModel, @@ -47,3 +42,57 @@ from .vblora import VBLoRAConfig, VBLoRAModel from .cpt import CPTConfig, CPTEmbedding from .bone import BoneConfig, BoneModel + +__all__ = [ + "AdaLoraConfig", + "AdaLoraModel", + "AdaptionPromptConfig", + "AdaptionPromptModel", + "BOFTConfig", + "BOFTModel", + "BoneConfig", + "BoneModel", + "CPTConfig", + "CPTEmbedding", + "EvaConfig", + "FourierFTConfig", + "FourierFTModel", + "HRAConfig", + "HRAModel", + "IA3Config", + "IA3Model", + "LNTuningConfig", + "LNTuningModel", + "LoHaConfig", + "LoHaModel", + "LoKrConfig", + "LoKrModel", + "LoftQConfig", + "LoraConfig", + "LoraModel", + "LoraRuntimeConfig", + "MixedModel", + "MultitaskPromptEmbedding", + "MultitaskPromptTuningConfig", + "MultitaskPromptTuningInit", + "OFTConfig", + "OFTModel", + "PolyConfig", + "PolyModel", + "PrefixEncoder", + "PrefixTuningConfig", + "PromptEmbedding", + "PromptEncoder", + "PromptEncoderConfig", + "PromptEncoderReparameterizationType", + "PromptTuningConfig", + "PromptTuningInit", + "VBLoRAConfig", + "VBLoRAModel", + "VeraConfig", + "VeraModel", + "XLoraConfig", + "XLoraModel", + "get_eva_state_dict", + "initialize_lora_eva_weights", +] diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 63a7216168..d9bf8fbe28 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -1,8 +1,3 @@ -# flake8: noqa -# There's no way to ignore "F401 '...' imported but unused" warnings in this -# module, but to preserve other warnings. So, don't check this module at all - -# coding=utf-8 # Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType -from .integrations import map_cache_to_layer_device_map +from .integrations import map_cache_to_layer_device_map # noqa: I001 from .loftq_utils import replace_lora_weights_loftq from .peft_types import PeftType, TaskType from .other import ( @@ -54,3 +48,43 @@ cast_mixed_precision_params, ) from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict, load_peft_weights + +__all__ = [ + "CONFIG_NAME", + "INCLUDE_LINEAR_LAYERS_SHORTHAND", + "SAFETENSORS_WEIGHTS_NAME", + "TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", + "TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING", + "WEIGHTS_NAME", + "ModulesToSaveWrapper", + "PeftType", + "TaskType", + "_freeze_adapter", + "_get_batch_size", + "_get_submodules", + "_is_valid_match", + "_prepare_prompt_learning_config", + "_set_adapter", + "_set_trainable", + "bloom_model_postprocess_past_key_value", + "cast_mixed_precision_params", + "get_auto_gptq_quant_linear", + "get_peft_model_state_dict", + "get_quantization_config", + "id_tensor_storage", + "infer_device", + "load_peft_weights", + "map_cache_to_layer_device_map", + "prepare_model_for_kbit_training", + "replace_lora_weights_loftq", + "set_peft_model_state_dict", + "shift_tokens_right", + "transpose", +] diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index 04be2a6d66..d2f1539074 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -1,8 +1,3 @@ -# flake8: noqa -# There's no way to ignore "F401 '...' imported but unused" warnings in this -# module, but to preserve other warnings. So, don't check this module at all - -# coding=utf-8 # Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import enum