diff --git a/data/download_all.sh b/data/download_all.sh index 6424f80bb..ea63e0d9d 100755 --- a/data/download_all.sh +++ b/data/download_all.sh @@ -201,3 +201,28 @@ else echo wikidata5m already prepared fi + +# wn11 +if [ ! -d "$BASEDIR/wn11" ]; then + echo Downloading wn11 + cd $BASEDIR + # TODO this also creates a __MACOSX folder on non-mac in the zip + # TODO download file from mannheim server + curl -O https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/wordnet11.zip + unzip wordnet11.zip + if [ -d "__MACOSX" ]; then + rm -r __MACOSX + fi + mv wordnet11 wn11 + mv wn11/dev.txt wn11/valid.txt +else + echo wn11 already present +fi + +if [ ! -f "$BASEDIR/wn11/dataset.yaml" ]; then + python preprocess.py wn11 --triple_class +else + echo wn11 already prepared +fi + + diff --git a/data/preprocess.py b/data/preprocess.py index fa1c735db..3d2627528 100755 --- a/data/preprocess.py +++ b/data/preprocess.py @@ -7,7 +7,7 @@ During preprocessing, each distinct entity name and each distinct distinct relation name is assigned an index (dense). The index-to-object mapping is stored in files -"entity_map.del" and "relation_map.del", resp. The triples (as indexes) are stored in +"entity_ids.del" and "relation_ids.del", resp. The triples (as indexes) are stored in files "train.del", "valid.del", and "test.del". Metadata information is stored in a file "dataset.yaml". @@ -19,6 +19,7 @@ import numpy as np from collections import OrderedDict + def store_map(symbol_map, filename): with open(filename, "w") as f: for symbol, index in symbol_map.items(): @@ -29,14 +30,31 @@ def store_map(symbol_map, filename): parser = argparse.ArgumentParser() parser.add_argument("folder", type=str) parser.add_argument("--order_sop", action="store_true") + parser.add_argument("--triple_class", action="store_true") args = parser.parse_args() print(f"Preprocessing {args.folder}...") raw_split_files = {"train": "train.txt", "valid": "valid.txt", "test": "test.txt"} split_files = {"train": "train.del", "valid": "valid.del", "test": "test.del"} - string_files = {"entity_strings": "entity_strings.del", "relation_strings": "relation_strings.del"} - split_files_without_unseen = {"train_sample": "train_sample.del", "valid_without_unseen": "valid_without_unseen.del", - "test_without_unseen": "test_without_unseen.del"} + + string_files = { + "entity_strings": "entity_strings.del", + "relation_strings": "relation_strings.del", + } + split_files_without_unseen = { + "train_sample": "train_sample.del", + "valid_without_unseen": "valid_without_unseen.del", + "test_without_unseen": "test_without_unseen.del", + } + + if args.triple_class: + split_files_negatives = { + "valid_negatives": "valid_negatives.del", + "test_negatives": "test_negatives.del"} + split_files_negatives_without_unseen = { + "valid_negatives_without_unseen": "valid_negatives_without_unseen.del", + "test_negatives_without_unseen": "test_negatives_without_unseen.del"} + split_sizes = {} if args.order_sop: @@ -73,7 +91,7 @@ def store_map(symbol_map, filename): if "train" in split: entities_in_train = entities.copy() relations_in_train = relations.copy() - + print(f"{len(relations)} distinct relations") print(f"{len(entities)} distinct entities") print("Writing relation and entity map...") @@ -87,17 +105,61 @@ def store_map(symbol_map, filename): for split, filename in split_files.items(): if split in ["valid", "test"]: split_without_unseen = split + "_without_unseen" - f_wo_unseen = open(os.path.join(args.folder, - split_files_without_unseen[split_without_unseen]), "w") + f_wo_unseen = open( + os.path.join( + args.folder, split_files_without_unseen[split_without_unseen] + ), + "w", + ) + if args.triple_class: + split_negatives_wo_unseen = f"{split}_negatives_without_unseen" + f_negatives_wo_unseen = open( + os.path.join( + args.folder, + split_files_negatives_without_unseen[split_negatives_wo_unseen] + ), + "w" + ) else: split_without_unseen = split + "_sample" - f_tr_sample = open(os.path.join(args.folder, - split_files_without_unseen[split_without_unseen]), "w") - train_sample = np.random.choice(split_sizes["train"], split_sizes["valid"], False) + f_tr_sample = open( + os.path.join( + args.folder, split_files_without_unseen[split_without_unseen] + ), + "w", + ) + train_sample = np.random.choice( + split_sizes["train"], split_sizes["valid"], False + ) with open(os.path.join(args.folder, filename), "w") as f: - size_unseen = 0 + if args.triple_class and split in ["valid", "test"]: + split_negatives = f"{split}_negatives" + f_negatives = open( + os.path.join( + args.folder, + split_files_negatives[split_negatives], + ), + "w", + ) + + if args.triple_class: + size_negatives = 0 + size_negatives_unseen = 0 + # positives; valid and test sizes have to be recalculated + size_positives = 0 + size_positives_unseen = 0 + else: + size_positives_unseen = 0 for n, t in enumerate(raw[split]): - f.write( + if args.triple_class and split in ["valid", "test"] and int(t[3]) == -1: + file_wrapper = f_negatives + size_negatives += 1 + elif args.triple_class and split in ["valid", "test"]: + size_positives += 1 + file_wrapper = f + else: + file_wrapper = f + file_wrapper.write( str(entities[t[S]]) + "\t" + str(relations[t[P]]) @@ -114,10 +176,22 @@ def store_map(symbol_map, filename): + str(entities[t[O]]) + "\n" ) - size_unseen += 1 - elif split in ["valid", "test"] and t[S] in entities_in_train and \ - t[O] in entities_in_train and t[P] in relations_in_train: - f_wo_unseen.write( + size_positives_unseen += 1 + elif ( + split in ["valid", "test"] + and t[S] in entities_in_train + and t[O] in entities_in_train + and t[P] in relations_in_train + ): + + if args.triple_class and int(t[3]) == -1: + file_wrapper = f_negatives_wo_unseen + size_negatives_unseen += 1 + else: + file_wrapper = f_wo_unseen + size_positives_unseen += 1 + + file_wrapper.write( str(entities[t[S]]) + "\t" + str(relations[t[P]]) @@ -125,17 +199,18 @@ def store_map(symbol_map, filename): + str(entities[t[O]]) + "\n" ) - size_unseen += 1 - without_unseen_sizes[split_without_unseen] = size_unseen + if args.triple_class and split in ["valid", "test"]: + without_unseen_sizes[split_negatives_wo_unseen] = size_negatives_unseen + split_sizes[split] = size_positives + split_sizes[split_negatives] = size_negatives + without_unseen_sizes[split_without_unseen] = size_positives_unseen # write config print("Writing dataset.yaml...") dataset_config = dict( - name=args.folder, - num_entities=len(entities), - num_relations=len(relations), + name=args.folder, num_entities=len(entities), num_relations=len(relations), ) - for obj in [ "entity", "relation" ]: + for obj in ["entity", "relation"]: dataset_config[f"files.{obj}_ids.filename"] = f"{obj}_ids.del" dataset_config[f"files.{obj}_ids.type"] = "map" for split in split_files.keys(): @@ -143,9 +218,26 @@ def store_map(symbol_map, filename): dataset_config[f"files.{split}.type"] = "triples" dataset_config[f"files.{split}.size"] = split_sizes.get(split) for split in split_files_without_unseen.keys(): - dataset_config[f"files.{split}.filename"] = split_files_without_unseen.get(split) + dataset_config[f"files.{split}.filename"] = split_files_without_unseen.get( + split + ) dataset_config[f"files.{split}.type"] = "triples" dataset_config[f"files.{split}.size"] = without_unseen_sizes.get(split) + if args.triple_class: + for split in split_files_negatives.keys(): + dataset_config[f"files.{split}.filename"] = split_files_negatives.get(split) + dataset_config[f"files.{split}.type"] = "triples" + dataset_config[f"files.{split}.size"] = split_sizes[split] + + for split in split_files_negatives_without_unseen.keys(): + dataset_config[f"files.{split}.filename"] = split_files_negatives_without_unseen.get( + split) + dataset_config[f"files.{split}.type"] = "triples" + dataset_config[f"files.{split}.size"] = without_unseen_sizes[ + split] + + + for string in string_files.keys(): if os.path.exists(os.path.join(args.folder, string_files[string])): dataset_config[f"files.{string}.filename"] = string_files.get(string) diff --git a/examples/toy-complex-train-tripleclass.yaml b/examples/toy-complex-train-tripleclass.yaml new file mode 100644 index 000000000..d75cdd811 --- /dev/null +++ b/examples/toy-complex-train-tripleclass.yaml @@ -0,0 +1,21 @@ +job.type: train +dataset.name: wn11 +model: complex +train: + optimizer: Adagrad + optimizer_args: + lr: 0.2 + weight_decay: 0.4e-7 +lookup_embedder.dim: 100 +#lookup_embedder.initialize: normal_ +lookup_embedder.initialize: xavier_uniform_ +eval: + type: triple_classification +triple_classification.random_seed: False +triple_classification.negatives_from: data + + +valid.metric: accuracy +valid.every: 1 + + diff --git a/examples/toy-complex-train.yaml b/examples/toy-complex-train.yaml index bfa3fba27..59d8bc78b 100644 --- a/examples/toy-complex-train.yaml +++ b/examples/toy-complex-train.yaml @@ -1,5 +1,7 @@ job.type: train -dataset.name: toy +#dataset.name: toy +dataset.name: fb15k-237 +#dataset.name: fb15k train: optimizer: Adagrad @@ -9,13 +11,14 @@ train: lr_scheduler_args: mode: max patience: 4 + batch_size: 1024 + +eval.type: triple_classification + +valid.every: 1 +valid.metric: accuracy model: complex lookup_embedder: dim: 100 - regularize_weight: 0.8e-7 - initialize: normal_ - initialize_args: - normal_: - mean: 0.0 - std: 0.1 + regularize_weight: 0.0 diff --git a/kge/config-default.yaml b/kge/config-default.yaml index 08ca70681..f06a8eea6 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -330,7 +330,7 @@ eval: # mean_reciprocal_rank_filtered_with_test. filter_with_test: True - # Type of evaluation (entity_ranking only at the moment) + # Type of evaluation (entity_ranking, triple_classification) type: entity_ranking # How to handle cases with ties between the correct answer and other answers, e.g., @@ -423,6 +423,14 @@ valid: ## EVALUATION ################################################################## +triple_classification: + random_seed: False + # How to obtain negative triple labels. Possible values are: + # - corruption: Create negatives by randomly corrupting existing triples (positives) + # - data : Obtain negative labels from the dataset. This assumes the data set + # contains the splits 'valid_negatives' and 'test_negatives' + negatives_from: corruption + ## HYPERPARAMETER SEARCH ####################################################### diff --git a/kge/job/__init__.py b/kge/job/__init__.py index c3bee8a37..de00c257b 100644 --- a/kge/job/__init__.py +++ b/kge/job/__init__.py @@ -9,3 +9,4 @@ from kge.job.ax_search import AxSearchJob from kge.job.entity_ranking import EntityRankingJob from kge.job.entity_pair_ranking import EntityPairRankingJob +from kge.job.triple_classification import TripleClassificationJob diff --git a/kge/job/eval.py b/kge/job/eval.py index 138726728..d97e838a9 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -72,7 +72,7 @@ def __init__(self, config, dataset, parent_job, model): @staticmethod def create(config, dataset, parent_job=None, model=None): """Factory method to create an evaluation job """ - from kge.job import EntityRankingJob, EntityPairRankingJob + from kge.job import EntityRankingJob, EntityPairRankingJob, TripleClassificationJob # create the job if config.get("eval.type") == "entity_ranking": @@ -81,6 +81,10 @@ def create(config, dataset, parent_job=None, model=None): return EntityPairRankingJob( config, dataset, parent_job=parent_job, model=model ) + elif config.get("eval.type") == "triple_classification": + return TripleClassificationJob( + config, dataset, parent_job=parent_job, model=model + ) else: raise ValueError("eval.type") diff --git a/kge/job/triple_classification.py b/kge/job/triple_classification.py new file mode 100644 index 000000000..918b19520 --- /dev/null +++ b/kge/job/triple_classification.py @@ -0,0 +1,361 @@ +import time + +import torch +from kge import Dataset, Config, Configurable +from kge.util.sampler import KgeUniformSampler +from kge.job import EvaluationJob + +SLOTS = [0, 1, 2] +SLOT_STR = ["s", "p", "o"] +S, P, O = SLOTS + + +class TripleClassificationSampler(Configurable): + def __init__(self, config: Config, configuration_key: str, dataset: Dataset): + super().__init__(config, configuration_key) + self.dataset = dataset + self._is_prepared = False + self.train_data = None + self.s_entities = None + self.o_entities = None + uni_sampler_config = config.clone() + # uni_sampler_config.set("negative_sampling.num_samples.s", self.get_option("num_samples.s")) + # TODO this is redundant as uniform.sample() is called with "num_samples" here in self.sample() + uni_sampler_config.set("negative_sampling.num_samples.s", 1) + # TODO maybe changing the API of KGEsampler.sample() to also accept a param "filter" + # as it is the case already with "num_samples" + # then we would not rely here on configuration options which actually + # belong to a training job + uni_sampler_config.set("negative_sampling.filtering.s", True) + # uni_sampler_config.set("negative_sampling.num_samples.o", self.get_option("num_samples.o")) + uni_sampler_config.set("negative_sampling.num_samples.o", 1) + uni_sampler_config.set("negative_sampling.filtering.o", True) + self.uniform_sampler = KgeUniformSampler( + uni_sampler_config, "negative_sampling", dataset + ) + + def _prepare(self,): + train_data = self.dataset.split("train") + #TODO probably outdated as it refers to out-commented code + self.s_entities = train_data[:, S].unique().tolist() + self.o_entities = train_data[:, O].unique().tolist() + self._is_prepared = True + + def sample(self, positive_triples: torch.Tensor): + """Generates dataset with positive and negative triples. + + Takes each triple of the specified dataset and randomly replaces either the + subject or the object with another subject/object. Only allows a subject/object + to be sampled if it appeared as a subject/object at the same position in the dataset. + + Returns: + corrupted: A new dataset with the original and corrupted triples. + + labels: A vector with labels for the corresponding triples in the dataset. + + rel_labels: A dictionary mapping relations to labels. + Example if we had two triples of relation 1 in the original + dataset: {1: [1, 0, 1, 0]} + """ + + if not self._is_prepared: + self._prepare() + + # Create objects for the corrupted dataset and the corresponding labels + corrupted = positive_triples.repeat(1, 2).view(-1, 3) + labels = torch.as_tensor([1, 0] * len(positive_triples)).type(torch.bool) + + # Random decision if sample subject(sample=nonzero) or object(sample=zero) + sample_subject = torch.randint(2, (len(positive_triples),)).type(torch.bool) + + # Sample subjects from subjects which appeared in the dataset + # corrupted[1::2][:, S][sample_subject] = torch.as_tensor( + # random.choice(self.s_entities) + # ) + corrupted[1::2, S][sample_subject] = self.uniform_sampler.sample( + corrupted[1::2][sample_subject], S, 1 + ).view(-1) + + # Sample objects from objects which appeared in the dataset + # corrupted[1::2][:, O][(sample_subject == False)] = torch.as_tensor( + # random.choice(self.o_entities) + # ) + corrupted[1::2, O][sample_subject == False] = self.uniform_sampler.sample( + corrupted[1::2][sample_subject == False], O, 1 + ).view(-1) + + return ( + corrupted.to(self.config.get("job.device")), + labels.to(self.config.get("job.device")), + ) + + +class TripleClassificationJob(EvaluationJob): + """Triple classification evaluation protocol. + + Testing a model's ability to classify true and false triples based on + thresholding scores. First, negative (corrupted) triples are generated by + randomly corrupting each triple in the validation and test data. Then the + scores for each triple, produced by the model to evaluate, is retrieved. + Afterwards a threshold is determined for each relation. The best threshold + for every relation is determined by maximizing the accuracy on validation + data. The unseen triples from the train data will then be predicted as True + if the score is higher than the threshold of the respective relation. The + metrics include accuracy and precision on test data. If necessary the + accuracy/precision per relation can be returned as well. + """ + + def __init__(self, config, dataset, parent_job, model): + super().__init__(config, dataset, parent_job, model) + self.valid_data_is_prepared = False + self.triple_classification_sampler = TripleClassificationSampler( + config, "triple_classification", dataset + ) + self.config.check( + "triple_classification.negatives_from", ["corruption", "data"] + ) + self.negatives_from = self.config.get("triple_classification.negatives_from") + if self.negatives_from == "data": + try: + self.config.get("dataset.files.valid_negatives.type") + self.config.get("dataset.files.test_negatives.type") + except: + raise KeyError( + "No splits test/valid_negatives found for the dataset. " + "Provide a dataset with splits valid_negatives and test_negatives " + "or run triple classification with negatives_from=corruption" + ) + + def _prepare(self): + """Prepare the corrupted validation and test data. + + The triples are corrupted only for the first evaluated epoch. Afterwards + is_prepared is set to true to make sure that every epoch is evaluated on + the same data. For model selection, the thresholds are found for validation + data and the accuracy on validation data is used. For testing the + thresholds are found for validation data and evaluated on test data. + """ + + if self.valid_data_is_prepared: + return + + self.config.log("Generate data with corrupted and true triples...") + + # TODO maybe should be generalized to allow for other splits as valid_wo_unseen + if self.eval_split == "test" and self.negatives_from == "corruption": + ( + self.tune_data, + self.tune_labels, + ) = self.triple_classification_sampler.sample(self.dataset.split("valid")) + ( + self.eval_data, + self.eval_labels, + ) = self.triple_classification_sampler.sample(self.dataset.split("test")) + + elif self.eval_split == "test" and self.negatives_from == "data": + positives_valid = self.dataset.split("valid") + negatives_valid = self.dataset.split("valid_negatives") + self.tune_data = torch.cat((positives_valid, negatives_valid)).to(self.device) + self.tune_labels = torch.cat( + (torch.ones(positives_valid.size(0), torch.zeros(negatives_valid.size(0)))) + ).to(self.device) + + positives_test = self.dataset.split("test") + negatives_test = self.dataset.split("test_negatives") + self.tune_data = torch.cat((positives_test, negatives_test)).to( + self.device) + self.tune_labels = torch.cat( + (torch.ones(positives_test.size(0), + torch.zeros(negatives_test.size(0)))) + ).to(self.device) + + elif self.eval_split == "valid" and self.negatives_from == "corruption": + ( + self.tune_data, + self.tune_labels, + ) = self.triple_classification_sampler.sample(self.dataset.split("valid")) + ( + self.eval_data, + self.eval_labels, + ) = self.triple_classification_sampler.sample(self.dataset.split("valid")) + + elif self.eval_split == "valid" and self.negatives_from == "data": + positives = self.dataset.split("valid") + negatives = self.dataset.split("valid_negatives") + self.tune_data = torch.cat((positives, negatives)).to(self.device) + self.tune_labels = torch.cat( + (torch.ones(positives.size(0)), torch.zeros(negatives.size(0))) + ).to(self.device) + + self.eval_data = self.tune_data + self.eval_labels = self.tune_labels + + # let the model add some hooks, if it wants to do so + self.model.prepare_job(self) + self.valid_data_is_prepared = True + + def run(self): + """Runs the triple classification job.""" + + self._prepare() + + was_training = self.model.training + self.model.eval() + + epoch_time = -time.time() + + # Get scores and scores per relation for the corrupted valid data + s_tune, p_tune, o_tune = ( + self.tune_data[:, 0], + self.tune_data[:, 1], + self.tune_data[:, 2], + ) + p_tune_unique = p_tune.unique() + tune_scores = self.model.score_spo(s_tune, p_tune, o_tune) + + # Get scores and scores per relation for the corrupted test data + s_eval, p_eval, o_eval = ( + self.eval_data[:, 0], + self.eval_data[:, 1], + self.eval_data[:, 2], + ) + p_eval_unique = p_eval.unique() + eval_scores = self.model.score_spo(s_eval, p_eval, o_eval) + + # Find the best thresholds for every relation on validation data + rel_thresholds = self.findThresholds(p_tune_unique, tune_scores,) + + # Make prediction for the specified evaluation data + self.config.log("Evaluating on {} data.".format(self.eval_split)) + metrics, not_in_eval = self.predict( + eval_scores, rel_thresholds, p_tune_unique, p_eval_unique + ) + + epoch_time += time.time() + # compute trace + trace_entry = dict( + type="triple_classification", + scope="epoch", + data_thresholds="Valid", + data_evaluate=self.eval_split, + epoch=self.epoch, + epoch_time=epoch_time, + **metrics, + ) + for f in self.post_epoch_trace_hooks: + f(self, trace_entry) + + # if validation metric is not present, try to compute it + metric_name = self.config.get("valid.metric") + if metric_name not in trace_entry: + trace_entry[metric_name] = eval( + self.config.get("valid.metric_expr"), + None, + {"config": self.config, **trace_entry}, + ) + + # write out trace + trace_entry = self.trace(**trace_entry, echo=True, echo_prefix=" ", log=True) + + # reset model and return metrics + if was_training: + self.model.train() + self.config.log("Finished evaluating on " + self.eval_split + " data.") + + return trace_entry + + def findThresholds(self, p_tune_unique, tune_scores): + """Find the best thresholds per relation by maximizing accuracy on + validation data. + + The thresholds are found for every relation by maximizing the accuracy on + the validation data. For a given relation, if the scores of all triple in + the relation are sorted, the perfect threshold is always a cut between two + of the scores. This means, that multiple possible values can be defined as + thresholds and give the highest accuracy. To evaluate only as many possible + thresholds as really necessary, the scores themselves are considered as + possible thresholds. This allows for a fast implementation. + + Args: + + p_tune: 1-D tensor containing the relations of the corrupted validation + dataset. + + tune_scores: 2-D tensor containing the scores of all corrupted + validation triples. + + rel_tune_scores: Dictionary containing the scores of the triples in a + relation. + + tune_thresh_labels: 1-D tensor containing the labels of all corrupted + tuning triples. + + tune_data: Dataset used. Should be the corrupted validation dataset. + + Returns: + rel_thresholds: Dictionary with thresholds per relation + {relation: thresholds}. + E.g.: {1: tensor(-2.0843, grad_fn=)} + """ + + # Initialize accuracies and thresholds + rel_thresholds = {r: -float("inf") for r in range(self.dataset.num_relations())} + + # Change the valid scores from a 2D to a 1D tensor + # tune_scores = torch.as_tensor( + # [float(tune_scores[i]) for i in range(len(tune_scores))] + # ).to(self.device) + + for r in p_tune_unique: + # 0-1 vector for indexing triples of the current relation + current_rel = self.tune_data[:, 1] == r + true_labels = self.tune_labels[current_rel].view(-1) + + # tune_scores[current_rel] and rel_tune_scores[r] both + # contain the scores of the current relation. In the comparison, every + # score is evaluated as possible threshold against all scores. + predictions = ( + tune_scores[current_rel].view(-1, 1) + >= tune_scores[current_rel].view(1, -1) + ).t() + + accuracies = (predictions == true_labels).float().sum(dim=1) + accuracies_max = accuracies.max() + + # Choose the smallest score of the ones which give the maximum + # accuracy as threshold to stay consistent. + rel_thresholds[r.item()] = tune_scores[current_rel][ + accuracies_max == accuracies + ].min() + + return rel_thresholds + + def predict(self, eval_scores, rel_thresholds, p_tune_unique, p_eval_unique): + """Makes predictions on evaluation/test data. + + Parameters: + rel_thresholds: Dictionary with relation thresholds. + + Returns: + rel_predictions: Dictionary with predictions for the triples in a relation, e.g. {1: [0, 0, 1, 1]}. + not_in_eval: List with relations that are in the test data, but not in the validation data. + """ + + tptn = 0 + # Set variable for relations which are not in valid data, but in test data + not_in_eval = [] + for r in p_eval_unique: + if ( + r in p_tune_unique + ): # Check if relation which is in valid data also is in test data + # Predict + current_rel = self.eval_data[:, 1] == r + true_labels = self.eval_labels[current_rel] + predictions = eval_scores[current_rel] >= rel_thresholds[r.item()] + tptn += (predictions == true_labels).float().sum().item() + else: + not_in_eval.append(r) + + metrics = dict(accuracy=tptn / self.eval_data.size(0)) + + return metrics, not_in_eval diff --git a/kge/util/sampler.py b/kge/util/sampler.py index 5f1ab31b8..03f9f1ac5 100644 --- a/kge/util/sampler.py +++ b/kge/util/sampler.py @@ -279,7 +279,11 @@ def _filter_and_resample_fast( positives_index = numba.typed.Dict() for i in range(batch_size): pair = (pairs[i][0], pairs[i][1]) - positives_index[pair] = index.get(pair).numpy() + positives_index[pair] = ( + index.get(pair).numpy() + if pair in index + else torch.IntTensor([]).numpy() + ) negative_samples = negative_samples.numpy() KgeUniformSampler._filter_and_resample_numba( negative_samples, pairs, positives_index, batch_size, int(voc_size),