From cd6dfad8e4b79c404c592d1ec3053a68cc7391f0 Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Sat, 16 May 2020 23:30:48 +0200 Subject: [PATCH 1/9] Add loss/cost on validation data --- kge/job/entity_ranking.py | 24 ++-- kge/job/train.py | 244 ++++++++++++++++++++------------------ 2 files changed, 145 insertions(+), 123 deletions(-) diff --git a/kge/job/entity_ranking.py b/kge/job/entity_ranking.py index 0ac108acc..c47548396 100644 --- a/kge/job/entity_ranking.py +++ b/kge/job/entity_ranking.py @@ -3,7 +3,7 @@ import torch import kge.job -from kge.job import EvaluationJob, Job +from kge.job import EvaluationJob, Job, TrainingJob from kge import Config, Dataset from collections import defaultdict @@ -88,15 +88,11 @@ def run(self) -> dict: num_entities = self.dataset.num_entities() # we also filter with test data if requested - filter_with_test = ( - "test" not in self.filter_splits and self.filter_with_test - ) + filter_with_test = "test" not in self.filter_splits and self.filter_with_test # which rankings to compute (DO NOT REORDER; code assumes the order given here) rankings = ( - ["_raw", "_filt", "_filt_test"] - if filter_with_test - else ["_raw", "_filt"] + ["_raw", "_filt", "_filt_test"] if filter_with_test else ["_raw", "_filt"] ) # dictionary that maps entry of rankings to a sparse tensor containing the @@ -327,7 +323,7 @@ def run(self) -> dict: type="entity_ranking", scope="batch", split=self.eval_split, - filter_splits = self.filter_splits, + filter_splits=self.filter_splits, epoch=self.epoch, batch=batch_number, size=len(batch), @@ -390,6 +386,18 @@ def merge_hist(target_hists, source_hists): ) epoch_time += time.time() + if isinstance(self.parent_job, TrainingJob): + self.parent_job: TrainingJob = self.parent_job + train_trace_entry = self.parent_job.run_epoch( + split=self.eval_split, echo_trace=False, do_backward=False + ) + metrics.update( + avg_loss=train_trace_entry["avg_loss"], + avg_penalty=train_trace_entry["avg_penalty"], + avg_penalties=train_trace_entry["avg_penalties"], + avg_cost=train_trace_entry["avg_cost"], + ) + # compute trace trace_entry = dict( type="entity_ranking", diff --git a/kge/job/train.py b/kge/job/train.py index 90edd2a18..d03685032 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -9,6 +9,7 @@ import torch import torch.utils.data import numpy as np +from torch.utils.data import DataLoader from kge import Config, Dataset from kge.job import Job @@ -63,12 +64,12 @@ def __init__( self.trace_batch: bool = self.config.get("train.trace_level") == "batch" self.epoch: int = 0 self.valid_trace: List[Dict[str, Any]] = [] - self.is_prepared = False + self.is_prepared: Dict[str, bool] = defaultdict(lambda: False) self.model.train() # attributes filled in by implementing classes - self.loader = None - self.num_examples = None + self.loader: Dict[str, DataLoader] = dict() + self.num_examples: Dict[str, int] = dict() self.type_str: Optional[str] = None #: Hooks run after training for an epoch. @@ -166,7 +167,9 @@ def run(self) -> None: # start a new epoch self.epoch += 1 self.config.log("Starting epoch {}...".format(self.epoch)) - trace_entry = self.run_epoch() + trace_entry = self.run_epoch( + split=self.train_split, echo_trace=True, do_backward=True + ) for f in self.post_epoch_hooks: f(self, trace_entry) self.config.log("Finished epoch {}.".format(self.epoch)) @@ -287,14 +290,16 @@ def resume(self, checkpoint_file: str = None) -> None: else: self.config.log("No checkpoint found, starting from scratch...") - def run_epoch(self) -> Dict[str, Any]: + def run_epoch( + self, split: str, echo_trace: bool, do_backward: bool + ) -> Dict[str, Any]: "Runs an epoch and returns a trace entry." # prepare the job is not done already - if not self.is_prepared: - self._prepare() + if not self.is_prepared[split]: + self._prepare(split=split) self.model.prepare_job(self) # let the model add some hooks - self.is_prepared = True + self.is_prepared[split] = True # variables that record various statitics sum_loss = 0.0 @@ -307,14 +312,14 @@ def run_epoch(self) -> Dict[str, Any]: optimizer_time = 0.0 # process each batch - for batch_index, batch in enumerate(self.loader): + for batch_index, batch in enumerate(self.loader[split]): for f in self.pre_batch_hooks: f(self) # process batch (preprocessing + forward pass + backward pass on loss) self.optimizer.zero_grad() batch_result: TrainingJob._ProcessBatchResult = self._process_batch( - batch_index, batch + batch_index, batch, do_backward ) sum_loss += batch_result.avg_loss * batch_result.size @@ -323,7 +328,7 @@ def run_epoch(self) -> Dict[str, Any]: penalties_torch = self.model.penalty( epoch=self.epoch, batch_index=batch_index, - num_batches=len(self.loader), + num_batches=len(self.loader[split]), batch=batch, ) batch_forward_time += time.time() @@ -332,7 +337,8 @@ def run_epoch(self) -> Dict[str, Any]: batch_backward_time = batch_result.backward_time - time.time() penalty = 0.0 for index, (penalty_key, penalty_value_torch) in enumerate(penalties_torch): - penalty_value_torch.backward() + if do_backward: + penalty_value_torch.backward() penalty += penalty_value_torch.item() sum_penalties[penalty_key] += penalty_value_torch.item() sum_penalty += penalty @@ -372,9 +378,11 @@ def run_epoch(self) -> Dict[str, Any]: ) # update parameters - batch_optimizer_time = -time.time() - self.optimizer.step() - batch_optimizer_time += time.time() + batch_optimizer_time = 0 + if do_backward: + batch_optimizer_time = -time.time() + self.optimizer.step() + batch_optimizer_time += time.time() # tracing/logging if self.trace_batch: @@ -382,10 +390,10 @@ def run_epoch(self) -> Dict[str, Any]: "type": self.type_str, "scope": "batch", "epoch": self.epoch, - "split": self.train_split, + "split": split, "batch": batch_index, "size": batch_result.size, - "batches": len(self.loader), + "batches": len(self.loader[split]), "lr": [group["lr"] for group in self.optimizer.param_groups], "avg_loss": batch_result.avg_loss, "penalties": [p.item() for k, p in penalties_torch], @@ -399,29 +407,30 @@ def run_epoch(self) -> Dict[str, Any]: for f in self.post_batch_trace_hooks: f(self, batch_trace) self.trace(**batch_trace, event="batch_completed") - print( - ( - "\r" # go back - + "{} batch{: " - + str(1 + int(math.ceil(math.log10(len(self.loader))))) - + "d}/{}" - + ", avg_loss {:.4E}, penalty {:.4E}, cost {:.4E}, time {:6.2f}s" - + "\033[K" # clear to right - ).format( - self.config.log_prefix, - batch_index, - len(self.loader) - 1, - batch_result.avg_loss, - penalty, - cost_value, - batch_result.prepare_time - + batch_forward_time - + batch_backward_time - + batch_optimizer_time, - ), - end="", - flush=True, - ) + if echo_trace: + print( + ( + "\r" # go back + + "{} batch{: " + + str(1 + int(math.ceil(math.log10(len(self.loader[split]))))) + + "d}/{}" + + ", avg_loss {:.4E}, penalty {:.4E}, cost {:.4E}, time {:6.2f}s" + + "\033[K" # clear to right + ).format( + self.config.log_prefix, + batch_index, + len(self.loader[split]) - 1, + batch_result.avg_loss, + penalty, + cost_value, + batch_result.prepare_time + + batch_forward_time + + batch_backward_time + + batch_optimizer_time, + ), + end="", + flush=True, + ) # update times prepare_time += batch_result.prepare_time @@ -440,14 +449,15 @@ def run_epoch(self) -> Dict[str, Any]: type=self.type_str, scope="epoch", epoch=self.epoch, - split=self.train_split, - batches=len(self.loader), - size=self.num_examples, + split=split, + batches=len(self.loader[split]), + size=self.num_examples[split], lr=[group["lr"] for group in self.optimizer.param_groups], - avg_loss=sum_loss / self.num_examples, - avg_penalty=sum_penalty / len(self.loader), - avg_penalties={k: p / len(self.loader) for k, p in sum_penalties.items()}, - avg_cost=sum_loss / self.num_examples + sum_penalty / len(self.loader), + avg_loss=sum_loss / self.num_examples[split], + avg_penalty=sum_penalty / len(self.loader[split]), + avg_penalties={k: p / len(self.loader[split]) for k, p in sum_penalties.items()}, + avg_cost=sum_loss / self.num_examples[split] + + sum_penalty / len(self.loader[split]), epoch_time=epoch_time, prepare_time=prepare_time, forward_time=forward_time, @@ -458,10 +468,12 @@ def run_epoch(self) -> Dict[str, Any]: ) for f in self.post_epoch_trace_hooks: f(self, trace_entry) - trace_entry = self.trace(**trace_entry, echo=True, echo_prefix=" ", log=True) + trace_entry = self.trace( + **trace_entry, echo=echo_trace, echo_prefix=" ", log=True + ) return trace_entry - def _prepare(self): + def _prepare(self, split: str): """Prepare this job for running. Sets (at least) the `loader`, `num_examples`, and `type_str` attributes of this @@ -484,7 +496,7 @@ class _ProcessBatchResult: backward_time: float def _process_batch( - self, batch_index: int, batch + self, batch_index: int, batch, do_backward: bool ) -> "TrainingJob._ProcessBatchResult": "Run forward and backward pass on batch and return results." raise NotImplementedError @@ -538,6 +550,21 @@ def __init__(self, config, dataset, parent_job=None): ) ) + #' for each query type: list of queries + self.queries = {} + + #' for each query type: list of all labels (concatenated across queries) + self.labels = {} + + #' for each query type: list of starting offset of labels in self.labels. The + #' labels for the i-th query of query_type are in labels[query_type] in range + #' label_offsets[query_type][i]:label_offsets[query_type][i+1] + self.label_offsets = {} + + #' for each query type (ordered as in self.query_types), index right after last + #' example of that type in the list of all examples + self.query_end_index = {} + config.log("Initializing 1-to-N training job...") self.type_str = "KvsAll" @@ -545,7 +572,7 @@ def __init__(self, config, dataset, parent_job=None): for f in Job.job_created_hooks: f(self) - def _prepare(self): + def _prepare(self, split: str): from kge.indexing import index_KvsAll_to_torch # determine enabled query types @@ -555,44 +582,29 @@ def _prepare(self): if enabled ] - #' for each query type: list of queries - self.queries = {} - - #' for each query type: list of all labels (concatenated across queries) - self.labels = {} - - #' for each query type: list of starting offset of labels in self.labels. The - #' labels for the i-th query of query_type are in labels[query_type] in range - #' label_offsets[query_type][i]:label_offsets[query_type][i+1] - self.label_offsets = {} - - #' for each query type (ordered as in self.query_types), index right after last - #' example of that type in the list of all examples - self.query_end_index = [] - # construct relevant data structures - self.num_examples = 0 + self.num_examples[split] = 0 for query_type in self.query_types: index_type = ( "sp_to_o" if query_type == "sp_" else ("so_to_p" if query_type == "s_o" else "po_to_s") ) - index = self.dataset.index(f"{self.train_split}_{index_type}") - self.num_examples += len(index) - self.query_end_index.append(self.num_examples) + index = self.dataset.index(f"{split}_{index_type}") + self.num_examples[split] += len(index) + self.query_end_index[query_type + split] = self.num_examples[split] # Convert indexes to pytorch tensors (as described above). ( - self.queries[query_type], - self.labels[query_type], - self.label_offsets[query_type], + self.queries[query_type + split], + self.labels[query_type + split], + self.label_offsets[query_type + split], ) = index_KvsAll_to_torch(index) # create dataloader - self.loader = torch.utils.data.DataLoader( - range(self.num_examples), - collate_fn=self._get_collate_fun(), + self.loader[split] = torch.utils.data.DataLoader( + range(self.num_examples[split]), + collate_fn=self._get_collate_fun(split), shuffle=True, batch_size=self.batch_size, num_workers=self.config.get("train.num_workers"), @@ -602,7 +614,7 @@ def _prepare(self): pin_memory=self.config.get("train.pin_memory"), ) - def _get_collate_fun(self): + def _get_collate_fun(self, split: str): # create the collate function def collate(batch): """For a batch of size n, returns a dictionary of: @@ -621,11 +633,15 @@ def collate(batch): for example_index in batch: start = 0 for query_type_index, query_type in enumerate(self.query_types): - end = self.query_end_index[query_type_index] + end = self.query_end_index[query_type + split] if example_index < end: example_index -= start - num_ones += self.label_offsets[query_type][example_index + 1] - num_ones -= self.label_offsets[query_type][example_index] + num_ones += self.label_offsets[query_type + split][ + example_index + 1 + ] + num_ones -= self.label_offsets[query_type + split][ + example_index + ] break start = end @@ -638,13 +654,13 @@ def collate(batch): for batch_index, example_index in enumerate(batch): start = 0 for query_type_index, query_type in enumerate(self.query_types): - end = self.query_end_index[query_type_index] + end = self.query_end_index[query_type + split] if example_index < end: example_index -= start query_type_indexes_batch[batch_index] = query_type_index - queries = self.queries[query_type] - label_offsets = self.label_offsets[query_type] - labels = self.labels[query_type] + queries = self.queries[query_type + split] + label_offsets = self.label_offsets[query_type + split] + labels = self.labels[query_type + split] if query_type == "sp_": query_col_1, query_col_2, target_col = S, P, O elif query_type == "s_o": @@ -685,7 +701,9 @@ def collate(batch): return collate - def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: + def _process_batch( + self, batch_index, batch, do_backward: bool + ) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() queries_batch = batch["queries"].to(self.device) @@ -756,7 +774,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: loss_value_total = loss_value.item() forward_time += time.time() backward_time -= time.time() - loss_value.backward() + if do_backward: + loss_value.backward() backward_time += time.time() # all done @@ -769,7 +788,6 @@ class TrainingJobNegativeSampling(TrainingJob): def __init__(self, config, dataset, parent_job=None): super().__init__(config, dataset, parent_job) self._sampler = KgeSampler.create(config, "negative_sampling", dataset) - self.is_prepared = False self._implementation = self.config.check( "negative_sampling.implementation", ["triple", "all", "batch", "auto"], ) @@ -793,16 +811,13 @@ def __init__(self, config, dataset, parent_job=None): for f in Job.job_created_hooks: f(self) - def _prepare(self): + def _prepare(self, split: str): """Construct dataloader""" - if self.is_prepared: - return - - self.num_examples = self.dataset.split(self.train_split).size(0) - self.loader = torch.utils.data.DataLoader( - range(self.num_examples), - collate_fn=self._get_collate_fun(), + self.num_examples[split] = self.dataset.split(split).size(0) + self.loader[split] = torch.utils.data.DataLoader( + range(self.num_examples[split]), + collate_fn=self._get_collate_fun(split), shuffle=True, batch_size=self.batch_size, num_workers=self.config.get("train.num_workers"), @@ -812,9 +827,7 @@ def _prepare(self): pin_memory=self.config.get("train.pin_memory"), ) - self.is_prepared = True - - def _get_collate_fun(self): + def _get_collate_fun(self, split: str): # create the collate function def collate(batch): """For a batch of size n, returns a tuple of: @@ -824,7 +837,7 @@ def collate(batch): in order S,P,O) """ - triples = self.dataset.split(self.train_split)[batch, :].long() + triples = self.dataset.split(split)[batch, :].long() # labels = torch.zeros((len(batch), self._sampler.num_negatives_total + 1)) # labels[:, 0] = 1 # labels = labels.view(-1) @@ -836,7 +849,9 @@ def collate(batch): return collate - def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: + def _process_batch( + self, batch_index, batch, do_backward: bool + ) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() batch_triples = batch["triples"].to(self.device) @@ -1006,7 +1021,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: # backward pass for this chunk backward_time -= time.time() - loss_value_torch.backward() + if do_backward: + loss_value_torch.backward() backward_time += time.time() # all done @@ -1020,7 +1036,6 @@ class TrainingJob1vsAll(TrainingJob): def __init__(self, config, dataset, parent_job=None): super().__init__(config, dataset, parent_job) - self.is_prepared = False config.log("Initializing spo training job...") self.type_str = "1vsAll" @@ -1028,17 +1043,14 @@ def __init__(self, config, dataset, parent_job=None): for f in Job.job_created_hooks: f(self) - def _prepare(self): + def _prepare(self, split: str): """Construct dataloader""" - if self.is_prepared: - return - - self.num_examples = self.dataset.split(self.train_split).size(0) - self.loader = torch.utils.data.DataLoader( - range(self.num_examples), + self.num_examples[split] = self.dataset.split(split).size(0) + self.loader[split] = torch.utils.data.DataLoader( + range(self.num_examples[split]), collate_fn=lambda batch: { - "triples": self.dataset.split(self.train_split)[batch, :].long() + "triples": self.dataset.split(split)[batch, :].long() }, shuffle=True, batch_size=self.batch_size, @@ -1049,9 +1061,9 @@ def _prepare(self): pin_memory=self.config.get("train.pin_memory"), ) - self.is_prepared = True - - def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: + def _process_batch( + self, batch_index, batch, do_backward: bool + ) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() triples = batch["triples"].to(self.device) @@ -1065,7 +1077,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: loss_value = loss_value_sp.item() forward_time += time.time() backward_time = -time.time() - loss_value_sp.backward() + if do_backward: + loss_value_sp.backward() backward_time += time.time() # forward/backward pass (po) @@ -1075,7 +1088,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: loss_value += loss_value_po.item() forward_time += time.time() backward_time -= time.time() - loss_value_po.backward() + if do_backward: + loss_value_po.backward() backward_time += time.time() # all done From 78596ea0072a9801e233c4abc75dc292378b9927 Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Thu, 21 May 2020 00:19:40 +0200 Subject: [PATCH 2/9] Use TrainJob in Eval Job Add EvalTrainingLossJob --- kge/job/entity_ranking.py | 15 +--- kge/job/eval.py | 76 +++++++++++++++++- kge/job/train.py | 157 ++++++++++++++++++-------------------- 3 files changed, 154 insertions(+), 94 deletions(-) diff --git a/kge/job/entity_ranking.py b/kge/job/entity_ranking.py index c47548396..5e70031c8 100644 --- a/kge/job/entity_ranking.py +++ b/kge/job/entity_ranking.py @@ -386,17 +386,10 @@ def merge_hist(target_hists, source_hists): ) epoch_time += time.time() - if isinstance(self.parent_job, TrainingJob): - self.parent_job: TrainingJob = self.parent_job - train_trace_entry = self.parent_job.run_epoch( - split=self.eval_split, echo_trace=False, do_backward=False - ) - metrics.update( - avg_loss=train_trace_entry["avg_loss"], - avg_penalty=train_trace_entry["avg_penalty"], - avg_penalties=train_trace_entry["avg_penalties"], - avg_cost=train_trace_entry["avg_cost"], - ) + train_trace_entry = self.eval_train_loss_job.run_epoch( + echo_trace=False, forward_only=False + ) + metrics.update(avg_loss=train_trace_entry["avg_loss"],) # compute trace trace_entry = dict( diff --git a/kge/job/eval.py b/kge/job/eval.py index 9bc1f4759..5f1320c82 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -1,6 +1,9 @@ +import time + import torch +from kge import Config, Dataset -from kge.job import Job +from kge.job import Job, TrainingJob class EvaluationJob(Job): @@ -30,6 +33,12 @@ def __init__(self, config, dataset, parent_job, model): self.filter_with_test = config.get("eval.filter_with_test") self.epoch = -1 + train_job_on_eval_split_config = config.clone() + train_job_on_eval_split_config.set("train.split", self.eval_split) + self.eval_train_loss_job = TrainingJob.create( + config=train_job_on_eval_split_config, parent_job=self, dataset=dataset + ) + #: Hooks run after training for an epoch. #: Signature: job, trace_entry self.post_epoch_hooks = [] @@ -98,6 +107,71 @@ def resume(self, checkpoint_file=None): ) +class EvalTrainingLossJob(EvaluationJob): + """ Entity ranking evaluation protocol """ + + def __init__(self, config: Config, dataset: Dataset, parent_job, model): + super().__init__(config, dataset, parent_job, model) + self.is_prepared = False + + if self.__class__ == EvalTrainingLossJob: + for f in Job.job_created_hooks: + f(self) + + @torch.no_grad() + def run(self) -> dict: + + epoch_time = -time.time() + + was_training = self.model.training + self.model.eval() + self.config.log( + "Evaluating on " + + self.eval_split + + " data (epoch {})...".format(self.epoch) + ) + epoch_time += time.time() + + train_trace_entry = self.eval_train_loss_job.run_epoch( + echo_trace=False, forward_only=False + ) + # compute trace + trace_entry = dict( + type="eval_train_loss_job", + scope="epoch", + split=self.eval_split, + epoch=self.epoch, + epoch_time=epoch_time, + event="eval_completed", + avg_loss=train_trace_entry["avg_loss"], + ) + for f in self.post_epoch_trace_hooks: + f(self, trace_entry) + + # if validation metric is not present, try to compute it + metric_name = self.config.get("valid.metric") + if metric_name not in trace_entry: + trace_entry[metric_name] = eval( + self.config.get("valid.metric_expr"), + None, + dict(config=self.config, **trace_entry), + ) + + # write out trace + trace_entry = self.trace(**trace_entry, echo=True, echo_prefix=" ", log=True) + + # reset model and return metrics + if was_training: + self.model.train() + self.config.log("Finished evaluating train loss on " + self.eval_split + " split.") + + for f in self.post_valid_hooks: + f(self, trace_entry) + + return trace_entry + + + # HISTOGRAM COMPUTATION ############################################################### diff --git a/kge/job/train.py b/kge/job/train.py index d03685032..8e26b0bf0 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -52,24 +52,26 @@ def __init__( self.batch_size: int = config.get("train.batch_size") self.device: str = self.config.get("job.device") self.train_split = config.get("train.split") - valid_conf = config.clone() - valid_conf.set("job.type", "eval") - if self.config.get("valid.split") != "": - valid_conf.set("eval.split", self.config.get("valid.split")) - valid_conf.set("eval.trace_level", self.config.get("valid.trace_level")) - self.valid_job = EvaluationJob.create( - valid_conf, dataset, parent_job=self, model=self.model - ) self.config.check("train.trace_level", ["batch", "epoch"]) self.trace_batch: bool = self.config.get("train.trace_level") == "batch" self.epoch: int = 0 - self.valid_trace: List[Dict[str, Any]] = [] - self.is_prepared: Dict[str, bool] = defaultdict(lambda: False) + self.is_prepared = False self.model.train() + if config.get("job.type") == "train": + valid_conf = config.clone() + valid_conf.set("job.type", "eval") + if self.config.get("valid.split") != "": + valid_conf.set("eval.split", self.config.get("valid.split")) + valid_conf.set("eval.trace_level", self.config.get("valid.trace_level")) + self.valid_job = EvaluationJob.create( + valid_conf, dataset, parent_job=self, model=self.model + ) + self.valid_trace: List[Dict[str, Any]] = [] + # attributes filled in by implementing classes - self.loader: Dict[str, DataLoader] = dict() - self.num_examples: Dict[str, int] = dict() + self.loader = None + self.num_examples = None self.type_str: Optional[str] = None #: Hooks run after training for an epoch. @@ -167,9 +169,7 @@ def run(self) -> None: # start a new epoch self.epoch += 1 self.config.log("Starting epoch {}...".format(self.epoch)) - trace_entry = self.run_epoch( - split=self.train_split, echo_trace=True, do_backward=True - ) + trace_entry = self.run_epoch(echo_trace=True, forward_only=True) for f in self.post_epoch_hooks: f(self, trace_entry) self.config.log("Finished epoch {}.".format(self.epoch)) @@ -290,16 +290,14 @@ def resume(self, checkpoint_file: str = None) -> None: else: self.config.log("No checkpoint found, starting from scratch...") - def run_epoch( - self, split: str, echo_trace: bool, do_backward: bool - ) -> Dict[str, Any]: + def run_epoch(self, echo_trace: bool, forward_only: bool) -> Dict[str, Any]: "Runs an epoch and returns a trace entry." # prepare the job is not done already - if not self.is_prepared[split]: - self._prepare(split=split) + if not self.is_prepared: + self._prepare() self.model.prepare_job(self) # let the model add some hooks - self.is_prepared[split] = True + self.is_prepared = True # variables that record various statitics sum_loss = 0.0 @@ -312,14 +310,14 @@ def run_epoch( optimizer_time = 0.0 # process each batch - for batch_index, batch in enumerate(self.loader[split]): + for batch_index, batch in enumerate(self.loader): for f in self.pre_batch_hooks: f(self) # process batch (preprocessing + forward pass + backward pass on loss) self.optimizer.zero_grad() batch_result: TrainingJob._ProcessBatchResult = self._process_batch( - batch_index, batch, do_backward + batch_index, batch, forward_only ) sum_loss += batch_result.avg_loss * batch_result.size @@ -328,7 +326,7 @@ def run_epoch( penalties_torch = self.model.penalty( epoch=self.epoch, batch_index=batch_index, - num_batches=len(self.loader[split]), + num_batches=len(self.loader), batch=batch, ) batch_forward_time += time.time() @@ -337,7 +335,7 @@ def run_epoch( batch_backward_time = batch_result.backward_time - time.time() penalty = 0.0 for index, (penalty_key, penalty_value_torch) in enumerate(penalties_torch): - if do_backward: + if forward_only: penalty_value_torch.backward() penalty += penalty_value_torch.item() sum_penalties[penalty_key] += penalty_value_torch.item() @@ -379,7 +377,7 @@ def run_epoch( # update parameters batch_optimizer_time = 0 - if do_backward: + if forward_only: batch_optimizer_time = -time.time() self.optimizer.step() batch_optimizer_time += time.time() @@ -390,10 +388,10 @@ def run_epoch( "type": self.type_str, "scope": "batch", "epoch": self.epoch, - "split": split, + "split": self.train_split, "batch": batch_index, "size": batch_result.size, - "batches": len(self.loader[split]), + "batches": len(self.loader), "lr": [group["lr"] for group in self.optimizer.param_groups], "avg_loss": batch_result.avg_loss, "penalties": [p.item() for k, p in penalties_torch], @@ -412,14 +410,14 @@ def run_epoch( ( "\r" # go back + "{} batch{: " - + str(1 + int(math.ceil(math.log10(len(self.loader[split]))))) + + str(1 + int(math.ceil(math.log10(len(self.loader))))) + "d}/{}" + ", avg_loss {:.4E}, penalty {:.4E}, cost {:.4E}, time {:6.2f}s" + "\033[K" # clear to right ).format( self.config.log_prefix, batch_index, - len(self.loader[split]) - 1, + len(self.loader) - 1, batch_result.avg_loss, penalty, cost_value, @@ -449,15 +447,14 @@ def run_epoch( type=self.type_str, scope="epoch", epoch=self.epoch, - split=split, - batches=len(self.loader[split]), - size=self.num_examples[split], + split=self.train_split, + batches=len(self.loader), + size=self.num_examples, lr=[group["lr"] for group in self.optimizer.param_groups], - avg_loss=sum_loss / self.num_examples[split], - avg_penalty=sum_penalty / len(self.loader[split]), - avg_penalties={k: p / len(self.loader[split]) for k, p in sum_penalties.items()}, - avg_cost=sum_loss / self.num_examples[split] - + sum_penalty / len(self.loader[split]), + avg_loss=sum_loss / self.num_examples, + avg_penalty=sum_penalty / len(self.loader), + avg_penalties={k: p / len(self.loader) for k, p in sum_penalties.items()}, + avg_cost=sum_loss / self.num_examples + sum_penalty / len(self.loader), epoch_time=epoch_time, prepare_time=prepare_time, forward_time=forward_time, @@ -473,7 +470,7 @@ def run_epoch( ) return trace_entry - def _prepare(self, split: str): + def _prepare(self): """Prepare this job for running. Sets (at least) the `loader`, `num_examples`, and `type_str` attributes of this @@ -496,7 +493,7 @@ class _ProcessBatchResult: backward_time: float def _process_batch( - self, batch_index: int, batch, do_backward: bool + self, batch_index: int, batch, forward_only: bool ) -> "TrainingJob._ProcessBatchResult": "Run forward and backward pass on batch and return results." raise NotImplementedError @@ -572,7 +569,7 @@ def __init__(self, config, dataset, parent_job=None): for f in Job.job_created_hooks: f(self) - def _prepare(self, split: str): + def _prepare(self): from kge.indexing import index_KvsAll_to_torch # determine enabled query types @@ -583,28 +580,28 @@ def _prepare(self, split: str): ] # construct relevant data structures - self.num_examples[split] = 0 + self.num_examples = 0 for query_type in self.query_types: index_type = ( "sp_to_o" if query_type == "sp_" else ("so_to_p" if query_type == "s_o" else "po_to_s") ) - index = self.dataset.index(f"{split}_{index_type}") - self.num_examples[split] += len(index) - self.query_end_index[query_type + split] = self.num_examples[split] + index = self.dataset.index(f"{self.train_split}_{index_type}") + self.num_examples += len(index) + self.query_end_index[query_type] = self.num_examples # Convert indexes to pytorch tensors (as described above). ( - self.queries[query_type + split], - self.labels[query_type + split], - self.label_offsets[query_type + split], + self.queries[query_type], + self.labels[query_type], + self.label_offsets[query_type], ) = index_KvsAll_to_torch(index) # create dataloader - self.loader[split] = torch.utils.data.DataLoader( - range(self.num_examples[split]), - collate_fn=self._get_collate_fun(split), + self.loader = torch.utils.data.DataLoader( + range(self.num_examples), + collate_fn=self._get_collate_fun(), shuffle=True, batch_size=self.batch_size, num_workers=self.config.get("train.num_workers"), @@ -614,7 +611,7 @@ def _prepare(self, split: str): pin_memory=self.config.get("train.pin_memory"), ) - def _get_collate_fun(self, split: str): + def _get_collate_fun(self): # create the collate function def collate(batch): """For a batch of size n, returns a dictionary of: @@ -633,15 +630,11 @@ def collate(batch): for example_index in batch: start = 0 for query_type_index, query_type in enumerate(self.query_types): - end = self.query_end_index[query_type + split] + end = self.query_end_index[query_type] if example_index < end: example_index -= start - num_ones += self.label_offsets[query_type + split][ - example_index + 1 - ] - num_ones -= self.label_offsets[query_type + split][ - example_index - ] + num_ones += self.label_offsets[query_type][example_index + 1] + num_ones -= self.label_offsets[query_type][example_index] break start = end @@ -654,13 +647,13 @@ def collate(batch): for batch_index, example_index in enumerate(batch): start = 0 for query_type_index, query_type in enumerate(self.query_types): - end = self.query_end_index[query_type + split] + end = self.query_end_index[query_type] if example_index < end: example_index -= start query_type_indexes_batch[batch_index] = query_type_index - queries = self.queries[query_type + split] - label_offsets = self.label_offsets[query_type + split] - labels = self.labels[query_type + split] + queries = self.queries[query_type] + label_offsets = self.label_offsets[query_type] + labels = self.labels[query_type] if query_type == "sp_": query_col_1, query_col_2, target_col = S, P, O elif query_type == "s_o": @@ -702,7 +695,7 @@ def collate(batch): return collate def _process_batch( - self, batch_index, batch, do_backward: bool + self, batch_index, batch, forward_only: bool ) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() @@ -774,7 +767,7 @@ def _process_batch( loss_value_total = loss_value.item() forward_time += time.time() backward_time -= time.time() - if do_backward: + if forward_only: loss_value.backward() backward_time += time.time() @@ -811,13 +804,13 @@ def __init__(self, config, dataset, parent_job=None): for f in Job.job_created_hooks: f(self) - def _prepare(self, split: str): + def _prepare(self): """Construct dataloader""" - self.num_examples[split] = self.dataset.split(split).size(0) - self.loader[split] = torch.utils.data.DataLoader( - range(self.num_examples[split]), - collate_fn=self._get_collate_fun(split), + self.num_examples = self.dataset.split(self.train_split).size(0) + self.loader = torch.utils.data.DataLoader( + range(self.num_examples), + collate_fn=self._get_collate_fun(), shuffle=True, batch_size=self.batch_size, num_workers=self.config.get("train.num_workers"), @@ -827,7 +820,7 @@ def _prepare(self, split: str): pin_memory=self.config.get("train.pin_memory"), ) - def _get_collate_fun(self, split: str): + def _get_collate_fun(self): # create the collate function def collate(batch): """For a batch of size n, returns a tuple of: @@ -837,7 +830,7 @@ def collate(batch): in order S,P,O) """ - triples = self.dataset.split(split)[batch, :].long() + triples = self.dataset.split(self.train_split)[batch, :].long() # labels = torch.zeros((len(batch), self._sampler.num_negatives_total + 1)) # labels[:, 0] = 1 # labels = labels.view(-1) @@ -850,7 +843,7 @@ def collate(batch): return collate def _process_batch( - self, batch_index, batch, do_backward: bool + self, batch_index, batch, forward_only: bool ) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() @@ -1021,7 +1014,7 @@ def _process_batch( # backward pass for this chunk backward_time -= time.time() - if do_backward: + if forward_only: loss_value_torch.backward() backward_time += time.time() @@ -1043,14 +1036,14 @@ def __init__(self, config, dataset, parent_job=None): for f in Job.job_created_hooks: f(self) - def _prepare(self, split: str): + def _prepare(self): """Construct dataloader""" - self.num_examples[split] = self.dataset.split(split).size(0) - self.loader[split] = torch.utils.data.DataLoader( - range(self.num_examples[split]), + self.num_examples = self.dataset.split(self.train_split).size(0) + self.loader = torch.utils.data.DataLoader( + range(self.num_examples), collate_fn=lambda batch: { - "triples": self.dataset.split(split)[batch, :].long() + "triples": self.dataset.split(self.train_split)[batch, :].long() }, shuffle=True, batch_size=self.batch_size, @@ -1062,7 +1055,7 @@ def _prepare(self, split: str): ) def _process_batch( - self, batch_index, batch, do_backward: bool + self, batch_index, batch, forward_only: bool ) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() @@ -1077,7 +1070,7 @@ def _process_batch( loss_value = loss_value_sp.item() forward_time += time.time() backward_time = -time.time() - if do_backward: + if forward_only: loss_value_sp.backward() backward_time += time.time() @@ -1088,7 +1081,7 @@ def _process_batch( loss_value += loss_value_po.item() forward_time += time.time() backward_time -= time.time() - if do_backward: + if forward_only: loss_value_po.backward() backward_time += time.time() From f204b8eadef40305c2623e26758b8a413f9d4478 Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Fri, 22 May 2020 11:55:00 +0200 Subject: [PATCH 3/9] integrate training_loss eval --- kge/config-default.yaml | 3 ++- kge/job/eval.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/kge/config-default.yaml b/kge/config-default.yaml index bce72d8dd..5e4fdac5d 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -324,7 +324,8 @@ eval: # mean_reciprocal_rank_filtered_with_test. filter_with_test: True - # Type of evaluation (entity_ranking only at the moment) + # Type of evaluation (entity_ranking or only training_loss, while entity + # ranking runs training_loss as well.) type: entity_ranking # Compute Hits@K for these choices of K diff --git a/kge/job/eval.py b/kge/job/eval.py index 5f1320c82..0cfde3299 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -85,6 +85,10 @@ def create(config, dataset, parent_job=None, model=None): return EntityPairRankingJob( config, dataset, parent_job=parent_job, model=model ) + elif config.get("eval.type") == "training_loss": + return EvalTrainingLossJob( + config, dataset, parent_job=parent_job, model=model + ) else: raise ValueError("eval.type") @@ -137,7 +141,7 @@ def run(self) -> dict: ) # compute trace trace_entry = dict( - type="eval_train_loss_job", + type="training_loss", scope="epoch", split=self.eval_split, epoch=self.epoch, @@ -163,7 +167,9 @@ def run(self) -> dict: # reset model and return metrics if was_training: self.model.train() - self.config.log("Finished evaluating train loss on " + self.eval_split + " split.") + self.config.log( + "Finished evaluating train loss on " + self.eval_split + " split." + ) for f in self.post_valid_hooks: f(self, trace_entry) @@ -171,7 +177,6 @@ def run(self) -> dict: return trace_entry - # HISTOGRAM COMPUTATION ############################################################### From b8625d304c5670f5c091b3fa5f7ec1524a07c6bc Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Sat, 23 May 2020 00:32:53 +0200 Subject: [PATCH 4/9] Adress reviews --- kge/job/eval.py | 2 +- kge/job/train.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/kge/job/eval.py b/kge/job/eval.py index 0cfde3299..4dd42cfe5 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -137,7 +137,7 @@ def run(self) -> dict: epoch_time += time.time() train_trace_entry = self.eval_train_loss_job.run_epoch( - echo_trace=False, forward_only=False + echo_trace=False, forward_only=True ) # compute trace trace_entry = dict( diff --git a/kge/job/train.py b/kge/job/train.py index 8e26b0bf0..c2f13d712 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -169,7 +169,7 @@ def run(self) -> None: # start a new epoch self.epoch += 1 self.config.log("Starting epoch {}...".format(self.epoch)) - trace_entry = self.run_epoch(echo_trace=True, forward_only=True) + trace_entry = self.run_epoch(echo_trace=True, forward_only=False) for f in self.post_epoch_hooks: f(self, trace_entry) self.config.log("Finished epoch {}.".format(self.epoch)) @@ -335,7 +335,7 @@ def run_epoch(self, echo_trace: bool, forward_only: bool) -> Dict[str, Any]: batch_backward_time = batch_result.backward_time - time.time() penalty = 0.0 for index, (penalty_key, penalty_value_torch) in enumerate(penalties_torch): - if forward_only: + if not forward_only: penalty_value_torch.backward() penalty += penalty_value_torch.item() sum_penalties[penalty_key] += penalty_value_torch.item() @@ -377,7 +377,7 @@ def run_epoch(self, echo_trace: bool, forward_only: bool) -> Dict[str, Any]: # update parameters batch_optimizer_time = 0 - if forward_only: + if not forward_only: batch_optimizer_time = -time.time() self.optimizer.step() batch_optimizer_time += time.time() @@ -767,7 +767,7 @@ def _process_batch( loss_value_total = loss_value.item() forward_time += time.time() backward_time -= time.time() - if forward_only: + if not forward_only: loss_value.backward() backward_time += time.time() @@ -1014,7 +1014,7 @@ def _process_batch( # backward pass for this chunk backward_time -= time.time() - if forward_only: + if not forward_only: loss_value_torch.backward() backward_time += time.time() @@ -1070,7 +1070,7 @@ def _process_batch( loss_value = loss_value_sp.item() forward_time += time.time() backward_time = -time.time() - if forward_only: + if not forward_only: loss_value_sp.backward() backward_time += time.time() @@ -1081,7 +1081,7 @@ def _process_batch( loss_value += loss_value_po.item() forward_time += time.time() backward_time -= time.time() - if forward_only: + if not forward_only: loss_value_po.backward() backward_time += time.time() From ae1cdbfa157acf7d4d4cec0a917a4072f2f2522e Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Sat, 23 May 2020 02:45:41 +0200 Subject: [PATCH 5/9] Small revision of EvaluationJob: run() is implemented now in EvaluationJob with standard stuff that has to be done for every EvaluationJob and then calls self._run() Shift EntityRanking specific stuff to there Make EvalTrainingLossJob an instance variable of EvaluationJob and run it as post_epoch_trace_hook Add self.verbose to EvaluationJob and use it Adress code review --- kge/job/entity_ranking.py | 53 ++++------------ kge/job/eval.py | 129 ++++++++++++++++++++++++-------------- kge/job/train.py | 33 +++++----- 3 files changed, 112 insertions(+), 103 deletions(-) diff --git a/kge/job/entity_ranking.py b/kge/job/entity_ranking.py index 5e70031c8..9e0cc01b9 100644 --- a/kge/job/entity_ranking.py +++ b/kge/job/entity_ranking.py @@ -1,5 +1,6 @@ import math import time +from typing import Dict, Any import torch import kge.job @@ -13,18 +14,23 @@ class EntityRankingJob(EvaluationJob): def __init__(self, config: Config, dataset: Dataset, parent_job, model): super().__init__(config, dataset, parent_job, model) - self.is_prepared = False if self.__class__ == EntityRankingJob: for f in Job.job_created_hooks: f(self) + max_k = min( + self.dataset.num_entities(), max(self.config.get("eval.hits_at_k_s")) + ) + self.hits_at_k_s = list( + filter(lambda x: x <= max_k, self.config.get("eval.hits_at_k_s")) + ) + self.filter_with_test = config.get("eval.filter_with_test") + + def _prepare(self): """Construct all indexes needed to run.""" - if self.is_prepared: - return - # create data and precompute indexes self.triples = self.dataset.split(self.config.get("eval.split")) for split in self.filter_splits: @@ -75,16 +81,8 @@ def _collate(self, batch): return batch, label_coords, test_label_coords @torch.no_grad() - def run(self) -> dict: - self._prepare() - - was_training = self.model.training - self.model.eval() - self.config.log( - "Evaluating on " - + self.eval_split - + " data (epoch {})...".format(self.epoch) - ) + def _run(self) -> Dict[str, Any]: + num_entities = self.dataset.num_entities() # we also filter with test data if requested @@ -386,11 +384,6 @@ def merge_hist(target_hists, source_hists): ) epoch_time += time.time() - train_trace_entry = self.eval_train_loss_job.run_epoch( - echo_trace=False, forward_only=False - ) - metrics.update(avg_loss=train_trace_entry["avg_loss"],) - # compute trace trace_entry = dict( type="entity_ranking", @@ -404,28 +397,6 @@ def merge_hist(target_hists, source_hists): event="eval_completed", **metrics, ) - for f in self.post_epoch_trace_hooks: - f(self, trace_entry) - - # if validation metric is not present, try to compute it - metric_name = self.config.get("valid.metric") - if metric_name not in trace_entry: - trace_entry[metric_name] = eval( - self.config.get("valid.metric_expr"), - None, - dict(config=self.config, **trace_entry), - ) - - # write out trace - trace_entry = self.trace(**trace_entry, echo=True, echo_prefix=" ", log=True) - - # reset model and return metrics - if was_training: - self.model.train() - self.config.log("Finished evaluating on " + self.eval_split + " split.") - - for f in self.post_valid_hooks: - f(self, trace_entry) return trace_entry diff --git a/kge/job/eval.py b/kge/job/eval.py index 4dd42cfe5..fddc3edbf 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -1,4 +1,5 @@ import time +from typing import Any, Dict import torch from kge import Config, Dataset @@ -15,12 +16,6 @@ def __init__(self, config, dataset, parent_job, model): self.model = model self.batch_size = config.get("eval.batch_size") self.device = self.config.get("job.device") - max_k = min( - self.dataset.num_entities(), max(self.config.get("eval.hits_at_k_s")) - ) - self.hits_at_k_s = list( - filter(lambda x: x <= max_k, self.config.get("eval.hits_at_k_s")) - ) self.config.check("train.trace_level", ["example", "batch", "epoch"]) self.trace_examples = self.config.get("eval.trace_level") == "example" self.trace_batch = ( @@ -30,14 +25,10 @@ def __init__(self, config, dataset, parent_job, model): self.filter_splits = self.config.get("eval.filter_splits") if self.eval_split not in self.filter_splits: self.filter_splits.append(self.eval_split) - self.filter_with_test = config.get("eval.filter_with_test") self.epoch = -1 - train_job_on_eval_split_config = config.clone() - train_job_on_eval_split_config.set("train.split", self.eval_split) - self.eval_train_loss_job = TrainingJob.create( - config=train_job_on_eval_split_config, parent_job=self, dataset=dataset - ) + self.verbose = True + self.is_prepared = False #: Hooks run after training for an epoch. #: Signature: job, trace_entry @@ -68,6 +59,22 @@ def __init__(self, config, dataset, parent_job, model): if config.get("eval.metrics_per.argument_frequency"): self.hist_hooks.append(hist_per_frequency_percentile) + # Add the training loss as a default to every evaluation job + # TODO: create AggregatingEvaluationsJob that runs and aggregates a list + # of EvaluationAjobs, such that users can configure combinations of + # EvalJobs themselves. Then this can be removed. + # See https://github.com/uma-pi1/kge/issues/102 + if not isinstance(self, EvalTrainingLossJob): + self.eval_train_loss_job = EvalTrainingLossJob( + config, dataset, parent_job=self, model=model + ) + self.eval_train_loss_job.verbose = False + self.post_epoch_trace_hooks.append( + lambda job, trace: trace.update( + avg_loss=self.eval_train_loss_job.run()["avg_loss"] + ) + ) + # all done, run job_created_hooks if necessary if self.__class__ == EvaluationJob: for f in Job.job_created_hooks: @@ -92,7 +99,55 @@ def create(config, dataset, parent_job=None, model=None): else: raise ValueError("eval.type") - def run(self) -> dict: + def _prepare(self): + """Prepare this job for running. Guaranteed to be called exactly once + """ + raise NotImplementedError + + def run(self) -> Dict[str, Any]: + + if not self.is_prepared: + self._prepare() + self.model.prepare_job(self) # let the model add some hooks + self.is_prepared = True + + was_training = self.model.training + self.model.eval() + self.config.log( + "Evaluating on " + + self.eval_split + + " data (epoch {})...".format(self.epoch), + echo=self.verbose + ) + + trace_entry = self._run() + + # if validation metric is not present, try to compute it + metric_name = self.config.get("valid.metric") + if metric_name not in trace_entry: + trace_entry[metric_name] = eval( + self.config.get("valid.metric_expr"), + None, + dict(config=self.config, **trace_entry), + ) + + for f in self.post_epoch_trace_hooks: + f(self, trace_entry) + + # write out trace + trace_entry = self.trace(**trace_entry, echo=self.verbose, echo_prefix=" ", log=True) + + # reset model and return metrics + if was_training: + self.model.train() + self.config.log("Finished evaluating on " + self.eval_split + " split.", echo=self.verbose) + + for f in self.post_valid_hooks: + f(self, trace_entry) + + return trace_entry + + def _run(self) -> Dict[str, Any]: """ Compute evaluation metrics, output results to trace file """ raise NotImplementedError @@ -116,28 +171,30 @@ class EvalTrainingLossJob(EvaluationJob): def __init__(self, config: Config, dataset: Dataset, parent_job, model): super().__init__(config, dataset, parent_job, model) - self.is_prepared = False + self.is_prepared = True + + train_job_on_eval_split_config = config.clone() + train_job_on_eval_split_config.set("train.split", self.eval_split) + self._train_job = TrainingJob.create( + config=train_job_on_eval_split_config, parent_job=self, dataset=dataset + ) + + self._train_job_verbose = False if self.__class__ == EvalTrainingLossJob: for f in Job.job_created_hooks: f(self) @torch.no_grad() - def run(self) -> dict: + def _run(self) -> Dict[str, Any]: epoch_time = -time.time() - was_training = self.model.training - self.model.eval() - self.config.log( - "Evaluating on " - + self.eval_split - + " data (epoch {})...".format(self.epoch) - ) + self.epoch = self.parent_job.epoch epoch_time += time.time() - train_trace_entry = self.eval_train_loss_job.run_epoch( - echo_trace=False, forward_only=True + train_trace_entry = self._train_job.run_epoch( + verbose=self._train_job_verbose, forward_only=True ) # compute trace trace_entry = dict( @@ -149,30 +206,6 @@ def run(self) -> dict: event="eval_completed", avg_loss=train_trace_entry["avg_loss"], ) - for f in self.post_epoch_trace_hooks: - f(self, trace_entry) - - # if validation metric is not present, try to compute it - metric_name = self.config.get("valid.metric") - if metric_name not in trace_entry: - trace_entry[metric_name] = eval( - self.config.get("valid.metric_expr"), - None, - dict(config=self.config, **trace_entry), - ) - - # write out trace - trace_entry = self.trace(**trace_entry, echo=True, echo_prefix=" ", log=True) - - # reset model and return metrics - if was_training: - self.model.train() - self.config.log( - "Finished evaluating train loss on " + self.eval_split + " split." - ) - - for f in self.post_valid_hooks: - f(self, trace_entry) return trace_entry diff --git a/kge/job/train.py b/kge/job/train.py index c2f13d712..06073fb24 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -169,7 +169,7 @@ def run(self) -> None: # start a new epoch self.epoch += 1 self.config.log("Starting epoch {}...".format(self.epoch)) - trace_entry = self.run_epoch(echo_trace=True, forward_only=False) + trace_entry = self.run_epoch() for f in self.post_epoch_hooks: f(self, trace_entry) self.config.log("Finished epoch {}.".format(self.epoch)) @@ -290,7 +290,7 @@ def resume(self, checkpoint_file: str = None) -> None: else: self.config.log("No checkpoint found, starting from scratch...") - def run_epoch(self, echo_trace: bool, forward_only: bool) -> Dict[str, Any]: + def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str, Any]: "Runs an epoch and returns a trace entry." # prepare the job is not done already @@ -405,7 +405,7 @@ def run_epoch(self, echo_trace: bool, forward_only: bool) -> Dict[str, Any]: for f in self.post_batch_trace_hooks: f(self, batch_trace) self.trace(**batch_trace, event="batch_completed") - if echo_trace: + if verbose: print( ( "\r" # go back @@ -466,7 +466,7 @@ def run_epoch(self, echo_trace: bool, forward_only: bool) -> Dict[str, Any]: for f in self.post_epoch_trace_hooks: f(self, trace_entry) trace_entry = self.trace( - **trace_entry, echo=echo_trace, echo_prefix=" ", log=True + **trace_entry, echo=verbose, echo_prefix=" ", log=True ) return trace_entry @@ -547,6 +547,21 @@ def __init__(self, config, dataset, parent_job=None): ) ) + self.queries = None + self.labels = None + self.label_offsets = None + self.query_end_index = None + + config.log("Initializing 1-to-N training job...") + self.type_str = "KvsAll" + + if self.__class__ == TrainingJobKvsAll: + for f in Job.job_created_hooks: + f(self) + + def _prepare(self): + from kge.indexing import index_KvsAll_to_torch + #' for each query type: list of queries self.queries = {} @@ -562,16 +577,6 @@ def __init__(self, config, dataset, parent_job=None): #' example of that type in the list of all examples self.query_end_index = {} - config.log("Initializing 1-to-N training job...") - self.type_str = "KvsAll" - - if self.__class__ == TrainingJobKvsAll: - for f in Job.job_created_hooks: - f(self) - - def _prepare(self): - from kge.indexing import index_KvsAll_to_torch - # determine enabled query types self.query_types = [ key From bac568cba78ebdccdbcfe0517d43f7fcc16f28cf Mon Sep 17 00:00:00 2001 From: Rainer Gemulla Date: Mon, 25 May 2020 11:28:02 +0200 Subject: [PATCH 6/9] Update config-default.yaml --- kge/config-default.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kge/config-default.yaml b/kge/config-default.yaml index 5e4fdac5d..07f851900 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -324,8 +324,8 @@ eval: # mean_reciprocal_rank_filtered_with_test. filter_with_test: True - # Type of evaluation (entity_ranking or only training_loss, while entity - # ranking runs training_loss as well.) + # Type of evaluation (entity_ranking or training_loss). Currently, + # entity_ranking runs training_loss as well. type: entity_ranking # Compute Hits@K for these choices of K From ba9faf54b8a1a49cf93a541b2d7a58dfe3a92a1e Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Mon, 25 May 2020 14:29:23 +0200 Subject: [PATCH 7/9] Adress code review eval.py: set negativ sampling filtering split to train.split rename to TrainingLossEvaluationJob train.py: initialize only things for train if TrainingJob is used for training more ifs - --- kge/job/eval.py | 11 ++++++----- kge/job/train.py | 19 +++++++++++++------ kge/util/sampler.py | 18 +++++++++++++----- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/kge/job/eval.py b/kge/job/eval.py index fddc3edbf..d34f7eeba 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -64,8 +64,8 @@ def __init__(self, config, dataset, parent_job, model): # of EvaluationAjobs, such that users can configure combinations of # EvalJobs themselves. Then this can be removed. # See https://github.com/uma-pi1/kge/issues/102 - if not isinstance(self, EvalTrainingLossJob): - self.eval_train_loss_job = EvalTrainingLossJob( + if not isinstance(self, TrainingLossEvaluationJob): + self.eval_train_loss_job = TrainingLossEvaluationJob( config, dataset, parent_job=self, model=model ) self.eval_train_loss_job.verbose = False @@ -93,7 +93,7 @@ def create(config, dataset, parent_job=None, model=None): config, dataset, parent_job=parent_job, model=model ) elif config.get("eval.type") == "training_loss": - return EvalTrainingLossJob( + return TrainingLossEvaluationJob( config, dataset, parent_job=parent_job, model=model ) else: @@ -166,7 +166,7 @@ def resume(self, checkpoint_file=None): ) -class EvalTrainingLossJob(EvaluationJob): +class TrainingLossEvaluationJob(EvaluationJob): """ Entity ranking evaluation protocol """ def __init__(self, config: Config, dataset: Dataset, parent_job, model): @@ -175,13 +175,14 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model): train_job_on_eval_split_config = config.clone() train_job_on_eval_split_config.set("train.split", self.eval_split) + train_job_on_eval_split_config.set("negative_sampling.filtering.split", self.config.get("train.split")) self._train_job = TrainingJob.create( config=train_job_on_eval_split_config, parent_job=self, dataset=dataset ) self._train_job_verbose = False - if self.__class__ == EvalTrainingLossJob: + if self.__class__ == TrainingLossEvaluationJob: for f in Job.job_created_hooks: f(self) diff --git a/kge/job/train.py b/kge/job/train.py index 06073fb24..e620f62d3 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -45,8 +45,6 @@ def __init__( super().__init__(config, dataset, parent_job) self.model: KgeModel = KgeModel.create(config, dataset) - self.optimizer = KgeOptimizer.create(config, self.model) - self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) self.loss = KgeLoss.create(config) self.abort_on_nan: bool = config.get("train.abort_on_nan") self.batch_size: int = config.get("train.batch_size") @@ -56,9 +54,14 @@ def __init__( self.trace_batch: bool = self.config.get("train.trace_level") == "batch" self.epoch: int = 0 self.is_prepared = False - self.model.train() if config.get("job.type") == "train": + + self.model.train() + + self.optimizer = KgeOptimizer.create(config, self.model) + self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) + valid_conf = config.clone() valid_conf.set("job.type", "eval") if self.config.get("valid.split") != "": @@ -315,7 +318,8 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str f(self) # process batch (preprocessing + forward pass + backward pass on loss) - self.optimizer.zero_grad() + if not forward_only: + self.optimizer.zero_grad() batch_result: TrainingJob._ProcessBatchResult = self._process_batch( batch_index, batch, forward_only ) @@ -450,7 +454,6 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str split=self.train_split, batches=len(self.loader), size=self.num_examples, - lr=[group["lr"] for group in self.optimizer.param_groups], avg_loss=sum_loss / self.num_examples, avg_penalty=sum_penalty / len(self.loader), avg_penalties={k: p / len(self.loader) for k, p in sum_penalties.items()}, @@ -463,6 +466,10 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str other_time=other_time, event="epoch_completed", ) + if not forward_only: + trace_entry.update( + lr=[group["lr"] for group in self.optimizer.param_groups], + ) for f in self.post_epoch_trace_hooks: f(self, trace_entry) trace_entry = self.trace( @@ -634,7 +641,7 @@ def collate(batch): num_ones = 0 for example_index in batch: start = 0 - for query_type_index, query_type in enumerate(self.query_types): + for query_type in self.query_types: end = self.query_end_index[query_type] if example_index < end: example_index -= start diff --git a/kge/util/sampler.py b/kge/util/sampler.py index 06c5f1a25..75d5f3d4b 100644 --- a/kge/util/sampler.py +++ b/kge/util/sampler.py @@ -148,7 +148,12 @@ def _filter_and_resample( cols = [[P, O], [S, O], [S, P]][slot] pairs = positive_triples[:, cols] for i in range(positive_triples.size(0)): - positives = index.get((pairs[i][0].item(), pairs[i][1].item())).numpy() + pair = (pairs[i][0].item(), pairs[i][1].item()) + positives = ( + index.get(pair).numpy() + if pair in index + else torch.IntTensor([]).numpy() + ) # indices of samples that have to be sampled again resample_idx = where_in(negative_samples[i].numpy(), positives) # number of new samples needed @@ -229,9 +234,7 @@ def _sample_shared( # contain its positive, drop that positive. For all other rows, drop a random # position. shared_samples_index = {s: j for j, s in enumerate(shared_samples)} - replacement = np.random.choice( - num_distinct + 1, batch_size, replace=True - ) + replacement = np.random.choice(num_distinct + 1, batch_size, replace=True) drop = torch.tensor( [ shared_samples_index.get(s, replacement[i]) @@ -277,13 +280,18 @@ def _filter_and_resample_fast( positives_index = numba.typed.Dict() for i in range(batch_size): pair = (pairs[i][0], pairs[i][1]) - positives_index[pair] = index.get(pair).numpy() + positives_index[pair] = ( + index.get(pair).numpy() + if pair in index + else torch.IntTensor([]).numpy() + ) negative_samples = negative_samples.numpy() KgeUniformSampler._filter_and_resample_numba( negative_samples, pairs, positives_index, batch_size, int(voc_size), ) return torch.tensor(negative_samples, dtype=torch.int64) + @staticmethod @numba.njit def _filter_and_resample_numba( negative_samples, pairs, positives_index, batch_size, voc_size From 615e4a057c2ac3b328f55efe14da19ec99d1256c Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Tue, 26 May 2020 20:44:46 +0200 Subject: [PATCH 8/9] Adress code review eval.py: set negativ sampling filtering split to train.split rename to TrainingLossEvaluationJob train.py: initialize only things for train if TrainingJob is used for training more ifs --- kge/job/eval.py | 4 +-- kge/job/train.py | 83 +++++++++++++++++++++++++++++------------------- 2 files changed, 53 insertions(+), 34 deletions(-) diff --git a/kge/job/eval.py b/kge/job/eval.py index d34f7eeba..3a672f0f0 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -177,7 +177,7 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model): train_job_on_eval_split_config.set("train.split", self.eval_split) train_job_on_eval_split_config.set("negative_sampling.filtering.split", self.config.get("train.split")) self._train_job = TrainingJob.create( - config=train_job_on_eval_split_config, parent_job=self, dataset=dataset + config=train_job_on_eval_split_config, parent_job=self, dataset=dataset, initialize_for_forward_only=True, ) self._train_job_verbose = False @@ -195,7 +195,7 @@ def _run(self) -> Dict[str, Any]: epoch_time += time.time() train_trace_entry = self._train_job.run_epoch( - verbose=self._train_job_verbose, forward_only=True + verbose=self._train_job_verbose ) # compute trace trace_entry = dict( diff --git a/kge/job/train.py b/kge/job/train.py index e620f62d3..ae988af85 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -39,7 +39,11 @@ class TrainingJob(Job): """ def __init__( - self, config: Config, dataset: Dataset, parent_job: Job = None + self, + config: Config, + dataset: Dataset, + parent_job: Job = None, + initialize_for_forward_only=False, ) -> None: from kge.job import EvaluationJob @@ -54,8 +58,9 @@ def __init__( self.trace_batch: bool = self.config.get("train.trace_level") == "batch" self.epoch: int = 0 self.is_prepared = False + self.is_forward_only = initialize_for_forward_only - if config.get("job.type") == "train": + if not initialize_for_forward_only: self.model.train() @@ -107,21 +112,35 @@ def __init__( @staticmethod def create( - config: Config, dataset: Dataset, parent_job: Job = None + config: Config, + dataset: Dataset, + parent_job: Job = None, + initialize_for_forward_only=False, ) -> "TrainingJob": """Factory method to create a training job.""" if config.get("train.type") == "KvsAll": - return TrainingJobKvsAll(config, dataset, parent_job) + return TrainingJobKvsAll( + config, dataset, parent_job, initialize_for_forward_only, + ) elif config.get("train.type") == "negative_sampling": - return TrainingJobNegativeSampling(config, dataset, parent_job) + return TrainingJobNegativeSampling( + config, dataset, parent_job, initialize_for_forward_only, + ) elif config.get("train.type") == "1vsAll": - return TrainingJob1vsAll(config, dataset, parent_job) + return TrainingJob1vsAll( + config, dataset, parent_job, initialize_for_forward_only, + ) else: # perhaps TODO: try class with specified name -> extensibility raise ValueError("train.type") def run(self) -> None: """Start/resume the training job and run to completion.""" + if self.is_forward_only: + raise Exception( + f"{self.__class__.__name__} was initialized for forward only. You can only call run_epoch()" + ) + self.config.log("Starting training...") checkpoint_every = self.config.get("train.checkpoint.every") checkpoint_keep = self.config.get("train.checkpoint.keep") @@ -293,7 +312,7 @@ def resume(self, checkpoint_file: str = None) -> None: else: self.config.log("No checkpoint found, starting from scratch...") - def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str, Any]: + def run_epoch(self, verbose: bool = True) -> Dict[str, Any]: "Runs an epoch and returns a trace entry." # prepare the job is not done already @@ -318,10 +337,10 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str f(self) # process batch (preprocessing + forward pass + backward pass on loss) - if not forward_only: + if not self.is_forward_only: self.optimizer.zero_grad() batch_result: TrainingJob._ProcessBatchResult = self._process_batch( - batch_index, batch, forward_only + batch_index, batch, self.is_forward_only ) sum_loss += batch_result.avg_loss * batch_result.size @@ -339,7 +358,7 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str batch_backward_time = batch_result.backward_time - time.time() penalty = 0.0 for index, (penalty_key, penalty_value_torch) in enumerate(penalties_torch): - if not forward_only: + if not self.is_forward_only: penalty_value_torch.backward() penalty += penalty_value_torch.item() sum_penalties[penalty_key] += penalty_value_torch.item() @@ -381,7 +400,7 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str # update parameters batch_optimizer_time = 0 - if not forward_only: + if not self.is_forward_only: batch_optimizer_time = -time.time() self.optimizer.step() batch_optimizer_time += time.time() @@ -466,7 +485,7 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str other_time=other_time, event="epoch_completed", ) - if not forward_only: + if not self.is_forward_only: trace_entry.update( lr=[group["lr"] for group in self.optimizer.param_groups], ) @@ -500,7 +519,7 @@ class _ProcessBatchResult: backward_time: float def _process_batch( - self, batch_index: int, batch, forward_only: bool + self, batch_index: int, batch ) -> "TrainingJob._ProcessBatchResult": "Run forward and backward pass on batch and return results." raise NotImplementedError @@ -517,8 +536,10 @@ class TrainingJobKvsAll(TrainingJob): - Example: a query + its labels, e.g., (John,marriedTo), [Jane] """ - def __init__(self, config, dataset, parent_job=None): - super().__init__(config, dataset, parent_job) + def __init__( + self, config, dataset, parent_job=None, initialize_for_forward_only=False + ): + super().__init__(config, dataset, parent_job, initialize_for_forward_only) self.label_smoothing = config.check_range( "KvsAll.label_smoothing", float("-inf"), 1.0, max_inclusive=False ) @@ -706,9 +727,7 @@ def collate(batch): return collate - def _process_batch( - self, batch_index, batch, forward_only: bool - ) -> TrainingJob._ProcessBatchResult: + def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() queries_batch = batch["queries"].to(self.device) @@ -779,7 +798,7 @@ def _process_batch( loss_value_total = loss_value.item() forward_time += time.time() backward_time -= time.time() - if not forward_only: + if not self.is_forward_only: loss_value.backward() backward_time += time.time() @@ -790,8 +809,10 @@ def _process_batch( class TrainingJobNegativeSampling(TrainingJob): - def __init__(self, config, dataset, parent_job=None): - super().__init__(config, dataset, parent_job) + def __init__( + self, config, dataset, parent_job=None, initialize_for_forward_only=False + ): + super().__init__(config, dataset, parent_job, initialize_for_forward_only) self._sampler = KgeSampler.create(config, "negative_sampling", dataset) self._implementation = self.config.check( "negative_sampling.implementation", ["triple", "all", "batch", "auto"], @@ -854,9 +875,7 @@ def collate(batch): return collate - def _process_batch( - self, batch_index, batch, forward_only: bool - ) -> TrainingJob._ProcessBatchResult: + def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() batch_triples = batch["triples"].to(self.device) @@ -1026,7 +1045,7 @@ def _process_batch( # backward pass for this chunk backward_time -= time.time() - if not forward_only: + if not self.is_forward_only: loss_value_torch.backward() backward_time += time.time() @@ -1039,8 +1058,10 @@ def _process_batch( class TrainingJob1vsAll(TrainingJob): """Samples SPO pairs and queries sp_ and _po, treating all other entities as negative.""" - def __init__(self, config, dataset, parent_job=None): - super().__init__(config, dataset, parent_job) + def __init__( + self, config, dataset, parent_job=None, initialize_for_forward_only=False + ): + super().__init__(config, dataset, parent_job, initialize_for_forward_only) config.log("Initializing spo training job...") self.type_str = "1vsAll" @@ -1066,9 +1087,7 @@ def _prepare(self): pin_memory=self.config.get("train.pin_memory"), ) - def _process_batch( - self, batch_index, batch, forward_only: bool - ) -> TrainingJob._ProcessBatchResult: + def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() triples = batch["triples"].to(self.device) @@ -1082,7 +1101,7 @@ def _process_batch( loss_value = loss_value_sp.item() forward_time += time.time() backward_time = -time.time() - if not forward_only: + if not self.is_forward_only: loss_value_sp.backward() backward_time += time.time() @@ -1093,7 +1112,7 @@ def _process_batch( loss_value += loss_value_po.item() forward_time += time.time() backward_time -= time.time() - if not forward_only: + if not self.is_forward_only: loss_value_po.backward() backward_time += time.time() From 66a9bc17d9232a8172cb151df4d4dbfed9b638f3 Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Wed, 27 May 2020 01:51:44 +0200 Subject: [PATCH 9/9] Adress code review eval.py: set splits train.py: use self.is_forward_only use dataset.index to load label_index for KvsAll sampler.py: rename filtering splits use dataset.index to load label_index for filtering splits config-default.yaml: add label_splits config for KvsAll indexing.py + misc.py: add index and helper funcs to merge KvsAll dicts --- kge/config-default.yaml | 6 +++++- kge/indexing.py | 19 +++++++++++++++++++ kge/job/eval.py | 34 +++++++++++++++++++++++++--------- kge/job/train.py | 11 +++++++++-- kge/misc.py | 31 +++++++++++++++++++++++++++++++ kge/util/sampler.py | 23 ++++++++++++++--------- 6 files changed, 103 insertions(+), 21 deletions(-) diff --git a/kge/config-default.yaml b/kge/config-default.yaml index 07f851900..c8cea184f 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -232,6 +232,10 @@ KvsAll: s_o: False _po: True + # Dataset splits from which the labels for a query are taken from. Default: If + # nothing is specified, then the train split is used. + label_splits: [] + # Options for negative sampling training (train.type=="negative_sampling") negative_sampling: # Negative sampler to use @@ -261,7 +265,7 @@ negative_sampling: p: False # as above o: False # as above - split: '' # split containing the positives; default is train.split + splits: [] # splits containing the positives; default is train.split # Implementation to use for filtering. # standard: use slow generic implementation, available for all samplers diff --git a/kge/indexing.py b/kge/indexing.py index b4a262915..f737c5ae8 100644 --- a/kge/indexing.py +++ b/kge/indexing.py @@ -2,6 +2,7 @@ from collections import defaultdict, OrderedDict import numba import numpy as np +from kge.misc import powerset, merge_dicts_of_1dim_torch_tensors def _group_by(keys, values) -> dict: @@ -220,6 +221,15 @@ def _invert_ids(dataset, obj: str): dataset.config.log(f"Indexed {len(inv)} {obj} ids", prefix=" ") +def merge_KvsAll_indexes(dataset, split, key): + value = dict([("sp", "o"), ("po", "s"), ("so", "p")])[key] + split_combi_str = "_".join(sorted(split)) + index_name = f"{split_combi_str}_{key}_to_{value}" + indexes = [dataset.index(f"{_split}_{key}_to_{value}") for _split in split] + dataset._indexes[index_name] = merge_dicts_of_1dim_torch_tensors(indexes) + return dataset._indexes[index_name] + + def create_default_index_functions(dataset: "Dataset"): for split in dataset.files_of_type("triples"): for key, value in [("sp", "o"), ("po", "s"), ("so", "p")]: @@ -227,6 +237,15 @@ def create_default_index_functions(dataset: "Dataset"): dataset.index_functions[f"{split}_{key}_to_{value}"] = IndexWrapper( index_KvsAll, split=split, key=key ) + # create all combinations of splits of length 2 and 3 + for split_combi in powerset(dataset.files_of_type("triples"), [2, 3]): + for key, value in [("sp", "o"), ("po", "s"), ("so", "p")]: + split_combi_str = "_".join(sorted(split_combi)) + index_name = f"{split_combi_str}_{key}_to_{value}" + dataset.index_functions[index_name] = IndexWrapper( + merge_KvsAll_indexes, split=split_combi, key=key + ) + dataset.index_functions["relation_types"] = index_relation_types dataset.index_functions["relations_per_type"] = index_relation_types dataset.index_functions["frequency_percentiles"] = index_frequency_percentiles diff --git a/kge/job/eval.py b/kge/job/eval.py index 3a672f0f0..c41a7024a 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -117,7 +117,7 @@ def run(self) -> Dict[str, Any]: "Evaluating on " + self.eval_split + " data (epoch {})...".format(self.epoch), - echo=self.verbose + echo=self.verbose, ) trace_entry = self._run() @@ -135,12 +135,16 @@ def run(self) -> Dict[str, Any]: f(self, trace_entry) # write out trace - trace_entry = self.trace(**trace_entry, echo=self.verbose, echo_prefix=" ", log=True) + trace_entry = self.trace( + **trace_entry, echo=self.verbose, echo_prefix=" ", log=True + ) # reset model and return metrics if was_training: self.model.train() - self.config.log("Finished evaluating on " + self.eval_split + " split.", echo=self.verbose) + self.config.log( + "Finished evaluating on " + self.eval_split + " split.", echo=self.verbose + ) for f in self.post_valid_hooks: f(self, trace_entry) @@ -175,11 +179,25 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model): train_job_on_eval_split_config = config.clone() train_job_on_eval_split_config.set("train.split", self.eval_split) - train_job_on_eval_split_config.set("negative_sampling.filtering.split", self.config.get("train.split")) + train_job_on_eval_split_config.set( + "negative_sampling.filtering.splits", + [self.config.get("train.split"), self.eval_split] + ["valid"] + if self.eval_split == "test" + else [], + ) + train_job_on_eval_split_config.set( + "KvsAll.label_splits", + [self.config.get("train.split"), self.eval_split] + ["valid"] + if self.eval_split == "test" + else [], + ) self._train_job = TrainingJob.create( - config=train_job_on_eval_split_config, parent_job=self, dataset=dataset, initialize_for_forward_only=True, + config=train_job_on_eval_split_config, + parent_job=self, + dataset=dataset, + initialize_for_forward_only=True, ) - + self._train_job.model = model self._train_job_verbose = False if self.__class__ == TrainingLossEvaluationJob: @@ -194,9 +212,7 @@ def _run(self) -> Dict[str, Any]: self.epoch = self.parent_job.epoch epoch_time += time.time() - train_trace_entry = self._train_job.run_epoch( - verbose=self._train_job_verbose - ) + train_trace_entry = self._train_job.run_epoch(verbose=self._train_job_verbose) # compute trace trace_entry = dict( type="training_loss", diff --git a/kge/job/train.py b/kge/job/train.py index ae988af85..5649ddec9 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -340,7 +340,7 @@ def run_epoch(self, verbose: bool = True) -> Dict[str, Any]: if not self.is_forward_only: self.optimizer.zero_grad() batch_result: TrainingJob._ProcessBatchResult = self._process_batch( - batch_index, batch, self.is_forward_only + batch_index, batch ) sum_loss += batch_result.avg_loss * batch_result.size @@ -575,6 +575,10 @@ def __init__( ) ) + self.label_splits = set(self.config.get("KvsAll.label_splits")) | set( + [self.train_split] + ) + self.queries = None self.labels = None self.label_offsets = None @@ -620,7 +624,10 @@ def _prepare(self): if query_type == "sp_" else ("so_to_p" if query_type == "s_o" else "po_to_s") ) - index = self.dataset.index(f"{self.train_split}_{index_type}") + split_combi_str = "_".join(sorted(self.label_splits)) + label_index = self.dataset.index(f"{split_combi_str}_{index_type}") + query_index = self.dataset.index(f"{self.train_split}_{index_type}") + index = {k: v for k, v in label_index.items() if k in query_index} self.num_examples += len(index) self.query_end_index[query_type] = self.num_examples diff --git a/kge/misc.py b/kge/misc.py index 90db3750c..01613d37b 100644 --- a/kge/misc.py +++ b/kge/misc.py @@ -1,3 +1,5 @@ +import torch +from itertools import chain, combinations from typing import List from torch import nn as nn @@ -6,6 +8,7 @@ import inspect import subprocess + def is_number(s, number_type): """ Returns True is string is a number. """ try: @@ -15,6 +18,33 @@ def is_number(s, number_type): return False +def merge_dicts_of_1dim_torch_tensors(keys_from_dols, values_from_dols=None): + if values_from_dols is None: + values_from_dols = keys_from_dols + keys = set(chain.from_iterable([d.keys() for d in keys_from_dols])) + no = torch.tensor([]).int() + return dict( + ( + k, + torch.cat([d.get(k, no) for d in values_from_dols]).sort()[0], + ) + for k in keys + ) + + +def powerset(iterable, filter_lens=None): + s = list(iterable) + return list( + map( + sorted, + filter( + lambda l: len(l) in filter_lens if filter_lens else True, + chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)), + ), + ) + ) + + # from https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script def get_git_revision_hash(): try: @@ -69,6 +99,7 @@ def is_exe(fpath): def kge_base_dir(): import kge + return os.path.abspath(filename_in_module(kge, "..")) diff --git a/kge/util/sampler.py b/kge/util/sampler.py index 75d5f3d4b..f0351f6d2 100644 --- a/kge/util/sampler.py +++ b/kge/util/sampler.py @@ -3,7 +3,7 @@ import random import torch -from typing import Optional +from typing import Optional, List import numpy as np import numba @@ -29,9 +29,9 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset): "Without replacement sampling is only supported when " "shared negative sampling is enabled." ) - self.filtering_split = config.get("negative_sampling.filtering.split") - if self.filtering_split == "": - self.filtering_split = config.get("train.split") + self.filtering_splits: List[str] = config.get("negative_sampling.filtering.splits") + if len(self.filtering_splits) == 0: + self.filtering_splits.append(config.get("train.split")) for slot in SLOTS: slot_str = SLOT_STR[slot] self.num_samples[slot] = self.get_option(f"num_samples.{slot_str}") @@ -43,7 +43,10 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset): # otherwise every worker would create every index again and again if self.filter_positives[slot]: pair = ["po", "so", "sp"][slot] - dataset.index(f"{self.filtering_split}_{pair}_to_{slot_str}") + for filtering_split in self.filtering_splits: + dataset.index(f"{filtering_split}_{pair}_to_{slot_str}") + filtering_splits_str = '_'.join(sorted(self.filtering_splits)) + dataset.index(f"{filtering_splits_str}_{pair}_to_{slot_str}") if any(self.filter_positives): if self.shared: raise ValueError( @@ -142,11 +145,12 @@ def _filter_and_resample( """Filter and resample indices until only negatives have been created. """ pair_str = ["po", "so", "sp"][slot] # holding the positive indices for the respective pair - index = self.dataset.index( - f"{self.filtering_split}_{pair_str}_to_{SLOT_STR[slot]}" - ) cols = [[P, O], [S, O], [S, P]][slot] pairs = positive_triples[:, cols] + split_combi_str = "_".join(sorted(self.filtering_splits)) + index = self.dataset.index( + f"{split_combi_str}_{pair_str}_to_{SLOT_STR[slot]}" + ) for i in range(positive_triples.size(0)): pair = (pairs[i][0].item(), pairs[i][1].item()) positives = ( @@ -266,8 +270,9 @@ def _filter_and_resample_fast( ): pair_str = ["po", "so", "sp"][slot] # holding the positive indices for the respective pair + split_combi_str = "_".join(sorted(self.filtering_splits)) index = self.dataset.index( - f"{self.filtering_split}_{pair_str}_to_{SLOT_STR[slot]}" + f"{split_combi_str}_{pair_str}_to_{SLOT_STR[slot]}" ) cols = [[P, O], [S, O], [S, P]][slot] pairs = positive_triples[:, cols].numpy()