Skip to content

Commit

Permalink
Merge pull request #344 from mil-tokyo/dev
Browse files Browse the repository at this point in the history
versioning 1.1.0
  • Loading branch information
Kiikurage authored Jul 1, 2017
2 parents c561547 + 928ecfd commit e6ab747
Show file tree
Hide file tree
Showing 343 changed files with 15,308 additions and 6,901 deletions.
28 changes: 18 additions & 10 deletions bin/convert_caffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import argparse
import ast
import os
import sys
from os import path
Expand All @@ -14,8 +13,10 @@
import numpy as np

from webdnn.backend.interface.generator import generate_descriptor
from webdnn.graph.converters.chainer import ChainerGraphConverter
from webdnn.frontend.chainer import ChainerConverter
from webdnn.graph.graph import Graph
from webdnn.graph.shape import Shape
from webdnn.util import console


def parse_input_blob(args):
Expand All @@ -28,7 +29,9 @@ def parse_input_blob(args):
else:
if not args.input_shape:
raise ValueError("input_npy or input_shapes must be specified to determine input")
input_shape = ast.literal_eval(args.input_shape)
input_shape, placeholders = Shape.parse(args.input_shape)
if len(placeholders) > 0:
raise ValueError("caffe converter does not support an input with placeholder")
input_blob = chainer.Variable(np.zeros(input_shape, dtype=np.float32))
return input_blob, input_filled

Expand Down Expand Up @@ -57,14 +60,19 @@ def main():
input_blob, input_filled = parse_input_blob(args)
output_names = args.output_names.split(",")

sys.stderr.write("Loading caffe model... (usually takes several minutes)\n")
console.stderr("[convert_caffe] Loading caffe model... (usually takes several minutes)")
link = chainer.links.caffe.CaffeFunction(args.caffemodel)

sys.stderr.write("Generating feedforward graph\n")
output_blobs = list(
link(inputs={args.input_name: input_blob}, outputs=output_names, train=False)) # list of Variable
console.stderr("[convert_caffe] Generating feedforward graph")
if chainer.__version__ >= "2.":
chainer.using_config("train", False)
output_blobs = list(
link(inputs={args.input_name: input_blob}, outputs=output_names)) # list of Variable
else:
output_blobs = list(
link(inputs={args.input_name: input_blob}, outputs=output_names, train=False)) # list of Variable
chainer_cg = chainer.computational_graph.build_computational_graph(output_blobs)
converter = ChainerGraphConverter()
converter = ChainerConverter()
graph = converter.convert(chainer_cg, [input_blob], output_blobs) # type: Graph

if args.out:
Expand All @@ -78,15 +86,15 @@ def main():
output_arrays = {output_name: output_blob.data for output_name, output_blob in zip(output_names, output_blobs)}
np.savez(path.join(output_dir, "example_output.npz"), **output_arrays)

sys.stderr.write("Generating descriptors\n")
console.stderr("[convert_caffe] Generating descriptors")
any_backend_failed = False
for backend in args.backend.split(","):
try:
graph_exec_data = generate_descriptor(backend, graph, constant_encoder_name=args.encoding)
graph_exec_data.save(output_dir)
except Exception as ex:
any_backend_failed = True
sys.stderr.write(f"Failed generating descriptor for backend {backend}: {str(ex)}\n")
console.error(f"[convert_caffe] Failed generating descriptor for backend {backend}: {str(ex)}")

if any_backend_failed:
sys.exit(1)
Expand Down
52 changes: 40 additions & 12 deletions bin/convert_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,25 @@
"""

import argparse
import ast
import importlib.util
import os
import sys
import traceback
from os import path

import h5py

from webdnn.backend.interface.generator import generate_descriptor
from webdnn.graph.converters.keras import KerasGraphConverter
from webdnn.graph.graph import Graph
from webdnn.frontend.keras import KerasConverter
from webdnn.graph.shape import Shape
from webdnn.graph.traverse import dump_dot
from webdnn.util import flags, console


def _load_plugin(filepath: str):
spec = importlib.util.spec_from_file_location("_plugin", filepath)
plugin = importlib.util.module_from_spec(spec)
spec.loader.exec_module(plugin)


def main():
Expand All @@ -27,13 +36,19 @@ def main():
parser.add_argument("--out",
help="output directory (default: <model>/webdnn_graph_descriptor)")
parser.add_argument("--encoding", help="name of weight encoder")
parser.add_argument("--visualize_ir", action="store_true")
parser.add_argument("--plugin", action="append", help="plugin python files which are imported before transpiling")
args = parser.parse_args()

sys.stderr.write("Generating feedforward graph\n")
input_shape = ast.literal_eval(args.input_shape)
console.stderr(f"[{path.basename(__file__)}] Generating feedforward graph")
if args.plugin:
for plugin_path in args.plugin:
_load_plugin(plugin_path)

input_shape, _ = Shape.parse(args.input_shape)
input_shapes = [input_shape]
model = h5py.File(args.kerasmodel, "r")
converter = KerasGraphConverter()
converter = KerasConverter()
graph = converter.convert(model, input_shapes)

if args.out:
Expand All @@ -42,20 +57,33 @@ def main():
output_dir = path.join(path.dirname(args.kerasmodel), "webdnn_graph_descriptor")
os.makedirs(output_dir, exist_ok=True)

sys.stderr.write("Generating descriptors\n")
if args.visualize_ir:
ir_dot_path = path.join(output_dir, "ir.dot")
with open(ir_dot_path, "w") as f:
f.write(dump_dot(graph))
console.stderr(f"IR graph can be visualized with graphviz command: 'dot {ir_dot_path} -T png -o output.png'")

console.stderr(f"[{path.basename(__file__)}] Generating graph descriptor")

any_backend_failed = False
last_backend_exception = None
for backend in args.backend.split(","):
backends = args.backend.split(",")
for i, backend in enumerate(backends):
console.stderr(f"[{path.basename(__file__)}] Backend: {console.colorize(backend, console.Color.Cyan)}")
try:
graph_exec_data = generate_descriptor(backend, graph, constant_encoder_name=args.encoding)
graph_exec_data.save(output_dir)
except Exception as ex:
if flags.DEBUG:
raise ex

any_backend_failed = True
last_backend_exception = ex
sys.stderr.write(f"Failed generating descriptor for backend {backend}: {str(ex)}\n")
console.error(f"[{path.basename(__file__)}] Failed generating descriptor for {backend} backend")
console.stderr(traceback.format_exc())
continue

if any_backend_failed:
raise last_backend_exception
exit(1)
# raise last_backend_exception


if __name__ == "__main__":
Expand Down
Binary file removed dist/webdnn-1.0.0-py3.6.egg
Binary file not shown.
Binary file added dist/webdnn-1.1.0-py3.6.egg
Binary file not shown.
Loading

0 comments on commit e6ab747

Please sign in to comment.