-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatching.py
38 lines (29 loc) · 1.31 KB
/
batching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np
class SimpleIterator(object):
"""Iterator without bucketing."""
def __init__(self, data, batch_size, shuffle_every_epoch=False):
self.data = dara
self.batch_size = batch_size
self.shuffle_every_epoch = shuffle_every_epoch
self.n_input = self.x_list[0].shape[-1]
self.x_lengths = np.array([i.shape[0] for i in x_list])
self.n_batches = int(np.ceil(float(len(self.x_lengths))/batch_size))
self.indices = np.arange(len(self.x_lengths))
np.random.shuffle(self.indices)
def __iter__(self):
if self.shuffle_every_epoch:
np.random.shuffle(self.indices)
for i_batch in range(self.n_batches):
batch_indices = self.indices[
i_batch*self.batch_size:(i_batch + 1)*self.batch_size
]
batch_x_lengths = self.x_lengths[batch_indices]
# Pad to maximum length in batch
batch_x_padded = np.zeros(
(len(batch_indices), np.max(batch_x_lengths), self.n_input),
dtype=NP_DTYPE
)
for i, length in enumerate(batch_x_lengths):
seq = self.x_list[batch_indices[i]]
batch_x_padded[i, :length, :] = seq
yield (batch_x_padded, batch_x_lengths)