Skip to content

Commit fba6a10

Browse files
authored
fix bug in TransDataLayout (#7137)
1 parent 06888bb commit fba6a10

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

paddle/framework/data_transform.cc

+10-1
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,20 @@ void TransDataLayout(const platform::DeviceContext* ctx,
8787
auto* dst = out->GetMutable<Tensor>();
8888
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");
8989

90-
dst->Resize(src.dims());
90+
auto src_dim = src.dims();
91+
dst->Resize(src_dim);
9192
auto place = kernel_pair.second.place_;
9293
CopyFrom(src, place, *ctx, dst);
9394
const std::vector<int> axis = {0, 2, 3, 1};
9495

96+
std::vector<int64_t> dst_dim;
97+
dst_dim.resize(axis.size());
98+
for (size_t i = 0; i < axis.size(); i++) {
99+
dst_dim[i] = src_dim[axis[i]];
100+
}
101+
102+
dst->Resize(make_ddim(dst_dim));
103+
95104
auto src_type = kernel_pair.first.data_type_;
96105
framework::VisitDataType(src_type, CastDataLayout(src, dst, ctx, axis));
97106

paddle/framework/data_transform_test.cc

+7-7
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,18 @@ using namespace platform;
3232
* 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN
3333
*/
3434

35-
std::array<proto::DataType, 2> kDataType = {proto::DataType::FP32,
36-
proto::DataType::FP64};
35+
std::array<proto::DataType, 2> kDataType = {
36+
{proto::DataType::FP32, proto::DataType::FP64}};
3737

38-
std::array<Place, 2> kPlace = {CPUPlace(), CUDAPlace(0)};
38+
std::array<Place, 2> kPlace = {{CPUPlace(), CUDAPlace(0)}};
3939

40-
std::array<DataLayout, 2> kDataLayout = {
40+
std::array<DataLayout, 2> kDataLayout = {{
4141
DataLayout::kNHWC, DataLayout::kNCHW,
42-
};
42+
}};
4343

44-
std::array<LibraryType, 2> kLibraryType = {
44+
std::array<LibraryType, 2> kLibraryType = {{
4545
LibraryType::kPlain, LibraryType::kMKLDNN,
46-
};
46+
}};
4747

4848
OpKernelType GenFromBit(const std::vector<bool> bits) {
4949
return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]],

0 commit comments

Comments
 (0)