-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdataset.py
73 lines (66 loc) · 2 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from PIL import Image
from transformers import GPT2TokenizerFast
import numpy as np
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
class Dataset:
def __init__(self, df, tfms):
self.df = df
self.tfms = tfms
def __len__(self):
return len(self.df)
def __getitem__(self,idx):
sample = self.df.iloc[idx,:]
image = sample['image']
caption = sample['caption']
image = Image.open(image).convert('RGB')
image = np.array(image)
augs = self.tfms(image=image)
image = augs['image']
caption = f"{caption}<|endoftext|>"
input_ids = tokenizer(
caption,
truncation=True)['input_ids']
labels = input_ids.copy()
labels[:-1] = input_ids[1:]
return image,input_ids,labels
def collate_fn(batch):
image = [i[0] for i in batch]
input_ids = [i[1] for i in batch]
labels = [i[2] for i in batch]
image = torch.stack(image,dim=0)
input_ids = tokenizer.pad(
{'input_ids':input_ids},
padding='longest',
return_attention_mask=False,
return_tensors='pt'
)['input_ids']
labels = tokenizer.pad(
{'input_ids':labels},
padding='longest',
return_attention_mask=False,
return_tensors='pt'
)['input_ids']
mask = (input_ids!=tokenizer.pad_token_id).long()
labels[mask==0]=-100
return image, input_ids, labels
sample_tfms = [
A.HorizontalFlip(),
A.RandomBrightnessContrast(),
A.ColorJitter(),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.3, rotate_limit=45, p=0.5),
A.HueSaturationValue(p=0.3),
]
train_tfms = A.Compose([
*sample_tfms,
A.Resize(224,224),
A.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5],always_apply=True),
ToTensorV2()
])
valid_tfms = A.Compose([
A.Resize(224,224),
A.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5],always_apply=True),
ToTensorV2()
])