Skip to content

Commit 6ed4172

Browse files
authored
[Accuracy diff No.105] Fix accuracy diff for max,amax,min,amin api (#73229)
* fix accuracy max,amax,min,amin * fix tests
1 parent 58f821d commit 6ed4172

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

paddle/phi/kernels/funcs/reduce_functor.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,12 @@ struct AMaxOrAMinGradFunctor {
308308
mask.sum(axis).reshape(dy->dimensions()).broadcast(dim);
309309
return;
310310
}
311+
312+
if (rank == 0) {
313+
dx->device(place) = dy->broadcast(dim) * mask;
314+
return;
315+
}
316+
311317
// axis is list, HANDLE_AXIS_DIM(broadcast_dim_size, rank)
312318
HANDLE_AXIS_DIM(3, 2);
313319
HANDLE_AXIS_DIM(4, 2);

test/legacy_test/test_max_min_amax_amin_op.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,33 @@ def _test_dygraph(func):
139139
# test two minimum or maximum elements
140140

141141

142+
class TestMaxMinAmaxAminAPI_AxisWithOne1(TestMaxMinAmaxAminAPI):
143+
def init_case(self):
144+
self.x_np = np.random.randn(1, 5, 10).astype(np.float32)
145+
self.shape = [1, 5, 10]
146+
self.dtype = 'float32'
147+
self.axis = 0
148+
self.keepdim = False
149+
150+
151+
class TestMaxMinAmaxAminAPI_AxisWithOne2(TestMaxMinAmaxAminAPI):
152+
def init_case(self):
153+
self.x_np = np.random.randn(1, 5, 10).astype(np.float32)
154+
self.shape = [1, 5, 10]
155+
self.dtype = 'float32'
156+
self.axis = 0
157+
self.keepdim = True
158+
159+
160+
class TestMaxMinAmaxAminAPI_AxisWithOne3(TestMaxMinAmaxAminAPI):
161+
def init_case(self):
162+
self.x_np = np.random.randn(1, 1, 10).astype(np.float32)
163+
self.shape = [1, 1, 10]
164+
self.dtype = 'float32'
165+
self.axis = (0, 1)
166+
self.keepdim = False
167+
168+
142169
class TestMaxMinAmaxAminAPI_ZeroDim(TestMaxMinAmaxAminAPI):
143170
def init_case(self):
144171
self.x_np = np.array(0.5)

0 commit comments

Comments
 (0)