diff --git a/examples/example_mnist_prune.py b/examples/example_mnist_prune.py new file mode 100644 index 00000000..fc88e45a --- /dev/null +++ b/examples/example_mnist_prune.py @@ -0,0 +1,206 @@ +# Copyright 2019 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Example of mnist model with pruning. + Adapted from TF model optimization example.""" + +import tempfile +import numpy as np + +import tensorflow.keras.backend as K +from tensorflow.keras.datasets import mnist +from tensorflow.keras.layers import Activation +from tensorflow.keras.layers import Flatten +from tensorflow.keras.layers import Input +from tensorflow.keras.models import Model +from tensorflow.keras.models import Sequential +from tensorflow.keras.models import save_model +from tensorflow.keras.utils import to_categorical + +from qkeras import QActivation +from qkeras import QDense +from qkeras import QConv2D +from qkeras import quantized_bits +from qkeras.utils import load_qmodel +from qkeras.utils import print_model_sparsity + +from tensorflow_model_optimization.python.core.sparsity.keras import prune +from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks +from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule + + +batch_size = 128 +num_classes = 10 +epochs = 12 + +prune_whole_model = True # Prune whole model or just specified layers + + +def build_model(input_shape): + x = x_in = Input(shape=input_shape, name="input") + x = QConv2D( + 32, (2, 2), strides=(2,2), + kernel_quantizer=quantized_bits(4,0,1), + bias_quantizer=quantized_bits(4,0,1), + name="conv2d_0_m")(x) + x = QActivation("quantized_relu(4,0)", name="act0_m")(x) + x = QConv2D( + 64, (3, 3), strides=(2,2), + kernel_quantizer=quantized_bits(4,0,1), + bias_quantizer=quantized_bits(4,0,1), + name="conv2d_1_m")(x) + x = QActivation("quantized_relu(4,0)", name="act1_m")(x) + x = QConv2D( + 64, (2, 2), strides=(2,2), + kernel_quantizer=quantized_bits(4,0,1), + bias_quantizer=quantized_bits(4,0,1), + name="conv2d_2_m")(x) + x = QActivation("quantized_relu(4,0)", name="act2_m")(x) + x = Flatten()(x) + x = QDense(num_classes, kernel_quantizer=quantized_bits(4,0,1), + bias_quantizer=quantized_bits(4,0,1), + name="dense")(x) + x = Activation("softmax", name="softmax")(x) + + model = Model(inputs=[x_in], outputs=[x]) + return model + + +def build_layerwise_model(input_shape, **pruning_params): + return Sequential([ + prune.prune_low_magnitude( + QConv2D( + 32, (2, 2), strides=(2,2), + kernel_quantizer=quantized_bits(4,0,1), + bias_quantizer=quantized_bits(4,0,1), + name="conv2d_0_m"), + input_shape=input_shape, + **pruning_params), + QActivation("quantized_relu(4,0)", name="act0_m"), + prune.prune_low_magnitude( + QConv2D( + 64, (3, 3), strides=(2,2), + kernel_quantizer=quantized_bits(4,0,1), + bias_quantizer=quantized_bits(4,0,1), + name="conv2d_1_m"), + **pruning_params), + QActivation("quantized_relu(4,0)", name="act1_m"), + prune.prune_low_magnitude( + QConv2D( + 64, (2, 2), strides=(2,2), + kernel_quantizer=quantized_bits(4,0,1), + bias_quantizer=quantized_bits(4,0,1), + name="conv2d_2_m"), + **pruning_params), + QActivation("quantized_relu(4,0)", name="act2_m"), + Flatten(), + prune.prune_low_magnitude( + QDense( + num_classes, kernel_quantizer=quantized_bits(4,0,1), + bias_quantizer=quantized_bits(4,0,1), + name="dense"), + **pruning_params), + Activation("softmax", name="softmax") + ]) + + +def train_and_save(model, x_train, y_train, x_test, y_test): + model.compile( + loss="categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"]) + + # Print the model summary. + model.summary() + + # Add a pruning step callback to peg the pruning step to the optimizer's + # step. Also add a callback to add pruning summaries to tensorboard + callbacks = [ + pruning_callbacks.UpdatePruningStep(), + #pruning_callbacks.PruningSummaries(log_dir=tempfile.mkdtemp()) + pruning_callbacks.PruningSummaries(log_dir="/tmp/mnist_prune") + ] + + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs, + verbose=1, + callbacks=callbacks, + validation_data=(x_test, y_test)) + score = model.evaluate(x_test, y_test, verbose=0) + print("Test loss:", score[0]) + print("Test accuracy:", score[1]) + + print_model_sparsity(model) + + # Export and import the model. Check that accuracy persists. + _, keras_file = tempfile.mkstemp(".h5") + print("Saving model to: ", keras_file) + save_model(model, keras_file) + + print("Reloading model") + with prune.prune_scope(): + loaded_model = load_qmodel(keras_file) + score = loaded_model.evaluate(x_test, y_test, verbose=0) + print("Test loss:", score[0]) + print("Test accuracy:", score[1]) + + +def main(): + # input image dimensions + img_rows, img_cols = 28, 28 + + # the data, shuffled and split between train and test sets + (x_train, y_train), (x_test, y_test) = mnist.load_data() + + if K.image_data_format() == "channels_first": + x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) + x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) + input_shape = (1, img_rows, img_cols) + else: + x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) + x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) + input_shape = (img_rows, img_cols, 1) + + x_train = x_train.astype("float32") + x_test = x_test.astype("float32") + x_train /= 255 + x_test /= 255 + print("x_train shape:", x_train.shape) + print(x_train.shape[0], "train samples") + print(x_test.shape[0], "test samples") + + # convert class vectors to binary class matrices + y_train = to_categorical(y_train, num_classes) + y_test = to_categorical(y_test, num_classes) + + pruning_params = { + "pruning_schedule": + pruning_schedule.ConstantSparsity(0.75, begin_step=2000, frequency=100) + } + + if prune_whole_model: + model = build_model(input_shape) + model = prune.prune_low_magnitude(model, **pruning_params) + else: + model = build_layerwise_model(input_shape, **pruning_params) + + train_and_save(model, x_train, y_train, x_test, y_test) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/qkeras/qconvolutional.py b/qkeras/qconvolutional.py index 0e1f11e7..337b8d79 100644 --- a/qkeras/qconvolutional.py +++ b/qkeras/qconvolutional.py @@ -28,6 +28,7 @@ from tensorflow.keras.layers import Dropout from tensorflow.keras.layers import InputSpec from tensorflow.keras.layers import Layer +from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer from .qlayers import Clip from .qlayers import QActivation @@ -35,7 +36,7 @@ from .quantizers import get_quantized_initializer -class QConv1D(Conv1D): +class QConv1D(Conv1D, PrunableLayer): """1D convolution layer (e.g. spatial convolution over images).""" # most of these parameters follow the implementation of Conv1D in Keras, @@ -155,8 +156,11 @@ def get_config(self): def get_quantizers(self): return self.quantizers + def get_prunable_weights(self): + return [self.kernel] -class QConv2D(Conv2D): + +class QConv2D(Conv2D, PrunableLayer): """2D convolution layer (e.g. spatial convolution over images).""" # most of these parameters follow the implementation of Conv2D in Keras, @@ -284,8 +288,11 @@ def get_config(self): def get_quantizers(self): return self.quantizers + def get_prunable_weights(self): + return [self.kernel] + -class QDepthwiseConv2D(DepthwiseConv2D): +class QDepthwiseConv2D(DepthwiseConv2D, PrunableLayer): """Creates quantized depthwise conv2d. Copied from mobilenet.""" # most of these parameters follow the implementation of DepthwiseConv2D @@ -457,6 +464,9 @@ def get_config(self): def get_quantizers(self): return self.quantizers + def get_prunable_weights(self): + return [] + def QSeparableConv2D(filters, # pylint: disable=invalid-name kernel_size, diff --git a/qkeras/qlayers.py b/qkeras/qlayers.py index 7a17570b..59fa709d 100644 --- a/qkeras/qlayers.py +++ b/qkeras/qlayers.py @@ -45,6 +45,7 @@ from tensorflow.keras.constraints import Constraint from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Layer +from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer import numpy as np @@ -60,7 +61,7 @@ # -class QActivation(Layer): +class QActivation(Layer, PrunableLayer): """Implements quantized activation layers.""" def __init__(self, activation, **kwargs): @@ -97,6 +98,9 @@ def get_config(self): def compute_output_shape(self, input_shape): return input_shape + def get_prunable_weights(self): + return [] + # # Constraint class to clip weights and bias between -1 and 1 so that: @@ -149,7 +153,7 @@ def get_config(self): # -class QDense(Dense): +class QDense(Dense, PrunableLayer): """Implements a quantized Dense layer.""" # most of these parameters follow the implementation of Dense in @@ -284,3 +288,7 @@ def get_config(self): def get_quantizers(self): return self.quantizers + + def get_prunable_weights(self): + return [self.kernel] + diff --git a/qkeras/qnormalization.py b/qkeras/qnormalization.py index c6663dd6..245edf2f 100644 --- a/qkeras/qnormalization.py +++ b/qkeras/qnormalization.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer import numpy as np import six @@ -40,7 +41,7 @@ from .safe_eval import safe_eval -class QBatchNormalization(BatchNormalization): +class QBatchNormalization(BatchNormalization, PrunableLayer): """Quantized Batch Normalization layer. For training, mean and variance are not quantized. For inference, the quantized moving mean and moving variance are used. @@ -302,3 +303,7 @@ def compute_output_shape(self, input_shape): def get_quantizers(self): return self.quantizers + + def get_prunable_weights(self): + return [] + diff --git a/qkeras/utils.py b/qkeras/utils.py index 49da7947..3596c29d 100644 --- a/qkeras/utils.py +++ b/qkeras/utils.py @@ -18,10 +18,15 @@ import six import tensorflow as tf +import tensorflow.keras.backend as K from tensorflow.keras import initializers from tensorflow.keras.models import model_from_json +from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper +from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry +from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer + import numpy as np from .qlayers import QActivation @@ -472,3 +477,31 @@ def load_qmodel(filepath, custom_objects=None, compile=True): qmodel = tf.keras.models.load_model(filepath, custom_objects=custom_objects, compile=compile) return qmodel + + +def print_model_sparsity(model): + """Prints sparsity for the pruned layers in the model.""" + + def _get_sparsity(weights): + return 1.0 - np.count_nonzero(weights) / float(weights.size) + + print("Model Sparsity Summary ({})".format(model.name)) + print("--") + for layer in model.layers: + if isinstance(layer, pruning_wrapper.PruneLowMagnitude): + prunable_weights = layer.layer.get_prunable_weights() + elif isinstance(layer, prunable_layer.PrunableLayer): + prunable_weights = layer.get_prunable_weights() + elif prune_registry.PruneRegistry.supports(layer): + weight_names = prune_registry.PruneRegistry._weight_names(layer) + prunable_weights = [getattr(layer, weight) for weight in weight_names] + else: + prunable_weights = None + if prunable_weights: + print("{}: {}".format( + layer.name, ", ".join([ + "({}, {})".format(weight.name, + str(_get_sparsity(K.get_value(weight)))) + for weight in prunable_weights + ]))) + print("\n") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8979d979..538b777c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ pyasn1<0.5.0,>=0.4.6 requests<3,>=2.21.0 pyparsing pytest>=4.6.9 +tensorflow-model-optimization>=0.2.1 \ No newline at end of file