forked from amazon-science/mix-generation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmixgen.py
27 lines (23 loc) · 788 Bytes
/
mixgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""
MixGen: A New Multi-Modal Data Augmentation
https://arxiv.org/abs/2206.08358
Apache-2.0 License, Copyright 2022 Amazon
"""
import random
import numpy as np
def mixgen(image, text, num, lam=0.5):
for i in range(num):
# image mixup
image[i,:] = (lam * image[i,:] + (1 - lam) * image[i+num,:]).unsqueeze(dim=0)
# text concat
text[i] = text[i] + " " + text[i+num]
return image, text
def mixgen_batch(image, text, num, lam=0.5):
batch_size = image.size()[0]
index = np.random.permutation(batch_size)
for i in range(batch_size):
# image mixup
image[i,:] = (lam * image[i,:] + (1 - lam) * image[index[i],:]).unsqueeze(dim=0)
# text concat
text[i] = text[i] + " " + text[index[i]]
return image, text