Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding integration tests package with ptq e2e test example and converting symmetric selection unit test to pytest #1347

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/run_keras_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ jobs:

- name: Run pytest
run: |
pytest tests_pytest/keras
pytest tests_pytest/unit_tests/keras
pytest tests_pytest/integration_tests/keras

2 changes: 1 addition & 1 deletion .github/workflows/run_pytorch_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ jobs:
python -m unittest discover tests/pytorch_tests -v
- name: Run pytest
run: |
pytest tests_pytest/pytorch
pytest tests_pytest/unit_tests/pytorch

7 changes: 4 additions & 3 deletions .github/workflows/run_tests_suite_coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ jobs:
run: coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" unittest discover tests/common_tests -v

- name: Run common tests (pytest)
run: coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/common
run: |
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/unit_tests/common tests_pytest/integration_tests/common

- name: Set up TensorFlow environment
run: |
Expand All @@ -56,7 +57,7 @@ jobs:
- name: Run TensorFlow tests (pytest)
run: |
source tf_env/bin/activate
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/keras
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/unit_tests/keras tests_pytest/integration_tests/keras

- name: Set up PyTorch environment
run: |
Expand All @@ -74,7 +75,7 @@ jobs:
- name: Run PyTorch tests (pytest)
run: |
source torch_env/bin/activate
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/pytorch
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/unit_tests/pytorch tests_pytest/integration_tests/pytorch

- name: Combine Multiple Coverage Files
run: coverage combine
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/tests_common.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ jobs:
run: python -m unittest discover tests/common_tests -v

- name: Run pytest
run: pytest tests_pytest/common
run: |
pytest tests_pytest/unit_tests/common

This file was deleted.

16 changes: 0 additions & 16 deletions tests/keras_tests/feature_networks_tests/test_features_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@
from tests.keras_tests.feature_networks_tests.feature_networks.softmax_shift_test import SoftmaxShiftTest
from tests.keras_tests.feature_networks_tests.feature_networks.split_concatenate_test import SplitConcatenateTest
from tests.keras_tests.feature_networks_tests.feature_networks.split_conv_bug_test import SplitConvBugTest
from tests.keras_tests.feature_networks_tests.feature_networks.symmetric_threshold_selection_activation_test import \
SymmetricThresholdSelectionActivationTest, SymmetricThresholdSelectionBoundedActivationTest
from tests.keras_tests.feature_networks_tests.feature_networks.test_depthwise_conv2d_replacement import \
DwConv2dReplacementTest
from tests.keras_tests.feature_networks_tests.feature_networks.test_kmeans_quantizer import \
Expand Down Expand Up @@ -754,20 +752,6 @@ def test_gptq_conv_group_dilation(self):
def test_split_conv_bug(self):
SplitConvBugTest(self).run_test()

def test_symmetric_threshold_selection_activation(self):
SymmetricThresholdSelectionActivationTest(self, QuantizationErrorMethod.NOCLIPPING).run_test()
SymmetricThresholdSelectionActivationTest(self, QuantizationErrorMethod.MSE).run_test()
SymmetricThresholdSelectionActivationTest(self, QuantizationErrorMethod.MAE).run_test()
SymmetricThresholdSelectionActivationTest(self, QuantizationErrorMethod.LP).run_test()
SymmetricThresholdSelectionActivationTest(self, QuantizationErrorMethod.KL).run_test()

def test_symmetric_threshold_selection_softmax_activation(self):
SymmetricThresholdSelectionBoundedActivationTest(self, QuantizationErrorMethod.NOCLIPPING).run_test()
SymmetricThresholdSelectionBoundedActivationTest(self, QuantizationErrorMethod.MSE).run_test()
SymmetricThresholdSelectionBoundedActivationTest(self, QuantizationErrorMethod.MAE).run_test()
SymmetricThresholdSelectionBoundedActivationTest(self, QuantizationErrorMethod.LP).run_test()
SymmetricThresholdSelectionBoundedActivationTest(self, QuantizationErrorMethod.KL).run_test()

def test_uniform_range_selection_activation(self):
UniformRangeSelectionActivationTest(self, QuantizationErrorMethod.NOCLIPPING).run_test()
UniformRangeSelectionActivationTest(self, QuantizationErrorMethod.MSE).run_test()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import keras
import pytest
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

from mct_quantizers import QuantizationMethod, KerasQuantizationWrapper
from mct_quantizers.keras.metadata import MetadataLayer
from model_compression_toolkit.core.common.user_info import UserInformation
from model_compression_toolkit.core.keras.constants import KERNEL, DEPTHWISE_KERNEL
from model_compression_toolkit.ptq import keras_post_training_quantization
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
AttributeQuantizationConfig, Signedness
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR
from tests.common_tests.helpers.tpcs_for_tests.v4.tpc import generate_tpc

INPUT_SHAPE = (224, 224, 3)


@pytest.fixture
def rep_data_gen():
np.random.seed(42)

def reppresentative_dataset():
for _ in range(2):
yield [np.random.randn(2, *INPUT_SHAPE)]

return reppresentative_dataset


def model_basic():
inputs = layers.Input(shape=INPUT_SHAPE)
x = layers.Conv2D(2, 3, padding='same')(inputs)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
return tf.keras.models.Model(inputs=inputs, outputs=x)


def model_residual():
inputs = layers.Input(shape=INPUT_SHAPE)
x1 = layers.Conv2D(2, 3, padding='same')(inputs)
x1 = layers.ReLU()(x1)

x2 = layers.Conv2D(2, 3, padding='same')(x1)
x2 = layers.BatchNormalization()(x2)
x2 = layers.ReLU()(x2)

x = layers.Add()([x1, x2])

x = layers.Flatten()(x)
x = layers.Dense(units=10, activation='softmax')(x)

return keras.Model(inputs=inputs, outputs=x)


def set_tpc(weights_quantizer, per_channel):
# TODO: currently, running E2E test with IMX500 V4 TPC from tests package
# we need to select a default TPC for tests, which is the one we want to verify e2e for.

att_cfg_noquant = AttributeQuantizationConfig()
att_cfg_quant = AttributeQuantizationConfig(weights_quantization_method=weights_quantizer,
weights_n_bits=8,
weights_per_channel_threshold=per_channel,
enable_weights_quantization=True)

op_cfg = OpQuantizationConfig(default_weight_attr_config=att_cfg_quant,
attr_weights_configs_mapping={KERNEL_ATTR: att_cfg_quant,
BIAS_ATTR: att_cfg_noquant},
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
activation_n_bits=8,
supported_input_activation_n_bits=8,
enable_activation_quantization=False, # No activation quantization
quantization_preserving=False,
fixed_scale=None,
fixed_zero_point=None,
simd_size=32,
signedness=Signedness.AUTO)


tpc = generate_tpc(default_config=op_cfg, base_config=op_cfg, mixed_precision_cfg_list=[op_cfg], name="test_tpc")

return tpc


@pytest.fixture
def tpc_factory():
def _tpc_factory(quant_method, per_channel):
return set_tpc(quant_method, per_channel)
return _tpc_factory


def _verify_weights_quantizer_params(quant_method, weights_quantizer, params_shape, per_channel):
assert weights_quantizer.per_channel == per_channel
assert weights_quantizer.quantization_method[0] == quant_method

if quant_method == QuantizationMethod.POWER_OF_TWO:
assert len(weights_quantizer.threshold) == params_shape
for t in weights_quantizer.threshold:
assert np.log2(np.abs(t)).astype(int) == np.log2(np.abs(t))
elif quant_method == QuantizationMethod.SYMMETRIC:
assert len(weights_quantizer.threshold) == params_shape
elif quant_method == QuantizationMethod.UNIFORM:
assert len(weights_quantizer.min_range) == params_shape
assert len(weights_quantizer.max_range) == params_shape


class TestPostTrainingQuantizationApi:
# TODO: add tests for:
# 1) activation only, W&A, LUT quantizer (separate)
# 2) extend to also test with different settings features (bc, snc, etc.)
# 3) advanced models and operators


def _verify_quantized_model_structure(self, model, q_model, quantization_info):
assert q_model is not None and isinstance(q_model, keras.Model)
assert quantization_info is not None and isinstance(quantization_info, UserInformation)

# Assert quantized model structure
assert len([l for l in q_model.layers if isinstance(l, layers.BatchNormalization)]) == 0, \
"Expects BN folding in quantized model."
assert len([l for l in q_model.layers if isinstance(l, MetadataLayer)]) == 1, \
"Expects quantized model to have a metadata stored in a dedicated layer."
original_conv_layers = [l for l in model.layers if
isinstance(l, (layers.Conv2D, layers.DepthwiseConv2D, layers.Dense))]
quantized_conv_layers = [l for l in q_model.layers if isinstance(l, KerasQuantizationWrapper)]
assert len(original_conv_layers) == len(quantized_conv_layers), \
"Expects all conv layers from the original model to be wrapped with a KerasQuantizationWrapper."


@pytest.mark.parametrize("quant_method", [QuantizationMethod.POWER_OF_TWO,
QuantizationMethod.SYMMETRIC,
QuantizationMethod.UNIFORM])
@pytest.mark.parametrize("per_channel", [True, False])
@pytest.mark.parametrize("model", [model_basic(), model_residual()])
def test_ptq_pot_weights_only(self, model, rep_data_gen, tpc_factory, quant_method, per_channel):

tpc = tpc_factory(quant_method, per_channel)
q_model, quantization_info = keras_post_training_quantization(model, rep_data_gen,
target_platform_capabilities=tpc)

self._verify_quantized_model_structure(model, q_model, quantization_info)

# Assert quantization properties
quantized_conv_layers = [l for l in q_model.layers if isinstance(l, KerasQuantizationWrapper)]
for quantize_wrapper in quantized_conv_layers:
assert isinstance(quantize_wrapper.layer,
(layers.Conv2D, layers.DepthwiseConv2D, layers.Dense, layers.Conv2DTranspose))

if isinstance(quantize_wrapper.layer, layers.DepthwiseConv2D):
weights_quantizer = quantize_wrapper.weights_quantizers[DEPTHWISE_KERNEL]
num_output_channels = (quantize_wrapper.layer.depthwise_kernel.shape[-1]
* quantize_wrapper.layer.depthwise_kernel.shape[-2])
else:
weights_quantizer = quantize_wrapper.weights_quantizers[KERNEL]
num_output_channels = quantize_wrapper.layer.kernel.shape[-1]

params_shape = num_output_channels if per_channel else 1
_verify_weights_quantizer_params(quant_method, weights_quantizer, params_shape, per_channel)





Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Loading