Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
SixQuant committed Jan 16, 2019
1 parent 062fc71 commit df7b090
Show file tree
Hide file tree
Showing 16 changed files with 1,906 additions and 1 deletion.
54 changes: 53 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,54 @@
# captcha
# Captcha

Recognize captcha using deep learning ResNet model and TFLearn

用深度学习残差网络(ResNet)模型实现的验证码自动识别,TFLearn 代码实现

![2565_1](assets/2565_1.png)

![3426_1](assets/3426_1.png)

![6071_1](assets/6071_1.png)

经过学习后自动识别类似上面这样的验证码

> 1000个训练数据,经过短短几分钟的训练,正确率可以达到 99%
详细过程请参考

> 验证码识别-TFLearn 版-单字符-简化-训练.ipynb
>
> 验证码识别-TFLearn 版-单字符-简化-使用.ipynb


# 使用

## 训练

```python
$ python3 train.py -i data -m model
Training Step: 640 | total loss: 0.30894 | time: 8.667s
| Momentum | epoch: 020 | loss: 0.30894 - acc: 0.9710 | val_loss: 0.15319 - val_acc: 0.9613 -- iter: 3196/3196
--
Successfully left training! Final model accuracy: 0.9710448384284973
save trained model to model/model.tfl
Training Duration 177.041 sec
```

> 第一个参数 -i 指向需要学习的验证码目录
>
> 第二个参数 -m 为学习完成后输出的模型
## 预测

```python
$ python3 test.py -i data -m model
data/0308_1.png 0308
data/1576_1.png 1576
data/8414_1.png 8414
data/0735_1.png 0135
data/9866_1.png 9866
```

##
Binary file added assets/2565_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/3426_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/6071_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data.zip
Binary file not shown.
15 changes: 15 additions & 0 deletions src/.idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions src/.idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions src/.idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions src/.idea/src.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

303 changes: 303 additions & 0 deletions src/.idea/workspace.xml

Large diffs are not rendered by default.

168 changes: 168 additions & 0 deletions src/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# -*- coding: UTF-8 -*-

import os
import cv2
import numpy as np
import pandas as pd
import tqdm
from sklearn import model_selection


def read_image(filename, IMAGE_H, IMAGE_W, LABEL_LENGTH):
image = cv2.imread(filename) # 读 PNG 只会读到3个通道
# image = mpimg.imread(filename)
h, w = image.shape[:2]
image = image[0:h, 12:w - 6]
image = cv2.resize(image, (IMAGE_W * LABEL_LENGTH, IMAGE_H), cv2.INTER_LINEAR) # 缩放大小

# Convert from [0, 255] -> [0.0, 1.0].
image = image.astype(np.float32)
image = image / 255.0

return image


def split_image(image, IMAGE_H, IMAGE_W, LABEL_LENGTH):
images = []
h = image.shape[0]
sw = IMAGE_W
for i in range(LABEL_LENGTH):
x = sw * i
images.append(image[0:h, x:x + sw])

return images


# 验证码去燥
def remove_noise(image):
return image


def load_data(path, IMAGE_H, IMAGE_W, LABEL_LENGTH, LABELS):
# OneHot
def char_to_vec(c):
y = np.zeros((len(LABELS),))
y[LABELS.index(c)] = 1.0
return y

labels = []
images = []
sfiles = []
fnames = os.listdir(path)
with tqdm.tqdm(total=len(fnames)) as pbar:
for i, name in enumerate(fnames):
if name.endswith(".jpg") or name.endswith(".jpeg") or name.endswith(".png"):
image = read_image(os.path.join(path, name), IMAGE_H, IMAGE_W, LABEL_LENGTH)

simgs = split_image(image, IMAGE_H, IMAGE_W, LABEL_LENGTH)
label = name[:LABEL_LENGTH].upper()

for k in range(LABEL_LENGTH):
labels.append(char_to_vec(label[k]))
images.append(remove_noise(simgs[k]))
sfiles.append(name)
pbar.update(1)

images = np.array(images)
labels = np.array(labels)
labels = labels.reshape((labels.shape[0], -1))
sfiles = np.array(sfiles)

return images, labels, sfiles


#
class DataSet(object):
def __init__(self, images, labels):
assert images.shape[0] == labels.shape[0], (
"images.shape: %s labels.shape: %s" % (images.shape,
labels.shape))
self._num_examples = images.shape[0]
self._images = images
self._labels = labels
self._epochs_completed = 0
self._index_in_epoch = 0

@property
def images(self):
return self._images

@property
def labels(self):
return self._labels

@property
def num_examples(self):
return self._num_examples

@property
def epochs_completed(self):
return self._epochs_completed

def next_batch(self, batch_size):
"""Return the next `batch_size` examples from this data set."""
start = self._index_in_epoch
self._index_in_epoch += batch_size
if self._index_in_epoch > self._num_examples:
# Finished epoch
self._epochs_completed += 1
# Shuffle the data
perm = np.arange(self._num_examples)
np.random.shuffle(perm)
self._images = self._images[perm]
self._labels = self._labels[perm]
# Start next epoch
start = 0
self._index_in_epoch = batch_size
assert batch_size <= self._num_examples
end = self._index_in_epoch
return self._images[start:end], self._labels[start:end]


class DataSets(object):
pass


# Onehot 编码转换回字符串
def onehot2number(label, LABELS):
return LABELS[np.argmax(label)]


# 进行数据平衡
def balance(images, labels, sfiles, LABELS):
a = []
for i, label in enumerate(labels):
label = onehot2number(label, LABELS)
a.append([i, label])

df = pd.DataFrame(a, columns=['i', 'label'])

new_images = []
new_labels = []
new_sfiles = []
for i in df['i']:
new_images.append(images[i])
new_labels.append(labels[i])
new_sfiles.append(sfiles[i])
images = np.array(new_images)
labels = np.array(new_labels)
sfiles = np.array(new_sfiles)

return images, labels, sfiles


def make_data_sets(images, labels):
trainX, testX, trainY, testY = model_selection.train_test_split(images, labels, test_size=0.20, random_state=42)

data_sets = DataSets()
data_sets.train = DataSet(trainX, trainY)
data_sets.test = DataSet(testX, testY)

return data_sets


def load(image_dir, IMAGE_H, IMAGE_W, LABEL_LENGTH, LABELS):
images, labels, sfiles = load_data(image_dir, IMAGE_H, IMAGE_W, LABEL_LENGTH, LABELS)
images, labels, sfiles = balance(images, labels, sfiles, LABELS)
data_sets = make_data_sets(images, labels)
return data_sets
63 changes: 63 additions & 0 deletions src/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import getopt
import os
import sys

import data
import resnet


def main(argv):
image_dir = ''
model_dir = ''
try:
opts, args = getopt.getopt(argv, "hi:m:", ["image=", "model="])
except getopt.GetoptError:
print('test.py -i <image_dir> -m <model_dir>')
sys.exit(2)
for opt, arg in opts:
if opt == '-h':
print('test.py -i <image_dir> -m <model_dir>')
sys.exit()
elif opt in ("-i", "--image"):
image_dir = arg
elif opt in ("-m", "--model"):
model_dir = arg

LABEL_LENGTH = 4 # 验证码字符数
# LABELS = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
LABELS = "0123456789" # 验证码字符组成

IMAGE_H = 28 # 缩放后单个字符图片大小
IMAGE_W = 28 # 缩放后单个字符图片大小
IMAGE_C = 3 # 图片通道数

model_file = model_dir + '/model.tfl'

# 加载模型
model = resnet.load(IMAGE_H, IMAGE_W, IMAGE_C, LABELS, model_file)

# 预测
def predict(filename):
image = data.read_image(filename, IMAGE_H, IMAGE_W, LABEL_LENGTH)
x_data = data.split_image(image, IMAGE_H, IMAGE_W, LABEL_LENGTH)
y_preds = model.predict(x_data)

label = ''
for y_pred in y_preds:
label = label + data.onehot2number(y_pred, LABELS)

return label

import random
files = os.listdir(image_dir)

for i in range(5):
filename = os.path.join(image_dir, files[random.randint(0, len(files))])
label = predict(filename)
print(filename, label)


if __name__ == "__main__":
main(sys.argv[1:])
Loading

0 comments on commit df7b090

Please sign in to comment.