File tree Expand file tree Collapse file tree 2 files changed +33
-0
lines changed Expand file tree Collapse file tree 2 files changed +33
-0
lines changed Original file line number Diff line number Diff line change @@ -308,6 +308,12 @@ struct AMaxOrAMinGradFunctor {
308
308
mask.sum (axis).reshape (dy->dimensions ()).broadcast (dim);
309
309
return ;
310
310
}
311
+
312
+ if (rank == 0 ) {
313
+ dx->device (place) = dy->broadcast (dim) * mask;
314
+ return ;
315
+ }
316
+
311
317
// axis is list, HANDLE_AXIS_DIM(broadcast_dim_size, rank)
312
318
HANDLE_AXIS_DIM (3 , 2 );
313
319
HANDLE_AXIS_DIM (4 , 2 );
Original file line number Diff line number Diff line change @@ -139,6 +139,33 @@ def _test_dygraph(func):
139
139
# test two minimum or maximum elements
140
140
141
141
142
+ class TestMaxMinAmaxAminAPI_AxisWithOne1 (TestMaxMinAmaxAminAPI ):
143
+ def init_case (self ):
144
+ self .x_np = np .random .randn (1 , 5 , 10 ).astype (np .float32 )
145
+ self .shape = [1 , 5 , 10 ]
146
+ self .dtype = 'float32'
147
+ self .axis = 0
148
+ self .keepdim = False
149
+
150
+
151
+ class TestMaxMinAmaxAminAPI_AxisWithOne2 (TestMaxMinAmaxAminAPI ):
152
+ def init_case (self ):
153
+ self .x_np = np .random .randn (1 , 5 , 10 ).astype (np .float32 )
154
+ self .shape = [1 , 5 , 10 ]
155
+ self .dtype = 'float32'
156
+ self .axis = 0
157
+ self .keepdim = True
158
+
159
+
160
+ class TestMaxMinAmaxAminAPI_AxisWithOne3 (TestMaxMinAmaxAminAPI ):
161
+ def init_case (self ):
162
+ self .x_np = np .random .randn (1 , 1 , 10 ).astype (np .float32 )
163
+ self .shape = [1 , 1 , 10 ]
164
+ self .dtype = 'float32'
165
+ self .axis = (0 , 1 )
166
+ self .keepdim = False
167
+
168
+
142
169
class TestMaxMinAmaxAminAPI_ZeroDim (TestMaxMinAmaxAminAPI ):
143
170
def init_case (self ):
144
171
self .x_np = np .array (0.5 )
You can’t perform that action at this time.
0 commit comments