diff --git a/kge/config-default.yaml b/kge/config-default.yaml index 5ef718b2e..fe97cfb4f 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -235,6 +235,10 @@ KvsAll: s_o: False _po: True + # Dataset splits from which the labels for a query are taken from. Default: If + # nothing is specified, then the train split is used. + label_splits: [] + # Options for negative sampling training (train.type=="negative_sampling") negative_sampling: # Negative sampler to use @@ -264,7 +268,7 @@ negative_sampling: p: False # as above o: False # as above - split: '' # split containing the positives; default is train.split + splits: [] # splits containing the positives; default is train.split # Implementation to use for filtering. # standard: use slow generic implementation, available for all samplers @@ -327,7 +331,8 @@ eval: # mean_reciprocal_rank_filtered_with_test. filter_with_test: True - # Type of evaluation (entity_ranking only at the moment) + # Type of evaluation (entity_ranking or training_loss). Currently, + # entity_ranking runs training_loss as well. type: entity_ranking # Compute Hits@K for these choices of K diff --git a/kge/indexing.py b/kge/indexing.py index b4e136ef4..e394323c6 100644 --- a/kge/indexing.py +++ b/kge/indexing.py @@ -2,6 +2,7 @@ from collections import defaultdict, OrderedDict import numba import numpy as np +from kge.misc import powerset, merge_dicts_of_1dim_torch_tensors def _group_by(keys, values) -> dict: @@ -222,6 +223,15 @@ def _invert_ids(dataset, obj: str): dataset.config.log(f"Indexed {len(inv)} {obj} ids", prefix=" ") +def merge_KvsAll_indexes(dataset, split, key): + value = dict([("sp", "o"), ("po", "s"), ("so", "p")])[key] + split_combi_str = "_".join(sorted(split)) + index_name = f"{split_combi_str}_{key}_to_{value}" + indexes = [dataset.index(f"{_split}_{key}_to_{value}") for _split in split] + dataset._indexes[index_name] = merge_dicts_of_1dim_torch_tensors(indexes) + return dataset._indexes[index_name] + + def create_default_index_functions(dataset: "Dataset"): for split in dataset.files_of_type("triples"): for key, value in [("sp", "o"), ("po", "s"), ("so", "p")]: @@ -229,6 +239,15 @@ def create_default_index_functions(dataset: "Dataset"): dataset.index_functions[f"{split}_{key}_to_{value}"] = IndexWrapper( index_KvsAll, split=split, key=key ) + # create all combinations of splits of length 2 and 3 + for split_combi in powerset(dataset.files_of_type("triples"), [2, 3]): + for key, value in [("sp", "o"), ("po", "s"), ("so", "p")]: + split_combi_str = "_".join(sorted(split_combi)) + index_name = f"{split_combi_str}_{key}_to_{value}" + dataset.index_functions[index_name] = IndexWrapper( + merge_KvsAll_indexes, split=split_combi, key=key + ) + dataset.index_functions["relation_types"] = index_relation_types dataset.index_functions["relations_per_type"] = index_relation_types dataset.index_functions["frequency_percentiles"] = index_frequency_percentiles diff --git a/kge/job/entity_ranking.py b/kge/job/entity_ranking.py index 9b36d1715..299551bf9 100644 --- a/kge/job/entity_ranking.py +++ b/kge/job/entity_ranking.py @@ -1,9 +1,10 @@ import math import time +from typing import Dict, Any import torch import kge.job -from kge.job import EvaluationJob, Job +from kge.job import EvaluationJob, Job, TrainingJob from kge import Config, Dataset from collections import defaultdict @@ -13,18 +14,23 @@ class EntityRankingJob(EvaluationJob): def __init__(self, config: Config, dataset: Dataset, parent_job, model): super().__init__(config, dataset, parent_job, model) - self.is_prepared = False if self.__class__ == EntityRankingJob: for f in Job.job_created_hooks: f(self) + max_k = min( + self.dataset.num_entities(), max(self.config.get("eval.hits_at_k_s")) + ) + self.hits_at_k_s = list( + filter(lambda x: x <= max_k, self.config.get("eval.hits_at_k_s")) + ) + self.filter_with_test = config.get("eval.filter_with_test") + + def _prepare(self): """Construct all indexes needed to run.""" - if self.is_prepared: - return - # create data and precompute indexes self.triples = self.dataset.split(self.config.get("eval.split")) for split in self.filter_splits: @@ -75,16 +81,8 @@ def _collate(self, batch): return batch, label_coords, test_label_coords @torch.no_grad() - def run(self) -> dict: - self._prepare() - - was_training = self.model.training - self.model.eval() - self.config.log( - "Evaluating on " - + self.eval_split - + " data (epoch {})...".format(self.epoch) - ) + def _run(self) -> Dict[str, Any]: + num_entities = self.dataset.num_entities() # we also filter with test data if requested @@ -399,28 +397,6 @@ def merge_hist(target_hists, source_hists): event="eval_completed", **metrics, ) - for f in self.post_epoch_trace_hooks: - f(self, trace_entry) - - # if validation metric is not present, try to compute it - metric_name = self.config.get("valid.metric") - if metric_name not in trace_entry: - trace_entry[metric_name] = eval( - self.config.get("valid.metric_expr"), - None, - dict(config=self.config, **trace_entry), - ) - - # write out trace - trace_entry = self.trace(**trace_entry, echo=True, echo_prefix=" ", log=True) - - # reset model and return metrics - if was_training: - self.model.train() - self.config.log("Finished evaluating on " + self.eval_split + " split.") - - for f in self.post_valid_hooks: - f(self, trace_entry) return trace_entry diff --git a/kge/job/eval.py b/kge/job/eval.py index 138726728..491f4fb12 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -1,10 +1,10 @@ -import torch +import time +from typing import Any, Optional, Dict +import torch from kge import Config, Dataset -from kge.job import Job -from kge.model import KgeModel -from typing import Dict, Union, Optional +from kge.job import Job, TrainingJob class EvaluationJob(Job): @@ -16,12 +16,6 @@ def __init__(self, config, dataset, parent_job, model): self.model = model self.batch_size = config.get("eval.batch_size") self.device = self.config.get("job.device") - max_k = min( - self.dataset.num_entities(), max(self.config.get("eval.hits_at_k_s")) - ) - self.hits_at_k_s = list( - filter(lambda x: x <= max_k, self.config.get("eval.hits_at_k_s")) - ) self.config.check("train.trace_level", ["example", "batch", "epoch"]) self.trace_examples = self.config.get("eval.trace_level") == "example" self.trace_batch = ( @@ -31,9 +25,11 @@ def __init__(self, config, dataset, parent_job, model): self.filter_splits = self.config.get("eval.filter_splits") if self.eval_split not in self.filter_splits: self.filter_splits.append(self.eval_split) - self.filter_with_test = config.get("eval.filter_with_test") self.epoch = -1 + self.verbose = True + self.is_prepared = False + #: Hooks run after training for an epoch. #: Signature: job, trace_entry self.post_epoch_hooks = [] @@ -64,6 +60,22 @@ def __init__(self, config, dataset, parent_job, model): if config.get("eval.metrics_per.argument_frequency"): self.hist_hooks.append(hist_per_frequency_percentile) + # Add the training loss as a default to every evaluation job + # TODO: create AggregatingEvaluationsJob that runs and aggregates a list + # of EvaluationAjobs, such that users can configure combinations of + # EvalJobs themselves. Then this can be removed. + # See https://github.com/uma-pi1/kge/issues/102 + if not isinstance(self, TrainingLossEvaluationJob): + self.eval_train_loss_job = TrainingLossEvaluationJob( + config, dataset, parent_job=self, model=model + ) + self.eval_train_loss_job.verbose = False + self.post_epoch_trace_hooks.append( + lambda job, trace: trace.update( + avg_loss=self.eval_train_loss_job.run()["avg_loss"] + ) + ) + # all done, run job_created_hooks if necessary if self.__class__ == EvaluationJob: for f in Job.job_created_hooks: @@ -81,10 +93,66 @@ def create(config, dataset, parent_job=None, model=None): return EntityPairRankingJob( config, dataset, parent_job=parent_job, model=model ) + elif config.get("eval.type") == "training_loss": + return TrainingLossEvaluationJob( + config, dataset, parent_job=parent_job, model=model + ) else: raise ValueError("eval.type") - def run(self) -> dict: + def _prepare(self): + """Prepare this job for running. Guaranteed to be called exactly once + """ + raise NotImplementedError + + def run(self) -> Dict[str, Any]: + + if not self.is_prepared: + self._prepare() + self.model.prepare_job(self) # let the model add some hooks + self.is_prepared = True + + was_training = self.model.training + self.model.eval() + self.config.log( + "Evaluating on " + + self.eval_split + + " data (epoch {})...".format(self.epoch), + echo=self.verbose, + ) + + trace_entry = self._run() + + # if validation metric is not present, try to compute it + metric_name = self.config.get("valid.metric") + if metric_name not in trace_entry: + trace_entry[metric_name] = eval( + self.config.get("valid.metric_expr"), + None, + dict(config=self.config, **trace_entry), + ) + + for f in self.post_epoch_trace_hooks: + f(self, trace_entry) + + # write out trace + trace_entry = self.trace( + **trace_entry, echo=self.verbose, echo_prefix=" ", log=True + ) + + # reset model and return metrics + if was_training: + self.model.train() + self.config.log( + "Finished evaluating on " + self.eval_split + " split.", echo=self.verbose + ) + + for f in self.post_valid_hooks: + f(self, trace_entry) + + return trace_entry + + def _run(self) -> Dict[str, Any]: """ Compute evaluation metrics, output results to trace file """ raise NotImplementedError @@ -132,8 +200,65 @@ def create_from( return super().create_from(checkpoint, new_config, dataset, parent_job) -# HISTOGRAM COMPUTATION ############################################################### +class TrainingLossEvaluationJob(EvaluationJob): + """ Entity ranking evaluation protocol """ + + def __init__(self, config: Config, dataset: Dataset, parent_job, model): + super().__init__(config, dataset, parent_job, model) + self.is_prepared = True + + train_job_on_eval_split_config = config.clone() + train_job_on_eval_split_config.set("train.split", self.eval_split) + train_job_on_eval_split_config.set("verbose", False) + train_job_on_eval_split_config.set( + "negative_sampling.filtering.splits", + [self.config.get("train.split"), self.eval_split] + ["valid"] + if self.eval_split == "test" + else [], + ) + train_job_on_eval_split_config.set( + "KvsAll.label_splits", + [self.config.get("train.split"), self.eval_split] + ["valid"] + if self.eval_split == "test" + else [], + ) + self._train_job = TrainingJob.create( + config=train_job_on_eval_split_config, + parent_job=self, + dataset=dataset, + model=model, + initialize_for_forward_only=True, + ) + self._train_job_verbose = False + if self.__class__ == TrainingLossEvaluationJob: + for f in Job.job_created_hooks: + f(self) + + @torch.no_grad() + def _run(self) -> Dict[str, Any]: + + epoch_time = -time.time() + + self.epoch = self.parent_job.epoch + epoch_time += time.time() + + train_trace_entry = self._train_job.run_epoch() + # compute trace + trace_entry = dict( + type="training_loss", + scope="epoch", + split=self.eval_split, + epoch=self.epoch, + epoch_time=epoch_time, + event="eval_completed", + avg_loss=train_trace_entry["avg_loss"], + ) + + return trace_entry + + +# HISTOGRAM COMPUTATION ############################################################### def __initialize_hist(hists, key, job): """If there is no histogram with given `key` in `hists`, add an empty one.""" diff --git a/kge/job/train.py b/kge/job/train.py index e778bf40b..6a0532b7a 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -1,8 +1,6 @@ -import itertools import os import math import time -import sys from collections import defaultdict from dataclasses import dataclass @@ -10,13 +8,14 @@ import torch import torch.utils.data import numpy as np +from torch.utils.data import DataLoader from kge import Config, Dataset from kge.job import Job from kge.model import KgeModel from kge.util import KgeLoss, KgeOptimizer, KgeSampler, KgeLRScheduler -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional import kge.job.util SLOTS = [0, 1, 2] @@ -52,7 +51,12 @@ class TrainingJob(Job): """ def __init__( - self, config: Config, dataset: Dataset, parent_job: Job = None, model=None + self, + config: Config, + dataset: Dataset, + parent_job: Job = None, + model = None, + initialize_for_forward_only=False, ) -> None: from kge.job import EvaluationJob @@ -61,27 +65,33 @@ def __init__( self.model: KgeModel = KgeModel.create(config, dataset) else: self.model: KgeModel = model - self.optimizer = KgeOptimizer.create(config, self.model) - self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) self.loss = KgeLoss.create(config) self.abort_on_nan: bool = config.get("train.abort_on_nan") self.batch_size: int = config.get("train.batch_size") self.device: str = self.config.get("job.device") self.train_split = config.get("train.split") - self.config.check("train.trace_level", ["batch", "epoch"]) self.trace_batch: bool = self.config.get("train.trace_level") == "batch" self.epoch: int = 0 - self.valid_trace: List[Dict[str, Any]] = [] - valid_conf = config.clone() - valid_conf.set("job.type", "eval") - if self.config.get("valid.split") != "": - valid_conf.set("eval.split", self.config.get("valid.split")) - valid_conf.set("eval.trace_level", self.config.get("valid.trace_level")) - self.valid_job = EvaluationJob.create( - valid_conf, dataset, parent_job=self, model=self.model - ) self.is_prepared = False + self.is_forward_only = initialize_for_forward_only + + if not initialize_for_forward_only: + + self.model.train() + + self.optimizer = KgeOptimizer.create(config, self.model) + self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) + + valid_conf = config.clone() + valid_conf.set("job.type", "eval") + if self.config.get("valid.split") != "": + valid_conf.set("eval.split", self.config.get("valid.split")) + valid_conf.set("eval.trace_level", self.config.get("valid.trace_level")) + self.valid_job = EvaluationJob.create( + valid_conf, dataset, parent_job=self, model=self.model + ) + self.valid_trace: List[Dict[str, Any]] = [] # attributes filled in by implementing classes self.loader = None @@ -116,25 +126,38 @@ def __init__( for f in Job.job_created_hooks: f(self) - self.model.train() - @staticmethod def create( - config: Config, dataset: Dataset, parent_job: Job = None, model=None + config: Config, + dataset: Dataset, + parent_job: Job = None, + model = None, + initialize_for_forward_only=False, ) -> "TrainingJob": """Factory method to create a training job.""" if config.get("train.type") == "KvsAll": - return TrainingJobKvsAll(config, dataset, parent_job, model=model) + return TrainingJobKvsAll( + config, dataset, parent_job, model, initialize_for_forward_only, + ) elif config.get("train.type") == "negative_sampling": - return TrainingJobNegativeSampling(config, dataset, parent_job, model=model) + return TrainingJobNegativeSampling( + config, dataset, parent_job, model, initialize_for_forward_only, + ) elif config.get("train.type") == "1vsAll": - return TrainingJob1vsAll(config, dataset, parent_job, model=model) + return TrainingJob1vsAll( + config, dataset, parent_job, model, initialize_for_forward_only, + ) else: # perhaps TODO: try class with specified name -> extensibility raise ValueError("train.type") def run(self) -> None: """Start/resume the training job and run to completion.""" + if self.is_forward_only: + raise Exception( + f"{self.__class__.__name__} was initialized for forward only. You can only call run_epoch()" + ) + self.config.log("Starting training...") checkpoint_every = self.config.get("train.checkpoint.every") checkpoint_keep = self.config.get("train.checkpoint.keep") @@ -292,7 +315,7 @@ def _load(self, checkpoint: Dict) -> str: ) ) - def run_epoch(self) -> Dict[str, Any]: + def run_epoch(self, verbose: bool = True) -> Dict[str, Any]: "Runs an epoch and returns a trace entry." # prepare the job is not done already @@ -317,7 +340,8 @@ def run_epoch(self) -> Dict[str, Any]: f(self) # process batch (preprocessing + forward pass + backward pass on loss) - self.optimizer.zero_grad() + if not self.is_forward_only: + self.optimizer.zero_grad() batch_result: TrainingJob._ProcessBatchResult = self._process_batch( batch_index, batch ) @@ -337,7 +361,8 @@ def run_epoch(self) -> Dict[str, Any]: batch_backward_time = batch_result.backward_time - time.time() penalty = 0.0 for index, (penalty_key, penalty_value_torch) in enumerate(penalties_torch): - penalty_value_torch.backward() + if not self.is_forward_only: + penalty_value_torch.backward() penalty += penalty_value_torch.item() sum_penalties[penalty_key] += penalty_value_torch.item() sum_penalty += penalty @@ -377,9 +402,11 @@ def run_epoch(self) -> Dict[str, Any]: ) # update parameters - batch_optimizer_time = -time.time() - self.optimizer.step() - batch_optimizer_time += time.time() + batch_optimizer_time = 0 + if not self.is_forward_only: + batch_optimizer_time = -time.time() + self.optimizer.step() + batch_optimizer_time += time.time() # tracing/logging if self.trace_batch: @@ -404,29 +431,29 @@ def run_epoch(self) -> Dict[str, Any]: for f in self.post_batch_trace_hooks: f(self, batch_trace) self.trace(**batch_trace, event="batch_completed") - self.config.print( - ( - "\r" # go back - + "{} batch{: " - + str(1 + int(math.ceil(math.log10(len(self.loader))))) - + "d}/{}" - + ", avg_loss {:.4E}, penalty {:.4E}, cost {:.4E}, time {:6.2f}s" - + "\033[K" # clear to right - ).format( - self.config.log_prefix, - batch_index, - len(self.loader) - 1, - batch_result.avg_loss, - penalty, - cost_value, - batch_result.prepare_time - + batch_forward_time - + batch_backward_time - + batch_optimizer_time, - ), - end="", - flush=True, - ) + if echo_trace: + self.config.print( + "\r" # go back + + "{} batch{: " + + str(1 + int(math.ceil(math.log10(len(self.loader))))) + + "d}/{}" + + ", avg_loss {:.4E}, penalty {:.4E}, cost {:.4E}, time {:6.2f}s" + + "\033[K" # clear to right + ).format( + self.config.log_prefix, + batch_index, + len(self.loader) - 1, + batch_result.avg_loss, + penalty, + cost_value, + batch_result.prepare_time + + batch_forward_time + + batch_backward_time + + batch_optimizer_time, + ), + end="", + flush=True, + ) # update times prepare_time += batch_result.prepare_time @@ -448,7 +475,6 @@ def run_epoch(self) -> Dict[str, Any]: split=self.train_split, batches=len(self.loader), size=self.num_examples, - lr=[group["lr"] for group in self.optimizer.param_groups], avg_loss=sum_loss / self.num_examples, avg_penalty=sum_penalty / len(self.loader), avg_penalties={k: p / len(self.loader) for k, p in sum_penalties.items()}, @@ -461,9 +487,15 @@ def run_epoch(self) -> Dict[str, Any]: other_time=other_time, event="epoch_completed", ) + if not self.is_forward_only: + trace_entry.update( + lr=[group["lr"] for group in self.optimizer.param_groups], + ) for f in self.post_epoch_trace_hooks: f(self, trace_entry) - trace_entry = self.trace(**trace_entry, echo=True, echo_prefix=" ", log=True) + trace_entry = self.trace( + **trace_entry, echo=verbose, echo_prefix=" ", log=True + ) return trace_entry def _prepare(self): @@ -506,8 +538,10 @@ class TrainingJobKvsAll(TrainingJob): - Example: a query + its labels, e.g., (John,marriedTo), [Jane] """ - def __init__(self, config, dataset, parent_job=None, model=None): - super().__init__(config, dataset, parent_job, model=model) + def __init__( + self, config, dataset, parent_job=None, model=None, initialize_for_forward_only=False + ): + super().__init__(config, dataset, parent_job, model, initialize_for_forward_only) self.label_smoothing = config.check_range( "KvsAll.label_smoothing", float("-inf"), 1.0, max_inclusive=False ) @@ -543,6 +577,15 @@ def __init__(self, config, dataset, parent_job=None, model=None): ) ) + self.label_splits = set(self.config.get("KvsAll.label_splits")) | set( + [self.train_split] + ) + + self.queries = None + self.labels = None + self.label_offsets = None + self.query_end_index = None + config.log("Initializing 1-to-N training job...") self.type_str = "KvsAll" @@ -553,13 +596,6 @@ def __init__(self, config, dataset, parent_job=None, model=None): def _prepare(self): from kge.indexing import index_KvsAll_to_torch - # determine enabled query types - self.query_types = [ - key - for key, enabled in self.config.get("KvsAll.query_types").items() - if enabled - ] - #' for each query type: list of queries self.queries = {} @@ -573,7 +609,14 @@ def _prepare(self): #' for each query type (ordered as in self.query_types), index right after last #' example of that type in the list of all examples - self.query_end_index = [] + self.query_end_index = {} + + # determine enabled query types + self.query_types = [ + key + for key, enabled in self.config.get("KvsAll.query_types").items() + if enabled + ] # construct relevant data structures self.num_examples = 0 @@ -583,9 +626,12 @@ def _prepare(self): if query_type == "sp_" else ("so_to_p" if query_type == "s_o" else "po_to_s") ) - index = self.dataset.index(f"{self.train_split}_{index_type}") + split_combi_str = "_".join(sorted(self.label_splits)) + label_index = self.dataset.index(f"{split_combi_str}_{index_type}") + query_index = self.dataset.index(f"{self.train_split}_{index_type}") + index = {k: v for k, v in label_index.items() if k in query_index} self.num_examples += len(index) - self.query_end_index.append(self.num_examples) + self.query_end_index[query_type] = self.num_examples # Convert indexes to pytorch tensors (as described above). ( @@ -623,8 +669,8 @@ def collate(batch): num_ones = 0 for example_index in batch: start = 0 - for query_type_index, query_type in enumerate(self.query_types): - end = self.query_end_index[query_type_index] + for query_type in self.query_types: + end = self.query_end_index[query_type] if example_index < end: example_index -= start num_ones += self.label_offsets[query_type][example_index + 1] @@ -641,7 +687,7 @@ def collate(batch): for batch_index, example_index in enumerate(batch): start = 0 for query_type_index, query_type in enumerate(self.query_types): - end = self.query_end_index[query_type_index] + end = self.query_end_index[query_type] if example_index < end: example_index -= start query_type_indexes_batch[batch_index] = query_type_index @@ -759,7 +805,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: loss_value_total = loss_value.item() forward_time += time.time() backward_time -= time.time() - loss_value.backward() + if not self.is_forward_only: + loss_value.backward() backward_time += time.time() # all done @@ -769,10 +816,10 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: class TrainingJobNegativeSampling(TrainingJob): - def __init__(self, config, dataset, parent_job=None, model=None): - super().__init__(config, dataset, parent_job, model=model) + def __init__( + self, config, dataset, parent_job=None, model=None, initialize_for_forward_only=False): + super().__init__(config, dataset, parent_job, model, initialize_for_forward_only) self._sampler = KgeSampler.create(config, "negative_sampling", dataset) - self.is_prepared = False self._implementation = self.config.check( "negative_sampling.implementation", ["triple", "all", "batch", "auto"], ) @@ -799,9 +846,6 @@ def __init__(self, config, dataset, parent_job=None, model=None): def _prepare(self): """Construct dataloader""" - if self.is_prepared: - return - self.num_examples = self.dataset.split(self.train_split).size(0) self.loader = torch.utils.data.DataLoader( range(self.num_examples), @@ -813,8 +857,6 @@ def _prepare(self): pin_memory=self.config.get("train.pin_memory"), ) - self.is_prepared = True - def _get_collate_fun(self): # create the collate function def collate(batch): @@ -1007,7 +1049,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: # backward pass for this chunk backward_time -= time.time() - loss_value_torch.backward() + if not self.is_forward_only: + loss_value_torch.backward() backward_time += time.time() # all done @@ -1019,9 +1062,10 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: class TrainingJob1vsAll(TrainingJob): """Samples SPO pairs and queries sp_ and _po, treating all other entities as negative.""" - def __init__(self, config, dataset, parent_job=None, model=None): - super().__init__(config, dataset, parent_job, model=model) - self.is_prepared = False + def __init__( + self, config, dataset, parent_job=None, model=None, initialize_for_forward_only=False + ): + super().__init__(config, dataset, parent_job, model, initialize_for_forward_only) config.log("Initializing spo training job...") self.type_str = "1vsAll" @@ -1032,9 +1076,6 @@ def __init__(self, config, dataset, parent_job=None, model=None): def _prepare(self): """Construct dataloader""" - if self.is_prepared: - return - self.num_examples = self.dataset.split(self.train_split).size(0) self.loader = torch.utils.data.DataLoader( range(self.num_examples), @@ -1048,8 +1089,6 @@ def _prepare(self): pin_memory=self.config.get("train.pin_memory"), ) - self.is_prepared = True - def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() @@ -1064,7 +1103,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: loss_value = loss_value_sp.item() forward_time += time.time() backward_time = -time.time() - loss_value_sp.backward() + if not self.is_forward_only: + loss_value_sp.backward() backward_time += time.time() # forward/backward pass (po) @@ -1074,7 +1114,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: loss_value += loss_value_po.item() forward_time += time.time() backward_time -= time.time() - loss_value_po.backward() + if not self.is_forward_only: + loss_value_po.backward() backward_time += time.time() # all done diff --git a/kge/misc.py b/kge/misc.py index 90db3750c..01613d37b 100644 --- a/kge/misc.py +++ b/kge/misc.py @@ -1,3 +1,5 @@ +import torch +from itertools import chain, combinations from typing import List from torch import nn as nn @@ -6,6 +8,7 @@ import inspect import subprocess + def is_number(s, number_type): """ Returns True is string is a number. """ try: @@ -15,6 +18,33 @@ def is_number(s, number_type): return False +def merge_dicts_of_1dim_torch_tensors(keys_from_dols, values_from_dols=None): + if values_from_dols is None: + values_from_dols = keys_from_dols + keys = set(chain.from_iterable([d.keys() for d in keys_from_dols])) + no = torch.tensor([]).int() + return dict( + ( + k, + torch.cat([d.get(k, no) for d in values_from_dols]).sort()[0], + ) + for k in keys + ) + + +def powerset(iterable, filter_lens=None): + s = list(iterable) + return list( + map( + sorted, + filter( + lambda l: len(l) in filter_lens if filter_lens else True, + chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)), + ), + ) + ) + + # from https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script def get_git_revision_hash(): try: @@ -69,6 +99,7 @@ def is_exe(fpath): def kge_base_dir(): import kge + return os.path.abspath(filename_in_module(kge, "..")) diff --git a/kge/util/sampler.py b/kge/util/sampler.py index 5f1ab31b8..daff4a87b 100644 --- a/kge/util/sampler.py +++ b/kge/util/sampler.py @@ -3,7 +3,7 @@ import random import torch -from typing import Optional +from typing import Optional, List import numpy as np import numba @@ -29,9 +29,9 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset): "Without replacement sampling is only supported when " "shared negative sampling is enabled." ) - self.filtering_split = config.get("negative_sampling.filtering.split") - if self.filtering_split == "": - self.filtering_split = config.get("train.split") + self.filtering_splits: List[str] = config.get("negative_sampling.filtering.splits") + if len(self.filtering_splits) == 0: + self.filtering_splits.append(config.get("train.split")) for slot in SLOTS: slot_str = SLOT_STR[slot] self.num_samples[slot] = self.get_option(f"num_samples.{slot_str}") @@ -43,7 +43,10 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset): # otherwise every worker would create every index again and again if self.filter_positives[slot]: pair = ["po", "so", "sp"][slot] - dataset.index(f"{self.filtering_split}_{pair}_to_{slot_str}") + for filtering_split in self.filtering_splits: + dataset.index(f"{filtering_split}_{pair}_to_{slot_str}") + filtering_splits_str = '_'.join(sorted(self.filtering_splits)) + dataset.index(f"{filtering_splits_str}_{pair}_to_{slot_str}") if any(self.filter_positives): if self.shared: raise ValueError( @@ -144,13 +147,19 @@ def _filter_and_resample( """Filter and resample indices until only negatives have been created. """ pair_str = ["po", "so", "sp"][slot] # holding the positive indices for the respective pair - index = self.dataset.index( - f"{self.filtering_split}_{pair_str}_to_{SLOT_STR[slot]}" - ) cols = [[P, O], [S, O], [S, P]][slot] pairs = positive_triples[:, cols] + split_combi_str = "_".join(sorted(self.filtering_splits)) + index = self.dataset.index( + f"{split_combi_str}_{pair_str}_to_{SLOT_STR[slot]}" + ) for i in range(positive_triples.size(0)): - positives = index.get((pairs[i][0].item(), pairs[i][1].item())).numpy() + pair = (pairs[i][0].item(), pairs[i][1].item()) + positives = ( + index.get(pair).numpy() + if pair in index + else torch.IntTensor([]).numpy() + ) # indices of samples that have to be sampled again resample_idx = where_in(negative_samples[i].numpy(), positives) # number of new samples needed @@ -231,9 +240,7 @@ def _sample_shared( # contain its positive, drop that positive. For all other rows, drop a random # position. shared_samples_index = {s: j for j, s in enumerate(shared_samples)} - replacement = np.random.choice( - num_distinct + 1, batch_size, replace=True - ) + replacement = np.random.choice(num_distinct + 1, batch_size, replace=True) drop = torch.tensor( [ shared_samples_index.get(s, replacement[i]) @@ -265,8 +272,9 @@ def _filter_and_resample_fast( ): pair_str = ["po", "so", "sp"][slot] # holding the positive indices for the respective pair + split_combi_str = "_".join(sorted(self.filtering_splits)) index = self.dataset.index( - f"{self.filtering_split}_{pair_str}_to_{SLOT_STR[slot]}" + f"{split_combi_str}_{pair_str}_to_{SLOT_STR[slot]}" ) cols = [[P, O], [S, O], [S, P]][slot] pairs = positive_triples[:, cols].numpy() @@ -279,13 +287,18 @@ def _filter_and_resample_fast( positives_index = numba.typed.Dict() for i in range(batch_size): pair = (pairs[i][0], pairs[i][1]) - positives_index[pair] = index.get(pair).numpy() + positives_index[pair] = ( + index.get(pair).numpy() + if pair in index + else torch.IntTensor([]).numpy() + ) negative_samples = negative_samples.numpy() KgeUniformSampler._filter_and_resample_numba( negative_samples, pairs, positives_index, batch_size, int(voc_size), ) return torch.tensor(negative_samples, dtype=torch.int64) + @staticmethod @numba.njit def _filter_and_resample_numba( negative_samples, pairs, positives_index, batch_size, voc_size