From cd3bd49d186e6c3e435ccfbc7e3deea233a957a3 Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Sat, 16 May 2020 23:30:48 +0200 Subject: [PATCH] Changes from master --- kge/job/eval.py | 14 +++++--------- kge/job/train.py | 11 ++--------- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/kge/job/eval.py b/kge/job/eval.py index 0787bebd7..491f4fb12 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -1,14 +1,10 @@ import time -from typing import Any, Dict +from typing import Any, Optional, Dict import torch from kge import Config, Dataset -from kge import Config, Dataset -from kge.job import Job -from kge.model import KgeModel - -from typing import Dict, Union, Optional +from kge.job import Job, TrainingJob class EvaluationJob(Job): @@ -213,6 +209,7 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model): train_job_on_eval_split_config = config.clone() train_job_on_eval_split_config.set("train.split", self.eval_split) + train_job_on_eval_split_config.set("verbose", False) train_job_on_eval_split_config.set( "negative_sampling.filtering.splits", [self.config.get("train.split"), self.eval_split] + ["valid"] @@ -229,9 +226,9 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model): config=train_job_on_eval_split_config, parent_job=self, dataset=dataset, + model=model, initialize_for_forward_only=True, ) - self._train_job.model = model self._train_job_verbose = False if self.__class__ == TrainingLossEvaluationJob: @@ -246,7 +243,7 @@ def _run(self) -> Dict[str, Any]: self.epoch = self.parent_job.epoch epoch_time += time.time() - train_trace_entry = self._train_job.run_epoch(verbose=self._train_job_verbose) + train_trace_entry = self._train_job.run_epoch() # compute trace trace_entry = dict( type="training_loss", @@ -263,7 +260,6 @@ def _run(self) -> Dict[str, Any]: # HISTOGRAM COMPUTATION ############################################################### - def __initialize_hist(hists, key, job): """If there is no histogram with given `key` in `hists`, add an empty one.""" if key not in hists: diff --git a/kge/job/train.py b/kge/job/train.py index fa21ca942..6a0532b7a 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -1,8 +1,6 @@ -import itertools import os import math import time -import sys from collections import defaultdict from dataclasses import dataclass @@ -17,7 +15,7 @@ from kge.model import KgeModel from kge.util import KgeLoss, KgeOptimizer, KgeSampler, KgeLRScheduler -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional import kge.job.util SLOTS = [0, 1, 2] @@ -67,8 +65,6 @@ def __init__( self.model: KgeModel = KgeModel.create(config, dataset) else: self.model: KgeModel = model - self.optimizer = KgeOptimizer.create(config, self.model) - self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) self.loss = KgeLoss.create(config) self.abort_on_nan: bool = config.get("train.abort_on_nan") self.batch_size: int = config.get("train.batch_size") @@ -130,8 +126,6 @@ def __init__( for f in Job.job_created_hooks: f(self) - self.model.train() - @staticmethod def create( config: Config, @@ -823,8 +817,7 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: class TrainingJobNegativeSampling(TrainingJob): def __init__( - self, config, dataset, parent_job=None, model=None, initialize_for_forward_only=False - ): + self, config, dataset, parent_job=None, model=None, initialize_for_forward_only=False): super().__init__(config, dataset, parent_job, model, initialize_for_forward_only) self._sampler = KgeSampler.create(config, "negative_sampling", dataset) self._implementation = self.config.check(