diff --git a/kge/model/embedder/lookup_embedder.py b/kge/model/embedder/lookup_embedder.py index 147e75671..82e56cc1c 100644 --- a/kge/model/embedder/lookup_embedder.py +++ b/kge/model/embedder/lookup_embedder.py @@ -121,10 +121,10 @@ def penalty(self, **kwargs) -> List[Tensor]: ] else: # weighted Lp regularization - unique_ids, counts = torch.unique( - kwargs["batch"]["triples"][:, kwargs["slot"]], return_counts=True + unique_indexes, counts = torch.unique( + kwargs["indexes"], return_counts=True ) - parameters = self._embeddings(unique_ids) + parameters = self._embeddings(unique_indexes) if p % 2 == 1: parameters = torch.abs(parameters) result += [ @@ -135,10 +135,10 @@ def penalty(self, **kwargs) -> List[Tensor]: / p * (parameters ** p * counts.float().view(-1, 1)) ).sum() - # In contrast to unweighted Lp regulariztion, rescaling by number of - # triples is necessary here so that penalty term is correct in - # expectation - / len(kwargs["batch"]["triples"]), + # In contrast to unweighted Lp regularization, rescaling by + # number of triples/indexes is necessary here so that penalty + # term is correct in expectation + / len(kwargs["indexes"]), ) ] else: # unknown regularization diff --git a/kge/model/kge_model.py b/kge/model/kge_model.py index cf4b98e59..d99a69fd9 100644 --- a/kge/model/kge_model.py +++ b/kge/model/kge_model.py @@ -426,19 +426,24 @@ def append_num_parameter(job, trace): job.post_epoch_trace_hooks.append(append_num_parameter) def penalty(self, **kwargs) -> List[Tensor]: + # Note: If the subject and object embedder are identical, embeddings may be + # penalized twice. This is intended (and necessary, e.g., if the penalty is + # weighted). if "batch" in kwargs and "triples" in kwargs["batch"]: - kwargs["batch"]["triples"] = kwargs["batch"]["triples"].to( - self.config.get("job.device") + triples = kwargs["batch"]["triples"].to(self.config.get("job.device")) + return ( + super().penalty(**kwargs) + + self.get_s_embedder().penalty(indexes=triples[:, S], **kwargs) + + self.get_p_embedder().penalty(indexes=triples[:, P], **kwargs) + + self.get_o_embedder().penalty(indexes=triples[:, O], **kwargs) + ) + else: + return ( + super().penalty(**kwargs) + + self.get_s_embedder().penalty(**kwargs) + + self.get_p_embedder().penalty(**kwargs) + + self.get_o_embedder().penalty(**kwargs) ) - return ( - # Note: If the subject and object embedder are identical, embeddings may be - # penalized twice. This is intended (and necessary, e.g., if the penalty is - # weighted). - super().penalty(**kwargs) - + self.get_s_embedder().penalty(slot=S, **kwargs) - + self.get_p_embedder().penalty(slot=P, **kwargs) - + self.get_o_embedder().penalty(slot=O, **kwargs) - ) def get_s_embedder(self) -> KgeEmbedder: return self._entity_embedder