-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
91 lines (71 loc) · 3.38 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from tqdm import tqdm
from torch.utils.data import Dataset
class Seq2SeqDatasetBART(Dataset):
def __init__(self, tokenizer, path_file, max_len):
self.tokenizer = tokenizer
self.path_file = path_file
self.data = []
self.max_len = max_len
self.setup()
def setup(self):
with open(self.path_file) as f:
for idx, line in tqdm(enumerate(f), desc='Reading lines'):
line = line.strip()
line = line.split("\t")
title = line[1]
split_title = title.split()
for i in range(len(split_title)):
if "**" in split_title[i]:
split_title[i] = split_title[i].replace("**", "", 2)
split_title[i] = " { "+split_title[i]+" } "
for i in range(len(split_title)):
if "*" in split_title[i]:
num_ast = split_title[i].count("*")
if num_ast==2:
split_title[i] = split_title[i].replace("*", "", 2)
split_title[i] = " {{ "+split_title[i]+" }} "
elif num_ast == 1:
index_ast = split_title[i].index("*")
if index_ast==0:
split_title[i] = split_title[i].replace("*", "",1)
split_title[i] = " {{ "+split_title[i]
else:
split_title[i] = split_title[i].replace("*", "",1)
split_title[i] = split_title[i]+" }} "
new_title = " ".join(e for e in split_title)
verb = line[2]
verb_sense = line[3]
arg = line[4]
arg_sense = line[5]
source = line[6:]
marker = " : "
point = " . "
encoded_source = []
for i,elem in enumerate(source):
if i == len(source)-1:
encoded_source += [str(i+1)+"."] + [elem] + [point]
else:
encoded_source += [str(i+1)+"."] + [elem] + [point]
if len(encoded_source)>1:
encoded_source = [" ".join(e for e in encoded_source)]
if self.max_len==175:
target = [new_title + point + verb + marker + verb_sense + point + arg + marker + arg_sense + point]
elif self.max_len == 20:
target = [new_title]
encoded_s, encoded_t = self.encoded_sentences(encoded_source, target)
data = {
"source": encoded_s,
"target":encoded_t,
}
self.data.append(data)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def encoded_sentences(self, source, target):
encoded_source = self.tokenizer.encode(self.tokenizer.tokenize(source[0]), truncation = True, max_length=1024)
encoded_s = torch.tensor(encoded_source)
encoded_target = self.tokenizer.encode(self.tokenizer.tokenize(target[0]), truncation = True, max_length=1024)
encoded_t = torch.tensor(encoded_target)
return encoded_s,encoded_t