Skip to content

Commit

Permalink
Adress code review
Browse files Browse the repository at this point in the history
eval.py:
set splits

train.py:
use self.is_forward_only
use dataset.index to load label_index for KvsAll

sampler.py:
rename filtering splits
use dataset.index to load label_index for filtering splits

config-default.yaml:
add label_splits config for KvsAll

indexing.py + misc.py:
add index and helper funcs to merge KvsAll dicts
  • Loading branch information
samuelbroscheit authored and rgemulla committed May 28, 2020
1 parent 46063a7 commit d965e9a
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 21 deletions.
6 changes: 5 additions & 1 deletion kge/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ KvsAll:
s_o: False
_po: True

# Dataset splits from which the labels for a query are taken from. Default: If
# nothing is specified, then the train split is used.
label_splits: []

# Options for negative sampling training (train.type=="negative_sampling")
negative_sampling:
# Negative sampler to use
Expand Down Expand Up @@ -264,7 +268,7 @@ negative_sampling:
p: False # as above
o: False # as above

split: '' # split containing the positives; default is train.split
splits: [] # splits containing the positives; default is train.split

# Implementation to use for filtering.
# standard: use slow generic implementation, available for all samplers
Expand Down
19 changes: 19 additions & 0 deletions kge/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict, OrderedDict
import numba
import numpy as np
from kge.misc import powerset, merge_dicts_of_1dim_torch_tensors


def _group_by(keys, values) -> dict:
Expand Down Expand Up @@ -222,13 +223,31 @@ def _invert_ids(dataset, obj: str):
dataset.config.log(f"Indexed {len(inv)} {obj} ids", prefix=" ")


def merge_KvsAll_indexes(dataset, split, key):
value = dict([("sp", "o"), ("po", "s"), ("so", "p")])[key]
split_combi_str = "_".join(sorted(split))
index_name = f"{split_combi_str}_{key}_to_{value}"
indexes = [dataset.index(f"{_split}_{key}_to_{value}") for _split in split]
dataset._indexes[index_name] = merge_dicts_of_1dim_torch_tensors(indexes)
return dataset._indexes[index_name]


def create_default_index_functions(dataset: "Dataset"):
for split in dataset.files_of_type("triples"):
for key, value in [("sp", "o"), ("po", "s"), ("so", "p")]:
# self assignment needed to capture the loop var
dataset.index_functions[f"{split}_{key}_to_{value}"] = IndexWrapper(
index_KvsAll, split=split, key=key
)
# create all combinations of splits of length 2 and 3
for split_combi in powerset(dataset.files_of_type("triples"), [2, 3]):
for key, value in [("sp", "o"), ("po", "s"), ("so", "p")]:
split_combi_str = "_".join(sorted(split_combi))
index_name = f"{split_combi_str}_{key}_to_{value}"
dataset.index_functions[index_name] = IndexWrapper(
merge_KvsAll_indexes, split=split_combi, key=key
)

dataset.index_functions["relation_types"] = index_relation_types
dataset.index_functions["relations_per_type"] = index_relation_types
dataset.index_functions["frequency_percentiles"] = index_frequency_percentiles
Expand Down
34 changes: 25 additions & 9 deletions kge/job/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def run(self) -> Dict[str, Any]:
"Evaluating on "
+ self.eval_split
+ " data (epoch {})...".format(self.epoch),
echo=self.verbose
echo=self.verbose,
)

trace_entry = self._run()
Expand All @@ -140,12 +140,16 @@ def run(self) -> Dict[str, Any]:
f(self, trace_entry)

# write out trace
trace_entry = self.trace(**trace_entry, echo=self.verbose, echo_prefix=" ", log=True)
trace_entry = self.trace(
**trace_entry, echo=self.verbose, echo_prefix=" ", log=True
)

# reset model and return metrics
if was_training:
self.model.train()
self.config.log("Finished evaluating on " + self.eval_split + " split.", echo=self.verbose)
self.config.log(
"Finished evaluating on " + self.eval_split + " split.", echo=self.verbose
)

for f in self.post_valid_hooks:
f(self, trace_entry)
Expand Down Expand Up @@ -209,11 +213,25 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model):

train_job_on_eval_split_config = config.clone()
train_job_on_eval_split_config.set("train.split", self.eval_split)
train_job_on_eval_split_config.set("negative_sampling.filtering.split", self.config.get("train.split"))
train_job_on_eval_split_config.set(
"negative_sampling.filtering.splits",
[self.config.get("train.split"), self.eval_split] + ["valid"]
if self.eval_split == "test"
else [],
)
train_job_on_eval_split_config.set(
"KvsAll.label_splits",
[self.config.get("train.split"), self.eval_split] + ["valid"]
if self.eval_split == "test"
else [],
)
self._train_job = TrainingJob.create(
config=train_job_on_eval_split_config, parent_job=self, dataset=dataset, initialize_for_forward_only=True,
config=train_job_on_eval_split_config,
parent_job=self,
dataset=dataset,
initialize_for_forward_only=True,
)

self._train_job.model = model
self._train_job_verbose = False

if self.__class__ == TrainingLossEvaluationJob:
Expand All @@ -228,9 +246,7 @@ def _run(self) -> Dict[str, Any]:
self.epoch = self.parent_job.epoch
epoch_time += time.time()

train_trace_entry = self._train_job.run_epoch(
verbose=self._train_job_verbose
)
train_trace_entry = self._train_job.run_epoch(verbose=self._train_job_verbose)
# compute trace
trace_entry = dict(
type="training_loss",
Expand Down
11 changes: 9 additions & 2 deletions kge/job/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def run_epoch(self, verbose: bool = True) -> Dict[str, Any]:
if not self.is_forward_only:
self.optimizer.zero_grad()
batch_result: TrainingJob._ProcessBatchResult = self._process_batch(
batch_index, batch, self.is_forward_only
batch_index, batch
)
sum_loss += batch_result.avg_loss * batch_result.size

Expand Down Expand Up @@ -583,6 +583,10 @@ def __init__(
)
)

self.label_splits = set(self.config.get("KvsAll.label_splits")) | set(
[self.train_split]
)

self.queries = None
self.labels = None
self.label_offsets = None
Expand Down Expand Up @@ -628,7 +632,10 @@ def _prepare(self):
if query_type == "sp_"
else ("so_to_p" if query_type == "s_o" else "po_to_s")
)
index = self.dataset.index(f"{self.train_split}_{index_type}")
split_combi_str = "_".join(sorted(self.label_splits))
label_index = self.dataset.index(f"{split_combi_str}_{index_type}")
query_index = self.dataset.index(f"{self.train_split}_{index_type}")
index = {k: v for k, v in label_index.items() if k in query_index}
self.num_examples += len(index)
self.query_end_index[query_type] = self.num_examples

Expand Down
31 changes: 31 additions & 0 deletions kge/misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch
from itertools import chain, combinations
from typing import List

from torch import nn as nn
Expand All @@ -6,6 +8,7 @@
import inspect
import subprocess


def is_number(s, number_type):
""" Returns True is string is a number. """
try:
Expand All @@ -15,6 +18,33 @@ def is_number(s, number_type):
return False


def merge_dicts_of_1dim_torch_tensors(keys_from_dols, values_from_dols=None):
if values_from_dols is None:
values_from_dols = keys_from_dols
keys = set(chain.from_iterable([d.keys() for d in keys_from_dols]))
no = torch.tensor([]).int()
return dict(
(
k,
torch.cat([d.get(k, no) for d in values_from_dols]).sort()[0],
)
for k in keys
)


def powerset(iterable, filter_lens=None):
s = list(iterable)
return list(
map(
sorted,
filter(
lambda l: len(l) in filter_lens if filter_lens else True,
chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)),
),
)
)


# from https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script
def get_git_revision_hash():
try:
Expand Down Expand Up @@ -69,6 +99,7 @@ def is_exe(fpath):

def kge_base_dir():
import kge

return os.path.abspath(filename_in_module(kge, ".."))


Expand Down
23 changes: 14 additions & 9 deletions kge/util/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import random
import torch
from typing import Optional
from typing import Optional, List
import numpy as np
import numba

Expand All @@ -29,9 +29,9 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset):
"Without replacement sampling is only supported when "
"shared negative sampling is enabled."
)
self.filtering_split = config.get("negative_sampling.filtering.split")
if self.filtering_split == "":
self.filtering_split = config.get("train.split")
self.filtering_splits: List[str] = config.get("negative_sampling.filtering.splits")
if len(self.filtering_splits) == 0:
self.filtering_splits.append(config.get("train.split"))
for slot in SLOTS:
slot_str = SLOT_STR[slot]
self.num_samples[slot] = self.get_option(f"num_samples.{slot_str}")
Expand All @@ -43,7 +43,10 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset):
# otherwise every worker would create every index again and again
if self.filter_positives[slot]:
pair = ["po", "so", "sp"][slot]
dataset.index(f"{self.filtering_split}_{pair}_to_{slot_str}")
for filtering_split in self.filtering_splits:
dataset.index(f"{filtering_split}_{pair}_to_{slot_str}")
filtering_splits_str = '_'.join(sorted(self.filtering_splits))
dataset.index(f"{filtering_splits_str}_{pair}_to_{slot_str}")
if any(self.filter_positives):
if self.shared:
raise ValueError(
Expand Down Expand Up @@ -144,11 +147,12 @@ def _filter_and_resample(
"""Filter and resample indices until only negatives have been created. """
pair_str = ["po", "so", "sp"][slot]
# holding the positive indices for the respective pair
index = self.dataset.index(
f"{self.filtering_split}_{pair_str}_to_{SLOT_STR[slot]}"
)
cols = [[P, O], [S, O], [S, P]][slot]
pairs = positive_triples[:, cols]
split_combi_str = "_".join(sorted(self.filtering_splits))
index = self.dataset.index(
f"{split_combi_str}_{pair_str}_to_{SLOT_STR[slot]}"
)
for i in range(positive_triples.size(0)):
pair = (pairs[i][0].item(), pairs[i][1].item())
positives = (
Expand Down Expand Up @@ -268,8 +272,9 @@ def _filter_and_resample_fast(
):
pair_str = ["po", "so", "sp"][slot]
# holding the positive indices for the respective pair
split_combi_str = "_".join(sorted(self.filtering_splits))
index = self.dataset.index(
f"{self.filtering_split}_{pair_str}_to_{SLOT_STR[slot]}"
f"{split_combi_str}_{pair_str}_to_{SLOT_STR[slot]}"
)
cols = [[P, O], [S, O], [S, P]][slot]
pairs = positive_triples[:, cols].numpy()
Expand Down

0 comments on commit d965e9a

Please sign in to comment.