Skip to content

Commit

Permalink
feat: added a version number to FastSpeech2
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelLarkin committed Oct 16, 2024
1 parent 739cedd commit 5b5b9bd
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions fs2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@


class FastSpeech2(pl.LightningModule):
__version__: str = "1"

def __init__(
self,
config: dict | FastSpeech2Config,
Expand Down Expand Up @@ -256,6 +258,13 @@ def on_load_checkpoint(self, checkpoint):
Note, this shouldn't fail on different versions of pydantic anymore,
but it will fail on breaking changes to the config. We should catch those exceptions
and handle them appropriately."""
if "model_info" in checkpoint:
assert (
checkpoint["model_info"]["name"] == FastSpeech2.__name__
), f"""Wrong model type ({checkpoint["model_info"]["name"]}), we are expecting a { FastSpeech2.__name__ }"""
assert (
checkpoint["model_info"]["version"] == self.__version__
), f"""Wrong model's version({checkpoint["model_info"]["version"]}), we are expecting version {FastSpeech2.__version__}"""
self.config = FeaturePredictionConfig(
**checkpoint["hyper_parameters"]["config"]
)
Expand All @@ -271,6 +280,10 @@ def on_save_checkpoint(self, checkpoint):
checkpoint["hyper_parameters"]["config"] = self.config.model_checkpoint_dump()
if self.stats is not None:
checkpoint["hyper_parameters"]["stats"] = self.stats.model_dump(mode="json")
checkpoint["model_info"] = {
"name": self.__class__.__name__,
"version": self.__version__,
}

def predict_step(self, batch, batch_idx):
with torch.no_grad():
Expand Down

0 comments on commit 5b5b9bd

Please sign in to comment.