diff --git a/src/qonnx/transformation/remove.py b/src/qonnx/transformation/remove.py index 0f7f38f7..22dbd6eb 100644 --- a/src/qonnx/transformation/remove.py +++ b/src/qonnx/transformation/remove.py @@ -144,3 +144,18 @@ def apply(self, model): break model = model.transform(InferShapes()) return (model, graph_modified) + + +class RemoveSuccessiveIdenticalQuant(Transformation): + def apply(self, model): + graph = model.graph + graph_modified = False + for n in graph.node: + if n.op_type == "Quant": + successors = model.find_direct_successors(n) + if successors is not None and len(successors) == 1 and successors[0].op_type == "Quant": + init_node = [model.get_initializer(i) for i in n.input] + init_succ = [model.get_initializer(i) for i in successors[0].input] + if init_node == init_succ: + remove_node_and_rewire(model, n) + return (model, graph_modified)