diff --git a/Figure_1.png b/Figure_1.png new file mode 100644 index 0000000..4021f38 Binary files /dev/null and b/Figure_1.png differ diff --git a/README.md b/README.md index 1a45a35..b2b5ac9 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,109 @@ -# RCNN -This is a Implementation of R-CNN +# Implementation of R-CNN + +This is a Implementation of R-CNN. (R-CNN: Regions with Convolutional Neural Network Features) + +## Prerequisites +- Python 3.5 +- Pytorch 0.4.1 +- sklearn 0.18.1 +- cv2 3.4.0 + +You can run the code in Windows/Linux with CPU/GPU. + +## Dataset + +For simplicity, rather than use PASCAL VOC 2012 dataset, I choose a samll dataset 17flowers. It can be downloaded from my Google Drive: + +https://drive.google.com/open?id=1qAkLZ-Wa8mca2xz-RZI3XW1Syjk2JfOB + +## Structure + +The project is structured as follows: + +``` +├── checkpoints/ +├── data/ +| ├── dataset_factory.py +| ├── dataset_factory.py +├── generate/ +├── loss/ +| ├── losses.py +├── models/ +| ├── model_factory.py +| ├── models.py +├── networks/ +| ├── network_factory.py +| ├── networks.py +├── options/ +| ├── base_options.py +| ├── test_options.py +| ├── train_options.py +├── sample_dataset/ +| ├── 2flowers +| ├── 17flowers +│ ├── fine_tune_list.txt +| ├── train_list.txt +├── utils/ +| ├── selectivesearch.py +| ├── util.py +├── evaluate.py +├── train_step1.py +├── train_step2.py +├── train_step3.py +├── train_step4.py +``` + +## Getting started + +### Supervised Train + +In this step, we use pre-trained AlexNet of Pytorch and train it using 17flowers dataset. + +``` +$ python train_step1.py +``` + +You can directly run it with default parameters. + +### Finetune + +In this step, we finetune the model on the 2flowers dataset. + +``` +$ python train_step2.py --batch_size 128 --load_alex_epoch 100 --options_dir finetune +``` + +Here we use the pre-trained AlexNet with 100 training epoch in the first step. "Options_dir" decides where to save parameters. + +### Train SVM + +In this step, we use features extracted from the last step to train SVMs. As stated in the paper, class-specific SVMs are trained in this step. Here there are two SVMs. + +``` +$ python train_step3.py --load_finetune_epoch 100 --options_dir svm +``` + +Here we adopt the finetuned AlexNet with 100 epoch in last step as feature extracter. + +### Box Regression + +In this step, we train a regression network to refine bounding boxes. + +``` +$ python train_step4.py --decay_rate 0.5 --options_dir regression --batch_size 512 +``` + +### Evaluate + +``` +$ python evaluate.py --load_finetune_epoch 100 --load_reg_epoch 40 --img_path ./sample_dataset/2flowers/jpg/1/image_1281.jpg +``` + +![](https://github.com/bigbrother33/Deep-Learning/blob/master/photo/1.PNG) + +## References + +- Selective-search: https://github.com/AlpacaDB/selectivesearch +- R-CNN with Tensorflow: https://github.com/bigbrother33/Deep-Learning + +Many codes related to image proposals are copied directly from "R-CNN with Tensorflow", and this project has shown more details about R-CNN in chinese. diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/dataset_factory.py b/data/dataset_factory.py new file mode 100644 index 0000000..62c07ee --- /dev/null +++ b/data/dataset_factory.py @@ -0,0 +1,39 @@ +import os + +class DatasetFactory: + def __init__(self): + pass + + @staticmethod + def get_by_name(dataset_name, opt): + if dataset_name == 'AlexnetDataset': + from data.datasets import AlexnetDataset + dataset = AlexnetDataset(opt) + elif dataset_name == 'FinetuneDataset': + from data.datasets import FinetuneDataset + dataset = FinetuneDataset(opt) + elif dataset_name == 'SVMDataset': + from data.datasets import SVMDataset + dataset = SVMDataset(opt) + elif dataset_name == 'RegDataset': + from data.datasets import RegDataset + dataset = RegDataset(opt) + else: + raise ValueError("Dataset [%s] not recognized." % dataset_name) + + print('Dataset {} was created'.format(dataset.name)) + return dataset + +class DatasetBase: + def __init__(self, opt): + self._name = 'BaseDataset' + self._root = opt.data_dir + self._opt = opt + + @property + def name(self): + return self._name + + @property + def path(self): + return self._root diff --git a/data/datasets.py b/data/datasets.py new file mode 100644 index 0000000..70b7c5b --- /dev/null +++ b/data/datasets.py @@ -0,0 +1,287 @@ +from data.dataset_factory import DatasetBase +from utils.util import resize_image +from utils.util import clip_pic +from utils.util import IOU +from utils.util import if_intersection +from utils.util import image_proposal +from models.model_factory import ModelsFactory +from sklearn.externals import joblib +import numpy as np +import codecs +import cv2 +import os +import torch + +class AlexnetDataset(DatasetBase): + + def __init__(self, opt): + super(AlexnetDataset, self).__init__(opt) + self._name = 'AlexnetDataset' + + self.datas = [] + + self._img_size = self._opt.image_size + self._batch_size = self._opt.batch_size + + self.epoch = 0 + self.cursor = 0 + + # read dataset + self._load_dataset() + + def _load_dataset(self): + file_path = os.path.join(self._root, self._opt.train_list) + with codecs.open(file_path, 'r', 'utf-8') as fr: + lines = fr.readlines() + for ind, line in enumerate(lines): + context = line.strip().split() + image_path = os.path.join(self._root, context[0]) + label = int(context[1]) + + img = cv2.imread(image_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = resize_image(img, self._img_size, self._img_size) + img = np.asarray(img, dtype='float32') + + self.datas.append([img, label]) + + def get_batch(self): + images = np.zeros((self._batch_size, self._img_size, self._img_size, 3)) + labels = np.zeros((self._batch_size, 1)) + count = 0 + while( count < self._batch_size): + images[count] = self.datas[self.cursor][0] + labels[count] = self.datas[self.cursor][1] + count += 1 + self.cursor += 1 + if self.cursor >= len(self.datas) : + self.cursor = 0 + self.epoch += 1 + np.random.shuffle(self.datas) + return images, labels + + def __len__(self): + return len(self.datas) + + +class FinetuneDataset(DatasetBase): + + def __init__(self, opt): + super(FinetuneDataset, self).__init__(opt) + self._name = 'FinetuneDataset' + + self.datas = [] + self.save_path = os.path.join(self._opt.generate_save_path, 'fineturn_data.npy') + self.pos_datas = [] + self.neg_datas = [] + + self._batch_size = self._opt.batch_size + self._pos_counts = int(self._batch_size / 4) + self._neg_counts = int(self._batch_size / 4 * 3) + + self._pos_cursor = 0 + self._neg_cursor = 0 + + self._img_size = self._opt.image_size + self._threshold = self._opt.finetune_threshold + + self.epoch = 0 + self.cursor = 0 + + # read dataset + if os.path.exists(self.save_path): + self._load_from_numpy() + else: + self._load_dataset() + + def _load_dataset(self): + file_path = os.path.join(self._root, self._opt.finetune_list) + with codecs.open(file_path, 'r', 'utf-8') as fr: + lines = fr.readlines() + for ind, line in enumerate(lines): + context = line.strip().split() + image_path = os.path.join(self._root, context[0]) + index = int(context[1]) + ref_rect = context[2].split(',') + ground_truth = [int(i) for i in ref_rect] + + images, vertices, _ = image_proposal(image_path, self._img_size) + for img_float, proposal_vertice in zip(images, vertices): + iou_val = IOU(ground_truth, proposal_vertice) + if iou_val < self._threshold: + label = 0 + self.neg_datas.append([img_float, label]) + else: + label = index + self.pos_datas.append([img_float, label]) + self.datas.append([img_float, label]) + joblib.dump(self.datas, self.save_path) + + + def _load_from_numpy(self): + self.datas = joblib.load(self.save_path) + for item in self.datas: + if item[1] == 0: + self.neg_datas.append([item[0], item[1]]) + else: + self.pos_datas.append([item[0], item[1]]) + + def get_batch(self): + images = np.zeros((self._batch_size, self._img_size, self._img_size, 3)) + labels = np.zeros((self._batch_size, 1)) + + count = 0 + while (count < self._pos_counts): + images[count] = self.pos_datas[self._pos_cursor][0] + labels[count] = self.pos_datas[self._pos_cursor][1] + count += 1 + self._pos_cursor += 1 + if self._pos_cursor >= len(self.pos_datas): + self._pos_cursor = 0 + np.random.shuffle(self.pos_datas) + while (count < self._pos_counts + self._neg_counts): + images[count] = self.neg_datas[self._neg_cursor][0] + labels[count] = self.neg_datas[self._neg_cursor][1] + count += 1 + self._neg_cursor += 1 + if self._neg_cursor >= len(self.neg_datas): + self._neg_cursor = 0 + np.random.shuffle(self.neg_datas) + """ + state = np.random.get_state() + np.random.shuffle(images) + np.random.set_state(state) + np.random.shuffle(labels) + """ + return images, labels + + + def __len__(self): + return len(self.datas) + +class SVMDataset(DatasetBase): + + def __init__(self, opt): + super(SVMDataset, self).__init__(opt) + self._name = 'SVMDataset' + + self.datas = [] + self.classA_features = [] + self.classA_labels = [] + self.classB_features = [] + self.classB_labels = [] + self.save_path = os.path.join(self._opt.generate_save_path, 'svm_data.npy') + + self._img_size = self._opt.image_size + + self.cursor = 0 + + self.model = ModelsFactory.get_by_name('FineModel', self._opt, is_train=False) + + # read dataset + if os.path.exists(self.save_path): + self._load_from_numpy() + else: + self._load_dataset() + + def _load_dataset(self): + file_path = os.path.join(self._root, self._opt.finetune_list) + with codecs.open(file_path, 'r', 'utf-8') as fr: + lines = fr.readlines() + for ind, line in enumerate(lines): + + context = line.strip().split() + image_path = os.path.join(self._root, context[0]) + index = int(context[1]) + ref_rect = context[2].split(',') + ground_truth = [int(i) for i in ref_rect] + + images, vertices, _ = image_proposal(image_path, self._img_size) + for img_float, proposal_vertice in zip(images, vertices): + iou_val = IOU(ground_truth, proposal_vertice) + if iou_val < self._opt.svm_threshold: + label = 0 + else: + label = index + + px = float(proposal_vertice[0]) + float(proposal_vertice[4] / 2.0) + py = float(proposal_vertice[1]) + float(proposal_vertice[5] / 2.0) + ph = float(proposal_vertice[5]) + pw = float(proposal_vertice[4]) + + gx = float(ref_rect[0]) + gy = float(ref_rect[1]) + gw = float(ref_rect[2]) + gh = float(ref_rect[3]) + + box_label = np.zeros(5) + box_label[1:5] = [(gx - px) / pw, (gy - py) / ph, np.log(gw / pw), np.log(gh / ph)] + if iou_val < self._opt.reg_threshold: + box_label[0] = 0 + else: + box_label[0] = 1 + + input_data=torch.Tensor([img_float]).permute(0,3,1,2) + + feature, final = self.model._forward_test(input_data) + feature = feature.data.cpu().numpy() + + self.datas.append([index, feature, label, box_label]) + joblib.dump(self.datas, self.save_path) + + def _load_from_numpy(self): + self.datas = joblib.load(self.save_path) + for item in self.datas: + if item[0] == 1: + self.classA_features.append(item[1][0]) + self.classA_labels.append(item[2]) + elif item[0] == 2: + self.classB_features.append(item[1][0]) + self.classB_labels.append(item[2]) + + def get_datas(self): + return self.classA_features, self.classA_labels, self.classB_features, self.classB_labels + + def __len__(self): + return len(self.datas) + + +class RegDataset(DatasetBase): + + def __init__(self, opt): + super(RegDataset, self).__init__(opt) + self._name = 'RegDataset' + + self.datas = [] + self._batch_size = self._opt.batch_size + self.cursor = 0 + + self._load_dataset() + + def _load_dataset(self): + file_path = os.path.join(self._opt.generate_save_path, 'svm_data.npy') + self.datas = joblib.load(file_path) + + def get_batch(self): + images = np.zeros((self._batch_size, 4096)) + labels = np.zeros((self._batch_size, 5)) + + count = 0 + while( count < self._batch_size): + images[count] = self.datas[self.cursor][1] + labels[count] = self.datas[self.cursor][3] + count += 1 + self.cursor += 1 + if self.cursor >= len(self.datas): + self.cursor = 0 + np.random.shuffle(self.datas) + return images, labels + + def __len__(self): + return len(self.datas) + + + + + + diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..a9c95c9 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,87 @@ +from __future__ import division +from models.model_factory import ModelsFactory +from options.test_options import TestOptions +from utils.util import image_proposal +from utils.util import show_rect +import torch +import numpy as np + +class Test: + def __init__(self): + self._opt = TestOptions().parse() + self._img_path = self._opt.img_path + self._img_size = self._opt.image_size + + self.fine_model = ModelsFactory.get_by_name('FineModel', self._opt, is_train=False) + self.svm_model_A = ModelsFactory.get_by_name('SvmModel', self._opt, is_train=False) + self.svm_model_A.load('A') + self.svm_model_B = ModelsFactory.get_by_name('SvmModel', self._opt, is_train=False) + self.svm_model_B.load('B') + self.svms = [self.svm_model_A, self.svm_model_B] + self.reg_model = ModelsFactory.get_by_name('RegModel', self._opt, is_train=False) + + self.test() + + def test(self): + imgs, _, rects = image_proposal(self._img_path, self._img_size) + + show_rect(self._img_path, rects, ' ') + + input_data=torch.Tensor(imgs).permute(0,3,1,2) + features, _ = self.fine_model._forward_test(input_data) + features = features.data.cpu().numpy() + + results = [] + results_old = [] + results_label = [] + count = 0 + + flower = {1:'pancy', 2:'Tulip'} + + for f in features: + for svm in self.svms: + pred = svm.predict([f.tolist()]) + # not background + if pred[0] != 0: + results_old.append(rects[count]) + input_data=torch.Tensor(f) + box = self.reg_model._forward_test(input_data) + box = box.data.cpu().numpy() + if box[0] > 0.3: + px, py, pw, ph = rects[count][0], rects[count][1], rects[count][2], rects[count][3] + old_center_x, old_center_y = px + pw / 2.0, py + ph / 2.0 + x_ping, y_ping, w_suo, h_suo = box[1], box[2], box[3], box[4], + new__center_x = x_ping * pw + old_center_x + new__center_y = y_ping * ph + old_center_y + new_w = pw * np.exp(w_suo) + new_h = ph * np.exp(h_suo) + new_verts = [new__center_x, new__center_y, new_w, new_h] + results.append(new_verts) + results_label.append(pred[0]) + count += 1 + + average_center_x, average_center_y, average_w,average_h = 0, 0, 0, 0 + #use average values to represent the final result + for vert in results: + average_center_x += vert[0] + average_center_y += vert[1] + average_w += vert[2] + average_h += vert[3] + average_center_x = average_center_x / len(results) + average_center_y = average_center_y / len(results) + average_w = average_w / len(results) + average_h = average_h / len(results) + average_result = [[average_center_x, average_center_y, average_w, average_h]] + result_label = max(results_label, key=results_label.count) + show_rect(self._img_path, results_old, ' ') + show_rect(self._img_path, average_result, flower[result_label]) + + +if __name__ == "__main__": + Test() + + + + + + diff --git a/loss/__init__.py b/loss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/loss/losses.py b/loss/losses.py new file mode 100644 index 0000000..ddbcd89 --- /dev/null +++ b/loss/losses.py @@ -0,0 +1,15 @@ +import torch.nn as nn +import torch + +class Regloss(nn.Module): + def __init__(self): + super(Regloss, self).__init__() + + def forward(self, y_true, y_pred): + no_object_loss = torch.pow((1 - y_true[:, 0]) * y_pred[:, 0],2).mean() + object_loss = torch.pow((y_true[:, 0]) * (y_pred[:, 0] - 1),2).mean() + + reg_loss = (y_true[:, 0] * (torch.pow(y_true[:, 1:5] - y_pred[:, 1:5],2).sum(1))).mean() + + loss = no_object_loss + object_loss + reg_loss + return loss diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/model_factory.py b/models/model_factory.py new file mode 100644 index 0000000..ca5e4fd --- /dev/null +++ b/models/model_factory.py @@ -0,0 +1,119 @@ +import os +import torch +from torch.optim import lr_scheduler + +class ModelsFactory: + def __init__(self): + pass + + @staticmethod + def get_by_name(model_name, *args, **kwargs): + model = None + + if model_name == 'AlexModel': + from .models import AlexModel + model = AlexModel(*args, **kwargs) + elif model_name == 'FineModel': + from .models import FineModel + model = FineModel(*args, **kwargs) + elif model_name == 'SvmModel': + from .models import SvmModel + model = SvmModel(*args, **kwargs) + elif model_name == 'RegModel': + from .models import RegModel + model = RegModel(*args, **kwargs) + else: + raise ValueError("Model %s not recognized." % model_name) + + print("Model %s was created" % model.name) + return model + + +class BaseModel(object): + + def __init__(self, opt, is_train): + self._name = 'BaseModel' + + self._opt = opt + self._gpu_ids = opt.gpu_ids + self._is_train = is_train + + self._Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor + self._LongTensor = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor + self._FloatTensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor + self._root_dir = opt.checkpoints_dir + + + @property + def name(self): + return self._name + + @property + def is_train(self): + return self._is_train + + def set_input(self, input): + assert False, "set_input not implemented" + + def set_train(self): + assert False, "set_train not implemented" + + def set_eval(self): + assert False, "set_eval not implemented" + + def forward(self, keep_data_for_visuals=False): + assert False, "forward not implemented" + + # used in test time, no backprop + def test(self): + assert False, "test not implemented" + + def get_image_paths(self): + return {} + + def optimize_parameters(self): + assert False, "optimize_parameters not implemented" + + def get_current_visuals(self): + return {} + + def get_current_errors(self): + return {} + + def get_current_scalars(self): + return {} + + def save(self, label): + assert False, "save not implemented" + + def load(self): + assert False, "load not implemented" + + def _save_network(self, network, network_label, epoch_label, save_dir): + save_filename = 'net_epoch_%s_id_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self._root_dir, save_dir, save_filename) + if len(self._gpu_ids) > 1: + torch.save(network.module.state_dict(), save_path) + else: + torch.save(network.state_dict(), save_path) + print ('saved net: %s' % save_path) + + def _load_network(self, network, network_label, epoch_label, load_dir): + load_filename = 'net_epoch_%s_id_%s.pth' % (epoch_label, network_label) + load_path = os.path.join(self._root_dir, load_dir, load_filename) + assert os.path.exists( + load_path), 'Weights file not found. Have you trained a model!? We are not providing one' % load_path + + state_dict = torch.load(load_path) + if len(self._gpu_ids) > 1: + network.module.load_state_dict(torch.load(load_path)) + else: + network.load_state_dict(torch.load(load_path)) + print ('loaded net: %s' % load_path) + + def print_network(self, network): + num_params = 0 + for param in network.parameters(): + num_params += param.numel() + print(network) + print('Total number of parameters: %d' % num_params) diff --git a/models/models.py b/models/models.py new file mode 100644 index 0000000..2e81296 --- /dev/null +++ b/models/models.py @@ -0,0 +1,353 @@ +from .model_factory import BaseModel +from networks.network_factory import NetworksFactory +from torch.autograd import Variable +from sklearn import svm +from collections import OrderedDict +import torch +import os +from sklearn.externals import joblib +from loss.losses import Regloss + +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +model_urls = { + 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', +} + +class AlexModel(BaseModel): + def __init__(self, opt, is_train): + super(AlexModel, self).__init__(opt, is_train) + self._name = 'AlexModel' + + # create networks + self._init_create_networks() + + # init train variables + if self._is_train: + self._init_train_vars() + + # use pre-trained AlexNet + if self._is_train and not self._opt.load_alex_epoch > 0: + self._init_weights() + + # load networks and optimizers + if not self._is_train or self._opt.load_alex_epoch > 0: + self.load() + + if not self._is_train: + self.set_eval() + + # init loss + self._init_losses() + + def _init_create_networks(self): + network_type = 'AlexNet' + self.network = self._create_branch(network_type) + if len(self._gpu_ids) > 1: + self.network = torch.nn.DataParallel(self.network, device_ids=self._gpu_ids) + if torch.cuda.is_available(): + self.network.cuda() + + def _init_train_vars(self): + self._current_lr = self._opt.learning_rate + self._decay_rate = self._opt.decay_rate + # initialize optimizers + self._optimizer = torch.optim.SGD(self.network.parameters(), + lr=self._current_lr, + momentum=self._decay_rate) + + def _init_weights(self): + state_dict=load_state_dict_from_url(model_urls['alexnet'], progress=True) + current_state = self.network.state_dict() + keys = list(state_dict.keys()) + for key in keys: + if key.startswith('features'): + current_state[key] = state_dict[key] + current_state['fn8.weight'] = state_dict['classifier.1.weight'] + current_state['fn8.bias'] = state_dict['classifier.1.bias'] + current_state['fn9.weight'] = state_dict['classifier.4.weight'] + current_state['fn9.bias'] = state_dict['classifier.4.bias'] + + def load(self): + load_epoch = self._opt.load_alex_epoch + self._load_network(self.network, 'AlexNet', load_epoch, self._opt.alex_dir) + + def save(self, label): + # save networks + self._save_network(self.network, 'AlexNet', label, self._opt.alex_dir) + + def _create_branch(self, branch_name): + return NetworksFactory.get_by_name(branch_name, self._opt.alex_classes) + + def _init_losses(self): + # define loss function + self._cross_entropy = torch.nn.CrossEntropyLoss() + if torch.cuda.is_available(): + self._cross_entropy = self._cross_entropy.cuda() + + def set_input(self, inputs, labels): + self.batch_inputs = self._Tensor(inputs).permute(0,3,1,2) + self.labels = Variable(self._LongTensor(labels).squeeze_()) + + def optimize_parameters(self): + if self._is_train: + + loss = self._forward() + self._optimizer.zero_grad() + loss.backward() + self._optimizer.step() + + def _forward(self): + feature, final = self.network(self.batch_inputs) + self._loss = self._cross_entropy(final, self.labels) + return self._loss + + def _forward_test(self, input): + feature, final = self.network(input) + return feature, final + + def get_current_errors(self): + loss_dict = OrderedDict([('loss_entropy', self._loss.data[0]), + ]) + return loss_dict + + def set_train(self): + self.network.train() + self._is_train = True + + def set_eval(self): + self.network.eval() + self._is_train = False + + +class FineModel(BaseModel): + def __init__(self, opt, is_train): + super(FineModel, self).__init__(opt, is_train) + self._name = 'FineModel' + + # create networks + self._init_create_networks() + + # init train variables + if self._is_train: + self._init_train_vars() + + # use pre-trained AlexNet + if self._is_train and not self._opt.load_finetune_epoch > 0: + self._init_weights() + + # load networks and optimizers + if not self._is_train or self._opt.load_finetune_epoch > 0: + self.load() + + if not self._is_train: + self.set_eval() + + # init loss + self._init_losses() + + def _init_create_networks(self): + network_type = 'AlexNet' + self.network = self._create_branch(network_type) + if len(self._gpu_ids) > 1: + self.network = torch.nn.DataParallel(self.network, device_ids=self._gpu_ids) + if torch.cuda.is_available(): + self.network.cuda() + + def _init_train_vars(self): + self._current_lr = self._opt.learning_rate + self._decay_rate = self._opt.decay_rate + # initialize optimizers + self._optimizer = torch.optim.SGD(self.network.parameters(), + lr=self._current_lr, + momentum=self._decay_rate) + + def _init_weights(self): + + load_epoch = self._opt.load_alex_epoch + load_dir = self._opt.alex_dir + network_label = "AlexNet" + load_filename = 'net_epoch_%s_id_%s.pth' % (load_epoch, network_label) + + load_path = os.path.join(self._opt.checkpoints_dir, load_dir, load_filename) + assert os.path.exists( + load_path), 'Weights file not found. Have you trained a model!? We are not providing one' % load_path + state_dict = torch.load(load_path) + current_state = self.network.state_dict() + keys = list(state_dict.keys()) + for key in keys: + if not key.startswith('fn10'): + current_state[key] = state_dict[key] + + def save(self, label): + # save networks + self._save_network(self.network, 'AlexNet', label, self._opt.finetune_dir) + + def load(self): + load_epoch = self._opt.load_finetune_epoch + self._load_network(self.network, 'AlexNet', load_epoch, self._opt.finetune_dir) + + def _create_branch(self, branch_name): + return NetworksFactory.get_by_name(branch_name, self._opt.finetune_classes) + + def _init_losses(self): + # define loss function + self._cross_entropy = torch.nn.CrossEntropyLoss() + if torch.cuda.is_available(): + self._cross_entropy = self._cross_entropy.cuda() + + def set_input(self, inputs, labels): + self.batch_inputs = self._Tensor(inputs).permute(0,3,1,2) + self.labels = Variable(self._LongTensor(labels).squeeze_()) + + def optimize_parameters(self): + if self._is_train: + + loss = self._forward() + self._optimizer.zero_grad() + loss.backward() + self._optimizer.step() + + def _forward(self): + feature, final = self.network(self.batch_inputs) + self._loss = self._cross_entropy(final, self.labels) + return self._loss + + def _forward_test(self, input): + feature, final = self.network(input) + return feature, final + + def get_current_errors(self): + loss_dict = OrderedDict([('loss_entropy', self._loss.data[0]), + ]) + return loss_dict + + def set_train(self): + self.network.train() + self._is_train = True + + def set_eval(self): + self.network.eval() + self._is_train = False + + +class SvmModel: + def __init__(self, opt, is_train): + self._opt = opt + self._name = 'SvmModel' + self._save_dir = os.path.join(opt.checkpoints_dir, opt.svm_dir) + self._is_train = is_train + + + def train(self, features, labels): + self.clf = svm.LinearSVC() + self.clf.fit(features, labels) + + def save(self, name): + save_path = os.path.join(self._save_dir, "%s_svm.pkl" % name) + joblib.dump(self.clf, save_path) + + def load(self, name): + load_path = os.path.join(self._save_dir, "%s_svm.pkl" % name) + self.clf = joblib.load(load_path) + + def predict(self, features): + pred = self.clf.predict(features) + return pred + + @property + def name(self): + return self._name + +class RegModel(BaseModel): + def __init__(self, opt, is_train): + super(RegModel, self).__init__(opt, is_train) + self._name = 'RegModel' + + # create networks + self._init_create_networks() + + # init train variables + if self._is_train and not self._opt.load_reg_epoch > 0: + self._init_train_vars() + + # load networks and optimizers + if not self._is_train or self._opt.load_reg_epoch > 0: + self.load() + + if not self._is_train: + self.set_eval() + + # init loss + self._init_losses() + + def _init_create_networks(self): + self.network = self._create_branch("RegNet") + if len(self._gpu_ids) > 1: + self.network = torch.nn.DataParallel(self.network, device_ids=self._gpu_ids) + if torch.cuda.is_available(): + self.network.cuda() + + def _init_train_vars(self): + self._current_lr = self._opt.learning_rate + self._decay_rate = self._opt.decay_rate + # initialize optimizers + self._optimizer = torch.optim.SGD(self.network.parameters(), + lr=self._current_lr, + momentum=self._decay_rate) + + def load(self): + load_epoch = self._opt.load_reg_epoch + self._load_network(self.network, 'RegNet', load_epoch, self._opt.reg_dir) + + + def save(self, label): + # save networks + self._save_network(self.network, 'RegNet', label, self._opt.reg_dir) + + def _create_branch(self, branch_name): + return NetworksFactory.get_by_name(branch_name, self._opt.reg_classes) + + def _init_losses(self): + # define loss function + self._reg_loss = Regloss() + if torch.cuda.is_available(): + self._reg_loss = self._reg_loss.cuda() + + def set_input(self, inputs, labels): + self.batch_inputs = self._Tensor(inputs) + self.labels = Variable(self._FloatTensor(labels)) + + def optimize_parameters(self): + if self._is_train: + + loss = self._forward() + self._optimizer.zero_grad() + loss.backward() + self._optimizer.step() + + def _forward(self): + final = self.network(self.batch_inputs) + self._loss = self._reg_loss(self.labels, final) + return self._loss + + def _forward_test(self, input): + final = self.network(input) + return final + + def get_current_errors(self): + loss_dict = OrderedDict([('loss_reg', self._loss.data[0]), + ]) + return loss_dict + + def set_train(self): + self.network.train() + self._is_train = True + + def set_eval(self): + self.network.eval() + self._is_train = False + diff --git a/networks/__init__.py b/networks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/networks/network_factory.py b/networks/network_factory.py new file mode 100644 index 0000000..0593e66 --- /dev/null +++ b/networks/network_factory.py @@ -0,0 +1,36 @@ +import torch.nn as nn +import functools + +class NetworksFactory: + def __init__(self): + pass + + @staticmethod + def get_by_name(network_name, *args, **kwargs): + + if network_name == 'AlexNet': + from .networks import AlexNet + network = AlexNet(*args, **kwargs) + elif network_name == 'SVM': + from .networks import SVM + network = SVM(*args, **kwargs) + elif network_name == 'RegNet': + from .networks import RegNet + network = RegNet(*args, **kwargs) + else: + raise ValueError("Network %s not recognized." % network_name) + + print ("Network %s was created" % network_name) + + return network + + +class NetworkBase(nn.Module): + def __init__(self): + super(NetworkBase, self).__init__() + self._name = 'BaseNetwork' + + @property + def name(self): + return self._name + diff --git a/networks/networks.py b/networks/networks.py new file mode 100644 index 0000000..8840c0d --- /dev/null +++ b/networks/networks.py @@ -0,0 +1,92 @@ +import torch.nn as nn +from .network_factory import NetworkBase +from sklearn import svm +import torch + +class AlexNet(NetworkBase): + + def __init__(self, output_num): + super(AlexNet, self).__init__() + self._name = 'AlexNet' + self._output_num = output_num + + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) + + self.drop8 = nn.Dropout() + self.fn8 = nn.Linear(256 * 6 * 6, 4096) + self.active8 = nn.ReLU(inplace=True) + + self.drop9 = nn.Dropout() + self.fn9 = nn.Linear(4096, 4096) + self.active9 = nn.ReLU(inplace=True) + + self.fn10 = nn.Linear(4096, self._output_num) + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + + x = self.drop8(x) + x = self.fn8(x) + x = self.active8(x) + + x = self.drop9(x) + x = self.fn9(x) + + feature = self.active9(x) + final = self.fn10(feature) + + return feature, final + +class SVM: + def __init__(self, opt): + self._name = 'SVM' + self._opt = opt + + def train(self, features, labels): + clf = svm.LinearSVC() + clf.fit(features, labels) + return clf + +class RegNet(NetworkBase): + + def __init__(self, output_num): + super(RegNet, self).__init__() + self._name = 'RegNet' + self._output_num = output_num + + layers = [] + fc1 = nn.Linear(4096, 4096) + fc1.weight.data.normal_(0.0, 0.01) + layers.append(fc1) + layers.append(nn.Dropout(0.5)) + layers.append(nn.Tanh()) + fc2 = nn.Linear(4096, self._output_num) + fc2.weight.data.normal_(0.0, 0.01) + layers.append(fc2) + layers.append(nn.Tanh()) + + self.logits = nn.Sequential(*layers) + + def forward(self, x): + return self.logits(x) + + + + diff --git a/options/__init__.py b/options/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/options/base_options.py b/options/base_options.py new file mode 100644 index 0000000..20b0430 --- /dev/null +++ b/options/base_options.py @@ -0,0 +1,95 @@ +import argparse +import os +from utils import util +import torch + +class BaseOptions(): + def __init__(self): + self._parser = argparse.ArgumentParser() + self._initialized = False + + def initialize(self): + + self._parser.add_argument('--data_dir', type=str, default='sample_dataset', help='path to dataset') + self._parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + self._parser.add_argument('--generate_save_path', type=str, default='./generate', help='dir to save generated data') + + self._parser.add_argument('--load_alex_epoch', type=int, default=-1, help='which epoch to load? set to -1 to use latest cached model') + self._parser.add_argument('--load_finetune_epoch', type=int, default=-1, help='which epoch to load? set to -1 to use latest cached model') + self._parser.add_argument('--load_reg_epoch', type=int, default=-1, help='which epoch to load? set to -1 to use latest cached model') + + self._parser.add_argument('--alex_dir', type=str, default='alexnet', help='dir for training alexnet') + self._parser.add_argument('--finetune_dir', type=str, default='finetune', help='dir for finetuning alexnet') + self._parser.add_argument('--svm_dir', type=str, default='svm', help='dir for training svm') + self._parser.add_argument('--reg_dir', type=str, default='regression', help='dir for box regression') + self._parser.add_argument('--options_dir', type=str, default='alexnet', help='dir for saving options') + + self._parser.add_argument('--alex_classes', type=int, default=17, help='# classes when training alexnet') + self._parser.add_argument('--finetune_classes', type=int, default=3, help='# classes when finetuning alexnet') + self._parser.add_argument('--reg_classes', type=int, default=5, help='# classes when training regnet') + + self._parser.add_argument('--finetune_threshold', type=float, default=0.3, help='threshold to select image region') + self._parser.add_argument('--svm_threshold', type=float, default=0.3, help='threshold to select image region') + self._parser.add_argument('--reg_threshold', type=float, default=0.6, help='threshold to select image region') + + self._parser.add_argument('--train_list', type=str, default='train_list.txt', help='training data') + self._parser.add_argument('--finetune_list', type=str, default='fine_tune_list.txt', help='training data') + + self._parser.add_argument('--image_size', type=int, default=224, help='input image size') + + + self._parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + + + self._initialized = True + + def parse(self): + if not self._initialized: + self.initialize() + self._opt = self._parser.parse_args() + + # set is train or set + self._opt.is_train = self.is_train + + # get and set gpus + self._get_set_gpus() + + args = vars(self._opt) + + # print in terminal args + self._print(args) + + # save args to file + self._save(args) + + return self._opt + + def _get_set_gpus(self): + # get gpu ids + str_ids = self._opt.gpu_ids.split(',') + self._opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + self._opt.gpu_ids.append(id) + + # set gpu ids + if torch.cuda.is_available() and len(self._opt.gpu_ids) > 0: + torch.cuda.set_device(self._opt.gpu_ids[0]) + + def _print(self, args): + print('------------ Options -------------') + for k, v in sorted(args.items()): + print('%s: %s' % (str(k), str(v))) + print('-------------- End ----------------') + + def _save(self, args): + expr_dir = os.path.join(self._opt.checkpoints_dir, self._opt.options_dir) + print(expr_dir) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, 'opt_%s.txt' % ('train' if self.is_train else 'test')) + with open(file_name, 'wt') as opt_file: + opt_file.write('------------ Options -------------\n') + for k, v in sorted(args.items()): + opt_file.write('%s: %s\n' % (str(k), str(v))) + opt_file.write('-------------- End ----------------\n') diff --git a/options/test_options.py b/options/test_options.py new file mode 100644 index 0000000..c226cd3 --- /dev/null +++ b/options/test_options.py @@ -0,0 +1,10 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + + self._parser.add_argument('--img_path', type=str, default='./sample_dataset/2flowers/jpg/1/image_1281.jpg', help='file containing results') + + self.is_train = False diff --git a/options/train_options.py b/options/train_options.py new file mode 100644 index 0000000..0381d75 --- /dev/null +++ b/options/train_options.py @@ -0,0 +1,13 @@ +from .base_options import BaseOptions + + +class TrainOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + + self._parser.add_argument('--total_epoch', type=int, default=100, help='total epoch for training') + self._parser.add_argument('--learning_rate', type=float, default=0.0001, help='initial learning rate') + self._parser.add_argument('--decay_rate', type=float, default=0.99, help='decay rate') + self._parser.add_argument('--batch_size', type=int, default=64, help='input batch size') + + self.is_train = True diff --git a/train_step1.py b/train_step1.py new file mode 100644 index 0000000..6fc5143 --- /dev/null +++ b/train_step1.py @@ -0,0 +1,57 @@ +from __future__ import division +from data.dataset_factory import DatasetFactory +from models.model_factory import ModelsFactory +from options.train_options import TrainOptions + +class Train: + def __init__(self): + self._opt = TrainOptions().parse() + + self._dataset_train = DatasetFactory.get_by_name("AlexnetDataset", self._opt) + self._dataset_train_size = len(self._dataset_train) + print('#train images = %d' % self._dataset_train_size) + + self._model = ModelsFactory.get_by_name("AlexModel", self._opt, is_train=True) + + self._train() + + def _train(self): + self._steps_per_epoch = int (self._dataset_train_size / self._opt.batch_size) + + for i_epoch in range(self._opt.load_alex_epoch + 1, self._opt.total_epoch + 1): + # train epoch + self._train_epoch(i_epoch) + + # save model + if i_epoch % 20 == 0: + print('saving the model at the end of epoch %d' % i_epoch) + self._model.save(i_epoch) + + def _train_epoch(self, i_epoch): + + for step in range(1, self._steps_per_epoch+1): + input, labels = self._dataset_train.get_batch() + + # train model + self._model.set_input(input, labels) + self._model.optimize_parameters() + + # display terminal + self._display_terminal_train(i_epoch, step) + + def _display_terminal_train(self, i_epoch, i_train_batch): + errors = self._model.get_current_errors() + message = '(epoch: %d, it: %d/%d) ' % (i_epoch, i_train_batch, self._steps_per_epoch) + for k, v in errors.items(): + message += '%s:%.3f ' % (k, v) + + print(message) + +if __name__ == "__main__": + Train() + + + + + + diff --git a/train_step2.py b/train_step2.py new file mode 100644 index 0000000..14c25ba --- /dev/null +++ b/train_step2.py @@ -0,0 +1,53 @@ +from __future__ import division +from data.dataset_factory import DatasetFactory +from models.model_factory import ModelsFactory +from options.train_options import TrainOptions + + +class Train: + def __init__(self): + self._opt = TrainOptions().parse() + + self._dataset_train = DatasetFactory.get_by_name("FinetuneDataset", self._opt) + self._dataset_train_size = len(self._dataset_train) + print('#train images = %d' % self._dataset_train_size) + + self._model = ModelsFactory.get_by_name("FineModel", self._opt, is_train=True) + self._train() + + def _train(self): + self._steps_per_epoch = int (self._dataset_train_size / self._opt.batch_size) + + for i_epoch in range(self._opt.load_finetune_epoch + 1, self._opt.total_epoch + 1): + # train epoch + self._train_epoch(i_epoch) + + # save model + if i_epoch % 20 == 0: + print('saving the model at the end of epoch %d' % i_epoch) + self._model.save(i_epoch) + + def _train_epoch(self, i_epoch): + + for step in range(1, self._steps_per_epoch+1): + input, labels = self._dataset_train.get_batch() + + # train model + self._model.set_input(input, labels) + self._model.optimize_parameters() + + # display terminal + self._display_terminal_train(i_epoch, step) + + def _display_terminal_train(self, i_epoch, i_train_batch): + errors = self._model.get_current_errors() + message = '(epoch: %d, it: %d/%d) ' % (i_epoch, i_train_batch, self._steps_per_epoch) + for k, v in errors.items(): + message += '%s:%.3f ' % (k, v) + + print(message) + +if __name__ == "__main__": + Train() + + diff --git a/train_step3.py b/train_step3.py new file mode 100644 index 0000000..e7afa01 --- /dev/null +++ b/train_step3.py @@ -0,0 +1,36 @@ +from __future__ import division +from data.dataset_factory import DatasetFactory +from models.model_factory import ModelsFactory +from options.train_options import TrainOptions +import numpy as np + +class Train: + def __init__(self): + self._opt = TrainOptions().parse() + + self._dataset_train = DatasetFactory.get_by_name("SVMDataset", self._opt) + self._dataset_train_size = len(self._dataset_train) + print('#train images = %d' % self._dataset_train_size) + + self.classA_features, self.classA_labels, self.classB_features, self.classB_labels = self._dataset_train.get_datas() + + self._modelA = ModelsFactory.get_by_name("SvmModel", self._opt, is_train=True) + self._modelB = ModelsFactory.get_by_name("SvmModel", self._opt, is_train=True) + + self._train(self._modelA, self.classA_features, self.classA_labels, "A") + self._train(self._modelB, self.classB_features, self.classB_labels, "B") + + def _train(self, model, features, labels, name): + model.train(features, labels) + model.save(name) + pred = model.predict(features) + + print (labels) + print (pred) + + + +if __name__ == "__main__": + Train() + + diff --git a/train_step4.py b/train_step4.py new file mode 100644 index 0000000..d1e6044 --- /dev/null +++ b/train_step4.py @@ -0,0 +1,57 @@ +from __future__ import division +from data.dataset_factory import DatasetFactory +from models.model_factory import ModelsFactory +from options.train_options import TrainOptions + +class Train: + def __init__(self): + self._opt = TrainOptions().parse() + + self._dataset_train = DatasetFactory.get_by_name("RegDataset", self._opt) + self._dataset_train_size = len(self._dataset_train) + print('#train images = %d' % self._dataset_train_size) + + self._model = ModelsFactory.get_by_name("RegModel", self._opt, is_train=True) + + self._train() + + def _train(self): + self._steps_per_epoch = int (self._dataset_train_size / self._opt.batch_size) + + for i_epoch in range(self._opt.load_reg_epoch + 1, self._opt.total_epoch + 1): + # train epoch + self._train_epoch(i_epoch) + + # save model + if i_epoch % 20 == 0: + print('saving the model at the end of epoch %d' % i_epoch) + self._model.save(i_epoch) + + def _train_epoch(self, i_epoch): + + for step in range(1, self._steps_per_epoch+1): + input, labels = self._dataset_train.get_batch() + + # train model + self._model.set_input(input, labels) + self._model.optimize_parameters() + + # display terminal + self._display_terminal_train(i_epoch, step) + + def _display_terminal_train(self, i_epoch, i_train_batch): + errors = self._model.get_current_errors() + message = '(epoch: %d, it: %d/%d) ' % (i_epoch, i_train_batch, self._steps_per_epoch) + for k, v in errors.items(): + message += '%s:%.3f ' % (k, v) + + print(message) + +if __name__ == "__main__": + Train() + + + + + + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/selectivesearch.py b/utils/selectivesearch.py new file mode 100644 index 0000000..fd393fd --- /dev/null +++ b/utils/selectivesearch.py @@ -0,0 +1,319 @@ +# -*- coding: utf-8 -*- +import skimage.io +import skimage.feature +import skimage.color +import skimage.transform +import skimage.util +import skimage.segmentation +import numpy + + +# "Selective Search for Object Recognition" by J.R.R. Uijlings et al. +# +# - Modified version with LBP extractor for texture vectorization + + +def _generate_segments(im_orig, scale, sigma, min_size): + """ + segment smallest regions by the algorithm of Felzenswalb and + Huttenlocher + """ + + # open the Image + im_mask = skimage.segmentation.felzenszwalb( + skimage.util.img_as_float(im_orig), scale=scale, sigma=sigma, + min_size=min_size) + + # merge mask channel to the image as a 4th channel + im_orig = numpy.append( + im_orig, numpy.zeros(im_orig.shape[:2])[:, :, numpy.newaxis], axis=2) + im_orig[:, :, 3] = im_mask + + return im_orig + + +def _sim_colour(r1, r2): + """ + calculate the sum of histogram intersection of colour + """ + return sum([min(a, b) for a, b in zip(r1["hist_c"], r2["hist_c"])]) + + +def _sim_texture(r1, r2): + """ + calculate the sum of histogram intersection of texture + """ + return sum([min(a, b) for a, b in zip(r1["hist_t"], r2["hist_t"])]) + + +def _sim_size(r1, r2, imsize): + """ + calculate the size similarity over the image + """ + return 1.0 - (r1["size"] + r2["size"]) / imsize + + +def _sim_fill(r1, r2, imsize): + """ + calculate the fill similarity over the image + """ + bbsize = ( + (max(r1["max_x"], r2["max_x"]) - min(r1["min_x"], r2["min_x"])) + * (max(r1["max_y"], r2["max_y"]) - min(r1["min_y"], r2["min_y"])) + ) + return 1.0 - (bbsize - r1["size"] - r2["size"]) / imsize + + +def _calc_sim(r1, r2, imsize): + return (_sim_colour(r1, r2) + _sim_texture(r1, r2) + + _sim_size(r1, r2, imsize) + _sim_fill(r1, r2, imsize)) + + +def _calc_colour_hist(img): + """ + calculate colour histogram for each region + + the size of output histogram will be BINS * COLOUR_CHANNELS(3) + + number of bins is 25 as same as [uijlings_ijcv2013_draft.pdf] + + extract HSV + """ + + BINS = 25 + hist = numpy.array([]) + + for colour_channel in (0, 1, 2): + + # extracting one colour channel + c = img[:, colour_channel] + + # calculate histogram for each colour and join to the result + hist = numpy.concatenate( + [hist] + [numpy.histogram(c, BINS, (0.0, 255.0))[0]]) + + # L1 normalize + hist = hist / len(img) + + return hist + + +def _calc_texture_gradient(img): + """ + calculate texture gradient for entire image + + The original SelectiveSearch algorithm proposed Gaussian derivative + for 8 orientations, but we use LBP instead. + + output will be [height(*)][width(*)] + """ + ret = numpy.zeros((img.shape[0], img.shape[1], img.shape[2])) + + for colour_channel in (0, 1, 2): + ret[:, :, colour_channel] = skimage.feature.local_binary_pattern( + img[:, :, colour_channel], 8, 1.0) + + return ret + + +def _calc_texture_hist(img): + """ + calculate texture histogram for each region + + calculate the histogram of gradient for each colours + the size of output histogram will be + BINS * ORIENTATIONS * COLOUR_CHANNELS(3) + """ + BINS = 10 + + hist = numpy.array([]) + + for colour_channel in (0, 1, 2): + + # mask by the colour channel + fd = img[:, colour_channel] + + # calculate histogram for each orientation and concatenate them all + # and join to the result + hist = numpy.concatenate( + [hist] + [numpy.histogram(fd, BINS, (0.0, 1.0))[0]]) + + # L1 Normalize + hist = hist / len(img) + + return hist + + +def _extract_regions(img): + + R = {} + + # get hsv image + hsv = skimage.color.rgb2hsv(img[:, :, :3]) + + # pass 1: count pixel positions + for y, i in enumerate(img): + + for x, (r, g, b, l) in enumerate(i): + + # initialize a new region + if l not in R: + R[l] = { + "min_x": 0xffff, "min_y": 0xffff, + "max_x": 0, "max_y": 0, "labels": [l]} + + # bounding box + if R[l]["min_x"] > x: + R[l]["min_x"] = x + if R[l]["min_y"] > y: + R[l]["min_y"] = y + if R[l]["max_x"] < x: + R[l]["max_x"] = x + if R[l]["max_y"] < y: + R[l]["max_y"] = y + + # pass 2: calculate texture gradient + tex_grad = _calc_texture_gradient(img) + + # pass 3: calculate colour histogram of each region + for k, v in R.items(): + + # colour histogram + masked_pixels = hsv[:, :, :][img[:, :, 3] == k] + R[k]["size"] = len(masked_pixels / 4) + R[k]["hist_c"] = _calc_colour_hist(masked_pixels) + + # texture histogram + R[k]["hist_t"] = _calc_texture_hist(tex_grad[:, :][img[:, :, 3] == k]) + + return R + + +def _extract_neighbours(regions): + + def intersect(a, b): + if (a["min_x"] < b["min_x"] < a["max_x"] + and a["min_y"] < b["min_y"] < a["max_y"]) or ( + a["min_x"] < b["max_x"] < a["max_x"] + and a["min_y"] < b["max_y"] < a["max_y"]) or ( + a["min_x"] < b["min_x"] < a["max_x"] + and a["min_y"] < b["max_y"] < a["max_y"]) or ( + a["min_x"] < b["max_x"] < a["max_x"] + and a["min_y"] < b["min_y"] < a["max_y"]): + return True + return False + + R = regions.items() + r = [elm for elm in R] + R = r + neighbours = [] + for cur, a in enumerate(R[:-1]): + for b in R[cur + 1:]: + if intersect(a[1], b[1]): + neighbours.append((a, b)) + + return neighbours + + +def _merge_regions(r1, r2): + new_size = r1["size"] + r2["size"] + rt = { + "min_x": min(r1["min_x"], r2["min_x"]), + "min_y": min(r1["min_y"], r2["min_y"]), + "max_x": max(r1["max_x"], r2["max_x"]), + "max_y": max(r1["max_y"], r2["max_y"]), + "size": new_size, + "hist_c": ( + r1["hist_c"] * r1["size"] + r2["hist_c"] * r2["size"]) / new_size, + "hist_t": ( + r1["hist_t"] * r1["size"] + r2["hist_t"] * r2["size"]) / new_size, + "labels": r1["labels"] + r2["labels"] + } + return rt + + +def selective_search( + im_orig, scale=1.0, sigma=0.8, min_size=50): + '''Selective Search + + Parameters + ---------- + im_orig : ndarray + Input image + scale : int + Free parameter. Higher means larger clusters in felzenszwalb segmentation. + sigma : float + Width of Gaussian kernel for felzenszwalb segmentation. + min_size : int + Minimum component size for felzenszwalb segmentation. + Returns + ------- + img : ndarray + image with region label + region label is stored in the 4th value of each pixel [r,g,b,(region)] + regions : array of dict + [ + { + 'rect': (left, top, right, bottom), + 'labels': [...] + }, + ... + ] + ''' + assert im_orig.shape[2] == 3, "3ch image is expected" + + # load image and get smallest regions + # region label is stored in the 4th value of each pixel [r,g,b,(region)] + img = _generate_segments(im_orig, scale, sigma, min_size) + + if img is None: + return None, {} + + imsize = img.shape[0] * img.shape[1] + R = _extract_regions(img) + + # extract neighbouring information + neighbours = _extract_neighbours(R) + + # calculate initial similarities + S = {} + for (ai, ar), (bi, br) in neighbours: + S[(ai, bi)] = _calc_sim(ar, br, imsize) + + # hierarchal search + while S != {}: + + # get highest similarity + # i, j = sorted(S.items(), cmp=lambda a, b: cmp(a[1], b[1]))[-1][0] + i, j = sorted(list(S.items()), key = lambda a: a[1])[-1][0] + + # merge corresponding regions + t = max(R.keys()) + 1.0 + R[t] = _merge_regions(R[i], R[j]) + + # mark similarities for regions to be removed + key_to_delete = [] + for k, v in S.items(): + if (i in k) or (j in k): + key_to_delete.append(k) + + # remove old similarities of related regions + for k in key_to_delete: + del S[k] + + # calculate similarity set with the new region + for k in filter(lambda a: a != (i, j), key_to_delete): + n = k[1] if k[0] in (i, j) else k[0] + S[(t, n)] = _calc_sim(R[t], R[n], imsize) + + regions = [] + for k, r in R.items(): + regions.append({ + 'rect': ( + r['min_x'], r['min_y'], + r['max_x'] - r['min_x'], r['max_y'] - r['min_y']), + 'size': r['size'], + 'labels': r['labels'] + }) + + return img, regions diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000..22903e9 --- /dev/null +++ b/utils/util.py @@ -0,0 +1,134 @@ +from __future__ import print_function +import os +import cv2 +import numpy as np +from . import selectivesearch +import matplotlib.pyplot as plt + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + +def resize_image(in_image, new_width, new_height, out_image=None, resize_mode=cv2.INTER_CUBIC): + ''' + + :param in_image: input image + :param new_width: width of the new image + :param new_height: height of the new image + :param out_image: new path of the new image + :param resize_mode: resize mode in CV2 + :return: the new image + ''' + img = cv2.resize(in_image, (new_width, new_height), resize_mode) + if out_image: + cv2.imwrite(out_image, img) + return img + +def clip_pic(img, rect): + ''' + + :param img: input image + :param rect: four params of the rect + :return: the clipped image and the rect + ''' + x, y, w, h = rect[0], rect[1], rect[2], rect[3] + x_1 = x + w + y_1 = y + h + return img[y:y_1, x:x_1, :], [x, y, x_1, y_1, w, h] + +def IOU(ver1, vertice2): + ''' + calculate the IOU of two verts + :param ver1: the first vert + :param vertice2: the second vert + :return: IOU value + ''' + vertice1 = [ver1[0], ver1[1], ver1[0]+ver1[2], ver1[1]+ver1[3]] + area_inter = if_intersection(vertice1[0], vertice1[2], vertice1[1], vertice1[3], vertice2[0], vertice2[2], vertice2[1], vertice2[3]) + if area_inter: + area_1 = ver1[2] * ver1[3] + area_2 = vertice2[4] * vertice2[5] + iou = float(area_inter) / (area_1 + area_2 - area_inter) + return iou + return False + +def if_intersection(xmin_a, xmax_a, ymin_a, ymax_a, xmin_b, xmax_b, ymin_b, ymax_b): + if_intersect = False + if xmin_a < xmax_b <= xmax_a and (ymin_a < ymax_b <= ymax_a or ymin_a <= ymin_b < ymax_a): + if_intersect = True + elif xmin_a <= xmin_b < xmax_a and (ymin_a < ymax_b <= ymax_a or ymin_a <= ymin_b < ymax_a): + if_intersect = True + elif xmin_b < xmax_a <= xmax_b and (ymin_b < ymax_a <= ymax_b or ymin_b <= ymin_a < ymax_b): + if_intersect = True + elif xmin_b <= xmin_a < xmax_b and (ymin_b < ymax_a <= ymax_b or ymin_b <= ymin_a < ymax_b): + if_intersect = True + else: + return if_intersect + if if_intersect: + x_sorted_list = sorted([xmin_a, xmax_a, xmin_b, xmax_b]) + y_sorted_list = sorted([ymin_a, ymax_a, ymin_b, ymax_b]) + x_intersect_w = x_sorted_list[2] - x_sorted_list[1] + y_intersect_h = y_sorted_list[2] - y_sorted_list[1] + area_inter = x_intersect_w * y_intersect_h + return area_inter + +def image_proposal(img_path, img_size): + ''' + use selectivesearch to obtain image proposals; + add some rules to refine these image proposals + ''' + img = cv2.imread(img_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_lbl, regions = selectivesearch.selective_search(img, scale=500, sigma=0.9, min_size=10) + candidates = set() + images = [] + vertices = [] + rects = [] + for r in regions: + if r['rect'] in candidates: + continue + if r['size'] < 200: + continue + if (r['rect'][2] * r['rect'][3]) < 500: + continue + proposal_img, proposal_vertice = clip_pic(img, r['rect']) + if len(proposal_img) == 0: + continue + x, y, w, h = r['rect'] + if w == 0 or h == 0: + continue + [a, b, c] = np.shape(proposal_img) + if a == 0 or b == 0 or c == 0: + continue + resized_proposal_img = resize_image(proposal_img, img_size, img_size) + candidates.add(r['rect']) + img_float = np.asarray(resized_proposal_img, dtype="float32") + images.append(img_float) + vertices.append(proposal_vertice) + rects.append(r['rect']) + return images, vertices, rects + +def show_rect(img_path, regions, message): + ''' + :param img_path: input image + :param regions: rects + :param message: annotation + :return: + ''' + img = cv2.imread(img_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + for x, y, w, h in regions: + x, y, w, h =int(x),int(y),int(w),int(h) + rect = cv2.rectangle( + img,(x, y), (x+w, y+h), (0,255,0),2) + font = cv2.FONT_HERSHEY_SIMPLEX + cv2.putText(img, message, (x+20, y+40),font, 1,(255,0,0),2) + plt.imshow(img) + plt.show()