Skip to content

Commit 71ee897

Browse files
authored
[0-size Tensor No.81、138、141] Add 0-size Tensor support for blha_get_max_len (#72937)
1 parent 5c1b862 commit 71ee897

File tree

14 files changed

+270
-13
lines changed

14 files changed

+270
-13
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2314,15 +2314,40 @@ bool NanmedianOpInferSymbolicShape(
23142314
if (mode == "avg") {
23152315
median_shape.emplace_back(2);
23162316
}
2317-
infer_context->SetShapeOrDataForValue(
2318-
op->result(0),
2319-
symbol::ShapeOrDataDimExprs{
2320-
symbol::TensorShapeOrDataDimExprs(out_shape)});
2321-
infer_context->SetShapeOrDataForValue(
2322-
op->result(1),
2323-
symbol::ShapeOrDataDimExprs{
2324-
symbol::TensorShapeOrDataDimExprs(median_shape)});
23252317

2318+
const auto &IsZero = [&](const symbol::DimExpr &dim_expr) {
2319+
if (dim_expr.isa<int64_t>()) {
2320+
return dim_expr.dyn_cast<int64_t>() == static_cast<int64_t>(0);
2321+
}
2322+
return false;
2323+
};
2324+
bool size_0 = false;
2325+
for (size_t i = 0; i < x_shape.size(); i++) {
2326+
if (IsZero(x_shape.at(i))) {
2327+
size_0 = true;
2328+
break;
2329+
}
2330+
}
2331+
if (size_0) {
2332+
std::vector<symbol::DimExpr> x_numel_0_shape = {};
2333+
infer_context->SetShapeOrDataForValue(
2334+
op->result(0),
2335+
symbol::ShapeOrDataDimExprs{
2336+
symbol::TensorShapeOrDataDimExprs(x_numel_0_shape)});
2337+
infer_context->SetShapeOrDataForValue(
2338+
op->result(1),
2339+
symbol::ShapeOrDataDimExprs{
2340+
symbol::TensorShapeOrDataDimExprs(x_numel_0_shape)});
2341+
} else {
2342+
infer_context->SetShapeOrDataForValue(
2343+
op->result(0),
2344+
symbol::ShapeOrDataDimExprs{
2345+
symbol::TensorShapeOrDataDimExprs(out_shape)});
2346+
infer_context->SetShapeOrDataForValue(
2347+
op->result(1),
2348+
symbol::ShapeOrDataDimExprs{
2349+
symbol::TensorShapeOrDataDimExprs(median_shape)});
2350+
}
23262351
return true;
23272352
}
23282353

paddle/phi/infermeta/unary.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2854,6 +2854,11 @@ void NanmedianInferMeta(const MetaTensor& x,
28542854
}
28552855
median_index->set_dtype(DataType::INT64);
28562856
median_index->set_dims(make_ddim(median_dim));
2857+
2858+
if (x.numel() == 0) {
2859+
out->set_dims(make_ddim({}));
2860+
median_index->set_dims(make_ddim({}));
2861+
}
28572862
}
28582863

28592864
void NMSInferMeta(const MetaTensor& x, float threshold, MetaTensor* out) {

paddle/phi/kernels/cpu/mv_grad_kernel.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920
#include "paddle/phi/kernels/funcs/blas/blas.h"
2021

2122
namespace phi {
@@ -30,6 +31,21 @@ void MvGradKernel(const Context& dev_ctx,
3031
auto dout = out_grad;
3132
auto dx = x_grad;
3233
auto dvec = vec_grad;
34+
if (x.numel() == 0 || vec.numel() == 0) {
35+
if (dx) {
36+
phi::Full<T, Context>(dev_ctx,
37+
phi::IntArray(common::vectorize(dx->dims())),
38+
static_cast<T>(0),
39+
dx);
40+
}
41+
if (dvec) {
42+
phi::Full<T, Context>(dev_ctx,
43+
phi::IntArray(common::vectorize(dvec->dims())),
44+
static_cast<T>(0),
45+
dvec);
46+
}
47+
return;
48+
}
3349

3450
const auto& dim_x = x.dims();
3551
int m = static_cast<int>(dim_x[0]);

paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ void NanmedianGradKernel(const Context& dev_ctx,
9696
bool keepdim UNUSED,
9797
const std::string& mode,
9898
DenseTensor* x_grad) {
99+
if (x_grad && x_grad->numel() == 0) {
100+
dev_ctx.template Alloc<T>(x_grad);
101+
return;
102+
}
99103
DenseTensor tmp_x;
100104
auto rank = x.dims().size();
101105
if ((axes.size() == 0) || rank <= 1) {

paddle/phi/kernels/cpu/nanmedian_kernel.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920
#include "paddle/phi/kernels/funcs/nanmedian_utils.h"
2021
#include "paddle/phi/kernels/top_k_kernel.h"
2122

@@ -218,6 +219,16 @@ void NanmedianKernel(const Context& dev_ctx,
218219
const std::string& mode,
219220
DenseTensor* out,
220221
DenseTensor* median_index) {
222+
if (x.numel() == 0) {
223+
phi::Full<T, Context>(
224+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
225+
phi::Full<int64_t, Context>(
226+
dev_ctx,
227+
phi::IntArray(common::vectorize(median_index->dims())),
228+
0,
229+
median_index);
230+
return;
231+
}
221232
DenseTensor tmp_x;
222233
auto rank = x.dims().size();
223234
if ((axes.size() == 0) || rank <= 1) {

paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include "paddle/phi/backends/context_pool.h"
1516
#include "paddle/phi/core/dense_tensor.h"
1617
#include "paddle/phi/core/kernel_registry.h"
1718
#include "paddle/phi/kernels/flash_attn_kernel.h"
@@ -49,13 +50,33 @@ void BlhaGetMaxLenKernel(const Context& dev_ctx,
4950
const phi::DenseTensor& batch_size,
5051
DenseTensor* max_enc_len_this_time,
5152
DenseTensor* max_dec_len_this_time) {
53+
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
54+
auto& dev_ctx_cpu = *pool.Get(phi::CPUPlace());
5255
// decoder
5356
max_dec_len_this_time->Resize({{1}});
54-
GetMaxLenTensor(dev_ctx, seq_lens_decoder, batch_size, max_dec_len_this_time);
57+
if (seq_lens_decoder.numel() > 0) {
58+
GetMaxLenTensor(
59+
dev_ctx, seq_lens_decoder, batch_size, max_dec_len_this_time);
60+
} else {
61+
phi::Full<int, phi::CPUContext>(
62+
reinterpret_cast<const phi::CPUContext&>(dev_ctx_cpu),
63+
phi::IntArray(common::vectorize(max_dec_len_this_time->dims())),
64+
0,
65+
max_dec_len_this_time);
66+
}
5567

5668
// encoder
5769
max_enc_len_this_time->Resize({{1}});
58-
GetMaxLenTensor(dev_ctx, seq_lens_encoder, batch_size, max_enc_len_this_time);
70+
if (seq_lens_encoder.numel() > 0) {
71+
GetMaxLenTensor(
72+
dev_ctx, seq_lens_encoder, batch_size, max_enc_len_this_time);
73+
} else {
74+
phi::Full<int, phi::CPUContext>(
75+
reinterpret_cast<const phi::CPUContext&>(dev_ctx_cpu),
76+
phi::IntArray(common::vectorize(max_enc_len_this_time->dims())),
77+
0,
78+
max_enc_len_this_time);
79+
}
5980
}
6081
} // namespace fusion
6182
} // namespace phi

paddle/phi/kernels/fusion/xpu/blha_get_max_len_kernel.cc

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414

1515
#include <paddle/phi/backends/xpu/xpu_context.h>
1616
#include "glog/logging.h"
17+
#include "paddle/phi/backends/context_pool.h"
1718
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1819
#include "paddle/phi/common/memory_utils.h"
1920
#include "paddle/phi/core/kernel_registry.h"
21+
#include "paddle/phi/kernels/full_kernel.h"
2022
#include "paddle/phi/kernels/memcpy_kernel.h"
2123
#include "xpu/xdnn.h"
2224

@@ -49,13 +51,33 @@ void BlhaGetMaxLenKernel(const Context& dev_ctx,
4951
const phi::DenseTensor& batch_size,
5052
DenseTensor* max_enc_len_this_time,
5153
DenseTensor* max_dec_len_this_time) {
54+
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
55+
auto& dev_ctx_cpu = *pool.Get(phi::CPUPlace());
5256
// decoder
5357
max_dec_len_this_time->Resize({{1}});
54-
GetMaxLenTensor(dev_ctx, seq_lens_decoder, batch_size, max_dec_len_this_time);
58+
if (seq_lens_decoder.numel() > 0) {
59+
GetMaxLenTensor(
60+
dev_ctx, seq_lens_decoder, batch_size, max_dec_len_this_time);
61+
} else {
62+
phi::Full<int, phi::CPUContext>(
63+
reinterpret_cast<const phi::CPUContext&>(dev_ctx_cpu),
64+
phi::IntArray(common::vectorize(max_dec_len_this_time->dims())),
65+
0,
66+
max_dec_len_this_time);
67+
}
5568

5669
// encoder
5770
max_enc_len_this_time->Resize({{1}});
58-
GetMaxLenTensor(dev_ctx, seq_lens_encoder, batch_size, max_enc_len_this_time);
71+
if (seq_lens_encoder.numel() > 0) {
72+
GetMaxLenTensor(
73+
dev_ctx, seq_lens_encoder, batch_size, max_enc_len_this_time);
74+
} else {
75+
phi::Full<int, phi::CPUContext>(
76+
reinterpret_cast<const phi::CPUContext&>(dev_ctx_cpu),
77+
phi::IntArray(common::vectorize(max_enc_len_this_time->dims())),
78+
0,
79+
max_enc_len_this_time);
80+
}
5981
}
6082
} // namespace fusion
6183
} // namespace phi

paddle/phi/kernels/gpu/mv_grad_kernel.cu

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920
#include "paddle/phi/kernels/funcs/blas/blas.h"
2021

2122
namespace phi {
@@ -41,6 +42,21 @@ void MvGradKernel(const Context &dev_ctx,
4142
auto dout = out_grad;
4243
auto dx = x_grad;
4344
auto dvec = vec_grad;
45+
if (x.numel() == 0 || vec.numel() == 0) {
46+
if (dx) {
47+
phi::Full<T, Context>(dev_ctx,
48+
phi::IntArray(common::vectorize(dx->dims())),
49+
static_cast<T>(0),
50+
dx);
51+
}
52+
if (dvec) {
53+
phi::Full<T, Context>(dev_ctx,
54+
phi::IntArray(common::vectorize(dvec->dims())),
55+
static_cast<T>(0),
56+
dvec);
57+
}
58+
return;
59+
}
4460

4561
auto dim_x = x.dims();
4662
int m = dim_x[0];

paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ void NanmedianGradKernel(const Context& dev_ctx,
111111
bool keepdim UNUSED,
112112
const std::string& mode,
113113
DenseTensor* x_grad) {
114+
if (x_grad && x_grad->numel() == 0) {
115+
dev_ctx.template Alloc<T>(x_grad);
116+
return;
117+
}
114118
DenseTensor tmp_x;
115119
auto rank = x.dims().size();
116120
if ((axes.size() == 0) || rank <= 1) {

paddle/phi/kernels/gpu/nanmedian_kernel.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,16 @@ void NanmedianKernel(const Context& dev_ctx,
356356
const std::string& mode,
357357
DenseTensor* out,
358358
DenseTensor* median_index) {
359+
if (x.numel() == 0) {
360+
phi::Full<T, Context>(
361+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
362+
phi::Full<int64_t, Context>(
363+
dev_ctx,
364+
phi::IntArray(common::vectorize(median_index->dims())),
365+
0,
366+
median_index);
367+
return;
368+
}
359369
DenseTensor tmp_x;
360370
auto rank = x.dims().size();
361371
if ((axes.size() == 0) || rank <= 1) {

paddle/phi/kernels/impl/mv_kernel_impl.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
#pragma once
1616

17+
#include "paddle/phi/kernels/full_kernel.h"
1718
#include "paddle/phi/kernels/funcs/blas/blas.h"
18-
1919
namespace phi {
2020

2121
template <typename T, typename Context>
@@ -29,6 +29,17 @@ void MvKernel(const Context& dev_ctx,
2929
const T* x_data = x.data<T>();
3030
const T* vec_data = vec.data<T>();
3131
T* out_data = dev_ctx.template Alloc<T>(out);
32+
if (out && out->numel() == 0) {
33+
return;
34+
}
35+
// x.shape [10, 0], vec.shape [0], out.shape [10]
36+
if (vec.numel() == 0) {
37+
phi::Full<T, Context>(dev_ctx,
38+
phi::IntArray(common::vectorize(out->dims())),
39+
static_cast<T>(0),
40+
out);
41+
return;
42+
}
3243

3344
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
3445

test/legacy_test/test_blha_get_max_len_op.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,82 @@ def test_static_api(self):
108108
)
109109

110110

111+
@unittest.skipIf(
112+
not core.is_compiled_with_cuda() and not core.is_compiled_with_xpu(),
113+
"Only support XPU or GPU in CUDA mode.",
114+
)
115+
class TestBlhaGetMaxLenOp_ZeroSize(unittest.TestCase):
116+
def setUp(self):
117+
self.name = "TestBlhaGetMaxLenOpDynamic_ZeroSize"
118+
if paddle.is_compiled_with_cuda():
119+
place = paddle.CUDAPlace(0)
120+
elif paddle.device.is_compiled_with_xpu():
121+
place = paddle.device.XPUPlace(0)
122+
else:
123+
raise ValueError("Only support CUDA or XPU Place.")
124+
self.batch_size = 0
125+
self.test_encoder_data = np.random.randint(
126+
1, 100, size=self.batch_size
127+
).astype("int32")
128+
self.test_decoder_data = np.random.randint(
129+
1, 100, size=self.batch_size
130+
).astype("int32")
131+
132+
def test_dynamic_api(self):
133+
paddle.disable_static()
134+
seq_lens_encoder = paddle.to_tensor(
135+
self.test_encoder_data,
136+
"int32",
137+
)
138+
seq_lens_decoder = paddle.to_tensor(
139+
self.test_decoder_data,
140+
"int32",
141+
)
142+
batch_size_tensor = paddle.ones([self.batch_size])
143+
max_enc_len_this_time, max_dec_len_this_time = blha_get_max_len(
144+
seq_lens_encoder,
145+
seq_lens_decoder,
146+
batch_size_tensor,
147+
)
148+
assert tuple(max_enc_len_this_time.shape) == (1,) and tuple(
149+
max_dec_len_this_time.shape
150+
) == (1,)
151+
152+
def test_static_api(self):
153+
paddle.enable_static()
154+
155+
if paddle.is_compiled_with_cuda():
156+
place = paddle.CUDAPlace(0)
157+
elif paddle.device.is_compiled_with_xpu():
158+
place = paddle.device.XPUPlace(0)
159+
else:
160+
raise ValueError("Only support CUDA or XPU Place.")
161+
162+
with paddle.static.program_guard(paddle.static.Program()):
163+
seq_lens_encoder = paddle.static.data(
164+
"seq_lens_encoder", self.test_encoder_data.shape, "int32"
165+
)
166+
seq_lens_decoder = paddle.static.data(
167+
"seq_lens_decoder", self.test_decoder_data.shape, "int32"
168+
)
169+
batch_size_tensor = paddle.ones([self.batch_size], "int32")
170+
max_enc_len_this_time, max_dec_len_this_time = blha_get_max_len(
171+
seq_lens_encoder,
172+
seq_lens_decoder,
173+
batch_size_tensor,
174+
)
175+
exe = paddle.static.Executor(place)
176+
res_max_enc_len_this_time, res_max_dec_len_this_time = exe.run(
177+
feed={
178+
"seq_lens_encoder": self.test_encoder_data,
179+
"seq_lens_decoder": self.test_decoder_data,
180+
},
181+
fetch_list=[max_enc_len_this_time, max_dec_len_this_time],
182+
)
183+
assert tuple(res_max_enc_len_this_time.shape) == (1,) and tuple(
184+
res_max_dec_len_this_time.shape
185+
) == (1,)
186+
187+
111188
if __name__ == '__main__':
112189
unittest.main()

0 commit comments

Comments
 (0)