forked from puke3615/SceneClassify
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_image_generator.py
68 lines (59 loc) · 1.86 KB
/
test_image_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# coding=utf-8
from generator import ImageDataGenerator
from im_utils import *
from config import *
"""
测试图片生成器
"""
def image_generator(train=True):
def wrap(value):
return float(train) and value
return ImageDataGenerator(
# contrast_stretching=True, #####
# histogram_equalization=False, #####
# adaptive_equalization=False, #####
# channel_shift_range=wrap(25.5),
# rotation_range=wrap(15.),
# width_shift_range=wrap(0.2),
# height_shift_range=wrap(0.2),
# shear_range=wrap(0.2),
# zoom_range=wrap(0.2),
# horizontal_flip=train,
# preprocessing_function=im_utils.scene_preprocess_input,
preprocessing_function=lambda x: scene_preprocess_input(aug_images([x])[0] if train else x)
)
def data_generator(path_image, train=True):
return image_generator(train).flow_from_directory(
path_image,
classes=['%02d' % i for i in range(80)],
target_size=(299, 299),
batch_size=32,
class_mode='categorical',
crop_mode=None,
save_prefix='train' if train else 'val',
save_to_dir='/Users/zijiao/Desktop/1',
)
generator = data_generator(PATH_TRAIN_IMAGES, train=True)
for i, (x, y) in enumerate(generator):
if i >= 1:
break
print(len(y))
# from PIL import Image
# from skimage import exposure
# import numpy as np
# path = os.path.join(PATH_TRAIN_IMAGES, '00/0d8575935a771b6a64aa0bf769ae87453beefcbf.jpg')
# im = Image.open(path)
# # im.show()
#
# im = np.array(im)
# p2, p98 = np.percentile(im, (2, 98)) #####
# # im = exposure.rescale_intensity(im, in_range=(p2, p98)) #####
#
# # im = exposure.equalize_adapthist(im, clip_limit=0.03) #####
# # im *= 255
#
# # im = exposure.equalize_hist(im).astype(np.uint8) #####
# # im *= 255
#
# im = Image.fromarray(im.astype(np.uint8))
# im.show()