From dc1141bb0d877dd2f7c6615357eff086f3478024 Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Sat, 16 May 2020 23:30:48 +0200 Subject: [PATCH 01/10] Add loss/cost on validation data --- kge/job/entity_ranking.py | 14 ++- kge/job/train.py | 249 ++++++++++++++++++++------------------ 2 files changed, 147 insertions(+), 116 deletions(-) diff --git a/kge/job/entity_ranking.py b/kge/job/entity_ranking.py index 9b36d1715..702b370fd 100644 --- a/kge/job/entity_ranking.py +++ b/kge/job/entity_ranking.py @@ -3,7 +3,7 @@ import torch import kge.job -from kge.job import EvaluationJob, Job +from kge.job import EvaluationJob, Job, TrainingJob from kge import Config, Dataset from collections import defaultdict @@ -386,6 +386,18 @@ def merge_hist(target_hists, source_hists): ) epoch_time += time.time() + if isinstance(self.parent_job, TrainingJob): + self.parent_job: TrainingJob = self.parent_job + train_trace_entry = self.parent_job.run_epoch( + split=self.eval_split, echo_trace=False, do_backward=False + ) + metrics.update( + avg_loss=train_trace_entry["avg_loss"], + avg_penalty=train_trace_entry["avg_penalty"], + avg_penalties=train_trace_entry["avg_penalties"], + avg_cost=train_trace_entry["avg_cost"], + ) + # compute trace trace_entry = dict( type="entity_ranking", diff --git a/kge/job/train.py b/kge/job/train.py index e778bf40b..4e3608a05 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -10,6 +10,7 @@ import torch import torch.utils.data import numpy as np +from torch.utils.data import DataLoader from kge import Config, Dataset from kge.job import Job @@ -81,11 +82,16 @@ def __init__( self.valid_job = EvaluationJob.create( valid_conf, dataset, parent_job=self, model=self.model ) - self.is_prepared = False + self.config.check("train.trace_level", ["batch", "epoch"]) + self.trace_batch: bool = self.config.get("train.trace_level") == "batch" + self.epoch: int = 0 + self.valid_trace: List[Dict[str, Any]] = [] + self.is_prepared: Dict[str, bool] = defaultdict(lambda: False) + self.model.train() # attributes filled in by implementing classes - self.loader = None - self.num_examples = None + self.loader: Dict[str, DataLoader] = dict() + self.num_examples: Dict[str, int] = dict() self.type_str: Optional[str] = None #: Hooks run after training for an epoch. @@ -185,7 +191,9 @@ def run(self) -> None: # start a new epoch self.epoch += 1 self.config.log("Starting epoch {}...".format(self.epoch)) - trace_entry = self.run_epoch() + trace_entry = self.run_epoch( + split=self.train_split, echo_trace=True, do_backward=True + ) for f in self.post_epoch_hooks: f(self, trace_entry) self.config.log("Finished epoch {}.".format(self.epoch)) @@ -292,14 +300,16 @@ def _load(self, checkpoint: Dict) -> str: ) ) - def run_epoch(self) -> Dict[str, Any]: + def run_epoch( + self, split: str, echo_trace: bool, do_backward: bool + ) -> Dict[str, Any]: "Runs an epoch and returns a trace entry." # prepare the job is not done already - if not self.is_prepared: - self._prepare() + if not self.is_prepared[split]: + self._prepare(split=split) self.model.prepare_job(self) # let the model add some hooks - self.is_prepared = True + self.is_prepared[split] = True # variables that record various statitics sum_loss = 0.0 @@ -312,14 +322,14 @@ def run_epoch(self) -> Dict[str, Any]: optimizer_time = 0.0 # process each batch - for batch_index, batch in enumerate(self.loader): + for batch_index, batch in enumerate(self.loader[split]): for f in self.pre_batch_hooks: f(self) # process batch (preprocessing + forward pass + backward pass on loss) self.optimizer.zero_grad() batch_result: TrainingJob._ProcessBatchResult = self._process_batch( - batch_index, batch + batch_index, batch, do_backward ) sum_loss += batch_result.avg_loss * batch_result.size @@ -328,7 +338,7 @@ def run_epoch(self) -> Dict[str, Any]: penalties_torch = self.model.penalty( epoch=self.epoch, batch_index=batch_index, - num_batches=len(self.loader), + num_batches=len(self.loader[split]), batch=batch, ) batch_forward_time += time.time() @@ -337,7 +347,8 @@ def run_epoch(self) -> Dict[str, Any]: batch_backward_time = batch_result.backward_time - time.time() penalty = 0.0 for index, (penalty_key, penalty_value_torch) in enumerate(penalties_torch): - penalty_value_torch.backward() + if do_backward: + penalty_value_torch.backward() penalty += penalty_value_torch.item() sum_penalties[penalty_key] += penalty_value_torch.item() sum_penalty += penalty @@ -377,9 +388,11 @@ def run_epoch(self) -> Dict[str, Any]: ) # update parameters - batch_optimizer_time = -time.time() - self.optimizer.step() - batch_optimizer_time += time.time() + batch_optimizer_time = 0 + if do_backward: + batch_optimizer_time = -time.time() + self.optimizer.step() + batch_optimizer_time += time.time() # tracing/logging if self.trace_batch: @@ -387,10 +400,10 @@ def run_epoch(self) -> Dict[str, Any]: "type": self.type_str, "scope": "batch", "epoch": self.epoch, - "split": self.train_split, + "split": split, "batch": batch_index, "size": batch_result.size, - "batches": len(self.loader), + "batches": len(self.loader[split]), "lr": [group["lr"] for group in self.optimizer.param_groups], "avg_loss": batch_result.avg_loss, "penalties": [p.item() for k, p in penalties_torch], @@ -404,29 +417,30 @@ def run_epoch(self) -> Dict[str, Any]: for f in self.post_batch_trace_hooks: f(self, batch_trace) self.trace(**batch_trace, event="batch_completed") - self.config.print( - ( - "\r" # go back - + "{} batch{: " - + str(1 + int(math.ceil(math.log10(len(self.loader))))) - + "d}/{}" - + ", avg_loss {:.4E}, penalty {:.4E}, cost {:.4E}, time {:6.2f}s" - + "\033[K" # clear to right - ).format( - self.config.log_prefix, - batch_index, - len(self.loader) - 1, - batch_result.avg_loss, - penalty, - cost_value, - batch_result.prepare_time - + batch_forward_time - + batch_backward_time - + batch_optimizer_time, - ), - end="", - flush=True, - ) + if echo_trace: + self.config.print( + ( + "\r" # go back + + "{} batch{: " + + str(1 + int(math.ceil(math.log10(len(self.loader))))) + + "d}/{}" + + ", avg_loss {:.4E}, penalty {:.4E}, cost {:.4E}, time {:6.2f}s" + + "\033[K" # clear to right + ).format( + self.config.log_prefix, + batch_index, + len(self.loader) - 1, + batch_result.avg_loss, + penalty, + cost_value, + batch_result.prepare_time + + batch_forward_time + + batch_backward_time + + batch_optimizer_time, + ), + end="", + flush=True, + ) # update times prepare_time += batch_result.prepare_time @@ -445,14 +459,15 @@ def run_epoch(self) -> Dict[str, Any]: type=self.type_str, scope="epoch", epoch=self.epoch, - split=self.train_split, - batches=len(self.loader), - size=self.num_examples, + split=split, + batches=len(self.loader[split]), + size=self.num_examples[split], lr=[group["lr"] for group in self.optimizer.param_groups], - avg_loss=sum_loss / self.num_examples, - avg_penalty=sum_penalty / len(self.loader), - avg_penalties={k: p / len(self.loader) for k, p in sum_penalties.items()}, - avg_cost=sum_loss / self.num_examples + sum_penalty / len(self.loader), + avg_loss=sum_loss / self.num_examples[split], + avg_penalty=sum_penalty / len(self.loader[split]), + avg_penalties={k: p / len(self.loader[split]) for k, p in sum_penalties.items()}, + avg_cost=sum_loss / self.num_examples[split] + + sum_penalty / len(self.loader[split]), epoch_time=epoch_time, prepare_time=prepare_time, forward_time=forward_time, @@ -463,10 +478,12 @@ def run_epoch(self) -> Dict[str, Any]: ) for f in self.post_epoch_trace_hooks: f(self, trace_entry) - trace_entry = self.trace(**trace_entry, echo=True, echo_prefix=" ", log=True) + trace_entry = self.trace( + **trace_entry, echo=echo_trace, echo_prefix=" ", log=True + ) return trace_entry - def _prepare(self): + def _prepare(self, split: str): """Prepare this job for running. Sets (at least) the `loader`, `num_examples`, and `type_str` attributes of this @@ -489,7 +506,7 @@ class _ProcessBatchResult: backward_time: float def _process_batch( - self, batch_index: int, batch + self, batch_index: int, batch, do_backward: bool ) -> "TrainingJob._ProcessBatchResult": "Run forward and backward pass on batch and return results." raise NotImplementedError @@ -543,6 +560,21 @@ def __init__(self, config, dataset, parent_job=None, model=None): ) ) + #' for each query type: list of queries + self.queries = {} + + #' for each query type: list of all labels (concatenated across queries) + self.labels = {} + + #' for each query type: list of starting offset of labels in self.labels. The + #' labels for the i-th query of query_type are in labels[query_type] in range + #' label_offsets[query_type][i]:label_offsets[query_type][i+1] + self.label_offsets = {} + + #' for each query type (ordered as in self.query_types), index right after last + #' example of that type in the list of all examples + self.query_end_index = {} + config.log("Initializing 1-to-N training job...") self.type_str = "KvsAll" @@ -550,7 +582,7 @@ def __init__(self, config, dataset, parent_job=None, model=None): for f in Job.job_created_hooks: f(self) - def _prepare(self): + def _prepare(self, split: str): from kge.indexing import index_KvsAll_to_torch # determine enabled query types @@ -560,44 +592,29 @@ def _prepare(self): if enabled ] - #' for each query type: list of queries - self.queries = {} - - #' for each query type: list of all labels (concatenated across queries) - self.labels = {} - - #' for each query type: list of starting offset of labels in self.labels. The - #' labels for the i-th query of query_type are in labels[query_type] in range - #' label_offsets[query_type][i]:label_offsets[query_type][i+1] - self.label_offsets = {} - - #' for each query type (ordered as in self.query_types), index right after last - #' example of that type in the list of all examples - self.query_end_index = [] - # construct relevant data structures - self.num_examples = 0 + self.num_examples[split] = 0 for query_type in self.query_types: index_type = ( "sp_to_o" if query_type == "sp_" else ("so_to_p" if query_type == "s_o" else "po_to_s") ) - index = self.dataset.index(f"{self.train_split}_{index_type}") - self.num_examples += len(index) - self.query_end_index.append(self.num_examples) + index = self.dataset.index(f"{split}_{index_type}") + self.num_examples[split] += len(index) + self.query_end_index[query_type + split] = self.num_examples[split] # Convert indexes to pytorch tensors (as described above). ( - self.queries[query_type], - self.labels[query_type], - self.label_offsets[query_type], + self.queries[query_type + split], + self.labels[query_type + split], + self.label_offsets[query_type + split], ) = index_KvsAll_to_torch(index) # create dataloader - self.loader = torch.utils.data.DataLoader( - range(self.num_examples), - collate_fn=self._get_collate_fun(), + self.loader[split] = torch.utils.data.DataLoader( + range(self.num_examples[split]), + collate_fn=self._get_collate_fun(split), shuffle=True, batch_size=self.batch_size, num_workers=self.config.get("train.num_workers"), @@ -605,7 +622,7 @@ def _prepare(self): pin_memory=self.config.get("train.pin_memory"), ) - def _get_collate_fun(self): + def _get_collate_fun(self, split: str): # create the collate function def collate(batch): """For a batch of size n, returns a dictionary of: @@ -624,11 +641,15 @@ def collate(batch): for example_index in batch: start = 0 for query_type_index, query_type in enumerate(self.query_types): - end = self.query_end_index[query_type_index] + end = self.query_end_index[query_type + split] if example_index < end: example_index -= start - num_ones += self.label_offsets[query_type][example_index + 1] - num_ones -= self.label_offsets[query_type][example_index] + num_ones += self.label_offsets[query_type + split][ + example_index + 1 + ] + num_ones -= self.label_offsets[query_type + split][ + example_index + ] break start = end @@ -641,13 +662,13 @@ def collate(batch): for batch_index, example_index in enumerate(batch): start = 0 for query_type_index, query_type in enumerate(self.query_types): - end = self.query_end_index[query_type_index] + end = self.query_end_index[query_type + split] if example_index < end: example_index -= start query_type_indexes_batch[batch_index] = query_type_index - queries = self.queries[query_type] - label_offsets = self.label_offsets[query_type] - labels = self.labels[query_type] + queries = self.queries[query_type + split] + label_offsets = self.label_offsets[query_type + split] + labels = self.labels[query_type + split] if query_type == "sp_": query_col_1, query_col_2, target_col = S, P, O elif query_type == "s_o": @@ -688,7 +709,9 @@ def collate(batch): return collate - def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: + def _process_batch( + self, batch_index, batch, do_backward: bool + ) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() queries_batch = batch["queries"].to(self.device) @@ -759,7 +782,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: loss_value_total = loss_value.item() forward_time += time.time() backward_time -= time.time() - loss_value.backward() + if do_backward: + loss_value.backward() backward_time += time.time() # all done @@ -772,7 +796,6 @@ class TrainingJobNegativeSampling(TrainingJob): def __init__(self, config, dataset, parent_job=None, model=None): super().__init__(config, dataset, parent_job, model=model) self._sampler = KgeSampler.create(config, "negative_sampling", dataset) - self.is_prepared = False self._implementation = self.config.check( "negative_sampling.implementation", ["triple", "all", "batch", "auto"], ) @@ -796,16 +819,13 @@ def __init__(self, config, dataset, parent_job=None, model=None): for f in Job.job_created_hooks: f(self) - def _prepare(self): + def _prepare(self, split: str): """Construct dataloader""" - if self.is_prepared: - return - - self.num_examples = self.dataset.split(self.train_split).size(0) - self.loader = torch.utils.data.DataLoader( - range(self.num_examples), - collate_fn=self._get_collate_fun(), + self.num_examples[split] = self.dataset.split(split).size(0) + self.loader[split] = torch.utils.data.DataLoader( + range(self.num_examples[split]), + collate_fn=self._get_collate_fun(split), shuffle=True, batch_size=self.batch_size, num_workers=self.config.get("train.num_workers"), @@ -813,9 +833,7 @@ def _prepare(self): pin_memory=self.config.get("train.pin_memory"), ) - self.is_prepared = True - - def _get_collate_fun(self): + def _get_collate_fun(self, split: str): # create the collate function def collate(batch): """For a batch of size n, returns a tuple of: @@ -825,7 +843,7 @@ def collate(batch): in order S,P,O) """ - triples = self.dataset.split(self.train_split)[batch, :].long() + triples = self.dataset.split(split)[batch, :].long() # labels = torch.zeros((len(batch), self._sampler.num_negatives_total + 1)) # labels[:, 0] = 1 # labels = labels.view(-1) @@ -837,7 +855,9 @@ def collate(batch): return collate - def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: + def _process_batch( + self, batch_index, batch, do_backward: bool + ) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() batch_triples = batch["triples"].to(self.device) @@ -1007,7 +1027,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: # backward pass for this chunk backward_time -= time.time() - loss_value_torch.backward() + if do_backward: + loss_value_torch.backward() backward_time += time.time() # all done @@ -1021,7 +1042,6 @@ class TrainingJob1vsAll(TrainingJob): def __init__(self, config, dataset, parent_job=None, model=None): super().__init__(config, dataset, parent_job, model=model) - self.is_prepared = False config.log("Initializing spo training job...") self.type_str = "1vsAll" @@ -1029,17 +1049,14 @@ def __init__(self, config, dataset, parent_job=None, model=None): for f in Job.job_created_hooks: f(self) - def _prepare(self): + def _prepare(self, split: str): """Construct dataloader""" - if self.is_prepared: - return - - self.num_examples = self.dataset.split(self.train_split).size(0) - self.loader = torch.utils.data.DataLoader( - range(self.num_examples), + self.num_examples[split] = self.dataset.split(split).size(0) + self.loader[split] = torch.utils.data.DataLoader( + range(self.num_examples[split]), collate_fn=lambda batch: { - "triples": self.dataset.split(self.train_split)[batch, :].long() + "triples": self.dataset.split(split)[batch, :].long() }, shuffle=True, batch_size=self.batch_size, @@ -1048,9 +1065,9 @@ def _prepare(self): pin_memory=self.config.get("train.pin_memory"), ) - self.is_prepared = True - - def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: + def _process_batch( + self, batch_index, batch, do_backward: bool + ) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() triples = batch["triples"].to(self.device) @@ -1064,7 +1081,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: loss_value = loss_value_sp.item() forward_time += time.time() backward_time = -time.time() - loss_value_sp.backward() + if do_backward: + loss_value_sp.backward() backward_time += time.time() # forward/backward pass (po) @@ -1074,7 +1092,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: loss_value += loss_value_po.item() forward_time += time.time() backward_time -= time.time() - loss_value_po.backward() + if do_backward: + loss_value_po.backward() backward_time += time.time() # all done From 2779ef82158c36ca18f8415c7788680cf663e7de Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Thu, 21 May 2020 00:19:40 +0200 Subject: [PATCH 02/10] Use TrainJob in Eval Job Add EvalTrainingLossJob --- kge/job/entity_ranking.py | 15 +--- kge/job/eval.py | 74 ++++++++++++++++++ kge/job/train.py | 158 ++++++++++++++++++-------------------- 3 files changed, 151 insertions(+), 96 deletions(-) diff --git a/kge/job/entity_ranking.py b/kge/job/entity_ranking.py index 702b370fd..19c685800 100644 --- a/kge/job/entity_ranking.py +++ b/kge/job/entity_ranking.py @@ -386,17 +386,10 @@ def merge_hist(target_hists, source_hists): ) epoch_time += time.time() - if isinstance(self.parent_job, TrainingJob): - self.parent_job: TrainingJob = self.parent_job - train_trace_entry = self.parent_job.run_epoch( - split=self.eval_split, echo_trace=False, do_backward=False - ) - metrics.update( - avg_loss=train_trace_entry["avg_loss"], - avg_penalty=train_trace_entry["avg_penalty"], - avg_penalties=train_trace_entry["avg_penalties"], - avg_cost=train_trace_entry["avg_cost"], - ) + train_trace_entry = self.eval_train_loss_job.run_epoch( + echo_trace=False, forward_only=False + ) + metrics.update(avg_loss=train_trace_entry["avg_loss"],) # compute trace trace_entry = dict( diff --git a/kge/job/eval.py b/kge/job/eval.py index 138726728..f2f218055 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -1,4 +1,7 @@ +import time + import torch +from kge import Config, Dataset from kge import Config, Dataset from kge.job import Job @@ -34,6 +37,12 @@ def __init__(self, config, dataset, parent_job, model): self.filter_with_test = config.get("eval.filter_with_test") self.epoch = -1 + train_job_on_eval_split_config = config.clone() + train_job_on_eval_split_config.set("train.split", self.eval_split) + self.eval_train_loss_job = TrainingJob.create( + config=train_job_on_eval_split_config, parent_job=self, dataset=dataset + ) + #: Hooks run after training for an epoch. #: Signature: job, trace_entry self.post_epoch_hooks = [] @@ -132,6 +141,71 @@ def create_from( return super().create_from(checkpoint, new_config, dataset, parent_job) +class EvalTrainingLossJob(EvaluationJob): + """ Entity ranking evaluation protocol """ + + def __init__(self, config: Config, dataset: Dataset, parent_job, model): + super().__init__(config, dataset, parent_job, model) + self.is_prepared = False + + if self.__class__ == EvalTrainingLossJob: + for f in Job.job_created_hooks: + f(self) + + @torch.no_grad() + def run(self) -> dict: + + epoch_time = -time.time() + + was_training = self.model.training + self.model.eval() + self.config.log( + "Evaluating on " + + self.eval_split + + " data (epoch {})...".format(self.epoch) + ) + epoch_time += time.time() + + train_trace_entry = self.eval_train_loss_job.run_epoch( + echo_trace=False, forward_only=False + ) + # compute trace + trace_entry = dict( + type="eval_train_loss_job", + scope="epoch", + split=self.eval_split, + epoch=self.epoch, + epoch_time=epoch_time, + event="eval_completed", + avg_loss=train_trace_entry["avg_loss"], + ) + for f in self.post_epoch_trace_hooks: + f(self, trace_entry) + + # if validation metric is not present, try to compute it + metric_name = self.config.get("valid.metric") + if metric_name not in trace_entry: + trace_entry[metric_name] = eval( + self.config.get("valid.metric_expr"), + None, + dict(config=self.config, **trace_entry), + ) + + # write out trace + trace_entry = self.trace(**trace_entry, echo=True, echo_prefix=" ", log=True) + + # reset model and return metrics + if was_training: + self.model.train() + self.config.log("Finished evaluating train loss on " + self.eval_split + " split.") + + for f in self.post_valid_hooks: + f(self, trace_entry) + + return trace_entry + + + # HISTOGRAM COMPUTATION ############################################################### diff --git a/kge/job/train.py b/kge/job/train.py index 4e3608a05..00e5e8fa3 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -69,29 +69,26 @@ def __init__( self.batch_size: int = config.get("train.batch_size") self.device: str = self.config.get("job.device") self.train_split = config.get("train.split") - - self.config.check("train.trace_level", ["batch", "epoch"]) - self.trace_batch: bool = self.config.get("train.trace_level") == "batch" - self.epoch: int = 0 - self.valid_trace: List[Dict[str, Any]] = [] - valid_conf = config.clone() - valid_conf.set("job.type", "eval") - if self.config.get("valid.split") != "": - valid_conf.set("eval.split", self.config.get("valid.split")) - valid_conf.set("eval.trace_level", self.config.get("valid.trace_level")) - self.valid_job = EvaluationJob.create( - valid_conf, dataset, parent_job=self, model=self.model - ) self.config.check("train.trace_level", ["batch", "epoch"]) self.trace_batch: bool = self.config.get("train.trace_level") == "batch" self.epoch: int = 0 - self.valid_trace: List[Dict[str, Any]] = [] - self.is_prepared: Dict[str, bool] = defaultdict(lambda: False) + self.is_prepared = False self.model.train() + if config.get("job.type") == "train": + valid_conf = config.clone() + valid_conf.set("job.type", "eval") + if self.config.get("valid.split") != "": + valid_conf.set("eval.split", self.config.get("valid.split")) + valid_conf.set("eval.trace_level", self.config.get("valid.trace_level")) + self.valid_job = EvaluationJob.create( + valid_conf, dataset, parent_job=self, model=self.model + ) + self.valid_trace: List[Dict[str, Any]] = [] + # attributes filled in by implementing classes - self.loader: Dict[str, DataLoader] = dict() - self.num_examples: Dict[str, int] = dict() + self.loader = None + self.num_examples = None self.type_str: Optional[str] = None #: Hooks run after training for an epoch. @@ -191,9 +188,7 @@ def run(self) -> None: # start a new epoch self.epoch += 1 self.config.log("Starting epoch {}...".format(self.epoch)) - trace_entry = self.run_epoch( - split=self.train_split, echo_trace=True, do_backward=True - ) + trace_entry = self.run_epoch(echo_trace=True, forward_only=True) for f in self.post_epoch_hooks: f(self, trace_entry) self.config.log("Finished epoch {}.".format(self.epoch)) @@ -300,16 +295,14 @@ def _load(self, checkpoint: Dict) -> str: ) ) - def run_epoch( - self, split: str, echo_trace: bool, do_backward: bool - ) -> Dict[str, Any]: + def run_epoch(self, echo_trace: bool, forward_only: bool) -> Dict[str, Any]: "Runs an epoch and returns a trace entry." # prepare the job is not done already - if not self.is_prepared[split]: - self._prepare(split=split) + if not self.is_prepared: + self._prepare() self.model.prepare_job(self) # let the model add some hooks - self.is_prepared[split] = True + self.is_prepared = True # variables that record various statitics sum_loss = 0.0 @@ -322,14 +315,14 @@ def run_epoch( optimizer_time = 0.0 # process each batch - for batch_index, batch in enumerate(self.loader[split]): + for batch_index, batch in enumerate(self.loader): for f in self.pre_batch_hooks: f(self) # process batch (preprocessing + forward pass + backward pass on loss) self.optimizer.zero_grad() batch_result: TrainingJob._ProcessBatchResult = self._process_batch( - batch_index, batch, do_backward + batch_index, batch, forward_only ) sum_loss += batch_result.avg_loss * batch_result.size @@ -338,7 +331,7 @@ def run_epoch( penalties_torch = self.model.penalty( epoch=self.epoch, batch_index=batch_index, - num_batches=len(self.loader[split]), + num_batches=len(self.loader), batch=batch, ) batch_forward_time += time.time() @@ -347,7 +340,7 @@ def run_epoch( batch_backward_time = batch_result.backward_time - time.time() penalty = 0.0 for index, (penalty_key, penalty_value_torch) in enumerate(penalties_torch): - if do_backward: + if forward_only: penalty_value_torch.backward() penalty += penalty_value_torch.item() sum_penalties[penalty_key] += penalty_value_torch.item() @@ -389,7 +382,7 @@ def run_epoch( # update parameters batch_optimizer_time = 0 - if do_backward: + if forward_only: batch_optimizer_time = -time.time() self.optimizer.step() batch_optimizer_time += time.time() @@ -400,10 +393,10 @@ def run_epoch( "type": self.type_str, "scope": "batch", "epoch": self.epoch, - "split": split, + "split": self.train_split, "batch": batch_index, "size": batch_result.size, - "batches": len(self.loader[split]), + "batches": len(self.loader), "lr": [group["lr"] for group in self.optimizer.param_groups], "avg_loss": batch_result.avg_loss, "penalties": [p.item() for k, p in penalties_torch], @@ -459,15 +452,14 @@ def run_epoch( type=self.type_str, scope="epoch", epoch=self.epoch, - split=split, - batches=len(self.loader[split]), - size=self.num_examples[split], + split=self.train_split, + batches=len(self.loader), + size=self.num_examples, lr=[group["lr"] for group in self.optimizer.param_groups], - avg_loss=sum_loss / self.num_examples[split], - avg_penalty=sum_penalty / len(self.loader[split]), - avg_penalties={k: p / len(self.loader[split]) for k, p in sum_penalties.items()}, - avg_cost=sum_loss / self.num_examples[split] - + sum_penalty / len(self.loader[split]), + avg_loss=sum_loss / self.num_examples, + avg_penalty=sum_penalty / len(self.loader), + avg_penalties={k: p / len(self.loader) for k, p in sum_penalties.items()}, + avg_cost=sum_loss / self.num_examples + sum_penalty / len(self.loader), epoch_time=epoch_time, prepare_time=prepare_time, forward_time=forward_time, @@ -483,7 +475,7 @@ def run_epoch( ) return trace_entry - def _prepare(self, split: str): + def _prepare(self): """Prepare this job for running. Sets (at least) the `loader`, `num_examples`, and `type_str` attributes of this @@ -506,7 +498,7 @@ class _ProcessBatchResult: backward_time: float def _process_batch( - self, batch_index: int, batch, do_backward: bool + self, batch_index: int, batch, forward_only: bool ) -> "TrainingJob._ProcessBatchResult": "Run forward and backward pass on batch and return results." raise NotImplementedError @@ -582,7 +574,7 @@ def __init__(self, config, dataset, parent_job=None, model=None): for f in Job.job_created_hooks: f(self) - def _prepare(self, split: str): + def _prepare(self): from kge.indexing import index_KvsAll_to_torch # determine enabled query types @@ -593,28 +585,28 @@ def _prepare(self, split: str): ] # construct relevant data structures - self.num_examples[split] = 0 + self.num_examples = 0 for query_type in self.query_types: index_type = ( "sp_to_o" if query_type == "sp_" else ("so_to_p" if query_type == "s_o" else "po_to_s") ) - index = self.dataset.index(f"{split}_{index_type}") - self.num_examples[split] += len(index) - self.query_end_index[query_type + split] = self.num_examples[split] + index = self.dataset.index(f"{self.train_split}_{index_type}") + self.num_examples += len(index) + self.query_end_index[query_type] = self.num_examples # Convert indexes to pytorch tensors (as described above). ( - self.queries[query_type + split], - self.labels[query_type + split], - self.label_offsets[query_type + split], + self.queries[query_type], + self.labels[query_type], + self.label_offsets[query_type], ) = index_KvsAll_to_torch(index) # create dataloader - self.loader[split] = torch.utils.data.DataLoader( - range(self.num_examples[split]), - collate_fn=self._get_collate_fun(split), + self.loader = torch.utils.data.DataLoader( + range(self.num_examples), + collate_fn=self._get_collate_fun(), shuffle=True, batch_size=self.batch_size, num_workers=self.config.get("train.num_workers"), @@ -622,7 +614,7 @@ def _prepare(self, split: str): pin_memory=self.config.get("train.pin_memory"), ) - def _get_collate_fun(self, split: str): + def _get_collate_fun(self): # create the collate function def collate(batch): """For a batch of size n, returns a dictionary of: @@ -641,15 +633,11 @@ def collate(batch): for example_index in batch: start = 0 for query_type_index, query_type in enumerate(self.query_types): - end = self.query_end_index[query_type + split] + end = self.query_end_index[query_type] if example_index < end: example_index -= start - num_ones += self.label_offsets[query_type + split][ - example_index + 1 - ] - num_ones -= self.label_offsets[query_type + split][ - example_index - ] + num_ones += self.label_offsets[query_type][example_index + 1] + num_ones -= self.label_offsets[query_type][example_index] break start = end @@ -662,13 +650,13 @@ def collate(batch): for batch_index, example_index in enumerate(batch): start = 0 for query_type_index, query_type in enumerate(self.query_types): - end = self.query_end_index[query_type + split] + end = self.query_end_index[query_type] if example_index < end: example_index -= start query_type_indexes_batch[batch_index] = query_type_index - queries = self.queries[query_type + split] - label_offsets = self.label_offsets[query_type + split] - labels = self.labels[query_type + split] + queries = self.queries[query_type] + label_offsets = self.label_offsets[query_type] + labels = self.labels[query_type] if query_type == "sp_": query_col_1, query_col_2, target_col = S, P, O elif query_type == "s_o": @@ -710,7 +698,7 @@ def collate(batch): return collate def _process_batch( - self, batch_index, batch, do_backward: bool + self, batch_index, batch, forward_only: bool ) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() @@ -782,7 +770,7 @@ def _process_batch( loss_value_total = loss_value.item() forward_time += time.time() backward_time -= time.time() - if do_backward: + if forward_only: loss_value.backward() backward_time += time.time() @@ -819,13 +807,13 @@ def __init__(self, config, dataset, parent_job=None, model=None): for f in Job.job_created_hooks: f(self) - def _prepare(self, split: str): + def _prepare(self): """Construct dataloader""" - self.num_examples[split] = self.dataset.split(split).size(0) - self.loader[split] = torch.utils.data.DataLoader( - range(self.num_examples[split]), - collate_fn=self._get_collate_fun(split), + self.num_examples = self.dataset.split(self.train_split).size(0) + self.loader = torch.utils.data.DataLoader( + range(self.num_examples), + collate_fn=self._get_collate_fun(), shuffle=True, batch_size=self.batch_size, num_workers=self.config.get("train.num_workers"), @@ -833,7 +821,7 @@ def _prepare(self, split: str): pin_memory=self.config.get("train.pin_memory"), ) - def _get_collate_fun(self, split: str): + def _get_collate_fun(self): # create the collate function def collate(batch): """For a batch of size n, returns a tuple of: @@ -843,7 +831,7 @@ def collate(batch): in order S,P,O) """ - triples = self.dataset.split(split)[batch, :].long() + triples = self.dataset.split(self.train_split)[batch, :].long() # labels = torch.zeros((len(batch), self._sampler.num_negatives_total + 1)) # labels[:, 0] = 1 # labels = labels.view(-1) @@ -856,7 +844,7 @@ def collate(batch): return collate def _process_batch( - self, batch_index, batch, do_backward: bool + self, batch_index, batch, forward_only: bool ) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() @@ -1027,7 +1015,7 @@ def _process_batch( # backward pass for this chunk backward_time -= time.time() - if do_backward: + if forward_only: loss_value_torch.backward() backward_time += time.time() @@ -1049,14 +1037,14 @@ def __init__(self, config, dataset, parent_job=None, model=None): for f in Job.job_created_hooks: f(self) - def _prepare(self, split: str): + def _prepare(self): """Construct dataloader""" - self.num_examples[split] = self.dataset.split(split).size(0) - self.loader[split] = torch.utils.data.DataLoader( - range(self.num_examples[split]), + self.num_examples = self.dataset.split(self.train_split).size(0) + self.loader = torch.utils.data.DataLoader( + range(self.num_examples), collate_fn=lambda batch: { - "triples": self.dataset.split(split)[batch, :].long() + "triples": self.dataset.split(self.train_split)[batch, :].long() }, shuffle=True, batch_size=self.batch_size, @@ -1066,7 +1054,7 @@ def _prepare(self, split: str): ) def _process_batch( - self, batch_index, batch, do_backward: bool + self, batch_index, batch, forward_only: bool ) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() @@ -1081,7 +1069,7 @@ def _process_batch( loss_value = loss_value_sp.item() forward_time += time.time() backward_time = -time.time() - if do_backward: + if forward_only: loss_value_sp.backward() backward_time += time.time() @@ -1092,7 +1080,7 @@ def _process_batch( loss_value += loss_value_po.item() forward_time += time.time() backward_time -= time.time() - if do_backward: + if forward_only: loss_value_po.backward() backward_time += time.time() From 68f9832093b94243144bf95b82ab1f7dc1a6042f Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Fri, 22 May 2020 11:55:00 +0200 Subject: [PATCH 03/10] integrate training_loss eval --- kge/config-default.yaml | 3 ++- kge/job/eval.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/kge/config-default.yaml b/kge/config-default.yaml index 5ef718b2e..ec508986f 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -327,7 +327,8 @@ eval: # mean_reciprocal_rank_filtered_with_test. filter_with_test: True - # Type of evaluation (entity_ranking only at the moment) + # Type of evaluation (entity_ranking or only training_loss, while entity + # ranking runs training_loss as well.) type: entity_ranking # Compute Hits@K for these choices of K diff --git a/kge/job/eval.py b/kge/job/eval.py index f2f218055..af1241af1 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -90,6 +90,10 @@ def create(config, dataset, parent_job=None, model=None): return EntityPairRankingJob( config, dataset, parent_job=parent_job, model=model ) + elif config.get("eval.type") == "training_loss": + return EvalTrainingLossJob( + config, dataset, parent_job=parent_job, model=model + ) else: raise ValueError("eval.type") @@ -171,7 +175,7 @@ def run(self) -> dict: ) # compute trace trace_entry = dict( - type="eval_train_loss_job", + type="training_loss", scope="epoch", split=self.eval_split, epoch=self.epoch, @@ -197,7 +201,9 @@ def run(self) -> dict: # reset model and return metrics if was_training: self.model.train() - self.config.log("Finished evaluating train loss on " + self.eval_split + " split.") + self.config.log( + "Finished evaluating train loss on " + self.eval_split + " split." + ) for f in self.post_valid_hooks: f(self, trace_entry) @@ -205,7 +211,6 @@ def run(self) -> dict: return trace_entry - # HISTOGRAM COMPUTATION ############################################################### From 6f8e5f4e668789d66b10dce3671a3d3c163c00de Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Sat, 23 May 2020 00:32:53 +0200 Subject: [PATCH 04/10] Adress reviews --- kge/job/eval.py | 2 +- kge/job/train.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/kge/job/eval.py b/kge/job/eval.py index af1241af1..c8239a064 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -171,7 +171,7 @@ def run(self) -> dict: epoch_time += time.time() train_trace_entry = self.eval_train_loss_job.run_epoch( - echo_trace=False, forward_only=False + echo_trace=False, forward_only=True ) # compute trace trace_entry = dict( diff --git a/kge/job/train.py b/kge/job/train.py index 00e5e8fa3..f05a91e76 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -188,7 +188,7 @@ def run(self) -> None: # start a new epoch self.epoch += 1 self.config.log("Starting epoch {}...".format(self.epoch)) - trace_entry = self.run_epoch(echo_trace=True, forward_only=True) + trace_entry = self.run_epoch(echo_trace=True, forward_only=False) for f in self.post_epoch_hooks: f(self, trace_entry) self.config.log("Finished epoch {}.".format(self.epoch)) @@ -340,7 +340,7 @@ def run_epoch(self, echo_trace: bool, forward_only: bool) -> Dict[str, Any]: batch_backward_time = batch_result.backward_time - time.time() penalty = 0.0 for index, (penalty_key, penalty_value_torch) in enumerate(penalties_torch): - if forward_only: + if not forward_only: penalty_value_torch.backward() penalty += penalty_value_torch.item() sum_penalties[penalty_key] += penalty_value_torch.item() @@ -382,7 +382,7 @@ def run_epoch(self, echo_trace: bool, forward_only: bool) -> Dict[str, Any]: # update parameters batch_optimizer_time = 0 - if forward_only: + if not forward_only: batch_optimizer_time = -time.time() self.optimizer.step() batch_optimizer_time += time.time() @@ -770,7 +770,7 @@ def _process_batch( loss_value_total = loss_value.item() forward_time += time.time() backward_time -= time.time() - if forward_only: + if not forward_only: loss_value.backward() backward_time += time.time() @@ -1015,7 +1015,7 @@ def _process_batch( # backward pass for this chunk backward_time -= time.time() - if forward_only: + if not forward_only: loss_value_torch.backward() backward_time += time.time() @@ -1069,7 +1069,7 @@ def _process_batch( loss_value = loss_value_sp.item() forward_time += time.time() backward_time = -time.time() - if forward_only: + if not forward_only: loss_value_sp.backward() backward_time += time.time() @@ -1080,7 +1080,7 @@ def _process_batch( loss_value += loss_value_po.item() forward_time += time.time() backward_time -= time.time() - if forward_only: + if not forward_only: loss_value_po.backward() backward_time += time.time() From c72fc22d63f78f62b61694b41207ff30db0a91c3 Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Sat, 23 May 2020 02:45:41 +0200 Subject: [PATCH 05/10] Small revision of EvaluationJob: run() is implemented now in EvaluationJob with standard stuff that has to be done for every EvaluationJob and then calls self._run() Shift EntityRanking specific stuff to there Make EvalTrainingLossJob an instance variable of EvaluationJob and run it as post_epoch_trace_hook Add self.verbose to EvaluationJob and use it Adress code review --- kge/job/entity_ranking.py | 53 ++++------------ kge/job/eval.py | 129 ++++++++++++++++++++++++-------------- kge/job/train.py | 32 +++++----- 3 files changed, 111 insertions(+), 103 deletions(-) diff --git a/kge/job/entity_ranking.py b/kge/job/entity_ranking.py index 19c685800..299551bf9 100644 --- a/kge/job/entity_ranking.py +++ b/kge/job/entity_ranking.py @@ -1,5 +1,6 @@ import math import time +from typing import Dict, Any import torch import kge.job @@ -13,18 +14,23 @@ class EntityRankingJob(EvaluationJob): def __init__(self, config: Config, dataset: Dataset, parent_job, model): super().__init__(config, dataset, parent_job, model) - self.is_prepared = False if self.__class__ == EntityRankingJob: for f in Job.job_created_hooks: f(self) + max_k = min( + self.dataset.num_entities(), max(self.config.get("eval.hits_at_k_s")) + ) + self.hits_at_k_s = list( + filter(lambda x: x <= max_k, self.config.get("eval.hits_at_k_s")) + ) + self.filter_with_test = config.get("eval.filter_with_test") + + def _prepare(self): """Construct all indexes needed to run.""" - if self.is_prepared: - return - # create data and precompute indexes self.triples = self.dataset.split(self.config.get("eval.split")) for split in self.filter_splits: @@ -75,16 +81,8 @@ def _collate(self, batch): return batch, label_coords, test_label_coords @torch.no_grad() - def run(self) -> dict: - self._prepare() - - was_training = self.model.training - self.model.eval() - self.config.log( - "Evaluating on " - + self.eval_split - + " data (epoch {})...".format(self.epoch) - ) + def _run(self) -> Dict[str, Any]: + num_entities = self.dataset.num_entities() # we also filter with test data if requested @@ -386,11 +384,6 @@ def merge_hist(target_hists, source_hists): ) epoch_time += time.time() - train_trace_entry = self.eval_train_loss_job.run_epoch( - echo_trace=False, forward_only=False - ) - metrics.update(avg_loss=train_trace_entry["avg_loss"],) - # compute trace trace_entry = dict( type="entity_ranking", @@ -404,28 +397,6 @@ def merge_hist(target_hists, source_hists): event="eval_completed", **metrics, ) - for f in self.post_epoch_trace_hooks: - f(self, trace_entry) - - # if validation metric is not present, try to compute it - metric_name = self.config.get("valid.metric") - if metric_name not in trace_entry: - trace_entry[metric_name] = eval( - self.config.get("valid.metric_expr"), - None, - dict(config=self.config, **trace_entry), - ) - - # write out trace - trace_entry = self.trace(**trace_entry, echo=True, echo_prefix=" ", log=True) - - # reset model and return metrics - if was_training: - self.model.train() - self.config.log("Finished evaluating on " + self.eval_split + " split.") - - for f in self.post_valid_hooks: - f(self, trace_entry) return trace_entry diff --git a/kge/job/eval.py b/kge/job/eval.py index c8239a064..1ab8049d0 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -1,4 +1,5 @@ import time +from typing import Any, Dict import torch from kge import Config, Dataset @@ -19,12 +20,6 @@ def __init__(self, config, dataset, parent_job, model): self.model = model self.batch_size = config.get("eval.batch_size") self.device = self.config.get("job.device") - max_k = min( - self.dataset.num_entities(), max(self.config.get("eval.hits_at_k_s")) - ) - self.hits_at_k_s = list( - filter(lambda x: x <= max_k, self.config.get("eval.hits_at_k_s")) - ) self.config.check("train.trace_level", ["example", "batch", "epoch"]) self.trace_examples = self.config.get("eval.trace_level") == "example" self.trace_batch = ( @@ -34,14 +29,10 @@ def __init__(self, config, dataset, parent_job, model): self.filter_splits = self.config.get("eval.filter_splits") if self.eval_split not in self.filter_splits: self.filter_splits.append(self.eval_split) - self.filter_with_test = config.get("eval.filter_with_test") self.epoch = -1 - train_job_on_eval_split_config = config.clone() - train_job_on_eval_split_config.set("train.split", self.eval_split) - self.eval_train_loss_job = TrainingJob.create( - config=train_job_on_eval_split_config, parent_job=self, dataset=dataset - ) + self.verbose = True + self.is_prepared = False #: Hooks run after training for an epoch. #: Signature: job, trace_entry @@ -73,6 +64,22 @@ def __init__(self, config, dataset, parent_job, model): if config.get("eval.metrics_per.argument_frequency"): self.hist_hooks.append(hist_per_frequency_percentile) + # Add the training loss as a default to every evaluation job + # TODO: create AggregatingEvaluationsJob that runs and aggregates a list + # of EvaluationAjobs, such that users can configure combinations of + # EvalJobs themselves. Then this can be removed. + # See https://github.com/uma-pi1/kge/issues/102 + if not isinstance(self, EvalTrainingLossJob): + self.eval_train_loss_job = EvalTrainingLossJob( + config, dataset, parent_job=self, model=model + ) + self.eval_train_loss_job.verbose = False + self.post_epoch_trace_hooks.append( + lambda job, trace: trace.update( + avg_loss=self.eval_train_loss_job.run()["avg_loss"] + ) + ) + # all done, run job_created_hooks if necessary if self.__class__ == EvaluationJob: for f in Job.job_created_hooks: @@ -97,7 +104,55 @@ def create(config, dataset, parent_job=None, model=None): else: raise ValueError("eval.type") - def run(self) -> dict: + def _prepare(self): + """Prepare this job for running. Guaranteed to be called exactly once + """ + raise NotImplementedError + + def run(self) -> Dict[str, Any]: + + if not self.is_prepared: + self._prepare() + self.model.prepare_job(self) # let the model add some hooks + self.is_prepared = True + + was_training = self.model.training + self.model.eval() + self.config.log( + "Evaluating on " + + self.eval_split + + " data (epoch {})...".format(self.epoch), + echo=self.verbose + ) + + trace_entry = self._run() + + # if validation metric is not present, try to compute it + metric_name = self.config.get("valid.metric") + if metric_name not in trace_entry: + trace_entry[metric_name] = eval( + self.config.get("valid.metric_expr"), + None, + dict(config=self.config, **trace_entry), + ) + + for f in self.post_epoch_trace_hooks: + f(self, trace_entry) + + # write out trace + trace_entry = self.trace(**trace_entry, echo=self.verbose, echo_prefix=" ", log=True) + + # reset model and return metrics + if was_training: + self.model.train() + self.config.log("Finished evaluating on " + self.eval_split + " split.", echo=self.verbose) + + for f in self.post_valid_hooks: + f(self, trace_entry) + + return trace_entry + + def _run(self) -> Dict[str, Any]: """ Compute evaluation metrics, output results to trace file """ raise NotImplementedError @@ -150,28 +205,30 @@ class EvalTrainingLossJob(EvaluationJob): def __init__(self, config: Config, dataset: Dataset, parent_job, model): super().__init__(config, dataset, parent_job, model) - self.is_prepared = False + self.is_prepared = True + + train_job_on_eval_split_config = config.clone() + train_job_on_eval_split_config.set("train.split", self.eval_split) + self._train_job = TrainingJob.create( + config=train_job_on_eval_split_config, parent_job=self, dataset=dataset + ) + + self._train_job_verbose = False if self.__class__ == EvalTrainingLossJob: for f in Job.job_created_hooks: f(self) @torch.no_grad() - def run(self) -> dict: + def _run(self) -> Dict[str, Any]: epoch_time = -time.time() - was_training = self.model.training - self.model.eval() - self.config.log( - "Evaluating on " - + self.eval_split - + " data (epoch {})...".format(self.epoch) - ) + self.epoch = self.parent_job.epoch epoch_time += time.time() - train_trace_entry = self.eval_train_loss_job.run_epoch( - echo_trace=False, forward_only=True + train_trace_entry = self._train_job.run_epoch( + verbose=self._train_job_verbose, forward_only=True ) # compute trace trace_entry = dict( @@ -183,30 +240,6 @@ def run(self) -> dict: event="eval_completed", avg_loss=train_trace_entry["avg_loss"], ) - for f in self.post_epoch_trace_hooks: - f(self, trace_entry) - - # if validation metric is not present, try to compute it - metric_name = self.config.get("valid.metric") - if metric_name not in trace_entry: - trace_entry[metric_name] = eval( - self.config.get("valid.metric_expr"), - None, - dict(config=self.config, **trace_entry), - ) - - # write out trace - trace_entry = self.trace(**trace_entry, echo=True, echo_prefix=" ", log=True) - - # reset model and return metrics - if was_training: - self.model.train() - self.config.log( - "Finished evaluating train loss on " + self.eval_split + " split." - ) - - for f in self.post_valid_hooks: - f(self, trace_entry) return trace_entry diff --git a/kge/job/train.py b/kge/job/train.py index f05a91e76..7449e0a9c 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -188,7 +188,7 @@ def run(self) -> None: # start a new epoch self.epoch += 1 self.config.log("Starting epoch {}...".format(self.epoch)) - trace_entry = self.run_epoch(echo_trace=True, forward_only=False) + trace_entry = self.run_epoch() for f in self.post_epoch_hooks: f(self, trace_entry) self.config.log("Finished epoch {}.".format(self.epoch)) @@ -295,7 +295,7 @@ def _load(self, checkpoint: Dict) -> str: ) ) - def run_epoch(self, echo_trace: bool, forward_only: bool) -> Dict[str, Any]: + def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str, Any]: "Runs an epoch and returns a trace entry." # prepare the job is not done already @@ -412,7 +412,6 @@ def run_epoch(self, echo_trace: bool, forward_only: bool) -> Dict[str, Any]: self.trace(**batch_trace, event="batch_completed") if echo_trace: self.config.print( - ( "\r" # go back + "{} batch{: " + str(1 + int(math.ceil(math.log10(len(self.loader))))) @@ -471,7 +470,7 @@ def run_epoch(self, echo_trace: bool, forward_only: bool) -> Dict[str, Any]: for f in self.post_epoch_trace_hooks: f(self, trace_entry) trace_entry = self.trace( - **trace_entry, echo=echo_trace, echo_prefix=" ", log=True + **trace_entry, echo=verbose, echo_prefix=" ", log=True ) return trace_entry @@ -552,6 +551,21 @@ def __init__(self, config, dataset, parent_job=None, model=None): ) ) + self.queries = None + self.labels = None + self.label_offsets = None + self.query_end_index = None + + config.log("Initializing 1-to-N training job...") + self.type_str = "KvsAll" + + if self.__class__ == TrainingJobKvsAll: + for f in Job.job_created_hooks: + f(self) + + def _prepare(self): + from kge.indexing import index_KvsAll_to_torch + #' for each query type: list of queries self.queries = {} @@ -567,16 +581,6 @@ def __init__(self, config, dataset, parent_job=None, model=None): #' example of that type in the list of all examples self.query_end_index = {} - config.log("Initializing 1-to-N training job...") - self.type_str = "KvsAll" - - if self.__class__ == TrainingJobKvsAll: - for f in Job.job_created_hooks: - f(self) - - def _prepare(self): - from kge.indexing import index_KvsAll_to_torch - # determine enabled query types self.query_types = [ key From da9bf6835b994769356ac7f14f029769ddf301d3 Mon Sep 17 00:00:00 2001 From: Rainer Gemulla Date: Mon, 25 May 2020 11:28:02 +0200 Subject: [PATCH 06/10] Update config-default.yaml --- kge/config-default.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kge/config-default.yaml b/kge/config-default.yaml index ec508986f..9e87f6365 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -327,8 +327,8 @@ eval: # mean_reciprocal_rank_filtered_with_test. filter_with_test: True - # Type of evaluation (entity_ranking or only training_loss, while entity - # ranking runs training_loss as well.) + # Type of evaluation (entity_ranking or training_loss). Currently, + # entity_ranking runs training_loss as well. type: entity_ranking # Compute Hits@K for these choices of K From 97b3904fbe0aaa4593a3ae7db0985228eda3a694 Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Mon, 25 May 2020 14:29:23 +0200 Subject: [PATCH 07/10] Adress code review eval.py: set negativ sampling filtering split to train.split rename to TrainingLossEvaluationJob train.py: initialize only things for train if TrainingJob is used for training more ifs - --- kge/job/eval.py | 11 ++++++----- kge/job/train.py | 17 +++++++++++++---- kge/util/sampler.py | 18 +++++++++++++----- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/kge/job/eval.py b/kge/job/eval.py index 1ab8049d0..9143eb2b3 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -69,8 +69,8 @@ def __init__(self, config, dataset, parent_job, model): # of EvaluationAjobs, such that users can configure combinations of # EvalJobs themselves. Then this can be removed. # See https://github.com/uma-pi1/kge/issues/102 - if not isinstance(self, EvalTrainingLossJob): - self.eval_train_loss_job = EvalTrainingLossJob( + if not isinstance(self, TrainingLossEvaluationJob): + self.eval_train_loss_job = TrainingLossEvaluationJob( config, dataset, parent_job=self, model=model ) self.eval_train_loss_job.verbose = False @@ -98,7 +98,7 @@ def create(config, dataset, parent_job=None, model=None): config, dataset, parent_job=parent_job, model=model ) elif config.get("eval.type") == "training_loss": - return EvalTrainingLossJob( + return TrainingLossEvaluationJob( config, dataset, parent_job=parent_job, model=model ) else: @@ -200,7 +200,7 @@ def create_from( return super().create_from(checkpoint, new_config, dataset, parent_job) -class EvalTrainingLossJob(EvaluationJob): +class TrainingLossEvaluationJob(EvaluationJob): """ Entity ranking evaluation protocol """ def __init__(self, config: Config, dataset: Dataset, parent_job, model): @@ -209,13 +209,14 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model): train_job_on_eval_split_config = config.clone() train_job_on_eval_split_config.set("train.split", self.eval_split) + train_job_on_eval_split_config.set("negative_sampling.filtering.split", self.config.get("train.split")) self._train_job = TrainingJob.create( config=train_job_on_eval_split_config, parent_job=self, dataset=dataset ) self._train_job_verbose = False - if self.__class__ == EvalTrainingLossJob: + if self.__class__ == TrainingLossEvaluationJob: for f in Job.job_created_hooks: f(self) diff --git a/kge/job/train.py b/kge/job/train.py index 7449e0a9c..7a8b77181 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -73,9 +73,14 @@ def __init__( self.trace_batch: bool = self.config.get("train.trace_level") == "batch" self.epoch: int = 0 self.is_prepared = False - self.model.train() if config.get("job.type") == "train": + + self.model.train() + + self.optimizer = KgeOptimizer.create(config, self.model) + self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) + valid_conf = config.clone() valid_conf.set("job.type", "eval") if self.config.get("valid.split") != "": @@ -320,7 +325,8 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str f(self) # process batch (preprocessing + forward pass + backward pass on loss) - self.optimizer.zero_grad() + if not forward_only: + self.optimizer.zero_grad() batch_result: TrainingJob._ProcessBatchResult = self._process_batch( batch_index, batch, forward_only ) @@ -454,7 +460,6 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str split=self.train_split, batches=len(self.loader), size=self.num_examples, - lr=[group["lr"] for group in self.optimizer.param_groups], avg_loss=sum_loss / self.num_examples, avg_penalty=sum_penalty / len(self.loader), avg_penalties={k: p / len(self.loader) for k, p in sum_penalties.items()}, @@ -467,6 +472,10 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str other_time=other_time, event="epoch_completed", ) + if not forward_only: + trace_entry.update( + lr=[group["lr"] for group in self.optimizer.param_groups], + ) for f in self.post_epoch_trace_hooks: f(self, trace_entry) trace_entry = self.trace( @@ -636,7 +645,7 @@ def collate(batch): num_ones = 0 for example_index in batch: start = 0 - for query_type_index, query_type in enumerate(self.query_types): + for query_type in self.query_types: end = self.query_end_index[query_type] if example_index < end: example_index -= start diff --git a/kge/util/sampler.py b/kge/util/sampler.py index 5f1ab31b8..56173e9fb 100644 --- a/kge/util/sampler.py +++ b/kge/util/sampler.py @@ -150,7 +150,12 @@ def _filter_and_resample( cols = [[P, O], [S, O], [S, P]][slot] pairs = positive_triples[:, cols] for i in range(positive_triples.size(0)): - positives = index.get((pairs[i][0].item(), pairs[i][1].item())).numpy() + pair = (pairs[i][0].item(), pairs[i][1].item()) + positives = ( + index.get(pair).numpy() + if pair in index + else torch.IntTensor([]).numpy() + ) # indices of samples that have to be sampled again resample_idx = where_in(negative_samples[i].numpy(), positives) # number of new samples needed @@ -231,9 +236,7 @@ def _sample_shared( # contain its positive, drop that positive. For all other rows, drop a random # position. shared_samples_index = {s: j for j, s in enumerate(shared_samples)} - replacement = np.random.choice( - num_distinct + 1, batch_size, replace=True - ) + replacement = np.random.choice(num_distinct + 1, batch_size, replace=True) drop = torch.tensor( [ shared_samples_index.get(s, replacement[i]) @@ -279,13 +282,18 @@ def _filter_and_resample_fast( positives_index = numba.typed.Dict() for i in range(batch_size): pair = (pairs[i][0], pairs[i][1]) - positives_index[pair] = index.get(pair).numpy() + positives_index[pair] = ( + index.get(pair).numpy() + if pair in index + else torch.IntTensor([]).numpy() + ) negative_samples = negative_samples.numpy() KgeUniformSampler._filter_and_resample_numba( negative_samples, pairs, positives_index, batch_size, int(voc_size), ) return torch.tensor(negative_samples, dtype=torch.int64) + @staticmethod @numba.njit def _filter_and_resample_numba( negative_samples, pairs, positives_index, batch_size, voc_size From 46063a71195f19c8c4e8d29558bbb27634eec16f Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Tue, 26 May 2020 20:44:46 +0200 Subject: [PATCH 08/10] Adress code review eval.py: set negativ sampling filtering split to train.split rename to TrainingLossEvaluationJob train.py: initialize only things for train if TrainingJob is used for training more ifs --- kge/job/eval.py | 4 +-- kge/job/train.py | 85 ++++++++++++++++++++++++++++++------------------ 2 files changed, 55 insertions(+), 34 deletions(-) diff --git a/kge/job/eval.py b/kge/job/eval.py index 9143eb2b3..7f7825a82 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -211,7 +211,7 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model): train_job_on_eval_split_config.set("train.split", self.eval_split) train_job_on_eval_split_config.set("negative_sampling.filtering.split", self.config.get("train.split")) self._train_job = TrainingJob.create( - config=train_job_on_eval_split_config, parent_job=self, dataset=dataset + config=train_job_on_eval_split_config, parent_job=self, dataset=dataset, initialize_for_forward_only=True, ) self._train_job_verbose = False @@ -229,7 +229,7 @@ def _run(self) -> Dict[str, Any]: epoch_time += time.time() train_trace_entry = self._train_job.run_epoch( - verbose=self._train_job_verbose, forward_only=True + verbose=self._train_job_verbose ) # compute trace trace_entry = dict( diff --git a/kge/job/train.py b/kge/job/train.py index 7a8b77181..61386c84c 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -53,7 +53,12 @@ class TrainingJob(Job): """ def __init__( - self, config: Config, dataset: Dataset, parent_job: Job = None, model=None + self, + config: Config, + dataset: Dataset, + parent_job: Job = None, + model = None, + initialize_for_forward_only=False, ) -> None: from kge.job import EvaluationJob @@ -73,8 +78,9 @@ def __init__( self.trace_batch: bool = self.config.get("train.trace_level") == "batch" self.epoch: int = 0 self.is_prepared = False + self.is_forward_only = initialize_for_forward_only - if config.get("job.type") == "train": + if not initialize_for_forward_only: self.model.train() @@ -128,21 +134,36 @@ def __init__( @staticmethod def create( - config: Config, dataset: Dataset, parent_job: Job = None, model=None + config: Config, + dataset: Dataset, + parent_job: Job = None, + model = None, + initialize_for_forward_only=False, ) -> "TrainingJob": """Factory method to create a training job.""" if config.get("train.type") == "KvsAll": - return TrainingJobKvsAll(config, dataset, parent_job, model=model) + return TrainingJobKvsAll( + config, dataset, parent_job, model, initialize_for_forward_only, + ) elif config.get("train.type") == "negative_sampling": - return TrainingJobNegativeSampling(config, dataset, parent_job, model=model) + return TrainingJobNegativeSampling( + config, dataset, parent_job, model, initialize_for_forward_only, + ) elif config.get("train.type") == "1vsAll": - return TrainingJob1vsAll(config, dataset, parent_job, model=model) + return TrainingJob1vsAll( + config, dataset, parent_job, model, initialize_for_forward_only, + ) else: # perhaps TODO: try class with specified name -> extensibility raise ValueError("train.type") def run(self) -> None: """Start/resume the training job and run to completion.""" + if self.is_forward_only: + raise Exception( + f"{self.__class__.__name__} was initialized for forward only. You can only call run_epoch()" + ) + self.config.log("Starting training...") checkpoint_every = self.config.get("train.checkpoint.every") checkpoint_keep = self.config.get("train.checkpoint.keep") @@ -300,7 +321,7 @@ def _load(self, checkpoint: Dict) -> str: ) ) - def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str, Any]: + def run_epoch(self, verbose: bool = True) -> Dict[str, Any]: "Runs an epoch and returns a trace entry." # prepare the job is not done already @@ -325,10 +346,10 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str f(self) # process batch (preprocessing + forward pass + backward pass on loss) - if not forward_only: + if not self.is_forward_only: self.optimizer.zero_grad() batch_result: TrainingJob._ProcessBatchResult = self._process_batch( - batch_index, batch, forward_only + batch_index, batch, self.is_forward_only ) sum_loss += batch_result.avg_loss * batch_result.size @@ -346,7 +367,7 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str batch_backward_time = batch_result.backward_time - time.time() penalty = 0.0 for index, (penalty_key, penalty_value_torch) in enumerate(penalties_torch): - if not forward_only: + if not self.is_forward_only: penalty_value_torch.backward() penalty += penalty_value_torch.item() sum_penalties[penalty_key] += penalty_value_torch.item() @@ -388,7 +409,7 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str # update parameters batch_optimizer_time = 0 - if not forward_only: + if not self.is_forward_only: batch_optimizer_time = -time.time() self.optimizer.step() batch_optimizer_time += time.time() @@ -472,7 +493,7 @@ def run_epoch(self, verbose: bool= True, forward_only: bool = False) -> Dict[str other_time=other_time, event="epoch_completed", ) - if not forward_only: + if not self.is_forward_only: trace_entry.update( lr=[group["lr"] for group in self.optimizer.param_groups], ) @@ -506,7 +527,7 @@ class _ProcessBatchResult: backward_time: float def _process_batch( - self, batch_index: int, batch, forward_only: bool + self, batch_index: int, batch ) -> "TrainingJob._ProcessBatchResult": "Run forward and backward pass on batch and return results." raise NotImplementedError @@ -523,8 +544,10 @@ class TrainingJobKvsAll(TrainingJob): - Example: a query + its labels, e.g., (John,marriedTo), [Jane] """ - def __init__(self, config, dataset, parent_job=None, model=None): - super().__init__(config, dataset, parent_job, model=model) + def __init__( + self, config, dataset, parent_job=None, model=None, initialize_for_forward_only=False + ): + super().__init__(config, dataset, parent_job, model, initialize_for_forward_only) self.label_smoothing = config.check_range( "KvsAll.label_smoothing", float("-inf"), 1.0, max_inclusive=False ) @@ -710,9 +733,7 @@ def collate(batch): return collate - def _process_batch( - self, batch_index, batch, forward_only: bool - ) -> TrainingJob._ProcessBatchResult: + def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() queries_batch = batch["queries"].to(self.device) @@ -783,7 +804,7 @@ def _process_batch( loss_value_total = loss_value.item() forward_time += time.time() backward_time -= time.time() - if not forward_only: + if not self.is_forward_only: loss_value.backward() backward_time += time.time() @@ -794,8 +815,10 @@ def _process_batch( class TrainingJobNegativeSampling(TrainingJob): - def __init__(self, config, dataset, parent_job=None, model=None): - super().__init__(config, dataset, parent_job, model=model) + def __init__( + self, config, dataset, parent_job=None, model=None, initialize_for_forward_only=False + ): + super().__init__(config, dataset, parent_job, model, initialize_for_forward_only) self._sampler = KgeSampler.create(config, "negative_sampling", dataset) self._implementation = self.config.check( "negative_sampling.implementation", ["triple", "all", "batch", "auto"], @@ -856,9 +879,7 @@ def collate(batch): return collate - def _process_batch( - self, batch_index, batch, forward_only: bool - ) -> TrainingJob._ProcessBatchResult: + def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() batch_triples = batch["triples"].to(self.device) @@ -1028,7 +1049,7 @@ def _process_batch( # backward pass for this chunk backward_time -= time.time() - if not forward_only: + if not self.is_forward_only: loss_value_torch.backward() backward_time += time.time() @@ -1041,8 +1062,10 @@ def _process_batch( class TrainingJob1vsAll(TrainingJob): """Samples SPO pairs and queries sp_ and _po, treating all other entities as negative.""" - def __init__(self, config, dataset, parent_job=None, model=None): - super().__init__(config, dataset, parent_job, model=model) + def __init__( + self, config, dataset, parent_job=None, model=None, initialize_for_forward_only=False + ): + super().__init__(config, dataset, parent_job, model, initialize_for_forward_only) config.log("Initializing spo training job...") self.type_str = "1vsAll" @@ -1066,9 +1089,7 @@ def _prepare(self): pin_memory=self.config.get("train.pin_memory"), ) - def _process_batch( - self, batch_index, batch, forward_only: bool - ) -> TrainingJob._ProcessBatchResult: + def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: # prepare prepare_time = -time.time() triples = batch["triples"].to(self.device) @@ -1082,7 +1103,7 @@ def _process_batch( loss_value = loss_value_sp.item() forward_time += time.time() backward_time = -time.time() - if not forward_only: + if not self.is_forward_only: loss_value_sp.backward() backward_time += time.time() @@ -1093,7 +1114,7 @@ def _process_batch( loss_value += loss_value_po.item() forward_time += time.time() backward_time -= time.time() - if not forward_only: + if not self.is_forward_only: loss_value_po.backward() backward_time += time.time() From d965e9ac02e31465b4ecb44cb3e104e6f6e18ef3 Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Wed, 27 May 2020 01:51:44 +0200 Subject: [PATCH 09/10] Adress code review eval.py: set splits train.py: use self.is_forward_only use dataset.index to load label_index for KvsAll sampler.py: rename filtering splits use dataset.index to load label_index for filtering splits config-default.yaml: add label_splits config for KvsAll indexing.py + misc.py: add index and helper funcs to merge KvsAll dicts --- kge/config-default.yaml | 6 +++++- kge/indexing.py | 19 +++++++++++++++++++ kge/job/eval.py | 34 +++++++++++++++++++++++++--------- kge/job/train.py | 11 +++++++++-- kge/misc.py | 31 +++++++++++++++++++++++++++++++ kge/util/sampler.py | 23 ++++++++++++++--------- 6 files changed, 103 insertions(+), 21 deletions(-) diff --git a/kge/config-default.yaml b/kge/config-default.yaml index 9e87f6365..fe97cfb4f 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -235,6 +235,10 @@ KvsAll: s_o: False _po: True + # Dataset splits from which the labels for a query are taken from. Default: If + # nothing is specified, then the train split is used. + label_splits: [] + # Options for negative sampling training (train.type=="negative_sampling") negative_sampling: # Negative sampler to use @@ -264,7 +268,7 @@ negative_sampling: p: False # as above o: False # as above - split: '' # split containing the positives; default is train.split + splits: [] # splits containing the positives; default is train.split # Implementation to use for filtering. # standard: use slow generic implementation, available for all samplers diff --git a/kge/indexing.py b/kge/indexing.py index b4e136ef4..e394323c6 100644 --- a/kge/indexing.py +++ b/kge/indexing.py @@ -2,6 +2,7 @@ from collections import defaultdict, OrderedDict import numba import numpy as np +from kge.misc import powerset, merge_dicts_of_1dim_torch_tensors def _group_by(keys, values) -> dict: @@ -222,6 +223,15 @@ def _invert_ids(dataset, obj: str): dataset.config.log(f"Indexed {len(inv)} {obj} ids", prefix=" ") +def merge_KvsAll_indexes(dataset, split, key): + value = dict([("sp", "o"), ("po", "s"), ("so", "p")])[key] + split_combi_str = "_".join(sorted(split)) + index_name = f"{split_combi_str}_{key}_to_{value}" + indexes = [dataset.index(f"{_split}_{key}_to_{value}") for _split in split] + dataset._indexes[index_name] = merge_dicts_of_1dim_torch_tensors(indexes) + return dataset._indexes[index_name] + + def create_default_index_functions(dataset: "Dataset"): for split in dataset.files_of_type("triples"): for key, value in [("sp", "o"), ("po", "s"), ("so", "p")]: @@ -229,6 +239,15 @@ def create_default_index_functions(dataset: "Dataset"): dataset.index_functions[f"{split}_{key}_to_{value}"] = IndexWrapper( index_KvsAll, split=split, key=key ) + # create all combinations of splits of length 2 and 3 + for split_combi in powerset(dataset.files_of_type("triples"), [2, 3]): + for key, value in [("sp", "o"), ("po", "s"), ("so", "p")]: + split_combi_str = "_".join(sorted(split_combi)) + index_name = f"{split_combi_str}_{key}_to_{value}" + dataset.index_functions[index_name] = IndexWrapper( + merge_KvsAll_indexes, split=split_combi, key=key + ) + dataset.index_functions["relation_types"] = index_relation_types dataset.index_functions["relations_per_type"] = index_relation_types dataset.index_functions["frequency_percentiles"] = index_frequency_percentiles diff --git a/kge/job/eval.py b/kge/job/eval.py index 7f7825a82..0787bebd7 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -122,7 +122,7 @@ def run(self) -> Dict[str, Any]: "Evaluating on " + self.eval_split + " data (epoch {})...".format(self.epoch), - echo=self.verbose + echo=self.verbose, ) trace_entry = self._run() @@ -140,12 +140,16 @@ def run(self) -> Dict[str, Any]: f(self, trace_entry) # write out trace - trace_entry = self.trace(**trace_entry, echo=self.verbose, echo_prefix=" ", log=True) + trace_entry = self.trace( + **trace_entry, echo=self.verbose, echo_prefix=" ", log=True + ) # reset model and return metrics if was_training: self.model.train() - self.config.log("Finished evaluating on " + self.eval_split + " split.", echo=self.verbose) + self.config.log( + "Finished evaluating on " + self.eval_split + " split.", echo=self.verbose + ) for f in self.post_valid_hooks: f(self, trace_entry) @@ -209,11 +213,25 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model): train_job_on_eval_split_config = config.clone() train_job_on_eval_split_config.set("train.split", self.eval_split) - train_job_on_eval_split_config.set("negative_sampling.filtering.split", self.config.get("train.split")) + train_job_on_eval_split_config.set( + "negative_sampling.filtering.splits", + [self.config.get("train.split"), self.eval_split] + ["valid"] + if self.eval_split == "test" + else [], + ) + train_job_on_eval_split_config.set( + "KvsAll.label_splits", + [self.config.get("train.split"), self.eval_split] + ["valid"] + if self.eval_split == "test" + else [], + ) self._train_job = TrainingJob.create( - config=train_job_on_eval_split_config, parent_job=self, dataset=dataset, initialize_for_forward_only=True, + config=train_job_on_eval_split_config, + parent_job=self, + dataset=dataset, + initialize_for_forward_only=True, ) - + self._train_job.model = model self._train_job_verbose = False if self.__class__ == TrainingLossEvaluationJob: @@ -228,9 +246,7 @@ def _run(self) -> Dict[str, Any]: self.epoch = self.parent_job.epoch epoch_time += time.time() - train_trace_entry = self._train_job.run_epoch( - verbose=self._train_job_verbose - ) + train_trace_entry = self._train_job.run_epoch(verbose=self._train_job_verbose) # compute trace trace_entry = dict( type="training_loss", diff --git a/kge/job/train.py b/kge/job/train.py index 61386c84c..fa21ca942 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -349,7 +349,7 @@ def run_epoch(self, verbose: bool = True) -> Dict[str, Any]: if not self.is_forward_only: self.optimizer.zero_grad() batch_result: TrainingJob._ProcessBatchResult = self._process_batch( - batch_index, batch, self.is_forward_only + batch_index, batch ) sum_loss += batch_result.avg_loss * batch_result.size @@ -583,6 +583,10 @@ def __init__( ) ) + self.label_splits = set(self.config.get("KvsAll.label_splits")) | set( + [self.train_split] + ) + self.queries = None self.labels = None self.label_offsets = None @@ -628,7 +632,10 @@ def _prepare(self): if query_type == "sp_" else ("so_to_p" if query_type == "s_o" else "po_to_s") ) - index = self.dataset.index(f"{self.train_split}_{index_type}") + split_combi_str = "_".join(sorted(self.label_splits)) + label_index = self.dataset.index(f"{split_combi_str}_{index_type}") + query_index = self.dataset.index(f"{self.train_split}_{index_type}") + index = {k: v for k, v in label_index.items() if k in query_index} self.num_examples += len(index) self.query_end_index[query_type] = self.num_examples diff --git a/kge/misc.py b/kge/misc.py index 90db3750c..01613d37b 100644 --- a/kge/misc.py +++ b/kge/misc.py @@ -1,3 +1,5 @@ +import torch +from itertools import chain, combinations from typing import List from torch import nn as nn @@ -6,6 +8,7 @@ import inspect import subprocess + def is_number(s, number_type): """ Returns True is string is a number. """ try: @@ -15,6 +18,33 @@ def is_number(s, number_type): return False +def merge_dicts_of_1dim_torch_tensors(keys_from_dols, values_from_dols=None): + if values_from_dols is None: + values_from_dols = keys_from_dols + keys = set(chain.from_iterable([d.keys() for d in keys_from_dols])) + no = torch.tensor([]).int() + return dict( + ( + k, + torch.cat([d.get(k, no) for d in values_from_dols]).sort()[0], + ) + for k in keys + ) + + +def powerset(iterable, filter_lens=None): + s = list(iterable) + return list( + map( + sorted, + filter( + lambda l: len(l) in filter_lens if filter_lens else True, + chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)), + ), + ) + ) + + # from https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script def get_git_revision_hash(): try: @@ -69,6 +99,7 @@ def is_exe(fpath): def kge_base_dir(): import kge + return os.path.abspath(filename_in_module(kge, "..")) diff --git a/kge/util/sampler.py b/kge/util/sampler.py index 56173e9fb..daff4a87b 100644 --- a/kge/util/sampler.py +++ b/kge/util/sampler.py @@ -3,7 +3,7 @@ import random import torch -from typing import Optional +from typing import Optional, List import numpy as np import numba @@ -29,9 +29,9 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset): "Without replacement sampling is only supported when " "shared negative sampling is enabled." ) - self.filtering_split = config.get("negative_sampling.filtering.split") - if self.filtering_split == "": - self.filtering_split = config.get("train.split") + self.filtering_splits: List[str] = config.get("negative_sampling.filtering.splits") + if len(self.filtering_splits) == 0: + self.filtering_splits.append(config.get("train.split")) for slot in SLOTS: slot_str = SLOT_STR[slot] self.num_samples[slot] = self.get_option(f"num_samples.{slot_str}") @@ -43,7 +43,10 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset): # otherwise every worker would create every index again and again if self.filter_positives[slot]: pair = ["po", "so", "sp"][slot] - dataset.index(f"{self.filtering_split}_{pair}_to_{slot_str}") + for filtering_split in self.filtering_splits: + dataset.index(f"{filtering_split}_{pair}_to_{slot_str}") + filtering_splits_str = '_'.join(sorted(self.filtering_splits)) + dataset.index(f"{filtering_splits_str}_{pair}_to_{slot_str}") if any(self.filter_positives): if self.shared: raise ValueError( @@ -144,11 +147,12 @@ def _filter_and_resample( """Filter and resample indices until only negatives have been created. """ pair_str = ["po", "so", "sp"][slot] # holding the positive indices for the respective pair - index = self.dataset.index( - f"{self.filtering_split}_{pair_str}_to_{SLOT_STR[slot]}" - ) cols = [[P, O], [S, O], [S, P]][slot] pairs = positive_triples[:, cols] + split_combi_str = "_".join(sorted(self.filtering_splits)) + index = self.dataset.index( + f"{split_combi_str}_{pair_str}_to_{SLOT_STR[slot]}" + ) for i in range(positive_triples.size(0)): pair = (pairs[i][0].item(), pairs[i][1].item()) positives = ( @@ -268,8 +272,9 @@ def _filter_and_resample_fast( ): pair_str = ["po", "so", "sp"][slot] # holding the positive indices for the respective pair + split_combi_str = "_".join(sorted(self.filtering_splits)) index = self.dataset.index( - f"{self.filtering_split}_{pair_str}_to_{SLOT_STR[slot]}" + f"{split_combi_str}_{pair_str}_to_{SLOT_STR[slot]}" ) cols = [[P, O], [S, O], [S, P]][slot] pairs = positive_triples[:, cols].numpy() From cd3bd49d186e6c3e435ccfbc7e3deea233a957a3 Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Sat, 16 May 2020 23:30:48 +0200 Subject: [PATCH 10/10] Changes from master --- kge/job/eval.py | 14 +++++--------- kge/job/train.py | 11 ++--------- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/kge/job/eval.py b/kge/job/eval.py index 0787bebd7..491f4fb12 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -1,14 +1,10 @@ import time -from typing import Any, Dict +from typing import Any, Optional, Dict import torch from kge import Config, Dataset -from kge import Config, Dataset -from kge.job import Job -from kge.model import KgeModel - -from typing import Dict, Union, Optional +from kge.job import Job, TrainingJob class EvaluationJob(Job): @@ -213,6 +209,7 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model): train_job_on_eval_split_config = config.clone() train_job_on_eval_split_config.set("train.split", self.eval_split) + train_job_on_eval_split_config.set("verbose", False) train_job_on_eval_split_config.set( "negative_sampling.filtering.splits", [self.config.get("train.split"), self.eval_split] + ["valid"] @@ -229,9 +226,9 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model): config=train_job_on_eval_split_config, parent_job=self, dataset=dataset, + model=model, initialize_for_forward_only=True, ) - self._train_job.model = model self._train_job_verbose = False if self.__class__ == TrainingLossEvaluationJob: @@ -246,7 +243,7 @@ def _run(self) -> Dict[str, Any]: self.epoch = self.parent_job.epoch epoch_time += time.time() - train_trace_entry = self._train_job.run_epoch(verbose=self._train_job_verbose) + train_trace_entry = self._train_job.run_epoch() # compute trace trace_entry = dict( type="training_loss", @@ -263,7 +260,6 @@ def _run(self) -> Dict[str, Any]: # HISTOGRAM COMPUTATION ############################################################### - def __initialize_hist(hists, key, job): """If there is no histogram with given `key` in `hists`, add an empty one.""" if key not in hists: diff --git a/kge/job/train.py b/kge/job/train.py index fa21ca942..6a0532b7a 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -1,8 +1,6 @@ -import itertools import os import math import time -import sys from collections import defaultdict from dataclasses import dataclass @@ -17,7 +15,7 @@ from kge.model import KgeModel from kge.util import KgeLoss, KgeOptimizer, KgeSampler, KgeLRScheduler -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional import kge.job.util SLOTS = [0, 1, 2] @@ -67,8 +65,6 @@ def __init__( self.model: KgeModel = KgeModel.create(config, dataset) else: self.model: KgeModel = model - self.optimizer = KgeOptimizer.create(config, self.model) - self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) self.loss = KgeLoss.create(config) self.abort_on_nan: bool = config.get("train.abort_on_nan") self.batch_size: int = config.get("train.batch_size") @@ -130,8 +126,6 @@ def __init__( for f in Job.job_created_hooks: f(self) - self.model.train() - @staticmethod def create( config: Config, @@ -823,8 +817,7 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: class TrainingJobNegativeSampling(TrainingJob): def __init__( - self, config, dataset, parent_job=None, model=None, initialize_for_forward_only=False - ): + self, config, dataset, parent_job=None, model=None, initialize_for_forward_only=False): super().__init__(config, dataset, parent_job, model, initialize_for_forward_only) self._sampler = KgeSampler.create(config, "negative_sampling", dataset) self._implementation = self.config.check(