+License: Apache-2.0
diff --git a/debian/pycoral-examples.install b/debian/pycoral-examples.install
new file mode 100644
index 0000000..edb0878
--- /dev/null
+++ b/debian/pycoral-examples.install
@@ -0,0 +1,17 @@
+examples/* /usr/share/pycoral/examples/
+test_data/COPYRIGHT /usr/share/pycoral/examples/images
+test_data/bird.bmp /usr/share/pycoral/examples/images
+test_data/cat.bmp /usr/share/pycoral/examples/images
+test_data/grace_hopper.bmp /usr/share/pycoral/examples/images
+test_data/parrot.jpg /usr/share/pycoral/examples/images
+test_data/sunflower.bmp /usr/share/pycoral/examples/images
+test_data/deeplabv3_mnv2_pascal_quant_edgetpu.tflite /usr/share/pycoral/examples/models
+test_data/ssd_mobilenet_v1_coco_quant_postprocess_edgetpu.tflite /usr/share/pycoral/examples/models
+test_data/ssd_mobilenet_v2_coco_quant_postprocess_edgetpu.tflite /usr/share/pycoral/examples/models
+test_data/ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite /usr/share/pycoral/examples/models
+test_data/mobilenet_v2_1.0_224_inat_bird_quant_edgetpu.tflite /usr/share/pycoral/examples/models
+test_data/mobilenet_v1_1.0_224_quant_embedding_extractor_edgetpu.tflite /usr/share/pycoral/examples/models
+test_data/mobilenet_v2_1.0_224_quant_edgetpu.tflite /usr/share/pycoral/examples/models
+test_data/coco_labels.txt /usr/share/pycoral/examples/models
+test_data/imagenet_labels.txt /usr/share/pycoral/examples/models
+test_data/inat_bird_labels.txt /usr/share/pycoral/examples/models
diff --git a/debian/rules b/debian/rules
new file mode 100755
index 0000000..7e7c0b0
--- /dev/null
+++ b/debian/rules
@@ -0,0 +1,29 @@
+#!/usr/bin/make -f
+# -*- makefile -*-
+
+# Uncomment this to turn on verbose mode.
+# export DH_VERBOSE=1
+PYBIND_ARMHF := {destdir}/{install_dir}/pycoral/pybind/_pywrap_coral.cpython-*-arm-linux-gnueabihf.so
+PYBIND_ARM64 := {destdir}/{install_dir}/pycoral/pybind/_pywrap_coral.cpython-*-aarch64-linux-gnu.so
+PYBIND_AMD64 := {destdir}/{install_dir}/pycoral/pybind/_pywrap_coral.cpython-*-x86_64-linux-gnu.so
+
+export PYBUILD_NAME=pycoral
+
+ifeq ($(DEB_TARGET_ARCH),armhf)
+ export PYBUILD_AFTER_INSTALL=rm -f $(PYBIND_ARM64) $(PYBIND_AMD64)
+else ifeq ($(DEB_TARGET_ARCH),arm64)
+ export PYBUILD_AFTER_INSTALL=rm -f $(PYBIND_ARMHF) $(PYBIND_AMD64)
+else ifeq ($(DEB_TARGET_ARCH),amd64)
+ export PYBUILD_AFTER_INSTALL=rm -f $(PYBIND_ARMHF) $(PYBIND_ARM64)
+endif
+
+%:
+ # Assume that all coral/pybind/*.so have been already built.
+ dh $@ --with python3 --buildsystem=pybuild
+
+# Skip .so post processing.
+override_dh_strip:
+override_dh_shlibdeps:
+
+# Skip tests.
+override_dh_auto_test:
diff --git a/docs/README.md b/docs/README.md
new file mode 100644
index 0000000..bc9324d
--- /dev/null
+++ b/docs/README.md
@@ -0,0 +1,27 @@
+# PyCoral API docs
+
+This directory holds the source files required to build the PyCoral API
+reference with Sphinx.
+
+You can build the reference docs as follows:
+
+```
+# We require Python3, so if that's not your default, first start a virtual environment:
+python3 -m venv ~/.my_venvs/coraldocs
+source ~/.my_venvs/coraldocs/bin/activate
+
+# Navigate to the pycoral/docs/ directory and run these commands...
+
+# Install the doc build dependencies:
+pip install -r requirements.txt
+
+# Build the docs:
+bash makedocs.sh
+```
+
+The results are output in `_build/`. The `_build/preview/` files are for local
+viewing--just open the `index.html` page. The `_build/web/` files are designed
+for publishing on the Coral website.
+
+For more information about the syntax in these RST files, see the
+[reStructuredText documentation](http://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html).
diff --git a/docs/conf.py b/docs/conf.py
new file mode 100644
index 0000000..7e1d9e7
--- /dev/null
+++ b/docs/conf.py
@@ -0,0 +1,122 @@
+# pylint:disable=missing-docstring,redefined-builtin
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# -*- coding: utf-8 -*-
+#
+# Configuration file for the Sphinx documentation builder.
+#
+# This file does only contain a selection of the most common options. For a
+# full list see the documentation:
+# http://www.sphinx-doc.org/en/master/config
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+import os
+import sys
+import unittest.mock
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
+
+# Mock modules not needed for docs
+sys.modules.update([('tflite_runtime', unittest.mock.MagicMock())])
+sys.modules.update([('tflite_runtime.interpreter', unittest.mock.MagicMock())])
+sys.modules.update([('pycoral.pybind', unittest.mock.MagicMock())])
+sys.modules.update([('pycoral.pybind._pywrap_coral',
+ unittest.mock.MagicMock())])
+
+# -- Project information -----------------------------------------------------
+
+project = 'PyCoral API'
+copyright = '2020, Google LLC'
+author = 'Google LLC'
+
+# The short X.Y version
+version = '1.0'
+# The full version, including alpha/beta/rc tags
+release = ''
+
+# -- General configuration ---------------------------------------------------
+
+# If your documentation needs a minimal Sphinx version, state it here.
+#
+# needs_sphinx = '1.0'
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.intersphinx', # Enables linking to other libs like Pillow
+ 'sphinx.ext.coverage',
+ 'sphinx.ext.napoleon', # Converts Google-style code comments to RST
+]
+
+# Autodoc configurations
+autoclass_content = 'both'
+
+# Intersfphinx config; Controls external linking to other python libraries
+intersphinx_mapping = {
+ 'python': ('https://docs.python.org/', None),
+ 'PIL': ('https://pillow.readthedocs.io/en/stable/', None),
+ 'numpy': ('https://docs.scipy.org/doc/numpy/', None)
+}
+
+# Disable rtype return values; output return type inline with description
+napoleon_use_rtype = False
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+source_parsers = {
+ '.md': 'recommonmark.parser.CommonMarkParser',
+}
+
+# The suffix(es) of source filenames.
+# You can specify multiple suffix as a list of string:
+#
+# source_suffix = ['.rst', '.md']
+source_suffix = ['.rst', '.md']
+
+# The master toctree document.
+master_doc = 'index'
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+#
+# This is also used if you do content translation via gettext catalogs.
+# Usually you set "language" from the command line for these cases.
+language = None
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'README*']
+
+# The name of the Pygments (syntax highlighting) style to use.
+pygments_style = None
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'coral_theme'
+html_theme_path = ['.']
+html_file_suffix = '.md'
+html_link_suffix = '/'
diff --git a/docs/coral_theme/layout.html b/docs/coral_theme/layout.html
new file mode 100644
index 0000000..5e74d1b
--- /dev/null
+++ b/docs/coral_theme/layout.html
@@ -0,0 +1,7 @@
+----
+Title: {{ title|striptags|e }}
+----
+
+
+{% block body %}{% endblock %}
+
\ No newline at end of file
diff --git a/docs/coral_theme/search.html b/docs/coral_theme/search.html
new file mode 100644
index 0000000..53cce44
--- /dev/null
+++ b/docs/coral_theme/search.html
@@ -0,0 +1 @@
+{{ toctree() }}
\ No newline at end of file
diff --git a/docs/coral_theme/theme.conf b/docs/coral_theme/theme.conf
new file mode 100644
index 0000000..aa1d70a
--- /dev/null
+++ b/docs/coral_theme/theme.conf
@@ -0,0 +1,12 @@
+[theme]
+inherit = basic
+stylesheet = none
+pygments_style = none
+sidebars = localtoc.html, relations.html, sourcelink.html, searchbox.html
+file_suffix = md
+
+[options]
+nosidebar = false
+sidebarwidth = 230
+body_min_width = 450
+body_max_width = 800
\ No newline at end of file
diff --git a/docs/index.rst b/docs/index.rst
new file mode 100644
index 0000000..97decae
--- /dev/null
+++ b/docs/index.rst
@@ -0,0 +1,68 @@
+PyCoral API reference
+=====================
+
+This is the API reference for the Coral Python library.
+
+
+Module summary
+--------------
+
++ :mod:`pycoral.utils.dataset`
+
+ .. automodule:: pycoral.utils.dataset
+ :noindex:
+
++ :mod:`pycoral.utils.edgetpu`
+
+ .. automodule:: pycoral.utils.edgetpu
+ :noindex:
+
++ :mod:`pycoral.adapters.common`
+
+ .. automodule:: pycoral.adapters.common
+ :noindex:
+
++ :mod:`pycoral.adapters.classify`
+
+ .. automodule:: pycoral.adapters.classify
+ :noindex:
+
++ :mod:`pycoral.adapters.detect`
+
+ .. automodule:: pycoral.adapters.detect
+ :noindex:
+
++ :mod:`pycoral.pipeline.pipelined_model_runner`
+
+ .. automodule:: pycoral.pipeline.pipelined_model_runner
+ :noindex:
+
++ :mod:`pycoral.learn.backprop.softmax_regression`
+
+ .. automodule:: pycoral.learn.backprop.softmax_regression
+ :noindex:
+
++ :mod:`pycoral.learn.imprinting.engine`
+
+ .. automodule:: pycoral.learn.imprinting.engine
+ :noindex:
+
+
+Contents
+--------
+
+.. toctree::
+ :maxdepth: 1
+
+ pycoral.utils
+ pycoral.adapters
+ pycoral.pipeline
+ pycoral.learn.backprop
+ pycoral.learn.imprinting
+
+
+API indices
+-----------
+
+* :ref:`Full index `
+* :ref:`Module index `
diff --git a/docs/makedocs.sh b/docs/makedocs.sh
new file mode 100644
index 0000000..a84c0a0
--- /dev/null
+++ b/docs/makedocs.sh
@@ -0,0 +1,80 @@
+#!/bin/bash
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+BUILD_DIR="_build"
+PREVIEW_DIR="${BUILD_DIR}/preview"
+WEB_DIR="${BUILD_DIR}/web"
+
+makeAll() {
+ makeClean
+ makeSphinxPreview
+ makeSphinxWeb
+}
+
+makeSphinxWeb() {
+ echo "Building Sphinx files for website..."
+ sphinx-build -b html . ${WEB_DIR}
+ # Delete intermediary/unused files:
+ find ${WEB_DIR} -mindepth 1 -not -name "*.md" -delete
+ rm ${WEB_DIR}/search.md ${WEB_DIR}/genindex.md ${WEB_DIR}/py-modindex.md
+ # Some custom tweaks to the output:
+ python postprocess.py -f ${WEB_DIR}/
+ echo "All done. Web pages are in ${WEB_DIR}."
+}
+
+makeSphinxPreview() {
+ echo "Building Sphinx files for local preview..."
+ # Build the docs for local viewing (in "read the docs" style):
+ sphinx-build -b html . ${PREVIEW_DIR} \
+ -D html_theme="sphinx_rtd_theme" \
+ -D html_file_suffix=".html" \
+ -D html_link_suffix=".html"
+ echo "All done. Preview pages are in ${PREVIEW_DIR}."
+}
+
+makeClean() {
+ rm -rf ${BUILD_DIR}
+ echo "Deleted ${BUILD_DIR}."
+}
+
+usage() {
+ echo -n "Usage:
+ makedocs.sh [-a|-w|-p|-c]
+
+ Options (only one allowed):
+ -a Clean and make all docs (default)
+ -w Make Sphinx for website
+ -p Make Sphinx for local preview
+ -c Clean
+"
+}
+
+if [[ "$#" -gt 1 ]]; then
+ usage
+elif [[ "$#" -eq 1 ]]; then
+ if [[ "$1" = "-a" ]]; then
+ makeAll
+ elif [[ "$1" = "-w" ]]; then
+ makeSphinxWeb
+ elif [[ "$1" = "-p" ]]; then
+ makeSphinxPreview
+ elif [[ "$1" = "-c" ]]; then
+ makeClean
+ else
+ usage
+ fi
+else
+ makeAll
+fi
diff --git a/docs/postprocess.py b/docs/postprocess.py
new file mode 100644
index 0000000..1af8845
--- /dev/null
+++ b/docs/postprocess.py
@@ -0,0 +1,113 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Process the PyCoral docs from Sphinx to optimize them for Coral website."""
+
+import argparse
+import os
+import re
+
+from bs4 import BeautifulSoup
+
+
+def remove_title(soup):
+ """Deletes the extra H1 title."""
+ h1 = soup.find('h1')
+ if h1:
+ h1.extract()
+ return soup
+
+
+def relocate_h2id(soup):
+ """Moves the anchor ID to the H2 tag, from the wrapper DIV."""
+ for h2 in soup.find_all('h2'):
+ div = h2.find_parent('div')
+ if div.has_attr('id') and not h2.has_attr('id'):
+ # print('Move ID: ' + div['id'])
+ h2['id'] = div['id']
+ del div['id']
+ # Also delete embedded tag
+ if h2.find('a'):
+ h2.find('a').extract()
+ return soup
+
+
+def clean_pre(soup):
+ """Adds our prettyprint class to PRE and removes some troubelsome tags."""
+ for pre in soup.find_all('pre'):
+ pre['class'] = 'language-cpp'
+ # This effectively deletes the wrapper DIV and P tags that cause issues
+ parent_p = pre.find_parent('p')
+ if parent_p:
+ parent_p.replace_with(pre)
+ return soup
+
+
+def remove_coral(soup):
+ """Removes 'coral' namespace link that does nothing."""
+ for a in soup.select('a[title=coral]'):
+ content = a.contents[0]
+ a.replace_with(content)
+ return soup
+
+
+def remove_init_string(soup):
+ """Removes a Sphinx-supplied description for namedtuple classes."""
+ paras = soup.find_all('p', string=re.compile(r'^Create new instance of'))
+ for p in paras:
+ p.extract()
+ return soup
+
+
+def clean_index(soup):
+ """Removes relative-URL backstep in index page links, due to website move."""
+ for link in soup.find_all('a'):
+ if link['href'].startswith('../'):
+ link['href'] = link['href'][1:]
+ return soup
+
+
+def process(file):
+ """Runs all the cleanup functions."""
+ print('Post-processing ' + file)
+ soup = BeautifulSoup(open(file), 'html.parser')
+ soup = remove_title(soup)
+ soup = relocate_h2id(soup)
+ soup = clean_pre(soup)
+ soup = remove_coral(soup)
+ soup = remove_init_string(soup)
+ if os.path.split(file)[1] == 'index.md':
+ soup = clean_index(soup)
+ with open(file, 'w') as output:
+ output.write(str(soup))
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ '-f', '--file', required=True, help='File path of HTML file(s).')
+ args = parser.parse_args()
+
+ # Accept a directory or single file
+ if os.path.isdir(args.file):
+ for file in os.listdir(args.file):
+ if os.path.splitext(file)[1] == '.md':
+ process(os.path.join(args.file, file))
+ else:
+ process(args.file)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/docs/pycoral.adapters.rst b/docs/pycoral.adapters.rst
new file mode 100644
index 0000000..6738f49
--- /dev/null
+++ b/docs/pycoral.adapters.rst
@@ -0,0 +1,30 @@
+pycoral.adapters
+================
+
+pycoral.adapters.common
+-----------------------
+
+.. automodule:: pycoral.adapters.common
+ :members:
+ :undoc-members:
+
+pycoral.adapters.classify
+-------------------------
+
+.. automodule:: pycoral.adapters.classify
+ :members: get_classes, get_classes_from_scores, get_scores, num_classes
+ :undoc-members:
+
+.. autoclass:: pycoral.adapters.classify.Class
+
+pycoral.adapters.detect
+-----------------------
+
+.. automodule:: pycoral.adapters.detect
+ :members: get_objects
+
+.. autoclass:: pycoral.adapters.detect.Object
+
+.. autoclass:: pycoral.adapters.detect.BBox
+ :members:
+ :member-order: bysource
\ No newline at end of file
diff --git a/docs/pycoral.learn.backprop.rst b/docs/pycoral.learn.backprop.rst
new file mode 100644
index 0000000..d574bc9
--- /dev/null
+++ b/docs/pycoral.learn.backprop.rst
@@ -0,0 +1,10 @@
+pycoral.learn.backprop
+======================
+
+pycoral.learn.backprop.softmax_regression
+-----------------------------------------
+
+.. automodule:: pycoral.learn.backprop.softmax_regression
+ :members:
+ :undoc-members:
+ :inherited-members:
\ No newline at end of file
diff --git a/docs/pycoral.learn.imprinting.rst b/docs/pycoral.learn.imprinting.rst
new file mode 100644
index 0000000..c83c4d1
--- /dev/null
+++ b/docs/pycoral.learn.imprinting.rst
@@ -0,0 +1,10 @@
+pycoral.learn.imprinting
+========================
+
+pycoral.learn.imprinting.engine
+-------------------------------
+
+.. automodule:: pycoral.learn.imprinting.engine
+ :members:
+ :undoc-members:
+ :inherited-members:
\ No newline at end of file
diff --git a/docs/pycoral.pipeline.rst b/docs/pycoral.pipeline.rst
new file mode 100644
index 0000000..c796b99
--- /dev/null
+++ b/docs/pycoral.pipeline.rst
@@ -0,0 +1,10 @@
+pycoral.pipeline
+================
+
+pycoral.pipeline.pipelined_model_runner
+---------------------------------------
+
+.. automodule:: pycoral.pipeline.pipelined_model_runner
+ :members:
+ :undoc-members:
+ :inherited-members:
diff --git a/docs/pycoral.utils.rst b/docs/pycoral.utils.rst
new file mode 100644
index 0000000..4baaa2d
--- /dev/null
+++ b/docs/pycoral.utils.rst
@@ -0,0 +1,19 @@
+pycoral.utils
+=============
+
+pycoral.utils.dataset
+---------------------
+
+.. automodule:: pycoral.utils.dataset
+ :members:
+ :undoc-members:
+ :inherited-members:
+
+
+pycoral.utils.edgetpu
+---------------------
+
+.. automodule:: pycoral.utils.edgetpu
+ :members:
+ :undoc-members:
+ :inherited-members:
\ No newline at end of file
diff --git a/docs/requirements.txt b/docs/requirements.txt
new file mode 100644
index 0000000..eea2831
--- /dev/null
+++ b/docs/requirements.txt
@@ -0,0 +1,6 @@
+# Python packages required to build the docs
+sphinx
+sphinx_rtd_theme
+recommonmark
+numpy
+beautifulsoup4
diff --git a/examples/README.md b/examples/README.md
new file mode 100644
index 0000000..a4c7c8b
--- /dev/null
+++ b/examples/README.md
@@ -0,0 +1,34 @@
+# PyCoral API examples
+
+This directory contains several examples that show how to use the
+[PyCoral API](https://coral.ai/docs/edgetpu/api-intro/) to perform
+inference or on-device transfer learning.
+
+## Get the code
+
+Before you begin, you must
+[set up your Coral device](https://coral.ai/docs/setup/).
+
+Then simply clone this repo:
+
+```
+git clone https://github.com/google-coral/pycoral
+```
+
+For more pre-compiled models, see [coral.ai/models](https://coral.ai/models/).
+
+## Run the example code
+
+Each `.py` file includes documentation at the top with an example command you
+can use to run it. Notice that they all use a pre-compiled model from the
+`test_data` directory, which is a submodule dependency for this repo. So if you
+want to use those files, you can clone them within the `pycoral` repo like this:
+
+```
+cd pycoral
+
+git submodule init && git submodule update
+```
+
+For more information about building models and running inference on the Edge
+TPU, see the [Coral documentation](https://coral.ai/docs/).
diff --git a/examples/backprop_last_layer.py b/examples/backprop_last_layer.py
new file mode 100644
index 0000000..8429e1e
--- /dev/null
+++ b/examples/backprop_last_layer.py
@@ -0,0 +1,265 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+r"""A demo for on-device backprop (transfer learning) of a classification model.
+
+This demo runs a similar task as described in TF Poets tutorial, except that
+learning happens on-device.
+https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0
+
+Here are the steps:
+1) mkdir -p /tmp/retrain/
+
+2) curl http://download.tensorflow.org/example_images/flower_photos.tgz \
+ | tar xz -C /tmp/retrain
+
+3) Start training:
+
+ python3 backprop_last_layer.py \
+ --data_dir /tmp/retrain/flower_photos \
+ --embedding_extractor_path \
+ test_data/mobilenet_v1_1.0_224_quant_embedding_extractor_edgetpu.tflite
+
+ Weights for retrained last layer will be saved to /tmp/retrain/output by
+ default.
+
+4) Run an inference with the new model:
+
+ python3 classify_image.py \
+ --model /tmp/retrain/output/retrained_model_edgetpu.tflite \
+ --label /tmp/retrain/output/label_map.txt
+ --input test_data/sunflower.bmp
+
+For more information, see
+https://coral.ai/docs/edgetpu/retrain-classification-ondevice-backprop/
+"""
+
+import argparse
+import contextlib
+import os
+import sys
+import time
+
+import numpy as np
+from PIL import Image
+
+from pycoral.adapters import classify
+from pycoral.adapters import common
+from pycoral.learn.backprop.softmax_regression import SoftmaxRegression
+from pycoral.utils.edgetpu import make_interpreter
+
+
+@contextlib.contextmanager
+def test_image(path):
+ """Returns opened test image."""
+ with open(path, 'rb') as f:
+ with Image.open(f) as image:
+ yield image
+
+
+def save_label_map(label_map, out_path):
+ """Saves label map to a file."""
+ with open(out_path, 'w') as f:
+ for key, val in label_map.items():
+ f.write('%s %s\n' % (key, val))
+
+
+def get_image_paths(data_dir):
+ """Walks through data_dir and returns list of image paths and label map.
+
+ Args:
+ data_dir: string, path to data directory. It assumes data directory is
+ organized as, - [CLASS_NAME_0] -- image_class_0_a.jpg --
+ image_class_0_b.jpg -- ... - [CLASS_NAME_1] -- image_class_1_a.jpg -- ...
+
+ Returns:
+ A tuple of (image_paths, labels, label_map)
+ image_paths: list of string, represents image paths
+ labels: list of int, represents labels
+ label_map: a dictionary (int -> string), e.g., 0->class0, 1->class1, etc.
+ """
+ classes = None
+ image_paths = []
+ labels = []
+
+ class_idx = 0
+ for root, dirs, files in os.walk(data_dir):
+ if root == data_dir:
+ # Each sub-directory in `data_dir`
+ classes = dirs
+ else:
+ # Read each sub-directory
+ assert classes[class_idx] in root
+ print('Reading dir: %s, which has %d images' % (root, len(files)))
+ for img_name in files:
+ image_paths.append(os.path.join(root, img_name))
+ labels.append(class_idx)
+ class_idx += 1
+
+ return image_paths, labels, dict(zip(range(class_idx), classes))
+
+
+def shuffle_and_split(image_paths, labels, val_percent=0.1, test_percent=0.1):
+ """Shuffles and splits data into train, validation, and test sets.
+
+ Args:
+ image_paths: list of string, of dim num_data
+ labels: list of int of length num_data
+ val_percent: validation data set percentage.
+ test_percent: test data set percentage.
+
+ Returns:
+ Two dictionaries (train_and_val_dataset, test_dataset).
+ train_and_val_dataset has the following fields.
+ 'data_train': data_train
+ 'labels_train': labels_train
+ 'data_val': data_val
+ 'labels_val': labels_val
+ test_dataset has the following fields.
+ 'data_test': data_test
+ 'labels_test': labels_test
+ """
+ image_paths = np.array(image_paths)
+ labels = np.array(labels)
+ perm = np.random.permutation(image_paths.shape[0])
+ image_paths = image_paths[perm]
+ labels = labels[perm]
+
+ num_total = image_paths.shape[0]
+ num_val = int(num_total * val_percent)
+ num_test = int(num_total * test_percent)
+ num_train = num_total - num_val - num_test
+
+ train_and_val_dataset = {}
+ train_and_val_dataset['data_train'] = image_paths[0:num_train]
+ train_and_val_dataset['labels_train'] = labels[0:num_train]
+ train_and_val_dataset['data_val'] = image_paths[num_train:num_train + num_val]
+ train_and_val_dataset['labels_val'] = labels[num_train:num_train + num_val]
+ test_dataset = {}
+ test_dataset['data_test'] = image_paths[num_train + num_val:]
+ test_dataset['labels_test'] = labels[num_train + num_val:]
+ return train_and_val_dataset, test_dataset
+
+
+def extract_embeddings(image_paths, interpreter):
+ """Uses model to process images as embeddings.
+
+ Reads image, resizes and feeds to model to get feature embeddings. Original
+ image is discarded to keep maximum memory consumption low.
+
+ Args:
+ image_paths: ndarray, represents a list of image paths.
+ interpreter: TFLite interpreter, wraps embedding extractor model.
+
+ Returns:
+ ndarray of length image_paths.shape[0] of embeddings.
+ """
+ input_size = common.input_size(interpreter)
+ feature_dim = classify.num_classes(interpreter)
+ embeddings = np.empty((len(image_paths), feature_dim), dtype=np.float32)
+ for idx, path in enumerate(image_paths):
+ with test_image(path) as img:
+ common.set_input(interpreter, img.resize(input_size, Image.NEAREST))
+ interpreter.invoke()
+ embeddings[idx, :] = classify.get_scores(interpreter)
+
+ return embeddings
+
+
+def train(model_path, data_dir, output_dir):
+ """Trains a softmax regression model given data and embedding extractor.
+
+ Args:
+ model_path: string, path to embedding extractor.
+ data_dir: string, directory that contains training data.
+ output_dir: string, directory to save retrained tflite model and label map.
+ """
+ t0 = time.perf_counter()
+ image_paths, labels, label_map = get_image_paths(data_dir)
+ train_and_val_dataset, test_dataset = shuffle_and_split(image_paths, labels)
+ # Initializes interpreter and allocates tensors here to avoid repeatedly
+ # initialization which is time consuming.
+ interpreter = make_interpreter(model_path, device=':0')
+ interpreter.allocate_tensors()
+ print('Extract embeddings for data_train')
+ train_and_val_dataset['data_train'] = extract_embeddings(
+ train_and_val_dataset['data_train'], interpreter)
+ print('Extract embeddings for data_val')
+ train_and_val_dataset['data_val'] = extract_embeddings(
+ train_and_val_dataset['data_val'], interpreter)
+ t1 = time.perf_counter()
+ print('Data preprocessing takes %.2f seconds' % (t1 - t0))
+
+ # Construct model and start training
+ weight_scale = 5e-2
+ reg = 0.0
+ feature_dim = train_and_val_dataset['data_train'].shape[1]
+ num_classes = np.max(train_and_val_dataset['labels_train']) + 1
+ model = SoftmaxRegression(
+ feature_dim, num_classes, weight_scale=weight_scale, reg=reg)
+
+ learning_rate = 1e-2
+ batch_size = 100
+ num_iter = 500
+ model.train_with_sgd(
+ train_and_val_dataset, num_iter, learning_rate, batch_size=batch_size)
+ t2 = time.perf_counter()
+ print('Training takes %.2f seconds' % (t2 - t1))
+
+ # Append learned weights to input model and save as tflite format.
+ out_model_path = os.path.join(output_dir, 'retrained_model_edgetpu.tflite')
+ with open(out_model_path, 'wb') as f:
+ f.write(model.serialize_model(model_path))
+ print('Model %s saved.' % out_model_path)
+ label_map_path = os.path.join(output_dir, 'label_map.txt')
+ save_label_map(label_map, label_map_path)
+ print('Label map %s saved.' % label_map_path)
+ t3 = time.perf_counter()
+ print('Saving retrained model and label map takes %.2f seconds' % (t3 - t2))
+
+ retrained_interpreter = make_interpreter(out_model_path, device=':0')
+ retrained_interpreter.allocate_tensors()
+ test_embeddings = extract_embeddings(test_dataset['data_test'],
+ retrained_interpreter)
+ saved_model_acc = np.mean(
+ np.argmax(test_embeddings, axis=1) == test_dataset['labels_test'])
+ print('Saved tflite model test accuracy: %.2f%%' % (saved_model_acc * 100))
+ t4 = time.perf_counter()
+ print('Checking test accuracy takes %.2f seconds' % (t4 - t3))
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--embedding_extractor_path',
+ required=True,
+ help='Path to embedding extractor tflite model.')
+ parser.add_argument('--data_dir', required=True, help='Directory to data.')
+ parser.add_argument(
+ '--output_dir',
+ default='/tmp/retrain/output',
+ help='Path to directory to save retrained model and label map.')
+ args = parser.parse_args()
+
+ if not os.path.exists(args.data_dir):
+ sys.exit('%s does not exist!' % args.data_dir)
+
+ if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir)
+
+ train(args.embedding_extractor_path, args.data_dir, args.output_dir)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/classify_image.py b/examples/classify_image.py
new file mode 100644
index 0000000..9f37ecd
--- /dev/null
+++ b/examples/classify_image.py
@@ -0,0 +1,82 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+r"""Example using PyCoral to classify a given image using an Edge TPU.
+
+To run this code, you must attach an Edge TPU attached to the host and
+install the Edge TPU runtime (`libedgetpu.so`) and `tflite_runtime`. For
+device setup instructions, see coral.ai/docs/setup.
+
+Example usage:
+```
+python3 classify_image.py \
+ --model test_data/mobilenet_v2_1.0_224_inat_bird_quant_edgetpu.tflite \
+ --labels test_data/inat_bird_labels.txt \
+ --input test_data/parrot.jpg
+```
+"""
+
+import argparse
+import time
+
+from PIL import Image
+from pycoral.adapters import classify
+from pycoral.adapters import common
+from pycoral.utils.dataset import read_label_file
+from pycoral.utils.edgetpu import make_interpreter
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('-m', '--model', required=True,
+ help='File path of .tflite file.')
+ parser.add_argument('-i', '--input', required=True,
+ help='Image to be classified.')
+ parser.add_argument('-l', '--labels',
+ help='File path of labels file.')
+ parser.add_argument('-k', '--top_k', type=int, default=1,
+ help='Max number of classification results')
+ parser.add_argument('-t', '--threshold', type=float, default=0.0,
+ help='Classification score threshold')
+ parser.add_argument('-c', '--count', type=int, default=5,
+ help='Number of times to run inference')
+ args = parser.parse_args()
+
+ labels = read_label_file(args.labels) if args.labels else {}
+
+ interpreter = make_interpreter(*args.model.split('@'))
+ interpreter.allocate_tensors()
+
+ size = common.input_size(interpreter)
+ image = Image.open(args.input).convert('RGB').resize(size, Image.ANTIALIAS)
+ common.set_input(interpreter, image)
+
+ print('----INFERENCE TIME----')
+ print('Note: The first inference on Edge TPU is slow because it includes',
+ 'loading the model into Edge TPU memory.')
+ for _ in range(args.count):
+ start = time.perf_counter()
+ interpreter.invoke()
+ inference_time = time.perf_counter() - start
+ classes = classify.get_classes(interpreter, args.top_k, args.threshold)
+ print('%.1fms' % (inference_time * 1000))
+
+ print('-------RESULTS--------')
+ for c in classes:
+ print('%s: %.5f' % (labels.get(c.id, c.id), c.score))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/detect_image.py b/examples/detect_image.py
new file mode 100644
index 0000000..e122862
--- /dev/null
+++ b/examples/detect_image.py
@@ -0,0 +1,106 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+r"""Example using PyCoral to detect objects in a given image.
+
+To run this code, you must attach an Edge TPU attached to the host and
+install the Edge TPU runtime (`libedgetpu.so`) and `tflite_runtime`. For
+device setup instructions, see coral.ai/docs/setup.
+
+Example usage:
+```
+python3 detect_image.py \
+ --model test_data/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite \
+ --labels test_data/coco_labels.txt \
+ --input test_data/grace_hopper.bmp \
+ --output ${HOME}/grace_hopper_processed.bmp
+```
+"""
+
+import argparse
+import time
+
+from PIL import Image
+from PIL import ImageDraw
+
+from pycoral.adapters import common
+from pycoral.adapters import detect
+from pycoral.utils.dataset import read_label_file
+from pycoral.utils.edgetpu import make_interpreter
+
+
+def draw_objects(draw, objs, labels):
+ """Draws the bounding box and label for each object."""
+ for obj in objs:
+ bbox = obj.bbox
+ draw.rectangle([(bbox.xmin, bbox.ymin), (bbox.xmax, bbox.ymax)],
+ outline='red')
+ draw.text((bbox.xmin + 10, bbox.ymin + 10),
+ '%s\n%.2f' % (labels.get(obj.id, obj.id), obj.score),
+ fill='red')
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('-m', '--model', required=True,
+ help='File path of .tflite file')
+ parser.add_argument('-i', '--input', required=True,
+ help='File path of image to process')
+ parser.add_argument('-l', '--labels', help='File path of labels file')
+ parser.add_argument('-t', '--threshold', type=float, default=0.4,
+ help='Score threshold for detected objects')
+ parser.add_argument('-o', '--output',
+ help='File path for the result image with annotations')
+ parser.add_argument('-c', '--count', type=int, default=5,
+ help='Number of times to run inference')
+ args = parser.parse_args()
+
+ labels = read_label_file(args.labels) if args.labels else {}
+ interpreter = make_interpreter(args.model)
+ interpreter.allocate_tensors()
+
+ image = Image.open(args.input)
+ _, scale = common.set_resized_input(
+ interpreter, image.size, lambda size: image.resize(size, Image.ANTIALIAS))
+
+ print('----INFERENCE TIME----')
+ print('Note: The first inference is slow because it includes',
+ 'loading the model into Edge TPU memory.')
+ for _ in range(args.count):
+ start = time.perf_counter()
+ interpreter.invoke()
+ inference_time = time.perf_counter() - start
+ objs = detect.get_objects(interpreter, args.threshold, scale)
+ print('%.2f ms' % (inference_time * 1000))
+
+ print('-------RESULTS--------')
+ if not objs:
+ print('No objects detected')
+
+ for obj in objs:
+ print(labels.get(obj.id, obj.id))
+ print(' id: ', obj.id)
+ print(' score: ', obj.score)
+ print(' bbox: ', obj.bbox)
+
+ if args.output:
+ image = image.convert('RGB')
+ draw_objects(ImageDraw.Draw(image), objs, labels)
+ image.save(args.output)
+ image.show()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/imprinting_learning.py b/examples/imprinting_learning.py
new file mode 100644
index 0000000..d74cb69
--- /dev/null
+++ b/examples/imprinting_learning.py
@@ -0,0 +1,229 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+r"""A demo for on-device imprinting (transfer learning) of a classification model.
+
+Here are the steps:
+1) Download the data set for transfer learning:
+ ```
+ wget https://dl.google.com/coral/sample_data/imprinting_data_script.tar.gz
+ tar zxf imprinting_data_script.tar.gz
+ ./imprinting_data_script/download_imprinting_test_data.sh ./
+ ```
+
+ This downloads 10 categories, 20 images for each category, saving it into
+ a directory named `open_image_v4_subset`.
+
+2) Start training the new classification model:
+ ```
+ python3 imprinting_learning.py \
+ --model_path test_data/mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite \
+ --data open_image_v4_subset \
+ --output ${HOME}/my_model.tflite
+ ```
+
+3) Run an inference with the new model:
+ ```
+ python3 classify_image.py \
+ --model my_model.tflite \
+ --label my_model.txt \
+ --input test_data/cat.bmp
+ ```
+
+For more information, see
+https://coral.ai/docs/edgetpu/retrain-classification-ondevice/
+"""
+
+import argparse
+import os
+import numpy as np
+from PIL import Image
+
+from pycoral.adapters import classify
+from pycoral.adapters import common
+from pycoral.learn.imprinting.engine import ImprintingEngine
+from pycoral.utils.edgetpu import make_interpreter
+
+
+def _read_data(path, test_ratio):
+ """Parses data from given directory, split them into two sets.
+
+ Args:
+ path: string, path of the data set. Images are stored in sub-directory named
+ by category.
+ test_ratio: float in (0,1), ratio of data used for testing.
+
+ Returns:
+ (train_set, test_set), A tuple of two dicts. Keys are the categories and
+ values are lists of image file names.
+ """
+ train_set = {}
+ test_set = {}
+ for category in os.listdir(path):
+ category_dir = os.path.join(path, category)
+ if os.path.isdir(category_dir):
+ images = [
+ f for f in os.listdir(category_dir)
+ if os.path.isfile(os.path.join(category_dir, f))
+ ]
+ if images:
+ k = max(int(test_ratio * len(images)), 1)
+ test_set[category] = images[:k]
+ assert test_set[category], 'No images to test [{}]'.format(category)
+ train_set[category] = images[k:]
+ assert train_set[category], 'No images to train [{}]'.format(category)
+ return train_set, test_set
+
+
+def _prepare_images(image_list, directory, shape):
+ """Reads images and converts them to numpy array with given shape.
+
+ Args:
+ image_list: a list of strings storing file names.
+ directory: string, path of directory storing input images.
+ shape: a 2-D tuple represents the shape of required input tensor.
+
+ Returns:
+ A list of numpy.array.
+ """
+ ret = []
+ for filename in image_list:
+ with Image.open(os.path.join(directory, filename)) as img:
+ img = img.convert('RGB')
+ img = img.resize(shape, Image.NEAREST)
+ ret.append(np.asarray(img))
+ return np.array(ret)
+
+
+def _save_labels(labels, model_path):
+ """Output labels as a txt file.
+
+ Args:
+ labels: {int : string}, map between label id and label.
+ model_path: string, path of the model.
+ """
+ label_file_name = model_path.replace('.tflite', '.txt')
+ with open(label_file_name, 'w') as f:
+ for label_id, label in labels.items():
+ f.write(str(label_id) + ' ' + label + '\n')
+ print('Labels file saved as :', label_file_name)
+
+
+def _parse_args():
+ """Parses args, set default values if it's not passed.
+
+ Returns:
+ Object with attributes. Each attribute represents an argument.
+ """
+ print('---------------------- Args ----------------------')
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--model_path', help='Path to the model path.', required=True)
+ parser.add_argument(
+ '--data',
+ help=('Path to the training set, images are stored'
+ 'under sub-directory named by category.'),
+ required=True)
+ parser.add_argument('--output', help='Name of the trained model.')
+ parser.add_argument(
+ '--test_ratio',
+ type=float,
+ help='Float number in (0,1), ratio of data used for test data.')
+ parser.add_argument(
+ '--keep_classes',
+ action='store_true',
+ help='Whether to keep base model classes.')
+ args = parser.parse_args()
+ if not args.output:
+ model_name = os.path.basename(args.model_path)
+ args.output = model_name.replace('.tflite', '_retrained.tflite')
+ print('Output path :', args.output)
+ # By default, choose 25% data for test.
+ if not args.test_ratio:
+ args.test_ratio = 0.25
+ assert args.test_ratio > 0
+ assert args.test_ratio < 1.0
+ print('Ratio of test images: {:.0%}'.format(args.test_ratio))
+ return args
+
+
+def main():
+ args = _parse_args()
+
+ engine = ImprintingEngine(args.model_path, keep_classes=args.keep_classes)
+ extractor = make_interpreter(engine.serialize_extractor_model(), device=':0')
+ extractor.allocate_tensors()
+ shape = common.input_size(extractor)
+
+ print('--------------- Parsing data set -----------------')
+ print('Dataset path:', args.data)
+
+ train_set, test_set = _read_data(args.data, args.test_ratio)
+ print('Image list successfully parsed! Category Num = ', len(train_set))
+
+ print('---------------- Processing training data ----------------')
+ print('This process may take more than 30 seconds.')
+ train_input = []
+ labels_map = {}
+ for class_id, (category, image_list) in enumerate(train_set.items()):
+ print('Processing category:', category)
+ train_input.append(
+ _prepare_images(image_list, os.path.join(args.data, category), shape))
+ labels_map[class_id] = category
+ print('---------------- Start training -----------------')
+ num_classes = engine.num_classes
+ for class_id, tensors in enumerate(train_input):
+ for tensor in tensors:
+ common.set_input(extractor, tensor)
+ extractor.invoke()
+ embedding = classify.get_scores(extractor)
+ engine.train(embedding, class_id=num_classes + class_id)
+ print('---------------- Training finished! -----------------')
+
+ with open(args.output, 'wb') as f:
+ f.write(engine.serialize_model())
+ print('Model saved as : ', args.output)
+ _save_labels(labels_map, args.output)
+
+ print('------------------ Start evaluating ------------------')
+ interpreter = make_interpreter(args.output)
+ interpreter.allocate_tensors()
+ size = common.input_size(interpreter)
+
+ top_k = 5
+ correct = [0] * top_k
+ wrong = [0] * top_k
+ for category, image_list in test_set.items():
+ print('Evaluating category [', category, ']')
+ for img_name in image_list:
+ img = Image.open(os.path.join(args.data, category,
+ img_name)).resize(size, Image.NEAREST)
+ common.set_input(interpreter, img)
+ interpreter.invoke()
+ candidates = classify.get_classes(interpreter, top_k, score_threshold=0.1)
+ recognized = False
+ for i in range(top_k):
+ if i < len(candidates) and labels_map[candidates[i].id] == category:
+ recognized = True
+ if recognized:
+ correct[i] = correct[i] + 1
+ else:
+ wrong[i] = wrong[i] + 1
+ print('---------------- Evaluation result -----------------')
+ for i in range(top_k):
+ print('Top {} : {:.0%}'.format(i + 1, correct[i] / (correct[i] + wrong[i])))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/model_pipelining_classify_image.py b/examples/model_pipelining_classify_image.py
new file mode 100644
index 0000000..230a5b1
--- /dev/null
+++ b/examples/model_pipelining_classify_image.py
@@ -0,0 +1,159 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+r"""Example to classify a given image using model pipelining with two Edge TPUs.
+
+To run this code, you must attach two Edge TPUs attached to the host and
+install the Edge TPU runtime (`libedgetpu.so`) and `tflite_runtime`. For
+device setup instructions, see g.co/coral/setup.
+
+Example usage (use `install_requirements.sh` to get these files):
+```
+python3 model_pipelining_classify_image.py \
+ --models test_data/inception_v3_299_quant_segment_%d_of_2_edgetpu.tflite \
+ --labels test_data/imagenet_labels.txt \
+ --input test_data/parrot.jpg
+```
+"""
+
+import argparse
+import re
+import threading
+import time
+
+import numpy as np
+from PIL import Image
+
+from pycoral.adapters import classify
+from pycoral.adapters import common
+import pycoral.pipeline.pipelined_model_runner as pipeline
+from pycoral.utils.dataset import read_label_file
+from pycoral.utils.edgetpu import list_edge_tpus
+from pycoral.utils.edgetpu import make_interpreter
+
+
+def _get_devices(num_devices):
+ """Returns list of device names in usb:N or pci:N format.
+
+ This function prefers returning PCI Edge TPU first.
+
+ Args:
+ num_devices: int, number of devices expected
+
+ Returns:
+ list of devices in pci:N and/or usb:N format
+
+ Raises:
+ RuntimeError: if not enough devices are available
+ """
+ edge_tpus = list_edge_tpus()
+
+ if len(edge_tpus) < num_devices:
+ raise RuntimeError(
+ 'Not enough Edge TPUs detected, expected %d, detected %d.' %
+ (num_devices, len(edge_tpus)))
+
+ num_pci_devices = sum(1 for device in edge_tpus if device['type'] == 'pci')
+
+ return ['pci:%d' % i for i in range(min(num_devices, num_pci_devices))] + [
+ 'usb:%d' % i for i in range(max(0, num_devices - num_pci_devices))
+ ]
+
+
+def _make_runner(model_paths, devices):
+ """Constructs PipelinedModelRunner given model paths and devices."""
+ print('Using devices: ', devices)
+ print('Using models: ', model_paths)
+
+ if len(model_paths) != len(devices):
+ raise ValueError('# of devices and # of model_paths should match')
+
+ interpreters = [make_interpreter(m, d) for m, d in zip(model_paths, devices)]
+ for interpreter in interpreters:
+ interpreter.allocate_tensors()
+ return pipeline.PipelinedModelRunner(interpreters)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ '-m',
+ '--models',
+ required=True,
+ help=('File path template of .tflite model segments, e.g.,'
+ 'inception_v3_299_quant_segment_%d_of_2_edgetpu.tflite'))
+ parser.add_argument(
+ '-i', '--input', required=True, help='Image to be classified.')
+ parser.add_argument(
+ '-l', '--labels', help='File path of labels file.')
+ parser.add_argument(
+ '-k', '--top_k', type=int, default=1,
+ help='Max number of classification results')
+ parser.add_argument(
+ '-t', '--threshold', type=float, default=0.0,
+ help='Classification score threshold')
+ parser.add_argument(
+ '-c', '--count', type=int, default=5,
+ help='Number of times to run inference')
+ args = parser.parse_args()
+ labels = read_label_file(args.labels) if args.labels else {}
+
+ result = re.search(r'^.*_segment_%d_of_(?P[0-9]+)_.*.tflite',
+ args.models)
+ if not result:
+ raise ValueError(
+ '--models should follow *_segment%d_of_[num_segments]_*.tflite pattern')
+ num_segments = int(result.group('num_segments'))
+ model_paths = [args.models % i for i in range(num_segments)]
+ devices = _get_devices(num_segments)
+ runner = _make_runner(model_paths, devices)
+
+ size = common.input_size(runner.interpreters()[0])
+ image = np.array(
+ Image.open(args.input).convert('RGB').resize(size, Image.ANTIALIAS))
+
+ def producer():
+ for _ in range(args.count):
+ runner.push([image])
+ runner.push([])
+
+ def consumer():
+ output_details = runner.interpreters()[-1].get_output_details()[0]
+ scale, zero_point = output_details['quantization']
+ while True:
+ result = runner.pop()
+ if not result:
+ break
+ scores = scale * (result[0][0].astype(np.int64) - zero_point)
+ classes = classify.get_classes_from_scores(scores, args.top_k,
+ args.threshold)
+ print('-------RESULTS--------')
+ for klass in classes:
+ print('%s: %.5f' % (labels.get(klass.id, klass.id), klass.score))
+
+ start = time.perf_counter()
+ producer_thread = threading.Thread(target=producer)
+ consumer_thread = threading.Thread(target=consumer)
+ producer_thread.start()
+ consumer_thread.start()
+ producer_thread.join()
+ consumer_thread.join()
+ average_time_ms = (time.perf_counter() - start) / args.count * 1000
+ print('Average inference time (over %d iterations): %.1fms' %
+ (args.count, average_time_ms))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/semantic_segmentation.py b/examples/semantic_segmentation.py
new file mode 100644
index 0000000..6f372b7
--- /dev/null
+++ b/examples/semantic_segmentation.py
@@ -0,0 +1,132 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+r"""An example of semantic segmentation.
+
+The following command runs this script and saves a new image showing the
+segmented pixels at the location specified by `output`:
+
+python3 examples/semantic_segmentation.py \
+ --model test_data/deeplabv3_mnv2_pascal_quant_edgetpu.tflite \
+ --input test_data/bird.bmp \
+ --keep_aspect_ratio \
+ --output ${HOME}/segmentation_result.jpg
+"""
+
+import argparse
+
+import numpy as np
+from PIL import Image
+
+from pycoral.adapters import common
+from pycoral.adapters import segment
+from pycoral.utils.edgetpu import make_interpreter
+
+
+def create_pascal_label_colormap():
+ """Creates a label colormap used in PASCAL VOC segmentation benchmark.
+
+ Returns:
+ A Colormap for visualizing segmentation results.
+ """
+ colormap = np.zeros((256, 3), dtype=int)
+ indices = np.arange(256, dtype=int)
+
+ for shift in reversed(range(8)):
+ for channel in range(3):
+ colormap[:, channel] |= ((indices >> channel) & 1) << shift
+ indices >>= 3
+
+ return colormap
+
+
+def label_to_color_image(label):
+ """Adds color defined by the dataset colormap to the label.
+
+ Args:
+ label: A 2D array with integer type, storing the segmentation label.
+
+ Returns:
+ result: A 2D array with floating type. The element of the array
+ is the color indexed by the corresponding element in the input label
+ to the PASCAL color map.
+
+ Raises:
+ ValueError: If label is not of rank 2 or its value is larger than color
+ map maximum entry.
+ """
+ if label.ndim != 2:
+ raise ValueError('Expect 2-D input label')
+
+ colormap = create_pascal_label_colormap()
+
+ if np.max(label) >= len(colormap):
+ raise ValueError('label value too large.')
+
+ return colormap[label]
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model', required=True,
+ help='Path of the segmentation model.')
+ parser.add_argument('--input', required=True,
+ help='File path of the input image.')
+ parser.add_argument('--output', default='semantic_segmentation_result.jpg',
+ help='File path of the output image.')
+ parser.add_argument(
+ '--keep_aspect_ratio',
+ action='store_true',
+ default=False,
+ help=(
+ 'keep the image aspect ratio when down-sampling the image by adding '
+ 'black pixel padding (zeros) on bottom or right. '
+ 'By default the image is resized and reshaped without cropping. This '
+ 'option should be the same as what is applied on input images during '
+ 'model training. Otherwise the accuracy may be affected and the '
+ 'bounding box of detection result may be stretched.'))
+ args = parser.parse_args()
+
+ interpreter = make_interpreter(args.model, device=':0')
+ interpreter.allocate_tensors()
+ width, height = common.input_size(interpreter)
+
+ img = Image.open(args.input)
+ if args.keep_aspect_ratio:
+ resized_img, _ = common.set_resized_input(
+ interpreter, img.size, lambda size: img.resize(size, Image.ANTIALIAS))
+ else:
+ resized_img = img.resize((width, height), Image.ANTIALIAS)
+ common.set_input(interpreter, resized_img)
+
+ interpreter.invoke()
+
+ result = segment.get_output(interpreter)
+ if len(result.shape) == 3:
+ result = np.argmax(result, axis=-1)
+
+ # If keep_aspect_ratio, we need to remove the padding area.
+ new_width, new_height = resized_img.size
+ result = result[:new_height, :new_width]
+ mask_img = Image.fromarray(label_to_color_image(result).astype(np.uint8))
+
+ # Concat resized input image and processed segmentation results.
+ output_img = Image.new('RGB', (2 * new_width, new_height))
+ output_img.paste(resized_img, (0, 0))
+ output_img.paste(mask_img, (width, 0))
+ output_img.save(args.output)
+ print('Please check ', args.output)
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/small_object_detection.py b/examples/small_object_detection.py
new file mode 100644
index 0000000..820ed8d
--- /dev/null
+++ b/examples/small_object_detection.py
@@ -0,0 +1,242 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+r"""An example to perform object detection with an image with added supports for smaller objects.
+
+The following command runs this example for object detection using a
+MobileNet model trained with the COCO dataset (it can detect 90 types
+of objects):
+```
+python3 small_object_detection.py \
+ --model test_data/ssd_mobilenet_v2_coco_quant_no_nms_edgetpu.tflite \
+ --label test_data/coco_labels.txt \
+ --input test_data/kite_and_cold.jpg \
+ --tile_size 1352x900,500x500,250x250 \
+ --tile_overlap 50 \
+ --score_threshold 0.5 \
+ --output ${HOME}/object_detection_results.jpg
+```
+
+Note: this example demonstrate small object detection, using the method of
+splitting the original image into tiles with some added overlaps in consecutive
+tiles. The tile size can also be specified in multiple layers as
+demonstrated on the above command. With the overlapping tiles and layers,
+some object candidates may have overlapping bounding boxes. The example then
+uses Non-Maximum-Suppressions to suppress the overlapping bounding boxes on the
+same objects. It then saves the result of the given image at the location
+specified by `output`, with bounding boxes drawn around each detected object.
+
+In order to boost performance, the model has non_max_suppression striped from
+the post processing operator. To do this, we can re-export the checkpoint by
+setting the iou_threshold to 1. By doing so, we see an overall speedup of about
+2x on average.
+"""
+
+import argparse
+import collections
+
+import numpy as np
+from PIL import Image
+from PIL import ImageDraw
+
+from pycoral.adapters import common
+from pycoral.adapters import detect
+from pycoral.utils.dataset import read_label_file
+from pycoral.utils.edgetpu import make_interpreter
+
+Object = collections.namedtuple('Object', ['label', 'score', 'bbox'])
+
+
+def tiles_location_gen(img_size, tile_size, overlap):
+ """Generates location of tiles after splitting the given image according the tile_size and overlap.
+
+ Args:
+ img_size (int, int): size of original image as width x height.
+ tile_size (int, int): size of the returned tiles as width x height.
+ overlap (int): The number of pixels to overlap the tiles.
+
+ Yields:
+ A list of points representing the coordinates of the tile in xmin, ymin,
+ xmax, ymax.
+ """
+
+ tile_width, tile_height = tile_size
+ img_width, img_height = img_size
+ h_stride = tile_height - overlap
+ w_stride = tile_width - overlap
+ for h in range(0, img_height, h_stride):
+ for w in range(0, img_width, w_stride):
+ xmin = w
+ ymin = h
+ xmax = min(img_width, w + tile_width)
+ ymax = min(img_height, h + tile_height)
+ yield [xmin, ymin, xmax, ymax]
+
+
+def non_max_suppression(objects, threshold):
+ """Returns a list of indexes of objects passing the NMS.
+
+ Args:
+ objects: result candidates.
+ threshold: the threshold of overlapping IoU to merge the boxes.
+
+ Returns:
+ A list of indexes containings the objects that pass the NMS.
+ """
+ if len(objects) == 1:
+ return [0]
+
+ boxes = np.array([o.bbox for o in objects])
+ xmins = boxes[:, 0]
+ ymins = boxes[:, 1]
+ xmaxs = boxes[:, 2]
+ ymaxs = boxes[:, 3]
+
+ areas = (xmaxs - xmins) * (ymaxs - ymins)
+ scores = [o.score for o in objects]
+ idxs = np.argsort(scores)
+
+ selected_idxs = []
+ while idxs.size != 0:
+
+ selected_idx = idxs[-1]
+ selected_idxs.append(selected_idx)
+
+ overlapped_xmins = np.maximum(xmins[selected_idx], xmins[idxs[:-1]])
+ overlapped_ymins = np.maximum(ymins[selected_idx], ymins[idxs[:-1]])
+ overlapped_xmaxs = np.minimum(xmaxs[selected_idx], xmaxs[idxs[:-1]])
+ overlapped_ymaxs = np.minimum(ymaxs[selected_idx], ymaxs[idxs[:-1]])
+
+ w = np.maximum(0, overlapped_xmaxs - overlapped_xmins)
+ h = np.maximum(0, overlapped_ymaxs - overlapped_ymins)
+
+ intersections = w * h
+ unions = areas[idxs[:-1]] + areas[selected_idx] - intersections
+ ious = intersections / unions
+
+ idxs = np.delete(
+ idxs, np.concatenate(([len(idxs) - 1], np.where(ious > threshold)[0])))
+
+ return selected_idxs
+
+
+def draw_object(draw, obj):
+ """Draws detection candidate on the image.
+
+ Args:
+ draw: the PIL.ImageDraw object that draw on the image.
+ obj: The detection candidate.
+ """
+ draw.rectangle(obj.bbox, outline='red')
+ draw.text((obj.bbox[0], obj.bbox[3]), obj.label, fill='#0000')
+ draw.text((obj.bbox[0], obj.bbox[3] + 10), str(obj.score), fill='#0000')
+
+
+def reposition_bounding_box(bbox, tile_location):
+ """Relocates bbox to the relative location to the original image.
+
+ Args:
+ bbox (int, int, int, int): bounding box relative to tile_location as xmin,
+ ymin, xmax, ymax.
+ tile_location (int, int, int, int): tile_location in the original image as
+ xmin, ymin, xmax, ymax.
+
+ Returns:
+ A list of points representing the location of the bounding box relative to
+ the original image as xmin, ymin, xmax, ymax.
+ """
+ bbox[0] = bbox[0] + tile_location[0]
+ bbox[1] = bbox[1] + tile_location[1]
+ bbox[2] = bbox[2] + tile_location[0]
+ bbox[3] = bbox[3] + tile_location[1]
+ return bbox
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--model',
+ required=True,
+ help='Detection SSD model path (must have post-processing operator).')
+ parser.add_argument('--label', help='Labels file path.')
+ parser.add_argument(
+ '--score_threshold',
+ help='Threshold for returning the candidates.',
+ type=float,
+ default=0.1)
+ parser.add_argument(
+ '--tile_sizes',
+ help=('Sizes of the tiles to split, could be more than one layer as a '
+ 'list a with comma delimiter in widthxheight. Example: '
+ '"300x300,250x250,.."'),
+ required=True)
+ parser.add_argument(
+ '--tile_overlap',
+ help=('Number of pixels to overlap the tiles. tile_overlap should be >= '
+ 'than half of the min desired object size, otherwise small objects '
+ 'could be missed on the tile boundary.'),
+ type=int,
+ default=15)
+ parser.add_argument(
+ '--iou_threshold',
+ help=('threshold to merge bounding box duing nms'),
+ type=float,
+ default=.1)
+ parser.add_argument('--input', help='Input image path.', required=True)
+ parser.add_argument('--output', help='Output image path.')
+ args = parser.parse_args()
+
+ interpreter = make_interpreter(args.model)
+ interpreter.allocate_tensors()
+ labels = read_label_file(args.label) if args.label else {}
+
+ # Open image.
+ img = Image.open(args.input).convert('RGB')
+ draw = ImageDraw.Draw(img)
+
+ objects_by_label = dict()
+ img_size = img.size
+ tile_sizes = [
+ map(int, tile_size.split('x')) for tile_size in args.tile_sizes.split(',')
+ ]
+ for tile_size in tile_sizes:
+ for tile_location in tiles_location_gen(img_size, tile_size,
+ args.tile_overlap):
+ tile = img.crop(tile_location)
+ _, scale = common.set_resized_input(
+ interpreter, tile.size,
+ lambda size, img=tile: img.resize(size, Image.NEAREST))
+ interpreter.invoke()
+ objs = detect.get_objects(interpreter, args.score_threshold, scale)
+
+ for obj in objs:
+ bbox = [obj.bbox.xmin, obj.bbox.ymin, obj.bbox.xmax, obj.bbox.ymax]
+ bbox = reposition_bounding_box(bbox, tile_location)
+
+ label = labels.get(obj.id, '')
+ objects_by_label.setdefault(label,
+ []).append(Object(label, obj.score, bbox))
+
+ for label, objects in objects_by_label.items():
+ idxs = non_max_suppression(objects, args.iou_threshold)
+ for idx in idxs:
+ draw_object(draw, objects[idx])
+
+ img.show()
+ if args.output:
+ img.save(args.output)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/two_models_inference.py b/examples/two_models_inference.py
new file mode 100644
index 0000000..5264121
--- /dev/null
+++ b/examples/two_models_inference.py
@@ -0,0 +1,193 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Demo to show running two models on one/two Edge TPU devices.
+
+This is a dummy example that compares running two different models using one
+Edge TPU vs two Edge TPUs. It requires that your system includes two Edge TPU
+devices.
+
+You give the script one classification model and one
+detection model, and it runs each model the number of times specified with the
+`num_inferences` argument, using the same image. It then reports the time
+spent using either one or two Edge TPU devices.
+
+Note: Running two models alternatively with one Edge TPU is cache unfriendly,
+as each model continuously kicks the other model off the device's cache when
+they each run. In this case, running several inferences with one model in a
+batch before switching to another model can help to some extent. It's also
+possible to co-compile both models so they can be cached simultaneously
+(if they fit; read more at coral.ai/docs/edgetpu/compiler/). But using two
+Edge TPUs with two threads can help more.
+"""
+
+import argparse
+import contextlib
+import threading
+import time
+from PIL import Image
+
+from pycoral.adapters import classify
+from pycoral.adapters import common
+from pycoral.adapters import detect
+from pycoral.utils.edgetpu import list_edge_tpus
+from pycoral.utils.edgetpu import make_interpreter
+
+
+@contextlib.contextmanager
+def open_image(path):
+ with open(path, 'rb') as f:
+ with Image.open(f) as image:
+ yield image
+
+
+def run_two_models_one_tpu(classification_model, detection_model, image_name,
+ num_inferences, batch_size):
+ """Runs two models ALTERNATIVELY using one Edge TPU.
+
+ It runs classification model `batch_size` times and then switch to run
+ detection model `batch_size` time until each model is run `num_inferences`
+ times.
+
+ Args:
+ classification_model: string, path to classification model
+ detection_model: string, path to detection model.
+ image_name: string, path to input image.
+ num_inferences: int, number of inferences to run for each model.
+ batch_size: int, indicates how many inferences to run one model before
+ switching to the other one.
+
+ Returns:
+ double, wall time it takes to finish the job.
+ """
+ start_time = time.perf_counter()
+ interpreter_a = make_interpreter(classification_model, device=':0')
+ interpreter_a.allocate_tensors()
+ interpreter_b = make_interpreter(detection_model, device=':0')
+ interpreter_b.allocate_tensors()
+
+ with open_image(image_name) as image:
+ size_a = common.input_size(interpreter_a)
+ common.set_input(interpreter_a, image.resize(size_a, Image.NEAREST))
+ _, scale_b = common.set_resized_input(
+ interpreter_b, image.size,
+ lambda size: image.resize(size, Image.NEAREST))
+
+ num_iterations = (num_inferences + batch_size - 1) // batch_size
+ for _ in range(num_iterations):
+ for _ in range(batch_size):
+ interpreter_a.invoke()
+ classify.get_classes(interpreter_a, top_k=1)
+ for _ in range(batch_size):
+ interpreter_b.invoke()
+ detect.get_objects(interpreter_b, score_threshold=0., image_scale=scale_b)
+ return time.perf_counter() - start_time
+
+
+def run_two_models_two_tpus(classification_model, detection_model, image_name,
+ num_inferences):
+ """Runs two models using two Edge TPUs with two threads.
+
+ Args:
+ classification_model: string, path to classification model
+ detection_model: string, path to detection model.
+ image_name: string, path to input image.
+ num_inferences: int, number of inferences to run for each model.
+
+ Returns:
+ double, wall time it takes to finish the job.
+ """
+
+ def classification_job(classification_model, image_name, num_inferences):
+ """Runs classification job."""
+ interpreter = make_interpreter(classification_model, device=':0')
+ interpreter.allocate_tensors()
+ size = common.input_size(interpreter)
+ with open_image(image_name) as image:
+ common.set_input(interpreter, image.resize(size, Image.NEAREST))
+
+ for _ in range(num_inferences):
+ interpreter.invoke()
+ classify.get_classes(interpreter, top_k=1)
+
+ def detection_job(detection_model, image_name, num_inferences):
+ """Runs detection job."""
+ interpreter = make_interpreter(detection_model, device=':1')
+ interpreter.allocate_tensors()
+ with open_image(image_name) as image:
+ _, scale = common.set_resized_input(
+ interpreter, image.size,
+ lambda size: image.resize(size, Image.NEAREST))
+
+ for _ in range(num_inferences):
+ interpreter.invoke()
+ detect.get_objects(interpreter, score_threshold=0., image_scale=scale)
+
+ start_time = time.perf_counter()
+ classification_thread = threading.Thread(
+ target=classification_job,
+ args=(classification_model, image_name, num_inferences))
+ detection_thread = threading.Thread(
+ target=detection_job, args=(detection_model, image_name, num_inferences))
+
+ classification_thread.start()
+ detection_thread.start()
+ classification_thread.join()
+ detection_thread.join()
+ return time.perf_counter() - start_time
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--classification_model',
+ help='Path of classification model.',
+ required=True)
+ parser.add_argument(
+ '--detection_model', help='Path of detection model.', required=True)
+ parser.add_argument('--image', help='Path of the image.', required=True)
+ parser.add_argument(
+ '--num_inferences',
+ help='Number of inferences to run.',
+ type=int,
+ default=2000)
+ parser.add_argument(
+ '--batch_size',
+ help='Runs one model batch_size times before switching to the other.',
+ type=int,
+ default=10)
+
+ args = parser.parse_args()
+
+ if len(list_edge_tpus()) <= 1:
+ raise RuntimeError('This demo requires at least two Edge TPU available.')
+
+ print('Running %s and %s with one Edge TPU, # inferences %d, batch_size %d.' %
+ (args.classification_model, args.detection_model, args.num_inferences,
+ args.batch_size))
+ cost_one_tpu = run_two_models_one_tpu(args.classification_model,
+ args.detection_model, args.image,
+ args.num_inferences, args.batch_size)
+ print('Running %s and %s with two Edge TPUs, # inferences %d.' %
+ (args.classification_model, args.detection_model, args.num_inferences))
+ cost_two_tpus = run_two_models_two_tpus(args.classification_model,
+ args.detection_model, args.image,
+ args.num_inferences)
+
+ print('Inference with one Edge TPU costs %.2f seconds.' % cost_one_tpu)
+ print('Inference with two Edge TPUs costs %.2f seconds.' % cost_two_tpus)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/libcoral b/libcoral
new file mode 160000
index 0000000..9824265
--- /dev/null
+++ b/libcoral
@@ -0,0 +1 @@
+Subproject commit 982426546dfa10128376d0c24fd8a8b161daac97
diff --git a/libedgetpu b/libedgetpu
new file mode 160000
index 0000000..14eee1a
--- /dev/null
+++ b/libedgetpu
@@ -0,0 +1 @@
+Subproject commit 14eee1a076aa1af7ec1ae3c752be79ae2604a708
diff --git a/libedgetpu_bin/LICENSE.txt b/libedgetpu_bin/LICENSE.txt
new file mode 100644
index 0000000..9dadd0b
--- /dev/null
+++ b/libedgetpu_bin/LICENSE.txt
@@ -0,0 +1,7 @@
+Copyright 2019 Google LLC. This software is provided as-is, without warranty
+or representation for any use or purpose. Your use of it is subject to your
+agreements with Google covering this software, or if no such agreement
+applies, your use is subject to a limited, non-transferable, non-exclusive
+license solely to run the software for your testing use, unless and until
+revoked by Google.
+
diff --git a/libedgetpu_bin/Makefile b/libedgetpu_bin/Makefile
new file mode 100644
index 0000000..18a8eb4
--- /dev/null
+++ b/libedgetpu_bin/Makefile
@@ -0,0 +1,21 @@
+SHELL := /bin/bash
+MAKEFILE_DIR := $(realpath $(dir $(lastword $(MAKEFILE_LIST))))
+
+.PHONY: all \
+ deb \
+ help
+
+all: deb
+
+deb:
+ dpkg-buildpackage -rfakeroot -us -uc -tc -b
+ dpkg-buildpackage -rfakeroot -us -uc -tc -b -a armhf -d
+ dpkg-buildpackage -rfakeroot -us -uc -tc -b -a arm64 -d
+ mkdir -p $(MAKEFILE_DIR)/../dist
+ mv $(MAKEFILE_DIR)/../*.{deb,changes,buildinfo} \
+ $(MAKEFILE_DIR)/../dist
+
+help:
+ @echo "make all - Build Debian packages for all platforms"
+ @echo "make help - Print help message"
+
diff --git a/libedgetpu_bin/debian/changelog b/libedgetpu_bin/debian/changelog
new file mode 100644
index 0000000..e99e249
--- /dev/null
+++ b/libedgetpu_bin/debian/changelog
@@ -0,0 +1,57 @@
+libedgetpu (15.0) stable; urgency=medium
+ * New release
+ -- Coral Mon, 02 Nov 2020 10:58:23 -0800
+libedgetpu (14.1) stable; urgency=medium
+ * New release
+ -- Coral Tue, 07 Jul 2020 13:47:32 -0700
+libedgetpu (14.0) stable; urgency=medium
+ * New release
+ -- Coral Wed, 25 Mar 2020 14:25:24 -0700
+libedgetpu (13.0) stable; urgency=medium
+ * New release
+ -- Coral Tue, 28 Jan 2020 15:58:19 -0700
+libedgetpu (12.1-1) mendel-chef; urgency=medium
+ * New release
+ -- Coral Wed, 30 Oct 2019 15:58:16 -0700
+libedgetpu (12-1) mendel-chef; urgency=medium
+ * New release
+ -- Coral Mon, 16 Sep 2019 13:27:18 -0700
+libedgetpu (11-1) mendel-chef; urgency=medium
+ * New release
+ -- Coral Mon, 15 Jul 2019 15:52:14 -0700
+libedgetpu (10-2) mendel-chef; urgency=medium
+ * New release
+ -- Coral Thu, 18 Apr 2019 13:37:19 -0700
+libedgetpu (9-2) mendel-chef; urgency=medium
+ * New release
+ -- Coral Wed, 03 Apr 2019 14:11:47 -0800
+libedgetpu (8-2) mendel-chef; urgency=medium
+ * New release
+ -- Coral Tue, 02 Apr 2019 14:11:47 -0800
+libedgetpu (7-2) mendel-chef; urgency=medium
+ * New release
+ -- Coral Thu, 28 Mar 2019 14:11:47 -0800
+libedgetpu (6-2) mendel-chef; urgency=medium
+ * New release
+ -- Coral Tue, 19 Mar 2019 17:08:25 -0700
+libedgetpu (5-2) mendel-beaker; urgency=medium
+ * New release
+ -- Coral Fri, 08 Mar 2019 14:11:47 -0800
+libedgetpu (4-2) mendel-beaker; urgency=medium
+ * New release
+ -- Coral Wed, 27 Feb 2019 11:00:25 -0800
+libedgetpu (3-2) mendel-beaker; urgency=medium
+ * New release
+ -- Coral Mon, 04 Feb 2019 11:20:13 -0800
+libedgetpu (3-1) mendel-beaker; urgency=medium
+ * New release
+ -- Coral Mon, 28 Jan 2019 14:10:00 -0800
+libedgetpu (2-1) mendel-beaker; urgency=medium
+ * New release
+ -- Coral Tue, 22 Jan 2019 10:42:07 -0800
+libedgetpu (1-1) mendel-beaker; urgency=medium
+ * New release
+ -- Coral Wed, 16 Jan 2019 12:00:00 -0800
+libedgetpu (0.1) UNRELEASED; urgency=medium
+ * Initial release.
+ -- Coral Mon, 04 Jun 2018 16:14:00 -0800
diff --git a/libedgetpu_bin/debian/compat b/libedgetpu_bin/debian/compat
new file mode 100644
index 0000000..f599e28
--- /dev/null
+++ b/libedgetpu_bin/debian/compat
@@ -0,0 +1 @@
+10
diff --git a/libedgetpu_bin/debian/control b/libedgetpu_bin/debian/control
new file mode 100644
index 0000000..71fd6b2
--- /dev/null
+++ b/libedgetpu_bin/debian/control
@@ -0,0 +1,45 @@
+Source: libedgetpu
+Maintainer: Coral
+Priority: optional
+Build-Depends: debhelper (>= 9)
+Standards-Version: 3.9.6
+Homepage: https://coral.ai/
+
+Package: libedgetpu1-std
+Provides: libedgetpu1 (= ${binary:Version})
+Conflicts: libedgetpu1, libedgetpu1-legacy
+Section: misc
+Priority: optional
+Architecture: any
+Multi-Arch: same
+Depends: libc6,
+ libgcc1,
+ libstdc++6,
+ libusb-1.0-0,
+ ${misc:Depends}
+Description: Support library for Edge TPU
+ Support library (standard speed) for the Edge TPU
+
+Package: libedgetpu1-max
+Provides: libedgetpu1 (= ${binary:Version})
+Conflicts: libedgetpu1, libedgetpu1-legacy
+Section: misc
+Priority: optional
+Architecture: any
+Multi-Arch: same
+Depends: libc6,
+ libgcc1,
+ libstdc++6,
+ libusb-1.0-0,
+ ${misc:Depends}
+Description: Support library for Edge TPU
+ Support library (max speed) for the Edge TPU
+
+Package:libedgetpu-dev
+Section: libdevel
+Priority: optional
+Architecture: any
+Depends: libedgetpu1-std (= ${binary:Version}) | libedgetpu1 (= ${binary:Version}),
+ ${misc:Depends}
+Description: Development files for libedgetpu
+ This package contains C++ Header files for libedgetpu.so
diff --git a/libedgetpu_bin/debian/copyright b/libedgetpu_bin/debian/copyright
new file mode 100644
index 0000000..d4bf188
--- /dev/null
+++ b/libedgetpu_bin/debian/copyright
@@ -0,0 +1,7 @@
+Format: http://www.debian.org/doc/packaging-manuals/copyright-format/1.0/
+Upstream-Name: edgetpu
+Source: https://github.com/google-coral/edgetpu
+
+Files: *
+Copyright: Copyright 2018 Google, LLC
+License: Apache-2.0
diff --git a/libedgetpu_bin/debian/libedgetpu-dev.install b/libedgetpu_bin/debian/libedgetpu-dev.install
new file mode 100644
index 0000000..e7c66ba
--- /dev/null
+++ b/libedgetpu_bin/debian/libedgetpu-dev.install
@@ -0,0 +1,2 @@
+edgetpu.h /usr/include
+edgetpu_c.h /usr/include
diff --git a/libedgetpu_bin/debian/libedgetpu1-max.lintian-overrides b/libedgetpu_bin/debian/libedgetpu1-max.lintian-overrides
new file mode 100644
index 0000000..3591ed2
--- /dev/null
+++ b/libedgetpu_bin/debian/libedgetpu1-max.lintian-overrides
@@ -0,0 +1,4 @@
+# We provide two conflicting package variants with the same soname inside.
+libedgetpu1-max: package-name-doesnt-match-sonames libedgetpu1
+libedgetpu1-max: missing-debconf-dependency-for-preinst
+libedgetpu1-max: too-long-short-description-in-templates libedgetpu/accepted-eula
diff --git a/libedgetpu_bin/debian/libedgetpu1-max.preinst b/libedgetpu_bin/debian/libedgetpu1-max.preinst
new file mode 100644
index 0000000..5a7a763
--- /dev/null
+++ b/libedgetpu_bin/debian/libedgetpu1-max.preinst
@@ -0,0 +1,24 @@
+#!/bin/sh -e
+
+. /usr/share/debconf/confmodule
+
+db_version 2.0
+
+db_get libedgetpu/accepted-eula
+if [ "$RET" = "true" ]; then
+ exit 0 # already accepted
+fi
+
+db_fset libedgetpu/accepted-eula seen false
+db_input critical libedgetpu/accepted-eula
+db_go || true
+
+db_get libedgetpu/accepted-eula
+if [ "$RET" = "true" ]; then
+ exit 0 # accepted
+fi
+
+db_input critical libedgetpu/error-eula
+db_go || true
+
+exit 1 # not accepted
diff --git a/libedgetpu_bin/debian/libedgetpu1-max.templates b/libedgetpu_bin/debian/libedgetpu1-max.templates
new file mode 100644
index 0000000..6c276be
--- /dev/null
+++ b/libedgetpu_bin/debian/libedgetpu1-max.templates
@@ -0,0 +1,21 @@
+Template: libedgetpu/accepted-eula
+Type: boolean
+Default: false
+Description: Continue to install the Edge TPU runtime that runs at the maximum operating frequency?
+ You're about to install the Edge TPU runtime that runs at the maximum operating frequency.
+ .
+ Warning: If you're using the Coral USB Accelerator, it may heat up during operation, depending
+ on the computation workloads and operating frequency. Touching the metal part of the USB
+ Accelerator after it has been operating for an extended period of time may lead to discomfort
+ and/or skin burns. As such, if you install the Edge TPU runtime using the maximum operating
+ frequency, the USB Accelerator should be operated at an ambient temperature of 25°C or less.
+ (If you instead install the Edge TPU runtime using the reduced operating frequency, then the
+ device is intended to safely operate at an ambient temperature of 35°C or less.)
+ .
+ Google does not accept any responsibility for any loss or damage if the device is operated
+ outside of the recommended ambient temperature range.
+
+Template: libedgetpu/error-eula
+Type: error
+Description: Install aborted.
+ For help setting up your device, see g.co/coral/setup.
diff --git a/libedgetpu_bin/debian/libedgetpu1-max.triggers b/libedgetpu_bin/debian/libedgetpu1-max.triggers
new file mode 100644
index 0000000..dd86603
--- /dev/null
+++ b/libedgetpu_bin/debian/libedgetpu1-max.triggers
@@ -0,0 +1 @@
+activate-noawait ldconfig
diff --git a/libedgetpu_bin/debian/libedgetpu1-max.udev b/libedgetpu_bin/debian/libedgetpu1-max.udev
new file mode 120000
index 0000000..e52da57
--- /dev/null
+++ b/libedgetpu_bin/debian/libedgetpu1-max.udev
@@ -0,0 +1 @@
+../edgetpu-accelerator.rules
\ No newline at end of file
diff --git a/libedgetpu_bin/debian/libedgetpu1-std.lintian-overrides b/libedgetpu_bin/debian/libedgetpu1-std.lintian-overrides
new file mode 100644
index 0000000..cfa2624
--- /dev/null
+++ b/libedgetpu_bin/debian/libedgetpu1-std.lintian-overrides
@@ -0,0 +1,2 @@
+# We provide two conflicting package variants with the same soname inside.
+libedgetpu1-std: package-name-doesnt-match-sonames libedgetpu1
diff --git a/libedgetpu_bin/debian/libedgetpu1-std.triggers b/libedgetpu_bin/debian/libedgetpu1-std.triggers
new file mode 100644
index 0000000..dd86603
--- /dev/null
+++ b/libedgetpu_bin/debian/libedgetpu1-std.triggers
@@ -0,0 +1 @@
+activate-noawait ldconfig
diff --git a/libedgetpu_bin/debian/libedgetpu1-std.udev b/libedgetpu_bin/debian/libedgetpu1-std.udev
new file mode 120000
index 0000000..e52da57
--- /dev/null
+++ b/libedgetpu_bin/debian/libedgetpu1-std.udev
@@ -0,0 +1 @@
+../edgetpu-accelerator.rules
\ No newline at end of file
diff --git a/libedgetpu_bin/debian/rules b/libedgetpu_bin/debian/rules
new file mode 100755
index 0000000..4fbaee5
--- /dev/null
+++ b/libedgetpu_bin/debian/rules
@@ -0,0 +1,46 @@
+#!/usr/bin/make -f
+# -*- makefile -*-
+
+# Uncomment this to turn on verbose mode.
+# export DH_VERBOSE=1
+FILENAME := libedgetpu.so.1.0
+SONAME := libedgetpu.so.1
+LIB_DEV := debian/libedgetpu-dev/usr/lib/$(DEB_HOST_GNU_TYPE)
+LIB_STD := debian/libedgetpu1-std/usr/lib/$(DEB_HOST_GNU_TYPE)
+LIB_MAX := debian/libedgetpu1-max/usr/lib/$(DEB_HOST_GNU_TYPE)
+
+ifeq ($(DEB_TARGET_ARCH),armhf)
+ CPU := armv7a
+else ifeq ($(DEB_TARGET_ARCH),arm64)
+ CPU := aarch64
+else ifeq ($(DEB_TARGET_ARCH),amd64)
+ CPU := k8
+endif
+
+%:
+ dh $@
+
+override_dh_auto_install:
+ dh_auto_install
+
+ mkdir -p $(LIB_DEV)
+ ln -fs $(FILENAME) $(LIB_DEV)/libedgetpu.so
+
+ mkdir -p $(LIB_STD)
+ cp -f throttled/$(CPU)/$(FILENAME) $(LIB_STD)/$(FILENAME)
+ ln -fs $(FILENAME) $(LIB_STD)/$(SONAME)
+
+ mkdir -p $(LIB_MAX)
+ cp -f direct/$(CPU)/$(FILENAME) $(LIB_MAX)/$(FILENAME)
+ ln -fs $(FILENAME) $(LIB_MAX)/$(SONAME)
+
+# Skip auto build and auto clean.
+override_dh_auto_clean:
+override_dh_auto_build:
+
+# Skip .so post processing.
+override_dh_strip:
+override_dh_shlibdeps:
+
+# Skip tests.
+override_dh_auto_test:
diff --git a/libedgetpu_bin/direct/aarch64/libedgetpu.so.1 b/libedgetpu_bin/direct/aarch64/libedgetpu.so.1
new file mode 120000
index 0000000..90ac68c
--- /dev/null
+++ b/libedgetpu_bin/direct/aarch64/libedgetpu.so.1
@@ -0,0 +1 @@
+libedgetpu.so.1.0
\ No newline at end of file
diff --git a/libedgetpu_bin/direct/aarch64/libedgetpu.so.1.0 b/libedgetpu_bin/direct/aarch64/libedgetpu.so.1.0
new file mode 100755
index 0000000..84b140c
Binary files /dev/null and b/libedgetpu_bin/direct/aarch64/libedgetpu.so.1.0 differ
diff --git a/libedgetpu_bin/direct/armv6/libedgetpu.so.1 b/libedgetpu_bin/direct/armv6/libedgetpu.so.1
new file mode 120000
index 0000000..90ac68c
--- /dev/null
+++ b/libedgetpu_bin/direct/armv6/libedgetpu.so.1
@@ -0,0 +1 @@
+libedgetpu.so.1.0
\ No newline at end of file
diff --git a/libedgetpu_bin/direct/armv6/libedgetpu.so.1.0 b/libedgetpu_bin/direct/armv6/libedgetpu.so.1.0
new file mode 100755
index 0000000..6b4f3fc
Binary files /dev/null and b/libedgetpu_bin/direct/armv6/libedgetpu.so.1.0 differ
diff --git a/libedgetpu_bin/direct/armv7a/libedgetpu.so.1 b/libedgetpu_bin/direct/armv7a/libedgetpu.so.1
new file mode 120000
index 0000000..90ac68c
--- /dev/null
+++ b/libedgetpu_bin/direct/armv7a/libedgetpu.so.1
@@ -0,0 +1 @@
+libedgetpu.so.1.0
\ No newline at end of file
diff --git a/libedgetpu_bin/direct/armv7a/libedgetpu.so.1.0 b/libedgetpu_bin/direct/armv7a/libedgetpu.so.1.0
new file mode 100755
index 0000000..4b8c73e
Binary files /dev/null and b/libedgetpu_bin/direct/armv7a/libedgetpu.so.1.0 differ
diff --git a/libedgetpu_bin/direct/darwin/libedgetpu.1.0.dylib b/libedgetpu_bin/direct/darwin/libedgetpu.1.0.dylib
new file mode 100755
index 0000000..3561512
Binary files /dev/null and b/libedgetpu_bin/direct/darwin/libedgetpu.1.0.dylib differ
diff --git a/libedgetpu_bin/direct/darwin/libedgetpu.1.dylib b/libedgetpu_bin/direct/darwin/libedgetpu.1.dylib
new file mode 120000
index 0000000..e2c7584
--- /dev/null
+++ b/libedgetpu_bin/direct/darwin/libedgetpu.1.dylib
@@ -0,0 +1 @@
+libedgetpu.1.0.dylib
\ No newline at end of file
diff --git a/libedgetpu_bin/direct/k8/libedgetpu.so.1 b/libedgetpu_bin/direct/k8/libedgetpu.so.1
new file mode 120000
index 0000000..90ac68c
--- /dev/null
+++ b/libedgetpu_bin/direct/k8/libedgetpu.so.1
@@ -0,0 +1 @@
+libedgetpu.so.1.0
\ No newline at end of file
diff --git a/libedgetpu_bin/direct/k8/libedgetpu.so.1.0 b/libedgetpu_bin/direct/k8/libedgetpu.so.1.0
new file mode 100755
index 0000000..8d5cd6a
Binary files /dev/null and b/libedgetpu_bin/direct/k8/libedgetpu.so.1.0 differ
diff --git a/libedgetpu_bin/direct/x64_windows/edgetpu.dll b/libedgetpu_bin/direct/x64_windows/edgetpu.dll
new file mode 100644
index 0000000..c07d541
Binary files /dev/null and b/libedgetpu_bin/direct/x64_windows/edgetpu.dll differ
diff --git a/libedgetpu_bin/direct/x64_windows/edgetpu.dll.if.lib b/libedgetpu_bin/direct/x64_windows/edgetpu.dll.if.lib
new file mode 100644
index 0000000..7a5bf61
Binary files /dev/null and b/libedgetpu_bin/direct/x64_windows/edgetpu.dll.if.lib differ
diff --git a/libedgetpu_bin/edgetpu-accelerator.rules b/libedgetpu_bin/edgetpu-accelerator.rules
new file mode 100644
index 0000000..60e034c
--- /dev/null
+++ b/libedgetpu_bin/edgetpu-accelerator.rules
@@ -0,0 +1,2 @@
+SUBSYSTEM=="usb",ATTRS{idVendor}=="1a6e",GROUP="plugdev"
+SUBSYSTEM=="usb",ATTRS{idVendor}=="18d1",GROUP="plugdev"
diff --git a/libedgetpu_bin/edgetpu.h b/libedgetpu_bin/edgetpu.h
new file mode 100644
index 0000000..ed4e4e2
--- /dev/null
+++ b/libedgetpu_bin/edgetpu.h
@@ -0,0 +1,317 @@
+/*
+Copyright 2018 Google LLC
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// @cond BEGIN doxygen exclude
+// This header file defines EdgeTpuManager and EdgeTpuContext.
+// See below for more details.
+
+#ifndef TFLITE_PUBLIC_EDGETPU_H_
+#define TFLITE_PUBLIC_EDGETPU_H_
+
+// If the ABI changes in a backward-incompatible way, please increment the
+// version number in the BUILD file.
+
+#include
+#include
+#include
+#include
+#include
+
+#include "tensorflow/lite/context.h"
+
+#if defined(_WIN32)
+#ifdef EDGETPU_COMPILE_LIBRARY
+#define EDGETPU_EXPORT __declspec(dllexport)
+#else
+#define EDGETPU_EXPORT __declspec(dllimport)
+#endif // EDGETPU_COMPILE_LIBRARY
+#else
+#define EDGETPU_EXPORT __attribute__((visibility("default")))
+#endif // _WIN32
+// END doxygen exclude @endcond
+
+namespace edgetpu {
+
+// Edge TPU custom op.
+static const char kCustomOp[] = "edgetpu-custom-op";
+
+// The device interface used with the host
+enum class DeviceType {
+ // PCIe Gen2 x1
+ kApexPci = 0,
+ // USB 2.0 or 3.1 Gen1
+ kApexUsb = 1,
+};
+
+class EdgeTpuContext;
+
+// Singleton Edge TPU manager for allocating new TPU contexts.
+// Functions in this interface are thread-safe.
+class EDGETPU_EXPORT EdgeTpuManager {
+ public:
+ // See EdgeTpuContext::GetDeviceOptions().
+ using DeviceOptions = std::unordered_map;
+ // Details about a particular Edge TPU
+ struct DeviceEnumerationRecord {
+ // The Edge TPU device type, either PCIe or USB
+ DeviceType type;
+ // System path for the Edge TPU device
+ std::string path;
+
+ // Returns true if two enumeration records point to the same device.
+ friend bool operator==(const DeviceEnumerationRecord& lhs,
+ const DeviceEnumerationRecord& rhs) {
+ return (lhs.type == rhs.type) && (lhs.path == rhs.path);
+ }
+
+ // Returns true if two enumeration records point to defferent devices.
+ friend bool operator!=(const DeviceEnumerationRecord& lhs,
+ const DeviceEnumerationRecord& rhs) {
+ return !(lhs == rhs);
+ }
+ };
+
+ // Returns a pointer to the singleton object, or nullptr if not supported on
+ // this platform.
+ static EdgeTpuManager* GetSingleton();
+
+ // @cond BEGIN doxygen exclude for deprecated APIs.
+
+ // NewEdgeTpuContext family functions has been deprecated and will be removed
+ // in the future. Please use OpenDevice for new code.
+ //
+ // These functions return an unique_ptr to EdgeTpuContext, with
+ // the intention that the device will be closed, and associate resources
+ // released, when the unique_ptr leaves scope.
+ //
+ // These functions seek exclusive ownership of the opened devices. As they
+ // cannot open devices already opened by OpenDevice, and vice versa.
+ // Devices opened through these functions would have attribute
+ // "ExclusiveOwnership", which can be queried through
+ // #EdgeTpuContext::GetDeviceOptions().
+
+ // Creates a new Edge TPU context to be assigned to Tflite::Interpreter. The
+ // Edge TPU context is associated with the default TPU device. May be null
+ // if underlying device cannot be found or open. Caller owns the returned new
+ // context and should destroy the context either implicity or explicitly after
+ // all interpreters sharing this context are destroyed.
+ virtual std::unique_ptr NewEdgeTpuContext() = 0;
+
+ // Same as above, but the created context is associated with the specified
+ // type.
+ virtual std::unique_ptr NewEdgeTpuContext(
+ DeviceType device_type) = 0;
+
+ // Same as above, but the created context is associated with the specified
+ // type and device path.
+ virtual std::unique_ptr NewEdgeTpuContext(
+ DeviceType device_type, const std::string& device_path) = 0;
+
+ // Same as above, but the created context is associated with the given device
+ // type, path and options.
+ //
+ // Available options are:
+ // - "Performance": ["Low", "Medium", "High", "Max"] (Default is "Max")
+ // - "Usb.AlwaysDfu": ["True", "False"] (Default is "False")
+ // - "Usb.MaxBulkInQueueLength": ["0",.., "255"] (Default is "32")
+ virtual std::unique_ptr NewEdgeTpuContext(
+ DeviceType device_type, const std::string& device_path,
+ const DeviceOptions& options) = 0;
+ // END doxygen exclude for deprecated APIs @endcond
+
+
+ // Enumerates all connected Edge TPU devices.
+ virtual std::vector EnumerateEdgeTpu() const = 0;
+
+ // Opens the default Edge TPU device.
+ //
+ // All `OpenDevice` functions return a shared_ptr to EdgeTpuContext, with
+ // the intention that the device can be shared among multiple software
+ // components. The device is closed after the last reference leaves scope.
+ //
+ // Multiple invocations of this function could return handle to the same
+ // device, but there is no guarantee.
+ //
+ // You cannot open devices opened by `NewEdgeTpuContext`, and vice versa.
+ //
+ // @return A shared pointer to Edge TPU device. The shared_ptr could point to
+ // nullptr in case of error.
+ virtual std::shared_ptr OpenDevice() = 0;
+
+ // Same as above, but the returned context is associated with the specified
+ // type.
+ //
+ // @param device_type The DeviceType you want to open.
+ virtual std::shared_ptr OpenDevice(
+ DeviceType device_type) = 0;
+
+ // Same as above, but the returned context is associated with the specified
+ // type and device path. If path is empty, any device of the specified type
+ // could be returned.
+ //
+ // @param device_type The DeviceType you want to open.
+ // @param device_path A path to the device you want.
+ //
+ // @return A shared pointer to Edge TPU device. The shared_ptr could point to
+ // nullptr in case of error.
+ virtual std::shared_ptr OpenDevice(
+ DeviceType device_type, const std::string& device_path) = 0;
+
+ // Same as above, but the specified options are used to create a new context
+ // if no existing device is compatible with the specified type and path.
+ //
+ // If a device of compatible type and path is not found, the options could be
+ // ignored. It is the caller's responsibility to verify if the returned
+ // context is desirable, through EdgeTpuContext::GetDeviceOptions().
+ //
+ // @param device_type The DeviceType you want to open.
+ // @param device_path A path to the device you want.
+ // @param options Specific criteria for the device you want.
+ // Available options are:
+ // - "Performance": ["Low", "Medium", "High", "Max"] (Default is "Max")
+ // - "Usb.AlwaysDfu": ["True", "False"] (Default is "False")
+ // - "Usb.MaxBulkInQueueLength": ["0",.., "255"] (Default is "32")
+ //
+ // @return A shared pointer to Edge TPU device. The shared_ptr could point to
+ // nullptr in case of error.
+ virtual std::shared_ptr OpenDevice(
+ DeviceType device_type, const std::string& device_path,
+ const DeviceOptions& options) = 0;
+
+ // Returns a snapshot of currently opened shareable devices.
+ // Exclusively owned Edge TPU devices cannot be returned here, as they're
+ // owned by unique pointers.
+ virtual std::vector> GetOpenedDevices()
+ const = 0;
+
+ // Sets the verbosity of operating logs related to each Edge TPU.
+ //
+ // @param verbosity The verbosity level, which may be 0 to 10.
+ // 10 is the most verbose; 0 is the default.
+ virtual TfLiteStatus SetVerbosity(int verbosity) = 0;
+
+ // Returns the version of the Edge TPU runtime stack.
+ virtual std::string Version() const = 0;
+
+ protected:
+ // No deletion for this singleton instance.
+ virtual ~EdgeTpuManager() = default;
+};
+
+// EdgeTpuContext is an object associated with one or more tflite::Interpreter.
+// Instances of this class should be allocated with EdgeTpuManager::OpenDevice.
+//
+// More than one Interpreter instances can point to the same context. This means
+// the tasks from both would be executed under the same TPU context.
+// The lifetime of this context must be longer than all associated
+// tflite::Interpreter instances.
+//
+// Functions in this interface are thread-safe.
+//
+// Typical usage with Coral:
+//
+// ```
+// // Sets up the tpu_context.
+// auto tpu_context =
+// edgetpu::EdgeTpuManager::GetSingleton()->OpenDevice();
+//
+// std::unique_ptr interpreter;
+// tflite::ops::builtin::BuiltinOpResolver resolver;
+// auto model =
+// tflite::FlatBufferModel::BuildFromFile(model_file_name.c_str());
+//
+// // Registers Edge TPU custom op handler with Tflite resolver.
+// resolver.AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp());
+//
+// tflite::InterpreterBuilder(*model, resolver)(&interpreter);
+//
+// // Binds a context with a specific interpreter.
+// interpreter->SetExternalContext(kTfLiteEdgeTpuContext,
+// tpu_context.get());
+//
+// // Note that all edge TPU context set ups should be done before this
+// // function is called.
+// interpreter->AllocateTensors();
+// .... (Prepare input tensors)
+// interpreter->Invoke();
+// .... (retrieving the result from output tensors)
+//
+// // Releases interpreter instance to free up resources associated with
+// // this custom op.
+// interpreter.reset();
+//
+// // Closes the edge TPU.
+// tpu_context.reset();
+// ```
+//
+// Typical usage with Android NNAPI:
+//
+// ```
+// std::unique_ptr interpreter;
+// tflite::ops::builtin::BuiltinOpResolver resolver;
+// auto model =
+// tflite::FlatBufferModel::BuildFromFile(model_file_name.c_str());
+//
+// // Registers Edge TPU custom op handler with Tflite resolver.
+// resolver.AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp());
+//
+// tflite::InterpreterBuilder(*model, resolver)(&interpreter);
+//
+// interpreter->AllocateTensors();
+// .... (Prepare input tensors)
+// interpreter->Invoke();
+// .... (retrieving the result from output tensors)
+//
+// // Releases interpreter instance to free up resources associated with
+// // this custom op.
+// interpreter.reset();
+// ```
+class EdgeTpuContext : public TfLiteExternalContext {
+ public:
+ virtual ~EdgeTpuContext() = 0;
+
+ // Returns a pointer to the device enumeration record for this device,
+ // if available.
+ virtual const EdgeTpuManager::DeviceEnumerationRecord& GetDeviceEnumRecord()
+ const = 0;
+
+ // Returns a snapshot of the options used to open this
+ // device, and current state, if available.
+ //
+ // Supported attributes are:
+ // - "ExclusiveOwnership": present when it is under exclusive ownership
+ // (unique_ptr returned by NewEdgeTpuContext).
+ // - "IsReady": present when it is ready for further requests.
+ virtual EdgeTpuManager::DeviceOptions GetDeviceOptions() const = 0;
+
+ // Returns true if the device is most likely ready to accept requests.
+ // When there are fatal errors, including unplugging of an USB device, the
+ // state of this device would be changed.
+ virtual bool IsReady() const = 0;
+};
+
+// Returns pointer to an instance of TfLiteRegistration to handle
+// Edge TPU custom ops, to be used with
+// tflite::ops::builtin::BuiltinOpResolver::AddCustom
+EDGETPU_EXPORT TfLiteRegistration* RegisterCustomOp();
+
+// Inserts name of device type into ostream. Returns the modified ostream.
+EDGETPU_EXPORT std::ostream& operator<<(std::ostream& out,
+ DeviceType device_type);
+
+} // namespace edgetpu
+
+
+#endif // TFLITE_PUBLIC_EDGETPU_H_
diff --git a/libedgetpu_bin/edgetpu_c.h b/libedgetpu_bin/edgetpu_c.h
new file mode 100644
index 0000000..64b73de
--- /dev/null
+++ b/libedgetpu_bin/edgetpu_c.h
@@ -0,0 +1,116 @@
+/*
+Copyright 2019 Google LLC
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+//
+// This header defines C API to provide edge TPU support for TensorFlow Lite
+// framework. It is only available for non-NNAPI use cases.
+//
+// Typical API usage from C++ code involves serveral steps:
+//
+// 1. Create tflite::FlatBufferModel which may contain edge TPU custom op.
+//
+// auto model =
+// tflite::FlatBufferModel::BuildFromFile(model_file_name.c_str());
+//
+// 2. Create tflite::Interpreter.
+//
+// tflite::ops::builtin::BuiltinOpResolver resolver;
+// std::unique_ptr interpreter;
+// tflite::InterpreterBuilder(model, resolver)(&interpreter);
+//
+// 3. Enumerate edge TPU devices.
+//
+// size_t num_devices;
+// std::unique_ptr devices(
+// edgetpu_list_devices(&num_devices), &edgetpu_free_devices);
+//
+// assert(num_devices > 0);
+// const auto& device = devices.get()[0];
+//
+// 4. Modify interpreter with the delegate.
+//
+// auto* delegate =
+// edgetpu_create_delegate(device.type, device.path, nullptr, 0);
+// interpreter->ModifyGraphWithDelegate({delegate, edgetpu_free_delegate});
+//
+// 5. Prepare input tensors and run inference.
+//
+// interpreter->AllocateTensors();
+// .... (Prepare input tensors)
+// interpreter->Invoke();
+// .... (Retrieve the result from output tensors)
+
+#ifndef TFLITE_PUBLIC_EDGETPU_C_H_
+#define TFLITE_PUBLIC_EDGETPU_C_H_
+
+#include "tensorflow/lite/c/common.h"
+
+#if defined(_WIN32)
+#ifdef EDGETPU_COMPILE_LIBRARY
+#define EDGETPU_EXPORT __declspec(dllexport)
+#else
+#define EDGETPU_EXPORT __declspec(dllimport)
+#endif // EDGETPU_COMPILE_LIBRARY
+#else
+#define EDGETPU_EXPORT __attribute__((visibility("default")))
+#endif // _WIN32
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+enum edgetpu_device_type {
+ EDGETPU_APEX_PCI = 0,
+ EDGETPU_APEX_USB = 1,
+};
+
+struct edgetpu_device {
+ enum edgetpu_device_type type;
+ const char* path;
+};
+
+struct edgetpu_option {
+ const char* name;
+ const char* value;
+};
+
+// Returns array of connected edge TPU devices.
+EDGETPU_EXPORT struct edgetpu_device* edgetpu_list_devices(size_t* num_devices);
+
+// Frees array returned by `edgetpu_list_devices`.
+EDGETPU_EXPORT void edgetpu_free_devices(struct edgetpu_device* dev);
+
+// Creates a delegate which handles all edge TPU custom ops inside
+// `tflite::Interpreter`. Options must be available only during the call of this
+// function.
+EDGETPU_EXPORT TfLiteDelegate* edgetpu_create_delegate(
+ enum edgetpu_device_type type, const char* name,
+ const struct edgetpu_option* options, size_t num_options);
+
+// Frees delegate returned by `edgetpu_create_delegate`.
+EDGETPU_EXPORT void edgetpu_free_delegate(TfLiteDelegate* delegate);
+
+// Sets verbosity of operating logs related to edge TPU.
+// Verbosity level can be set to [0-10], in which 10 is the most verbose.
+EDGETPU_EXPORT void edgetpu_verbosity(int verbosity);
+
+// Returns the version of edge TPU runtime stack.
+EDGETPU_EXPORT const char* edgetpu_version();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // TFLITE_PUBLIC_EDGETPU_C_H_
diff --git a/libedgetpu_bin/throttled/aarch64/libedgetpu.so.1 b/libedgetpu_bin/throttled/aarch64/libedgetpu.so.1
new file mode 120000
index 0000000..90ac68c
--- /dev/null
+++ b/libedgetpu_bin/throttled/aarch64/libedgetpu.so.1
@@ -0,0 +1 @@
+libedgetpu.so.1.0
\ No newline at end of file
diff --git a/libedgetpu_bin/throttled/aarch64/libedgetpu.so.1.0 b/libedgetpu_bin/throttled/aarch64/libedgetpu.so.1.0
new file mode 100755
index 0000000..f093e95
Binary files /dev/null and b/libedgetpu_bin/throttled/aarch64/libedgetpu.so.1.0 differ
diff --git a/libedgetpu_bin/throttled/armv6/libedgetpu.so.1 b/libedgetpu_bin/throttled/armv6/libedgetpu.so.1
new file mode 120000
index 0000000..90ac68c
--- /dev/null
+++ b/libedgetpu_bin/throttled/armv6/libedgetpu.so.1
@@ -0,0 +1 @@
+libedgetpu.so.1.0
\ No newline at end of file
diff --git a/libedgetpu_bin/throttled/armv6/libedgetpu.so.1.0 b/libedgetpu_bin/throttled/armv6/libedgetpu.so.1.0
new file mode 100755
index 0000000..fedd538
Binary files /dev/null and b/libedgetpu_bin/throttled/armv6/libedgetpu.so.1.0 differ
diff --git a/libedgetpu_bin/throttled/armv7a/libedgetpu.so.1 b/libedgetpu_bin/throttled/armv7a/libedgetpu.so.1
new file mode 120000
index 0000000..90ac68c
--- /dev/null
+++ b/libedgetpu_bin/throttled/armv7a/libedgetpu.so.1
@@ -0,0 +1 @@
+libedgetpu.so.1.0
\ No newline at end of file
diff --git a/libedgetpu_bin/throttled/armv7a/libedgetpu.so.1.0 b/libedgetpu_bin/throttled/armv7a/libedgetpu.so.1.0
new file mode 100755
index 0000000..ef2aa71
Binary files /dev/null and b/libedgetpu_bin/throttled/armv7a/libedgetpu.so.1.0 differ
diff --git a/libedgetpu_bin/throttled/darwin/libedgetpu.1.0.dylib b/libedgetpu_bin/throttled/darwin/libedgetpu.1.0.dylib
new file mode 100755
index 0000000..a00b2e5
Binary files /dev/null and b/libedgetpu_bin/throttled/darwin/libedgetpu.1.0.dylib differ
diff --git a/libedgetpu_bin/throttled/darwin/libedgetpu.1.dylib b/libedgetpu_bin/throttled/darwin/libedgetpu.1.dylib
new file mode 120000
index 0000000..e2c7584
--- /dev/null
+++ b/libedgetpu_bin/throttled/darwin/libedgetpu.1.dylib
@@ -0,0 +1 @@
+libedgetpu.1.0.dylib
\ No newline at end of file
diff --git a/libedgetpu_bin/throttled/k8/libedgetpu.so.1 b/libedgetpu_bin/throttled/k8/libedgetpu.so.1
new file mode 120000
index 0000000..90ac68c
--- /dev/null
+++ b/libedgetpu_bin/throttled/k8/libedgetpu.so.1
@@ -0,0 +1 @@
+libedgetpu.so.1.0
\ No newline at end of file
diff --git a/libedgetpu_bin/throttled/k8/libedgetpu.so.1.0 b/libedgetpu_bin/throttled/k8/libedgetpu.so.1.0
new file mode 100755
index 0000000..feb6c64
Binary files /dev/null and b/libedgetpu_bin/throttled/k8/libedgetpu.so.1.0 differ
diff --git a/libedgetpu_bin/throttled/x64_windows/edgetpu.dll b/libedgetpu_bin/throttled/x64_windows/edgetpu.dll
new file mode 100644
index 0000000..d7eb6b4
Binary files /dev/null and b/libedgetpu_bin/throttled/x64_windows/edgetpu.dll differ
diff --git a/libedgetpu_bin/throttled/x64_windows/edgetpu.dll.if.lib b/libedgetpu_bin/throttled/x64_windows/edgetpu.dll.if.lib
new file mode 100644
index 0000000..1a1591d
Binary files /dev/null and b/libedgetpu_bin/throttled/x64_windows/edgetpu.dll.if.lib differ
diff --git a/pycoral/__init__.py b/pycoral/__init__.py
new file mode 100644
index 0000000..b3e7fbc
--- /dev/null
+++ b/pycoral/__init__.py
@@ -0,0 +1,17 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Version information for Coral Python APIs."""
+
+__version__ = "1.0.0"
diff --git a/pycoral/adapters/__init__.py b/pycoral/adapters/__init__.py
new file mode 100644
index 0000000..3025788
--- /dev/null
+++ b/pycoral/adapters/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/pycoral/adapters/classify.py b/pycoral/adapters/classify.py
new file mode 100644
index 0000000..4a477ee
--- /dev/null
+++ b/pycoral/adapters/classify.py
@@ -0,0 +1,106 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Functions to work with a classification model."""
+
+import collections
+import operator
+import numpy as np
+
+
+Class = collections.namedtuple('Class', ['id', 'score'])
+"""Represents a single classification, with the following fields:
+
+ .. py:attribute:: id
+
+ The class id.
+
+ .. py:attribute:: score
+
+ The prediction score.
+"""
+
+
+def num_classes(interpreter):
+ """Gets the number of classes output by a classification model.
+
+ Args:
+ interpreter: The ``tf.lite.Interpreter`` holding the model.
+
+ Returns:
+ The total number of classes output by the model.
+ """
+ return np.prod(interpreter.get_output_details()[0]['shape'])
+
+
+def get_scores(interpreter):
+ """Gets the output (all scores) from a classification model, dequantizing it if necessary.
+
+ Args:
+ interpreter: The ``tf.lite.Interpreter`` to query for output.
+
+ Returns:
+ The output tensor (flattened and dequantized) as :obj:`numpy.array`.
+ """
+ output_details = interpreter.get_output_details()[0]
+ output_data = interpreter.tensor(output_details['index'])().flatten()
+
+ if np.issubdtype(output_details['dtype'], np.integer):
+ scale, zero_point = output_details['quantization']
+ # Always convert to np.int64 to avoid overflow on subtraction.
+ return scale * (output_data.astype(np.int64) - zero_point)
+
+ return output_data
+
+
+def get_classes_from_scores(scores,
+ top_k=float('inf'),
+ score_threshold=-float('inf')):
+ """Gets results from a classification model as a list of ordered classes, based on given scores.
+
+ Args:
+ scores: The output from a classification model. Must be flattened and
+ dequantized.
+ top_k (int): The number of top results to return.
+ score_threshold (float): The score threshold for results. All returned
+ results have a score greater-than-or-equal-to this value.
+
+ Returns:
+ A list of :obj:`Class` objects representing the classification results,
+ ordered by scores.
+ """
+ top_k = min(top_k, len(scores))
+ classes = [
+ Class(i, scores[i])
+ for i in np.argpartition(scores, -top_k)[-top_k:]
+ if scores[i] >= score_threshold
+ ]
+ return sorted(classes, key=operator.itemgetter(1), reverse=True)
+
+
+def get_classes(interpreter, top_k=float('inf'), score_threshold=-float('inf')):
+ """Gets results from a classification model as a list of ordered classes.
+
+ Args:
+ interpreter: The ``tf.lite.Interpreter`` to query for results.
+ top_k (int): The number of top results to return.
+ score_threshold (float): The score threshold for results. All returned
+ results have a score greater-than-or-equal-to this value.
+
+ Returns:
+ A list of :obj:`Class` objects representing the classification results,
+ ordered by scores.
+ """
+ return get_classes_from_scores(
+ get_scores(interpreter), top_k, score_threshold)
diff --git a/pycoral/adapters/common.py b/pycoral/adapters/common.py
new file mode 100644
index 0000000..e1f55e2
--- /dev/null
+++ b/pycoral/adapters/common.py
@@ -0,0 +1,100 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Functions to work with any model."""
+
+import numpy as np
+
+
+def output_tensor(interpreter, i):
+ """Gets a model's ith output tensor.
+
+ Args:
+ interpreter: The ``tf.lite.Interpreter`` holding the model.
+ i (int): The index position of an output tensor.
+ Returns:
+ The output tensor at the specified position.
+ """
+ return interpreter.tensor(interpreter.get_output_details()[i]['index'])()
+
+
+def input_details(interpreter, key):
+ """Gets a model's input details by specified key.
+
+ Args:
+ interpreter: The ``tf.lite.Interpreter`` holding the model.
+ key (int): The index position of an input tensor.
+ Returns:
+ The input details.
+ """
+ return interpreter.get_input_details()[0][key]
+
+
+def input_size(interpreter):
+ """Gets a model's input size as (width, height) tuple.
+
+ Args:
+ interpreter: The ``tf.lite.Interpreter`` holding the model.
+ Returns:
+ The input tensor size as (width, height) tuple.
+ """
+ _, height, width, _ = input_details(interpreter, 'shape')
+ return width, height
+
+
+def input_tensor(interpreter):
+ """Gets a model's input tensor view as numpy array of shape (height, width, 3).
+
+ Args:
+ interpreter: The ``tf.lite.Interpreter`` holding the model.
+ Returns:
+ The input tensor view as :obj:`numpy.array` (height, width, 3).
+ """
+ tensor_index = input_details(interpreter, 'index')
+ return interpreter.tensor(tensor_index)()[0]
+
+
+def set_input(interpreter, data):
+ """Copies data to a model's input tensor.
+
+ Args:
+ interpreter: The ``tf.lite.Interpreter`` to update.
+ data: The input tensor.
+ """
+ input_tensor(interpreter)[:, :] = data
+
+
+def set_resized_input(interpreter, size, resize):
+ """Copies a resized and properly zero-padded image to a model's input tensor.
+
+ Args:
+ interpreter: The ``tf.lite.Interpreter`` to update.
+ size (tuple): The original image size as (width, height) tuple.
+ resize: A function that takes a (width, height) tuple, and returns an
+ image resized to those dimensions.
+
+ Returns:
+ The resized tensor with zero-padding as tuple
+ (resized_tensor, resize_ratio).
+ """
+ width, height = input_size(interpreter)
+ w, h = size
+ scale = min(width / w, height / h)
+ w, h = int(w * scale), int(h * scale)
+ tensor = input_tensor(interpreter)
+ tensor.fill(0) # padding
+ _, _, channel = tensor.shape
+ result = resize((w, h))
+ tensor[:h, :w] = np.reshape(result, (h, w, channel))
+ return result, (scale, scale)
diff --git a/pycoral/adapters/detect.py b/pycoral/adapters/detect.py
new file mode 100644
index 0000000..b36dfbf
--- /dev/null
+++ b/pycoral/adapters/detect.py
@@ -0,0 +1,208 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Functions to work with a detection model."""
+
+import collections
+from pycoral.adapters import common
+
+Object = collections.namedtuple('Object', ['id', 'score', 'bbox'])
+"""Represents a detected object.
+
+ .. py:attribute:: id
+
+ The object's class id.
+
+ .. py:attribute:: score
+
+ The object's prediction score.
+
+ .. py:attribute:: bbox
+
+ A :obj:`BBox` object defining the object's location.
+"""
+
+
+class BBox(collections.namedtuple('BBox', ['xmin', 'ymin', 'xmax', 'ymax'])):
+ """The bounding box for a detected object.
+
+ .. py:attribute:: xmin
+
+ X-axis start point
+
+ .. py:attribute:: ymin
+
+ Y-axis start point
+
+ .. py:attribute:: xmax
+
+ X-axis end point
+
+ .. py:attribute:: ymax
+
+ Y-axis end point
+ """
+ __slots__ = ()
+
+ @property
+ def width(self):
+ """The bounding box width."""
+ return self.xmax - self.xmin
+
+ @property
+ def height(self):
+ """The bounding box height."""
+ return self.ymax - self.ymin
+
+ @property
+ def area(self):
+ """The bound box area."""
+ return self.width * self.height
+
+ @property
+ def valid(self):
+ """Indicates whether bounding box is valid or not (boolean).
+
+ A valid bounding box has xmin <= xmax and ymin <= ymax (equivalent
+ to width >= 0 and height >= 0).
+ """
+ return self.width >= 0 and self.height >= 0
+
+ def scale(self, sx, sy):
+ """Scales the bounding box.
+
+ Args:
+ sx (float): Scale factor for the x-axis.
+ sy (float): Scale factor for the y-axis.
+ Returns:
+ A :obj:`BBox` object with the rescaled dimensions.
+ """
+ return BBox(xmin=sx * self.xmin,
+ ymin=sy * self.ymin,
+ xmax=sx * self.xmax,
+ ymax=sy * self.ymax)
+
+ def translate(self, dx, dy):
+ """Translates the bounding box position.
+
+ Args:
+ dx (int): Number of pixels to move the box on the x-axis.
+ dy (int): Number of pixels to move the box on the y-axis.
+ Returns:
+ A :obj:`BBox` object at the new position.
+ """
+ return BBox(xmin=dx + self.xmin,
+ ymin=dy + self.ymin,
+ xmax=dx + self.xmax,
+ ymax=dy + self.ymax)
+
+ def map(self, f):
+ """Maps all box coordinates to a new position using a given function.
+
+ Args:
+ f: A function that takes a single coordinate and returns a new one.
+ Returns:
+ A :obj:`BBox` with the new coordinates.
+ """
+ return BBox(xmin=f(self.xmin),
+ ymin=f(self.ymin),
+ xmax=f(self.xmax),
+ ymax=f(self.ymax))
+
+ @staticmethod
+ def intersect(a, b):
+ """Gets a box representing the intersection between two boxes.
+
+ Args:
+ a: :obj:`BBox` A.
+ b: :obj:`BBox` B.
+ Returns:
+ A :obj:`BBox` representing the area where the two boxes intersect
+ (may be an invalid box, check with :func:`valid`).
+ """
+ return BBox(xmin=max(a.xmin, b.xmin),
+ ymin=max(a.ymin, b.ymin),
+ xmax=min(a.xmax, b.xmax),
+ ymax=min(a.ymax, b.ymax))
+
+ @staticmethod
+ def union(a, b):
+ """Gets a box representing the union of two boxes.
+
+ Args:
+ a: :obj:`BBox` A.
+ b: :obj:`BBox` B.
+ Returns:
+ A :obj:`BBox` representing the unified area of the two boxes
+ (always a valid box).
+ """
+ return BBox(xmin=min(a.xmin, b.xmin),
+ ymin=min(a.ymin, b.ymin),
+ xmax=max(a.xmax, b.xmax),
+ ymax=max(a.ymax, b.ymax))
+
+ @staticmethod
+ def iou(a, b):
+ """Gets the intersection-over-union value for two boxes.
+
+ Args:
+ a: :obj:`BBox` A.
+ b: :obj:`BBox` B.
+ Returns:
+ The intersection-over-union value: 1.0 meaning the two boxes are
+ perfectly aligned, 0 if not overlapping at all (invalid intersection).
+ """
+ intersection = BBox.intersect(a, b)
+ if not intersection.valid:
+ return 0.0
+ area = intersection.area
+ return area / (a.area + b.area - area)
+
+
+def get_objects(interpreter,
+ score_threshold=-float('inf'),
+ image_scale=(1.0, 1.0)):
+ """Gets results from a detection model as a list of detected objects.
+
+ Args:
+ interpreter: The ``tf.lite.Interpreter`` to query for results.
+ score_threshold (float): The score threshold for results. All returned
+ results have a score greater-than-or-equal-to this value.
+ image_scale (float, float): Scaling factor to apply to the bounding boxes
+ as (x-scale-factor, y-scale-factor), where each factor is from 0 to 1.0.
+
+ Returns:
+ A list of :obj:`Object` objects, which each contains the detected object's
+ id, score, and bounding box as :obj:`BBox`.
+ """
+ boxes = common.output_tensor(interpreter, 0)[0]
+ class_ids = common.output_tensor(interpreter, 1)[0]
+ scores = common.output_tensor(interpreter, 2)[0]
+ count = int(common.output_tensor(interpreter, 3)[0])
+
+ width, height = common.input_size(interpreter)
+ image_scale_x, image_scale_y = image_scale
+ sx, sy = width / image_scale_x, height / image_scale_y
+
+ def make(i):
+ ymin, xmin, ymax, xmax = boxes[i]
+ return Object(
+ id=int(class_ids[i]),
+ score=float(scores[i]),
+ bbox=BBox(xmin=xmin,
+ ymin=ymin,
+ xmax=xmax,
+ ymax=ymax).scale(sx, sy).map(int))
+
+ return [make(i) for i in range(count) if scores[i] >= score_threshold]
diff --git a/pycoral/adapters/segment.py b/pycoral/adapters/segment.py
new file mode 100644
index 0000000..56135f1
--- /dev/null
+++ b/pycoral/adapters/segment.py
@@ -0,0 +1,21 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Functions to work with segmentation models."""
+import numpy as np
+
+
+def get_output(interpreter):
+ output_details = interpreter.get_output_details()[0]
+ return interpreter.tensor(output_details['index'])()[0].astype(np.uint8)
diff --git a/pycoral/learn/__init__.py b/pycoral/learn/__init__.py
new file mode 100644
index 0000000..086a24e
--- /dev/null
+++ b/pycoral/learn/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/pycoral/learn/backprop/__init__.py b/pycoral/learn/backprop/__init__.py
new file mode 100644
index 0000000..086a24e
--- /dev/null
+++ b/pycoral/learn/backprop/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/pycoral/learn/backprop/softmax_regression.py b/pycoral/learn/backprop/softmax_regression.py
new file mode 100644
index 0000000..f2fd372
--- /dev/null
+++ b/pycoral/learn/backprop/softmax_regression.py
@@ -0,0 +1,143 @@
+# Lint as: python3
+# pylint:disable=g-doc-args,g-short-docstring-punctuation,g-no-space-after-docstring-summary
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A softmax regression model for on-device backpropagation of the last layer."""
+from pycoral.pybind import _pywrap_coral
+
+
+class SoftmaxRegression:
+ """An implementation of the softmax regression function (multinominal logistic
+
+ regression) that operates as the last layer of your classification model, and
+ allows for on-device training with backpropagation (for this layer only).
+
+ The input for this layer must be an image embedding, which should be the
+ output of your embedding extractor (the backbone of your model). Once given
+ here, the input is fed to a fully-connected layer where weights and bias are
+ applied, and then passed to the softmax function to receive the final
+ probability distribution based on the number of classes for your model:
+
+ training/inference input (image embedding) --> fully-connected layer -->
+ softmax function
+
+ When you're conducting training with :func:`train_with_sgd`, the process uses
+ a cross-entropy loss function to measure the error and then update the weights
+ of the fully-connected layer (backpropagation).
+
+ When you're satisfied with the inference accuracy, call
+ :func:`serialize_model` to create a new model in `bytes` with this
+ retrained layer appended to your embedding extractor. You can then run
+ inferences with this new model as usual (using TensorFlow Lite interpreter
+ API).
+
+ .. note::
+
+ This last layer (FC + softmax) in the retrained model always runs on the
+ host CPU instead of the Edge TPU. As long as the rest of your embedding
+ extractor model is compiled for the Edge TPU, then running this last layer
+ on the CPU should not significantly affect the inference speed.
+
+
+ """
+
+ def __init__(self,
+ feature_dim=None,
+ num_classes=None,
+ weight_scale=0.01,
+ reg=0.0):
+ """For more detail, see the `Stanford CS231 explanation of the softmax
+ classifier `_.
+
+ Args:
+ feature_dim (int): The dimension of the input feature (length of the
+ feature vector).
+ num_classes (int): The number of output classes.
+ weight_scale (float): A weight factor for computing new weights. The
+ backpropagated weights are drawn from standard normal distribution, then
+ multiplied by this number to keep the scale small.
+ reg (float): The regularization strength.
+ """
+ self.model = _pywrap_coral.SoftmaxRegressionModelWrapper(
+ feature_dim, num_classes, weight_scale, reg)
+
+ def serialize_model(self, in_model_path):
+ """Appends learned weights to your TensorFlow Lite model and serializes it.
+
+ Beware that learned weights and biases are quantized from float32 to uint8.
+
+ Args:
+ in_model_path (str): Path to the embedding extractor model (``.tflite``
+ file).
+
+ Returns:
+ The TF Lite model with new weights, as a `bytes` object.
+ """
+ return self.model.AppendLayersToEmbeddingExtractor(in_model_path)
+
+ def get_accuracy(self, mat_x, labels):
+ """Calculates the model's accuracy (percentage correct).
+
+ The calculation is on performing inferences on the given data and labels.
+
+ Args:
+ mat_x (:obj:`numpy.array`): The input data (image embeddings) to test,
+ as a matrix of shape ``NxD``, where ``N`` is number of inputs to test
+ and ``D`` is the dimension of the input feature (length of the feature
+ vector).
+ labels (:obj:`numpy.array`): An array of the correct label indices that
+ correspond to the test data passed in ``mat_x`` (class label index in
+ one-hot vector).
+
+ Returns:
+ The accuracy (the percent correct) as a float.
+ """
+ return self.model.GetAccuracy(mat_x, labels)
+
+ def train_with_sgd(self,
+ data,
+ num_iter,
+ learning_rate,
+ batch_size=100,
+ print_every=100):
+ """Trains your model using stochastic gradient descent (SGD).
+
+ The training data must be structured in a dictionary as specified in the
+ ``data`` argument below. Notably, the training/validation images must be
+ passed as image embeddings, not as the original image input. That is, run
+ the images through your embedding extractor (the backbone of your graph) and
+ use the resulting image embeddings here.
+
+ Args:
+ data (dict): A dictionary that maps ``'data_train'`` to an array of
+ training image embeddings, ``'labels_train'`` to an array of training
+ labels, ``'data_val'`` to an array of validation image embeddings, and
+ ``'labels_val'`` to an array of validation labels.
+ num_iter (int): The number of iterations to train.
+ learning_rate (float): The learning rate (step size) to use in training.
+ batch_size (int): The number of training examples to use in each
+ iteration.
+ print_every (int): The number of iterations for which to print the loss,
+ and training/validation accuracy. For example, ``20`` prints the stats
+ for every 20 iterations. ``0`` disables printing.
+ """
+ train_config = _pywrap_coral.TrainConfigWrapper(num_iter, batch_size,
+ print_every)
+
+ training_data = _pywrap_coral.TrainingDataWrapper(data['data_train'],
+ data['data_val'],
+ data['labels_train'],
+ data['labels_val'])
+
+ self.model.Train(training_data, train_config, learning_rate)
diff --git a/pycoral/learn/imprinting/__init__.py b/pycoral/learn/imprinting/__init__.py
new file mode 100644
index 0000000..086a24e
--- /dev/null
+++ b/pycoral/learn/imprinting/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/pycoral/learn/imprinting/engine.py b/pycoral/learn/imprinting/engine.py
new file mode 100644
index 0000000..2ac9b15
--- /dev/null
+++ b/pycoral/learn/imprinting/engine.py
@@ -0,0 +1,78 @@
+# Lint as: python3
+# pylint:disable=g-doc-args,g-short-docstring-punctuation,invalid-name,missing-class-docstring
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A weight imprinting engine that performs low-shot transfer-learning for image classification models.
+
+For more information about how to use this API and how to create the type of
+model required, see
+`Retrain a classification model on-device with weight imprinting
+`_.
+"""
+
+from pycoral.pybind import _pywrap_coral
+
+
+class ImprintingEngine:
+
+ def __init__(self, model_path, keep_classes=False):
+ """Performs weight imprinting (transfer learning) with the given model.
+
+ Args:
+ model_path (str): Path to the model you want to retrain. This model must
+ be a ``.tflite`` file output by the ``join_tflite_models`` tool. For
+ more information about how to create a compatible model, read `Retrain
+ an image classification model on-device
+ `_.
+ keep_classes (bool): If True, keep the existing classes from the
+ pre-trained model (and use training to add additional classes). If
+ False, drop the existing classes and train the model to include new
+ classes only.
+ """
+ self._engine = _pywrap_coral.ImprintingEnginePythonWrapper(
+ model_path, keep_classes)
+
+ @property
+ def embedding_dim(self):
+ """Returns number of embedding dimensions."""
+ return self._engine.EmbeddingDim()
+
+ @property
+ def num_classes(self):
+ """Returns number of currently trained classes."""
+ return self._engine.NumClasses()
+
+ def serialize_extractor_model(self):
+ """Returns embedding extractor model as `bytes` object."""
+ return self._engine.SerializeExtractorModel()
+
+ def serialize_model(self):
+ """Returns newly trained model as `bytes` object."""
+ return self._engine.SerializeModel()
+
+ def train(self, embedding, class_id):
+ """Trains the model with the given embedding for specified class.
+
+ You can use this to add new classes to the model or retrain classes that you
+ previously added using this imprinting API.
+
+ Args:
+ embedding (:obj:`numpy.array`): The embedding vector for training
+ specified single class.
+ class_id (int): The label id for this class. The index must be either the
+ number of existing classes (to add a new class to the model) or the
+ index of an existing class that was trained using this imprinting API
+ (you can't retrain classes from the pre-trained model).
+ """
+ self._engine.Train(embedding, class_id)
diff --git a/pycoral/pipeline/__init__.py b/pycoral/pipeline/__init__.py
new file mode 100644
index 0000000..3025788
--- /dev/null
+++ b/pycoral/pipeline/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/pycoral/pipeline/pipelined_model_runner.py b/pycoral/pipeline/pipelined_model_runner.py
new file mode 100644
index 0000000..05e0be1
--- /dev/null
+++ b/pycoral/pipeline/pipelined_model_runner.py
@@ -0,0 +1,167 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""The pipeline API allows you to run a segmented model across multiple Edge TPUs.
+
+For more information, see `Pipeline a model with multiple Edge
+TPUs `_.
+"""
+
+import numpy as np
+
+from pycoral.pybind import _pywrap_coral
+
+
+def _get_names(details):
+ """Returns a set of names given input/output tensor details."""
+ return {d['name'] for d in details}
+
+
+class PipelinedModelRunner:
+ """Manages the model pipeline.
+
+ To create an instance::
+
+ interpreter_a = tflite.Interpreter(model_path=model_segment_a,
+ experimental_delegates=delegate_a)
+ interpreter_a.allocate_tensors()
+ interpreter_b = tflite.Interpreter(model_path=model_segment_b,
+ experimental_delegates=delegate_b)
+ interpreter_b.allocate_tensors()
+ interpreters = [interpreter_a, interpreter_b]
+ runner = PipelinedModelRunner(interpreters)
+ """
+
+ def __init__(self, interpreters):
+ """Be sure you first call ``allocate_tensors()`` on each interpreter.
+
+ Args:
+ interpreters: A list of ``tf.lite.Interpreter`` objects, one for each
+ segment in the pipeline.
+ """
+ self._runner = None
+
+ if not interpreters:
+ raise ValueError('At least one interpreter expected')
+
+ # It requires that the inputs of interpreter[i] is a subset of outputs of
+ # interpreter[j], where j=0,...,i-1.
+ prev_outputs = _get_names(interpreters[0].get_input_details())
+ for index, interpreter in enumerate(interpreters):
+ inputs = _get_names(interpreter.get_input_details())
+ if not inputs.issubset(prev_outputs):
+ raise ValueError(
+ 'Interpreter {} can not get its input tensors'.format(index))
+ prev_outputs.update(_get_names(interpreter.get_output_details()))
+
+ self._interpreters = interpreters
+ self._runner = _pywrap_coral.PipelinedModelRunnerWrapper(
+ [i._native_handle() for i in interpreters])
+ self._input_types = [
+ d['dtype'] for d in self._interpreters[0].get_input_details()
+ ]
+ self._output_shapes = [
+ d['shape'] for d in self._interpreters[-1].get_output_details()
+ ]
+
+ def __del__(self):
+ if self._runner:
+ # Push empty request to stop the pipeline in case user forgot.
+ self.push([])
+ num_unconsumed = 0
+ # Release any unconsumed tensors if any.
+ while self.pop():
+ num_unconsumed += 1
+ if num_unconsumed:
+ print(
+ 'WARNING: {} unconsumed results in the pipeline during destruction!'
+ .format(num_unconsumed))
+
+ def set_input_queue_size(self, size):
+ """Sets the maximum number of inputs that may be queued for inference.
+
+ By default, input queue size is unlimited.
+
+ Note: It's OK to change the queue size max when PipelinedModelRunner is
+ active. If the new max is smaller than current queue size, pushes to
+ the queue are blocked until the current queue size drops below the new max.
+
+ Args:
+ size (int): The input queue size max
+ """
+ self._runner.SetInputQueueSize(size)
+
+ def set_output_queue_size(self, size):
+ """Sets the maximum number of outputs that may be unconsumed.
+
+ By default, output queue size is unlimited.
+
+ Note: It's OK to change the queue size max when PipelinedModelRunner is
+ active. If the new max is smaller than current queue size, pushes to the
+ queue are blocked until the current queue size drops below the new max.
+
+ Args:
+ size (int): The output queue size max
+ """
+ self._runner.SetOutputQueueSize(size)
+
+ def push(self, input_tensors):
+ """Pushes input tensors to trigger inference.
+
+ Pushing an empty list is allowed, which signals the class that no more
+ inputs will be added (the function will return false if inputs were pushed
+ after this special push). This special push allows the ``pop()`` consumer to
+ properly drain unconsumed output tensors.
+
+ Caller will be blocked if the current input queue size is greater than the
+ queue size max (use ``set_input_queue_size()``). By default, input queue
+ size threshold is unlimited, in this case, call to push() is non-blocking.
+
+ Args:
+ input_tensors: A list of :obj:`numpy.array` as the input for the given
+ model, in the appropriate order.
+
+ Returns:
+ True if push is successful; False otherwise.
+ """
+ if input_tensors and len(input_tensors) != len(self._input_types):
+ raise ValueError('Expected input of length {}, but got {}'.format(
+ len(self._input_types), len(input_tensors)))
+
+ for tensor, input_type in zip(input_tensors, self._input_types):
+ if not isinstance(tensor, np.ndarray) or tensor.dtype != input_type:
+ raise ValueError(
+ 'Input should be a list of numpy array of type {}'.format(
+ input_type))
+
+ return self._runner.Push(input_tensors)
+
+ def pop(self):
+ """Returns a single inference result.
+
+ This function blocks the calling thread until a result is returned.
+
+ Returns:
+ List of :obj:`numpy.array` objects representing the model's output
+ tensor. Returns None when a ``push()`` receives an empty list, indicating
+ there are no more output tensors available.
+ """
+ result = self._runner.Pop()
+ if result:
+ result = [r.reshape(s) for r, s in zip(result, self._output_shapes)]
+ return result
+
+ def interpreters(self):
+ """Returns list of interpreters that constructed PipelinedModelRunner."""
+ return self._interpreters
diff --git a/pycoral/utils/__init__.py b/pycoral/utils/__init__.py
new file mode 100644
index 0000000..086a24e
--- /dev/null
+++ b/pycoral/utils/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/pycoral/utils/dataset.py b/pycoral/utils/dataset.py
new file mode 100644
index 0000000..ceba145
--- /dev/null
+++ b/pycoral/utils/dataset.py
@@ -0,0 +1,45 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities to help process a dataset."""
+
+import re
+
+
+def read_label_file(file_path):
+ """Reads labels from a text file and returns it as a dictionary.
+
+ This function supports label files with the following formats:
+
+ + Each line contains id and description separated by colon or space.
+ Example: ``0:cat`` or ``0 cat``.
+ + Each line contains a description only. The returned label id's are based on
+ the row number.
+
+ Args:
+ file_path (str): path to the label file.
+
+ Returns:
+ Dict of (int, string) which maps label id to description.
+ """
+ with open(file_path, 'r', encoding='utf-8') as f:
+ lines = f.readlines()
+ ret = {}
+ for row_number, content in enumerate(lines):
+ pair = re.split(r'[:\s]+', content.strip(), maxsplit=1)
+ if len(pair) == 2 and pair[0].strip().isdigit():
+ ret[int(pair[0])] = pair[1].strip()
+ else:
+ ret[row_number] = pair[0].strip()
+ return ret
diff --git a/pycoral/utils/edgetpu.py b/pycoral/utils/edgetpu.py
new file mode 100644
index 0000000..ea8659f
--- /dev/null
+++ b/pycoral/utils/edgetpu.py
@@ -0,0 +1,187 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for using the TensorFlow Lite Interpreter with Edge TPU."""
+
+import contextlib
+import ctypes
+import ctypes.util
+
+import numpy as np
+
+# pylint:disable=unused-import
+from pycoral.pybind._pywrap_coral import GetRuntimeVersion as get_runtime_version
+from pycoral.pybind._pywrap_coral import InvokeWithBytes as invoke_with_bytes
+from pycoral.pybind._pywrap_coral import InvokeWithDmaBuffer as invoke_with_dmabuffer
+from pycoral.pybind._pywrap_coral import InvokeWithMemBuffer as invoke_with_membuffer
+from pycoral.pybind._pywrap_coral import ListEdgeTpus as list_edge_tpus
+from pycoral.pybind._pywrap_coral import SupportsDmabuf as supports_dmabuf
+import platform
+import tflite_runtime.interpreter as tflite
+
+_EDGETPU_SHARED_LIB = {
+ 'Linux': 'libedgetpu.so.1',
+ 'Darwin': 'libedgetpu.1.dylib',
+ 'Windows': 'edgetpu.dll'
+}[platform.system()]
+
+
+def load_edgetpu_delegate(options=None):
+ """Loads the Edge TPU delegate with the given options."""
+ return tflite.load_delegate(_EDGETPU_SHARED_LIB, options or {})
+
+
+def make_interpreter(model_path_or_content, device=None):
+ """Returns a new interpreter instance.
+
+ Interpreter is created from either model path or model content and attached
+ to an Edge TPU device.
+
+ Args:
+ model_path_or_content (str or bytes): `str` object is interpreted as
+ model path, `bytes` object is interpreted as model content.
+ device (str): The type of Edge TPU device you want:
+
+ + None -- use any Edge TPU
+ + ":" -- use N-th Edge TPU
+ + "usb" -- use any USB Edge TPU
+ + "usb:" -- use N-th USB Edge TPU
+ + "pci" -- use any PCIe Edge TPU
+ + "pci:" -- use N-th PCIe Edge TPU
+
+ Returns:
+ New ``tf.lite.Interpreter`` instance.
+ """
+ delegates = [load_edgetpu_delegate({'device': device} if device else {})]
+ if isinstance(model_path_or_content, bytes):
+ return tflite.Interpreter(
+ model_content=model_path_or_content, experimental_delegates=delegates)
+ else:
+ return tflite.Interpreter(
+ model_path=model_path_or_content, experimental_delegates=delegates)
+
+
+# ctypes definition of GstMapInfo. This is a stable API, guaranteed to be
+# ABI compatible for any past and future GStreamer 1.0 releases.
+# Used to get the underlying memory pointer without any copies, and without
+# native library linking against libgstreamer.
+class _GstMapInfo(ctypes.Structure):
+ _fields_ = [
+ ('memory', ctypes.c_void_p), # GstMemory *memory
+ ('flags', ctypes.c_int), # GstMapFlags flags
+ ('data', ctypes.c_void_p), # guint8 *data
+ ('size', ctypes.c_size_t), # gsize size
+ ('maxsize', ctypes.c_size_t), # gsize maxsize
+ ('user_data', ctypes.c_void_p * 4), # gpointer user_data[4]
+ ('_gst_reserved', ctypes.c_void_p * 4)
+ ] # GST_PADDING
+
+
+# Try to import GStreamer but don't fail if it's not available. If not available
+# we're probably not getting GStreamer buffers as input anyway.
+_libgst = None
+try:
+ # pylint:disable=g-import-not-at-top
+ import gi
+ gi.require_version('Gst', '1.0')
+ gi.require_version('GstAllocators', '1.0')
+ # pylint:disable=g-multiple-import
+ from gi.repository import Gst, GstAllocators
+ _libgst = ctypes.CDLL(ctypes.util.find_library('gstreamer-1.0'))
+ _libgst.gst_buffer_map.argtypes = [
+ ctypes.c_void_p,
+ ctypes.POINTER(_GstMapInfo), ctypes.c_int
+ ]
+ _libgst.gst_buffer_map.restype = ctypes.c_int
+ _libgst.gst_buffer_unmap.argtypes = [
+ ctypes.c_void_p, ctypes.POINTER(_GstMapInfo)
+ ]
+ _libgst.gst_buffer_unmap.restype = None
+except (ImportError, ValueError, OSError):
+ pass
+
+
+def _is_valid_ctypes_input(input_data):
+ if not isinstance(input_data, tuple):
+ return False
+ pointer, size = input_data
+ if not isinstance(pointer, ctypes.c_void_p):
+ return False
+ return isinstance(size, int)
+
+
+@contextlib.contextmanager
+def _gst_buffer_map(buffer):
+ """Yields gst buffer map."""
+ mapping = _GstMapInfo()
+ ptr = hash(buffer)
+ success = _libgst.gst_buffer_map(ptr, mapping, Gst.MapFlags.READ)
+ if not success:
+ raise RuntimeError('gst_buffer_map failed')
+ try:
+ yield ctypes.c_void_p(mapping.data), mapping.size
+ finally:
+ _libgst.gst_buffer_unmap(ptr, mapping)
+
+
+def _check_input_size(input_size, expected_input_size):
+ if input_size != expected_input_size:
+ raise ValueError('input size={}, expected={}.'.format(
+ input_size, expected_input_size))
+
+
+def run_inference(interpreter, input_data):
+ """Performs interpreter ``invoke()`` with a raw input tensor.
+
+ Args:
+ interpreter: The ``tf.lite.Interpreter`` to invoke.
+ input_data: A 1-D array as the input tensor. Input data must be uint8
+ format. Data may be ``Gst.Buffer`` or :obj:`numpy.ndarray`.
+ """
+ input_shape = interpreter.get_input_details()[0]['shape']
+ expected_input_size = np.prod(input_shape)
+
+ interpreter_handle = interpreter._native_handle() # pylint:disable=protected-access
+ if isinstance(input_data, bytes):
+ _check_input_size(len(input_data), expected_input_size)
+ invoke_with_bytes(interpreter_handle, input_data)
+ elif _is_valid_ctypes_input(input_data):
+ pointer, actual_size = input_data
+ _check_input_size(actual_size, expected_input_size)
+ invoke_with_membuffer(interpreter_handle, pointer.value,
+ expected_input_size)
+ elif _libgst and isinstance(input_data, Gst.Buffer):
+ memory = input_data.peek_memory(0)
+ map_buffer = not GstAllocators.is_dmabuf_memory(
+ memory) or not supports_dmabuf(interpreter_handle)
+ if not map_buffer:
+ _check_input_size(memory.size, expected_input_size)
+ fd = GstAllocators.dmabuf_memory_get_fd(memory)
+ try:
+ invoke_with_dmabuffer(interpreter_handle, fd, expected_input_size)
+ except RuntimeError:
+ # dma-buf input didn't work, likely due to old kernel driver. This
+ # situation can't be detected until one inference has been tried.
+ map_buffer = True
+ if map_buffer:
+ with _gst_buffer_map(input_data) as (pointer, actual_size):
+ assert actual_size >= expected_input_size
+ invoke_with_membuffer(interpreter_handle, pointer.value,
+ expected_input_size)
+ elif isinstance(input_data, np.ndarray):
+ _check_input_size(len(input_data), expected_input_size)
+ invoke_with_membuffer(interpreter_handle, input_data.ctypes.data,
+ expected_input_size)
+ else:
+ raise TypeError('input data type is not supported.')
diff --git a/scripts/build.sh b/scripts/build.sh
new file mode 100755
index 0000000..838ebc6
--- /dev/null
+++ b/scripts/build.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+#
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+set -ex
+
+readonly SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+readonly MAKEFILE="${SCRIPT_DIR}/../Makefile"
+readonly DOCKER_CPUS="${DOCKER_CPUS:=k8 aarch64 armv7a}"
+PYTHON_VERSIONS="35 36 37 38"
+
+while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --clean)
+ make -f "${MAKEFILE}" clean
+ shift
+ ;;
+ --python_versions)
+ PYTHON_VERSIONS=$2
+ shift
+ shift
+ ;;
+ *)
+ shift
+ ;;
+ esac
+done
+
+function docker_image {
+ case $1 in
+ 35) echo "ubuntu:16.04" ;;
+ 36) echo "ubuntu:18.04" ;;
+ 37) echo "debian:buster" ;;
+ 38) echo "ubuntu:20.04" ;;
+ *) echo "Unsupported python version: $1" 1>&2; exit 1 ;;
+ esac
+}
+
+for python_version in ${PYTHON_VERSIONS}; do
+ make DOCKER_CPUS="${DOCKER_CPUS}" \
+ DOCKER_IMAGE=$(docker_image "${python_version}") \
+ DOCKER_TARGETS="pybind tflite wheel tflite-wheel" \
+ -f "${MAKEFILE}" \
+ docker-build
+done
diff --git a/scripts/build_deb.sh b/scripts/build_deb.sh
new file mode 100755
index 0000000..9966ca1
--- /dev/null
+++ b/scripts/build_deb.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+#
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+set -ex
+
+readonly SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+readonly MAKEFILE="${SCRIPT_DIR}/../Makefile"
+readonly CMD="make deb tflite-deb && make -C libedgetpu_bin deb"
+
+"${SCRIPT_DIR}/build.sh"
+make DOCKER_SHELL_COMMAND="${CMD}" -f "${MAKEFILE}" docker-shell
diff --git a/scripts/runtime/install.sh b/scripts/runtime/install.sh
new file mode 100755
index 0000000..e4657da
--- /dev/null
+++ b/scripts/runtime/install.sh
@@ -0,0 +1,187 @@
+#!/bin/bash
+#
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+if [[ -d "${SCRIPT_DIR}/libedgetpu" ]]; then
+ LIBEDGETPU_DIR="${SCRIPT_DIR}/libedgetpu"
+else
+ LIBEDGETPU_DIR="${SCRIPT_DIR}/../../libedgetpu_bin"
+fi
+
+function info {
+ echo -e "\033[0;32m${1}\033[0m" # green
+}
+
+function warn {
+ echo -e "\033[0;33m${1}\033[0m" # yellow
+}
+
+function error {
+ echo -e "\033[0;31m${1}\033[0m" # red
+}
+
+function install_file {
+ local name="${1}"
+ local src="${2}"
+ local dst="${3}"
+
+ info "Installing ${name} [${dst}]..."
+ if [[ -f "${dst}" ]]; then
+ warn "File already exists. Replacing it..."
+ rm -f "${dst}"
+ fi
+ cp -a "${src}" "${dst}"
+}
+
+if [[ "${EUID}" != 0 ]]; then
+ error "Please use sudo to run as root."
+ exit 1
+fi
+
+if [[ -f /etc/mendel_version ]]; then
+ error "Looks like you're using a Coral Dev Board. You should instead use Debian packages to manage Edge TPU software."
+ exit 1
+fi
+
+readonly OS="$(uname -s)"
+readonly MACHINE="$(uname -m)"
+
+if [[ "${OS}" == "Linux" ]]; then
+ case "${MACHINE}" in
+ x86_64)
+ HOST_GNU_TYPE=x86_64-linux-gnu
+ CPU=k8
+ ;;
+ armv6l)
+ HOST_GNU_TYPE=arm-linux-gnueabihf
+ CPU=armv6
+ ;;
+ armv7l)
+ HOST_GNU_TYPE=arm-linux-gnueabihf
+ CPU=armv7a
+ ;;
+ aarch64)
+ HOST_GNU_TYPE=aarch64-linux-gnu
+ CPU=aarch64
+ ;;
+ *)
+ error "Your Linux platform is not supported."
+ exit 1
+ ;;
+ esac
+elif [[ "${OS}" == "Darwin" ]]; then
+ CPU=darwin
+
+ MACPORTS_PATH_AUTO="$(command -v port || true)"
+ MACPORTS_PATH="${MACPORTS_PATH_AUTO:-/opt/local/bin/port}"
+
+ BREW_PATH_AUTO="$(command -v brew || true)"
+ BREW_PATH="${BREW_PATH_AUTO:-/usr/local/bin/brew}"
+
+ if [[ -x "${MACPORTS_PATH}" ]]; then
+ DARWIN_INSTALL_COMMAND="${MACPORTS_PATH}"
+ DARWIN_INSTALL_USER="$(whoami)"
+ elif [[ -x "${BREW_PATH}" ]]; then
+ DARWIN_INSTALL_COMMAND="${BREW_PATH}"
+ DARWIN_INSTALL_USER="${SUDO_USER}"
+ else
+ error "You need to install either Homebrew or MacPorts first."
+ exit 1
+ fi
+else
+ error "Your operating system is not supported."
+ exit 1
+fi
+
+cat << EOM
+Warning: If you're using the Coral USB Accelerator, it may heat up during operation, depending
+on the computation workloads and operating frequency. Touching the metal part of the USB
+Accelerator after it has been operating for an extended period of time may lead to discomfort
+and/or skin burns. As such, if you enable the Edge TPU runtime using the maximum operating
+frequency, the USB Accelerator should be operated at an ambient temperature of 25°C or less.
+Alternatively, if you enable the Edge TPU runtime using the reduced operating frequency, then
+the device is intended to safely operate at an ambient temperature of 35°C or less.
+
+Google does not accept any responsibility for any loss or damage if the device
+is operated outside of the recommended ambient temperature range.
+
+Note: This question affects only USB-based Coral devices, and is irrelevant for PCIe devices.
+................................................................................
+Would you like to enable the maximum operating frequency for your Coral USB device? Y/N
+EOM
+
+read USE_MAX_FREQ
+case "${USE_MAX_FREQ}" in
+ [yY])
+ info "Using the maximum operating frequency for Coral USB devices."
+ FREQ_DIR=direct
+ ;;
+ *)
+ info "Using the reduced operating frequency for Coral USB devices."
+ FREQ_DIR=throttled
+ ;;
+esac
+
+if [[ "${CPU}" == "darwin" ]]; then
+ sudo -u "${DARWIN_INSTALL_USER}" "${DARWIN_INSTALL_COMMAND}" install libusb
+
+ DARWIN_INSTALL_LIB_DIR="$(dirname "$(dirname "${DARWIN_INSTALL_COMMAND}")")/lib"
+ LIBEDGETPU_LIB_DIR="/usr/local/lib"
+ mkdir -p "${LIBEDGETPU_LIB_DIR}"
+
+ install_file "Edge TPU runtime library" \
+ "${LIBEDGETPU_DIR}/${FREQ_DIR}/darwin/libedgetpu.1.0.dylib" \
+ "${LIBEDGETPU_LIB_DIR}"
+
+ install_file "Edge TPU runtime library symlink" \
+ "${LIBEDGETPU_DIR}/${FREQ_DIR}/darwin/libedgetpu.1.dylib" \
+ "${LIBEDGETPU_LIB_DIR}"
+
+ install_name_tool -id "${LIBEDGETPU_LIB_DIR}/libedgetpu.1.dylib" \
+ "${LIBEDGETPU_LIB_DIR}/libedgetpu.1.0.dylib"
+
+ install_name_tool -change "/opt/local/lib/libusb-1.0.0.dylib" \
+ "${DARWIN_INSTALL_LIB_DIR}/libusb-1.0.0.dylib" \
+ "${LIBEDGETPU_LIB_DIR}/libedgetpu.1.0.dylib"
+else
+ for pkg in libc6 libgcc1 libstdc++6 libusb-1.0-0; do
+ if ! dpkg -l "${pkg}" > /dev/null; then
+ PACKAGES+=" ${pkg}"
+ fi
+ done
+
+ if [[ -n "${PACKAGES}" ]]; then
+ info "Installing library dependencies:${PACKAGES}..."
+ apt-get update && apt-get install -y ${PACKAGES}
+ info "Done."
+ fi
+
+ if [[ -x "$(command -v udevadm)" ]]; then
+ install_file "device rule file" \
+ "${LIBEDGETPU_DIR}/edgetpu-accelerator.rules" \
+ "/etc/udev/rules.d/99-edgetpu-accelerator.rules"
+ udevadm control --reload-rules && udevadm trigger
+ info "Done."
+ fi
+
+ install_file "Edge TPU runtime library" \
+ "${LIBEDGETPU_DIR}/${FREQ_DIR}/${CPU}/libedgetpu.so.1.0" \
+ "/usr/lib/${HOST_GNU_TYPE}/libedgetpu.so.1.0"
+ ldconfig # Generates libedgetpu.so.1 symlink
+ info "Done."
+fi
diff --git a/scripts/runtime/uninstall.sh b/scripts/runtime/uninstall.sh
new file mode 100755
index 0000000..9ec0126
--- /dev/null
+++ b/scripts/runtime/uninstall.sh
@@ -0,0 +1,103 @@
+#!/bin/bash
+#
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+function info {
+ echo -e "\033[0;32m${1}\033[0m" # green
+}
+
+function warn {
+ echo -e "\033[0;33m${1}\033[0m" # yellow
+}
+
+function error {
+ echo -e "\033[0;31m${1}\033[0m" # red
+}
+
+if [[ "${EUID}" != 0 ]]; then
+ error "Please use sudo to run as root."
+ exit 1
+fi
+
+if [[ -f /etc/mendel_version ]]; then
+ error "Looks like you're using a Coral Dev Board. You should instead use Debian packages to manage Edge TPU software."
+ exit 1
+fi
+
+readonly OS="$(uname -s)"
+readonly MACHINE="$(uname -m)"
+
+if [[ "${OS}" == "Linux" ]]; then
+ case "${MACHINE}" in
+ x86_64)
+ HOST_GNU_TYPE=x86_64-linux-gnu
+ CPU_DIR=k8
+ ;;
+ armv7l)
+ HOST_GNU_TYPE=arm-linux-gnueabihf
+ CPU_DIR=armv7a
+ ;;
+ aarch64)
+ HOST_GNU_TYPE=aarch64-linux-gnu
+ CPU_DIR=aarch64
+ ;;
+ *)
+ error "Your Linux platform is not supported. There's nothing to uninstall."
+ exit 1
+ ;;
+ esac
+elif [[ "${OS}" == "Darwin" ]]; then
+ CPU=darwin
+else
+ error "Your operating system is not supported. There's nothing to uninstall."
+ exit 1
+fi
+
+if [[ "${CPU}" == "darwin" ]]; then
+ LIBEDGETPU_LIB_DIR="/usr/local/lib"
+
+ if [[ -f "${LIBEDGETPU_LIB_DIR}/libedgetpu.1.0.dylib" ]]; then
+ info "Uninstalling Edge TPU runtime library..."
+ rm -f "${LIBEDGETPU_LIB_DIR}/libedgetpu.1.0.dylib"
+ info "Done"
+ fi
+
+ if [[ -L "${LIBEDGETPU_LIB_DIR}/libedgetpu.1.dylib" ]]; then
+ info "Uninstalling Edge TPU runtime library symlink..."
+ rm -f "${LIBEDGETPU_LIB_DIR}/libedgetpu.1.dylib"
+ info "Done"
+ fi
+else
+ if [[ -x "$(command -v udevadm)" ]]; then
+ UDEV_RULE_PATH="/etc/udev/rules.d/99-edgetpu-accelerator.rules"
+ if [[ -f "${UDEV_RULE_PATH}" ]]; then
+ info "Uninstalling device rule file [${UDEV_RULE_PATH}]..."
+ rm -f "${UDEV_RULE_PATH}"
+ udevadm control --reload-rules && udevadm trigger
+ info "Done."
+ fi
+ fi
+
+ LIBEDGETPU_DST="/usr/lib/${HOST_GNU_TYPE}/libedgetpu.so.1.0"
+ if [[ -f "${LIBEDGETPU_DST}" ]]; then
+ info "Uninstalling Edge TPU runtime library [${LIBEDGETPU_DST}]..."
+ rm -f "${LIBEDGETPU_DST}"
+ ldconfig
+ info "Done."
+ fi
+fi
diff --git a/scripts/windows/build.bat b/scripts/windows/build.bat
new file mode 100644
index 0000000..505689e
--- /dev/null
+++ b/scripts/windows/build.bat
@@ -0,0 +1,85 @@
+echo off
+setlocal enabledelayedexpansion
+
+if not defined PYTHON ( set PYTHON=python )
+
+set BAZEL_CMD=bazel
+if defined BAZEL_OUTPUT_BASE (
+ set BAZEL_CMD=%BAZEL_CMD% --output_base=%BAZEL_OUTPUT_BASE%
+)
+
+set BAZEL_INFO_FLAGS=^
+--experimental_repo_remote_exec
+
+set BAZEL_VS=C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools
+set BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC
+call "%BAZEL_VC%\Auxiliary\Build\vcvars64.bat"
+
+for /f %%i in ('%BAZEL_CMD% info %BAZEL_INFO_FLAGS% output_base') do set "BAZEL_OUTPUT_BASE=%%i"
+for /f %%i in ('%BAZEL_CMD% info %BAZEL_INFO_FLAGS% output_path') do set "BAZEL_OUTPUT_PATH=%%i"
+for /f %%i in ('%PYTHON% -c "import sys;print(str(sys.version_info.major)+str(sys.version_info.minor))"') do set "PY3_VER=%%i"
+for /f %%i in ('%PYTHON% -c "import sys;print(sys.executable)"') do set "PYTHON_BIN_PATH=%%i"
+for /f %%i in ('%PYTHON% -c "import sys;print(sys.base_prefix)"') do set "PYTHON_LIB_PATH=%%i\Lib"
+
+set BAZEL_OUTPUT_PATH=%BAZEL_OUTPUT_PATH:/=\%
+set BAZEL_OUTPUT_BASE=%BAZEL_OUTPUT_BASE:/=\%
+set CPU=x64_windows
+set COMPILATION_MODE=opt
+set LIBEDGETPU_VERSION=direct
+
+set ROOTDIR=%~dp0\..\..\
+set BAZEL_OUT_DIR=%BAZEL_OUTPUT_PATH%\%CPU%-%COMPILATION_MODE%\bin
+set PYBIND_OUT_DIR=%ROOTDIR%\pycoral\pybind
+set TFLITE_WRAPPER_OUT_DIR=%ROOTDIR%\tflite_runtime
+set LIBEDGETPU_DIR=%ROOTDIR%\libedgetpu_bin\%LIBEDGETPU_VERSION%\x64_windows
+
+set TFLITE_WRAPPER_NAME=_pywrap_tensorflow_interpreter_wrapper.cp%PY3_VER%-win_amd64.pyd
+set PYBIND_WRAPPER_NAME=_pywrap_coral.cp%PY3_VER%-win_amd64.pyd
+set LIBEDGETPU_DLL_NAME=edgetpu.dll
+
+set TFLITE_WRAPPER_PATH=%TFLITE_WRAPPER_OUT_DIR%\%TFLITE_WRAPPER_NAME%
+set PYBIND_WRAPPER_PATH=%PYBIND_OUT_DIR%\%PYBIND_WRAPPER_NAME%
+set LIBEDGETPU_DLL_PATH=%LIBEDGETPU_DIR%\%LIBEDGETPU_DLL_NAME%
+
+:PROCESSARGS
+set ARG=%1
+if defined ARG (
+ if "%ARG%"=="/DBG" (
+ set COMPILATION_MODE=dbg
+ )
+ shift
+ goto PROCESSARGS
+)
+
+for /f "tokens=3" %%i in ('type %ROOTDIR%\WORKSPACE ^| findstr /C:"TENSORFLOW_COMMIT ="') do set "TENSORFLOW_COMMIT=%%i"
+set BAZEL_BUILD_FLAGS= ^
+--compilation_mode=%COMPILATION_MODE% ^
+--copt=/DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION ^
+--copt=/D_HAS_DEPRECATED_RESULT_OF ^
+--linkopt=/DEFAULTLIB:%LIBEDGETPU_DLL_PATH%.if.lib ^
+--experimental_repo_remote_exec ^
+--copt=/std:c++latest
+set BAZEL_QUERY_FLAGS=^
+--experimental_repo_remote_exec
+
+rem PYBIND
+%BAZEL_CMD% build %BAZEL_BUILD_FLAGS% ^
+ //src:edgetpu.res || goto :exit
+%BAZEL_CMD% build %BAZEL_BUILD_FLAGS% ^
+ --embed_label=%TENSORFLOW_COMMIT% ^
+ --stamp ^
+ //src:_pywrap_coral || goto :exit
+if not exist %PYBIND_OUT_DIR% md %PYBIND_OUT_DIR%
+type NUL >%PYBIND_OUT_DIR%\__init__.py
+copy %BAZEL_OUT_DIR%\src\_pywrap_coral.pyd %PYBIND_WRAPPER_PATH% >NUL
+
+rem TfLite
+%BAZEL_CMD% build %BAZEL_BUILD_FLAGS% ^
+ @org_tensorflow//tensorflow/lite/python/interpreter_wrapper:_pywrap_tensorflow_interpreter_wrapper || goto :exit
+if not exist %TFLITE_WRAPPER_OUT_DIR% md %TFLITE_WRAPPER_OUT_DIR%
+copy %BAZEL_OUT_DIR%\external\org_tensorflow\tensorflow\lite\python\interpreter_wrapper\_pywrap_tensorflow_interpreter_wrapper.pyd ^
+ %TFLITE_WRAPPER_PATH% >NUL
+copy %BAZEL_OUTPUT_BASE%\external\org_tensorflow\tensorflow\lite\python\interpreter.py %TFLITE_WRAPPER_OUT_DIR%
+
+:exit
+exit /b %ERRORLEVEL%
diff --git a/scripts/windows/build_wheel.bat b/scripts/windows/build_wheel.bat
new file mode 100644
index 0000000..6da516e
--- /dev/null
+++ b/scripts/windows/build_wheel.bat
@@ -0,0 +1,16 @@
+echo off
+setlocal enabledelayedexpansion
+
+if not defined PYTHON ( set PYTHON=python )
+set ROOTDIR=%~dp0\..\..\
+for /f %%i in ('%PYTHON% -c "import sys;print(str(sys.version_info.major)+str(sys.version_info.minor))"') do set "PY3_VER=%%i"
+set WRAPPER_NAME=_pywrap_coral.cp%PY3_VER%-win_amd64.pyd
+
+rem Build the code, in case it doesn't exist yet.
+call %ROOTDIR%\scripts\windows\build.bat || goto :exit
+
+%PYTHON% %ROOTDIR%\setup.py bdist_wheel -d %ROOTDIR%\dist
+rd /s /q build
+
+:exit
+exit /b %ERRORLEVEL%
\ No newline at end of file
diff --git a/scripts/windows/clean.bat b/scripts/windows/clean.bat
new file mode 100644
index 0000000..6e404f4
--- /dev/null
+++ b/scripts/windows/clean.bat
@@ -0,0 +1,11 @@
+echo off
+setlocal enabledelayedexpansion
+
+set ROOTDIR=%~dp0\..\..\
+
+bazel clean
+
+for /f %%i in ('dir /a:d /b %ROOTDIR%\bazel-*') do rd /q %%i
+rd /s /q %ROOTDIR%\pycoral\pybind
+rd /s /q %ROOTDIR%\tflite_runtime
+rd /s /q %ROOTDIR%\coral.egg-info
\ No newline at end of file
diff --git a/scripts/windows/docker_build.bat b/scripts/windows/docker_build.bat
new file mode 100644
index 0000000..a2f53f6
--- /dev/null
+++ b/scripts/windows/docker_build.bat
@@ -0,0 +1,23 @@
+echo off
+setlocal enabledelayedexpansion
+
+if not defined PY3_VER set PY3_VER=38
+set ROOTDIR=%~dp0\..\..\
+set TEST_DATA_DIR=%ROOTDIR%\..\test_data
+set LIBCORAL_DIR=%ROOTDIR%\..\libcoral
+set LIBEDGETPU_DIR=%ROOTDIR%\..\libedgetpu
+for /f %%i in ("%ROOTDIR%") do set "ROOTDIR=%%~fi"
+for /f %%i in ("%TEST_DATA_DIR%") do set "TEST_DATA_DIR=%%~fi"
+for /f %%i in ("%LIBCORAL_DIR%") do set "LIBCORAL_DIR=%%~fi"
+for /f %%i in ("%LIBEDGETPU_DIR%") do set "LIBEDGETPU_DIR=%%~fi"
+for /f "tokens=2 delims==" %%i in ('wmic os get /format:value ^| findstr TotalVisibleMemorySize') do set /A "MEM_KB=%%i >> 1"
+
+docker run -m %MEM_KB%KB --cpus %NUMBER_OF_PROCESSORS% --rm ^
+ -v %ROOTDIR%:c:\edgetpu ^
+ -v %TEST_DATA_DIR%:c:\edgetpu\test_data ^
+ -v %LIBCORAL_DIR%:c:\edgetpu\libcoral ^
+ -v %LIBEDGETPU_DIR%:c:\edgetpu\libedgetpu ^
+ -w c:\edgetpu ^
+ -e PYTHON=c:\python%PY3_VER%\python.exe ^
+ -e BAZEL_OUTPUT_BASE=c:\temp\edgetpu ^
+ edgetpu-win scripts\windows\build.bat
diff --git a/scripts/windows/docker_make_image.bat b/scripts/windows/docker_make_image.bat
new file mode 100644
index 0000000..bc9cee0
--- /dev/null
+++ b/scripts/windows/docker_make_image.bat
@@ -0,0 +1,6 @@
+echo off
+setlocal enabledelayedexpansion
+
+set ROOTDIR=%~dp0\..\..\
+
+docker build -t edgetpu-win -f %ROOTDIR%\docker\Dockerfile.windows %ROOTDIR%\docker
diff --git a/scripts/windows/docker_make_wheels.bat b/scripts/windows/docker_make_wheels.bat
new file mode 100644
index 0000000..902e11e
--- /dev/null
+++ b/scripts/windows/docker_make_wheels.bat
@@ -0,0 +1,56 @@
+echo off
+setlocal enabledelayedexpansion
+
+set ROOTDIR=%~dp0\..\..\
+set TEST_DATA_DIR=%ROOTDIR%\..\test_data
+set LIBCORAL_DIR=%ROOTDIR%\..\libcoral
+set LIBEDGETPU_DIR=%ROOTDIR%\..\libedgetpu
+for /f %%i in ("%ROOTDIR%") do set "ROOTDIR=%%~fi"
+for /f %%i in ("%TEST_DATA_DIR%") do set "TEST_DATA_DIR=%%~fi"
+for /f %%i in ("%LIBCORAL_DIR%") do set "LIBCORAL_DIR=%%~fi"
+for /f %%i in ("%LIBEDGETPU_DIR%") do set "LIBEDGETPU_DIR=%%~fi"
+for /f "tokens=2 delims==" %%i in ('wmic os get /format:value ^| findstr TotalVisibleMemorySize') do set /A "MEM_KB=%%i >> 1"
+
+rem Build Python 3.5 wheel
+call %ROOTDIR%\scripts\windows\clean.bat
+docker run -m %MEM_KB%KB --cpus %NUMBER_OF_PROCESSORS% --rm ^
+ -v %ROOTDIR%:c:\edgetpu ^
+ -v %TEST_DATA_DIR%:c:\edgetpu\test_data ^
+ -v %LIBCORAL_DIR%:c:\edgetpu\libcoral ^
+ -v %LIBEDGETPU_DIR%:c:\edgetpu\libedgetpu ^
+ -w c:\edgetpu ^
+ -e BAZEL_OUTPUT_BASE=c:\temp\edgetpu ^
+ -e PYTHON=c:\python35\python.exe edgetpu-win scripts\windows\build_wheel.bat
+
+rem Build Python 3.6 wheel
+call %ROOTDIR%\scripts\windows\clean.bat
+docker run -m %MEM_KB%KB --cpus %NUMBER_OF_PROCESSORS% --rm ^
+ -v %ROOTDIR%:c:\edgetpu ^
+ -v %TEST_DATA_DIR%:c:\edgetpu\test_data ^
+ -v %LIBCORAL_DIR%:c:\edgetpu\libcoral ^
+ -v %LIBEDGETPU_DIR%:c:\edgetpu\libedgetpu ^
+ -w c:\edgetpu ^
+ -e BAZEL_OUTPUT_BASE=c:\temp\edgetpu ^
+ -e PYTHON=c:\python36\python.exe edgetpu-win scripts\windows\build_wheel.bat
+
+rem Build Python 3.7 wheel
+call %ROOTDIR%\scripts\windows\clean.bat
+docker run -m %MEM_KB%KB --cpus %NUMBER_OF_PROCESSORS% --rm ^
+ -v %ROOTDIR%:c:\edgetpu ^
+ -v %TEST_DATA_DIR%:c:\edgetpu\test_data ^
+ -v %LIBCORAL_DIR%:c:\edgetpu\libcoral ^
+ -v %LIBEDGETPU_DIR%:c:\edgetpu\libedgetpu ^
+ -w c:\edgetpu ^
+ -e BAZEL_OUTPUT_BASE=c:\temp\edgetpu ^
+ -e PYTHON=c:\python37\python.exe edgetpu-win scripts\windows\build_wheel.bat
+
+rem Build Python 3.8 wheel
+call %ROOTDIR%\scripts\windows\clean.bat
+docker run -m %MEM_KB%KB --cpus %NUMBER_OF_PROCESSORS% --rm ^
+ -v %ROOTDIR%:c:\edgetpu ^
+ -v %TEST_DATA_DIR%:c:\edgetpu\test_data ^
+ -v %LIBCORAL_DIR%:c:\edgetpu\libcoral ^
+ -v %LIBEDGETPU_DIR%:c:\edgetpu\libedgetpu ^
+ -w c:\edgetpu ^
+ -e BAZEL_OUTPUT_BASE=c:\temp\edgetpu ^
+ -e PYTHON=c:\python38\python.exe edgetpu-win scripts\windows\build_wheel.bat
diff --git a/scripts/windows/install.bat b/scripts/windows/install.bat
new file mode 100644
index 0000000..711d844
--- /dev/null
+++ b/scripts/windows/install.bat
@@ -0,0 +1,54 @@
+@echo off
+setlocal enabledelayedexpansion
+
+rem Check for Admin privileges
+fsutil dirty query %systemdrive% >NUL
+if not %ERRORLEVEL% == 0 (
+ powershell Start-Process -FilePath '%0' -ArgumentList "elevated" -verb runas
+ exit /b
+)
+
+if exist "%~dp0\libedgetpu" (
+ rem Running with the script in the root
+ set ROOTDIR=%~dp0
+) else (
+ rem Running with the script in scripts\windows
+ set ROOTDIR=%~dp0\..\..\
+)
+
+cd /d "%ROOTDIR%"
+set ROOTDIR=%CD%
+
+echo Warning: During normal operation, the Edge TPU Accelerator may heat up,
+echo depending on the computation workloads and operating frequency. Touching the
+echo metal part of the device after it has been operating for an extended period of
+echo time may lead to discomfort and/or skin burns. As such, when running at the
+echo default operating frequency, the device is intended to safely operate at an
+echo ambient temperature of 35C or less. Or when running at the maximum operating
+echo frequency, it should be operated at an ambient temperature of 25C or less.
+echo.
+echo Google does not accept any responsibility for any loss or damage if the device
+echo is operated outside of the recommended ambient temperature range.
+echo ................................................................................
+set /p USE_MAX_FREQ="Would you like to enable the maximum operating frequency for the USB Accelerator? Y/N "
+if "%USE_MAX_FREQ%" == "y" set FREQ_DIR=direct
+if "%USE_MAX_FREQ%" == "Y" set FREQ_DIR=direct
+if not defined FREQ_DIR set FREQ_DIR=throttled
+
+echo Installing UsbDk
+start /wait msiexec /i "%ROOTDIR%\third_party\usbdk\UsbDk_1.0.21_x64.msi" /quiet /qb! /norestart
+
+echo Installing Windows drivers
+pnputil /add-driver "%ROOTDIR%\third_party\coral_accelerator_windows\*.inf" /install
+
+echo Installing performance counters
+lodctr /M:"%ROOTDIR%\third_party\coral_accelerator_windows\coral.man"
+
+echo Copying edgetpu and libusb to System32
+copy "%ROOTDIR%\libedgetpu\%FREQ_DIR%\x64_windows\edgetpu.dll" c:\windows\system32
+copy "%ROOTDIR%\third_party\libusb_win\libusb-1.0.dll" c:\windows\system32
+
+echo Install complete!
+rem If %1 is elevated, this means we were re-invoked to gain Administrator.
+rem In this case, we're in a new window, so call pause to allow the user to view output.
+if "%1" == "elevated" pause
diff --git a/scripts/windows/uninstall.bat b/scripts/windows/uninstall.bat
new file mode 100644
index 0000000..d54f67d
--- /dev/null
+++ b/scripts/windows/uninstall.bat
@@ -0,0 +1,43 @@
+@echo off
+setlocal enabledelayedexpansion
+
+rem Check for Admin privileges
+fsutil dirty query %systemdrive% >NUL
+if not %ERRORLEVEL% == 0 (
+ powershell Start-Process -FilePath '%0' -ArgumentList "elevated" -verb runas
+ exit /b
+)
+
+if exist "%~dp0\libedgetpu" (
+ rem Running with the script in the root
+ set ROOTDIR=%~dp0
+) else (
+ rem Running with the script in scripts\windows
+ set ROOTDIR=%~dp0\..\..\
+)
+
+cd /d "%ROOTDIR%""
+set ROOTDIR=%CD%
+
+echo Deleting edgetpu and libusb from System32
+del c:\windows\system32\edgetpu.dll
+del c:\windows\system32\libusb-1.0.dll
+
+echo Unistalling WinUSB drivers
+for /f "tokens=3" %%a in ('pnputil /enum-devices /class {88bae032-5a81-49f0-bc3d-a4ff138216d6} ^| findstr /b "Driver Name:"') do (
+ set infs=%%a !infs!
+)
+set infs=%infs:---=inf%
+echo %infs%
+for %%a in (%infs%) do (
+ echo %%a
+ pnputil /delete-driver %%a /uninstall
+)
+
+echo Uninstalling UsbDk
+start /wait msiexec /x "%ROOTDIR%\third_party\usbdk\UsbDk_1.0.21_x64.msi" /quiet /qb! /norestart
+
+echo Uninstall complete!
+rem If %1 is elevated, this means we were re-invoked to gain Administrator.
+rem In this case, we're in a new window, so call pause to allow the user to view output.
+if "%1" == "elevated" pause
diff --git a/scripts/windows/version.ps1 b/scripts/windows/version.ps1
new file mode 100644
index 0000000..23bdbc3
--- /dev/null
+++ b/scripts/windows/version.ps1
@@ -0,0 +1,33 @@
+param (
+ $File
+)
+
+function Get-ResourceString {
+ param (
+ $Key,
+ $File
+ )
+ $FileContent = ((Get-Content $File) -replace "`0", "")
+ $Lines = foreach ($Line in $FileContent) {
+ if ($Line -match "($Key)") {
+ $Line
+ }
+ }
+ $Lines = foreach ($Line in $Lines) {
+ if ($Line -match "($Key).*$") {
+ $Matches[0]
+ }
+ }
+ $Lines = $Lines -replace '[\W]', "`r`n"
+ $Line = foreach ($Line in $Lines) {
+ if ($Line -match "($Key).*\s") {
+ $Matches[0]
+ }
+ }
+ $Rest = $Line.Replace($Key, "")
+ $Output = "$Key`: $Rest".Trim()
+ Write-Output $Output
+}
+
+Get-ResourceString -File $File -Key 'CL_NUMBER'
+Get-ResourceString -File $File -Key 'TENSORFLOW_COMMIT'
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..9c65691
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,70 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+import importlib.machinery
+
+from setuptools import setup, find_packages
+from setuptools.dist import Distribution
+
+def read(filename):
+ path = os.path.join(os.path.abspath(os.path.dirname(__file__)), filename)
+ with open(path , 'r') as f:
+ return f.read()
+
+def find_version(text):
+ match = re.search(r"^__version__\s*=\s*['\"](.*)['\"]\s*$", text,
+ re.MULTILINE)
+ return match.group(1)
+
+setup(
+ name='pycoral',
+ description='Coral Python API',
+ long_description=read('README.md'),
+ license='Apache 2',
+ version=find_version(read('pycoral/__init__.py')),
+ author='Coral',
+ author_email='coral-support@google.com',
+ url='https://github.com/google-coral/pycoral',
+ classifiers=[
+ 'Development Status :: 5 - Production/Stable',
+ 'Intended Audience :: Developers',
+ 'Intended Audience :: Education',
+ 'Intended Audience :: Science/Research',
+ 'License :: OSI Approved :: Apache Software License',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.5',
+ 'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ 'Topic :: Scientific/Engineering',
+ 'Topic :: Scientific/Engineering :: Mathematics',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'Topic :: Software Development',
+ 'Topic :: Software Development :: Libraries',
+ 'Topic :: Software Development :: Libraries :: Python Modules',
+ ],
+ packages=find_packages(),
+ package_data={
+ '': [os.environ.get('WRAPPER_NAME', '*' + importlib.machinery.EXTENSION_SUFFIXES[-1])]
+ },
+ install_requires=[
+ 'numpy>=1.12.1',
+ 'Pillow>=4.0.0',
+ 'tflite-runtime==2.5.0',
+ ],
+ **({'has_ext_modules': lambda: True} if 'WRAPPER_NAME' in os.environ else {}),
+ python_requires='>=3.5.2',
+ )
diff --git a/src/BUILD b/src/BUILD
new file mode 100644
index 0000000..a67c96a
--- /dev/null
+++ b/src/BUILD
@@ -0,0 +1,87 @@
+load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+config_setting(
+ name = "windows",
+ values = {
+ "cpu": "x64_windows",
+ },
+)
+
+cc_library(
+ name = "builddata",
+ linkstamp = select({
+ ":windows": None, # Linkstamp doesn't work on Windows, https://github.com/bazelbuild/bazel/issues/6997
+ "//conditions:default": "builddata.cc",
+ }),
+)
+
+genrule(
+ name = "rc_tpl",
+ srcs = [
+ "builddata.cc",
+ "edgetpu.rc.tpl",
+ "edgetpu_rc.ps1",
+ ],
+ outs = [
+ "edgetpu.rc",
+ ],
+ cmd_ps = "$(location edgetpu_rc.ps1) " +
+ "-BuildDataFile $(location builddata.cc) " +
+ "-BuildStatus bazel-out\\stable-status.txt " +
+ "-ResFileTemplate $(location edgetpu.rc.tpl) " +
+ "-OutputFile $(location edgetpu.rc)",
+ stamp = 1,
+)
+
+genrule(
+ name = "dll_res_gen",
+ srcs = [
+ "edgetpu.rc",
+ ],
+ outs = [
+ "edgetpu.res",
+ ],
+ cmd_bat = "rc.exe /nologo /fo $(location edgetpu.res) $(location edgetpu.rc)",
+)
+
+pybind_extension(
+ name = "_pywrap_coral",
+ srcs = [
+ "coral_wrapper.cc",
+ ],
+ hdrs = [],
+ data = select({
+ ":windows": [":edgetpu.res"],
+ "//conditions:default": [],
+ }),
+ linkopts = select({
+ ":windows": ["$(location :edgetpu.res)"],
+ "//conditions:default": [],
+ }),
+ module_name = "_pywrap_coral",
+ deps = [
+ ":builddata",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
+ "@libcoral//coral:bbox",
+ "@libcoral//coral:tflite_utils",
+ "@libcoral//coral/learn:imprinting_engine",
+ "@libcoral//coral/learn:utils",
+ "@libcoral//coral/learn/backprop:softmax_regression_model",
+ "@libcoral//coral/pipeline:allocator",
+ "@libcoral//coral/pipeline:common",
+ "@libcoral//coral/pipeline:pipelined_model_runner",
+ "@libedgetpu//tflite/public:edgetpu",
+ "@org_tensorflow//tensorflow/lite:stateful_error_reporter",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@pybind11",
+ "@python",
+ ],
+)
diff --git a/src/builddata.cc b/src/builddata.cc
new file mode 100644
index 0000000..460109b
--- /dev/null
+++ b/src/builddata.cc
@@ -0,0 +1,21 @@
+extern "C" const char kPythonWrapperBuildEmbedLabel[];
+const char kPythonWrapperBuildEmbedLabel[] = BUILD_EMBED_LABEL;
+
+extern "C" const char kPythonWrapperBaseChangeList[];
+const char kPythonWrapperBaseChangeList[] = "CL_NUMBER=340495397";
+
+namespace {
+// Build a type whose constructor will contain references to all the build data
+// variables, preventing them from being GC'ed by the linker.
+struct KeepBuildDataVariables {
+ KeepBuildDataVariables() {
+ volatile int opaque_flag = 0;
+ if (!opaque_flag) return;
+
+ const void* volatile capture;
+ capture = &kPythonWrapperBuildEmbedLabel;
+ capture = &kPythonWrapperBaseChangeList;
+ static_cast(capture);
+ }
+} dummy;
+} // namespace
diff --git a/src/coral_wrapper.cc b/src/coral_wrapper.cc
new file mode 100644
index 0000000..5e1cb9a
--- /dev/null
+++ b/src/coral_wrapper.cc
@@ -0,0 +1,425 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "absl/types/span.h"
+#include "coral/bbox.h"
+#include "coral/learn/backprop/softmax_regression_model.h"
+#include "coral/learn/imprinting_engine.h"
+#include "coral/learn/utils.h"
+#include "coral/pipeline/allocator.h"
+#include "coral/pipeline/common.h"
+#include "coral/pipeline/pipelined_model_runner.h"
+#include "coral/tflite_utils.h"
+#include "pybind11/numpy.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/pytypes.h"
+#include "pybind11/stl.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/stateful_error_reporter.h"
+#include "tflite/public/edgetpu.h"
+
+namespace {
+namespace py = pybind11;
+
+template
+struct NumPyTypeImpl;
+
+template <>
+struct NumPyTypeImpl<4, true> {
+ enum { type = NPY_INT32 };
+};
+
+template <>
+struct NumPyTypeImpl<4, false> {
+ enum { type = NPY_UINT32 };
+};
+
+template <>
+struct NumPyTypeImpl<8, true> {
+ enum { type = NPY_INT64 };
+};
+
+template <>
+struct NumPyTypeImpl<8, false> {
+ enum { type = NPY_UINT64 };
+};
+
+template
+struct NumPyType {
+ enum { type = NumPyTypeImpl::value>::type };
+};
+
+template
+PyObject* PyArrayFromSpan(absl::Span span) {
+ npy_intp size = span.size();
+ void* pydata = malloc(size * sizeof(T));
+ std::memcpy(pydata, span.data(), size * sizeof(T));
+
+ PyObject* obj = PyArray_SimpleNewFromData(
+ 1, &size, NumPyType::type>::type, pydata);
+ PyArray_ENABLEFLAGS(reinterpret_cast(obj), NPY_ARRAY_OWNDATA);
+ return obj;
+}
+
+py::object Pyo(PyObject* ptr) { return py::reinterpret_steal(ptr); }
+
+using Strides = Eigen::Stride;
+using Scalar = Eigen::MatrixXf::Scalar;
+constexpr bool kRowMajor = Eigen::MatrixXf::Flags & Eigen::RowMajorBit;
+
+Eigen::MatrixXf TensorFromPyBuf(const py::buffer& b) {
+ py::buffer_info info = b.request();
+ if (info.format != py::format_descriptor::format())
+ throw std::runtime_error("Incompatible format: expected a float array!");
+ if (info.ndim != 2)
+ throw std::runtime_error("Incompatible buffer dimension!");
+ auto strides = Strides(info.strides[kRowMajor ? 0 : 1] / sizeof(Scalar),
+ info.strides[kRowMajor ? 1 : 0] / sizeof(Scalar));
+ auto map = Eigen::Map(
+ static_cast(info.ptr), info.shape[0], info.shape[1], strides);
+ return Eigen::MatrixXf(map);
+}
+
+template
+absl::Span BufferInfoSpan(const py::buffer_info& info) {
+ return absl::MakeSpan(reinterpret_cast(info.ptr), info.size);
+}
+
+std::unique_ptr LoadModel(
+ const std::string& model_path) {
+ auto model = tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
+ if (!model) throw std::invalid_argument("Failed to open file: " + model_path);
+ return model;
+}
+
+template
+py::bytes SerializeModel(T& engine) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto status = engine.SerializeModel(&fbb);
+ if (!status.ok()) throw std::runtime_error(std::string(status.message()));
+ return py::bytes(reinterpret_cast(fbb.GetBufferPointer()),
+ fbb.GetSize());
+}
+
+std::string GetRuntimeVersion() {
+ return ::edgetpu::EdgeTpuManager::GetSingleton()->Version();
+}
+
+TfLiteType NumpyDtypeToTfLiteType(const std::string& format) {
+ static std::unordered_map* type_map =
+ new std::unordered_map{
+ {py::format_descriptor::format(), kTfLiteFloat32},
+ {py::format_descriptor::format(), kTfLiteInt32},
+ {py::format_descriptor::format(), kTfLiteUInt8},
+ {py::format_descriptor::format(), kTfLiteInt64},
+ {py::format_descriptor::format(), kTfLiteInt16},
+ {py::format_descriptor::format(), kTfLiteInt8},
+ {py::format_descriptor::format(), kTfLiteFloat64},
+ };
+ const auto it = type_map->find(format);
+ if (it == type_map->end()) {
+ throw std::runtime_error("Unexpected numpy dtype: " + format);
+ } else {
+ return it->second;
+ }
+}
+
+py::dtype TfLiteTypeToNumpyDtype(const TfLiteType& type) {
+ // std::hash is added here because of a defect in std::unordered_map API,
+ // which is fixed in C++14 and newer version of libstdc++.
+ // https://stackoverflow.com/a/29618545
+ static std::unordered_map>* type_map =
+ new std::unordered_map>{
+ {kTfLiteFloat32, py::format_descriptor::format()},
+ {kTfLiteInt32, py::format_descriptor::format()},
+ {kTfLiteUInt8, py::format_descriptor::format()},
+ {kTfLiteInt64, py::format_descriptor::format()},
+ {kTfLiteInt16, py::format_descriptor::format()},
+ {kTfLiteInt8, py::format_descriptor::format()},
+ {kTfLiteFloat64, py::format_descriptor::format()},
+ };
+ const auto it = type_map->find(type);
+ if (it == type_map->end()) {
+ throw std::runtime_error("Unexpected TfLiteType: " +
+ std::string(TfLiteTypeGetName(type)));
+ } else {
+ return py::dtype(it->second);
+ }
+}
+
+class MallocBuffer : public coral::Buffer {
+ public:
+ explicit MallocBuffer(void* ptr) : ptr_(ptr) {}
+
+ void* ptr() override { return ptr_; }
+
+ private:
+ void* ptr_ = nullptr;
+};
+
+// Allocator with leaky `free` function. Caller should use std::free() to free
+// the underlying memory allocated by std::malloc; otherwise there will be
+// memory leaks.
+class LeakyMallocAllocator : public coral::Allocator {
+ public:
+ LeakyMallocAllocator() = default;
+
+ coral::Buffer* Alloc(size_t size) override {
+ return new MallocBuffer(std::malloc(size));
+ }
+
+ void Free(coral::Buffer* buffer) override {
+ // Note: the memory allocated by std::malloc is not freed here.
+ delete buffer;
+ }
+};
+
+} // namespace
+
+PYBIND11_MODULE(_pywrap_coral, m) {
+ // This function must be called in the initialization section of a module that
+ // will make use of the C-API (PyArray_SimpleNewFromData).
+ // It imports the module where the function-pointer table is stored and points
+ // the correct variable to it.
+ // Different with import_array() import_array1() has return value.
+ // https://docs.scipy.org/doc/numpy-1.14.2/reference/c-api.array.html
+ import_array1();
+
+ m.def("InvokeWithMemBuffer",
+ [](py::object interpreter_handle, intptr_t buffer, size_t size) {
+ auto* interpreter = reinterpret_cast(
+ interpreter_handle.cast());
+ auto status = coral::InvokeWithMemBuffer(
+ interpreter, reinterpret_cast(buffer), size,
+ static_cast(
+ interpreter->error_reporter()));
+ if (!status.ok())
+ throw std::runtime_error(std::string(status.message()));
+ });
+
+ m.def("InvokeWithBytes",
+ [](py::object interpreter_handle, py::bytes input_data) {
+ auto* interpreter = reinterpret_cast(
+ interpreter_handle.cast());
+ char* buffer;
+ ssize_t length;
+ PyBytes_AsStringAndSize(input_data.ptr(), &buffer, &length);
+ auto status = coral::InvokeWithMemBuffer(
+ interpreter, buffer, static_cast(length),
+ static_cast(
+ interpreter->error_reporter()));
+ if (!status.ok())
+ throw std::runtime_error(std::string(status.message()));
+ });
+
+ m.def("InvokeWithDmaBuffer",
+ [](py::object interpreter_handle, int dma_fd, size_t size) {
+ auto* interpreter = reinterpret_cast(
+ interpreter_handle.cast());
+ auto status = coral::InvokeWithDmaBuffer(
+ interpreter, dma_fd, size,
+ static_cast(
+ interpreter->error_reporter()));
+ if (!status.ok())
+ throw std::runtime_error(std::string(status.message()));
+ });
+
+ m.def("SupportsDmabuf", [](py::object interpreter_handle) {
+ auto* interpreter = reinterpret_cast(
+ interpreter_handle.cast());
+ auto* context = interpreter->primary_subgraph().context();
+ auto* edgetpu_context = static_cast(
+ context->GetExternalContext(context, kTfLiteEdgeTpuContext));
+ if (!edgetpu_context) return false;
+ auto device = edgetpu_context->GetDeviceEnumRecord();
+ return device.type == edgetpu::DeviceType::kApexPci;
+ });
+
+ m.def("GetRuntimeVersion", &GetRuntimeVersion,
+ R"pbdoc(
+ Returns the Edge TPU runtime (libedgetpu.so) version.
+
+ This runtime version is dynamically retrieved from the shared object.
+
+ Returns:
+ A string for the version name.
+ )pbdoc");
+
+ m.def(
+ "ListEdgeTpus",
+ []() {
+ py::list device_list;
+ for (const auto& item :
+ edgetpu::EdgeTpuManager::GetSingleton()->EnumerateEdgeTpu()) {
+ py::dict device;
+ device["type"] =
+ item.type == edgetpu::DeviceType::kApexPci ? "pci" : "usb";
+ device["path"] = item.path;
+ device_list.append(device);
+ }
+ return device_list;
+ },
+ R"pbdoc(
+ Lists all available Edge TPU devices.
+
+ Returns:
+ A list of dictionary, each representing a device record of type and path.
+ )pbdoc");
+
+ py::class_(m, "ImprintingEnginePythonWrapper")
+ .def(py::init([](const std::string& model_path, bool keep_classes) {
+ std::unique_ptr model;
+ auto status = coral::ImprintingModel::Create(
+ *LoadModel(model_path)->GetModel(), &model);
+ if (!status.ok())
+ throw std::invalid_argument(std::string(status.message()));
+ return coral::ImprintingEngine::Create(std::move(model), keep_classes);
+ }))
+ .def("EmbeddingDim",
+ [](coral::ImprintingEngine& self) { return self.embedding_dim(); })
+ .def("NumClasses",
+ [](coral::ImprintingEngine& self) {
+ return self.GetClasses().size();
+ })
+ .def("SerializeExtractorModel",
+ [](coral::ImprintingEngine& self) {
+ auto buffer = self.ExtractorModelBuffer();
+ return py::bytes(buffer.data(), buffer.size());
+ })
+ .def("SerializeModel",
+ [](coral::ImprintingEngine& self) { return SerializeModel(self); })
+ .def("Train", [](coral::ImprintingEngine& self,
+ py::array_t weights_array, int class_id) {
+ auto request = weights_array.request();
+ if (request.shape != std::vector{self.embedding_dim()})
+ throw std::runtime_error("Invalid weights array shape.");
+
+ const auto* weights = reinterpret_cast(request.ptr);
+ auto status =
+ self.Train(absl::MakeSpan(weights, self.embedding_dim()), class_id);
+ if (!status.ok())
+ throw std::runtime_error(std::string(status.message()));
+ });
+ py::class_(m, "TrainConfigWrapper")
+ .def(py::init());
+ py::class_(m, "TrainingDataWrapper")
+ .def(py::init<>([](const py::buffer& training_data,
+ const py::buffer& validation_data,
+ const std::vector& training_labels,
+ const std::vector& validation_labels) {
+ auto self = absl::make_unique();
+ self->training_data = TensorFromPyBuf(training_data);
+ self->validation_data = TensorFromPyBuf(validation_data);
+ self->training_labels = training_labels;
+ self->validation_labels = validation_labels;
+ return self;
+ }));
+ py::class_(m, "SoftmaxRegressionModelWrapper")
+ .def(py::init())
+ .def("Train",
+ [](coral::SoftmaxRegressionModel& self,
+ const coral::TrainingData& training_data,
+ const coral::TrainConfig& train_config, float learning_rate) {
+ return self.Train(training_data, train_config, learning_rate);
+ })
+ .def("GetAccuracy",
+ [](coral::SoftmaxRegressionModel& self,
+ const py::buffer& training_data,
+ const std::vector& training_labels) {
+ return self.GetAccuracy(TensorFromPyBuf(training_data),
+ training_labels);
+ })
+ .def("AppendLayersToEmbeddingExtractor",
+ [](coral::SoftmaxRegressionModel& self,
+ const std::string& in_model_path) {
+ flatbuffers::FlatBufferBuilder fbb;
+ self.AppendLayersToEmbeddingExtractor(
+ *LoadModel(in_model_path)->GetModel(), &fbb);
+ return py::bytes(reinterpret_cast(fbb.GetBufferPointer()),
+ fbb.GetSize());
+ });
+
+ py::class_(m, "PipelinedModelRunnerWrapper")
+ .def(py::init([](const py::list& list) {
+ static coral::Allocator* output_tensor_allocator =
+ new LeakyMallocAllocator();
+ std::vector interpreters(list.size());
+ for (int i = 0; i < list.size(); ++i) {
+ interpreters[i] =
+ reinterpret_cast(list[i].cast());
+ }
+ return absl::make_unique(
+ interpreters, /*input_tensor_allocator=*/nullptr,
+ output_tensor_allocator);
+ }))
+ .def("SetInputQueueSize", &coral::PipelinedModelRunner::SetInputQueueSize)
+ .def("SetOutputQueueSize",
+ &coral::PipelinedModelRunner::SetOutputQueueSize)
+ .def("Push",
+ [](coral::PipelinedModelRunner& self, py::list& list) -> bool {
+ std::vector input_tensors(list.size());
+ for (int i = 0; i < list.size(); ++i) {
+ const auto info = list[i].cast().request();
+ input_tensors[i].type = NumpyDtypeToTfLiteType(info.format);
+ input_tensors[i].bytes = info.size * info.itemsize;
+ input_tensors[i].buffer = self.GetInputTensorAllocator()->Alloc(
+ input_tensors[i].bytes);
+ std::memcpy(input_tensors[i].buffer->ptr(), info.ptr,
+ input_tensors[i].bytes);
+ }
+ // Release GIL because Push can be blocking (if input queue size is
+ // bigger than input queue size threshold).
+ py::gil_scoped_release release;
+ auto push_status = self.Push(input_tensors);
+ py::gil_scoped_acquire acquire;
+ return push_status;
+ })
+ .def("Pop", [](coral::PipelinedModelRunner& self) -> py::object {
+ std::vector output_tensors;
+
+ // Release GIL because Pop is blocking.
+ py::gil_scoped_release release;
+ self.Pop(&output_tensors);
+ py::gil_scoped_acquire acquire;
+
+ if (output_tensors.empty()) {
+ return py::none();
+ }
+
+ py::list result;
+ for (auto tensor : output_tensors) {
+ // Underlying memory's ownership is passed to numpy object.
+ py::capsule free_when_done(tensor.buffer->ptr(),
+ [](void* ptr) { std::free(ptr); });
+ result.append(py::array(TfLiteTypeToNumpyDtype(tensor.type),
+ /*shape=*/{tensor.bytes},
+ /*strides=*/{1}, tensor.buffer->ptr(),
+ free_when_done));
+ self.GetOutputTensorAllocator()->Free(tensor.buffer);
+ }
+ return result;
+ });
+}
diff --git a/src/edgetpu.rc.tpl b/src/edgetpu.rc.tpl
new file mode 100644
index 0000000..17e7c98
--- /dev/null
+++ b/src/edgetpu.rc.tpl
@@ -0,0 +1,36 @@
+// Redefine some constants we would usually get from winver.h
+// Bazel doesn't know the correct include paths to pass along
+// to rc to pick up the header.
+#define VS_VERSION_INFO 1
+#define VS_FFI_FILEFLAGSMASK 0x3FL
+#define VS_FF_DEBUG 0x1L
+#define VOS__WINDOWS32 0x4L
+#define VFT_DLL 0x2L
+
+#define CL_NUMBER_STR "CL_NUMBER_TEMPLATE\040"
+#define TENSORFLOW_COMMIT_STR "TENSORFLOW_COMMIT_TEMPLATE\040"
+
+VS_VERSION_INFO VERSIONINFO
+FILEFLAGSMASK VS_FFI_FILEFLAGSMASK
+FILEFLAGS 0
+FILEOS VOS__WINDOWS32
+FILETYPE VFT_DLL
+FILESUBTYPE 0
+BEGIN
+ BLOCK "StringFileInfo"
+ BEGIN
+ BLOCK "040904E4"
+ BEGIN
+ VALUE "FileDescription", "EdgeTPU Python library\0"
+ VALUE "InternalName", "_pywrap_coral.pyd\0"
+ VALUE "LegalCopyright", "(C) 2019-2020 Google, LLC\0"
+ VALUE "ProductName", "edgetpu\0"
+ VALUE "CL_NUMBER", CL_NUMBER_STR
+ VALUE "TENSORFLOW_COMMIT", TENSORFLOW_COMMIT_STR
+ END
+ END
+ BLOCK "VarFileInfo"
+ BEGIN
+ VALUE "Translation", 0x0409, 1252
+ END
+END
diff --git a/src/edgetpu_rc.ps1 b/src/edgetpu_rc.ps1
new file mode 100644
index 0000000..9f26702
--- /dev/null
+++ b/src/edgetpu_rc.ps1
@@ -0,0 +1,15 @@
+param (
+ $BuildStatusFile,
+ $BuildDataFile,
+ $ResFileTemplate,
+ $OutputFile
+)
+
+$BuildStatus = Get-Content $BuildStatusFile
+$BuildData = Get-Content $BuildDataFile
+$ResFile = Get-Content $ResFileTemplate
+$ClNumber = (((-split ($BuildData -match 'CL_NUMBER'))[-1]) -split '=')[-1].Trim("`";")
+$TensorflowCommit = (-split ($BuildStatus -match 'BUILD_EMBED_LABEL'))[-1]
+$ResFile = $ResFile.Replace('CL_NUMBER_TEMPLATE', $ClNumber)
+$ResFile = $ResFile.Replace('TENSORFLOW_COMMIT_TEMPLATE', $TensorflowCommit)
+Out-File -FilePath $OutputFile -InputObject $ResFile -Encoding unicode
diff --git a/test_data b/test_data
new file mode 160000
index 0000000..c21de44
--- /dev/null
+++ b/test_data
@@ -0,0 +1 @@
+Subproject commit c21de4450f88a20ac5968628d375787745932a5a
diff --git a/tests/classify_test.py b/tests/classify_test.py
new file mode 100644
index 0000000..460de7d
--- /dev/null
+++ b/tests/classify_test.py
@@ -0,0 +1,124 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+
+from PIL import Image
+import unittest
+from pycoral.adapters import classify
+from pycoral.adapters import common
+from pycoral.utils.edgetpu import make_interpreter
+from tests.test_utils import coral_test_main
+from tests.test_utils import test_data_path
+
+CHICKADEE = 20
+TABBY_CAT = 282
+TIGER_CAT = 283
+EGYPTIAN_CAT = 286
+
+EFFICIENTNET_IMAGE_QUANTIZATION = (1 / 128, 127)
+
+
+def test_image(image_file, size):
+ return Image.open(test_data_path(image_file)).resize(size, Image.NEAREST)
+
+
+def rescale_image(image, image_quantization, tensor_quatization, tensor_dtype):
+ scale0, zero_point0 = image_quantization
+ scale, zero_point = tensor_quatization
+
+ min_value = np.iinfo(tensor_dtype).min
+ max_value = np.iinfo(tensor_dtype).max
+
+ def rescale(x):
+ # The following is the same as y = (x - a) / b, where
+ # b = scale / scale0 and a = zero_point0 - b * zero_point.
+ y = int(zero_point + (scale0 * (x - zero_point0)) / scale)
+ return max(min_value, min(y, max_value))
+
+ rescale = np.vectorize(rescale, otypes=[tensor_dtype])
+ return rescale(image)
+
+
+def classify_image(model_file, image_file, image_quantization=None):
+ """Runs image classification and returns result with the highest score.
+
+ Args:
+ model_file: string, model file name.
+ image_file: string, image file name.
+ image_quantization: (scale: float, zero_point: float), assumed image
+ quantization parameters.
+
+ Returns:
+ Classification result with the highest score as (index, score) tuple.
+ """
+ interpreter = make_interpreter(test_data_path(model_file))
+ interpreter.allocate_tensors()
+ image = test_image(image_file, common.input_size(interpreter))
+
+ input_type = common.input_details(interpreter, 'dtype')
+ if np.issubdtype(input_type, np.floating):
+ # This preprocessing is specific to MobileNet V1 with floating point input.
+ image = (input_type(image) - 127.5) / 127.5
+
+ if np.issubdtype(input_type, np.integer) and image_quantization:
+ image = rescale_image(image, image_quantization,
+ common.input_details(interpreter, 'quantization'),
+ input_type)
+
+ common.set_input(interpreter, image)
+ interpreter.invoke()
+ return classify.get_classes(interpreter)[0]
+
+
+def mobilenet_v1(depth_multiplier, input_size):
+ return 'mobilenet_v1_%s_%d_quant_edgetpu.tflite' % (depth_multiplier,
+ input_size)
+
+
+def mobilenet_v1_float_io(depth_multiplier, input_size):
+ return 'mobilenet_v1_%s_%d_ptq_float_io_legacy_edgetpu.tflite' % (
+ depth_multiplier, input_size)
+
+
+def efficientnet(input_type):
+ return 'efficientnet-edgetpu-%s_quant_edgetpu.tflite' % input_type
+
+
+class TestClassify(unittest.TestCase):
+
+ def test_mobilenet_v1_100_224(self):
+ index, score = classify_image(mobilenet_v1(1.0, 224), 'cat.bmp')
+ self.assertEqual(index, EGYPTIAN_CAT)
+ self.assertGreater(score, 0.78)
+
+ def test_mobilenet_v1_050_160(self):
+ index, score = classify_image(mobilenet_v1(0.5, 160), 'cat.bmp')
+ self.assertEqual(index, EGYPTIAN_CAT)
+ self.assertGreater(score, 0.67)
+
+ def test_mobilenet_v1_float_224(self):
+ index, score = classify_image(mobilenet_v1_float_io(1.0, 224), 'cat.bmp')
+ self.assertEqual(index, EGYPTIAN_CAT)
+ self.assertGreater(score, 0.7)
+
+ def test_efficientnet_l(self):
+ index, score = classify_image(
+ efficientnet('L'), 'cat.bmp', EFFICIENTNET_IMAGE_QUANTIZATION)
+ self.assertEqual(index, EGYPTIAN_CAT)
+ self.assertGreater(score, 0.65)
+
+
+if __name__ == '__main__':
+ coral_test_main()
diff --git a/tests/detect_test.py b/tests/detect_test.py
new file mode 100644
index 0000000..05b4317
--- /dev/null
+++ b/tests/detect_test.py
@@ -0,0 +1,143 @@
+# Lint as: python3
+# pylint:disable=g-generic-assert
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from PIL import Image
+
+import unittest
+from pycoral.adapters import common
+from pycoral.adapters import detect
+from pycoral.utils.edgetpu import make_interpreter
+from tests.test_utils import coral_test_main
+from tests.test_utils import test_data_path
+
+BBox = detect.BBox
+
+CAT = 16 # coco_labels.txt
+ABYSSINIAN = 0 # pet_labels.txt
+
+
+def get_objects(model_file, image_file, score_threshold=0.0):
+ interpreter = make_interpreter(test_data_path(model_file))
+ interpreter.allocate_tensors()
+ image = Image.open(test_data_path(image_file))
+ _, scale = common.set_resized_input(
+ interpreter, image.size, lambda size: image.resize(size, Image.ANTIALIAS))
+ interpreter.invoke()
+ return detect.get_objects(
+ interpreter, score_threshold=score_threshold, image_scale=scale)
+
+
+def face_model():
+ return 'ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite'
+
+
+def coco_model(version):
+ return 'ssd_mobilenet_v%d_coco_quant_postprocess_edgetpu.tflite' % version
+
+
+def fine_tuned_model():
+ return 'ssd_mobilenet_v1_fine_tuned_pet_edgetpu.tflite'
+
+
+class BBoxTest(unittest.TestCase):
+
+ def test_basic(self):
+ bbox = BBox(100, 110, 200, 210)
+ self.assertEqual(bbox.xmin, 100)
+ self.assertEqual(bbox.ymin, 110)
+ self.assertEqual(bbox.xmax, 200)
+ self.assertEqual(bbox.ymax, 210)
+
+ self.assertTrue(bbox.valid)
+
+ self.assertEqual(bbox.width, 100)
+ self.assertEqual(bbox.height, 100)
+
+ self.assertEqual(bbox.area, 10000)
+
+ def test_scale(self):
+ self.assertEqual(BBox(1, 1, 10, 20).scale(3, 4), BBox(3, 4, 30, 80))
+
+ def test_translate(self):
+ self.assertEqual(BBox(1, 1, 10, 20).translate(10, 20), BBox(11, 21, 20, 40))
+
+ def test_map(self):
+ self.assertEqual(BBox(1.1, 2.1, 3.1, 4.1).map(int), BBox(1, 2, 3, 4))
+ self.assertEqual(BBox(1.9, 2.9, 3.9, 4.9).map(int), BBox(1, 2, 3, 4))
+
+ def test_intersect_valid(self):
+ a = BBox(0, 0, 200, 200)
+ b = BBox(100, 100, 300, 300)
+
+ self.assertAlmostEqual(BBox.iou(a, b), 0.14286, delta=0.0001)
+ self.assertEqual(BBox.intersect(a, b), BBox(100, 100, 200, 200))
+
+ def test_intersect_invalid(self):
+ a = BBox(0, 0, 10, 20)
+ b = BBox(20, 30, 25, 35)
+ self.assertAlmostEqual(BBox.iou(a, b), 0.0)
+ self.assertEqual(BBox.intersect(a, b), BBox(20, 30, 10, 20))
+
+ def test_union(self):
+ self.assertEqual(
+ BBox.union(BBox(0, 0, 10, 20), BBox(50, 50, 60, 70)),
+ BBox(0, 0, 60, 70))
+
+
+class DetectTest(unittest.TestCase):
+
+ def assert_bbox_almost_equal(self, first, second, overlap_factor=0.95):
+ self.assertGreaterEqual(
+ BBox.iou(first, second),
+ overlap_factor,
+ msg='iou(%s, %s) is less than expected' % (first, second))
+
+ def test_face(self):
+ objs = get_objects(face_model(), 'grace_hopper.bmp')
+ self.assertEqual(len(objs), 1)
+ self.assertGreater(objs[0].score, 0.996)
+ self.assert_bbox_almost_equal(objs[0].bbox,
+ BBox(xmin=125, ymin=40, xmax=402, ymax=363))
+
+ def test_coco_v1(self):
+ objs = get_objects(coco_model(version=1), 'cat.bmp')
+ self.assertGreater(len(objs), 0)
+ obj = objs[0]
+ self.assertEqual(obj.id, CAT)
+ self.assertGreater(obj.score, 0.7)
+ self.assert_bbox_almost_equal(obj.bbox,
+ BBox(xmin=29, ymin=39, xmax=377, ymax=347))
+
+ def test_coco_v2(self):
+ objs = get_objects(coco_model(version=2), 'cat.bmp')
+ self.assertGreater(len(objs), 0)
+ obj = objs[0]
+ self.assertEqual(obj.id, CAT)
+ self.assertGreater(obj.score, 0.9)
+ self.assert_bbox_almost_equal(obj.bbox,
+ BBox(xmin=43, ymin=35, xmax=358, ymax=333))
+
+ def test_fine_tuned(self):
+ objs = get_objects(fine_tuned_model(), 'cat.bmp')
+ self.assertGreater(len(objs), 0)
+ obj = objs[0]
+ self.assertEqual(obj.id, ABYSSINIAN)
+ self.assertGreater(obj.score, 0.88)
+ self.assert_bbox_almost_equal(obj.bbox,
+ BBox(xmin=177, ymin=37, xmax=344, ymax=216))
+
+
+if __name__ == '__main__':
+ coral_test_main()
diff --git a/tests/edgetpu_utils_test.py b/tests/edgetpu_utils_test.py
new file mode 100644
index 0000000..0bc75e2
--- /dev/null
+++ b/tests/edgetpu_utils_test.py
@@ -0,0 +1,188 @@
+# Lint as: python3
+# pylint:disable=# pylint:disable=g-generic-assert
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import ctypes
+import ctypes.util
+import io
+
+import numpy as np
+
+from pycoral.utils import edgetpu
+from tests import test_utils
+import unittest
+
+
+# Detect whether GStreamer is available.
+# This code session is copied from utils/edgetpu.py.
+class _GstMapInfo(ctypes.Structure):
+ _fields_ = [
+ ('memory', ctypes.c_void_p), # GstMemory *memory
+ ('flags', ctypes.c_int), # GstMapFlags flags
+ ('data', ctypes.c_void_p), # guint8 *data
+ ('size', ctypes.c_size_t), # gsize size
+ ('maxsize', ctypes.c_size_t), # gsize maxsize
+ ('user_data', ctypes.c_void_p * 4), # gpointer user_data[4]
+ ('_gst_reserved', ctypes.c_void_p * 4)
+ ] # GST_PADDING
+
+
+_libgst = None
+try:
+ # pylint:disable=g-import-not-at-top
+ import gi
+ gi.require_version('Gst', '1.0')
+ from gi.repository import Gst
+ _libgst = ctypes.CDLL(ctypes.util.find_library('gstreamer-1.0'))
+ _libgst.gst_buffer_map.argtypes = [
+ ctypes.c_void_p,
+ ctypes.POINTER(_GstMapInfo), ctypes.c_int
+ ]
+ _libgst.gst_buffer_map.restype = ctypes.c_int
+ _libgst.gst_buffer_unmap.argtypes = [
+ ctypes.c_void_p, ctypes.POINTER(_GstMapInfo)
+ ]
+ _libgst.gst_buffer_unmap.restype = None
+ Gst.init(None)
+except (ImportError, ValueError, OSError):
+ pass
+
+
+def read_file(filename):
+ with open(filename, mode='rb') as f:
+ return f.read()
+
+
+def required_input_array_size(interpreter):
+ input_shape = interpreter.get_input_details()[0]['shape']
+ return np.prod(input_shape)
+
+
+# Use --config=asan for better coverage.
+class TestEdgeTpuUtils(unittest.TestCase):
+
+ def _default_test_model_path(self):
+ return test_utils.test_data_path(
+ 'mobilenet_v1_1.0_224_quant_edgetpu.tflite')
+
+ def test_load_from_model_file(self):
+ edgetpu.make_interpreter(self._default_test_model_path())
+
+ def test_load_from_model_content(self):
+ with io.open(self._default_test_model_path(), 'rb') as model_file:
+ edgetpu.make_interpreter(model_file.read())
+
+ def test_load_from_invalid_model_path(self):
+ with self.assertRaisesRegex(
+ ValueError, 'Could not open \'invalid_model_path.tflite\'.'):
+ edgetpu.make_interpreter('invalid_model_path.tflite')
+
+ def test_load_with_device(self):
+ edgetpu.make_interpreter(self._default_test_model_path(), device=':0')
+
+ def test_load_with_nonexistent_device(self):
+ # Assume that there can not be 1000 Edge TPU devices connected.
+ with self.assertRaisesRegex(ValueError, 'Failed to load delegate'):
+ edgetpu.make_interpreter(self._default_test_model_path(), device=':1000')
+
+ def test_load_with_invalid_device_str(self):
+ with self.assertRaisesRegex(ValueError, 'Failed to load delegate'):
+ edgetpu.make_interpreter(self._default_test_model_path(), device='foo')
+
+ def _run_inference_with_different_input_types(self, interpreter, input_data):
+ """Tests inference with different input types.
+
+ It doesn't check correctness of inference. Instead it checks inference
+ repeatability with different input types.
+
+ Args:
+ interpreter : A tflite interpreter.
+ input_data (list): A 1-D list as the input tensor.
+ """
+ output_index = interpreter.get_output_details()[0]['index']
+ # numpy array
+ np_input = np.asarray(input_data, np.uint8)
+ edgetpu.run_inference(interpreter, np_input)
+ ret = interpreter.tensor(output_index)()
+ ret0 = np.copy(ret)
+ self.assertTrue(np.array_equal(ret0, ret))
+ # bytes
+ bytes_input = bytes(input_data)
+ edgetpu.run_inference(interpreter, bytes_input)
+ ret = interpreter.tensor(output_index)()
+ self.assertTrue(np.array_equal(ret0, ret))
+ # ctypes
+ edgetpu.run_inference(
+ interpreter, (np_input.ctypes.data_as(ctypes.c_void_p), np_input.size))
+ ret = interpreter.tensor(output_index)()
+ self.assertTrue(np.array_equal(ret0, ret))
+ # Gst buffer
+ if _libgst:
+ gst_input = Gst.Buffer.new_wrapped(bytes_input)
+ edgetpu.run_inference(interpreter, gst_input)
+ self.assertTrue(np.array_equal(ret0, ret))
+ else:
+ print('Can not import gi. Skip test on Gst.Buffer input type.')
+
+ def test_run_inference_with_different_types(self):
+ interpreter = edgetpu.make_interpreter(self._default_test_model_path())
+ interpreter.allocate_tensors()
+ input_size = required_input_array_size(interpreter)
+ input_data = test_utils.generate_random_input(1, input_size)
+ self._run_inference_with_different_input_types(interpreter, input_data)
+
+ def test_run_inference_larger_input_size(self):
+ interpreter = edgetpu.make_interpreter(self._default_test_model_path())
+ interpreter.allocate_tensors()
+ input_size = required_input_array_size(interpreter)
+ input_data = test_utils.generate_random_input(1, input_size + 1)
+ with self.assertRaisesRegex(ValueError,
+ 'input size=150529, expected=150528'):
+ self._run_inference_with_different_input_types(interpreter, input_data)
+
+ def test_run_inference_smaller_input_size(self):
+ interpreter = edgetpu.make_interpreter(self._default_test_model_path())
+ interpreter.allocate_tensors()
+ input_size = required_input_array_size(interpreter)
+ input_data = test_utils.generate_random_input(1, input_size - 1)
+ with self.assertRaisesRegex(ValueError,
+ 'input size=150527, expected=150528'):
+ self._run_inference_with_different_input_types(interpreter, input_data)
+
+ def test_invoke_with_dma_buffer_model_not_ready(self):
+ interpreter = edgetpu.make_interpreter(self._default_test_model_path())
+ input_size = 224 * 224 * 3
+ # Note: Exception is triggered because interpreter.allocate_tensors() is not
+ # called.
+ with self.assertRaisesRegex(RuntimeError,
+ 'Invoke called on model that is not ready.'):
+ edgetpu.invoke_with_dmabuffer(interpreter._native_handle(), 0, input_size)
+
+ def test_invoke_with_mem_buffer_model_not_ready(self):
+ interpreter = edgetpu.make_interpreter(self._default_test_model_path())
+ input_size = 224 * 224 * 3
+ np_input = np.zeros(input_size, dtype=np.uint8)
+ # Note: Exception is triggered because interpreter.allocate_tensors() is not
+ # called.
+ with self.assertRaisesRegex(RuntimeError,
+ 'Invoke called on model that is not ready.'):
+ edgetpu.invoke_with_membuffer(interpreter._native_handle(),
+ np_input.ctypes.data, input_size)
+
+ def test_list_edge_tpu_paths(self):
+ self.assertGreater(len(edgetpu.list_edge_tpus()), 0)
+
+
+if __name__ == '__main__':
+ test_utils.coral_test_main()
diff --git a/tests/imprinting_engine_test.py b/tests/imprinting_engine_test.py
new file mode 100644
index 0000000..2c097a0
--- /dev/null
+++ b/tests/imprinting_engine_test.py
@@ -0,0 +1,174 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+from PIL import Image
+
+from pycoral.adapters import classify
+from pycoral.adapters import common
+from pycoral.learn.imprinting.engine import ImprintingEngine
+from pycoral.utils.edgetpu import make_interpreter
+from tests import test_utils
+import unittest
+
+_MODEL_LIST = [
+ 'mobilenet_v1_1.0_224_l2norm_quant.tflite',
+ 'mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite'
+]
+
+TrainPoint = collections.namedtuple('TainPoint', ['images', 'class_id'])
+TestPoint = collections.namedtuple('TainPoint', ['image', 'class_id', 'score'])
+
+
+def set_input(interpreter, image):
+ size = common.input_size(interpreter)
+ common.set_input(interpreter, image.resize(size, Image.NEAREST))
+
+
+class TestImprintingEnginePythonAPI(unittest.TestCase):
+
+ def _train_and_test(self, model_path, train_points, test_points,
+ keep_classes):
+ # Train.
+ engine = ImprintingEngine(model_path, keep_classes)
+
+ extractor = make_interpreter(
+ engine.serialize_extractor_model(), device=':0')
+ extractor.allocate_tensors()
+
+ for point in train_points:
+ for image in point.images:
+ with test_utils.test_image('imprinting', image) as img:
+ set_input(extractor, img)
+ extractor.invoke()
+ embedding = classify.get_scores(extractor)
+ self.assertEqual(len(embedding), engine.embedding_dim)
+ engine.train(embedding, point.class_id)
+
+ # Test.
+ trained_model = engine.serialize_model()
+ classifier = make_interpreter(trained_model, device=':0')
+ classifier.allocate_tensors()
+
+ self.assertEqual(len(classifier.get_output_details()), 1)
+
+ if not keep_classes:
+ self.assertEqual(len(train_points), classify.num_classes(classifier))
+
+ for point in test_points:
+ with test_utils.test_image('imprinting', point.image) as img:
+ set_input(classifier, img)
+ classifier.invoke()
+ top = classify.get_classes(classifier, top_k=1)[0]
+ self.assertEqual(top.id, point.class_id)
+ self.assertGreater(top.score, point.score)
+
+ return trained_model
+
+ # Test full model, not keeping base model classes.
+ def test_training_l2_norm_model_not_keep_classes(self):
+ train_points = [
+ TrainPoint(images=['cat_train_0.bmp'], class_id=0),
+ TrainPoint(images=['dog_train_0.bmp'], class_id=1),
+ TrainPoint(
+ images=['hotdog_train_0.bmp', 'hotdog_train_1.bmp'], class_id=2),
+ ]
+ test_points = [
+ TestPoint(image='cat_test_0.bmp', class_id=0, score=0.99),
+ TestPoint(image='dog_test_0.bmp', class_id=1, score=0.99),
+ TestPoint(image='hotdog_test_0.bmp', class_id=2, score=0.99)
+ ]
+ for model_path in _MODEL_LIST:
+ with self.subTest(model_path=model_path):
+ self._train_and_test(
+ test_utils.test_data_path(model_path),
+ train_points,
+ test_points,
+ keep_classes=False)
+
+ # Test full model, keeping base model classes.
+ def test_training_l2_norm_model_keep_classes(self):
+ train_points = [
+ TrainPoint(images=['cat_train_0.bmp'], class_id=1001),
+ TrainPoint(images=['dog_train_0.bmp'], class_id=1002),
+ TrainPoint(
+ images=['hotdog_train_0.bmp', 'hotdog_train_1.bmp'], class_id=1003)
+ ]
+ test_points = [
+ TestPoint(image='cat_test_0.bmp', class_id=1001, score=0.99),
+ TestPoint(image='hotdog_test_0.bmp', class_id=1003, score=0.92)
+ ]
+ for model_path in _MODEL_LIST:
+ with self.subTest(model_path=model_path):
+ self._train_and_test(
+ test_utils.test_data_path(model_path),
+ train_points,
+ test_points,
+ keep_classes=True)
+
+ def test_incremental_training(self):
+ train_points = [TrainPoint(images=['cat_train_0.bmp'], class_id=0)]
+ retrain_points = [
+ TrainPoint(images=['dog_train_0.bmp'], class_id=1),
+ TrainPoint(
+ images=['hotdog_train_0.bmp', 'hotdog_train_1.bmp'], class_id=2)
+ ]
+ test_points = [
+ TestPoint(image='cat_test_0.bmp', class_id=0, score=0.99),
+ TestPoint(image='dog_test_0.bmp', class_id=1, score=0.99),
+ TestPoint(image='hotdog_test_0.bmp', class_id=2, score=0.99)
+ ]
+ for model_path in _MODEL_LIST:
+ with self.subTest(model_path=model_path):
+ model = self._train_and_test(
+ test_utils.test_data_path(model_path),
+ train_points, [],
+ keep_classes=False)
+
+ with test_utils.temporary_file(suffix='.tflite') as new_model_file:
+ new_model_file.write(model)
+ # Retrain based on cat only model.
+ self._train_and_test(
+ new_model_file.name,
+ retrain_points,
+ test_points,
+ keep_classes=True)
+
+ def test_imprinting_engine_saving_without_training(self):
+ model_list = [
+ 'mobilenet_v1_1.0_224_l2norm_quant.tflite',
+ 'mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite'
+ ]
+ for model in model_list:
+ engine = ImprintingEngine(
+ test_utils.test_data_path(model), keep_classes=False)
+ with self.assertRaisesRegex(RuntimeError, 'Model is not trained.'):
+ engine.serialize_model()
+
+ def test_imprinting_engine_invalid_model_path(self):
+ with self.assertRaisesRegex(
+ ValueError, 'Failed to open file: invalid_model_path.tflite'):
+ ImprintingEngine('invalid_model_path.tflite')
+
+ def test_imprinting_engine_load_extractor_with_wrong_format(self):
+ expected_message = ('Unsupported model architecture. Input model must have '
+ 'an L2Norm layer.')
+ with self.assertRaisesRegex(ValueError, expected_message):
+ ImprintingEngine(
+ test_utils.test_data_path('mobilenet_v1_1.0_224_quant.tflite'))
+
+
+if __name__ == '__main__':
+ test_utils.coral_test_main()
diff --git a/tests/imprinting_evaluation_test.py b/tests/imprinting_evaluation_test.py
new file mode 100644
index 0000000..9fe8932
--- /dev/null
+++ b/tests/imprinting_evaluation_test.py
@@ -0,0 +1,152 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Evaluates the accuracy of imprinting based transfer learning model."""
+
+import contextlib
+import os
+from PIL import Image
+
+from pycoral.adapters import classify
+from pycoral.adapters import common
+from pycoral.learn.imprinting.engine import ImprintingEngine
+from pycoral.utils.edgetpu import make_interpreter
+from tests import test_utils
+import unittest
+
+
+@contextlib.contextmanager
+def test_image(path):
+ with open(path, 'rb') as f:
+ with Image.open(f) as image:
+ yield image
+
+
+class ImprintingEngineEvaluationTest(unittest.TestCase):
+
+ def _transfer_learn_and_evaluate(self, model_path, keep_classes, dataset_path,
+ test_ratio, top_k_range):
+ """Transfer-learns with given params and returns the evaluation result.
+
+ Args:
+ model_path: string, path of the base model.
+ keep_classes: bool, whether to keep base model classes.
+ dataset_path: string, path to the directory of dataset. The images should
+ be put under sub-directory named by category.
+ test_ratio: float, the ratio of images used for test.
+ top_k_range: int, top_k range to be evaluated. The function will return
+ accuracy from top 1 to top k.
+
+ Returns:
+ list of float numbers.
+ """
+ engine = ImprintingEngine(model_path, keep_classes)
+
+ extractor = make_interpreter(engine.serialize_extractor_model())
+ extractor.allocate_tensors()
+
+ num_classes = engine.num_classes
+
+ print('--------------- Parsing dataset ----------------')
+ print('Dataset path:', dataset_path)
+
+ # train in fixed order to ensure the same evaluation result.
+ train_set, test_set = test_utils.prepare_data_set_from_directory(
+ dataset_path, test_ratio, True)
+
+ print('Image list successfully parsed! Number of Categories = ',
+ len(train_set))
+ print('--------------- Processing training data ----------------')
+ print('This process may take more than 30 seconds.')
+ train_input = []
+ labels_map = {}
+ for class_id, (category, image_list) in enumerate(train_set.items()):
+ print('Processing {} ({} images)'.format(category, len(image_list)))
+ train_input.append(
+ [os.path.join(dataset_path, category, image) for image in image_list])
+ labels_map[num_classes + class_id] = category
+
+ # train
+ print('---------------- Start training -----------------')
+ size = common.input_size(extractor)
+ for class_id, images in enumerate(train_input):
+ for image in images:
+ with test_image(image) as img:
+ common.set_input(extractor, img.resize(size, Image.NEAREST))
+ extractor.invoke()
+ engine.train(classify.get_scores(extractor),
+ class_id=num_classes + class_id)
+
+ print('---------------- Training finished -----------------')
+ with test_utils.temporary_file(suffix='.tflite') as output_model_path:
+ output_model_path.write(engine.serialize_model())
+
+ # Evaluate
+ print('---------------- Start evaluating -----------------')
+ classifier = make_interpreter(output_model_path.name)
+ classifier.allocate_tensors()
+
+ # top[i] represents number of top (i+1) correct inference.
+ top_k_correct_count = [0] * top_k_range
+ image_num = 0
+ for category, image_list in test_set.items():
+ n = len(image_list)
+ print('Evaluating {} ({} images)'.format(category, n))
+ for image_name in image_list:
+ with test_image(os.path.join(dataset_path, category,
+ image_name)) as img:
+ # Set threshold as a negative number to ensure we get top k
+ # candidates even if its score is 0.
+ size = common.input_size(classifier)
+ common.set_input(classifier, img.resize(size, Image.NEAREST))
+ classifier.invoke()
+ candidates = classify.get_classes(classifier, top_k=top_k_range)
+
+ for i in range(len(candidates)):
+ candidate = candidates[i]
+ if candidate.id in labels_map and \
+ labels_map[candidate.id] == category:
+ top_k_correct_count[i] += 1
+ break
+ image_num += n
+ for i in range(1, top_k_range):
+ top_k_correct_count[i] += top_k_correct_count[i - 1]
+
+ return [top_k_correct_count[i] / image_num for i in range(top_k_range)]
+
+ def _test_oxford17_flowers_single(self, model_path, keep_classes, expected):
+ top_k_range = len(expected)
+ ret = self._transfer_learn_and_evaluate(
+ test_utils.test_data_path(model_path), keep_classes,
+ test_utils.test_data_path('oxford_17flowers'), 0.25, top_k_range)
+ for i in range(top_k_range):
+ self.assertGreaterEqual(ret[i], expected[i])
+
+ # Evaluate with L2Norm full model, not keeping base model classes.
+ def test_oxford17_flowers_l2_norm_model_not_keep_classes(self):
+ self._test_oxford17_flowers_single(
+ 'mobilenet_v1_1.0_224_l2norm_quant.tflite',
+ keep_classes=False,
+ expected=[0.86, 0.94, 0.96, 0.97, 0.97])
+
+ # Evaluate with L2Norm full model, keeping base model classes.
+ def test_oxford17_flowers_l2_norm_model_keep_classes(self):
+ self._test_oxford17_flowers_single(
+ 'mobilenet_v1_1.0_224_l2norm_quant.tflite',
+ keep_classes=True,
+ expected=[0.86, 0.94, 0.96, 0.96, 0.97])
+
+
+if __name__ == '__main__':
+ test_utils.coral_test_main()
diff --git a/tests/multiple_tpus_test.py b/tests/multiple_tpus_test.py
new file mode 100644
index 0000000..8fed115
--- /dev/null
+++ b/tests/multiple_tpus_test.py
@@ -0,0 +1,94 @@
+# Lint as: python3
+# pylint:disable=# pylint:disable=g-generic-assert
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import threading
+
+from PIL import Image
+
+from pycoral.adapters import classify
+from pycoral.adapters import common
+from pycoral.adapters import detect
+from pycoral.utils.dataset import read_label_file
+from pycoral.utils.edgetpu import make_interpreter
+from tests import test_utils
+import unittest
+
+
+class MultipleTpusTest(unittest.TestCase):
+
+ def test_run_classification_and_detection(self):
+
+ def classification_task(num_inferences):
+ tid = threading.get_ident()
+ print('Thread: %d, %d inferences for classification task' %
+ (tid, num_inferences))
+ labels = read_label_file(test_utils.test_data_path('imagenet_labels.txt'))
+ model_name = 'mobilenet_v1_1.0_224_quant_edgetpu.tflite'
+ interpreter = make_interpreter(
+ test_utils.test_data_path(model_name), device=':0')
+ interpreter.allocate_tensors()
+ size = common.input_size(interpreter)
+ print('Thread: %d, using device 0' % tid)
+ with test_utils.test_image('cat.bmp') as img:
+ for _ in range(num_inferences):
+ common.set_input(interpreter, img.resize(size, Image.NEAREST))
+ interpreter.invoke()
+ ret = classify.get_classes(interpreter, top_k=1)
+ self.assertEqual(len(ret), 1)
+ self.assertEqual(labels[ret[0].id], 'Egyptian cat')
+ print('Thread: %d, done classification task' % tid)
+
+ def detection_task(num_inferences):
+ tid = threading.get_ident()
+ print('Thread: %d, %d inferences for detection task' %
+ (tid, num_inferences))
+ model_name = 'ssd_mobilenet_v1_coco_quant_postprocess_edgetpu.tflite'
+ interpreter = make_interpreter(
+ test_utils.test_data_path(model_name), device=':1')
+ interpreter.allocate_tensors()
+ print('Thread: %d, using device 1' % tid)
+ with test_utils.test_image('cat.bmp') as img:
+ for _ in range(num_inferences):
+ _, scale = common.set_resized_input(
+ interpreter,
+ img.size,
+ lambda size, image=img: image.resize(size, Image.ANTIALIAS))
+ interpreter.invoke()
+ ret = detect.get_objects(
+ interpreter, score_threshold=0.7, image_scale=scale)
+ self.assertEqual(len(ret), 1)
+ self.assertEqual(ret[0].id, 16) # cat
+ expected_bbox = detect.BBox(
+ xmin=int(0.1 * img.size[0]),
+ ymin=int(0.1 * img.size[1]),
+ xmax=int(0.7 * img.size[0]),
+ ymax=int(1.0 * img.size[1]))
+ self.assertGreaterEqual(
+ detect.BBox.iou(expected_bbox, ret[0].bbox), 0.85)
+ print('Thread: %d, done detection task' % tid)
+
+ num_inferences = 2000
+ t1 = threading.Thread(target=classification_task, args=(num_inferences,))
+ t2 = threading.Thread(target=detection_task, args=(num_inferences,))
+
+ t1.start()
+ t2.start()
+
+ t1.join()
+ t2.join()
+
+
+if __name__ == '__main__':
+ test_utils.coral_test_main()
diff --git a/tests/pipelined_model_runner_test.py b/tests/pipelined_model_runner_test.py
new file mode 100644
index 0000000..f419b23
--- /dev/null
+++ b/tests/pipelined_model_runner_test.py
@@ -0,0 +1,204 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import threading
+import time
+
+import numpy as np
+
+import pycoral.pipeline.pipelined_model_runner as pipeline
+from pycoral.utils.edgetpu import list_edge_tpus
+from pycoral.utils.edgetpu import make_interpreter
+from tests import test_utils
+import unittest
+
+
+def _get_ref_result(ref_model, input_tensors):
+ interpreter = make_interpreter(test_utils.test_data_path(ref_model))
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()
+ assert len(input_details) == 1
+ output_details = interpreter.get_output_details()
+ assert len(output_details) == 1
+
+ interpreter.tensor(input_details[0]['index'])()[0][:, :] = input_tensors[0]
+ interpreter.invoke()
+ return np.array(interpreter.tensor(output_details[0]['index'])())
+
+
+def _get_devices(num_devices):
+ """Returns list of device names in usb:N or pci:N format.
+
+ This function prefers returning PCI Edge TPU first.
+
+ Args:
+ num_devices: int, number of devices expected
+
+ Returns:
+ list of devices in pci:N and/or usb:N format
+
+ Raises:
+ RuntimeError: if not enough devices are available
+ """
+ edge_tpus = list_edge_tpus()
+
+ if len(edge_tpus) < num_devices:
+ raise RuntimeError(
+ 'Not enough Edge TPUs detected, expected %d, detected %d.' %
+ (num_devices, len(edge_tpus)))
+
+ num_pci_devices = sum(1 for device in edge_tpus if device['type'] == 'pci')
+
+ return ['pci:%d' % i for i in range(min(num_devices, num_pci_devices))] + [
+ 'usb:%d' % i for i in range(max(0, num_devices - num_pci_devices))
+ ]
+
+
+def _make_runner(model_paths, devices):
+ print('Using devices: ', devices)
+ print('Using models: ', model_paths)
+
+ if len(model_paths) != len(devices):
+ raise ValueError('# of devices and # of model_paths should match')
+
+ interpreters = [
+ make_interpreter(test_utils.test_data_path(m), d)
+ for m, d in zip(model_paths, devices)
+ ]
+ for interpreter in interpreters:
+ interpreter.allocate_tensors()
+ return pipeline.PipelinedModelRunner(interpreters)
+
+
+class PipelinedModelRunnerTest(unittest.TestCase):
+
+ def setUp(self):
+ super(PipelinedModelRunnerTest, self).setUp()
+ model_segments = [
+ 'pipeline/inception_v3_299_quant_segment_0_of_2_edgetpu.tflite',
+ 'pipeline/inception_v3_299_quant_segment_1_of_2_edgetpu.tflite',
+ ]
+ self.runner = _make_runner(model_segments,
+ _get_devices(len(model_segments)))
+
+ input_details = self.runner.interpreters()[0].get_input_details()
+ self.assertEqual(len(input_details), 1)
+ self.input_shape = input_details[0]['shape']
+
+ np.random.seed(0)
+ self.input_tensors = [
+ np.random.randint(0, 256, size=self.input_shape, dtype=np.uint8)
+ ]
+
+ ref_model = 'inception_v3_299_quant_edgetpu.tflite'
+ self.ref_result = _get_ref_result(ref_model, self.input_tensors)
+
+ def test_bad_segments(self):
+ model_segments = [
+ 'pipeline/inception_v3_299_quant_segment_1_of_2_edgetpu.tflite',
+ 'pipeline/inception_v3_299_quant_segment_0_of_2_edgetpu.tflite',
+ ]
+ with self.assertRaisesRegex(
+ ValueError, r'Interpreter [\d]+ can not get its input tensors'):
+ unused_runner = _make_runner(model_segments, [None] * len(model_segments))
+
+ def test_unsupported_input_type(self):
+ with self.assertRaisesRegex(
+ ValueError, 'Input should be a list of numpy array of type*'):
+ self.runner.push([np.random.random(self.input_shape)])
+
+ def test_check_unconsumed_tensor(self):
+ # Everything should work fine without crashing.
+ self.runner.push(self.input_tensors)
+
+ def test_push_and_pop(self):
+ self.assertTrue(self.runner.push(self.input_tensors))
+ result = self.runner.pop()
+ self.assertEqual(len(result), 1)
+ np.testing.assert_equal(result[0], self.ref_result)
+
+ # Check after [] is pushed.
+ self.assertTrue(self.runner.push([]))
+ self.assertFalse(self.runner.push(self.input_tensors))
+ self.assertIsNone(self.runner.pop())
+
+ def test_producer_and_consumer_threads(self):
+ num_requests = 5
+
+ def producer(self):
+ for _ in range(num_requests):
+ self.runner.push(self.input_tensors)
+ self.runner.push([])
+
+ def consumer(self):
+ while True:
+ result = self.runner.pop()
+ if not result:
+ break
+ np.testing.assert_equal(result[0], self.ref_result)
+
+ producer_thread = threading.Thread(target=producer, args=(self,))
+ consumer_thread = threading.Thread(target=consumer, args=(self,))
+
+ producer_thread.start()
+ consumer_thread.start()
+ producer_thread.join()
+ consumer_thread.join()
+
+ def test_set_input_and_output_queue_size(self):
+ self.runner.set_input_queue_size(1)
+ self.runner.set_output_queue_size(1)
+ num_segments = len(self.runner.interpreters())
+
+ # When both input and output queue size are set to 1, the max number of
+ # requests pipeline runner can buffer is 2*num_segments+1. This is because
+ # the intermediate queues need to be filled as well.
+ max_buffered_requests = 2 * num_segments + 1
+
+ # Push `max_buffered_requests` to pipeline, such that the next `push` will
+ # be blocking as there is no consumer to process the results at the moment.
+ for _ in range(max_buffered_requests):
+ self.assertTrue(self.runner.push(self.input_tensors))
+
+ # Sleep for `max_buffered_requests` seconds to make sure the first request
+ # already reaches the last segments. This assumes that it takes 1 second for
+ # each segment to return inference result (which is a generous upper bound).
+ time.sleep(max_buffered_requests)
+
+ def push_new_request(self):
+ self.assertTrue(self.runner.push(self.input_tensors))
+ self.assertTrue(self.runner.push([]))
+
+ producer_thread = threading.Thread(target=push_new_request, args=(self,))
+ producer_thread.start()
+ # If runner's input queue has room, push is non-blocking and should return
+ # immediately. If producer_thread is still alive after join() with some
+ # `timeout`, it means the thread is blocked.
+ producer_thread.join(1.0)
+ self.assertTrue(producer_thread.is_alive())
+
+ processed_requests = 0
+ while True:
+ result = self.runner.pop()
+ if not result:
+ break
+ processed_requests += 1
+ self.assertEqual(processed_requests, max_buffered_requests + 1)
+ producer_thread.join(1.0)
+ self.assertFalse(producer_thread.is_alive())
+
+
+if __name__ == '__main__':
+ test_utils.coral_test_main()
diff --git a/tests/segment_test.py b/tests/segment_test.py
new file mode 100644
index 0000000..d76caf6
--- /dev/null
+++ b/tests/segment_test.py
@@ -0,0 +1,99 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+from PIL import Image
+import unittest
+from pycoral.adapters import common
+from pycoral.adapters import segment
+from pycoral.utils.edgetpu import make_interpreter
+from tests.test_utils import coral_test_main
+from tests.test_utils import test_data_path
+
+
+def deeplab_model_dm05(tpu):
+ suffix = '_edgetpu' if tpu else ''
+ return 'deeplabv3_mnv2_dm05_pascal_quant%s.tflite' % suffix
+
+
+def deeplab_model_dm10(tpu):
+ suffix = '_edgetpu' if tpu else ''
+ return 'deeplabv3_mnv2_pascal_quant%s.tflite' % suffix
+
+
+def keras_post_training_unet_mv2(tpu, size):
+ suffix = '_edgetpu' if tpu else ''
+ return 'keras_post_training_unet_mv2_%d_quant%s.tflite' % (size, suffix)
+
+
+def array_iou(a, b):
+ count = (a == b).sum()
+ return count / (a.size + b.size - count)
+
+
+def segment_image(model_file, image_file, mask_file):
+ interpreter = make_interpreter(test_data_path(model_file))
+ interpreter.allocate_tensors()
+
+ image = Image.open(test_data_path(image_file)).resize(
+ common.input_size(interpreter), Image.ANTIALIAS)
+ common.set_input(interpreter, image)
+ interpreter.invoke()
+
+ result = segment.get_output(interpreter)
+ if len(result.shape) > 2:
+ result = np.argmax(result, axis=2)
+
+ reference = np.asarray(Image.open(test_data_path(mask_file)))
+ return array_iou(result, reference)
+
+
+class SegmentTest(unittest.TestCase):
+
+ def test_deeplab_dm10(self):
+ for tpu in [False, True]:
+ with self.subTest(tpu=tpu):
+ self.assertGreater(
+ segment_image(
+ deeplab_model_dm10(tpu), 'bird_segmentation.bmp',
+ 'bird_segmentation_mask.bmp'), 0.90)
+
+ def test_deeplab_dm05(self):
+ for tpu in [False, True]:
+ with self.subTest(tpu=tpu):
+ self.assertGreater(
+ segment_image(
+ deeplab_model_dm05(tpu), 'bird_segmentation.bmp',
+ 'bird_segmentation_mask.bmp'), 0.90)
+
+ def test_keras_post_training_unet_mv2_128(self):
+ for tpu in [False, True]:
+ with self.subTest(tpu=tpu):
+ self.assertGreater(
+ segment_image(
+ keras_post_training_unet_mv2(tpu, 128), 'dog_segmentation.bmp',
+ 'dog_segmentation_mask.bmp'), 0.86)
+
+ def test_keras_post_training_unet_mv2_256(self):
+ for tpu in [False, True]:
+ with self.subTest(tpu=tpu):
+ self.assertGreater(
+ segment_image(
+ keras_post_training_unet_mv2(tpu, 256),
+ 'dog_segmentation_256.bmp', 'dog_segmentation_mask_256.bmp'),
+ 0.81)
+
+
+if __name__ == '__main__':
+ coral_test_main()
diff --git a/tests/softmax_regression_test.py b/tests/softmax_regression_test.py
new file mode 100644
index 0000000..df8d81c
--- /dev/null
+++ b/tests/softmax_regression_test.py
@@ -0,0 +1,151 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests SoftmaxRegression class.
+
+Generates some fake data and tries to overfit the data with SoftmaxRegression.
+"""
+import numpy as np
+
+from pycoral.learn.backprop.softmax_regression import SoftmaxRegression
+from tests import test_utils
+import unittest
+
+
+def generate_fake_data(class_sizes, means, cov_mats):
+ """Generates fake data for training and testing.
+
+ Examples from same class is drawn from the same MultiVariate Normal (MVN)
+ distribution.
+
+ # classes = len(class_sizes) = len(means) = len(cov_mats)
+ dim of MVN = cov_mats[0].shape[0]
+
+ Args:
+ class_sizes: list of ints, number of examples to draw from each class.
+ means: list of list of floats, mean value of each MVN distribution.
+ cov_mats: list of ndarray, each element is a k by k ndarray, which
+ represents the covariance matrix in MVN distribution, k is the dimension
+ of MVN distribution.
+
+ Returns:
+ a tuple of data and labels. data and labels are shuffled.
+ """
+ # Some sanity checks.
+ assert len(class_sizes) == len(means)
+ assert len(class_sizes) == len(cov_mats)
+
+ num_data = np.sum(class_sizes)
+ feature_dim = len(means[0])
+ data = np.empty((num_data, feature_dim)).astype(np.float32)
+ labels = np.empty((num_data), dtype=int)
+
+ start_idx = 0
+ class_idx = 0
+ for size, mean, cov_mat in zip(class_sizes, means, cov_mats):
+ data[start_idx:start_idx + size] = np.random.multivariate_normal(
+ mean, cov_mat, size)
+ labels[start_idx:start_idx + size] = np.ones(size, dtype=int) * class_idx
+ start_idx += size
+ class_idx += 1
+
+ perm = np.random.permutation(data.shape[0])
+ data = data[perm, :]
+ labels = labels[perm]
+
+ return data, labels
+
+
+class SoftmaxRegressionTest(unittest.TestCase):
+
+ def test_softmax_regression_linear_separable_data(self):
+ # Fake data is generated from 3 MVN distributions, these MVN distributionss
+ # are tuned to be well-separated, such that it can be separated by
+ # SoftmaxRegression model (which is a linear classifier).
+ num_train = 200
+ num_val = 30
+ # Let's distribute data evenly among different classes.
+ num_classes = 3
+ class_sizes = ((num_train + num_val) // num_classes) * np.ones(
+ num_classes, dtype=int)
+ class_sizes[-1] = (num_train + num_val) - np.sum(class_sizes[0:-1])
+
+ # 3 is chosen, such that each pair of mean is over 6 `sigma` distance
+ # apart. Which makes classes harder to `touch` each other.
+ # https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule
+ means = np.array([[1, 1], [-1, -1], [1, -1]]) * 3
+ feature_dim = len(means[0])
+ cov_mats = [np.eye(feature_dim)] * num_classes
+
+ model = SoftmaxRegression(feature_dim, num_classes)
+ np.random.seed(12345)
+ all_data, all_labels = generate_fake_data(class_sizes, means, cov_mats)
+
+ dataset = {}
+ dataset['data_train'] = all_data[0:num_train]
+ dataset['labels_train'] = all_labels[0:num_train]
+ dataset['data_val'] = all_data[num_train:]
+ dataset['labels_val'] = all_labels[num_train:]
+ # train with SGD.
+ num_iter = 20
+ learning_rate = 0.01
+ model.train_with_sgd(
+ dataset, num_iter, learning_rate, batch_size=100, print_every=5)
+ self.assertGreater(
+ model.get_accuracy(dataset['data_train'], dataset['labels_train']),
+ 0.99)
+
+ def test_softmax_regression_linear_non_separable_data(self):
+ # Fake data is generated from 3 MVN distributions, these MVN distributions
+ # are NOT well-separated.
+ num_train = 200
+ num_val = 30
+ # Let's distribute data evenly among different classes.
+ num_classes = 3
+ class_sizes = ((num_train + num_val) // num_classes) * np.ones(
+ num_classes, dtype=int)
+ class_sizes[-1] = (num_train + num_val) - np.sum(class_sizes[0:-1])
+
+ means = np.array([[1, 1], [-1, -1], [1, -1]])
+ feature_dim = len(means[0])
+ cov_mats = [np.eye(feature_dim)] * num_classes
+
+ model = SoftmaxRegression(feature_dim, num_classes)
+ np.random.seed(54321)
+ all_data, all_labels = generate_fake_data(class_sizes, means, cov_mats)
+
+ dataset = {}
+ dataset['data_train'] = all_data[0:num_train]
+ dataset['labels_train'] = all_labels[0:num_train]
+ dataset['data_val'] = all_data[num_train:]
+ dataset['labels_val'] = all_labels[num_train:]
+ # train with SGD.
+ num_iter = 50
+ learning_rate = 0.1
+ model.train_with_sgd(
+ dataset, num_iter, learning_rate, batch_size=100, print_every=5)
+ self.assertGreater(
+ model.get_accuracy(dataset['data_train'], dataset['labels_train']), 0.8)
+
+ def test_softmax_regression_serialize_model(self):
+ feature_dim = 1024
+ num_classes = 5
+ model = SoftmaxRegression(feature_dim, num_classes)
+ in_model_path = test_utils.test_data_path(
+ 'mobilenet_v1_1.0_224_quant_embedding_extractor.tflite')
+ model.serialize_model(in_model_path)
+
+
+if __name__ == '__main__':
+ test_utils.coral_test_main()
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000..fd4953b
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,164 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test utils for benchmark and manual tests."""
+
+import argparse
+import collections
+import contextlib
+import os
+import random
+import sys
+import tempfile
+
+import numpy as np
+from PIL import Image
+import unittest
+
+_TEST_DATA_DIR = ''
+
+
+def test_data_path(path, *paths):
+ """Returns absolute path for a given test file."""
+ return os.path.abspath(os.path.join(_TEST_DATA_DIR, path, *paths))
+
+
+@contextlib.contextmanager
+def test_image(path, *paths):
+ """Returns opened test image."""
+ with open(test_data_path(path, *paths), 'rb') as f:
+ with Image.open(f) as image:
+ yield image
+
+
+@contextlib.contextmanager
+def temporary_file(suffix=None):
+ """Creates a named temp file, and deletes after going out of scope.
+
+ Exists to work around issues with passing the result of
+ tempfile.NamedTemporaryFile to native code on Windows,
+ if delete was set to True.
+
+ Args:
+ suffix: If provided, the file name will end with suffix.
+
+ Yields:
+ An file-like object.
+ """
+ resource = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
+ try:
+ yield resource
+ finally:
+ resource.close()
+ os.unlink(resource.name)
+
+
+def generate_random_input(seed, n):
+ """Generates a list with n uint8 numbers."""
+ random.seed(a=seed)
+ return [random.randint(0, 255) for _ in range(n)]
+
+
+def prepare_images(image_list, directory, shape):
+ """Reads images and converts them to numpy array with specified shape.
+
+ Args:
+ image_list: a list of strings storing file names.
+ directory: string, path of directory storing input images.
+ shape: a 2-D tuple represents the shape of required input tensor.
+
+ Returns:
+ A list of numpy.array.
+ """
+ ret = []
+ for filename in image_list:
+ with open(os.path.join(directory, filename), 'rb') as f:
+ with Image.open(f) as img:
+ img = img.resize(shape, Image.NEAREST)
+ ret.append(np.asarray(img).flatten())
+ return np.array(ret)
+
+
+def area(box):
+ """Calculates area of a given bounding box."""
+ assert box[1][0] >= box[0][0]
+ assert box[1][1] >= box[0][1]
+ return float((box[1][0] - box[0][0]) * (box[1][1] - box[0][1]))
+
+
+def iou(box_a, box_b):
+ """Calculates intersection area / union area for two bounding boxes."""
+ assert area(box_a) > 0
+ assert area(box_b) > 0
+ intersect = np.array(
+ [[max(box_a[0][0], box_b[0][0]),
+ max(box_a[0][1], box_b[0][1])],
+ [min(box_a[1][0], box_b[1][0]),
+ min(box_a[1][1], box_b[1][1])]])
+ return area(intersect) / (area(box_a) + area(box_b) - area(intersect))
+
+
+def prepare_data_set_from_directory(path, test_ratio, fixed_order):
+ """Parses data set from given directory, split them into train/test sets.
+
+ Args:
+ path: string, path of the data set. Images are stored in sub-directory named
+ by category.
+ test_ratio: float in (0,1), ratio of data used for testing.
+ fixed_order: bool, whether to spilt data set in fixed order.
+
+ Returns:
+ (train_set, test_set), A tuple of two OrderedDicts. Keys are the categories
+ and values are lists of image file names.
+ """
+ train_set = collections.OrderedDict()
+ test_set = collections.OrderedDict()
+ sub_dirs = os.listdir(path)
+ if fixed_order:
+ sub_dirs.sort()
+ for category in sub_dirs:
+ category_dir = os.path.join(path, category)
+ if os.path.isdir(category_dir):
+ images = [
+ f for f in os.listdir(category_dir)
+ if os.path.isfile(os.path.join(category_dir, f))
+ ]
+ if images:
+ if fixed_order:
+ images.sort()
+ k = int(test_ratio * len(images))
+ test_set[category] = images[:k]
+ assert test_set[category], 'No images to test [{}]'.format(category)
+ train_set[category] = images[k:]
+ assert train_set[category], 'No images to train [{}]'.format(category)
+ return train_set, test_set
+
+
+def coral_test_main():
+ """Test main to get test_data_dir flag from commend line.
+
+ In edgetpu GoB repo:
+ the Python test files are under edgetpu/tests.
+ test_data is under edgetpu/test_data.
+ """
+
+ global _TEST_DATA_DIR
+ test_data_dir_default = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'test_data')
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--test_data_dir', default=test_data_dir_default, type=str)
+ args, sys.argv = parser.parse_known_args(sys.argv)
+ _TEST_DATA_DIR = args.test_data_dir
+ unittest.main()
diff --git a/third_party/python/BUILD b/third_party/python/BUILD
new file mode 100644
index 0000000..d371c5a
--- /dev/null
+++ b/third_party/python/BUILD
@@ -0,0 +1,23 @@
+config_setting(
+ name = "windows",
+ values = {
+ "cpu": "x64_windows",
+ }
+)
+
+config_setting(
+ name = "darwin",
+ values = {
+ "cpu": "darwin",
+ }
+)
+
+cc_library(
+ name = "python",
+ deps = select({
+ ":windows": ["@python_windows"],
+ ":darwin": ["@python_darwin//:python3-headers"],
+ "//conditions:default": ["@python_linux//:python3-headers"],
+ }),
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/python/darwin/BUILD b/third_party/python/darwin/BUILD
new file mode 100644
index 0000000..8253b9b
--- /dev/null
+++ b/third_party/python/darwin/BUILD
@@ -0,0 +1,48 @@
+config_setting(
+ name = "py35",
+ define_values = {"PY3_VER": "35"}
+)
+
+config_setting(
+ name = "py36",
+ define_values = {"PY3_VER": "36"}
+)
+
+config_setting(
+ name = "py37",
+ define_values = {"PY3_VER": "37"}
+)
+
+config_setting(
+ name = "py38",
+ define_values = {"PY3_VER": "38"}
+)
+
+# sudo port install python35 python36 python37 python38
+# sudo port install py35-numpy py36-numpy py37-numpy py38-numpy
+cc_library(
+ name = "python3-headers",
+ hdrs = select({
+ "py35": glob(["3.5/include/python3.5m/*.h",
+ "3.5/lib/python3.5/site-packages/numpy/core/include/numpy/*.h"]),
+ "py36": glob(["3.6/include/python3.6m/*.h",
+ "3.6/lib/python3.6/site-packages/numpy/core/include/numpy/*.h"]),
+ "py37": glob(["3.7/include/python3.7m/*.h",
+ "3.7/lib/python3.7/site-packages/numpy/core/include/numpy/*.h"]),
+ "py38": glob(["3.8/include/python3.8/*.h",
+ "3.8/include/python3.8/cpython/*.h",
+ "3.8/lib/python3.8/site-packages/numpy/core/include/numpy/*.h"]),
+ }, no_match_error = "PY3_VER is not specified"),
+ includes = select({
+ "py35": ["3.5/include/python3.5m",
+ "3.5/lib/python3.5/site-packages/numpy/core/include"],
+ "py36": ["3.6/include/python3.6m",
+ "3.6/lib/python3.6/site-packages/numpy/core/include"],
+ "py37": ["3.7/include/python3.7m",
+ "3.7/lib/python3.7/site-packages/numpy/core/include"],
+ "py38": ["3.8/include/python3.8",
+ "3.8/include/python3.8/cpython",
+ "3.8/lib/python3.8/site-packages/numpy/core/include"],
+ }, no_match_error = "PY3_VER is not specified"),
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/python/linux/BUILD b/third_party/python/linux/BUILD
new file mode 100644
index 0000000..8606801
--- /dev/null
+++ b/third_party/python/linux/BUILD
@@ -0,0 +1,48 @@
+config_setting(
+ name = "py35",
+ define_values = {"PY3_VER": "35"}
+)
+
+config_setting(
+ name = "py36",
+ define_values = {"PY3_VER": "36"}
+)
+
+config_setting(
+ name = "py37",
+ define_values = {"PY3_VER": "37"}
+)
+
+config_setting(
+ name = "py38",
+ define_values = {"PY3_VER": "38"}
+)
+
+cc_library(
+ name = "python3-headers",
+ hdrs = select({
+ "py35": glob(["python3.5m/*.h",
+ "python3.5m/numpy/*.h",
+ "aarch64-linux-gnu/python3.5m/*.h",
+ "arm-linux-gnueabihf/python3.5m/*.h"]),
+ "py36": glob(["python3.6m/*.h",
+ "python3.6m/numpy/*.h",
+ "aarch64-linux-gnu/python3.6m/*.h",
+ "arm-linux-gnueabihf/python3.6m/*.h"]),
+ "py37": glob(["python3.7m/*.h",
+ "python3.7m/numpy/*.h",
+ "aarch64-linux-gnu/python3.7m/*.h",
+ "arm-linux-gnueabihf/python3.7m/*.h"]),
+ "py38": glob(["python3.8m/*.h",
+ "python3.8m/numpy/*.h",
+ "aarch64-linux-gnu/python3.8m/*.h",
+ "arm-linux-gnueabihf/python3.8m/*.h"]),
+ }, no_match_error = "PY3_VER is not specified"),
+ includes = select({
+ "py35": [".", "python3.5m"],
+ "py36": [".", "python3.6m"],
+ "py37": [".", "python3.7m"],
+ "py38": [".", "python3.8m"],
+ }, no_match_error = "PY3_VER is not specified"),
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/python/windows/BUILD b/third_party/python/windows/BUILD
new file mode 100644
index 0000000..1d4f23b
--- /dev/null
+++ b/third_party/python/windows/BUILD
@@ -0,0 +1,8 @@
+cc_library(
+ name = "python_windows",
+ deps = [
+ "@local_config_python//:python_headers",
+ "@local_config_python//:numpy_headers",
+ ],
+ visibility = ["//visibility:public"],
+)