Skip to content

Commit

Permalink
Merge branch 'main' into logger_update
Browse files Browse the repository at this point in the history
  • Loading branch information
Ofir Gordon authored and Ofir Gordon committed Mar 20, 2024
2 parents eea3aff + 1409fdc commit 53cc93f
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 16 deletions.
8 changes: 7 additions & 1 deletion model_compression_toolkit/core/pytorch/reader/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.fx import symbolic_trace
from torch.fx.passes.shape_prop import ShapeProp

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.pytorch.reader.graph_builders import edges_builder, nodes_builder
from model_compression_toolkit.core.pytorch.utils import set_model
Expand Down Expand Up @@ -84,7 +85,12 @@ def fx_graph_module_generation(pytorch_model: torch.nn.Module,
A fx.GraphModule (static model) representing the Pytorch model.
"""
set_model(pytorch_model)
symbolic_traced = symbolic_trace(pytorch_model)

try:
symbolic_traced = symbolic_trace(pytorch_model)
except torch.fx.proxy.TraceError as e:
Logger.error(f'Error parsing model with torch.fx\n'
f'fx error: {e}')
inputs = next(representative_data_gen())
input_for_shape_infer = [to_tensor(i) for i in inputs]
ShapeProp(symbolic_traced).propagate(*input_for_shape_infer)
Expand Down
14 changes: 14 additions & 0 deletions tests/pytorch_tests/graph_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
53 changes: 53 additions & 0 deletions tests/pytorch_tests/graph_tests/test_fx_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


import unittest
import torch
import numpy as np

from model_compression_toolkit.core.pytorch.reader.reader import fx_graph_module_generation
from model_compression_toolkit.core.pytorch.pytorch_implementation import to_torch_tensor


class BadFxModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
self.relu = torch.nn.ReLU()

def forward(self, inputs, flag=False):
x = self.conv(inputs)
if flag:
x = self.relu(x)
else:
x = self.relu(x) + x
return x


class TestGraphReading(unittest.TestCase):

def test_graph_reading(self):
model = BadFxModel()
try:
graph = fx_graph_module_generation(model,
lambda : np.zeros((1, 3, 20, 20)),
to_torch_tensor)
except Exception as e:
self.assertEqual(str(e).split('\n')[0], 'Error parsing model with torch.fx')


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions tests/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
TestGPTQModelBuilderWithActivationHolder as TestGPTQModelBuilderWithActivationHolderPytorch
from tests.pytorch_tests.exporter_tests.test_runner import PytorchExporterTestsRunner
from tests.data_generation_tests.pytorch.test_pytorch_data_generation_runner import PytorchDataGenerationTestRunner
from tests.pytorch_tests.graph_tests.test_fx_errors import TestGraphReading


if __name__ == '__main__':
Expand Down Expand Up @@ -160,6 +161,7 @@
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(PytorchTrainableInfrastructureTestRunner))
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(PytorchExporterTestsRunner))
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(PytorchDataGenerationTestRunner))
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestGraphReading))

# ---------------- Join them together and run them
comboSuite = unittest.TestSuite(suiteList)
Expand Down
48 changes: 33 additions & 15 deletions tutorials/notebooks/keras/export/example_keras_export.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@
"cell_type": "code",
"source": [
"import numpy as np\n",
"from keras.applications import ResNet50\n",
"from keras.applications import MobileNetV2\n",
"import model_compression_toolkit as mct\n",
"\n",
"# Create a model\n",
"float_model = ResNet50()\n",
"float_model = MobileNetV2()\n",
"# Quantize the model.\n",
"# Notice that here the representative dataset is random for demonstration only.\n",
"quantized_exportable_model, _ = mct.ptq.keras_post_training_quantization(float_model,\n",
Expand Down Expand Up @@ -87,10 +87,8 @@
{
"cell_type": "code",
"source": [
"import tempfile\n",
"\n",
"# Path of exported model\n",
"_, keras_file_path = tempfile.mkstemp('.keras')\n",
"keras_file_path = 'exported_model_mctq.keras'\n",
"\n",
"# Export a keras model with mctq custom quantizers.\n",
"mct.exporter.keras_export_model(model=quantized_exportable_model,\n",
Expand All @@ -107,17 +105,40 @@
"source": [
"Notice that the model has the same size as the quantized exportable model as weights data types are float.\n",
"\n",
"#### Fakely-Quantized"
"#### MCTQ - Loading Exported Model\n",
"\n",
"To load the exported model with MCTQ quantizers, use `mct.keras_load_quantized_model`:"
],
"metadata": {
"id": "Bwx5rxXDF_gb"
}
},
{
"cell_type": "code",
"source": [
"loaded_model = mct.keras_load_quantized_model(keras_file_path)"
],
"metadata": {
"id": "q235XNJQmTdd"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"\n",
"#### Fakely-Quantized"
],
"metadata": {
"id": "sOmDjSehlQba"
}
},
{
"cell_type": "code",
"source": [
"# Path of exported model\n",
"_, keras_file_path = tempfile.mkstemp('.keras')\n",
"keras_file_path = 'exported_model_fakequant.keras'\n",
"\n",
"# Use mode KerasExportSerializationFormat.KERAS for a .keras model\n",
"# and QuantizationFormat.FAKELY_QUANT for fakely-quantized weights\n",
Expand Down Expand Up @@ -154,10 +175,7 @@
{
"cell_type": "code",
"source": [
"import tempfile\n",
"\n",
"# Path of exported model\n",
"_, tflite_file_path = tempfile.mkstemp('.tflite')\n",
"tflite_file_path = 'exported_model_int8.tflite'\n",
"\n",
"# Use mode KerasExportSerializationFormat.TFLITE for tflite model and quantization_format.INT8.\n",
"mct.exporter.keras_export_model(model=quantized_exportable_model,\n",
Expand Down Expand Up @@ -186,12 +204,11 @@
"import os\n",
"\n",
"# Save float model to measure its size\n",
"_, float_file_path = tempfile.mkstemp('.keras')\n",
"float_file_path = 'exported_model_float.keras'\n",
"float_model.save(float_file_path)\n",
"\n",
"print(\"Float model in Mb:\", os.path.getsize(float_file_path) / float(2 ** 20))\n",
"print(\"Quantized model in Mb:\", os.path.getsize(tflite_file_path) / float(2 ** 20))\n",
"print(f'Compression ratio: {os.path.getsize(float_file_path) / os.path.getsize(tflite_file_path)}')"
"print(\"Quantized model in Mb:\", os.path.getsize(tflite_file_path) / float(2 ** 20))"
],
"metadata": {
"id": "LInM16OMGUtF"
Expand All @@ -217,7 +234,8 @@
"cell_type": "code",
"source": [
"# Path of exported model\n",
"_, tflite_file_path = tempfile.mkstemp('.tflite')\n",
"tflite_file_path = 'exported_model_fakequant.tflite'\n",
"\n",
"\n",
"# Use mode KerasExportSerializationFormat.TFLITE for tflite model and QuantizationFormat.FAKELY_QUANT for fakely-quantized weights\n",
"# and activations.\n",
Expand Down

0 comments on commit 53cc93f

Please sign in to comment.