-
Notifications
You must be signed in to change notification settings - Fork 47
Onnxscript implementation of BatchNormToAffine #185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Onnxscript implementation of BatchNormToAffine #185
Conversation
… match in replace pattern
@@ -1,4 +1,4 @@ | |||
# Copyright (c) 2020 Xilinx, Inc. | |||
# Copyright (c) 2025 Advanced Micro Devices, Inc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please do not remove older copyrights, you can just add the new one for 2025 AMD as the 2nd line
@@ -11,7 +11,7 @@ | |||
# this list of conditions and the following disclaimer in the documentation | |||
# and/or other materials provided with the distribution. | |||
# | |||
# * Neither the name of Xilinx nor the names of its | |||
# * Neither the name of AMD nor the names of its |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually this one should be QONNX, not Xilinx or AMD
# Get epsilon from matched pattern | ||
batch_norm = kwargs['match'].nodes[0] | ||
epsilon_attr = batch_norm.attributes.get('epsilon', None) | ||
epsilon_value = 1e-5 if epsilon_attr is None else epsilon_attr.value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be better to have the 1e-5 default configurable with e.g. a top level variable in the module (or if possible, passed in as an optional arg to the Transformation with this default value)
input_shape = x.shape | ||
assert input_shape is not None and len(input_shape) >= 2 | ||
n_spatial_dims = len(input_shape) - 2 | ||
axes = [0] + [i + 2 for i in range(n_spatial_dims)] | ||
A = op.Unsqueeze(A, axes=axes) | ||
B = op.Unsqueeze(B, axes=axes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- is this always safe? does it make sense to add more testcases with different dimensionalties? (now much easier to create dummy models with onnxscript)
- the original transformation removes surrounding squeeze/unsqueeze nodes if they were present around the BatchNorm, does this new version still have the same effect? (also another good thing to test/check)
@@ -117,3 +117,7 @@ def test_batchnorm_to_affine_epsilon(epsilon): | |||
output_lowered = output_dict[output_node_name] | |||
|
|||
assert (output_original == output_lowered).all() | |||
|
|||
op_types = list(map(lambda x: x.op_type, model_lowered.graph.node)) | |||
assert "BatchNormalization" not in op_types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check for no Unsqueeze/Squeeze left here as well?
batch_norm = kwargs['match'].nodes[0] | ||
epsilon_attr = batch_norm.attributes.get('epsilon', None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be handled by something like https://github.com/microsoft/onnxscript/blob/main/onnxscript/rewriter/cast_constant_of_shape.py#L13-L22 to get the attribute instead, removing the need for the special util?
Uh oh!
There was an error while loading. Please reload this page.