Skip to content

Commit

Permalink
update humanseg_server (#2002)
Browse files Browse the repository at this point in the history
* update humanseg_server

* add clean func

* update save inference model
  • Loading branch information
jm12138 authored Sep 16, 2022
1 parent 8873a70 commit cf5f311
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 69 deletions.
20 changes: 12 additions & 8 deletions modules/image/semantic_segmentation/humanseg_server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,19 +173,13 @@
```python
def save_inference_model(dirname='humanseg_server_model',
model_filename=None,
params_filename=None,
combined=True)
def save_inference_model(dirname)
```
- 将模型保存到指定路径。
- **参数**
* dirname: 存在模型的目录名称
* model\_filename: 模型文件名称,默认为\_\_model\_\_
* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
* combined: 是否将参数保存到统一的一个文件中
* dirname: 模型保存路径
## 四、服务部署
Expand Down Expand Up @@ -243,11 +237,21 @@
* 1.0.0
初始发布
* 1.1.0
新增视频人像分割接口
新增视频流人像分割接口
* 1.1.1
修复cudnn为8.0.4显存泄露问题
* 1.2.0
移除 Fluid API
```shell
$ hub install humanseg_server == 1.2.0
```
23 changes: 13 additions & 10 deletions modules/image/semantic_segmentation/humanseg_server/README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,21 +170,15 @@
```python
def save_inference_model(dirname='humanseg_server_model',
model_filename=None,
params_filename=None,
combined=True)
def save_inference_model(dirname)
```
- Save the model to the specified path.
- **Parameters**
* dirname: Save path.
* model\_filename: Model file name,defalt is \_\_model\_\_
* params\_filename: Parameter file name,defalt is \_\_params\_\_(Only takes effect when `combined` is True)
* combined: Whether to save the parameters to a unified file.
* dirname: Model save path.
Expand Down Expand Up @@ -242,7 +236,7 @@
- 1.0.0
First release
First release
- 1.1.0
Expand All @@ -252,4 +246,13 @@
* 1.1.1
Fix memory leakage problem of on cudnn 8.0.4
Fix memory leakage problem of on cudnn 8.0.4
* 1.2.0
Remove Fluid API
```shell
$ hub install humanseg_server == 1.2.0
```
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import cv2
import numpy as np
from PIL import Image

__all__ = ['reader', 'preprocess_v']

Expand Down
109 changes: 59 additions & 50 deletions modules/image/semantic_segmentation/humanseg_server/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@

import cv2
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
import paddle
import paddle.jit
import paddle.static
from paddle.inference import Config, create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving

from humanseg_server.processor import postprocess, base64_to_cv2, cv2_to_base64, check_dir
Expand All @@ -36,33 +37,37 @@
author="baidu-vis",
author_email="",
summary="DeepLabv3+ is a semantic segmentation model.",
version="1.1.0")
class DeeplabV3pXception65HumanSeg(hub.Module):
def _initialize(self):
self.default_pretrained_model_path = os.path.join(self.directory, "humanseg_server_inference")
version="1.2.0")
class DeeplabV3pXception65HumanSeg:
def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "humanseg_server_inference", "model")
self._set_config()

def _set_config(self):
"""
predictor config setting
"""
self.model_file_path = os.path.join(self.default_pretrained_model_path, '__model__')
self.params_file_path = os.path.join(self.default_pretrained_model_path, '__params__')
cpu_config = AnalysisConfig(self.model_file_path, self.params_file_path)
model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
self.cpu_predictor = create_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.model_file_path, self.params_file_path)
gpu_config = Config(model, params)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)

if paddle.get_cudnn_version() == 8004:
gpu_config.delete_pass('conv_elementwise_add_act_fuse_pass')
gpu_config.delete_pass('conv_elementwise_add2_act_fuse_pass')
self.gpu_predictor = create_predictor(gpu_config)

def segment(self,
images=None,
Expand Down Expand Up @@ -114,9 +119,16 @@ def segment(self,
pass
# feed batch image
batch_image = np.array([data['image'] for data in batch_data])
batch_image = PaddleTensor(batch_image.copy())
output = self.gpu_predictor.run([batch_image]) if use_gpu else self.cpu_predictor.run([batch_image])
output = output[1].as_ndarray()

predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(batch_image.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[1])
output = output_handle.copy_to_cpu()

output = np.expand_dims(output[:, 1, :, :], axis=1)
# postprocess one by one
for i in range(len(batch_data)):
Expand Down Expand Up @@ -154,9 +166,16 @@ def video_stream_segment(self, frame_org, frame_id, prev_gray, prev_cfd, use_gpu
height = int(frame_org.shape[1])
disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
frame = preprocess_v(frame_org, resize_w, resize_h)
image = PaddleTensor(np.array([frame.copy()]))
output = self.gpu_predictor.run([image]) if use_gpu else self.cpu_predictor.run([image])
score_map = output[1].as_ndarray()

predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(frame.copy()[None, ...])
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[1])
score_map = output_handle.copy_to_cpu()

frame = np.transpose(frame, axes=[1, 2, 0])
score_map = np.transpose(np.squeeze(score_map, 0), axes=[1, 2, 0])
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
Expand All @@ -173,7 +192,7 @@ def video_stream_segment(self, frame_org, frame_id, prev_gray, prev_cfd, use_gpu
img_matting = cv2.resize(optflow_map, (height, width), cv2.INTER_LINEAR)
return [img_matting, cur_gray, optflow_map]

def video_segment(self, video_path=None, use_gpu=False, save_dir='humanseg_server_video'):
def video_segment(self, video_path=None, use_gpu=False, save_dir='humanseg_server_video_result'):
resize_h = 512
resize_w = 512
if not video_path:
Expand Down Expand Up @@ -201,9 +220,16 @@ def video_segment(self, video_path=None, use_gpu=False, save_dir='humanseg_serve
ret, frame_org = cap_video.read()
if ret:
frame = preprocess_v(frame_org, resize_w, resize_h)
image = PaddleTensor(np.array([frame.copy()]))
output = self.gpu_predictor.run([image]) if use_gpu else self.cpu_predictor.run([image])
score_map = output[1].as_ndarray()

predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(frame.copy()[None, ...])
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[1])
score_map = output_handle.copy_to_cpu()

frame = np.transpose(frame, axes=[1, 2, 0])
score_map = np.transpose(np.squeeze(score_map, 0), axes=[1, 2, 0])
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
Expand All @@ -228,9 +254,16 @@ def video_segment(self, video_path=None, use_gpu=False, save_dir='humanseg_serve
ret, frame_org = cap_video.read()
if ret:
frame = preprocess_v(frame_org, resize_w, resize_h)
image = PaddleTensor(np.array([frame.copy()]))
output = self.gpu_predictor.run([image]) if use_gpu else self.cpu_predictor.run([image])
score_map = output[1].as_ndarray()

predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(frame.copy()[None, ...])
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[1])
score_map = output_handle.copy_to_cpu()

frame = np.transpose(frame, axes=[1, 2, 0])
score_map = np.transpose(np.squeeze(score_map, 0), axes=[1, 2, 0])
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
Expand All @@ -252,30 +285,6 @@ def video_segment(self, video_path=None, use_gpu=False, save_dir='humanseg_serve
break
cap_video.release()

def save_inference_model(self,
dirname='humanseg_server_model',
model_filename=None,
params_filename=None,
combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path,
model_filename=model_filename,
params_filename=params_filename,
executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)

@serving
def serving_method(self, images, **kwargs):
"""
Expand Down
Loading

0 comments on commit cf5f311

Please sign in to comment.