Skip to content

Commit

Permalink
fix: attempt automatic embedding table update for previous models
Browse files Browse the repository at this point in the history
  • Loading branch information
roedoejet committed Feb 12, 2025
1 parent a4841c7 commit 53afbf3
Showing 1 changed file with 53 additions and 1 deletion.
54 changes: 53 additions & 1 deletion fs2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from everyvoice.text.features import N_PHONOLOGICAL_FEATURES
from everyvoice.text.lookups import LookupTable
from everyvoice.text.text_processor import TextProcessor
from everyvoice.text.utils import get_symbols_from_symbol_dict, symbol_sorter
from everyvoice.utils import pydantic_validation_error_shortener
from everyvoice.utils.heavy import expand
from loguru import logger
Expand All @@ -35,7 +36,7 @@


class FastSpeech2(pl.LightningModule):
_VERSION: str = "1.1"
_VERSION: str = "1.2"

def __init__(
self,
Expand Down Expand Up @@ -296,6 +297,57 @@ def check_and_upgrade_checkpoint(self, checkpoint):
# Upgrading from 0.0 to 1.0 requires no changes; future versions might require changes
checkpoint["model_info"]["version"] = "1.0"

# We changed the handling of phonological features in everyvoice==0.3.0
if (
ckpt_version < Version("1.2")
and checkpoint["hyper_parameters"]["config"]["model"][
"target_text_representation_level"
]
== TargetTrainingTextRepresentationLevel.phonological_features.value
):
raise ValueError(
f"""There were breaking changes to the handling of phonological features in version 1.2, introduced in version 0.3.0 of EveryVoice.
Your model is version {ckpt_version} and your model may not work as a result. Please downgrade to everyvoice 0.2.0."""
)

elif ckpt_version < Version("1.2"):
old_hardcoded_symbols = [
"\x80",
" ",
"<EXCL>",
"<QINT>",
"<QUOTE>",
"<BB>",
"<SB>",
"<EPS>",
]
checkpoint_symbols = symbol_sorter(
get_symbols_from_symbol_dict(
checkpoint["hyper_parameters"]["config"]["text"]["symbols"]
),
hardcoded_initial_symbols=old_hardcoded_symbols,
)
model_symbols = self.text_processor.symbols
assert (
checkpoint_symbols <= model_symbols
), "Unfortunately we are unable to automatically update your embedding table. Please re-train your model or downgrade to everyvoice 0.2.0"
checkpoint_symbol_indices = torch.Tensor(
[
model_symbols.index(c) if c in model_symbols else 0
for c in checkpoint_symbols
]
).int()
new_weights = torch.zeros(self.text_input_layer.weight.size())
# Copy data into the correct positions
new_weights[checkpoint_symbol_indices, :] = checkpoint["state_dict"][
"text_input_layer.weight"
]
# Update the checkpoint's state_dict with the new weights
checkpoint["state_dict"]["text_input_layer.weight"] = new_weights
logger.warning(
f"Your checkpoint was trained using version {ckpt_version} but your code is currently running {self._VERSION}. We have attempted to update your checkpoint automatically, but if you encounter issues, please re-train your model or downgrade to everyvoice 0.2.0"
)

return checkpoint

def on_load_checkpoint(self, checkpoint):
Expand Down

0 comments on commit 53afbf3

Please sign in to comment.