Skip to content

Commit

Permalink
Merge pull request #101 from Nzteb/reg_api
Browse files Browse the repository at this point in the history
(WIP)Change embedder penalty API
  • Loading branch information
rgemulla authored May 25, 2020
2 parents 1ec6a1d + 925318a commit 952ec5a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
14 changes: 7 additions & 7 deletions kge/model/embedder/lookup_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,10 @@ def penalty(self, **kwargs) -> List[Tensor]:
]
else:
# weighted Lp regularization
unique_ids, counts = torch.unique(
kwargs["batch"]["triples"][:, kwargs["slot"]], return_counts=True
unique_indexes, counts = torch.unique(
kwargs["indexes"], return_counts=True
)
parameters = self._embeddings(unique_ids)
parameters = self._embeddings(unique_indexes)
if p % 2 == 1:
parameters = torch.abs(parameters)
result += [
Expand All @@ -135,10 +135,10 @@ def penalty(self, **kwargs) -> List[Tensor]:
/ p
* (parameters ** p * counts.float().view(-1, 1))
).sum()
# In contrast to unweighted Lp regulariztion, rescaling by number of
# triples is necessary here so that penalty term is correct in
# expectation
/ len(kwargs["batch"]["triples"]),
# In contrast to unweighted Lp regularization, rescaling by
# number of triples/indexes is necessary here so that penalty
# term is correct in expectation
/ len(kwargs["indexes"]),
)
]
else: # unknown regularization
Expand Down
27 changes: 16 additions & 11 deletions kge/model/kge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,19 +426,24 @@ def append_num_parameter(job, trace):
job.post_epoch_trace_hooks.append(append_num_parameter)

def penalty(self, **kwargs) -> List[Tensor]:
# Note: If the subject and object embedder are identical, embeddings may be
# penalized twice. This is intended (and necessary, e.g., if the penalty is
# weighted).
if "batch" in kwargs and "triples" in kwargs["batch"]:
kwargs["batch"]["triples"] = kwargs["batch"]["triples"].to(
self.config.get("job.device")
triples = kwargs["batch"]["triples"].to(self.config.get("job.device"))
return (
super().penalty(**kwargs)
+ self.get_s_embedder().penalty(indexes=triples[:, S], **kwargs)
+ self.get_p_embedder().penalty(indexes=triples[:, P], **kwargs)
+ self.get_o_embedder().penalty(indexes=triples[:, O], **kwargs)
)
else:
return (
super().penalty(**kwargs)
+ self.get_s_embedder().penalty(**kwargs)
+ self.get_p_embedder().penalty(**kwargs)
+ self.get_o_embedder().penalty(**kwargs)
)
return (
# Note: If the subject and object embedder are identical, embeddings may be
# penalized twice. This is intended (and necessary, e.g., if the penalty is
# weighted).
super().penalty(**kwargs)
+ self.get_s_embedder().penalty(slot=S, **kwargs)
+ self.get_p_embedder().penalty(slot=P, **kwargs)
+ self.get_o_embedder().penalty(slot=O, **kwargs)
)

def get_s_embedder(self) -> KgeEmbedder:
return self._entity_embedder
Expand Down

0 comments on commit 952ec5a

Please sign in to comment.