-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
33 lines (31 loc) · 1.11 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# -*- coding: utf-8 -*-
"""
@Time : 2020/10/20 3:10 pm
@Auth : maker
@Email:[email protected]
"""
import torch
import pandas as pd
from collections import Counter
class Dataset(torch.utils.data.Dataset):
def __init__(self,args):
self.args = args
self.words = self.load_words()
self.uniq_words = self.get_uniq_words()
self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
self.words_indexes = [self.word_to_index[w] for w in self.words]
def load_words(self):
train_df = pd.read_csv('data/reddit-cleanjokes.csv')
text = train_df['Joke'].str.cat(sep=' ')
return text.split(' ')
def get_uniq_words(self):
word_counts = Counter(self.words)
return sorted(word_counts,key = word_counts.get,reverse = True)
def __len__(self):
return len(self.words_indexes) - self.args.sequence_length
def __getitem__(self, index):
return (
torch.tensor(self.words_indexes[index:index + self.args.sequence_length]),
torch.tensor(self.words_indexes[index + 1 : index + self.args.sequence_length + 1]),
)