-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
1,906 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,54 @@ | ||
# captcha | ||
# Captcha | ||
|
||
Recognize captcha using deep learning ResNet model and TFLearn | ||
|
||
用深度学习残差网络(ResNet)模型实现的验证码自动识别,TFLearn 代码实现 | ||
|
||
 | ||
|
||
 | ||
|
||
 | ||
|
||
经过学习后自动识别类似上面这样的验证码 | ||
|
||
> 1000个训练数据,经过短短几分钟的训练,正确率可以达到 99% | ||
详细过程请参考 | ||
|
||
> 验证码识别-TFLearn 版-单字符-简化-训练.ipynb | ||
> | ||
> 验证码识别-TFLearn 版-单字符-简化-使用.ipynb | ||
|
||
|
||
# 使用 | ||
|
||
## 训练 | ||
|
||
```python | ||
$ python3 train.py -i data -m model | ||
Training Step: 640 | total loss: 0.30894 | time: 8.667s | ||
| Momentum | epoch: 020 | loss: 0.30894 - acc: 0.9710 | val_loss: 0.15319 - val_acc: 0.9613 -- iter: 3196/3196 | ||
-- | ||
Successfully left training! Final model accuracy: 0.9710448384284973 | ||
save trained model to model/model.tfl | ||
Training Duration 177.041 sec | ||
``` | ||
|
||
> 第一个参数 -i 指向需要学习的验证码目录 | ||
> | ||
> 第二个参数 -m 为学习完成后输出的模型 | ||
## 预测 | ||
|
||
```python | ||
$ python3 test.py -i data -m model | ||
data/0308_1.png 0308 | ||
data/1576_1.png 1576 | ||
data/8414_1.png 8414 | ||
data/0735_1.png 0135 | ||
data/9866_1.png 9866 | ||
``` | ||
|
||
## |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# -*- coding: UTF-8 -*- | ||
|
||
import os | ||
import cv2 | ||
import numpy as np | ||
import pandas as pd | ||
import tqdm | ||
from sklearn import model_selection | ||
|
||
|
||
def read_image(filename, IMAGE_H, IMAGE_W, LABEL_LENGTH): | ||
image = cv2.imread(filename) # 读 PNG 只会读到3个通道 | ||
# image = mpimg.imread(filename) | ||
h, w = image.shape[:2] | ||
image = image[0:h, 12:w - 6] | ||
image = cv2.resize(image, (IMAGE_W * LABEL_LENGTH, IMAGE_H), cv2.INTER_LINEAR) # 缩放大小 | ||
|
||
# Convert from [0, 255] -> [0.0, 1.0]. | ||
image = image.astype(np.float32) | ||
image = image / 255.0 | ||
|
||
return image | ||
|
||
|
||
def split_image(image, IMAGE_H, IMAGE_W, LABEL_LENGTH): | ||
images = [] | ||
h = image.shape[0] | ||
sw = IMAGE_W | ||
for i in range(LABEL_LENGTH): | ||
x = sw * i | ||
images.append(image[0:h, x:x + sw]) | ||
|
||
return images | ||
|
||
|
||
# 验证码去燥 | ||
def remove_noise(image): | ||
return image | ||
|
||
|
||
def load_data(path, IMAGE_H, IMAGE_W, LABEL_LENGTH, LABELS): | ||
# OneHot | ||
def char_to_vec(c): | ||
y = np.zeros((len(LABELS),)) | ||
y[LABELS.index(c)] = 1.0 | ||
return y | ||
|
||
labels = [] | ||
images = [] | ||
sfiles = [] | ||
fnames = os.listdir(path) | ||
with tqdm.tqdm(total=len(fnames)) as pbar: | ||
for i, name in enumerate(fnames): | ||
if name.endswith(".jpg") or name.endswith(".jpeg") or name.endswith(".png"): | ||
image = read_image(os.path.join(path, name), IMAGE_H, IMAGE_W, LABEL_LENGTH) | ||
|
||
simgs = split_image(image, IMAGE_H, IMAGE_W, LABEL_LENGTH) | ||
label = name[:LABEL_LENGTH].upper() | ||
|
||
for k in range(LABEL_LENGTH): | ||
labels.append(char_to_vec(label[k])) | ||
images.append(remove_noise(simgs[k])) | ||
sfiles.append(name) | ||
pbar.update(1) | ||
|
||
images = np.array(images) | ||
labels = np.array(labels) | ||
labels = labels.reshape((labels.shape[0], -1)) | ||
sfiles = np.array(sfiles) | ||
|
||
return images, labels, sfiles | ||
|
||
|
||
# | ||
class DataSet(object): | ||
def __init__(self, images, labels): | ||
assert images.shape[0] == labels.shape[0], ( | ||
"images.shape: %s labels.shape: %s" % (images.shape, | ||
labels.shape)) | ||
self._num_examples = images.shape[0] | ||
self._images = images | ||
self._labels = labels | ||
self._epochs_completed = 0 | ||
self._index_in_epoch = 0 | ||
|
||
@property | ||
def images(self): | ||
return self._images | ||
|
||
@property | ||
def labels(self): | ||
return self._labels | ||
|
||
@property | ||
def num_examples(self): | ||
return self._num_examples | ||
|
||
@property | ||
def epochs_completed(self): | ||
return self._epochs_completed | ||
|
||
def next_batch(self, batch_size): | ||
"""Return the next `batch_size` examples from this data set.""" | ||
start = self._index_in_epoch | ||
self._index_in_epoch += batch_size | ||
if self._index_in_epoch > self._num_examples: | ||
# Finished epoch | ||
self._epochs_completed += 1 | ||
# Shuffle the data | ||
perm = np.arange(self._num_examples) | ||
np.random.shuffle(perm) | ||
self._images = self._images[perm] | ||
self._labels = self._labels[perm] | ||
# Start next epoch | ||
start = 0 | ||
self._index_in_epoch = batch_size | ||
assert batch_size <= self._num_examples | ||
end = self._index_in_epoch | ||
return self._images[start:end], self._labels[start:end] | ||
|
||
|
||
class DataSets(object): | ||
pass | ||
|
||
|
||
# Onehot 编码转换回字符串 | ||
def onehot2number(label, LABELS): | ||
return LABELS[np.argmax(label)] | ||
|
||
|
||
# 进行数据平衡 | ||
def balance(images, labels, sfiles, LABELS): | ||
a = [] | ||
for i, label in enumerate(labels): | ||
label = onehot2number(label, LABELS) | ||
a.append([i, label]) | ||
|
||
df = pd.DataFrame(a, columns=['i', 'label']) | ||
|
||
new_images = [] | ||
new_labels = [] | ||
new_sfiles = [] | ||
for i in df['i']: | ||
new_images.append(images[i]) | ||
new_labels.append(labels[i]) | ||
new_sfiles.append(sfiles[i]) | ||
images = np.array(new_images) | ||
labels = np.array(new_labels) | ||
sfiles = np.array(new_sfiles) | ||
|
||
return images, labels, sfiles | ||
|
||
|
||
def make_data_sets(images, labels): | ||
trainX, testX, trainY, testY = model_selection.train_test_split(images, labels, test_size=0.20, random_state=42) | ||
|
||
data_sets = DataSets() | ||
data_sets.train = DataSet(trainX, trainY) | ||
data_sets.test = DataSet(testX, testY) | ||
|
||
return data_sets | ||
|
||
|
||
def load(image_dir, IMAGE_H, IMAGE_W, LABEL_LENGTH, LABELS): | ||
images, labels, sfiles = load_data(image_dir, IMAGE_H, IMAGE_W, LABEL_LENGTH, LABELS) | ||
images, labels, sfiles = balance(images, labels, sfiles, LABELS) | ||
data_sets = make_data_sets(images, labels) | ||
return data_sets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#!/usr/bin/python | ||
# -*- coding: UTF-8 -*- | ||
import getopt | ||
import os | ||
import sys | ||
|
||
import data | ||
import resnet | ||
|
||
|
||
def main(argv): | ||
image_dir = '' | ||
model_dir = '' | ||
try: | ||
opts, args = getopt.getopt(argv, "hi:m:", ["image=", "model="]) | ||
except getopt.GetoptError: | ||
print('test.py -i <image_dir> -m <model_dir>') | ||
sys.exit(2) | ||
for opt, arg in opts: | ||
if opt == '-h': | ||
print('test.py -i <image_dir> -m <model_dir>') | ||
sys.exit() | ||
elif opt in ("-i", "--image"): | ||
image_dir = arg | ||
elif opt in ("-m", "--model"): | ||
model_dir = arg | ||
|
||
LABEL_LENGTH = 4 # 验证码字符数 | ||
# LABELS = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" | ||
LABELS = "0123456789" # 验证码字符组成 | ||
|
||
IMAGE_H = 28 # 缩放后单个字符图片大小 | ||
IMAGE_W = 28 # 缩放后单个字符图片大小 | ||
IMAGE_C = 3 # 图片通道数 | ||
|
||
model_file = model_dir + '/model.tfl' | ||
|
||
# 加载模型 | ||
model = resnet.load(IMAGE_H, IMAGE_W, IMAGE_C, LABELS, model_file) | ||
|
||
# 预测 | ||
def predict(filename): | ||
image = data.read_image(filename, IMAGE_H, IMAGE_W, LABEL_LENGTH) | ||
x_data = data.split_image(image, IMAGE_H, IMAGE_W, LABEL_LENGTH) | ||
y_preds = model.predict(x_data) | ||
|
||
label = '' | ||
for y_pred in y_preds: | ||
label = label + data.onehot2number(y_pred, LABELS) | ||
|
||
return label | ||
|
||
import random | ||
files = os.listdir(image_dir) | ||
|
||
for i in range(5): | ||
filename = os.path.join(image_dir, files[random.randint(0, len(files))]) | ||
label = predict(filename) | ||
print(filename, label) | ||
|
||
|
||
if __name__ == "__main__": | ||
main(sys.argv[1:]) |
Oops, something went wrong.