-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
99 lines (79 loc) · 3.54 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (c) 2021 Cognizant Digital Business, Cognizant AI Labs
# Issued under this Academic Public License: github.com/cognizant-ai-labs/tom-release/LICENSE.
"""
Utility functions for TOM experiments
"""
import numpy as np
import pandas as pd
import torch
def load_dataset(path):
"""
Load a csv dataset into the format TOM expects.
"""
dataset_df = pd.read_csv(path)
feature_columns = [c for c in dataset_df.columns if 'feature' in c]
class_columns = [c for c in dataset_df.columns if 'class' in c]
data_columns = feature_columns + class_columns
dataset = {}
for split in ['train', 'val', 'test']:
split_df = dataset_df[dataset_df.split == split]
data_df = split_df[data_columns]
data_arr = data_df.values
data_tensor = torch.from_numpy(data_arr).float()
dataset[split] = data_tensor
dataset['true_input_variable_indices'] = np.arange(len(feature_columns))
dataset['true_output_variable_indices'] = np.arange(len(feature_columns),
len(data_columns))
# Load origin so the oracle can use it. This is specific to the ch problem.
origin_row = dataset_df[dataset_df.split == 'origin']
origin = origin_row[feature_columns].iloc[0].values
dataset['origin'] = origin
return dataset
def squared_hinge(pred, target):
"""
The squared hinge loss as an alternative to crossentropy in classification.
"""
return torch.mean(torch.sum(torch.clamp(1 - (2 * target - 1) * pred, 0.)**2, dim=1))
def compute_loss_and_accuracy(model,
dataset,
split,
batch_size,
soft_model=False,
loss_fn=squared_hinge):
with torch.no_grad():
data_set_tensor = dataset[split]
context_tensor = dataset['context_tensor']
num_samples = data_set_tensor.shape[0]
output_contexts = context_tensor
input_contexts = context_tensor[:,:,dataset['true_input_variable_indices']]
# Compute number of validation steps
nsteps = int(np.ceil(data_set_tensor.shape[0] / batch_size))
# Run validation step-by-step
correct = 0
total_loss = 0.
for val_step in range(nsteps):
start_idx = val_step * batch_size
end_idx = start_idx + batch_size
# Pull out batch
batch_input = data_set_tensor[start_idx:end_idx]
batch_input = batch_input[:,dataset['true_input_variable_indices']]
# Forward pass to get prediction
if soft_model:
pred = model(batch_input, input_contexts, output_contexts, dataset['dataset_idx'])
else:
pred = model(batch_input, input_contexts, output_contexts)
# Pull out target
target = data_set_tensor[start_idx:end_idx]
# Compute number correct
class_pred = pred[:,dataset['true_output_variable_indices']]
class_target = target[:,dataset['true_output_variable_indices']]
pred_label = torch.argmax(class_pred, dim=1)
target_label = torch.argmax(class_target, dim=1)
correct += (pred_label == target_label).sum().item()
# Compute loss
loss = loss_fn(class_pred, class_target)
total_loss += loss.item() * batch_size
total = data_set_tensor.shape[0]
accuracy = correct / float(total)
mean_loss = total_loss / float(total)
return mean_loss, accuracy