Skip to content

Commit

Permalink
Allow to use labels for triple classification from data
Browse files Browse the repository at this point in the history
  • Loading branch information
Nzteb committed Jun 4, 2020
1 parent 5b1a5b4 commit 8a4416f
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 7 deletions.
11 changes: 7 additions & 4 deletions examples/toy-complex-train-tripleclass.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
job.type: train
dataset.name: toy
model: distmult
dataset.name: wn11
model: complex
train:
optimizer: Adagrad
optimizer_args:
Expand All @@ -11,8 +11,11 @@ lookup_embedder.dim: 100
lookup_embedder.initialize: xavier_uniform_
eval:
type: triple_classification
metrics_per.relation: False
triple_classification_random_seed: False
triple_classification.random_seed: False
triple_classification.negatives_from: data


valid.metric: accuracy
valid.every: 1


10 changes: 9 additions & 1 deletion kge/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ eval:
# mean_reciprocal_rank_filtered_with_test.
filter_with_test: True

# Type of evaluation (entity_ranking only at the moment)
# Type of evaluation (entity_ranking, triple_classification)
type: entity_ranking

# How to handle cases with ties between the correct answer and other answers, e.g.,
Expand Down Expand Up @@ -423,6 +423,14 @@ valid:

## EVALUATION ##################################################################

triple_classification:
random_seed: False
# How to obtain negative triple labels. Possible values are:
# - corruption: Create negatives by randomly corrupting existing triples (positives)
# - data : Obtain negative labels from the dataset. This assumes the data set
# contains the splits 'valid_negatives' and 'test_negatives'
negatives_from: corruption


## HYPERPARAMETER SEARCH #######################################################

Expand Down
54 changes: 52 additions & 2 deletions kge/job/triple_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset):
self.o_entities = None
uni_sampler_config = config.clone()
# uni_sampler_config.set("negative_sampling.num_samples.s", self.get_option("num_samples.s"))
# TODO this is redundant as uniform.sample() is called with "num_samples" here in self.sample()
uni_sampler_config.set("negative_sampling.num_samples.s", 1)
# TODO maybe changing the API of KGEsampler.sample() to also accept a param "filter"
# as it is the case already with "num_samples"
# then we would not rely here on configuration options which actually
# belong to a training job
uni_sampler_config.set("negative_sampling.filtering.s", True)
# uni_sampler_config.set("negative_sampling.num_samples.o", self.get_option("num_samples.o"))
uni_sampler_config.set("negative_sampling.num_samples.o", 1)
Expand All @@ -31,6 +36,7 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset):

def _prepare(self,):
train_data = self.dataset.split("train")
#TODO probably outdated as it refers to out-commented code
self.s_entities = train_data[:, S].unique().tolist()
self.o_entities = train_data[:, O].unique().tolist()
self._is_prepared = True
Expand Down Expand Up @@ -105,6 +111,20 @@ def __init__(self, config, dataset, parent_job, model):
self.triple_classification_sampler = TripleClassificationSampler(
config, "triple_classification", dataset
)
self.config.check(
"triple_classification.negatives_from", ["corruption", "data"]
)
self.negatives_from = self.config.get("triple_classification.negatives_from")
if self.negatives_from == "data":
try:
self.config.get("dataset.files.valid_negatives.type")
self.config.get("dataset.files.test_negatives.type")
except:
raise KeyError(
"No splits test/valid_negatives found for the dataset. "
"Provide a dataset with splits valid_negatives and test_negatives "
"or run triple classification with negatives_from=corruption"
)

def _prepare(self):
"""Prepare the corrupted validation and test data.
Expand All @@ -121,7 +141,8 @@ def _prepare(self):

self.config.log("Generate data with corrupted and true triples...")

if self.eval_split == "test":
# TODO maybe should be generalized to allow for other splits as valid_wo_unseen
if self.eval_split == "test" and self.negatives_from == "corruption":
(
self.tune_data,
self.tune_labels,
Expand All @@ -130,7 +151,25 @@ def _prepare(self):
self.eval_data,
self.eval_labels,
) = self.triple_classification_sampler.sample(self.dataset.split("test"))
else:

elif self.eval_split == "test" and self.negatives_from == "data":
positives_valid = self.dataset.split("valid")
negatives_valid = self.dataset.split("valid_negatives")
self.tune_data = torch.cat((positives_valid, negatives_valid)).to(self.device)
self.tune_labels = torch.cat(
(torch.ones(positives_valid.size(0), torch.zeros(negatives_valid.size(0))))
).to(self.device)

positives_test = self.dataset.split("test")
negatives_test = self.dataset.split("test_negatives")
self.tune_data = torch.cat((positives_test, negatives_test)).to(
self.device)
self.tune_labels = torch.cat(
(torch.ones(positives_test.size(0),
torch.zeros(negatives_test.size(0))))
).to(self.device)

elif self.eval_split == "valid" and self.negatives_from == "corruption":
(
self.tune_data,
self.tune_labels,
Expand All @@ -140,6 +179,17 @@ def _prepare(self):
self.eval_labels,
) = self.triple_classification_sampler.sample(self.dataset.split("valid"))

elif self.eval_split == "valid" and self.negatives_from == "data":
positives = self.dataset.split("valid")
negatives = self.dataset.split("valid_negatives")
self.tune_data = torch.cat((positives, negatives)).to(self.device)
self.tune_labels = torch.cat(
(torch.ones(positives.size(0)), torch.zeros(negatives.size(0)))
).to(self.device)

self.eval_data = self.tune_data
self.eval_labels = self.tune_labels

# let the model add some hooks, if it wants to do so
self.model.prepare_job(self)
self.valid_data_is_prepared = True
Expand Down

0 comments on commit 8a4416f

Please sign in to comment.