-
Notifications
You must be signed in to change notification settings - Fork 47
Onnxscript implementation of BatchNormToAffine #185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Copyright (c) 2020 Xilinx, Inc. | ||
# Copyright (c) 2025 Advanced Micro Devices, Inc. | ||
# All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
|
@@ -11,7 +11,7 @@ | |
# this list of conditions and the following disclaimer in the documentation | ||
# and/or other materials provided with the distribution. | ||
# | ||
# * Neither the name of Xilinx nor the names of its | ||
# * Neither the name of AMD nor the names of its | ||
# contributors may be used to endorse or promote products derived from | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually this one should be QONNX, not Xilinx or AMD |
||
# this software without specific prior written permission. | ||
# | ||
|
@@ -27,78 +27,59 @@ | |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
import numpy as np | ||
from onnx import TensorProto | ||
from onnx import TensorProto, helper | ||
from onnx import helper as oh | ||
|
||
from qonnx.transformation.base import Transformation | ||
from qonnx.transformation.infer_shapes import InferShapes | ||
from qonnx.transformation.fold_constants import FoldConstants | ||
from qonnx.util.basic import get_by_name | ||
|
||
from qonnx.core.modelwrapper import ModelWrapper | ||
from onnxscript import opset15 as op | ||
from onnxscript import script | ||
from onnxscript.rewriter import pattern, rewrite | ||
from onnxscript import ir | ||
|
||
from qonnx.util.onnxscript import ReplacePattern | ||
|
||
def target_pattern(op, x, scale, bias, mean, var): | ||
return op.BatchNormalization(x, scale, bias, mean, var) | ||
|
||
def replace_pattern(op, x, scale, bias, mean, var, **kwargs): | ||
|
||
# Get epsilon from matched pattern | ||
batch_norm = kwargs['match'].nodes[0] | ||
epsilon_attr = batch_norm.attributes.get('epsilon', None) | ||
epsilon_value = 1e-5 if epsilon_attr is None else epsilon_attr.value | ||
Comment on lines
+52
to
+53
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this be handled by something like https://github.com/microsoft/onnxscript/blob/main/onnxscript/rewriter/cast_constant_of_shape.py#L13-L22 to get the attribute instead, removing the need for the special util? |
||
epsilon_tensor = helper.make_tensor("epsilon", TensorProto.FLOAT, (1,), [epsilon_value]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be better to have the 1e-5 default configurable with e.g. a top level variable in the module (or if possible, passed in as an optional arg to the Transformation with this default value) |
||
epsilon = op.Constant(value=epsilon_tensor) | ||
|
||
A = op.Div(scale, op.Sqrt(op.Add(var, epsilon))) | ||
B = op.Sub(bias, op.Mul(A, mean)) | ||
|
||
# Unsqueeze A and B | ||
input_shape = x.shape | ||
assert input_shape is not None and len(input_shape) >= 2 | ||
n_spatial_dims = len(input_shape) - 2 | ||
axes = [0] + [i + 2 for i in range(n_spatial_dims)] | ||
A = op.Unsqueeze(A, axes=axes) | ||
B = op.Unsqueeze(B, axes=axes) | ||
|
||
Comment on lines
+62
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
return op.Add(op.Mul(x, A), B) | ||
|
||
rule1 = pattern.RewriteRule(target_pattern, ReplacePattern(replace_pattern), verbose=10) | ||
rewrite_rules = pattern.RewriteRuleSet([rule1]) | ||
|
||
class BatchNormToAffine(Transformation): | ||
"""Replaces any test-time BatchNorm layers with Mul-Add layers.""" | ||
|
||
def apply(self, model): | ||
graph = model.graph | ||
node_ind = 0 | ||
graph_modified = False | ||
for n in graph.node: | ||
node_ind += 1 | ||
if n.op_type == "BatchNormalization": | ||
graph_modified = True | ||
bn_input = n.input[0] | ||
bn_output = n.output[0] | ||
# extract batchnorm parameters as numpy arrays | ||
scale = model.get_initializer(n.input[1]) | ||
bias = model.get_initializer(n.input[2]) | ||
mean = model.get_initializer(n.input[3]) | ||
variance = model.get_initializer(n.input[4]) | ||
epsilon = get_by_name(n.attribute, "epsilon") | ||
epsilon = getattr(epsilon, "f", 1e-5) | ||
# find A and B to compute batchnorm as affine transpose Ax+B | ||
# TODO is a division by moving avg factor needed for variance? | ||
A = scale / np.sqrt(epsilon + variance) | ||
B = bias - (A * mean) | ||
# see if we have surrounding Unsqueeze/Squeeze nodes we can remove | ||
producer = model.find_producer(bn_input) | ||
if producer is not None: | ||
if producer.op_type == "Unsqueeze": | ||
bn_input = producer.input[0] | ||
consumer = model.find_consumer(bn_output) | ||
if consumer is not None: | ||
if consumer.op_type == "Squeeze": | ||
bn_output = consumer.output[0] | ||
data_shape = model.get_tensor_shape(bn_input) | ||
assert A.ndim == B.ndim, "Unexpected mul/add dims in BatchNormToAffine" | ||
assert len(data_shape) >= A.ndim, "Unexpected number of dims found in BatchNormToAffine" | ||
# reshape the mul/add constants to match the data shape/dims | ||
# by adding (1,) dimensions to the right | ||
n_spatial_dims = len(data_shape) - 2 | ||
target_shape = (1, -1) + tuple(1 for i in range(n_spatial_dims)) | ||
A = A.reshape(target_shape) | ||
B = B.reshape(target_shape) | ||
# create value_info and initializers for Mul and Add constants | ||
mul_const = oh.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape) | ||
graph.value_info.append(mul_const) | ||
model.set_initializer(mul_const.name, A) | ||
mul_output = oh.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, data_shape) | ||
graph.value_info.append(mul_output) | ||
add_const = oh.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape) | ||
graph.value_info.append(add_const) | ||
model.set_initializer(add_const.name, B) | ||
# create Mul and Add nodes to replace the batchnorm | ||
mul_node = oh.make_node("Mul", [bn_input, mul_const.name], [mul_output.name]) | ||
add_node = oh.make_node("Add", [mul_output.name, add_const.name], [bn_output]) | ||
# insert where the batchnorm is to preserve topological ordering | ||
graph.node.insert(node_ind, mul_node) | ||
graph.node.insert(node_ind + 1, add_node) | ||
# remove old nodes | ||
graph.node.remove(n) | ||
if consumer is not None: | ||
if consumer.op_type == "Squeeze": | ||
graph.node.remove(consumer) | ||
if producer is not None: | ||
if producer.op_type == "Unsqueeze": | ||
graph.node.remove(producer) | ||
model = ir.from_proto(model.model) | ||
model = rewrite(model, pattern_rewrite_rules=rewrite_rules) | ||
model = ir.to_proto(model) | ||
model = ModelWrapper(model) | ||
model = model.transform(InferShapes()) | ||
return (model, graph_modified) | ||
model = model.transform(FoldConstants()) | ||
return (model, False) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) 2025 Advanced Micro Devices, Inc. | ||
# All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions are met: | ||
# | ||
# * Redistributions of source code must retain the above copyright notice, this | ||
# list of conditions and the following disclaimer. | ||
# | ||
# * Redistributions in binary form must reproduce the above copyright notice, | ||
# this list of conditions and the following disclaimer in the documentation | ||
# and/or other materials provided with the distribution. | ||
# | ||
# * Neither the name of AMD nor the names of its | ||
# contributors may be used to endorse or promote products derived from | ||
# this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
from onnxscript.rewriter._rewrite_rule import ReplacementPatternFunction, ReplacementSubgraph | ||
from typing import Sequence | ||
from onnxscript.ir import _convenience, _tape | ||
|
||
RewriterContext = _tape.Builder | ||
|
||
class ReplacePattern(ReplacementPatternFunction): | ||
"""Utility wrapper that provides matched pattern information to the replacement function. | ||
The matched pattern is passed as the 'match' keyword argument.""" | ||
|
||
def __init__(self, func): | ||
super().__init__(func) | ||
|
||
def get_replacement(self, match): | ||
context = RewriterContext() | ||
new_outputs = self._function(context, match=match, **match.bindings) | ||
if new_outputs is None: | ||
return None | ||
if not isinstance(new_outputs, Sequence): | ||
new_outputs = [new_outputs] | ||
return ReplacementSubgraph( | ||
match, new_outputs, context.nodes, context.initializers, context.used_opsets | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -117,3 +117,7 @@ def test_batchnorm_to_affine_epsilon(epsilon): | |
output_lowered = output_dict[output_node_name] | ||
|
||
assert (output_original == output_lowered).all() | ||
|
||
op_types = list(map(lambda x: x.op_type, model_lowered.graph.node)) | ||
assert "BatchNormalization" not in op_types | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check for no Unsqueeze/Squeeze left here as well? |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please do not remove older copyrights, you can just add the new one for 2025 AMD as the 2nd line