-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathsamplers.py
32 lines (25 loc) · 881 Bytes
/
samplers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import numpy as np
class CategoriesSampler():
def __init__(self, label, n_batch, n_cls, n_per):
self.n_batch = n_batch
self.n_cls = n_cls
self.n_per = n_per
label = np.array(label)
self.m_ind = []
for i in range(max(label) + 1):
ind = np.argwhere(label == i).reshape(-1)
ind = torch.from_numpy(ind)
self.m_ind.append(ind)
def __len__(self):
return self.n_batch
def __iter__(self):
for i_batch in range(self.n_batch):
batch = []
classes = torch.randperm(len(self.m_ind))[:self.n_cls]
for c in classes:
l = self.m_ind[c]
pos = torch.randperm(len(l))[:self.n_per]
batch.append(l[pos])
batch = torch.stack(batch).t().reshape(-1)
yield batch