Skip to content

Onnxscript implementation of BatchNormToAffine #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

alanbacellar
Copy link
Collaborator

@alanbacellar alanbacellar commented May 28, 2025

  • onnxscript batchnormtoaffine
  • util for receiving match in replace pattern
  • added check batchnorm op not in graph in pytest
  • added onnxscript==0.2.6 and brevitas to install requires

@@ -1,4 +1,4 @@
# Copyright (c) 2020 Xilinx, Inc.
# Copyright (c) 2025 Advanced Micro Devices, Inc.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please do not remove older copyrights, you can just add the new one for 2025 AMD as the 2nd line

@@ -11,7 +11,7 @@
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of Xilinx nor the names of its
# * Neither the name of AMD nor the names of its
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually this one should be QONNX, not Xilinx or AMD

# Get epsilon from matched pattern
batch_norm = kwargs['match'].nodes[0]
epsilon_attr = batch_norm.attributes.get('epsilon', None)
epsilon_value = 1e-5 if epsilon_attr is None else epsilon_attr.value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better to have the 1e-5 default configurable with e.g. a top level variable in the module (or if possible, passed in as an optional arg to the Transformation with this default value)

Comment on lines +62 to +67
input_shape = x.shape
assert input_shape is not None and len(input_shape) >= 2
n_spatial_dims = len(input_shape) - 2
axes = [0] + [i + 2 for i in range(n_spatial_dims)]
A = op.Unsqueeze(A, axes=axes)
B = op.Unsqueeze(B, axes=axes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. is this always safe? does it make sense to add more testcases with different dimensionalties? (now much easier to create dummy models with onnxscript)
  2. the original transformation removes surrounding squeeze/unsqueeze nodes if they were present around the BatchNorm, does this new version still have the same effect? (also another good thing to test/check)

@@ -117,3 +117,7 @@ def test_batchnorm_to_affine_epsilon(epsilon):
output_lowered = output_dict[output_node_name]

assert (output_original == output_lowered).all()

op_types = list(map(lambda x: x.op_type, model_lowered.graph.node))
assert "BatchNormalization" not in op_types
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check for no Unsqueeze/Squeeze left here as well?

Comment on lines +52 to +53
batch_norm = kwargs['match'].nodes[0]
epsilon_attr = batch_norm.attributes.get('epsilon', None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be handled by something like https://github.com/microsoft/onnxscript/blob/main/onnxscript/rewriter/cast_constant_of_shape.py#L13-L22 to get the attribute instead, removing the need for the special util?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants