Skip to content

Commit

Permalink
Support loading of old checkpoints with current embedders
Browse files Browse the repository at this point in the history
  • Loading branch information
rgemulla committed May 25, 2020
1 parent 8d519e0 commit f2bcbdb
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion kge/model/kge_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
import tempfile
from collections import OrderedDict

from torch import Tensor
import torch.nn
Expand Down Expand Up @@ -69,7 +70,15 @@ def save(self):

def load(self, savepoint):
"Loads state from a saved data structure"
self.load_state_dict(savepoint[0])
# handle deprecated keys
state_dict = OrderedDict()
for k,v in savepoint[0].items():
if k == "_entity_embedder.embeddings.weight":
k = "_entity_embedder._embeddings.weight"
if k == "_relation_embedder.embeddings.weight":
k = "_relation_embedder._embeddings.weight"
state_dict[k] = v
self.load_state_dict(state_dict)
self.meta = savepoint[1]


Expand Down

0 comments on commit f2bcbdb

Please sign in to comment.