forked from EleutherAI/the-pile
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpile.py
139 lines (112 loc) · 4 KB
/
pile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import lm_dataformat as lmd
import json
from pytablewriter import MarkdownTableWriter
from tqdm import tqdm, trange
from utils import humanbytes
import random
from datasets import *
datasets = [
(BibliotikDataset() , 1. ),
(PubMedCentralDataset(), 1. ),
(ArXivDataset() , 1. ),
(FreeLawDataset() , 1. ),
(OpenWebTextDataset() , 1. ),
(StackExchangeDataset(), 1. ),
(USPTODataset() , 1. ),
(PubMedDataset() , 1. ),
(WikipediaDataset() , 1. ),
(OpensubtitlesDataset(), 1. ),
(GutenbergDataset() , 1. ),
(LiteroticaDataset() , 1. ),
(DMMathDataset() , 1. ),
(BookCorpusDataset() , 1. ),
(UbuntuIRCDataset() , 1. ),
(EuroParlDataset() , 1. ),
(PhilPapersDataset() , 1. ),
(ExPorterDataset() , 1. ),
(EnronEmailsDataset() , 1. ),
]
def take(n, iter):
ret = []
for i in range(n):
try:
ret.append(next(iter))
except StopIteration:
break
return ret
def get_samples():
for dset, _ in datasets:
print('\\subsection{' + dset.name() + '}')
print()
docs = take(1000, dset.documents())
random.shuffle(docs)
limit = 8192
res = ''
for doc in docs:
if len(res) > limit:
break
res += doc + '<|endoftext|>'
if len(res) > limit:
i = random.randrange(0, len(res) - limit)
res = res[i:i+limit]
print('\\begin{verbatim}\n' + res + '\n\\end{verbatim}')
def mk_table(datasets):
values = []
total_weight = sum([x[1] * x[0].size() for x in datasets])
train_chars = 1.2e12
for dataset, weight in datasets:
size = dataset.size()
relative_weight = size * weight / total_weight
values.append([dataset.name(), size, '{:.2%}'.format(relative_weight), train_chars / size * relative_weight, humanbytes(size / dataset.num_docs())])
values.sort(key=lambda x: -x[1])
values.append(['**Total**', sum([x[1] for x in values]), "", "", humanbytes(sum([x[1] for x in values]) / sum(x[0].num_docs() for x in datasets))])
values = [[x[0], humanbytes(x[1]), x[2], x[3], x[4]] for x in values]
writer = MarkdownTableWriter()
writer.table_name = "The Pile™"
writer.headers = ["Component", "Size", "Weight", "Epochs", "Mean Document Size"]
writer.value_matrix = values
return writer.dumps()
class ThePile:
def __init__(self, datasets, dataset_bytes):
self.datasets = datasets
self.dataset_bytes = dataset_bytes
@abc.abstractmethod
def name(self):
return "The Pile"
@abc.abstractmethod
def documents(self):
datasets = []
weights = []
# calculate relative_weight for each
total_weight = sum([x[1] * x[0].size() for x in self.datasets])
for dataset, weight in self.datasets:
size = dataset.size()
relative_weight = size * weight / total_weight
datasets.append(cycle_documents(dataset))
weights.append(relative_weight)
random.seed(42)
# yield from dataset until right number of bytes
total_bytes = 0
pbar = tqdm(total=self.dataset_bytes, unit='B', unit_scale=True, unit_divisor=1024)
while True:
chunk = random.choices(population=datasets, weights=weights, k=1000)
for dset in chunk:
doc = next(dset)
size = utf8len(doc)
total_bytes += size
pbar.update(size)
yield doc
if total_bytes > self.dataset_bytes:
return
@abc.abstractmethod
def clean(self):
for dataset, _ in self.datasets: dataset.clean()
def size(self):
return sum(map(lambda x: x[0].size(), tqdm(self.datasets())))
if __name__ == '__main__':
random.seed(42)
print(mk_table(datasets))
pile = ThePile(datasets, int(1.2e12))
get_samples()
for x in pile.documents():
pass