Skip to content

Commit c0af05f

Browse files
refine linspace Op for dtype setting(#27071)
1 parent 418060a commit c0af05f

File tree

4 files changed

+58
-12
lines changed

4 files changed

+58
-12
lines changed

paddle/fluid/operators/linspace_op.cu

+19-4
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/framework/data_type_transform.h"
1516
#include "paddle/fluid/framework/op_registry.h"
1617
#include "paddle/fluid/operators/linspace_op.h"
1718
#include "paddle/fluid/platform/cuda_primitives.h"
1819

1920
namespace paddle {
2021
namespace operators {
2122

23+
using Tensor = framework::Tensor;
24+
2225
template <typename T>
2326
__global__ void LinspaceKernel(T start, double step, int64_t size, T* out) {
2427
CUDA_KERNEL_LOOP(index, size) {
@@ -35,15 +38,27 @@ template <typename T>
3538
class CUDALinspaceKernel : public framework::OpKernel<T> {
3639
public:
3740
void Compute(const framework::ExecutionContext& context) const override {
38-
auto* start_t = context.Input<framework::Tensor>("Start");
39-
auto* stop_t = context.Input<framework::Tensor>("Stop");
41+
auto* pre_start = context.Input<framework::Tensor>("Start");
42+
auto* pre_stop = context.Input<framework::Tensor>("Stop");
4043
auto* num_t = context.Input<framework::Tensor>("Num");
4144
auto* out = context.Output<framework::Tensor>("Out");
45+
auto dtype = static_cast<framework::proto::VarType::Type>(
46+
context.Attr<int>("dtype"));
47+
48+
Tensor start_t;
49+
Tensor stop_t;
50+
auto start_dtype =
51+
framework::OpKernelType(pre_start->type(), context.GetPlace());
52+
auto stop_dtype =
53+
framework::OpKernelType(pre_stop->type(), context.GetPlace());
54+
auto out_dtype = framework::OpKernelType(dtype, context.GetPlace());
55+
framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
56+
framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
4257

4358
framework::Tensor n;
44-
framework::TensorCopy(*start_t, platform::CPUPlace(), &n);
59+
framework::TensorCopy(start_t, platform::CPUPlace(), &n);
4560
T start = n.data<T>()[0];
46-
framework::TensorCopy(*stop_t, platform::CPUPlace(), &n);
61+
framework::TensorCopy(stop_t, platform::CPUPlace(), &n);
4762
T stop = n.data<T>()[0];
4863
framework::TensorCopy(*num_t, platform::CPUPlace(), &n);
4964
int32_t num = n.data<int32_t>()[0];

paddle/fluid/operators/linspace_op.h

+20-2
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,38 @@ limitations under the License. */
1414

1515
#pragma once
1616
#include <functional>
17+
#include "paddle/fluid/framework/data_type_transform.h"
1718
#include "paddle/fluid/framework/op_registry.h"
1819
#include "paddle/fluid/operators/math/math_function.h"
1920

2021
namespace paddle {
2122
namespace operators {
2223

24+
using Tensor = framework::Tensor;
25+
2326
template <typename T>
2427
class CPULinspaceKernel : public framework::OpKernel<T> {
2528
public:
2629
void Compute(const framework::ExecutionContext& context) const override {
27-
T start = context.Input<framework::Tensor>("Start")->data<T>()[0];
28-
T stop = context.Input<framework::Tensor>("Stop")->data<T>()[0];
30+
auto* pre_start = context.Input<framework::Tensor>("Start");
31+
auto* pre_stop = context.Input<framework::Tensor>("Stop");
2932
int32_t num = context.Input<framework::Tensor>("Num")->data<int32_t>()[0];
3033
auto* out = context.Output<framework::Tensor>("Out");
34+
auto dtype = static_cast<framework::proto::VarType::Type>(
35+
context.Attr<int>("dtype"));
36+
37+
Tensor start_t;
38+
Tensor stop_t;
39+
auto start_dtype =
40+
framework::OpKernelType(pre_start->type(), context.GetPlace());
41+
auto stop_dtype =
42+
framework::OpKernelType(pre_stop->type(), context.GetPlace());
43+
auto out_dtype = framework::OpKernelType(dtype, context.GetPlace());
44+
framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
45+
framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
46+
47+
T start = start_t.data<T>()[0];
48+
T stop = stop_t.data<T>()[0];
3149
PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0.");
3250

3351
out->Resize(framework::make_ddim({num}));

python/paddle/fluid/layers/tensor.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -1488,13 +1488,18 @@ def linspace(start, stop, num, dtype=None, name=None):
14881488

14891489
helper = LayerHelper("linspace", **locals())
14901490

1491+
start_dtype = convert_dtype(tensor_start.dtype)
1492+
stop_dtype = convert_dtype(tensor_stop.dtype)
1493+
out_dtype = convert_dtype(dtype)
14911494
if isinstance(start, Variable):
1492-
check_dtype(start.dtype, 'start', (convert_dtype(dtype)), 'linspace')
1495+
check_dtype(start.dtype, 'start',
1496+
['float32', 'float64', 'int32', 'int64'], 'linspace')
14931497
else:
14941498
check_type(start, 'start', (int, float), 'linspace')
14951499

14961500
if isinstance(stop, Variable):
1497-
check_dtype(stop.dtype, 'stop', (convert_dtype(dtype)), 'linspace')
1501+
check_dtype(stop.dtype, 'stop',
1502+
['float32', 'float64', 'int32', 'int64'], 'linspace')
14981503
else:
14991504
check_type(stop, 'stop', (int, float), 'linspace')
15001505
if isinstance(num, Variable):
@@ -1503,6 +1508,14 @@ def linspace(start, stop, num, dtype=None, name=None):
15031508
check_type(num, 'num', (int), 'linspace')
15041509
check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'],
15051510
'linspace')
1511+
if ((stop_dtype == "float64" or start_dtype == "float64") and
1512+
out_dtype in ["float32", "int32"]) or ((stop_dtype == "int64" or
1513+
start_dtype == "int64") and
1514+
out_dtype == "int32"):
1515+
raise ValueError(
1516+
"The dtype of start/stop is {}/{} but the attr(dtype) of linspace is {}, "
1517+
"which may cause data type overflows. Please reset attr(dtype) of linspace."
1518+
.format(start_dtype, stop_dtype, dtype))
15061519

15071520
out = helper.create_variable_for_type_inference(dtype=dtype)
15081521

python/paddle/fluid/tests/unittests/test_linspace.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -154,16 +154,16 @@ def test_step_dtype():
154154
self.assertRaises(TypeError, test_step_dtype)
155155

156156
def test_start_dtype():
157-
start = fluid.data(shape=[1], dtype="int32", name="start")
157+
start = fluid.data(shape=[1], dtype="float64", name="start")
158158
fluid.layers.linspace(start, 10, 1, dtype="float32")
159159

160-
self.assertRaises(TypeError, test_start_dtype)
160+
self.assertRaises(ValueError, test_start_dtype)
161161

162162
def test_end_dtype():
163-
end = fluid.data(shape=[1], dtype="int32", name="end")
163+
end = fluid.data(shape=[1], dtype="float64", name="end")
164164
fluid.layers.linspace(0, end, 1, dtype="float32")
165165

166-
self.assertRaises(TypeError, test_end_dtype)
166+
self.assertRaises(ValueError, test_end_dtype)
167167

168168
def test_num_dtype():
169169
num = fluid.data(shape=[1], dtype="int32", name="step")

0 commit comments

Comments
 (0)