diff --git a/lightning_ir/base/class_factory.py b/lightning_ir/base/class_factory.py index 9dde0fc..b169559 100644 --- a/lightning_ir/base/class_factory.py +++ b/lightning_ir/base/class_factory.py @@ -15,9 +15,6 @@ CONFIG_MAPPING, MODEL_MAPPING, TOKENIZER_MAPPING, - AutoConfig, - AutoModel, - AutoTokenizer, PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase, @@ -159,13 +156,11 @@ def from_backbone_class(self, BackboneClass: Type[PretrainedConfig]) -> Type[Lig f"{self.cc_lir_model_type}{BackboneClass.__name__}", (LightningIRConfigMixin, BackboneClass), { - "model_type": f"{BackboneClass.model_type}-{self.MixinConfig.model_type}", + "model_type": self.MixinConfig.model_type, "backbone_model_type": BackboneClass.model_type, + "mixin_config": self.MixinConfig, }, ) - - AutoConfig.register(DerivedLightningIRConfig.model_type, DerivedLightningIRConfig, exist_ok=True) - return DerivedLightningIRConfig @@ -217,9 +212,6 @@ def from_backbone_class(self, BackboneClass: Type[PreTrainedModel]) -> Type[Ligh (LightningIRModelMixin, BackboneClass), {"config_class": DerivedLightningIRConfig, "_backbone_forward": BackboneClass.forward}, ) - - AutoModel.register(DerivedLightningIRConfig, DerivedLightningIRModel, exist_ok=True) - return DerivedLightningIRModel @@ -252,10 +244,12 @@ def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> except (OSError, ValueError): # best guess at model type config_dict = get_tokenizer_config(model_name_or_path) - Tokenizer = tokenizer_class_from_name(config_dict["tokenizer_class"]) - for config, tokenizers in TOKENIZER_MAPPING.items(): - if Tokenizer in tokenizers: - return getattr(config, "backbone_model_type", None) or getattr(config, "model_type") + backbone_tokenizer_class = config_dict.get("backbone_tokenizer_class", None) + if backbone_tokenizer_class is not None: + Tokenizer = tokenizer_class_from_name(backbone_tokenizer_class) + for config, tokenizers in TOKENIZER_MAPPING.items(): + if Tokenizer in tokenizers: + return getattr(config, "model_type") raise ValueError("No backbone model found in the configuration") def from_pretrained( @@ -305,12 +299,6 @@ def from_backbone_classes( ) if DerivedLightningIRTokenizers[1] is not None: DerivedLightningIRTokenizers[1].slow_tokenizer_class = DerivedLightningIRTokenizers[0] - - DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig) - AutoTokenizer.register( - DerivedLightningIRConfig, DerivedLightningIRTokenizers[0], DerivedLightningIRTokenizers[1] - ) - return DerivedLightningIRTokenizers def from_backbone_class(self, BackboneClass: Type[PreTrainedTokenizerBase]) -> Type[LightningIRTokenizer]: diff --git a/lightning_ir/base/config.py b/lightning_ir/base/config.py index 3889289..05d3e65 100644 --- a/lightning_ir/base/config.py +++ b/lightning_ir/base/config.py @@ -9,13 +9,13 @@ class from the Hugging Face Transformers library. from pathlib import Path from typing import Any, Dict, Set -from transformers import CONFIG_MAPPING +from transformers import PretrainedConfig from .class_factory import LightningIRConfigClassFactory from .external_model_hub import CHECKPOINT_MAPPING -class LightningIRConfig: +class LightningIRConfig(PretrainedConfig): """The configuration class to instantiate a Lightning IR model. Acts as a mixin for the transformers.PretrainedConfig_ class. @@ -71,10 +71,7 @@ def to_dict(self) -> Dict[str, Any]: :return: Configuration dictionary :rtype: Dict[str, Any] """ - if hasattr(super(), "to_dict"): - output = getattr(super(), "to_dict")() - else: - output = self.to_added_args_dict() + output = getattr(super(), "to_dict")() if self.backbone_model_type is not None: output["backbone_model_type"] = self.backbone_model_type return output @@ -93,7 +90,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | Path, *args, **kwa :return: Derived LightningIRConfig class :rtype: LightningIRConfig """ + # provides AutoConfig.from_pretrained support if cls is LightningIRConfig or all(issubclass(base, LightningIRConfig) for base in cls.__bases__): + # no backbone config found, create dervied lightning-ir config based on backbone config config = None if pretrained_model_name_or_path in CHECKPOINT_MAPPING: config = CHECKPOINT_MAPPING[pretrained_model_name_or_path] @@ -111,34 +110,3 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | Path, *args, **kwa derived_config.update(config.to_dict()) return cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) return super(LightningIRConfig, cls).from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - - @classmethod - def from_dict(cls, config_dict: Dict[str, Any], *args, **kwargs) -> "LightningIRConfig": - """Loads the configuration from a dictionary. Wraps the transformers.PretrainedConfig.from_dict_ method to - return a derived LightningIRConfig class. See :class:`.LightningIRConfigClassFactory` for more details. - - .. _transformers.PretrainedConfig.from_dict: \ -https://huggingface.co/docs/transformers/main_classes/configuration.html#transformers.PretrainedConfig.from_dict - - :param config_dict: Configuration dictionary - :type config_dict: Dict[str, Any] - :raises ValueError: If the model type does not match the configuration model type - :return: Derived LightningIRConfig class - :rtype: LightningIRConfig - """ - if all(issubclass(base, LightningIRConfig) for base in cls.__bases__) or cls is LightningIRConfig: - if "backbone_model_type" in config_dict: - backbone_model_type = config_dict["backbone_model_type"] - model_type = config_dict["model_type"] - if cls is not LightningIRConfig and model_type != cls.model_type: - raise ValueError( - f"Model type {model_type} does not match configuration model type {cls.model_type}" - ) - else: - backbone_model_type = config_dict["model_type"] - model_type = cls.model_type - MixinConfig = CONFIG_MAPPING[model_type] - BackboneConfig = CONFIG_MAPPING[backbone_model_type] - cls = LightningIRConfigClassFactory(MixinConfig).from_backbone_class(BackboneConfig) - return cls.from_dict(config_dict, *args, **kwargs) - return super(LightningIRConfig, cls).from_dict(config_dict, *args, **kwargs) diff --git a/lightning_ir/base/model.py b/lightning_ir/base/model.py index 998ba0c..b90b7e2 100644 --- a/lightning_ir/base/model.py +++ b/lightning_ir/base/model.py @@ -161,7 +161,8 @@ def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> "Li """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method and to return a derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details. - .. _transformers.PreTrainedModel.from_pretrained: https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained # noqa +.. _transformers.PreTrainedModel.from_pretrained: \ + https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained :param model_name_or_path: Name or path of the pretrained model :type model_name_or_path: str | Path diff --git a/lightning_ir/base/tokenizer.py b/lightning_ir/base/tokenizer.py index 8be8c46..fc53b39 100644 --- a/lightning_ir/base/tokenizer.py +++ b/lightning_ir/base/tokenizer.py @@ -4,16 +4,18 @@ This module contains the main tokenizer class for the Lightning IR library. """ -from typing import Dict, Sequence, Type +import json +from os import PathLike +from typing import Dict, Sequence, Tuple, Type -from transformers import TOKENIZER_MAPPING, BatchEncoding +from transformers import TOKENIZER_MAPPING, BatchEncoding, PreTrainedTokenizerBase from .class_factory import LightningIRTokenizerClassFactory from .config import LightningIRConfig from .external_model_hub import CHECKPOINT_MAPPING -class LightningIRTokenizer: +class LightningIRTokenizer(PreTrainedTokenizerBase): """Base class for Lightning IR tokenizers. Derived classes implement the tokenize method for handling query and document tokenization. It acts as mixin for a transformers.PreTrainedTokenizer_ backbone tokenizer. @@ -77,6 +79,7 @@ def from_pretrained(cls, model_name_or_path: str, *args, **kwargs) -> "Lightning :return: A derived LightningIRTokenizer consisting of a backbone tokenizer and a LightningIRTokenizer mixin :rtype: LightningIRTokenizer """ + # provides AutoTokenizer.from_pretrained support config = kwargs.pop("config", None) if config is not None: kwargs.update(config.to_tokenizer_dict()) @@ -94,6 +97,7 @@ def from_pretrained(cls, model_name_or_path: str, *args, **kwargs) -> "Lightning Config = LightningIRTokenizerClassFactory.get_lightning_ir_config(model_name_or_path) if Config is None: raise ValueError("Pass a config to `from_pretrained`.") + Config = getattr(Config, "mixin_config", Config) BackboneConfig = LightningIRTokenizerClassFactory.get_backbone_config(model_name_or_path) BackboneTokenizers = TOKENIZER_MAPPING[BackboneConfig] if kwargs.get("use_fast", True): @@ -103,3 +107,36 @@ def from_pretrained(cls, model_name_or_path: str, *args, **kwargs) -> "Lightning cls = LightningIRTokenizerClassFactory(Config).from_backbone_class(BackboneTokenizer) return cls.from_pretrained(model_name_or_path, *args, **kwargs) return super(LightningIRTokenizer, cls).from_pretrained(model_name_or_path, *args, **kwargs) + + def _save_pretrained( + self, + save_directory: str | PathLike, + file_names: Tuple[str], + legacy_format: bool | None = None, + filename_prefix: str | None = None, + ) -> Tuple[str]: + # bit of a hack to change the tokenizer class in the stored tokenizer config to only contain the + # lightning_ir tokenizer class (removing the backbone tokenizer class) + save_files = super()._save_pretrained(save_directory, file_names, legacy_format, filename_prefix) + config_file = save_files[0] + with open(config_file, "r") as file: + tokenizer_config = json.load(file) + + tokenizer_class = None + backbone_tokenizer_class = None + for base in self.__class__.__bases__: + if issubclass(base, LightningIRTokenizer): + if tokenizer_class is not None: + raise ValueError("Multiple Lightning IR tokenizer classes found.") + tokenizer_class = base.__name__ + continue + if issubclass(base, PreTrainedTokenizerBase): + backbone_tokenizer_class = base.__name__ + + tokenizer_config["tokenizer_class"] = tokenizer_class + tokenizer_config["backbone_tokenizer_class"] = backbone_tokenizer_class + + with open(config_file, "w") as file: + out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n" + file.write(out_str) + return save_files diff --git a/lightning_ir/bi_encoder/config.py b/lightning_ir/bi_encoder/config.py index 3259972..3c96aab 100644 --- a/lightning_ir/bi_encoder/config.py +++ b/lightning_ir/bi_encoder/config.py @@ -149,9 +149,9 @@ def save_pretrained(self, save_directory: str | PathLike, **kwargs) -> None: :param save_directory: Directory to save the configuration :type save_directory: str | PathLike """ + super().save_pretrained(save_directory, **kwargs) with open(os.path.join(save_directory, "mask_scoring_tokens.json"), "w") as f: json.dump({"query": self.query_mask_scoring_tokens, "doc": self.doc_mask_scoring_tokens}, f) - return super().save_pretrained(save_directory, **kwargs) @classmethod def get_config_dict( diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..cd4a2d3 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,31 @@ +from pathlib import Path + +from transformers import AutoConfig + +from lightning_ir.base.config import LightningIRConfig +from lightning_ir.base.module import LightningIRModule + + +def test_serialize_deserialize(module: LightningIRModule, tmp_path: Path): + config = module.model.config + config_class = module.model.config_class + save_dir = str(tmp_path / config_class.model_type) + config.save_pretrained(save_dir) + new_configs = [ + config.__class__.from_pretrained(save_dir), + config.__class__.__bases__[0].from_pretrained(save_dir), + LightningIRConfig.from_pretrained(save_dir), + AutoConfig.from_pretrained(save_dir), + ] + for new_config in new_configs: + for key, value in config.__dict__.items(): + if key in ( + "torch_dtype", + "_name_or_path", + "_commit_hash", + "transformers_version", + "_attn_implementation_autoset", + "_attn_implementation_internal", + ): + continue + assert getattr(new_config, key) == value diff --git a/tests/test_model.py b/tests/test_model.py index aa07233..66e3608 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -77,7 +77,6 @@ def test_seralize_deserialize(module: LightningIRModule, tmp_path: Path): "_name_or_path", "_commit_hash", "transformers_version", - "model_type", "_attn_implementation_autoset", ): continue