We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 06888bb commit fba6a10Copy full SHA for fba6a10
paddle/framework/data_transform.cc
@@ -87,11 +87,20 @@ void TransDataLayout(const platform::DeviceContext* ctx,
87
auto* dst = out->GetMutable<Tensor>();
88
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");
89
90
- dst->Resize(src.dims());
+ auto src_dim = src.dims();
91
+ dst->Resize(src_dim);
92
auto place = kernel_pair.second.place_;
93
CopyFrom(src, place, *ctx, dst);
94
const std::vector<int> axis = {0, 2, 3, 1};
95
96
+ std::vector<int64_t> dst_dim;
97
+ dst_dim.resize(axis.size());
98
+ for (size_t i = 0; i < axis.size(); i++) {
99
+ dst_dim[i] = src_dim[axis[i]];
100
+ }
101
+
102
+ dst->Resize(make_ddim(dst_dim));
103
104
auto src_type = kernel_pair.first.data_type_;
105
framework::VisitDataType(src_type, CastDataLayout(src, dst, ctx, axis));
106
paddle/framework/data_transform_test.cc
@@ -32,18 +32,18 @@ using namespace platform;
32
* 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN
33
*/
34
35
-std::array<proto::DataType, 2> kDataType = {proto::DataType::FP32,
36
- proto::DataType::FP64};
+std::array<proto::DataType, 2> kDataType = {
+ {proto::DataType::FP32, proto::DataType::FP64}};
37
38
-std::array<Place, 2> kPlace = {CPUPlace(), CUDAPlace(0)};
+std::array<Place, 2> kPlace = {{CPUPlace(), CUDAPlace(0)}};
39
40
-std::array<DataLayout, 2> kDataLayout = {
+std::array<DataLayout, 2> kDataLayout = {{
41
DataLayout::kNHWC, DataLayout::kNCHW,
42
-};
+}};
43
44
-std::array<LibraryType, 2> kLibraryType = {
+std::array<LibraryType, 2> kLibraryType = {{
45
LibraryType::kPlain, LibraryType::kMKLDNN,
46
47
48
OpKernelType GenFromBit(const std::vector<bool> bits) {
49
return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]],
0 commit comments