Skip to content

Commit 56eead2

Browse files
add tensor support for gaussian_random_op test=develop (#24389) (#24500)
1 parent f6050da commit 56eead2

File tree

8 files changed

+354
-133
lines changed

8 files changed

+354
-133
lines changed

paddle/fluid/operators/fill_constant_op.h

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,48 +20,26 @@ limitations under the License. */
2020
#include "paddle/fluid/framework/data_type.h"
2121
#include "paddle/fluid/framework/op_registry.h"
2222
#include "paddle/fluid/operators/math/math_function.h"
23+
#include "paddle/fluid/operators/utils.h"
2324

2425
namespace paddle {
2526
namespace operators {
2627

2728
using Tensor = framework::Tensor;
2829

29-
inline framework::DDim GetShape(const framework::ExecutionContext &ctx) {
30+
inline framework::DDim GetShape(const framework::ExecutionContext &ctx,
31+
std::string op_type) {
3032
// 1. shape is a Tensor
3133
if (ctx.HasInput("ShapeTensor")) {
3234
auto *shape_tensor = ctx.Input<framework::LoDTensor>("ShapeTensor");
33-
auto *shape_data = shape_tensor->data<int>();
34-
framework::Tensor cpu_shape_tensor;
35-
if (platform::is_gpu_place(shape_tensor->place())) {
36-
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor);
37-
shape_data = cpu_shape_tensor.data<int>();
38-
}
39-
auto vec_shape =
40-
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
35+
auto vec_shape = GetDataFromTensor<int>(shape_tensor);
4136
return framework::make_ddim(vec_shape);
4237
}
4338

4439
// 2. shape is a list/tuple containing Tensor
4540
auto shape_tensor_list = ctx.MultiInput<framework::Tensor>("ShapeTensorList");
4641
if (shape_tensor_list.size() > 0) {
47-
std::vector<int> vec_shape;
48-
for (size_t i = 0; i < shape_tensor_list.size(); ++i) {
49-
auto tensor = shape_tensor_list[i];
50-
PADDLE_ENFORCE_EQ(
51-
tensor->dims(), framework::make_ddim({1}),
52-
platform::errors::InvalidArgument(
53-
"If the element type of 'shape'(tensor_list type) in "
54-
"FillConstantOp is Tensor, the shape of this Tensor element must "
55-
"be [1]. But received the Tensor element's shape is [%s]",
56-
tensor->dims()));
57-
if (platform::is_gpu_place(tensor->place())) {
58-
framework::Tensor temp;
59-
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
60-
vec_shape.push_back(*temp.data<int>());
61-
} else {
62-
vec_shape.push_back(*tensor->data<int>());
63-
}
64-
}
42+
auto vec_shape = GetDataFromTensorList(shape_tensor_list);
6543
return framework::make_ddim(vec_shape);
6644
}
6745

@@ -115,7 +93,8 @@ class FillConstantKernel : public framework::OpKernel<T> {
11593
}
11694
value = tensor_data[0];
11795
}
118-
auto shape = GetShape(ctx);
96+
const std::string op_type = "fill_constant";
97+
auto shape = GetShape(ctx, op_type);
11998

12099
if (out_var->IsType<framework::LoDTensor>()) {
121100
tensor = out_var->GetMutable<framework::LoDTensor>();

paddle/fluid/operators/gaussian_random_op.cc

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,45 @@ limitations under the License. */
1414

1515
#include <random>
1616
#include "paddle/fluid/framework/op_registry.h"
17-
17+
#include "paddle/fluid/operators/fill_constant_op.h"
1818
#ifdef PADDLE_WITH_MKLDNN
1919
#include "paddle/fluid/platform/mkldnn_helper.h"
2020
#endif
2121

2222
namespace paddle {
2323
namespace operators {
2424

25+
using Tensor = framework::Tensor;
2526
template <typename T>
2627
class CPUGaussianRandomKernel : public framework::OpKernel<T> {
28+
public:
29+
void Compute(const framework::ExecutionContext& context) const override {
30+
float mean = context.Attr<float>("mean");
31+
float std = context.Attr<float>("std");
32+
auto* tensor = context.Output<framework::Tensor>("Out");
33+
34+
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
35+
std::minstd_rand engine;
36+
if (seed == 0) {
37+
seed = std::random_device()();
38+
}
39+
engine.seed(seed);
40+
std::normal_distribution<T> dist(mean, std);
41+
42+
const std::string op_type = "gaussian_random";
43+
auto shape = GetShape(context, op_type);
44+
tensor->Resize(shape);
45+
int64_t size = tensor->numel();
46+
T* data = tensor->mutable_data<T>(context.GetPlace());
47+
48+
for (int64_t i = 0; i < size; ++i) {
49+
data[i] = dist(engine);
50+
}
51+
}
52+
};
53+
54+
template <typename T>
55+
class CPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
2756
public:
2857
void Compute(const framework::ExecutionContext& context) const override {
2958
float mean = context.Attr<float>("mean");
@@ -58,12 +87,26 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
5887
for (auto dim : shape) {
5988
temp.push_back(static_cast<int64_t>(dim));
6089
}
61-
PADDLE_ENFORCE_GT(
62-
shape.size(), 0UL,
63-
platform::errors::InvalidArgument(
64-
"Attribute(shape) of GaussianRandomOp must be set "
65-
"and shape.size() > 0, but reveived shape.size() is %d",
66-
shape.size()));
90+
if (shape.empty() && ctx->HasInput("ShapeTensor")) {
91+
auto shape_dims = ctx->GetInputDim("ShapeTensor");
92+
int num_ele = 1;
93+
for (int i = 0; i < shape_dims.size(); ++i) {
94+
num_ele *= shape_dims[i];
95+
}
96+
auto vec_dims = std::vector<int>(num_ele, -1);
97+
ctx->SetOutputDim("Out", framework::make_ddim(vec_dims));
98+
99+
return;
100+
}
101+
if (!(ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList"))) {
102+
PADDLE_ENFORCE_GT(
103+
shape.size(), 0UL,
104+
platform::errors::InvalidArgument(
105+
"Attribute(shape) of GaussianRandomOp must be set "
106+
"and shape.size() > 0, but reveived shape.size() is %d",
107+
shape.size()));
108+
}
109+
67110
ctx->SetOutputDim("Out", framework::make_ddim(temp));
68111
}
69112

@@ -85,6 +128,16 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
85128
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
86129
ctx.device_context(), layout, library);
87130
}
131+
132+
framework::OpKernelType GetKernelTypeForVar(
133+
const std::string& var_name, const Tensor& tensor,
134+
const framework::OpKernelType& expected_kernel_type) const override {
135+
if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") {
136+
return expected_kernel_type;
137+
}
138+
return framework::OpKernelType(expected_kernel_type.data_type_,
139+
tensor.place(), tensor.layout());
140+
}
88141
};
89142

90143
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -94,7 +147,18 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
94147

95148
AddAttr<std::vector<int64_t>>("shape",
96149
"(vector<int64_t>) "
97-
"The dimension of random tensor.");
150+
"The dimension of random tensor.")
151+
.SetDefault({});
152+
AddInput("ShapeTensor",
153+
"(Tensor<int>), optional). The shape of the output."
154+
"It has a higher priority than Attr(shape).")
155+
.AsDispensable();
156+
AddInput("ShapeTensorList",
157+
"(vector<Tensor<int>>, optional). The shape of the output. "
158+
"It has a higher priority than Attr(shape)."
159+
"The shape of the element in vector must be [1].")
160+
.AsDuplicable()
161+
.AsDispensable();
98162
AddAttr<float>("mean",
99163
"(float, default 0.0) "
100164
"mean of random tensor.")
@@ -135,5 +199,5 @@ REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
135199
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel<float>,
136200
ops::CPUGaussianRandomKernel<double>);
137201
REGISTER_OP_CPU_KERNEL(gaussian_random_batch_size_like,
138-
ops::CPUGaussianRandomKernel<float>,
139-
ops::CPUGaussianRandomKernel<double>);
202+
ops::CPUGaussianRandomBatchSizeLikeKernel<float>,
203+
ops::CPUGaussianRandomBatchSizeLikeKernel<double>);

paddle/fluid/operators/gaussian_random_op.cu

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include <thrust/transform.h>
1616
#include "paddle/fluid/framework/op_registry.h"
1717
#include "paddle/fluid/framework/operator.h"
18+
#include "paddle/fluid/operators/fill_constant_op.h"
1819

1920
namespace paddle {
2021
namespace operators {
@@ -41,7 +42,6 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
4142
public:
4243
void Compute(const framework::ExecutionContext& context) const override {
4344
auto* tensor = context.Output<framework::Tensor>("Out");
44-
T* data = tensor->mutable_data<T>(context.GetPlace());
4545
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
4646
if (seed == 0) {
4747
std::random_device rd;
@@ -50,19 +50,45 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
5050
T mean = static_cast<T>(context.Attr<float>("mean"));
5151
T std = static_cast<T>(context.Attr<float>("std"));
5252
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
53+
const std::string op_type = "gaussian_random";
54+
auto shape = GetShape(context, op_type);
55+
tensor->Resize(shape);
56+
T* data = tensor->mutable_data<T>(context.GetPlace());
57+
5358
int64_t size = tensor->numel();
5459
thrust::transform(index_sequence_begin, index_sequence_begin + size,
5560
thrust::device_ptr<T>(data),
5661
GaussianGenerator<T>(mean, std, seed));
5762
}
5863
};
5964

65+
template <typename T>
66+
class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
67+
public:
68+
void Compute(const framework::ExecutionContext& context) const override {
69+
auto* tensor = context.Output<framework::Tensor>("Out");
70+
T* data = tensor->mutable_data<T>(context.GetPlace());
71+
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
72+
if (seed == 0) {
73+
std::random_device rd;
74+
seed = rd();
75+
}
76+
T mean = static_cast<T>(context.Attr<float>("mean"));
77+
T std = static_cast<T>(context.Attr<float>("std"));
78+
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
79+
int64_t size = tensor->numel();
80+
thrust::transform(index_sequence_begin, index_sequence_begin + size,
81+
thrust::device_ptr<T>(data),
82+
GaussianGenerator<T>(mean, std, seed));
83+
}
84+
};
6085
} // namespace operators
6186
} // namespace paddle
6287

6388
REGISTER_OP_CUDA_KERNEL(gaussian_random,
6489
paddle::operators::GPUGaussianRandomKernel<float>,
6590
paddle::operators::GPUGaussianRandomKernel<double>);
66-
REGISTER_OP_CUDA_KERNEL(gaussian_random_batch_size_like,
67-
paddle::operators::GPUGaussianRandomKernel<float>,
68-
paddle::operators::GPUGaussianRandomKernel<double>);
91+
REGISTER_OP_CUDA_KERNEL(
92+
gaussian_random_batch_size_like,
93+
paddle::operators::GPUGaussianRandomBatchSizeLikeKernel<float>,
94+
paddle::operators::GPUGaussianRandomBatchSizeLikeKernel<double>);

paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include <string>
16+
#include "paddle/fluid/operators/fill_constant_op.h"
1617
#include "paddle/fluid/operators/mean_op.h"
1718

1819
namespace paddle {
@@ -26,7 +27,6 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
2627
float mean = context.Attr<float>("mean");
2728
float std = context.Attr<float>("std");
2829
auto* tensor = context.Output<framework::Tensor>("Out");
29-
T* data = tensor->mutable_data<T>(context.GetPlace());
3030

3131
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
3232
std::minstd_rand engine;
@@ -35,6 +35,11 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
3535
}
3636
engine.seed(seed);
3737
std::normal_distribution<T> dist(mean, std);
38+
39+
const std::string op_type = "gaussian_random";
40+
auto shape = GetShape(context, op_type);
41+
tensor->Resize(shape);
42+
T* data = tensor->mutable_data<T>(context.GetPlace());
3843
int64_t size = tensor->numel();
3944
for (int64_t i = 0; i < size; ++i) {
4045
data[i] = dist(engine);

python/paddle/fluid/layers/distributions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,9 @@ def sample(self, shape, seed=0):
357357
output_shape = shape + batch_shape
358358
zero_tmp = tensor.fill_constant_batch_size_like(
359359
self.loc + self.scale, batch_shape + shape, self.loc.dtype, 0.)
360-
normal_random_tmp = nn.gaussian_random_batch_size_like(
361-
zero_tmp, zero_tmp.shape, mean=0., std=1., seed=seed)
360+
zero_tmp_shape = nn.shape(zero_tmp)
361+
normal_random_tmp = nn.gaussian_random(
362+
zero_tmp_shape, mean=0., std=1., seed=seed)
362363
output = normal_random_tmp * (zero_tmp + self.scale) + self.loc
363364
return nn.reshape(output, output_shape)
364365
else:

0 commit comments

Comments
 (0)