-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmodels.py
325 lines (285 loc) · 11.9 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
from keras.models import Model
from keras.layers import Dense, Conv1D, Layer, Input, Concatenate
from keras.metrics import categorical_accuracy, categorical_crossentropy
from keras.utils import to_categorical
from keras.optimizers import RMSprop
from keras.regularizers import l2
import keras.backend as K
from keras.engine import InputSpec
import numpy
__author__ = 'Romain Tavenard romain.tavenard[at]univ-rennes2.fr'
class GlobalMinPooling1D(Layer):
"""Global min pooling operation for temporal data.
# Input shape
3D tensor with shape: `(batch_size, steps, features)`.
# Output shape
2D tensor with shape:
`(batch_size, features)`
"""
def __init__(self, **kwargs):
super(GlobalMinPooling1D, self).__init__(**kwargs)
self.input_spec = InputSpec(ndim=3)
def compute_output_shape(self, input_shape):
return input_shape[0], input_shape[2]
def call(self, inputs, **kwargs):
return K.min(inputs, axis=1)
class LocalSquaredDistanceLayer(Layer):
"""Pairwise (squared) distance computation between local patches and shapelets
# Input shape
3D tensor with shape: `(batch_size, steps, features)`.
# Output shape
3D tensor with shape:
`(batch_size, steps, n_shapelets)`
"""
def __init__(self, n_shapelets, **kwargs):
self.n_shapelets = n_shapelets
super(LocalSquaredDistanceLayer, self).__init__(**kwargs)
self.input_spec = InputSpec(ndim=3)
def build(self, input_shape):
self.kernel = self.add_weight(name='kernel',
shape=(self.n_shapelets, input_shape[2]),
initializer='uniform',
trainable=True)
super(LocalSquaredDistanceLayer, self).build(input_shape)
def call(self, x, **kwargs):
# (x - y)^2 = x^2 + y^2 - 2 * x * y
x_sq = K.expand_dims(K.sum(x ** 2, axis=2), axis=-1)
y_sq = K.reshape(K.sum(self.kernel ** 2, axis=1), (1, 1, self.n_shapelets))
xy = K.dot(x, K.transpose(self.kernel))
return x_sq + y_sq - 2 * xy
def compute_output_shape(self, input_shape):
return input_shape[0], input_shape[1], self.n_shapelets
def grabocka_params_to_shapelet_size_dict(ts_sz, n_classes, l, r):
"""Compute number and length of shapelets the way it is done in [1]_.
Parameters
----------
ts_sz: int
Length of time series in the dataset
n_classes: int
Number of classes in the dataset
l: float
Fraction of the length of time series to be used for base shapelet length
r: int
Number of different shapelet lengths to use
Returns
-------
dict
Dictionnary giving, for each shapelet length, the number of such shapelets to be generated
Examples
--------
>>> d = grabocka_params_to_shapelet_size_dict(ts_sz=100, n_classes=3, l=0.1, r=2)
>>> keys = sorted(d.keys())
>>> print(keys)
[10, 20]
>>> print([d[k] for k in keys])
[3, 3]
"""
base_size = int(l * ts_sz)
d = {}
for sz_idx in range(r):
shp_sz = base_size * (sz_idx + 1)
n_shapelets = int(numpy.log10(ts_sz - shp_sz + 1) * (n_classes - 1))
d[shp_sz] = n_shapelets
return d
class ShapeletModel:
"""Learning Time-Series Shapelets model as presented in [1]_.
This implementation only accepts mono-dimensional time series as inputs.
Parameters
----------
n_shapelets_per_size: dict
Dictionary giving, for each shapelet size (key),
the number of such shapelets to be trained (value)
max_iter: int (default: 1000)
Number of training epochs.
batch_size: int (default:256)
Batch size to be used.
verbose_level: {0, 1, 2} (default: 2)
`keras` verbose level.
optimizer: str or keras.optimizers.Optimizer (default: "sgd")
`keras` optimizer to use for training.
weight_regularizer: float or None (default: None)
`keras` regularizer to use for training the classification (softmax) layer.
If None, no regularization is performed.
Attributes
----------
shapelets: numpy.ndarray
Set of time-series shapelets
Examples
--------
>>> from tslearn.generators import random_walk_blobs
>>> X, y = random_walk_blobs(n_ts_per_blob=100, sz=256, d=1, n_blobs=3)
>>> clf = ShapeletModel(n_shapelets_per_size={10: 5}, max_iter=1, verbose_level=0)
>>> clf.fit(X, y).shapelets_.shape
(5,)
>>> clf.shapelets_[0].shape
(10,)
>>> clf.predict(X).shape
(300,)
>>> clf.transform(X).shape
(300, 5)
>>> clf2 = ShapeletModel(n_shapelets_per_size={10: 5, 20: 10}, max_iter=1, verbose_level=0)
>>> clf2.fit(X, y).shapelets_.shape
(15,)
>>> clf2.shapelets_[0].shape
(10,)
>>> clf2.shapelets_[5].shape
(20,)
>>> clf2.predict(X).shape
(300,)
>>> clf2.transform(X).shape
(300, 15)
References
----------
.. [1] J. Grabocka et al. Learning Time-Series Shapelets. SIGKDD 2014.
"""
def __init__(self, n_shapelets_per_size,
max_iter=1000,
batch_size=256,
verbose_level=2,
optimizer="sgd",
weight_regularizer=0.):
self.n_shapelets_per_size = n_shapelets_per_size
self.n_classes = None
self.optimizer = optimizer
self.epochs = max_iter
self.weight_regularizer = weight_regularizer
self.model = None
self.transformer_model = None
self.batch_size = batch_size
self.verbose_level = verbose_level
self.categorical_y = False
@property
def _n_shapelet_sizes(self):
return len(self.n_shapelets_per_size)
@property
def shapelets_(self):
total_n_shp = sum(self.n_shapelets_per_size.values())
shapelets = numpy.empty((total_n_shp, ), dtype=object)
idx = 0
for i in range(self._n_shapelet_sizes):
for shp in self.model.get_layer("shapelets_%d" % i).get_weights()[0]:
shapelets[idx] = shp
idx += 1
assert idx == total_n_shp
return shapelets
def fit(self, X, y):
n_ts, sz, d = X.shape
assert(d == 1)
if y.ndim == 1:
y_ = to_categorical(y)
else:
y_ = y
self.categorical_y = True
n_classes = y_.shape[1]
self._set_model_layers(ts_sz=sz, d=d, n_classes=n_classes)
self.model.compile(loss="categorical_crossentropy",
optimizer=self.optimizer,
metrics=[categorical_accuracy,
categorical_crossentropy])
self.transformer_model.compile(loss="mean_squared_error",
optimizer=self.optimizer)
self._set_weights_false_conv(d=d)
self.model.fit(X, y_,
batch_size=self.batch_size,
epochs=self.epochs,
verbose=self.verbose_level)
return self
def predict(self, X):
categorical_preds = self.model.predict(X,
batch_size=self.batch_size,
verbose=self.verbose_level)
if self.categorical_y:
return categorical_preds
else:
return categorical_preds.argmax(axis=1)
def transform(self, X):
return self.transformer_model.predict(X,
batch_size=self.batch_size,
verbose=self.verbose_level)
def _set_weights_false_conv(self, d):
shapelet_sizes = sorted(self.n_shapelets_per_size.keys())
for i, sz in enumerate(sorted(shapelet_sizes)):
weights_false_conv = numpy.empty((sz, d, sz))
for di in range(d):
weights_false_conv[:, di, :] = numpy.eye(sz)
layer = self.model.get_layer("false_conv_%d" % i)
layer.set_weights([weights_false_conv])
def _set_model_layers(self, ts_sz, d, n_classes):
inputs = Input(shape=(ts_sz, d), name="input")
shapelet_sizes = sorted(self.n_shapelets_per_size.keys())
pool_layers = []
for i, sz in enumerate(sorted(shapelet_sizes)):
transformer_layer = Conv1D(filters=sz,
kernel_size=sz,
trainable=False,
use_bias=False,
name="false_conv_%d" % i)(inputs)
shapelet_layer = LocalSquaredDistanceLayer(self.n_shapelets_per_size[sz],
name="shapelets_%d" % i)(transformer_layer)
pool_layers.append(GlobalMinPooling1D(name="min_pooling_%d" % i)(shapelet_layer))
if len(shapelet_sizes) > 1:
concatenated_features = Concatenate()(pool_layers)
else:
concatenated_features = pool_layers[0]
if self.weight_regularizer > 0.:
outputs = Dense(units=n_classes,
activation="softmax",
kernel_regularizer=l2(self.weight_regularizer),
name="softmax")(concatenated_features)
else:
outputs = Dense(units=n_classes,
activation="softmax",
name="softmax")(concatenated_features)
self.model = Model(inputs=inputs, outputs=outputs)
self.transformer_model = Model(inputs=inputs, outputs=concatenated_features)
def get_weights(self, layer_name=None):
"""Return model weights (or weights for a given layer if `layer_name` is provided).
Parameters
----------
layer_name: str or None (default: None)
Name of the layer for which weights should be returned.
If None, all model weights are returned.
Available layer names with weights are:
- "shapelets_i" with i an integer for the sets of shapelets
corresponding to each shapelet size (sorted in ascending order)
- "softmax" for the final classification layer
Returns
-------
list
list of model (or layer) weights
Examples
--------
>>> from tslearn.generators import random_walk_blobs
>>> X, y = random_walk_blobs(n_ts_per_blob=100, sz=256, d=1, n_blobs=3)
>>> clf = ShapeletModel(n_shapelets_per_size={10: 5}, max_iter=1, verbose_level=0)
>>> clf.fit(X, y).get_weights("softmax")[0].shape
(5, 3)
"""
if layer_name is None:
return self.model.get_weights()
else:
return self.model.get_layer(layer_name).get_weights()
if __name__ == "__main__":
from tslearn.datasets import CachedDatasets
from tslearn.preprocessing import TimeSeriesScalerMeanVariance
import time
X_train, y_train, X_test, y_test = CachedDatasets().load_dataset("Trace")
X_train = TimeSeriesScalerMeanVariance().fit_transform(X_train)
X_test = TimeSeriesScalerMeanVariance().fit_transform(X_test)
ts_sz = X_train.shape[1]
l, r = 0.1, 2 # Taken (for dataset Trace) from the Table at:
# http://fs.ismll.de/publicspace/LearningShapelets/
n_classes = len(set(y_train))
n_shapelets_per_size = grabocka_params_to_shapelet_size_dict(ts_sz, n_classes, l, r)
t0 = time.time()
clf = ShapeletModel(n_shapelets_per_size=n_shapelets_per_size,
max_iter=1000,
optimizer=RMSprop(lr=.001),
weight_regularizer=.01,
verbose_level=0)
clf.fit(X_train, y_train)
print("Total time for training: %fs" % (time.time() - t0))
print([shp.shape for shp in clf.shapelets_])
pred = clf.predict(X_test)
print(numpy.sum(y_test == pred))
print(clf.transform(X_train).shape)