Skip to content

Commit ac57ded

Browse files
committed
unstack 大tensor问题
1 parent 1d581f2 commit ac57ded

File tree

4 files changed

+22
-13
lines changed

4 files changed

+22
-13
lines changed

paddle/phi/infermeta/backward.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include "paddle/phi/common/type_traits.h"
1717
#include "paddle/phi/core/utils/data_type.h"
1818
#include "paddle/phi/kernels/funcs/axis_utils.h"
19+
#include <glog/logging.h>
1920

2021
namespace phi {
2122

@@ -1747,8 +1748,10 @@ void UnStackGradInferMeta(const std::vector<const MetaTensor*>& out_grad,
17471748
rank));
17481749
if (axis < 0) axis += (rank + 1);
17491750

1750-
auto vec = common::vectorize<int>(input_dims[0]);
1751-
vec.insert(vec.begin() + axis, static_cast<int>(input_dims.size()));
1751+
auto vec = common::vectorize<int64_t>(input_dims[0]);
1752+
vec.insert(vec.begin() + axis, static_cast<int64_t>(input_dims.size()));
1753+
for (size_t i =0;i<vec.size();i++)
1754+
VLOG(1) << "!!!!!!" << vec[i];
17521755
x_grad->set_dims(common::make_ddim(vec));
17531756
x_grad->set_dtype(out_grad[0]->dtype());
17541757
}

paddle/phi/infermeta/unary.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ limitations under the License. */
3232
#include "paddle/phi/kernels/funcs/unsqueeze.h"
3333
#include "paddle/phi/kernels/impl/einsum_impl.h"
3434

35+
#include <glog/logging.h>
36+
3537
namespace phi {
3638

3739
namespace detail {
@@ -5962,7 +5964,7 @@ void UnStackInferMeta(const MetaTensor& x,
59625964
x_dim[axis],
59635965
num));
59645966
}
5965-
auto vec = common::vectorize<int>(x_dim);
5967+
auto vec = common::vectorize<int64_t>(x_dim);
59665968
vec.erase(vec.begin() + axis);
59675969
for (size_t i = 0; i < output_count; i++) {
59685970
outs[i]->set_dims(common::make_ddim(vec));

paddle/phi/kernels/funcs/stack_and_unstack.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ void LaunchUnStackKernel(const Context& ctx,
210210
constexpr int kWarpSize = 32;
211211
constexpr int kMaxOut = 16;
212212

213-
int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1;
213+
int64_t tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1;
214214
if (split_dim < kMaxOut) {
215215
tid_y = split_dim;
216216
tid_x =
@@ -219,10 +219,13 @@ void LaunchUnStackKernel(const Context& ctx,
219219
} else {
220220
tid_y = kMaxOut;
221221
tid_x = kWarpSize;
222-
bid_y = backends::gpu::DivUp<int>(split_dim, kMaxOut);
222+
bid_y = backends::gpu::DivUp<int64_t>(split_dim, kMaxOut);
223223
}
224-
int tile_x_num = backends::gpu::DivUp<int>(out_row, tid_x);
225-
bid_x = std::min(tile_x_num, backends::gpu::kMultiDimslimit);
224+
int64_t tile_x_num = backends::gpu::DivUp<int64_t>(out_row, tid_x);
225+
if (tile_x_num < static_cast<int64_t>(backends::gpu::kMultiDimslimit))
226+
bid_x = tile_x_num;
227+
else
228+
bid_x = backends::gpu::kMultiDimslimit;
226229
dim3 blocks(tid_x, tid_y, 1);
227230
dim3 grids(bid_x, bid_y, 1);
228231

paddle/phi/kernels/impl/unstack_grad_kernel_impl.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,16 @@ void UnStackGradKernel(const Context &dev_ctx,
2626
const std::vector<const DenseTensor *> &x,
2727
int axis,
2828
DenseTensor *x_grad) {
29+
VLOG(1) << "in";
2930
if (axis < 0) axis += (x[0]->dims().size() + 1);
3031

31-
int n = static_cast<int>(x.size());
32+
int64_t n = static_cast<int64_t>(x.size());
3233
auto *x_grad_data = dev_ctx.template Alloc<T>(x_grad);
3334
std::vector<const T *> x_datas(n);
34-
for (int i = 0; i < n; i++) x_datas[i] = x[i]->data<T>();
35+
for (int64_t i = 0; i < n; i++) x_datas[i] = x[i]->data<T>();
3536

36-
int pre = 1;
37-
int post = 1;
37+
int64_t pre = 1;
38+
int64_t post = 1;
3839
auto &dim = x[0]->dims();
3940
for (auto i = 0; i < axis; ++i) pre *= dim[i];
4041
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
@@ -56,8 +57,8 @@ void UnStackGradKernel(const Context &dev_ctx,
5657

5758
size_t x_offset = 0;
5859
size_t y_offset = 0;
59-
for (int i = 0; i < pre; i++) {
60-
for (int j = 0; j < n; j++) {
60+
for (int64_t i = 0; i < pre; i++) {
61+
for (int64_t j = 0; j < n; j++) {
6162
std::memcpy(
6263
x_grad_data + y_offset, x_data_arr[j] + x_offset, post * sizeof(T));
6364
y_offset += post;

0 commit comments

Comments
 (0)