18
18
from ..fluid import core , layers
19
19
20
20
# TODO: define searching & indexing functions of a tensor
21
- from ..fluid .layers import argmin #DEFINE_ALIAS
22
21
from ..fluid .layers import has_inf #DEFINE_ALIAS
23
22
from ..fluid .layers import has_nan #DEFINE_ALIAS
24
23
@@ -123,7 +122,7 @@ def argsort(x, axis=-1, descending=False, name=None):
123
122
return ids
124
123
125
124
126
- def argmax (x , axis = None , dtype = None , keepdim = False , name = None ):
125
+ def argmax (x , axis = None , keepdim = False , dtype = "int64" , name = None ):
127
126
"""
128
127
This OP computes the indices of the max elements of the input tensor's
129
128
element along the provided axis.
@@ -134,10 +133,10 @@ def argmax(x, axis=None, dtype=None, keepdim=False, name=None):
134
133
axis(int, optional): Axis to compute indices along. The effective range
135
134
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
136
135
as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
137
- dtype(str): Data type of the output tensor which can
138
- be int32, int64. The default value is None, and it will
139
- return the int64 indices.
140
136
keepdim(bool, optional): Keep the axis that selecting max. The defalut value is False.
137
+ dtype(str|np.dtype, optional): Data type of the output tensor which can
138
+ be int32, int64. The default value is 'int64', and it will
139
+ return the int64 indices.
141
140
name(str, optional): The default value is None. Normally there is no
142
141
need for user to set this property. For more information, please
143
142
refer to :ref:`api_guide_Name`.
@@ -163,48 +162,39 @@ def argmax(x, axis=None, dtype=None, keepdim=False, name=None):
163
162
print(out3.numpy())
164
163
# [2 3 1]
165
164
"""
165
+ if axis is not None and not isinstance (axis , int ):
166
+ raise TypeError (
167
+ "The type of 'axis' must be int or None in argmax, but received %s."
168
+ % (type (axis )))
169
+ var_dtype = convert_np_dtype_to_dtype_ (dtype )
170
+ check_dtype (var_dtype , 'dtype' , ['int32' , 'int64' ], 'argmin' )
166
171
flatten = False
167
172
if axis is None :
168
173
flatten = True
169
174
axis = 0
170
175
171
176
if in_dygraph_mode ():
172
- if dtype != None :
173
- var_dtype = convert_np_dtype_to_dtype_ (dtype )
174
- out = core .ops .arg_max (x , 'axis' , axis , 'dtype' , var_dtype ,
175
- 'keepdim' , keepdim , 'flatten' , flatten )
176
- else :
177
- out = core .ops .arg_max (x , 'axis' , axis , 'keepdim' , keepdim ,
178
- 'flatten' , flatten )
177
+ out = core .ops .arg_max (x , 'axis' , axis , 'dtype' , var_dtype , 'keepdims' ,
178
+ keepdim , 'flatten' , flatten )
179
179
return out
180
180
181
181
helper = LayerHelper ("argmax" , ** locals ())
182
182
check_variable_and_dtype (
183
183
x , 'x' , ['float32' , 'float64' , 'int16' , 'int32' , 'int64' , 'uint8' ],
184
184
'paddle.argmax' )
185
- var_dtype = None
186
185
attrs = {}
187
- if dtype is not None :
188
- if dtype not in ['int32' , 'int64' ]:
189
- raise ValueError (
190
- "The value of 'dtype' in argmax op must be int32, int64, but received of {}" .
191
- format (dtype ))
192
- var_dtype = convert_np_dtype_to_dtype_ (dtype )
193
- attrs ["dtype" ] = var_dtype
194
- else :
195
- var_dtype = VarDesc .VarType .INT64
196
-
197
186
out = helper .create_variable_for_type_inference (var_dtype )
198
187
attrs ['keepdims' ] = keepdim
199
188
attrs ['axis' ] = axis
200
189
attrs ['flatten' ] = flatten
190
+ attrs ['dtype' ] = var_dtype
201
191
helper .append_op (
202
192
type = 'arg_max' , inputs = {'X' : x }, outputs = {'Out' : [out ]}, attrs = attrs )
203
193
out .stop_gradient = True
204
194
return out
205
195
206
196
207
- def argmin (x , axis = None , dtype = None , keepdim = False , name = None ):
197
+ def argmin (x , axis = None , keepdim = False , dtype = "int64" , name = None ):
208
198
"""
209
199
This OP computes the indices of the min elements of the input tensor's
210
200
element along the provided axis.
@@ -215,10 +205,10 @@ def argmin(x, axis=None, dtype=None, keepdim=False, name=None):
215
205
axis(int, optional): Axis to compute indices along. The effective range
216
206
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
217
207
as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
208
+ keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False.
218
209
dtype(str): Data type of the output tensor which can
219
- be int32, int64. The default value is None , and it will
210
+ be int32, int64. The default value is 'int64' , and it will
220
211
return the int64 indices.
221
- keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False.
222
212
name(str, optional): The default value is None. Normally there is no
223
213
need for user to set this property. For more information, please
224
214
refer to :ref:`api_guide_Name`.
@@ -244,41 +234,32 @@ def argmin(x, axis=None, dtype=None, keepdim=False, name=None):
244
234
print(out3.numpy())
245
235
# [0 0 2]
246
236
"""
237
+ if axis is not None and not isinstance (axis , int ):
238
+ raise TypeError (
239
+ "The type of 'axis' must be int or None in argmin, but received %s."
240
+ % (type (axis )))
241
+ var_dtype = convert_np_dtype_to_dtype_ (dtype )
242
+ check_dtype (var_dtype , 'dtype' , ['int32' , 'int64' ], 'argmin' )
247
243
flatten = False
248
244
if axis is None :
249
245
flatten = True
250
246
axis = 0
251
247
252
248
if in_dygraph_mode ():
253
- if dtype != None :
254
- var_dtype = convert_np_dtype_to_dtype_ (dtype )
255
- out = core .ops .arg_min (x , 'axis' , axis , 'dtype' , var_dtype ,
256
- 'keepdim' , keepdim , 'flatten' , flatten )
257
- else :
258
- out = core .ops .arg_min (x , 'axis' , axis , 'keepdim' , keepdim ,
259
- 'flatten' , flatten )
249
+ out = core .ops .arg_min (x , 'axis' , axis , 'dtype' , var_dtype , 'keepdims' ,
250
+ keepdim , 'flatten' , flatten )
260
251
return out
261
252
262
253
helper = LayerHelper ("argmin" , ** locals ())
263
254
check_variable_and_dtype (
264
255
x , 'x' , ['float32' , 'float64' , 'int16' , 'int32' , 'int64' , 'uint8' ],
265
256
'paddle.argmin' )
266
- var_dtype = None
267
- attrs = {}
268
- if dtype is not None :
269
- if dtype not in ['int32' , 'int64' ]:
270
- raise ValueError (
271
- "The value of 'dtype' in argmin op must be int32, int64, but received of {}" .
272
- format (dtype ))
273
- var_dtype = convert_np_dtype_to_dtype_ (dtype )
274
- attrs ["dtype" ] = var_dtype
275
- else :
276
- var_dtype = VarDesc .VarType .INT64
277
-
278
257
out = helper .create_variable_for_type_inference (var_dtype )
258
+ attrs = {}
279
259
attrs ['keepdims' ] = keepdim
280
260
attrs ['axis' ] = axis
281
261
attrs ['flatten' ] = flatten
262
+ attrs ['dtype' ] = var_dtype
282
263
helper .append_op (
283
264
type = 'arg_min' , inputs = {'X' : x }, outputs = {'Out' : [out ]}, attrs = attrs )
284
265
out .stop_gradient = True
0 commit comments