forked from bmilde/deepschmatzing
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnolearn_addon.py
79 lines (67 loc) · 2.59 KB
/
nolearn_addon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
import theano
from nolearn.lasagne import BatchIterator
def float32(k):
return np.cast['float32'](k)
def shuffle(*arrays):
p = np.random.permutation(len(arrays[0]))
return [array[p] for array in arrays]
class ShufflingBatchIterator(BatchIterator):
def __iter__(self):
self.X, self.y = shuffle(self.X, self.y)
for res in super(ShufflingBatchIterator, self).__iter__():
yield res
class EarlyStopping(object):
def __init__(self, patience=100):
self.patience = patience
self.best_valid = np.inf
self.best_valid_epoch = 0
self.best_weights = None
def __call__(self, nn, train_history):
current_valid = train_history[-1]['valid_loss']
current_epoch = train_history[-1]['epoch']
if current_valid < self.best_valid:
self.best_valid = current_valid
self.best_valid_epoch = current_epoch
self.best_weights = [w.get_value() for w in nn.get_all_params()]
elif self.best_valid_epoch + self.patience < current_epoch:
print("Early stopping.")
print("Best valid loss was {:.6f} at epoch {}.".format(
self.best_valid, self.best_valid_epoch))
nn.load_weights_from(self.best_weights)
raise StopIteration()
class AdjustVariable(object):
def __init__(self, name, start=0.03, stop=0.001):
self.name = name
self.start, self.stop = start, stop
self.ls = None
def __call__(self, nn, train_history):
if self.ls is None:
self.ls = np.linspace(self.start, self.stop, nn.max_epochs)
epoch = train_history[-1]['epoch']
new_value = float32(self.ls[epoch - 1])
getattr(nn, self.name).set_value(new_value)
#custom batch iterator for nolearn, that iterates in <minibatch> chunks
class ForcedEvenBatchIterator(object):
def __init__(self, batch_size, forced_even=False):
self.batch_size = batch_size
self.forced_even = forced_even
def __call__(self, X, y=None, test=False):
self.X, self.y = X, y
self.test = test
return self
def __iter__(self):
n_samples = self.X.shape[0]
bs = self.batch_size
for i in range((n_samples + bs - 1) / bs):
sl = slice(i * bs, (i + 1) * bs)
Xb = self.X[sl]
if self.forced_even and len(Xb) != bs:
continue
if self.y is not None:
yb = self.y[sl]
else:
yb = None
yield self.transform(Xb, yb)
def transform(self, Xb, yb):
return Xb, yb