diff --git a/examples/toy-complex-train-tripleclass.yaml b/examples/toy-complex-train-tripleclass.yaml index 582f8b72f..d75cdd811 100644 --- a/examples/toy-complex-train-tripleclass.yaml +++ b/examples/toy-complex-train-tripleclass.yaml @@ -1,6 +1,6 @@ job.type: train -dataset.name: toy -model: distmult +dataset.name: wn11 +model: complex train: optimizer: Adagrad optimizer_args: @@ -11,8 +11,11 @@ lookup_embedder.dim: 100 lookup_embedder.initialize: xavier_uniform_ eval: type: triple_classification - metrics_per.relation: False - triple_classification_random_seed: False +triple_classification.random_seed: False +triple_classification.negatives_from: data + + valid.metric: accuracy +valid.every: 1 diff --git a/kge/config-default.yaml b/kge/config-default.yaml index 08ca70681..f06a8eea6 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -330,7 +330,7 @@ eval: # mean_reciprocal_rank_filtered_with_test. filter_with_test: True - # Type of evaluation (entity_ranking only at the moment) + # Type of evaluation (entity_ranking, triple_classification) type: entity_ranking # How to handle cases with ties between the correct answer and other answers, e.g., @@ -423,6 +423,14 @@ valid: ## EVALUATION ################################################################## +triple_classification: + random_seed: False + # How to obtain negative triple labels. Possible values are: + # - corruption: Create negatives by randomly corrupting existing triples (positives) + # - data : Obtain negative labels from the dataset. This assumes the data set + # contains the splits 'valid_negatives' and 'test_negatives' + negatives_from: corruption + ## HYPERPARAMETER SEARCH ####################################################### diff --git a/kge/job/triple_classification.py b/kge/job/triple_classification.py index 2d5cce804..918b19520 100644 --- a/kge/job/triple_classification.py +++ b/kge/job/triple_classification.py @@ -20,7 +20,12 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset): self.o_entities = None uni_sampler_config = config.clone() # uni_sampler_config.set("negative_sampling.num_samples.s", self.get_option("num_samples.s")) + # TODO this is redundant as uniform.sample() is called with "num_samples" here in self.sample() uni_sampler_config.set("negative_sampling.num_samples.s", 1) + # TODO maybe changing the API of KGEsampler.sample() to also accept a param "filter" + # as it is the case already with "num_samples" + # then we would not rely here on configuration options which actually + # belong to a training job uni_sampler_config.set("negative_sampling.filtering.s", True) # uni_sampler_config.set("negative_sampling.num_samples.o", self.get_option("num_samples.o")) uni_sampler_config.set("negative_sampling.num_samples.o", 1) @@ -31,6 +36,7 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset): def _prepare(self,): train_data = self.dataset.split("train") + #TODO probably outdated as it refers to out-commented code self.s_entities = train_data[:, S].unique().tolist() self.o_entities = train_data[:, O].unique().tolist() self._is_prepared = True @@ -105,6 +111,20 @@ def __init__(self, config, dataset, parent_job, model): self.triple_classification_sampler = TripleClassificationSampler( config, "triple_classification", dataset ) + self.config.check( + "triple_classification.negatives_from", ["corruption", "data"] + ) + self.negatives_from = self.config.get("triple_classification.negatives_from") + if self.negatives_from == "data": + try: + self.config.get("dataset.files.valid_negatives.type") + self.config.get("dataset.files.test_negatives.type") + except: + raise KeyError( + "No splits test/valid_negatives found for the dataset. " + "Provide a dataset with splits valid_negatives and test_negatives " + "or run triple classification with negatives_from=corruption" + ) def _prepare(self): """Prepare the corrupted validation and test data. @@ -121,7 +141,8 @@ def _prepare(self): self.config.log("Generate data with corrupted and true triples...") - if self.eval_split == "test": + # TODO maybe should be generalized to allow for other splits as valid_wo_unseen + if self.eval_split == "test" and self.negatives_from == "corruption": ( self.tune_data, self.tune_labels, @@ -130,7 +151,25 @@ def _prepare(self): self.eval_data, self.eval_labels, ) = self.triple_classification_sampler.sample(self.dataset.split("test")) - else: + + elif self.eval_split == "test" and self.negatives_from == "data": + positives_valid = self.dataset.split("valid") + negatives_valid = self.dataset.split("valid_negatives") + self.tune_data = torch.cat((positives_valid, negatives_valid)).to(self.device) + self.tune_labels = torch.cat( + (torch.ones(positives_valid.size(0), torch.zeros(negatives_valid.size(0)))) + ).to(self.device) + + positives_test = self.dataset.split("test") + negatives_test = self.dataset.split("test_negatives") + self.tune_data = torch.cat((positives_test, negatives_test)).to( + self.device) + self.tune_labels = torch.cat( + (torch.ones(positives_test.size(0), + torch.zeros(negatives_test.size(0)))) + ).to(self.device) + + elif self.eval_split == "valid" and self.negatives_from == "corruption": ( self.tune_data, self.tune_labels, @@ -140,6 +179,17 @@ def _prepare(self): self.eval_labels, ) = self.triple_classification_sampler.sample(self.dataset.split("valid")) + elif self.eval_split == "valid" and self.negatives_from == "data": + positives = self.dataset.split("valid") + negatives = self.dataset.split("valid_negatives") + self.tune_data = torch.cat((positives, negatives)).to(self.device) + self.tune_labels = torch.cat( + (torch.ones(positives.size(0)), torch.zeros(negatives.size(0))) + ).to(self.device) + + self.eval_data = self.tune_data + self.eval_labels = self.tune_labels + # let the model add some hooks, if it wants to do so self.model.prepare_job(self) self.valid_data_is_prepared = True