forked from rasbt/faster-pytorch-blog
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlocal_dataset_utilities.py
102 lines (74 loc) · 2.88 KB
/
local_dataset_utilities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import sys
import tarfile
import time
import numpy as np
import pandas as pd
from packaging import version
from torch.utils.data import Dataset
from tqdm import tqdm
import urllib
def reporthook(count, block_size, total_size):
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = progress_size / (1024.0**2 * duration)
percent = count * block_size * 100.0 / total_size
sys.stdout.write(
f"\r{int(percent)}% | {progress_size / (1024.**2):.2f} MB "
f"| {speed:.2f} MB/s | {duration:.2f} sec elapsed"
)
sys.stdout.flush()
def download_dataset():
source = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
target = "aclImdb_v1.tar.gz"
if os.path.exists(target):
os.remove(target)
if not os.path.isdir("aclImdb") and not os.path.isfile("aclImdb_v1.tar.gz"):
urllib.request.urlretrieve(source, target, reporthook)
if not os.path.isdir("aclImdb"):
with tarfile.open(target, "r:gz") as tar:
tar.extractall()
def load_dataset_into_to_dataframe():
basepath = "aclImdb"
labels = {"pos": 1, "neg": 0}
df = pd.DataFrame()
with tqdm(total=50000) as pbar:
for s in ("test", "train"):
for l in ("pos", "neg"):
path = os.path.join(basepath, s, l)
for file in sorted(os.listdir(path)):
with open(os.path.join(path, file), "r", encoding="utf-8") as infile:
txt = infile.read()
if version.parse(pd.__version__) >= version.parse("1.3.2"):
x = pd.DataFrame(
[[txt, labels[l]]], columns=["review", "sentiment"]
)
df = pd.concat([df, x], ignore_index=False)
else:
df = df.append([[txt, labels[l]]], ignore_index=True)
pbar.update()
df.columns = ["text", "label"]
np.random.seed(0)
df = df.reindex(np.random.permutation(df.index))
print("Class distribution:")
np.bincount(df["label"].values)
return df
def partition_dataset(df):
df_shuffled = df.sample(frac=1, random_state=1).reset_index()
df_train = df_shuffled.iloc[:35_000]
df_val = df_shuffled.iloc[35_000:40_000]
df_test = df_shuffled.iloc[40_000:]
df_train.to_csv("train.csv", index=False, encoding="utf-8")
df_val.to_csv("val.csv", index=False, encoding="utf-8")
df_test.to_csv("test.csv", index=False, encoding="utf-8")
class IMDBDataset(Dataset):
def __init__(self, dataset_dict, partition_key="train"):
self.partition = dataset_dict[partition_key]
def __getitem__(self, index):
return self.partition[index]
def __len__(self):
return self.partition.num_rows