Skip to content

Commit 897e574

Browse files
authored
cherry-pick from the develop PR#26792, fix the argmin, argmax
cherry-pick from the develop PR#26792, fix the argmin, argmax
1 parent 2c298d6 commit 897e574

File tree

5 files changed

+108
-56
lines changed

5 files changed

+108
-56
lines changed

paddle/fluid/operators/arg_max_op.cc

+18
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/framework/op_version_registry.h"
1516
#include "paddle/fluid/operators/arg_min_max_op_base.h"
1617

1718
REGISTER_OPERATOR(
@@ -31,3 +32,20 @@ REGISTER_OP_CPU_KERNEL(
3132
int16_t>,
3233
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
3334
uint8_t>);
35+
REGISTER_OP_VERSION(arg_max)
36+
.AddCheckpoint(
37+
R"ROC(
38+
Upgrade argmax add a new attribute [flatten] and modify the attribute of dtype)ROC",
39+
paddle::framework::compatible::OpVersionDesc()
40+
.NewAttr("flatten",
41+
"In order to compute the argmax over the flattened array "
42+
"when the "
43+
"argument `axis` in python API is None.",
44+
false)
45+
.ModifyAttr(
46+
"dtype",
47+
"change the default value of dtype, the older version "
48+
"is -1, means return the int64 indices."
49+
"The new version is 3, return the int64 indices directly."
50+
"And supporting the dtype of -1 in new version.",
51+
3));

paddle/fluid/operators/arg_min_max_op_base.h

+28-7
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ struct VisitDataArgMinMaxFunctor {
7070
auto axis = ctx.Attr<int64_t>("axis");
7171
auto keepdims = ctx.Attr<bool>("keepdims");
7272
const bool& flatten = ctx.Attr<bool>("flatten");
73+
// paddle do not have the scalar tensor, just return the shape [1] tensor
74+
if (flatten) keepdims = true;
7375

7476
// if flatten, will construct the new dims for the cacluate
7577
framework::DDim x_dims;
@@ -164,15 +166,30 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
164166
platform::errors::InvalidArgument(
165167
"'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size()));
166168

169+
auto x_rank = x_dims.size();
170+
if (axis < 0) axis += x_rank;
171+
if (ctx->IsRuntime()) {
172+
const int& dtype = ctx->Attrs().Get<int>("dtype");
173+
if (dtype == framework::proto::VarType::INT32) {
174+
int64_t all_element_num = 0;
175+
if (flatten) {
176+
all_element_num = framework::product(x_dims);
177+
178+
} else {
179+
all_element_num = x_dims[axis];
180+
}
181+
PADDLE_ENFORCE_LE(
182+
all_element_num, INT_MAX,
183+
"The element num of the argmin/argmax input at axis is "
184+
"%d, is larger than int32 maximum value:%d, you must "
185+
"set the dtype of argmin/argmax to 'int64'.",
186+
all_element_num, INT_MAX);
187+
}
188+
}
167189
std::vector<int64_t> vec;
168190
if (flatten) {
169-
// if is flatten, will return the only on element
170-
if (keepdims) {
171-
vec.emplace_back(static_cast<int64_t>(1));
172-
}
191+
vec.emplace_back(static_cast<int64_t>(1));
173192
} else {
174-
auto x_rank = x_dims.size();
175-
if (axis < 0) axis += x_rank;
176193
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
177194
if (keepdims) {
178195
vec.emplace_back(static_cast<int64_t>(1));
@@ -194,10 +211,14 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
194211
AddOutput("Out", "Output tensor.");
195212
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
196213
AddAttr<bool>("keepdims", "Keep the dim that to reduce.").SetDefault(false);
197-
AddAttr<int>("dtype", "Keep the dim that to reduce.").SetDefault(-1);
198214
AddAttr<bool>("flatten",
199215
"Flatten the input value, and search the min or max indices")
200216
.SetDefault(false);
217+
AddAttr<int>("dtype",
218+
"(int, 3), the dtype of indices, the indices dtype must be "
219+
"int32, int64."
220+
"default dtype is int64, and proto value is 3.")
221+
.SetDefault(3);
201222
AddComment(string::Sprintf(R"DOC(
202223
%s Operator.
203224

paddle/fluid/operators/arg_min_op.cc

+18
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/framework/op_version_registry.h"
1516
#include "paddle/fluid/operators/arg_min_max_op_base.h"
1617

1718
REGISTER_OPERATOR(
@@ -31,3 +32,20 @@ REGISTER_OP_CPU_KERNEL(
3132
int16_t>,
3233
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
3334
uint8_t>);
35+
REGISTER_OP_VERSION(arg_min)
36+
.AddCheckpoint(
37+
R"ROC(
38+
Upgrade argmin add a new attribute [flatten] and modify the attribute of dtype)ROC",
39+
paddle::framework::compatible::OpVersionDesc()
40+
.NewAttr("flatten",
41+
"In order to compute the argmin over the flattened array "
42+
"when the "
43+
"argument `axis` in python API is None.",
44+
false)
45+
.ModifyAttr(
46+
"dtype",
47+
"change the default value of dtype, the older version "
48+
"is -1, means return the int64 indices."
49+
"The new version is 3, return the int64 indices directly."
50+
"And supporting the dtype of -1 in new version.",
51+
3));

python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def run_static(self, place):
218218
self.assertTrue("test_arg_api" in result.name)
219219

220220
def run_dygraph(self, place):
221-
paddle.disable_static()
221+
paddle.disable_static(place)
222222
op = eval("paddle.%s" % (op_type))
223223
data_tensor = paddle.to_tensor(self.input_data)
224224

@@ -240,7 +240,7 @@ def run_dygraph(self, place):
240240
#case 4
241241
result_data = op(data_tensor, axis=-1, keepdim=True)
242242
excepted_data = self.numpy_op(self.input_data, axis=-1)
243-
excepted_data = excepted_data.reshape((10))
243+
excepted_data = excepted_data.reshape((10, 1))
244244
self.assertTrue((result_data.numpy() == excepted_data).all(), True)
245245

246246
#case 5
@@ -299,14 +299,28 @@ def test_argmax_attr_type():
299299
name="test_argmax", shape=[10], dtype="float32")
300300
output = paddle.argmax(x=data, dtype="float32")
301301

302-
self.assertRaises(ValueError, test_argmax_attr_type)
302+
self.assertRaises(TypeError, test_argmax_attr_type)
303303

304304
def test_argmin_attr_type():
305305
data = paddle.static.data(
306306
name="test_argmax", shape=[10], dtype="float32")
307307
output = paddle.argmin(x=data, dtype="float32")
308308

309-
self.assertRaises(ValueError, test_argmin_attr_type)
309+
self.assertRaises(TypeError, test_argmin_attr_type)
310+
311+
def test_argmax_axis_type():
312+
data = paddle.static.data(
313+
name="test_argmax", shape=[10], dtype="float32")
314+
output = paddle.argmax(x=data, axis=1.2)
315+
316+
self.assertRaises(TypeError, test_argmax_axis_type)
317+
318+
def test_argmin_axis_type():
319+
data = paddle.static.data(
320+
name="test_argmin", shape=[10], dtype="float32")
321+
output = paddle.argmin(x=data, axis=1.2)
322+
323+
self.assertRaises(TypeError, test_argmin_axis_type)
310324

311325

312326
if __name__ == '__main__':

python/paddle/tensor/search.py

+26-45
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from ..fluid import core, layers
1919

2020
# TODO: define searching & indexing functions of a tensor
21-
from ..fluid.layers import argmin #DEFINE_ALIAS
2221
from ..fluid.layers import has_inf #DEFINE_ALIAS
2322
from ..fluid.layers import has_nan #DEFINE_ALIAS
2423

@@ -123,7 +122,7 @@ def argsort(x, axis=-1, descending=False, name=None):
123122
return ids
124123

125124

126-
def argmax(x, axis=None, dtype=None, keepdim=False, name=None):
125+
def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
127126
"""
128127
This OP computes the indices of the max elements of the input tensor's
129128
element along the provided axis.
@@ -134,10 +133,10 @@ def argmax(x, axis=None, dtype=None, keepdim=False, name=None):
134133
axis(int, optional): Axis to compute indices along. The effective range
135134
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
136135
as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
137-
dtype(str): Data type of the output tensor which can
138-
be int32, int64. The default value is None, and it will
139-
return the int64 indices.
140136
keepdim(bool, optional): Keep the axis that selecting max. The defalut value is False.
137+
dtype(str|np.dtype, optional): Data type of the output tensor which can
138+
be int32, int64. The default value is 'int64', and it will
139+
return the int64 indices.
141140
name(str, optional): The default value is None. Normally there is no
142141
need for user to set this property. For more information, please
143142
refer to :ref:`api_guide_Name`.
@@ -163,48 +162,39 @@ def argmax(x, axis=None, dtype=None, keepdim=False, name=None):
163162
print(out3.numpy())
164163
# [2 3 1]
165164
"""
165+
if axis is not None and not isinstance(axis, int):
166+
raise TypeError(
167+
"The type of 'axis' must be int or None in argmax, but received %s."
168+
% (type(axis)))
169+
var_dtype = convert_np_dtype_to_dtype_(dtype)
170+
check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
166171
flatten = False
167172
if axis is None:
168173
flatten = True
169174
axis = 0
170175

171176
if in_dygraph_mode():
172-
if dtype != None:
173-
var_dtype = convert_np_dtype_to_dtype_(dtype)
174-
out = core.ops.arg_max(x, 'axis', axis, 'dtype', var_dtype,
175-
'keepdim', keepdim, 'flatten', flatten)
176-
else:
177-
out = core.ops.arg_max(x, 'axis', axis, 'keepdim', keepdim,
178-
'flatten', flatten)
177+
out = core.ops.arg_max(x, 'axis', axis, 'dtype', var_dtype, 'keepdims',
178+
keepdim, 'flatten', flatten)
179179
return out
180180

181181
helper = LayerHelper("argmax", **locals())
182182
check_variable_and_dtype(
183183
x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
184184
'paddle.argmax')
185-
var_dtype = None
186185
attrs = {}
187-
if dtype is not None:
188-
if dtype not in ['int32', 'int64']:
189-
raise ValueError(
190-
"The value of 'dtype' in argmax op must be int32, int64, but received of {}".
191-
format(dtype))
192-
var_dtype = convert_np_dtype_to_dtype_(dtype)
193-
attrs["dtype"] = var_dtype
194-
else:
195-
var_dtype = VarDesc.VarType.INT64
196-
197186
out = helper.create_variable_for_type_inference(var_dtype)
198187
attrs['keepdims'] = keepdim
199188
attrs['axis'] = axis
200189
attrs['flatten'] = flatten
190+
attrs['dtype'] = var_dtype
201191
helper.append_op(
202192
type='arg_max', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs)
203193
out.stop_gradient = True
204194
return out
205195

206196

207-
def argmin(x, axis=None, dtype=None, keepdim=False, name=None):
197+
def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
208198
"""
209199
This OP computes the indices of the min elements of the input tensor's
210200
element along the provided axis.
@@ -215,10 +205,10 @@ def argmin(x, axis=None, dtype=None, keepdim=False, name=None):
215205
axis(int, optional): Axis to compute indices along. The effective range
216206
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
217207
as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
208+
keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False.
218209
dtype(str): Data type of the output tensor which can
219-
be int32, int64. The default value is None, and it will
210+
be int32, int64. The default value is 'int64', and it will
220211
return the int64 indices.
221-
keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False.
222212
name(str, optional): The default value is None. Normally there is no
223213
need for user to set this property. For more information, please
224214
refer to :ref:`api_guide_Name`.
@@ -244,41 +234,32 @@ def argmin(x, axis=None, dtype=None, keepdim=False, name=None):
244234
print(out3.numpy())
245235
# [0 0 2]
246236
"""
237+
if axis is not None and not isinstance(axis, int):
238+
raise TypeError(
239+
"The type of 'axis' must be int or None in argmin, but received %s."
240+
% (type(axis)))
241+
var_dtype = convert_np_dtype_to_dtype_(dtype)
242+
check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
247243
flatten = False
248244
if axis is None:
249245
flatten = True
250246
axis = 0
251247

252248
if in_dygraph_mode():
253-
if dtype != None:
254-
var_dtype = convert_np_dtype_to_dtype_(dtype)
255-
out = core.ops.arg_min(x, 'axis', axis, 'dtype', var_dtype,
256-
'keepdim', keepdim, 'flatten', flatten)
257-
else:
258-
out = core.ops.arg_min(x, 'axis', axis, 'keepdim', keepdim,
259-
'flatten', flatten)
249+
out = core.ops.arg_min(x, 'axis', axis, 'dtype', var_dtype, 'keepdims',
250+
keepdim, 'flatten', flatten)
260251
return out
261252

262253
helper = LayerHelper("argmin", **locals())
263254
check_variable_and_dtype(
264255
x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
265256
'paddle.argmin')
266-
var_dtype = None
267-
attrs = {}
268-
if dtype is not None:
269-
if dtype not in ['int32', 'int64']:
270-
raise ValueError(
271-
"The value of 'dtype' in argmin op must be int32, int64, but received of {}".
272-
format(dtype))
273-
var_dtype = convert_np_dtype_to_dtype_(dtype)
274-
attrs["dtype"] = var_dtype
275-
else:
276-
var_dtype = VarDesc.VarType.INT64
277-
278257
out = helper.create_variable_for_type_inference(var_dtype)
258+
attrs = {}
279259
attrs['keepdims'] = keepdim
280260
attrs['axis'] = axis
281261
attrs['flatten'] = flatten
262+
attrs['dtype'] = var_dtype
282263
helper.append_op(
283264
type='arg_min', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs)
284265
out.stop_gradient = True

0 commit comments

Comments
 (0)