Skip to content

Commit

Permalink
remove registration of derived lightning-ir models
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Nov 27, 2024
1 parent 78e373a commit efa4724
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 63 deletions.
28 changes: 8 additions & 20 deletions lightning_ir/base/class_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
CONFIG_MAPPING,
MODEL_MAPPING,
TOKENIZER_MAPPING,
AutoConfig,
AutoModel,
AutoTokenizer,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Expand Down Expand Up @@ -159,13 +156,11 @@ def from_backbone_class(self, BackboneClass: Type[PretrainedConfig]) -> Type[Lig
f"{self.cc_lir_model_type}{BackboneClass.__name__}",
(LightningIRConfigMixin, BackboneClass),
{
"model_type": f"{BackboneClass.model_type}-{self.MixinConfig.model_type}",
"model_type": self.MixinConfig.model_type,
"backbone_model_type": BackboneClass.model_type,
"mixin_config": self.MixinConfig,
},
)

AutoConfig.register(DerivedLightningIRConfig.model_type, DerivedLightningIRConfig, exist_ok=True)

return DerivedLightningIRConfig


Expand Down Expand Up @@ -217,9 +212,6 @@ def from_backbone_class(self, BackboneClass: Type[PreTrainedModel]) -> Type[Ligh
(LightningIRModelMixin, BackboneClass),
{"config_class": DerivedLightningIRConfig, "_backbone_forward": BackboneClass.forward},
)

AutoModel.register(DerivedLightningIRConfig, DerivedLightningIRModel, exist_ok=True)

return DerivedLightningIRModel


Expand Down Expand Up @@ -252,10 +244,12 @@ def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) ->
except (OSError, ValueError):
# best guess at model type
config_dict = get_tokenizer_config(model_name_or_path)
Tokenizer = tokenizer_class_from_name(config_dict["tokenizer_class"])
for config, tokenizers in TOKENIZER_MAPPING.items():
if Tokenizer in tokenizers:
return getattr(config, "backbone_model_type", None) or getattr(config, "model_type")
backbone_tokenizer_class = config_dict.get("backbone_tokenizer_class", None)
if backbone_tokenizer_class is not None:
Tokenizer = tokenizer_class_from_name(backbone_tokenizer_class)
for config, tokenizers in TOKENIZER_MAPPING.items():
if Tokenizer in tokenizers:
return getattr(config, "model_type")
raise ValueError("No backbone model found in the configuration")

def from_pretrained(
Expand Down Expand Up @@ -305,12 +299,6 @@ def from_backbone_classes(
)
if DerivedLightningIRTokenizers[1] is not None:
DerivedLightningIRTokenizers[1].slow_tokenizer_class = DerivedLightningIRTokenizers[0]

DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig)
AutoTokenizer.register(
DerivedLightningIRConfig, DerivedLightningIRTokenizers[0], DerivedLightningIRTokenizers[1]
)

return DerivedLightningIRTokenizers

def from_backbone_class(self, BackboneClass: Type[PreTrainedTokenizerBase]) -> Type[LightningIRTokenizer]:
Expand Down
42 changes: 5 additions & 37 deletions lightning_ir/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ class from the Hugging Face Transformers library.
from pathlib import Path
from typing import Any, Dict, Set

from transformers import CONFIG_MAPPING
from transformers import PretrainedConfig

from .class_factory import LightningIRConfigClassFactory
from .external_model_hub import CHECKPOINT_MAPPING


class LightningIRConfig:
class LightningIRConfig(PretrainedConfig):
"""The configuration class to instantiate a Lightning IR model. Acts as a mixin for the
transformers.PretrainedConfig_ class.
Expand Down Expand Up @@ -71,10 +71,7 @@ def to_dict(self) -> Dict[str, Any]:
:return: Configuration dictionary
:rtype: Dict[str, Any]
"""
if hasattr(super(), "to_dict"):
output = getattr(super(), "to_dict")()
else:
output = self.to_added_args_dict()
output = getattr(super(), "to_dict")()
if self.backbone_model_type is not None:
output["backbone_model_type"] = self.backbone_model_type
return output
Expand All @@ -93,7 +90,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | Path, *args, **kwa
:return: Derived LightningIRConfig class
:rtype: LightningIRConfig
"""
# provides AutoConfig.from_pretrained support
if cls is LightningIRConfig or all(issubclass(base, LightningIRConfig) for base in cls.__bases__):
# no backbone config found, create dervied lightning-ir config based on backbone config
config = None
if pretrained_model_name_or_path in CHECKPOINT_MAPPING:
config = CHECKPOINT_MAPPING[pretrained_model_name_or_path]
Expand All @@ -111,34 +110,3 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | Path, *args, **kwa
derived_config.update(config.to_dict())
return cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
return super(LightningIRConfig, cls).from_pretrained(pretrained_model_name_or_path, *args, **kwargs)

@classmethod
def from_dict(cls, config_dict: Dict[str, Any], *args, **kwargs) -> "LightningIRConfig":
"""Loads the configuration from a dictionary. Wraps the transformers.PretrainedConfig.from_dict_ method to
return a derived LightningIRConfig class. See :class:`.LightningIRConfigClassFactory` for more details.
.. _transformers.PretrainedConfig.from_dict: \
https://huggingface.co/docs/transformers/main_classes/configuration.html#transformers.PretrainedConfig.from_dict
:param config_dict: Configuration dictionary
:type config_dict: Dict[str, Any]
:raises ValueError: If the model type does not match the configuration model type
:return: Derived LightningIRConfig class
:rtype: LightningIRConfig
"""
if all(issubclass(base, LightningIRConfig) for base in cls.__bases__) or cls is LightningIRConfig:
if "backbone_model_type" in config_dict:
backbone_model_type = config_dict["backbone_model_type"]
model_type = config_dict["model_type"]
if cls is not LightningIRConfig and model_type != cls.model_type:
raise ValueError(
f"Model type {model_type} does not match configuration model type {cls.model_type}"
)
else:
backbone_model_type = config_dict["model_type"]
model_type = cls.model_type
MixinConfig = CONFIG_MAPPING[model_type]
BackboneConfig = CONFIG_MAPPING[backbone_model_type]
cls = LightningIRConfigClassFactory(MixinConfig).from_backbone_class(BackboneConfig)
return cls.from_dict(config_dict, *args, **kwargs)
return super(LightningIRConfig, cls).from_dict(config_dict, *args, **kwargs)
3 changes: 2 additions & 1 deletion lightning_ir/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> "Li
"""Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method and to return a
derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details.
.. _transformers.PreTrainedModel.from_pretrained: https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained # noqa
.. _transformers.PreTrainedModel.from_pretrained: \
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
:param model_name_or_path: Name or path of the pretrained model
:type model_name_or_path: str | Path
Expand Down
43 changes: 40 additions & 3 deletions lightning_ir/base/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
This module contains the main tokenizer class for the Lightning IR library.
"""

from typing import Dict, Sequence, Type
import json
from os import PathLike
from typing import Dict, Sequence, Tuple, Type

from transformers import TOKENIZER_MAPPING, BatchEncoding
from transformers import TOKENIZER_MAPPING, BatchEncoding, PreTrainedTokenizerBase

from .class_factory import LightningIRTokenizerClassFactory
from .config import LightningIRConfig
from .external_model_hub import CHECKPOINT_MAPPING


class LightningIRTokenizer:
class LightningIRTokenizer(PreTrainedTokenizerBase):
"""Base class for Lightning IR tokenizers. Derived classes implement the tokenize method for handling query
and document tokenization. It acts as mixin for a transformers.PreTrainedTokenizer_ backbone tokenizer.
Expand Down Expand Up @@ -77,6 +79,7 @@ def from_pretrained(cls, model_name_or_path: str, *args, **kwargs) -> "Lightning
:return: A derived LightningIRTokenizer consisting of a backbone tokenizer and a LightningIRTokenizer mixin
:rtype: LightningIRTokenizer
"""
# provides AutoTokenizer.from_pretrained support
config = kwargs.pop("config", None)
if config is not None:
kwargs.update(config.to_tokenizer_dict())
Expand All @@ -94,6 +97,7 @@ def from_pretrained(cls, model_name_or_path: str, *args, **kwargs) -> "Lightning
Config = LightningIRTokenizerClassFactory.get_lightning_ir_config(model_name_or_path)
if Config is None:
raise ValueError("Pass a config to `from_pretrained`.")
Config = getattr(Config, "mixin_config", Config)
BackboneConfig = LightningIRTokenizerClassFactory.get_backbone_config(model_name_or_path)
BackboneTokenizers = TOKENIZER_MAPPING[BackboneConfig]
if kwargs.get("use_fast", True):
Expand All @@ -103,3 +107,36 @@ def from_pretrained(cls, model_name_or_path: str, *args, **kwargs) -> "Lightning
cls = LightningIRTokenizerClassFactory(Config).from_backbone_class(BackboneTokenizer)
return cls.from_pretrained(model_name_or_path, *args, **kwargs)
return super(LightningIRTokenizer, cls).from_pretrained(model_name_or_path, *args, **kwargs)

def _save_pretrained(
self,
save_directory: str | PathLike,
file_names: Tuple[str],
legacy_format: bool | None = None,
filename_prefix: str | None = None,
) -> Tuple[str]:
# bit of a hack to change the tokenizer class in the stored tokenizer config to only contain the
# lightning_ir tokenizer class (removing the backbone tokenizer class)
save_files = super()._save_pretrained(save_directory, file_names, legacy_format, filename_prefix)
config_file = save_files[0]
with open(config_file, "r") as file:
tokenizer_config = json.load(file)

tokenizer_class = None
backbone_tokenizer_class = None
for base in self.__class__.__bases__:
if issubclass(base, LightningIRTokenizer):
if tokenizer_class is not None:
raise ValueError("Multiple Lightning IR tokenizer classes found.")
tokenizer_class = base.__name__
continue
if issubclass(base, PreTrainedTokenizerBase):
backbone_tokenizer_class = base.__name__

tokenizer_config["tokenizer_class"] = tokenizer_class
tokenizer_config["backbone_tokenizer_class"] = backbone_tokenizer_class

with open(config_file, "w") as file:
out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
file.write(out_str)
return save_files
2 changes: 1 addition & 1 deletion lightning_ir/bi_encoder/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def save_pretrained(self, save_directory: str | PathLike, **kwargs) -> None:
:param save_directory: Directory to save the configuration
:type save_directory: str | PathLike
"""
super().save_pretrained(save_directory, **kwargs)
with open(os.path.join(save_directory, "mask_scoring_tokens.json"), "w") as f:
json.dump({"query": self.query_mask_scoring_tokens, "doc": self.doc_mask_scoring_tokens}, f)
return super().save_pretrained(save_directory, **kwargs)

@classmethod
def get_config_dict(
Expand Down
31 changes: 31 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pathlib import Path

from transformers import AutoConfig

from lightning_ir.base.config import LightningIRConfig
from lightning_ir.base.module import LightningIRModule


def test_serialize_deserialize(module: LightningIRModule, tmp_path: Path):
config = module.model.config
config_class = module.model.config_class
save_dir = str(tmp_path / config_class.model_type)
config.save_pretrained(save_dir)
new_configs = [
config.__class__.from_pretrained(save_dir),
config.__class__.__bases__[0].from_pretrained(save_dir),
LightningIRConfig.from_pretrained(save_dir),
AutoConfig.from_pretrained(save_dir),
]
for new_config in new_configs:
for key, value in config.__dict__.items():
if key in (
"torch_dtype",
"_name_or_path",
"_commit_hash",
"transformers_version",
"_attn_implementation_autoset",
"_attn_implementation_internal",
):
continue
assert getattr(new_config, key) == value
1 change: 0 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def test_seralize_deserialize(module: LightningIRModule, tmp_path: Path):
"_name_or_path",
"_commit_hash",
"transformers_version",
"model_type",
"_attn_implementation_autoset",
):
continue
Expand Down

0 comments on commit efa4724

Please sign in to comment.