Skip to content

Commit

Permalink
Merge pull request #18 from vloncar:pruning
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 297473595
Change-Id: I300a5d241523d0ea09f4d2004e64f94dde6748d3
  • Loading branch information
copybara-github committed Feb 27, 2020
2 parents 8f17e33 + 3fe9a31 commit bc70882
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 6 deletions.
206 changes: 206 additions & 0 deletions examples/example_mnist_prune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright 2019 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Example of mnist model with pruning.
Adapted from TF model optimization example."""

import tempfile
import numpy as np

import tensorflow.keras.backend as K
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.models import save_model
from tensorflow.keras.utils import to_categorical

from qkeras import QActivation
from qkeras import QDense
from qkeras import QConv2D
from qkeras import quantized_bits
from qkeras.utils import load_qmodel
from qkeras.utils import print_model_sparsity

from tensorflow_model_optimization.python.core.sparsity.keras import prune
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule


batch_size = 128
num_classes = 10
epochs = 12

prune_whole_model = True # Prune whole model or just specified layers


def build_model(input_shape):
x = x_in = Input(shape=input_shape, name="input")
x = QConv2D(
32, (2, 2), strides=(2,2),
kernel_quantizer=quantized_bits(4,0,1),
bias_quantizer=quantized_bits(4,0,1),
name="conv2d_0_m")(x)
x = QActivation("quantized_relu(4,0)", name="act0_m")(x)
x = QConv2D(
64, (3, 3), strides=(2,2),
kernel_quantizer=quantized_bits(4,0,1),
bias_quantizer=quantized_bits(4,0,1),
name="conv2d_1_m")(x)
x = QActivation("quantized_relu(4,0)", name="act1_m")(x)
x = QConv2D(
64, (2, 2), strides=(2,2),
kernel_quantizer=quantized_bits(4,0,1),
bias_quantizer=quantized_bits(4,0,1),
name="conv2d_2_m")(x)
x = QActivation("quantized_relu(4,0)", name="act2_m")(x)
x = Flatten()(x)
x = QDense(num_classes, kernel_quantizer=quantized_bits(4,0,1),
bias_quantizer=quantized_bits(4,0,1),
name="dense")(x)
x = Activation("softmax", name="softmax")(x)

model = Model(inputs=[x_in], outputs=[x])
return model


def build_layerwise_model(input_shape, **pruning_params):
return Sequential([
prune.prune_low_magnitude(
QConv2D(
32, (2, 2), strides=(2,2),
kernel_quantizer=quantized_bits(4,0,1),
bias_quantizer=quantized_bits(4,0,1),
name="conv2d_0_m"),
input_shape=input_shape,
**pruning_params),
QActivation("quantized_relu(4,0)", name="act0_m"),
prune.prune_low_magnitude(
QConv2D(
64, (3, 3), strides=(2,2),
kernel_quantizer=quantized_bits(4,0,1),
bias_quantizer=quantized_bits(4,0,1),
name="conv2d_1_m"),
**pruning_params),
QActivation("quantized_relu(4,0)", name="act1_m"),
prune.prune_low_magnitude(
QConv2D(
64, (2, 2), strides=(2,2),
kernel_quantizer=quantized_bits(4,0,1),
bias_quantizer=quantized_bits(4,0,1),
name="conv2d_2_m"),
**pruning_params),
QActivation("quantized_relu(4,0)", name="act2_m"),
Flatten(),
prune.prune_low_magnitude(
QDense(
num_classes, kernel_quantizer=quantized_bits(4,0,1),
bias_quantizer=quantized_bits(4,0,1),
name="dense"),
**pruning_params),
Activation("softmax", name="softmax")
])


def train_and_save(model, x_train, y_train, x_test, y_test):
model.compile(
loss="categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])

# Print the model summary.
model.summary()

# Add a pruning step callback to peg the pruning step to the optimizer's
# step. Also add a callback to add pruning summaries to tensorboard
callbacks = [
pruning_callbacks.UpdatePruningStep(),
#pruning_callbacks.PruningSummaries(log_dir=tempfile.mkdtemp())
pruning_callbacks.PruningSummaries(log_dir="/tmp/mnist_prune")
]

model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
callbacks=callbacks,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])

print_model_sparsity(model)

# Export and import the model. Check that accuracy persists.
_, keras_file = tempfile.mkstemp(".h5")
print("Saving model to: ", keras_file)
save_model(model, keras_file)

print("Reloading model")
with prune.prune_scope():
loaded_model = load_qmodel(keras_file)
score = loaded_model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])


def main():
# input image dimensions
img_rows, img_cols = 28, 28

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == "channels_first":
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# convert class vectors to binary class matrices
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

pruning_params = {
"pruning_schedule":
pruning_schedule.ConstantSparsity(0.75, begin_step=2000, frequency=100)
}

if prune_whole_model:
model = build_model(input_shape)
model = prune.prune_low_magnitude(model, **pruning_params)
else:
model = build_layerwise_model(input_shape, **pruning_params)

train_and_save(model, x_train, y_train, x_test, y_test)


if __name__ == "__main__":
main()
16 changes: 13 additions & 3 deletions qkeras/qconvolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import InputSpec
from tensorflow.keras.layers import Layer
from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer

from .qlayers import Clip
from .qlayers import QActivation
from .quantizers import get_quantizer
from .quantizers import get_quantized_initializer


class QConv1D(Conv1D):
class QConv1D(Conv1D, PrunableLayer):
"""1D convolution layer (e.g. spatial convolution over images)."""

# most of these parameters follow the implementation of Conv1D in Keras,
Expand Down Expand Up @@ -155,8 +156,11 @@ def get_config(self):
def get_quantizers(self):
return self.quantizers

def get_prunable_weights(self):
return [self.kernel]

class QConv2D(Conv2D):

class QConv2D(Conv2D, PrunableLayer):
"""2D convolution layer (e.g. spatial convolution over images)."""

# most of these parameters follow the implementation of Conv2D in Keras,
Expand Down Expand Up @@ -284,8 +288,11 @@ def get_config(self):
def get_quantizers(self):
return self.quantizers

def get_prunable_weights(self):
return [self.kernel]


class QDepthwiseConv2D(DepthwiseConv2D):
class QDepthwiseConv2D(DepthwiseConv2D, PrunableLayer):
"""Creates quantized depthwise conv2d. Copied from mobilenet."""

# most of these parameters follow the implementation of DepthwiseConv2D
Expand Down Expand Up @@ -457,6 +464,9 @@ def get_config(self):
def get_quantizers(self):
return self.quantizers

def get_prunable_weights(self):
return []


def QSeparableConv2D(filters, # pylint: disable=invalid-name
kernel_size,
Expand Down
12 changes: 10 additions & 2 deletions qkeras/qlayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Layer
from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer


import numpy as np
Expand All @@ -60,7 +61,7 @@
#


class QActivation(Layer):
class QActivation(Layer, PrunableLayer):
"""Implements quantized activation layers."""

def __init__(self, activation, **kwargs):
Expand Down Expand Up @@ -97,6 +98,9 @@ def get_config(self):
def compute_output_shape(self, input_shape):
return input_shape

def get_prunable_weights(self):
return []


#
# Constraint class to clip weights and bias between -1 and 1 so that:
Expand Down Expand Up @@ -149,7 +153,7 @@ def get_config(self):
#


class QDense(Dense):
class QDense(Dense, PrunableLayer):
"""Implements a quantized Dense layer."""

# most of these parameters follow the implementation of Dense in
Expand Down Expand Up @@ -284,3 +288,7 @@ def get_config(self):

def get_quantizers(self):
return self.quantizers

def get_prunable_weights(self):
return [self.kernel]

7 changes: 6 additions & 1 deletion qkeras/qnormalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer

import numpy as np
import six
Expand All @@ -40,7 +41,7 @@
from .safe_eval import safe_eval


class QBatchNormalization(BatchNormalization):
class QBatchNormalization(BatchNormalization, PrunableLayer):
"""Quantized Batch Normalization layer.
For training, mean and variance are not quantized.
For inference, the quantized moving mean and moving variance are used.
Expand Down Expand Up @@ -302,3 +303,7 @@ def compute_output_shape(self, input_shape):

def get_quantizers(self):
return self.quantizers

def get_prunable_weights(self):
return []

33 changes: 33 additions & 0 deletions qkeras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@
import six

import tensorflow as tf
import tensorflow.keras.backend as K

from tensorflow.keras import initializers
from tensorflow.keras.models import model_from_json

from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer

import numpy as np

from .qlayers import QActivation
Expand Down Expand Up @@ -472,3 +477,31 @@ def load_qmodel(filepath, custom_objects=None, compile=True):
qmodel = tf.keras.models.load_model(filepath, custom_objects=custom_objects, compile=compile)

return qmodel


def print_model_sparsity(model):
"""Prints sparsity for the pruned layers in the model."""

def _get_sparsity(weights):
return 1.0 - np.count_nonzero(weights) / float(weights.size)

print("Model Sparsity Summary ({})".format(model.name))
print("--")
for layer in model.layers:
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
prunable_weights = layer.layer.get_prunable_weights()
elif isinstance(layer, prunable_layer.PrunableLayer):
prunable_weights = layer.get_prunable_weights()
elif prune_registry.PruneRegistry.supports(layer):
weight_names = prune_registry.PruneRegistry._weight_names(layer)
prunable_weights = [getattr(layer, weight) for weight in weight_names]
else:
prunable_weights = None
if prunable_weights:
print("{}: {}".format(
layer.name, ", ".join([
"({}, {})".format(weight.name,
str(_get_sparsity(K.get_value(weight))))
for weight in prunable_weights
])))
print("\n")
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pyasn1<0.5.0,>=0.4.6
requests<3,>=2.21.0
pyparsing
pytest>=4.6.9
tensorflow-model-optimization>=0.2.1

0 comments on commit bc70882

Please sign in to comment.