From 53afbf3f641302a40d00b8e82c0c547f21b7597f Mon Sep 17 00:00:00 2001 From: Aidan Pine Date: Wed, 12 Feb 2025 13:02:19 -0800 Subject: [PATCH] fix: attempt automatic embedding table update for previous models --- fs2/model.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/fs2/model.py b/fs2/model.py index d7276bf..798f77d 100644 --- a/fs2/model.py +++ b/fs2/model.py @@ -15,6 +15,7 @@ from everyvoice.text.features import N_PHONOLOGICAL_FEATURES from everyvoice.text.lookups import LookupTable from everyvoice.text.text_processor import TextProcessor +from everyvoice.text.utils import get_symbols_from_symbol_dict, symbol_sorter from everyvoice.utils import pydantic_validation_error_shortener from everyvoice.utils.heavy import expand from loguru import logger @@ -35,7 +36,7 @@ class FastSpeech2(pl.LightningModule): - _VERSION: str = "1.1" + _VERSION: str = "1.2" def __init__( self, @@ -296,6 +297,57 @@ def check_and_upgrade_checkpoint(self, checkpoint): # Upgrading from 0.0 to 1.0 requires no changes; future versions might require changes checkpoint["model_info"]["version"] = "1.0" + # We changed the handling of phonological features in everyvoice==0.3.0 + if ( + ckpt_version < Version("1.2") + and checkpoint["hyper_parameters"]["config"]["model"][ + "target_text_representation_level" + ] + == TargetTrainingTextRepresentationLevel.phonological_features.value + ): + raise ValueError( + f"""There were breaking changes to the handling of phonological features in version 1.2, introduced in version 0.3.0 of EveryVoice. + Your model is version {ckpt_version} and your model may not work as a result. Please downgrade to everyvoice 0.2.0.""" + ) + + elif ckpt_version < Version("1.2"): + old_hardcoded_symbols = [ + "\x80", + " ", + "", + "", + "", + "", + "", + "", + ] + checkpoint_symbols = symbol_sorter( + get_symbols_from_symbol_dict( + checkpoint["hyper_parameters"]["config"]["text"]["symbols"] + ), + hardcoded_initial_symbols=old_hardcoded_symbols, + ) + model_symbols = self.text_processor.symbols + assert ( + checkpoint_symbols <= model_symbols + ), "Unfortunately we are unable to automatically update your embedding table. Please re-train your model or downgrade to everyvoice 0.2.0" + checkpoint_symbol_indices = torch.Tensor( + [ + model_symbols.index(c) if c in model_symbols else 0 + for c in checkpoint_symbols + ] + ).int() + new_weights = torch.zeros(self.text_input_layer.weight.size()) + # Copy data into the correct positions + new_weights[checkpoint_symbol_indices, :] = checkpoint["state_dict"][ + "text_input_layer.weight" + ] + # Update the checkpoint's state_dict with the new weights + checkpoint["state_dict"]["text_input_layer.weight"] = new_weights + logger.warning( + f"Your checkpoint was trained using version {ckpt_version} but your code is currently running {self._VERSION}. We have attempted to update your checkpoint automatically, but if you encounter issues, please re-train your model or downgrade to everyvoice 0.2.0" + ) + return checkpoint def on_load_checkpoint(self, checkpoint):