From 37c2e1e0446e8e2f92700eb38614b6107ac9bf8d Mon Sep 17 00:00:00 2001 From: reuvenp Date: Tue, 14 Jan 2025 19:05:10 +0200 Subject: [PATCH] fix tpc creation after refactor --- .../keras/example_keras_pruning_mnist.ipynb | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/tutorials/notebooks/mct_features_notebooks/keras/example_keras_pruning_mnist.ipynb b/tutorials/notebooks/mct_features_notebooks/keras/example_keras_pruning_mnist.ipynb index b1b594c81..8bcde15da 100644 --- a/tutorials/notebooks/mct_features_notebooks/keras/example_keras_pruning_mnist.ipynb +++ b/tutorials/notebooks/mct_features_notebooks/keras/example_keras_pruning_mnist.ipynb @@ -217,22 +217,23 @@ { "cell_type": "code", "source": [ - "from model_compression_toolkit.target_platform_capabilities.target_platform import Signedness\n", - "tp = mct.target_platform\n", + "from mct_quantizers import QuantizationMethod\n", + "from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import schema, TargetPlatformCapabilities, Signedness, \\\n", + " AttributeQuantizationConfig, OpQuantizationConfig\n", "\n", "simd_size = 1\n", "\n", "def get_tpc():\n", " # Define the default weight attribute configuration\n", - " default_weight_attr_config = tp.AttributeQuantizationConfig(\n", - " weights_quantization_method=tp.QuantizationMethod.UNIFORM,\n", + " default_weight_attr_config = AttributeQuantizationConfig(\n", + " weights_quantization_method=QuantizationMethod.UNIFORM,\n", " )\n", "\n", " # Define the OpQuantizationConfig\n", - " default_config = tp.OpQuantizationConfig(\n", + " default_config = OpQuantizationConfig(\n", " default_weight_attr_config=default_weight_attr_config,\n", " attr_weights_configs_mapping={},\n", - " activation_quantization_method=tp.QuantizationMethod.UNIFORM,\n", + " activation_quantization_method=QuantizationMethod.UNIFORM,\n", " activation_n_bits=8,\n", " supported_input_activation_n_bits=8,\n", " enable_activation_quantization=False,\n", @@ -240,18 +241,19 @@ " fixed_scale=None,\n", " fixed_zero_point=None,\n", " simd_size=simd_size,\n", - " signedness=Signedness.AUTO\n", + " signedness=schema.Signedness.AUTO\n", " )\n", + " \n", + " # In this tutorial, we will use the default OpQuantizationConfig for all operator sets.\n", + " operator_set=[]\n", "\n", " # Create the quantization configuration options and model\n", - " default_configuration_options = tp.QuantizationConfigOptions([default_config])\n", - " tp_model = tp.TargetPlatformCapabilities(default_configuration_options,\n", - " tpc_minor_version=1,\n", - " tpc_patch_version=0,\n", - " tpc_platform_type=\"custom_pruning_notebook_tpc\")\n", - "\n", - " # Return the target platform capabilities\n", - " tpc = tp.FrameworkQuantizationCapabilities(tp_model)\n", + " default_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple([default_config]))\n", + " tpc = TargetPlatformCapabilities(default_qco=default_configuration_options,\n", + " tpc_minor_version=1,\n", + " tpc_patch_version=0,\n", + " tpc_platform_type=\"custom_pruning_notebook_tpc\",\n", + " operator_set=tuple(operator_set))\n", " return tpc\n" ], "metadata": {