Skip to content

Commit

Permalink
Merge pull request #97 from fastmachinelearning/feature/tensor_stats
Browse files Browse the repository at this point in the history
Range analysis improvements and better input shape override
  • Loading branch information
maltanar authored Feb 5, 2024
2 parents 813128f + 3a33770 commit 3fd9386
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 40 deletions.
12 changes: 6 additions & 6 deletions src/qonnx/transformation/extract_conv_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,29 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import warnings
from onnx import TensorProto, helper
from onnx import helper

from qonnx.transformation.base import Transformation


class ExtractBiasFromConv(Transformation):
"""
Extracts the (optional) Bias from a Conv node and inserts it behind the
Conv node as an Add node.
Extracts the (optional) Bias from a Conv(Transpose) node and inserts it behind the
Conv(Transpose) node as an Add node.
"""

def apply(self, model):
graph = model.graph
node_ind = 0
for n in graph.node:
node_ind += 1
if n.op_type == "Conv":
if n.op_type in ["Conv", "ConvTranspose"]:
# Check if the node has a bias input
if len(n.input) > 2:
# Extract bias
bias = model.get_initializer(n.input[2])
if bias is None:
warnings.warn(f"Could not extract bias from Conv node {n}")
warnings.warn(f"Could not extract bias from node {n}")
continue

# Insert bias as Add node behind the Conv node
Expand All @@ -65,7 +65,7 @@ def apply(self, model):

act_add_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
model.get_tensor_valueinfo(n.output[0]).type.tensor_type.elem_type,
out_shape,
)
graph.value_info.append(act_add_tensor)
Expand Down
26 changes: 18 additions & 8 deletions src/qonnx/util/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit


def cleanup_model(model, preserve_qnt_ops=True, override_batchsize=None, extract_conv_bias=False):
def cleanup_model(model, preserve_qnt_ops=True, override_inpsize=None, extract_conv_bias=False):
"""Execute the transformations for the cleanup function on a model level.
This allows the reuse of the cleanup transformations, without needing to read/write the model from/to disk.
Expand All @@ -61,6 +61,19 @@ def cleanup_model(model, preserve_qnt_ops=True, override_batchsize=None, extract
preserve_qnt_optypes = ["Quant", "BipolarQuant", "QuantizeLinear", "DequantizeLinear"]
else:
preserve_qnt_optypes = []

if override_inpsize is not None:
if type(override_inpsize) is str:
override_inpsize = eval(override_inpsize)
if type(override_inpsize) is int:
override_batchsize = override_inpsize
model = model.transform(ChangeBatchSize(override_batchsize))
elif type(override_inpsize) is tuple:
override_batchsize = override_inpsize[0]
model = model.transform(ChangeBatchSize(override_batchsize))
iname = model.graph.input[0].name
model.set_tensor_shape(iname, override_inpsize)

cleanup_transformations = [
InferShapes(),
GiveUniqueParameterTensors(),
Expand All @@ -80,27 +93,24 @@ def cleanup_model(model, preserve_qnt_ops=True, override_batchsize=None, extract
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())

if override_batchsize is not None:
model = model.transform(ChangeBatchSize(override_batchsize))
model = model.transform(InferShapes())

return model


def cleanup(in_file, *, out_file=None, preserve_qnt_ops=True, override_batchsize: int = None, extract_conv_bias=False):
def cleanup(in_file, *, out_file=None, preserve_qnt_ops=True, override_inpsize: str = None, extract_conv_bias=False):
"""Execute a set of graph transformations to clean-up the given ONNX file.
:param in_file: Filename for the input ONNX model
:param preserve_qnt_ops: Preserve weight quantization operators
:param out_file: If set, filename for the output ONNX model. Set to in_file with _clean
suffix otherwise.
:param override_batchsize: If specified, override the batch size for the ONNX graph
:param override_inpsize: If specified, override the input size (e.g. "(1,3,224,224)" to set all or
just 1 to set batchsize to 1) for the ONNX graph
:param extract_conv_bias: If specified, separate Conv bias into its own Add node
"""

model = ModelWrapper(in_file)
model = cleanup_model(
model, preserve_qnt_ops=preserve_qnt_ops, override_batchsize=override_batchsize, extract_conv_bias=extract_conv_bias
model, preserve_qnt_ops=preserve_qnt_ops, override_inpsize=override_inpsize, extract_conv_bias=extract_conv_bias
)
if out_file is None:
out_file = in_file.replace(".onnx", "_clean.onnx")
Expand Down
98 changes: 72 additions & 26 deletions src/qonnx/util/range_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,6 @@ def calculate_matvec_accumulator_extremum(matrix: np.ndarray, vec_min, vec_max):
return (min_values, max_values)


def propagate_range(node, model, range_dict):
iname = node.input[0]
node_irange = range_dict[iname]
for oname in node.output:
range_dict[oname] = node_irange


def calc_gemm_range(node, model, range_dict):
alpha = get_by_name(node.attribute, "alpha").f
beta = get_by_name(node.attribute, "beta").f
Expand Down Expand Up @@ -172,10 +165,49 @@ def calc_conv_range(node, model, range_dict):
range_dict[oname] = ret


def calc_convtranspose_range(node, model, range_dict):
iname = node.input[0]
wname = node.input[1]
assert len(node.input) == 2, "Found unsupported ConvTranspose with bias"
oname = node.output[0]
irange = range_dict[iname]
imin, imax = irange
weights = model.get_initializer(wname)
assert weights is not None, "Uninitialized ConvTranspose weights"
groups = get_by_name(node.attribute, "group")
if groups is None:
# default to dense convs
groups = 1
else:
groups = groups.i
assert groups == 1, "Only dense (non-grouped) ConvTranspose is supported"
# do weight reshaping to treat Conv similar to MatMul
# (mh, mw) = (ofm, (ifm x k0 x k1 x ...))
conv_ofm = weights.shape[1]
conv_ifm = weights.shape[0]
weights = weights.transpose(1, 0, 2, 3).reshape(conv_ofm, -1)
k_total = weights.shape[1] // conv_ifm
if type(imin) is np.ndarray:
imin_rep = np.repeat(imin, k_total)
imax_rep = np.repeat(imax, k_total)
else:
imin_rep = imin
imax_rep = imax
dw_ret_min = []
dw_ret_max = []
for i in range(conv_ofm):
w_slice = weights[i, :].reshape(1, -1)
dw_ret = calculate_matvec_accumulator_extremum(w_slice, imin_rep, imax_rep)
dw_ret_min.append(dw_ret[0].item())
dw_ret_max.append(dw_ret[1].item())
ret = (np.asarray(dw_ret_min), np.asarray(dw_ret_max))
range_dict[oname] = ret


def get_minmax_prototype_tensors(irange, ishp, inp_vi, i_channel_axis=1):
proto_min = valueinfo_to_tensor(inp_vi)
proto_max = valueinfo_to_tensor(inp_vi)
if type(irange[0]) in [float, int, np.float32, np.float64, np.uint8, np.int8]:
if type(irange[0]) in [float, int, np.float16, np.float32, np.float64, np.uint8, np.int8]:
imin, imax = irange
proto_min[...] = imin
proto_max[...] = imax
Expand Down Expand Up @@ -211,25 +243,34 @@ def calc_monotonic_range(node, model, range_dict, i_channel_axis=1):
inp_vi = model.get_tensor_valueinfo(inp)
proto_vectors.append(get_minmax_prototype_tensors(irange, ishp, inp_vi, i_channel_axis))
# process all combinations of prototype vectors for dynamic inputs
running_min = None
running_max = None
running_min = [None for i in range(len(node.output))]
running_max = [None for i in range(len(node.output))]
# create context for single-node execution
ctx = {x: model.get_initializer(x) for x in node.input}
ctx[oname] = valueinfo_to_tensor(model.get_tensor_valueinfo(oname))
for oname in node.output:
ctx[oname] = valueinfo_to_tensor(model.get_tensor_valueinfo(oname))
# assume all outputs are homogenous wrt data layout (e.g. channel axis
# always lives in the same position)
axes_to_min = [i for i in range(ctx[oname].ndim)]
axes_to_min.remove(i_channel_axis)
axes_to_min = tuple(axes_to_min)
for inps in itertools.product(*proto_vectors):
for i in range(n_dyn_inp):
ctx[dyn_inps[i]] = inps[i]
execute_node(node, ctx, model.graph, opset_version=opset_version)
# grab new output and update running min/max
out = ctx[oname]
chanwise_min = out.min(axis=axes_to_min).flatten()
chanwise_max = out.max(axis=axes_to_min).flatten()
running_min = np.minimum(chanwise_min, running_min).flatten() if running_min is not None else chanwise_min
running_max = np.maximum(chanwise_max, running_max).flatten() if running_max is not None else chanwise_max
range_dict[oname] = (running_min, running_max)
for oind, oname in enumerate(node.output):
# grab new output and update running min/max
out = ctx[oname]
chanwise_min = out.min(axis=axes_to_min).flatten()
chanwise_max = out.max(axis=axes_to_min).flatten()
running_min[oind] = (
np.minimum(chanwise_min, running_min[oind]).flatten() if running_min[oind] is not None else chanwise_min
)
running_max[oind] = (
np.maximum(chanwise_max, running_max[oind]).flatten() if running_max[oind] is not None else chanwise_max
)
for oind, oname in enumerate(node.output):
range_dict[oname] = (running_min[oind], running_max[oind])


def calc_range_outdtype(node, model, range_dict):
Expand All @@ -240,12 +281,13 @@ def calc_range_outdtype(node, model, range_dict):


optype_to_range_calc = {
"Transpose": propagate_range,
"Transpose": calc_monotonic_range,
"MatMul": calc_matmul_range,
"Conv": calc_conv_range,
"ConvTranspose": calc_convtranspose_range,
"QuantMaxNorm": calc_range_outdtype,
"Flatten": propagate_range,
"Reshape": propagate_range,
"Flatten": calc_monotonic_range,
"Reshape": calc_monotonic_range,
"Quant": calc_monotonic_range,
"BipolarQuant": calc_monotonic_range,
"Mul": calc_monotonic_range,
Expand All @@ -254,7 +296,7 @@ def calc_range_outdtype(node, model, range_dict):
"Add": calc_monotonic_range,
"BatchNormalization": calc_monotonic_range,
"Relu": calc_monotonic_range,
"Pad": propagate_range,
"Pad": calc_monotonic_range,
"AveragePool": calc_monotonic_range,
"Trunc": calc_range_outdtype,
"MaxPool": calc_monotonic_range,
Expand All @@ -267,6 +309,7 @@ def calc_range_outdtype(node, model, range_dict):
"Clip": calc_monotonic_range,
"Sigmoid": calc_monotonic_range,
"Concat": calc_monotonic_range,
"Split": calc_monotonic_range,
}


Expand Down Expand Up @@ -320,8 +363,12 @@ def range_analysis(
range_min = None
range_max = None
else:
irange = irange.split(",")
range_min, range_max = float(irange[0]), float(irange[1])
irange = eval(irange)
range_min, range_max = irange
if isinstance(range_min, list):
range_min = np.asarray(range_min, dtype=np.float32)
if isinstance(range_max, list):
range_max = np.asarray(range_max, dtype=np.float32)
elif isinstance(irange, tuple):
range_min, range_max = irange
else:
Expand Down Expand Up @@ -350,9 +397,8 @@ def range_analysis(
for node in model.graph.node:
dyn_inputs = [x for x in node.input if is_dyn_input(x, model)]
inprange_ok = all([x in range_dict.keys() for x in dyn_inputs])
outcount_ok = len(node.output) == 1
op_ok = node.op_type in optype_to_range_calc.keys()
if inprange_ok and op_ok and outcount_ok:
if inprange_ok and op_ok:
range_calc_fxn = optype_to_range_calc[node.op_type]
range_calc_fxn(node, model, range_dict)
out_range = range_dict[node.output[0]]
Expand Down

0 comments on commit 3fd9386

Please sign in to comment.