-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils_lora.py
119 lines (93 loc) · 3.82 KB
/
utils_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from tqdm import tqdm
import jittor as jt
import clip
from PIL import Image
from jittor import transform as jt_transform
from jittor.transform import Compose, ImageNormalize
# def cls_acc(output, target, topk=1):
# pred = output.topk(topk, 1, True, True)[1].t()
# correct = pred.eq(target.view(1, -1).expand_as(pred))
# acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
# acc = 100 * acc / target.shape[0]
# return acc
def cls_acc(output, target):
l_argmax = output.argmax(dim=-1)
argmax_result = l_argmax[0]
correct_predictions = (argmax_result == target).sum()
acc = 100 * correct_predictions/ target.shape[0]
# pred = output.topk(topk, 1, True, True)[1].transpose(0, 1)
# correct = pred.equal(target.view(1, -1).expand_as(pred))
# acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).numpy())
# acc = 100 * acc / target.shape[0]
return acc
class Resize:
def __init__(self, size, mode=Image.BILINEAR):
if isinstance(size, int):
self.size = size
else:
self.size = _setup_size(
size,
error_msg="If size is a sequence, it should have 2 values")
self.mode = mode
def __call__(self, img: Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
if isinstance(self.size, int):
w, h = img.size
short, long = (w, h) if w <= h else (h, w)
if short == self.size:
return img
new_short, new_long = self.size, int(self.size * long / short)
new_w, new_h = (new_short, new_long) if w <= h else (new_long,
new_short)
size = (new_h, new_w)
return jt_transform.resize(img, size, self.mode)
def _convert_image_to_rgb(image):
return image.convert("RGB")
# Transforms
tfm_train_base = jt_transform.Compose([
# jt_transform.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation= Image.BICUBIC),
Resize(224, mode=Image.BICUBIC),
jt_transform.RandomCrop(224),
jt_transform.RandomHorizontalFlip(p=0.5),
jt_transform.ToTensor()
])
tfm_test_base = jt_transform.Compose([
Resize(224, mode=Image.BICUBIC),
jt_transform.CenterCrop(224),
_convert_image_to_rgb,
jt_transform.ToTensor()
])
def clip_classifier(classnames, template, clip_model):
clip_weights = []
jt.flags.use_cuda = 1
with jt.no_grad():
for classname in classnames:
texts = [classname]
texts = clip.tokenize(texts)
texts = jt.array(texts)
class_embeddings = clip_model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
clip_weights.append(class_embedding)
del classname
del class_embeddings
jt.gc()
clip_weights = jt.stack(clip_weights, dim=1)
return clip_weights
tfm_clip = Compose([ImageNormalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
def pre_load_features(clip_model, loader):
features, labels = [], []
with jt.no_grad():
for i, (images, target) in enumerate(tqdm(loader)):
# images, target = images.cuda(), target.cuda()
image_features = clip_model.encode_image(tfm_clip(images))
image_features /= image_features.norm(dim=-1, keepdim=True)
features.append(image_features.cpu())
labels.append(target.cpu())
del image_features
jt.gc()
jt.sync_all()
features, labels = jt.cat(features), jt.cat(labels)
return features, labels