Skip to content

Commit

Permalink
[converter] tflite schema update (#57)
Browse files Browse the repository at this point in the history
* [converter] TFLite schema update

* [tests] refine

* [converter] use new schema in transformable
  • Loading branch information
peterjc123 authored Mar 15, 2022
1 parent de5445b commit 6c85458
Show file tree
Hide file tree
Showing 25 changed files with 7,378 additions and 175 deletions.
7 changes: 4 additions & 3 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ ignore =
extend-ignore = E203
per-file-ignores =
tinynn/converter/operators/tflite/generated_ops.py: E501
tinynn/converter/operators/torch/aten_schema.py: E501
tinynn/converter/operators/torch/quantized_schema.py: E501
tinynn/converter/operators/torch/torchvision_schema.py: E501
tinynn/converter/schemas/tflite/schema_generated.py: E301, E302, E501, E704
tinynn/converter/schemas/torch/aten_schema.py: E501
tinynn/converter/schemas/torch/quantized_schema.py: E501
tinynn/converter/schemas/torch/torchvision_schema.py: E501
tests/converter_op_test.py: E704
examples/*: E402
exclude =
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ repos:
rev: 22.1.0
hooks:
- id: black
exclude: ^tinynn/converter/schemas
- repo: https://github.com/pycqa/flake8
rev: 4.0.1
hooks:
Expand Down
5 changes: 3 additions & 2 deletions docs/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ There are a large number of operators in PyTorch. We cannot cover all operators,
- OP schema (without I/O tensors): [generated_ops.py](../tinynn/converter/operators/tflite/generated_ops.py)
- Full schema: https://www.tensorflow.org/mlir/tfl_ops
- TorchScript
- ATen schema [aten_schema.py](../tinynn/converter/operators/torch/aten_schema.py)
- Quantized schema [quantized_schema.py](../tinynn/converter/operators/torch/quantized_schema.py)
- ATen schema [aten_schema.py](../tinynn/converter/schemas/torch/aten_schema.py)
- Quantized schema [quantized_schema.py](../tinynn/converter/schemas/torch/quantized_schema.py)
- Quantized schema [torchvision_schema.py](../tinynn/converter/schemas/torch/torchvision_schema.py)
- Translation logic
- ATen OPs [aten.py](../tinynn/converter/operators/torch/aten.py)
- Quantized OPs [quantized.py](../tinynn/converter/operators/torch/quantized.py)
Expand Down
5 changes: 3 additions & 2 deletions docs/FAQ_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ A:一般有两种解法
- OP schema (without I/O tensors): [generated_ops.py](../tinynn/converter/operators/tflite/generated_ops.py)
- Full schema: https://www.tensorflow.org/mlir/tfl_ops
- TorchScript
- ATen schema [aten_schema.py](../tinynn/converter/operators/torch/aten_schema.py)
- Quantized schema [quantized_schema.py](../tinynn/converter/operators/torch/quantized_schema.py)
- ATen schema [aten_schema.py](../tinynn/converter/schemas/torch/aten_schema.py)
- Quantized schema [quantized_schema.py](../tinynn/converter/schemas/torch/quantized_schema.py)
- Torchvision schema [torchvision_schema.py](../tinynn/converter/schemas/torch/torchvision_schema.py)
- 两者的对应翻译代码
- ATen OPs [aten.py](../tinynn/converter/operators/torch/aten.py)
- Quantized OPs [quantized.py](../tinynn/converter/operators/torch/quantized.py)
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ extend-exclude = '''
| ^/build/
| ^/debugging/
| ^/tinynn/converter/operators/tflite/generated_ops\.py
| ^/tinynn/converter/operators/torch/aten_schema\.py
| ^/tinynn/converter/operators/torch/quantized_schema\.py
| ^/tinynn/converter/operators/torch/torchvision_schema\.py
| ^/tinynn/converter/schemas/tflite//schema_generated\.py
| ^/tinynn/converter/schemas/torch/aten_schema\.py
| ^/tinynn/converter/schemas/torch/quantized_schema\.py
| ^/tinynn/converter/schemas/torch/torchvision_schema\.py
'''
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ PyYAML>=5.3.1
ruamel.yaml>=0.16.12
futures3>=1.0.0; python_version < '3.7'
python-igraph>=0.9.6
tflite==2.3.0
flatbuffers>=1.12
2 changes: 2 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3027,6 +3027,7 @@ def setUp(self):
continue
if backend in torch.backends.quantized.supported_engines:
self.backend = backend
torch.backends.quantized.engine = backend
return
self.skipTest('No quantization backend is found')

Expand Down Expand Up @@ -3440,6 +3441,7 @@ def forward(self, x):
assert_close(dummy_output, tfl_output, atol=1, rtol=1)

@unittest.skipIf(not hasattr(torch.nn.quantized, 'ELU'), 'Quantized elu is not supported')
@unittest.skipIf(LooseVersion(tf.__version__) < LooseVersion('2.4.1'), 'Quantized elu is not supported')
def test_quantized_elu_int8(self):
class Model(nn.Module):
def __init__(self) -> None:
Expand Down
95 changes: 49 additions & 46 deletions tests/converter_optimizer_test.py

Large diffs are not rendered by default.

9 changes: 7 additions & 2 deletions tinynn/converter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,17 @@ To solve the above problems, we implement this converter that translates models
+ [transformable.py](operators/tflite/transformable.py) : Transformable operators, such as BatchNorm, Conv2d, and other composite operators composed of multiple TFLite operators
+ [torch](operators/torch) : PyTorch related classes
+ [base.py](operators/torch/base.py) : The base data structure needed for TorchScript parsing
+ [aten_schema.py](operators/torch/aten_schema.py) : Wrapper classes generated from ATen schema
+ [quantized_schema.py](operators/torch/quantized_schema.py) : Wrapper class generated from quantized schema
+ [aten.py](operators/torch/aten.py) : Translation of ATen-related operators
+ [quantized.py](operators/torch/quantized.py) : Translation of quantized-related operators
+ [base.py](operators/base.py) : Definition of generic operators
+ [graph.py](operators/graph.py) : Computation of graph-related infrastructure
+ [op_version.py](operators/op_version.py) : Handler for operator version
+ [optimize.py](operators/optimize.py) : Computation graph optimization
+ [schemas](schemas): Most of the schemas of the converter
+ [tflite](schemas/tflite) : TFLite related schemas
+ [schema_generated.py](schemas/tflite/schema_generated.py) : TFLite schema parsers
+ [torch](schemas/torch) : PyTorch related schemas
+ [aten_schema.py](schemas/torch/aten_schema.py) : Wrapper classes generated from ATen schema
+ [quantized_schema.py](schemas/torch/quantized_schema.py) : Wrapper class generated from quantized schema
+ [torchvision_schema.py](schemas/torch/torchvision_schema.py) : Wrapper class torchvision_schema from Torchvision schema
+ [base.py](base.py): Entry class `TFLiteConverter`
9 changes: 7 additions & 2 deletions tinynn/converter/README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,17 @@
+ [transformable.py](operators/tflite/transformable.py) : 可转换算子,如BatchNorm、Conv2d等由多个TFLite算子组成的复合算子
+ [torch](operators/torch) : PyTorch相关的类
+ [base.py](operators/torch/base.py) : TorchScript解析所需的基础数据结构
+ [aten_schema.py](operators/torch/aten_schema.py) : 从ATen schema生成的Wrapper类
+ [quantized_schema.py](operators/torch/quantized_schema.py) : 从Quantized schema生成的Wrapper类
+ [aten.py](operators/torch/aten.py) : ATen相关算子的翻译
+ [quantized.py](operators/torch/quantized.py) : Quantized相关算子的翻译
+ [base.py](operators/base.py) : 通用算子的定义
+ [graph.py](operators/graph.py) : 计算图相关的基础设施
+ [op_version.py](operators/op_version.py) : 设置算子版本
+ [optimize.py](operators/optimize.py) : 计算图优化
+ [schemas](schemas): schemas相关
+ [tflite](schemas/tflite) : TFLite相关的schema
+ [schema_generated.py](schemas/tflite/schema_generated.py) : TFLite schema 解析器
+ [torch](schemas/torch) : PyTorch相关的schema
+ [aten_schema.py](schemas/torch/aten_schema.py) : 从ATen schema生成的Wrapper类
+ [quantized_schema.py](schemas/torch/quantized_schema.py) : 从Quantized schema生成的Wrapper类
+ [torchvision_schema.py](schemas/torch/torchvision_schema.py) : 从Torchvision schema生成的Wrapper类
+ [base](base.py): 入口类TFLiteConverter
4 changes: 1 addition & 3 deletions tinynn/converter/operators/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import inspect
import sys

from enum import IntEnum

from tflite.ActivationFunctionType import ActivationFunctionType
from tflite.BuiltinOperator import BuiltinOperator
from ..schemas.tflite.schema_generated import ActivationFunctionType, BuiltinOperator

# In Python 3.6, we cannot make ExtendedOperator derive from IntEnum
if sys.version_info >= (3, 7):
Expand Down
7 changes: 3 additions & 4 deletions tinynn/converter/operators/op_version.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from .base import ExtendedOperator, FUSE_ACTIVATION_MAP
from .graph import CommonGraph
from ..schemas.tflite import schema_generated as tfl_schema
from . import tflite as tfl

import tflite as tfl_schema
from .base import FUSE_ACTIVATION_MAP, ExtendedOperator
from .graph import CommonGraph


class OPVersioner(object):
Expand Down
4 changes: 2 additions & 2 deletions tinynn/converter/operators/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import igraph as ig
import numpy as np
from tinynn.util.util import class_conditional, get_logger

from tflite.ActivationFunctionType import ActivationFunctionType
from tinynn.util.util import class_conditional, get_logger

from ..schemas.tflite.schema_generated import ActivationFunctionType
from . import tflite as tfl
from .base import FUSE_ACTIVATION_MAP, ExtendedOperator
from .graph import CommonGraph
Expand Down
29 changes: 15 additions & 14 deletions tinynn/converter/operators/tflite/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import copy
import tflite
import typing

import flatbuffers
import numpy as np
import torch
import flatbuffers
import pkg_resources

if pkg_resources.get_distribution('tflite').version != '2.3.0':
raise AssertionError('tflite==2.3.0 is required. Please run `python3 -m pip install tflite==2.3.0`.')
from ...schemas.tflite import schema_generated as tflite

Offset = int

Expand All @@ -31,7 +29,10 @@ def build(self, builder: flatbuffers.Builder) -> Offset:
custom_code = create_string(builder, tflite.OperatorCode.CustomCode, self.custom_code)

tflite.OperatorCodeStart(builder)
tflite.OperatorCodeAddBuiltinCode(builder, self.code)
if self.code < tflite.BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES:
tflite.OperatorCodeAddDeprecatedBuiltinCode(builder, self.code)
else:
tflite.OperatorCodeAddBuiltinCode(builder, self.code)
tflite.OperatorCodeAddVersion(builder, self.version)

if custom_code is not None:
Expand Down Expand Up @@ -375,10 +376,10 @@ def create_offset_vector(builder: flatbuffers.Builder, prop: typing.Callable, ve
vec = list(vec)

prop_name = prop.__name__
cls_name = prop.__module__.split('.')[-1]
cls_name = prop.__qualname__.split('.')[0]
func_name = f'{cls_name}Start{prop_name}Vector'
if not hasattr(tflite, func_name):
assert False, "invalid prop is given"
assert False, f"invalid prop is given, {prop.__qualname__}"

start_vec_func = getattr(tflite, func_name)
start_vec_func(builder, len(vec))
Expand All @@ -397,10 +398,10 @@ def create_numpy_array(builder: flatbuffers.Builder, prop: typing.Callable, vec:
assert False, "type of vec unexpected, expected: list or tuple or ndarray"

prop_name = prop.__name__
cls_name = prop.__module__.split('.')[-1]
cls_name = prop.__qualname__.split('.')[0]
func_name = f'{cls_name}Start{prop_name}Vector'
if not hasattr(tflite, func_name):
assert False, "invalid prop is given"
assert False, f"invalid prop is given, {prop.__qualname__}"

arr = np.asarray(vec, dtype=dtype)
return builder.CreateNumpyVector(arr)
Expand All @@ -411,10 +412,10 @@ def create_string(builder: flatbuffers.Builder, prop: typing.Callable, val: str)
assert False, "type of val unexpected, expected: str"

prop_name = prop.__name__
cls_name = prop.__module__.split('.')[-1]
cls_name = prop.__qualname__.split('.')[0]
func_name = f'{cls_name}Add{prop_name}'
if not hasattr(tflite, func_name):
assert False, "invalid prop is given"
assert False, f"invalid prop is given, {prop.__qualname__}"

return builder.CreateString(val)

Expand All @@ -424,10 +425,10 @@ def create_byte_array(builder: flatbuffers.Builder, prop: typing.Callable, val:
assert False, "type of val unexpected, expected: bytes or bytearray"

prop_name = prop.__name__
cls_name = prop.__module__.split('.')[-1]
cls_name = prop.__qualname__.split('.')[0]
func_name = f'{cls_name}Start{prop_name}Vector'
if not hasattr(tflite, func_name):
assert False, "invalid prop is given"
assert False, f"invalid prop is given, {prop.__qualname__}"

return builder.CreateByteVector(val)

Expand Down
Loading

0 comments on commit 6c85458

Please sign in to comment.