-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathhelper.py
116 lines (90 loc) · 3.66 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np, sys, os, random, pdb, json, uuid, time, argparse
from pprint import pprint
import logging, logging.config
from collections import defaultdict as ddict
from ordered_set import OrderedSet
# PyTorch related imports
import torch
from torch.nn import functional as F
from torch.nn.init import xavier_normal_
from torch.utils.data import DataLoader
from torch.nn import Parameter
from torch_scatter import scatter_add
np.set_printoptions(precision=4)
def set_gpu(gpus):
"""
Sets the GPU to be used for the run
Parameters
----------
gpus: List of GPUs to be used for the run
Returns
-------
"""
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpus
def get_logger(name, log_dir, config_dir):
"""
Creates a logger object
Parameters
----------
name: Name of the logger file
log_dir: Directory where logger file needs to be stored
config_dir: Directory from where log_config.json needs to be read
Returns
-------
A logger object which writes to both file and stdout
"""
config_dict = json.load(open( config_dir + 'log_config.json'))
config_dict['handlers']['file_handler']['filename'] = log_dir + name.replace('/', '-')
logging.config.dictConfig(config_dict)
logger = logging.getLogger(name)
std_out_format = '%(asctime)s - [%(levelname)s] - %(message)s'
consoleHandler = logging.StreamHandler(sys.stdout)
consoleHandler.setFormatter(logging.Formatter(std_out_format))
logger.addHandler(consoleHandler)
return logger
def get_combined_results(left_results, right_results):
results = {}
count = float(left_results['count'])
results['left_mr'] = round(left_results ['mr'] /count, 5)
results['left_mrr'] = round(left_results ['mrr']/count, 5)
results['right_mr'] = round(right_results['mr'] /count, 5)
results['right_mrr'] = round(right_results['mrr']/count, 5)
results['mr'] = round((left_results['mr'] + right_results['mr']) /(2*count), 5)
results['mrr'] = round((left_results['mrr'] + right_results['mrr'])/(2*count), 5)
for k in range(10):
results['left_hits@{}'.format(k+1)] = round(left_results ['hits@{}'.format(k+1)]/count, 5)
results['right_hits@{}'.format(k+1)] = round(right_results['hits@{}'.format(k+1)]/count, 5)
results['hits@{}'.format(k+1)] = round((left_results['hits@{}'.format(k+1)] + right_results['hits@{}'.format(k+1)])/(2*count), 5)
results['left_hits@{}'.format(20)] = round(left_results ['hits@{}'.format(20)]/count, 5)
results['right_hits@{}'.format(20)] = round(right_results['hits@{}'.format(20)]/count, 5)
results['hits@{}'.format(20)] = round((left_results['hits@{}'.format(20)] + right_results['hits@{}'.format(20)])/(2*count), 5)
results["left_ranks"] = left_results["ranks"]
results["right_ranks"] = right_results["ranks"]
results["sub"] = left_results["sub"]
results["rel"] = left_results["rel"]
results["obj"] = left_results["obj"]
return results
def get_param(shape):
param = Parameter(torch.Tensor(*shape));
xavier_normal_(param.data)
return param
def get_param_device(x, y, device):
param = Parameter(torch.Tensor(x,y).to(device));
xavier_normal_(param.data)
return param
def com_mult(a, b):
r1, i1 = a[..., 0], a[..., 1]
r2, i2 = b[..., 0], b[..., 1]
return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim = -1)
def conj(a):
a[..., 1] = -a[..., 1]
return a
def cconv(a, b):
return torch.irfft(com_mult(torch.rfft(a, 1), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],))
def ccorr(a, b):
return torch.irfft(com_mult(conj(torch.rfft(a, 1)), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],))
def mask_softmax(a, mask):
b = torch.masked_fill(a, (1-mask).bool(), -np.inf)
output = torch.softmax(b, 1)
return output