Skip to content

Onnxscript implementation of BatchNormToAffine #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ install_requires =
onnxruntime>=1.16.1
sigtools>=4.0.1
toposort>=1.7.0
onnxscript==0.2.6


[options.packages.find]
Expand Down
111 changes: 46 additions & 65 deletions src/qonnx/transformation/batchnorm_to_affine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020 Xilinx, Inc.
# Copyright (c) 2025 Advanced Micro Devices, Inc.
# All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please do not remove older copyrights, you can just add the new one for 2025 AMD as the 2nd line

#
# Redistribution and use in source and binary forms, with or without
Expand All @@ -11,7 +11,7 @@
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of Xilinx nor the names of its
# * Neither the name of AMD nor the names of its
# contributors may be used to endorse or promote products derived from
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually this one should be QONNX, not Xilinx or AMD

# this software without specific prior written permission.
#
Expand All @@ -27,78 +27,59 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
from onnx import TensorProto
from onnx import TensorProto, helper
from onnx import helper as oh

from qonnx.transformation.base import Transformation
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.fold_constants import FoldConstants
from qonnx.util.basic import get_by_name

from qonnx.core.modelwrapper import ModelWrapper
from onnxscript import opset15 as op
from onnxscript import script
from onnxscript.rewriter import pattern, rewrite
from onnxscript import ir

from qonnx.util.onnxscript import ReplacePattern

def target_pattern(op, x, scale, bias, mean, var):
return op.BatchNormalization(x, scale, bias, mean, var)

def replace_pattern(op, x, scale, bias, mean, var, **kwargs):

# Get epsilon from matched pattern
batch_norm = kwargs['match'].nodes[0]
epsilon_attr = batch_norm.attributes.get('epsilon', None)
epsilon_value = 1e-5 if epsilon_attr is None else epsilon_attr.value
Comment on lines +52 to +53
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be handled by something like https://github.com/microsoft/onnxscript/blob/main/onnxscript/rewriter/cast_constant_of_shape.py#L13-L22 to get the attribute instead, removing the need for the special util?

epsilon_tensor = helper.make_tensor("epsilon", TensorProto.FLOAT, (1,), [epsilon_value])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better to have the 1e-5 default configurable with e.g. a top level variable in the module (or if possible, passed in as an optional arg to the Transformation with this default value)

epsilon = op.Constant(value=epsilon_tensor)

A = op.Div(scale, op.Sqrt(op.Add(var, epsilon)))
B = op.Sub(bias, op.Mul(A, mean))

# Unsqueeze A and B
input_shape = x.shape
assert input_shape is not None and len(input_shape) >= 2
n_spatial_dims = len(input_shape) - 2
axes = [0] + [i + 2 for i in range(n_spatial_dims)]
A = op.Unsqueeze(A, axes=axes)
B = op.Unsqueeze(B, axes=axes)

Comment on lines +62 to +67
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. is this always safe? does it make sense to add more testcases with different dimensionalties? (now much easier to create dummy models with onnxscript)
  2. the original transformation removes surrounding squeeze/unsqueeze nodes if they were present around the BatchNorm, does this new version still have the same effect? (also another good thing to test/check)

return op.Add(op.Mul(x, A), B)

rule1 = pattern.RewriteRule(target_pattern, ReplacePattern(replace_pattern), verbose=10)
rewrite_rules = pattern.RewriteRuleSet([rule1])

class BatchNormToAffine(Transformation):
"""Replaces any test-time BatchNorm layers with Mul-Add layers."""

def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "BatchNormalization":
graph_modified = True
bn_input = n.input[0]
bn_output = n.output[0]
# extract batchnorm parameters as numpy arrays
scale = model.get_initializer(n.input[1])
bias = model.get_initializer(n.input[2])
mean = model.get_initializer(n.input[3])
variance = model.get_initializer(n.input[4])
epsilon = get_by_name(n.attribute, "epsilon")
epsilon = getattr(epsilon, "f", 1e-5)
# find A and B to compute batchnorm as affine transpose Ax+B
# TODO is a division by moving avg factor needed for variance?
A = scale / np.sqrt(epsilon + variance)
B = bias - (A * mean)
# see if we have surrounding Unsqueeze/Squeeze nodes we can remove
producer = model.find_producer(bn_input)
if producer is not None:
if producer.op_type == "Unsqueeze":
bn_input = producer.input[0]
consumer = model.find_consumer(bn_output)
if consumer is not None:
if consumer.op_type == "Squeeze":
bn_output = consumer.output[0]
data_shape = model.get_tensor_shape(bn_input)
assert A.ndim == B.ndim, "Unexpected mul/add dims in BatchNormToAffine"
assert len(data_shape) >= A.ndim, "Unexpected number of dims found in BatchNormToAffine"
# reshape the mul/add constants to match the data shape/dims
# by adding (1,) dimensions to the right
n_spatial_dims = len(data_shape) - 2
target_shape = (1, -1) + tuple(1 for i in range(n_spatial_dims))
A = A.reshape(target_shape)
B = B.reshape(target_shape)
# create value_info and initializers for Mul and Add constants
mul_const = oh.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape)
graph.value_info.append(mul_const)
model.set_initializer(mul_const.name, A)
mul_output = oh.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, data_shape)
graph.value_info.append(mul_output)
add_const = oh.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape)
graph.value_info.append(add_const)
model.set_initializer(add_const.name, B)
# create Mul and Add nodes to replace the batchnorm
mul_node = oh.make_node("Mul", [bn_input, mul_const.name], [mul_output.name])
add_node = oh.make_node("Add", [mul_output.name, add_const.name], [bn_output])
# insert where the batchnorm is to preserve topological ordering
graph.node.insert(node_ind, mul_node)
graph.node.insert(node_ind + 1, add_node)
# remove old nodes
graph.node.remove(n)
if consumer is not None:
if consumer.op_type == "Squeeze":
graph.node.remove(consumer)
if producer is not None:
if producer.op_type == "Unsqueeze":
graph.node.remove(producer)
model = ir.from_proto(model.model)
model = rewrite(model, pattern_rewrite_rules=rewrite_rules)
model = ir.to_proto(model)
model = ModelWrapper(model)
model = model.transform(InferShapes())
return (model, graph_modified)
model = model.transform(FoldConstants())
return (model, False)

51 changes: 51 additions & 0 deletions src/qonnx/util/onnxscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2025 Advanced Micro Devices, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of AMD nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from onnxscript.rewriter._rewrite_rule import ReplacementPatternFunction, ReplacementSubgraph
from typing import Sequence
from onnxscript.ir import _convenience, _tape

RewriterContext = _tape.Builder

class ReplacePattern(ReplacementPatternFunction):
"""Utility wrapper that provides matched pattern information to the replacement function.
The matched pattern is passed as the 'match' keyword argument."""

def __init__(self, func):
super().__init__(func)

def get_replacement(self, match):
context = RewriterContext()
new_outputs = self._function(context, match=match, **match.bindings)
if new_outputs is None:
return None
if not isinstance(new_outputs, Sequence):
new_outputs = [new_outputs]
return ReplacementSubgraph(
match, new_outputs, context.nodes, context.initializers, context.used_opsets
)
4 changes: 4 additions & 0 deletions tests/transformation/test_batchnorm_to_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,7 @@ def test_batchnorm_to_affine_epsilon(epsilon):
output_lowered = output_dict[output_node_name]

assert (output_original == output_lowered).all()

op_types = list(map(lambda x: x.op_type, model_lowered.graph.node))
assert "BatchNormalization" not in op_types
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check for no Unsqueeze/Squeeze left here as well?


Loading