Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
申瑞珉 (Ruimin Shen) committed Feb 11, 2019
1 parent ec65527 commit f850084
Show file tree
Hide file tree
Showing 16 changed files with 72 additions and 90 deletions.
4 changes: 2 additions & 2 deletions convert_caffe_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def main():
stages=inference.stages.state_dict(),
), 0)
finally:
for stage, output in enumerate(inference(torch.autograd.Variable(tensor, volatile=True))):
for stage, output in enumerate(inference(tensor)):
for name, feature in output.items():
val = feature.data.numpy()
val = feature.detach().numpy()
print('\t'.join(map(str, [
'stage%d/%s' % (stage, name),
'x'.join(map(str, val.shape)),
Expand Down
9 changes: 4 additions & 5 deletions convert_tf_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,14 @@ def main():
'x'.join(map(str, val.shape)),
utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(),
])))
_tensor = torch.autograd.Variable(tensor, volatile=True)
val = dnn(_tensor).data.numpy()
val = dnn(tensor).detach().numpy()
print('\t'.join(map(str, [
'x'.join(map(str, val.shape)),
utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(),
])))
for stage, output in enumerate(inference(_tensor)):
for stage, output in enumerate(inference(tensor)):
for name, feature in output.items():
val = feature.data.numpy()
val = feature.detach().numpy()
print('\t'.join(map(str, [
'stage%d/%s' % (stage, name),
'x'.join(map(str, val.shape)),
Expand All @@ -144,7 +143,7 @@ def main():
forward = inference.forward
inference.forward = lambda self, *x: list(forward(self, *x)[-1].values())
with SummaryWriter(model_dir) as writer:
writer.add_graph(inference, (_tensor,))
writer.add_graph(inference, (tensor,))


def make_args():
Expand Down
2 changes: 1 addition & 1 deletion convert_torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def main():
inference = model.Inference(config, dnn, stages)
inference.eval()
logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in inference.state_dict().values())))
image = torch.autograd.Variable(torch.randn(args.batch_size, 3, height, width), volatile=True)
image = torch.randn(args.batch_size, 3, height, width)
path = model_dir + '.onnx'
logging.info('save ' + path)
forward = inference.forward
Expand Down
5 changes: 3 additions & 2 deletions demo_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ def main():
utils.modify_config(config, cmd)
with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
logging.config.dictConfig(yaml.load(f))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cache_dir = utils.get_cache_dir(config)
_, num_parts = utils.get_dataset_mappers(config)
limbs_index = utils.get_limbs_index(config)
dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config))
dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config)).to(device)
draw_points = utils.visualize.DrawPoints(limbs_index, colors=config.get('draw_points', 'colors').split())
_draw_points = utils.visualize.DrawPoints(limbs_index, thickness=1)
draw_bbox = utils.visualize.DrawBBox()
Expand All @@ -63,7 +64,7 @@ def main():
except configparser.NoOptionError:
workers = multiprocessing.cpu_count()
sizes = utils.train.load_sizes(config)
feature_sizes = [model.feature_size(dnn, *size) for size in sizes]
feature_sizes = [dnn(torch.randn(1, 3, *size).to(device)).size()[-2:] for size in sizes]
collate_fn = utils.data.Collate(
config,
transform.parse_transform(config, config.get('transform', 'resize_train')),
Expand Down
5 changes: 3 additions & 2 deletions demo_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ def main():
utils.modify_config(config, cmd)
with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
logging.config.dictConfig(yaml.load(f))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cache_dir = utils.get_cache_dir(config)
_, num_parts = utils.get_dataset_mappers(config)
limbs_index = utils.get_limbs_index(config)
dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config))
dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config)).to(device)
logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in dnn.state_dict().values())))
size = tuple(map(int, config.get('image', 'size').split()))
draw_points = utils.visualize.DrawPoints(limbs_index, colors=config.get('draw_points', 'colors').split())
Expand All @@ -110,7 +111,7 @@ def main():
collate_fn = utils.data.Collate(
config,
transform.parse_transform(config, config.get('transform', 'resize_train')),
[size], [model.feature_size(dnn, *size)],
[size], [dnn(torch.randn(1, 3, *size).to(device)).size()[-2:]],
maintain=config.getint('data', 'maintain'),
transform_image=transform.get_transform(config, config.get('transform', 'image_train').split()),
)
Expand Down
27 changes: 16 additions & 11 deletions estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@
import torch.optim
import torch.utils.data
import torch.nn as nn
from caffe2.proto import caffe2_pb2
from caffe2.python import workspace
try:
from caffe2.proto import caffe2_pb2
from caffe2.python import workspace
except ImportError:
pass
import humanize
import pybenchmark
import cv2
Expand All @@ -47,6 +50,7 @@ class Estimate(object):
def __init__(self, args, config):
self.args = args
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.cache_dir = utils.get_cache_dir(config)
self.model_dir = utils.get_model_dir(config)
_, self.num_parts = utils.get_dataset_mappers(config)
Expand All @@ -69,7 +73,7 @@ def __init__(self, args, config):
with open(os.path.join(self.model_dir, 'predict_net.pb'), 'rb') as f:
predict_net.ParseFromString(f.read())
p = workspace.Predictor(init_net, predict_net)
self.inference = lambda tensor: [{'parts': torch.autograd.Variable(torch.from_numpy(parts)), 'limbs': torch.autograd.Variable(torch.from_numpy(limbs))} for parts, limbs in zip(*[iter(p.run([tensor.data.cpu().numpy()]))] * 2)]
self.inference = lambda tensor: [{'parts': torch.from_numpy(parts), 'limbs': torch.from_numpy(limbs)} for parts, limbs in zip(*[iter(p.run([tensor.detach().cpu().numpy()]))] * 2)]
else:
self.step, self.epoch, self.dnn, self.stages = self.load()
self.inference = model.Inference(config, self.dnn, self.stages)
Expand Down Expand Up @@ -133,13 +137,13 @@ def __call__(self):
image_resized = self.resize(image_bgr, self.height, self.width)
image = self.transform_image(image_resized)
tensor = self.transform_tensor(image)
tensor = utils.ensure_device(tensor.unsqueeze(0))
outputs = pybenchmark.profile('inference')(self.inference)(torch.autograd.Variable(tensor, volatile=True))
tensor = tensor.unsqueeze(0).to(self.device)
outputs = pybenchmark.profile('inference')(self.inference)(tensor)
if hasattr(self, 'draw_cluster'):
output = outputs[-1]
parts, limbs = (output[name][0].data for name in 'parts, limbs'.split(', '))
parts, limbs = (output[name][0] for name in 'parts, limbs'.split(', '))
parts = parts[:-1]
parts, limbs = (t.cpu().numpy() for t in (parts, limbs))
parts, limbs = (t.detach().cpu().numpy() for t in (parts, limbs))
try:
interpolation = getattr(cv2, 'INTER_' + self.config.get('estimate', 'interpolation').upper())
parts, limbs = (np.stack([cv2.resize(feature, (self.width, self.height), interpolation=interpolation) for feature in a]) for a in (parts, limbs))
Expand All @@ -159,7 +163,7 @@ def __call__(self):
image_result = self.draw_cluster(image_result, cluster)
else:
image_result = image_resized.copy()
feature = self.get_feature(outputs).data.cpu().numpy()
feature = self.get_feature(outputs).detach().cpu().numpy()
image_result = self.draw_feature(image_result, feature)
if self.args.output:
if not hasattr(self, 'writer'):
Expand All @@ -183,10 +187,11 @@ def main():
utils.modify_config(config, cmd)
with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
logging.config.dictConfig(yaml.load(f))
detect = Estimate(args, config)
estimate = Estimate(args, config)
try:
while detect.cap.isOpened():
detect()
with torch.no_grad():
while estimate.cap.isOpened():
estimate()
except KeyboardInterrupt:
logging.warning('interrupted')
finally:
Expand Down
12 changes: 2 additions & 10 deletions model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,6 @@ def channel_dict(num_parts, num_limbs):
])


def feature_size(dnn, height, width):
image = torch.autograd.Variable(torch.randn(1, 3, height, width), volatile=True)
if next(dnn.parameters()).is_cuda:
image = image.cuda()
feature = dnn(image)
return feature.size()[-2:]


class Inference(nn.Module):
def __init__(self, config, dnn, stages):
nn.Module.__init__(self)
Expand All @@ -81,10 +73,10 @@ def __init__(self, config, data, limbs_index, height, width):
self.width = width

def __call__(self, **kwargs):
mask = torch.autograd.Variable(self.data['mask'].float())
mask = self.data['mask'].float()
batch_size, rows, cols = mask.size()
mask = mask.view(batch_size, 1, rows, cols)
data = {name: torch.autograd.Variable(self.data[name]) for name in kwargs}
data = {name: self.data[name] for name in kwargs}
return {name: self.loss(mask, data[name], feature) for name, feature in kwargs.items()}

def loss(self, mask, label, feature):
Expand Down
6 changes: 3 additions & 3 deletions model/dnn/inception4.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,10 @@ def init(self, config_channels):
beta = True
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight = nn.init.kaiming_normal(m.weight)
m.weight = nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
m.weight.fill_(1)
m.bias.zero_()
m.weight.requires_grad = gamma
m.bias.requires_grad = beta
try:
Expand Down
6 changes: 3 additions & 3 deletions model/dnn/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def __init__(self, config_channels):

for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight = nn.init.kaiming_normal(m.weight)
m.weight = nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
m.weight.fill_(1)
m.bias.zero_()

def forward(self, x):
return self.layers(x)
9 changes: 4 additions & 5 deletions model/dnn/mobilenet2.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,11 @@ def _initialize_weights(self):
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
m.weight.data.fill_(1) # PyTorch's bug
m.bias.data.zero_() # PyTorch's bug
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
m.weight.normal_(0, 0.01)
m.bias.zero_()


class MobileNet2Dilate2(MobileNet2):
Expand Down
6 changes: 3 additions & 3 deletions model/dnn/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ def __init__(self, config_channels, anchors, num_cls, block, layers):

for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight = nn.init.kaiming_normal(m.weight)
m.weight = nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
m.weight.fill_(1)
m.bias.zero_()

def _make_layer(self, config_channels, prefix, block, channels, blocks, stride=1):
layers = []
Expand Down
12 changes: 6 additions & 6 deletions model/stages/openpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def __init__(self, config_channels, channel_dict, channels_dnn, prefix):
def init(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight = nn.init.xavier_normal(m.weight)
m.weight = nn.init.xavier_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
m.weight.fill_(1)
m.bias.zero_()

def forward(self, x, **kwargs):
return {name: var(x) for name, var in self._modules.items()}
Expand All @@ -85,10 +85,10 @@ def __init__(self, config_channels, channels, channels_dnn, prefix):
def init(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight = nn.init.xavier_normal(m.weight)
m.weight = nn.init.xavier_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
m.weight.fill_(1)
m.bias.zero_()

def forward(self, x, **kwargs):
x = torch.cat([kwargs[name] for name in ('limbs', 'parts')] + [x], 1)
Expand Down
10 changes: 6 additions & 4 deletions receptive_field_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Analyzer(object):
def __init__(self, args, config):
self.args = args
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model_dir = utils.get_model_dir(config)
_, self.num_parts = utils.get_dataset_mappers(config)
self.limbs_index = utils.get_limbs_index(config)
Expand All @@ -64,7 +65,8 @@ def __init__(self, args, config):
if torch.cuda.is_available():
self.inference.cuda()
self.height, self.width = tuple(map(int, config.get('image', 'size').split()))
output = self.dnn(torch.autograd.Variable(utils.ensure_device(torch.zeros(1, 3, self.height, self.width)), volatile=True))
t = torch.zeros(1, 3, self.height, self.width).to(self.device)
output = self.dnn(t)
_, _, self.rows, self.cols = output.size()
self.i, self.j = self.rows // 2, self.cols // 2
self.output = output[:, :, self.i, self.j]
Expand All @@ -83,11 +85,11 @@ def __call__(self):
for i, _yx in enumerate(torch.unbind(yx)):
y, x = torch.unbind(_yx)
tensor[i, :, y, x] = 1
tensor = utils.ensure_device(tensor)
output = self.dnn(torch.autograd.Variable(tensor, volatile=True))
tensor = tensor.to(self.device)
output = self.dnn(tensor)
output = output[:, :, self.i, self.j]
cmp = output == self.output
cmp = torch.prod(cmp, -1).data
cmp = torch.prod(cmp, -1)
for _yx, c in zip(torch.unbind(yx), torch.unbind(cmp)):
y, x = torch.unbind(_yx)
changed[y, x] = c
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
tqdm
pybenchmark
graphviz
torch<=0.3.1
torch>=0.4.0
pandas
onnx
onnx_caffe2
Expand All @@ -17,7 +17,7 @@ Pillow
PyQt5
scipy
skimage
tensorboardX
tensorboardX>=1.2
tensorflow
PyYAML
pycocotools
Loading

0 comments on commit f850084

Please sign in to comment.