Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Canjie-Luo authored Jan 9, 2019
1 parent 5f70d9e commit 7d68a27
Show file tree
Hide file tree
Showing 17 changed files with 1,355 additions and 0 deletions.
81 changes: 81 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# MORAN: A Multi-Object Rectified Attention Network for Scene Text Recognition

![](https://img.shields.io/badge/version-v2-orange.svg)

MORAN is a network with rectification mechanism for general scene text recognition. The paper in [arXiv]() version is available now.

![](demo/MORAN_v2.gif)

## Improvements of MORAN v2:

- More stable rectification network for one-stage training
- Replace VGG backbone by ResNet
- Use bidirectional decoder (a trick borrowed from [ASTER](https://github.com/bgshih/aster))

| <center>Dataset</center> | <center>IIIT5K</center> | <center>SVT</center> | <center>IC03</center> | <center>IC13</center> | <center>SVT-P</center> | <center>CUTE80</center> | <center>IC15 (1811)</center> | <center>IC15 (2077)</center> |
| :---: | :---: | :---: | :---:| :---:| :---:| :---:| :---:| :---:|
| MORAN v1 (two-stage training) | <center>91.2</center> | <center>**88.3**</center> | <center>**95.0**</center> | <center>92.4</center> | <center>76.1</center> | <center>77.4</center> | <center>74.7</center> | <center>68.8</center> |
| <center>MORAN v2 (one-stage training)</center> | <center>**93.4**</center> | <center>**88.3**</center> | <center>94.2</center> | <center>**93.2**</center> | <center>**79.7**</center> | <center>**81.9**</center> | <center>**77.8**</center> | <center>**73.9**</center> |

## Requirements

- [PyTorch](https://pytorch.org/) 0.3.*
- [TorchVision](https://pypi.org/project/torchvision/)
- [Python](https://www.python.org/) 2.7.*
- [OpenCV](https://opencv.org/) 2.4.*

Use [pip](https://pypi.org/project/pip/) to install the following libraries.

```bash
pip install -r requirements.txt
```

- [Colour](https://pypi.org/project/colour/)
- [LMDB](https://pypi.org/project/lmdb/)
- [matplotlib](https://pypi.org/project/matplotlib/)

## Data Preparation
Please convert your own dataset to lmdb format by using the [tool](https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py) provided by [@Baoguang Shi](https://github.com/bgshih). You can also download the training and testing datasets prepared by us. The raw pictures of testing datasets can be found [here](https://github.com/chengzhanzhan/STR).

- [about 20G training datasets and testing datasets](https://pan.baidu.com/s/1TqZfvoEhyv57yf4YBjSzFg), password: l8em

## Training and Testing

Modify the path to dataset folder in `train_MORAN.sh`:

```bash
--train_nips path_to_dataset \
--train_cvpr path_to_dataset \
--valroot path_to_dataset \
```

And start training: (manually decrease the learning rate for your task)

```bash
sh train_MORAN.sh
```

## Demo

Download the model parameter file from the link above and put the `demo.pth` into root folder. Then, execute the `demo.py` for more visualizations.

```bash
python demo.py
```

## Citation

```
@article{cluo2019moran,
author = {Canjie Luo, Lianwen Jin, Zenghui Sun},
title = {MORAN: A Multi-Object Rectified Attention network for Scene Text Recognition},
journal = {Pattern Recognition},
volume = {},
number = {},
pages = {},
year = {2019},
}
```

## Acknowledgment
The repo is developed based on [@Jieru Mei's](https://github.com/meijieru) [crnn.pytorch](https://github.com/meijieru/crnn.pytorch) and [@marvis'](https://github.com/marvis) [ocr_attention](https://github.com/marvis/ocr_attention). Thanks for your contribution.
65 changes: 65 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from torch.autograd import Variable
import tools.utils as utils
import tools.dataset as dataset
from PIL import Image
from collections import OrderedDict
import cv2
from models.moran import MORAN

model_path = './demo.pth'
img_path = './demo/0.png'
alphabet = '0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:s:t:u:v:w:x:y:z:$'

MORAN = MORAN(1, len(alphabet.split(':')), 256, 32, 100, BidirDecoder=True)

if torch.cuda.is_available():
MORAN = MORAN.cuda()

print('loading pretrained model from %s' % model_path)
state_dict = torch.load(model_path)
MORAN_state_dict_rename = OrderedDict()
for k, v in state_dict.items():
name = k.replace("module.", "") # remove `module.`
MORAN_state_dict_rename[name] = v
MORAN.load_state_dict(MORAN_state_dict_rename)

for p in MORAN.parameters():
p.requires_grad = False
MORAN.eval()

converter = utils.strLabelConverterForAttention(alphabet, ':')
transformer = dataset.resizeNormalize((100, 32))
image = Image.open(img_path).convert('L')
image = transformer(image)

if torch.cuda.is_available():
image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)
text = torch.LongTensor(1 * 5)
length = torch.IntTensor(1)
text = Variable(text)
length = Variable(length)

max_iter = 20
t, l = converter.encode('0'*max_iter)
utils.loadData(text, t)
utils.loadData(length, l)
output = MORAN(image, length, text, text, test=True, debug=True)

preds, preds_reverse = output[0]
demo = output[1]

_, preds = preds.max(1)
_, preds_reverse = preds_reverse.max(1)

sim_preds = converter.decode(preds.data, length.data)
sim_preds = sim_preds.strip().split('$')[0]
sim_preds_reverse = converter.decode(preds_reverse.data, length.data)
sim_preds_reverse = sim_preds_reverse.strip().split('$')[0]

print('\nResult:\n' + 'Left to Right: ' + sim_preds + '\nRight to Left: ' + sim_preds_reverse + '\n\n')

cv2.imshow("demo", demo)
cv2.waitKey()
Binary file added demo/0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/MORAN_v2.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 7d68a27

Please sign in to comment.