diff --git a/kge/model/kge_model.py b/kge/model/kge_model.py index d99a69fd9..f59e81fcd 100644 --- a/kge/model/kge_model.py +++ b/kge/model/kge_model.py @@ -1,5 +1,6 @@ import importlib import tempfile +from collections import OrderedDict from torch import Tensor import torch.nn @@ -69,7 +70,15 @@ def save(self): def load(self, savepoint): "Loads state from a saved data structure" - self.load_state_dict(savepoint[0]) + # handle deprecated keys + state_dict = OrderedDict() + for k,v in savepoint[0].items(): + if k == "_entity_embedder.embeddings.weight": + k = "_entity_embedder._embeddings.weight" + if k == "_relation_embedder.embeddings.weight": + k = "_relation_embedder._embeddings.weight" + state_dict[k] = v + self.load_state_dict(state_dict) self.meta = savepoint[1]