-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
309 lines (254 loc) · 11.7 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import os
import logging
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from transformers import ViTFeatureExtractor, ViTModel
NUM_CLASSES = 2
LEARNING_RATE = 1e-4
DROPOUT_P = 0.1
RESNET_OUT_DIM = 2048
DINO_EMBEDDING_DIM_SMALL = 384 # for dino-vits models
DINO_EMBEDDING_DIM_BASE = 768 # for dino-vitb models
PATCH_SIZE_S16_MODEL = 'facebook/dino-vits16'
PATCH_SIZE_S8_MODEL = 'facebook/dino-vits8'
PATCH_SIZE_B16_MODEL = 'facebook/dino-vitb16'
PATCH_SIZE_B8_MODEL = 'facebook/dino-vitb8'
DINO_MODEL = PATCH_SIZE_B16_MODEL
losses = []
logging.basicConfig(level=logging.INFO) # DEBUG, INFO, WARNING, ERROR, CRITICAL
print("CUDA available:", torch.cuda.is_available())
class SelfSupervisedDinoTransformerModel(nn.Module):
def __init__(
self,
num_classes,
loss_fn,
hidden_dim=512,
dropout_p=0.1,
dino_model=DINO_MODEL,
):
super(SelfSupervisedDinoTransformerModel, self).__init__()
# NOTE: We run the feature extractor in the pl.LightningModule, i.e.
# this model receives the image pixel values (from the feature
# extractor) as input; alternatively, that could be done here by
# refactoring the pl.LightningModule and uncommenting the following
# self.feature_extractor = ViTFeatureExtractor.from_pretrained(dino_model)
self.dino_model = ViTModel.from_pretrained(dino_model)
dino_embedding_dim = None
# Note that dino_model is the str name of the model, whereas
# self.dino_model is the actual ViTModel object
if dino_model in [PATCH_SIZE_S16_MODEL, PATCH_SIZE_S8_MODEL]:
dino_embedding_dim = DINO_EMBEDDING_DIM_SMALL
elif dino_model in [PATCH_SIZE_B16_MODEL, PATCH_SIZE_B8_MODEL]:
dino_embedding_dim = DINO_EMBEDDING_DIM_BASE
self.fc1 = torch.nn.Linear(in_features=dino_embedding_dim, out_features=hidden_dim)
self.fc2 = torch.nn.Linear(in_features=hidden_dim, out_features=num_classes)
self.loss_fn = loss_fn
self.dropout = torch.nn.Dropout(dropout_p)
def forward(self, image_pixel_values, label=None):
dino_embedding = self.dino_model(pixel_values=image_pixel_values)
dino_last_hidden_states = dino_embedding.last_hidden_state[:,0]
hidden = torch.nn.functional.relu(self.fc1(dino_last_hidden_states))
logits = self.fc2(hidden)
# nn.CrossEntropyLoss expects raw logits as model output, NOT torch.nn.functional.softmax(logits, dim=1)
# https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
pred = logits
loss = self.loss_fn(pred, label)
return (pred, loss)
class SelfSupervisedDinoIDCDetectionModel(pl.LightningModule):
def __init__(self, hparams=None):
super(SelfSupervisedDinoIDCDetectionModel, self).__init__()
if hparams:
# Cannot reassign self.hparams in pl.LightningModule; must use update()
# https://github.com/PyTorchLightning/pytorch-lightning/discussions/7525
self.hparams.update(hparams)
logging.info(self.hparams)
self.learning_rate = self.hparams.get("learning_rate", LEARNING_RATE)
self.dino_model = self.hparams.get("dino_model", DINO_MODEL)
self.feature_extractor = ViTFeatureExtractor.from_pretrained(self.dino_model)
self.model = SelfSupervisedDinoTransformerModel(
num_classes=self.hparams.get("num_classes", NUM_CLASSES),
loss_fn=torch.nn.CrossEntropyLoss(),
dino_model=self.dino_model,
)
logging.info(self.dino_model)
# When reloading the model for evaluation, we will need the
# hyperparameters as they are now
self.save_hyperparameters()
# Required for pl.LightningModule
def forward(self, image, label):
# pl.Lightning convention: forward() defines prediction for inference]
image = torch.tensor(np.stack(self.feature_extractor(image)["pixel_values"]), axis=0)
return self.model(image, label)
# Required for pl.LightningModule
def training_step(self, batch, batch_idx):
global losses
# pl.Lightning convention: training_step() defines prediction and
# accompanying loss for training, independent of forward()
image, label = batch["image"], batch["label"]
pred, loss = self.model(image, label)
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
print(loss.item())
losses.append(loss.item())
return loss
# Optional for pl.LightningModule
def training_step_end(self, batch_parts):
"""
Aggregates results when training using a strategy that splits data
from each batch across GPUs (e.g. data parallel)
Note that training_step returns a loss, thus batch_parts returns a list
of K loss values (where there are K GPUs being used)
"""
return sum(batch_parts) / len(batch_parts)
# Optional for pl.LightningModule
def test_step(self, batch, batch_idx):
image, label = batch["image"], batch["label"]
pred, loss = self.model(image, label)
pred_label = torch.argmax(pred, dim=1)
accuracy = torch.sum(pred_label == label).item() / (len(label) * 1.0)
output = {
'test_loss': loss,
'test_acc': torch.tensor(accuracy).cuda()
}
print(loss.item(), output['test_acc'])
return output
# Optional for pl.LightningModule
def test_epoch_end(self, outputs):
avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
avg_accuracy = torch.stack([x["test_acc"] for x in outputs]).mean()
logs = {
'test_loss': avg_loss,
'test_acc': avg_accuracy
}
# pl.LightningModule has some issues displaying the results automatically
# As a workaround, we can store the result logs as an attribute of the
# class instance and display them manually at the end of testing
# https://github.com/PyTorchLightning/pytorch-lightning/issues/1088
self.test_results = logs
return {
'avg_test_loss': avg_loss,
'avg_test_acc': avg_accuracy,
'log': logs,
'progress_bar': logs
}
# Required for pl.LightningModule
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
# optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=0.9)
return optimizer
class ResNetModel(nn.Module):
def __init__(
self,
num_classes,
loss_fn,
image_feature_dim,
hidden_dim=512,
dropout_p=0.1,
):
super(ResNetModel, self).__init__()
self.image_encoder = torchvision.models.resnet152(pretrained=True)
# Overwrite last layer to get features (rather than classification)
self.image_encoder.fc = torch.nn.Linear(
in_features=RESNET_OUT_DIM, out_features=image_feature_dim)
self.fc1 = torch.nn.Linear(in_features=image_feature_dim, out_features=hidden_dim)
self.fc2 = torch.nn.Linear(in_features=hidden_dim, out_features=num_classes)
self.loss_fn = loss_fn
self.dropout = torch.nn.Dropout(dropout_p)
def forward(self, image, label):
image_features = torch.nn.functional.relu(self.image_encoder(image))
hidden = torch.nn.functional.relu(self.fc1(image_features))
logits = self.fc2(hidden)
# nn.CrossEntropyLoss expects raw logits as model output, NOT torch.nn.functional.softmax(logits, dim=1)
# https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
pred = logits
loss = self.loss_fn(pred, label)
return (pred, loss)
class ResNetIDCDetectionModel(pl.LightningModule):
def __init__(self, hparams=None):
super(ResNetIDCDetectionModel, self).__init__()
if hparams:
# Cannot reassign self.hparams in pl.LightningModule; must use update()
# https://github.com/PyTorchLightning/pytorch-lightning/discussions/7525
self.hparams.update(hparams)
self.embedding_dim = self.hparams.get("embedding_dim", 768)
self.image_feature_dim = self.hparams.get("image_feature_dim", 300)
self.learning_rate = self.hparams.get("learning_rate", LEARNING_RATE)
self.model = ResNetModel(
num_classes=self.hparams.get("num_classes", NUM_CLASSES),
loss_fn=torch.nn.CrossEntropyLoss(),
image_feature_dim=self.image_feature_dim,
dropout_p=self.hparams.get("dropout_p", DROPOUT_P)
)
# Required for pl.LightningModule
def forward(self, image, label):
# pl.Lightning convention: forward() defines prediction for inference
return self.model(image, label)
# Required for pl.LightningModule
def training_step(self, batch, batch_idx):
global losses
# pl.Lightning convention: training_step() defines prediction and
# accompanying loss for training, independent of forward()
image, label = batch["image"], batch["label"]
pred, loss = self.model(image, label)
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
print(loss.item())
losses.append(loss.item())
return loss
# Optional for pl.LightningModule
def training_step_end(self, batch_parts):
"""
Aggregates results when training using a strategy that splits data
from each batch across GPUs (e.g. data parallel)
Note that training_step returns a loss, thus batch_parts returns a list
of K loss values (where there are K GPUs being used)
"""
return sum(batch_parts) / len(batch_parts)
# Optional for pl.LightningModule
def test_step(self, batch, batch_idx):
image, label = batch["image"], batch["label"]
pred, loss = self.model(image, label)
pred_label = torch.argmax(pred, dim=1)
accuracy = torch.sum(pred_label == label).item() / (len(label) * 1.0)
output = {
'test_loss': loss,
'test_acc': torch.tensor(accuracy).cuda()
}
print(loss.item(), output['test_acc'])
return output
# Optional for pl.LightningModule
def test_epoch_end(self, outputs):
avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
avg_accuracy = torch.stack([x["test_acc"] for x in outputs]).mean()
logs = {
'test_loss': avg_loss,
'test_acc': avg_accuracy
}
# pl.LightningModule has some issues displaying the results automatically
# As a workaround, we can store the result logs as an attribute of the
# class instance and display them manually at the end of testing
# https://github.com/PyTorchLightning/pytorch-lightning/issues/1088
self.test_results = logs
return {
'avg_test_loss': avg_loss,
'avg_test_acc': avg_accuracy,
'log': logs,
'progress_bar': logs
}
# Required for pl.LightningModule
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
# optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=0.9)
return optimizer
class PrintCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training started...")
def on_train_end(self, trainer, pl_module):
print("Training done...")
global losses
for loss_val in losses:
print(loss_val)