Skip to content

Commit

Permalink
Changes from master
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbroscheit authored and rgemulla committed May 28, 2020
1 parent d965e9a commit cd3bd49
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 18 deletions.
14 changes: 5 additions & 9 deletions kge/job/eval.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import time
from typing import Any, Dict
from typing import Any, Optional, Dict

import torch
from kge import Config, Dataset

from kge import Config, Dataset
from kge.job import Job
from kge.model import KgeModel

from typing import Dict, Union, Optional
from kge.job import Job, TrainingJob


class EvaluationJob(Job):
Expand Down Expand Up @@ -213,6 +209,7 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model):

train_job_on_eval_split_config = config.clone()
train_job_on_eval_split_config.set("train.split", self.eval_split)
train_job_on_eval_split_config.set("verbose", False)
train_job_on_eval_split_config.set(
"negative_sampling.filtering.splits",
[self.config.get("train.split"), self.eval_split] + ["valid"]
Expand All @@ -229,9 +226,9 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model):
config=train_job_on_eval_split_config,
parent_job=self,
dataset=dataset,
model=model,
initialize_for_forward_only=True,
)
self._train_job.model = model
self._train_job_verbose = False

if self.__class__ == TrainingLossEvaluationJob:
Expand All @@ -246,7 +243,7 @@ def _run(self) -> Dict[str, Any]:
self.epoch = self.parent_job.epoch
epoch_time += time.time()

train_trace_entry = self._train_job.run_epoch(verbose=self._train_job_verbose)
train_trace_entry = self._train_job.run_epoch()
# compute trace
trace_entry = dict(
type="training_loss",
Expand All @@ -263,7 +260,6 @@ def _run(self) -> Dict[str, Any]:

# HISTOGRAM COMPUTATION ###############################################################


def __initialize_hist(hists, key, job):
"""If there is no histogram with given `key` in `hists`, add an empty one."""
if key not in hists:
Expand Down
11 changes: 2 additions & 9 deletions kge/job/train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import itertools
import os
import math
import time
import sys
from collections import defaultdict

from dataclasses import dataclass
Expand All @@ -17,7 +15,7 @@
from kge.model import KgeModel

from kge.util import KgeLoss, KgeOptimizer, KgeSampler, KgeLRScheduler
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional
import kge.job.util

SLOTS = [0, 1, 2]
Expand Down Expand Up @@ -67,8 +65,6 @@ def __init__(
self.model: KgeModel = KgeModel.create(config, dataset)
else:
self.model: KgeModel = model
self.optimizer = KgeOptimizer.create(config, self.model)
self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer)
self.loss = KgeLoss.create(config)
self.abort_on_nan: bool = config.get("train.abort_on_nan")
self.batch_size: int = config.get("train.batch_size")
Expand Down Expand Up @@ -130,8 +126,6 @@ def __init__(
for f in Job.job_created_hooks:
f(self)

self.model.train()

@staticmethod
def create(
config: Config,
Expand Down Expand Up @@ -823,8 +817,7 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult:

class TrainingJobNegativeSampling(TrainingJob):
def __init__(
self, config, dataset, parent_job=None, model=None, initialize_for_forward_only=False
):
self, config, dataset, parent_job=None, model=None, initialize_for_forward_only=False):
super().__init__(config, dataset, parent_job, model, initialize_for_forward_only)
self._sampler = KgeSampler.create(config, "negative_sampling", dataset)
self._implementation = self.config.check(
Expand Down

0 comments on commit cd3bd49

Please sign in to comment.