Skip to content

Commit 3fb7eac

Browse files
authored
[MLU] add clip_by_norm kernel (#1338)
1 parent 8c20c2a commit 3fb7eac

File tree

2 files changed

+232
-0
lines changed

2 files changed

+232
-0
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "kernels/funcs/elementwise_utils.h"
16+
#include "kernels/funcs/mlu_baseop.h"
17+
#include "kernels/funcs/mlu_funcs.h"
18+
19+
namespace custom_kernel {
20+
template <typename T, typename Context>
21+
void ClipByNormKernel(const Context& dev_ctx,
22+
const phi::DenseTensor& x,
23+
float max_norm,
24+
phi::DenseTensor* out) {
25+
PADDLE_ENFORCE_NOT_NULL(&x,
26+
phi::errors::InvalidArgument(
27+
"Input(X) of ClipByNormOp should not be null. "
28+
"Please check if it is created correctly."));
29+
phi::DenseTensor square_sum;
30+
phi::DenseTensorMeta square_sum_meta = {x.dtype(), phi::DDim({1})};
31+
square_sum.set_meta(square_sum_meta);
32+
dev_ctx.template Alloc<T>(&square_sum);
33+
34+
MLUCnnlTensorDesc input_desc(x);
35+
MLUCnnlTensorDesc square_sum_desc(square_sum);
36+
37+
// L2Loss
38+
MLUCnnl::L2Loss(
39+
dev_ctx, input_desc.get(), GetBasePtr(&x), GetBasePtr(&square_sum));
40+
41+
// do mul
42+
phi::DenseTensor scale_tensor;
43+
scale_tensor.Resize({1});
44+
dev_ctx.template Alloc<T>(&scale_tensor);
45+
46+
phi::DenseTensor bias_tensor;
47+
bias_tensor.Resize({1});
48+
dev_ctx.template Alloc<T>(&bias_tensor);
49+
50+
MLUCnnlTensorDesc scale_desc(scale_tensor);
51+
MLUCnnlTensorDesc bias_desc(bias_tensor);
52+
FillMLUTensorWithHostValue(dev_ctx, static_cast<T>(2.0f), &scale_tensor);
53+
FillMLUTensorWithHostValue(dev_ctx, static_cast<T>(0.0f), &bias_tensor);
54+
55+
MLUCnnl::Scale(dev_ctx,
56+
0,
57+
square_sum_desc.get(),
58+
GetBasePtr(&square_sum),
59+
scale_desc.get(),
60+
GetBasePtr(&scale_tensor),
61+
bias_desc.get(),
62+
GetBasePtr(&bias_tensor),
63+
square_sum_desc.get(),
64+
GetBasePtr(&square_sum));
65+
66+
// sqrt
67+
phi::DenseTensor x_norm;
68+
phi::DenseTensorMeta x_norm_meta = {x.dtype(), phi::DDim({1})};
69+
x_norm.set_meta(x_norm_meta);
70+
dev_ctx.template Alloc<T>(&x_norm);
71+
72+
MLUCnnlTensorDesc x_norm_desc(x_norm);
73+
cnnlComputationPreference_t prefer = CNNL_COMPUTATION_HIGH_PRECISION;
74+
MLUCnnl::Sqrt(dev_ctx,
75+
prefer,
76+
square_sum_desc.get(),
77+
GetBasePtr(&square_sum),
78+
x_norm_desc.get(),
79+
GetBasePtr(&x_norm));
80+
81+
phi::DenseTensor x_norm_t;
82+
phi::DenseTensorMeta x_norm_t_meta = {
83+
x_norm.dtype(), x_norm.dims(), x_norm.layout()};
84+
x_norm_t.set_meta(x_norm_t_meta);
85+
86+
// sync copy
87+
dev_ctx.Wait();
88+
TensorCopy(dev_ctx, x_norm, true, &x_norm_t, phi::CPUPlace());
89+
auto x_norm_v = static_cast<float>(*(x_norm_t.data<T>()));
90+
91+
dev_ctx.template Alloc<T>(out);
92+
if (x_norm_v <= max_norm) {
93+
TensorCopy(dev_ctx, x, false, out);
94+
} else {
95+
auto epsilon = x_norm_v <= static_cast<float>(1e-30)
96+
? static_cast<float>(1e-6)
97+
: static_cast<float>(0);
98+
99+
float scaling = max_norm / (x_norm_v + epsilon);
100+
auto scale_t = static_cast<T>(scaling);
101+
phi::DenseTensor scaling_tensor;
102+
scaling_tensor.Resize({1});
103+
dev_ctx.template Alloc<T>(&scaling_tensor);
104+
MLUCnnlTensorDesc scaling_tensor_desc(scaling_tensor);
105+
MLUCnnl::Fill(dev_ctx,
106+
CNNL_POINTER_MODE_HOST,
107+
&scale_t,
108+
scaling_tensor_desc.get(),
109+
GetBasePtr(&scaling_tensor));
110+
111+
auto data_type = ToCnnlDataType<T>();
112+
MLUCnnlTensorDesc out_desc(*out);
113+
114+
// compute out = scaling_tensor * x
115+
MLUOpTensorKernel<T>(
116+
dev_ctx, scaling_tensor, x, -1, CNNL_OP_TENSOR_MUL, out);
117+
}
118+
}
119+
} // namespace custom_kernel
120+
121+
PD_REGISTER_PLUGIN_KERNEL(clip_by_norm,
122+
mlu,
123+
ALL_LAYOUT,
124+
custom_kernel::ClipByNormKernel,
125+
float,
126+
phi::dtype::float16) {}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
import paddle
20+
from tests.op_test import OpTest
21+
22+
paddle.enable_static()
23+
24+
25+
class TestClipByNormOp(OpTest):
26+
def setUp(self):
27+
self.set_mlu()
28+
self.max_relative_error = 0.006
29+
self.init_dtype()
30+
self.initTestCase()
31+
input = np.random.random(self.shape).astype(self.dtype)
32+
input[np.abs(input) < self.max_relative_error] = 0.5
33+
self.op_type = "clip_by_norm"
34+
self.inputs = {
35+
"X": input,
36+
}
37+
self.attrs = {}
38+
self.attrs["max_norm"] = self.max_norm
39+
norm = np.sqrt(np.sum(np.square(input)))
40+
if norm > self.max_norm:
41+
output = self.max_norm * input / norm
42+
else:
43+
output = input
44+
self.outputs = {"Out": output}
45+
46+
def set_mlu(self):
47+
self.__class__.use_custom_device = True
48+
self.place = paddle.CustomPlace("mlu", 0)
49+
50+
def test_check_output(self):
51+
self.check_output_with_place(self.place)
52+
53+
def initTestCase(self):
54+
self.shape = (100,)
55+
self.max_norm = 1.0
56+
57+
def init_dtype(self):
58+
self.dtype = np.float32
59+
60+
61+
class TestCase1(TestClipByNormOp):
62+
def initTestCase(self):
63+
self.shape = (100,)
64+
self.max_norm = 1e20
65+
66+
67+
class TestCase2(TestClipByNormOp):
68+
def initTestCase(self):
69+
self.shape = (16, 16)
70+
self.max_norm = 0.1
71+
72+
73+
class TestCase3(TestClipByNormOp):
74+
def initTestCase(self):
75+
self.shape = (4, 8, 16)
76+
self.max_norm = 1.0
77+
78+
79+
class TestClipByNormOpFp16(TestClipByNormOp):
80+
def init_dtype(self):
81+
self.dtype = np.float16
82+
83+
def test_check_output(self):
84+
self.check_output_with_place(self.place, atol=1e-3)
85+
86+
87+
class TestClipByNormOpFp16Case1(TestClipByNormOpFp16):
88+
def initTestCase(self):
89+
self.shape = (100,)
90+
self.max_norm = 1e20
91+
92+
93+
class TestClipByNormOpFp16Case2(TestClipByNormOpFp16):
94+
def initTestCase(self):
95+
self.shape = (16, 16)
96+
self.max_norm = 0.1
97+
98+
99+
class TestClipByNormOpFp16Case3(TestClipByNormOpFp16):
100+
def initTestCase(self):
101+
self.shape = (4, 8, 16)
102+
self.max_norm = 1.0
103+
104+
105+
if __name__ == "__main__":
106+
unittest.main()

0 commit comments

Comments
 (0)