Skip to content

Commit 3dffe11

Browse files
authored
[NPU] Fix concat while inputs size is 1 (PaddlePaddle#237)
* fix concat while inputs size is 1 * fix concat kernel * fix bug
1 parent 1dcf7c1 commit 3dffe11

File tree

2 files changed

+27
-22
lines changed

2 files changed

+27
-22
lines changed

backends/npu/kernels/concat_kernel.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ void ConcatKernel(const Context& dev_ctx,
6161
}
6262
}
6363
if (inputs.size() == 1) {
64-
*out = inputs[0];
64+
out->ResizeAndAllocate(inputs[0].dims());
65+
TensorCopy(dev_ctx, inputs[0], true, out);
6566
return;
6667
}
6768
NpuOpRunner runner;
@@ -72,7 +73,7 @@ void ConcatKernel(const Context& dev_ctx,
7273
.AddAttr("N", static_cast<int>(inputs.size()));
7374
runner.AddInputNames(names);
7475
runner.Run(stream);
75-
76+
7677
} else {
7778
// TODO(songkai05): In CANN512, Concat doesn't support dtype double,
7879
// so cast double to float32 temporarily until it supports double.
@@ -98,12 +99,13 @@ void ConcatKernel(const Context& dev_ctx,
9899
out_fp32.set_meta(meta_fp32);
99100
dev_ctx.template Alloc<float>(&out_fp32);
100101

101-
if(inputs.size() == 1) {
102-
int index = std::stoi(names[0].substr(1, names[0].size()-1));
103-
*out = *ins[index];
102+
if (inputs.size() == 1) {
103+
int index = std::stoi(names[0].substr(1, names[0].size() - 1));
104+
out->ResizeAndAllocate(ins[index].dims());
105+
TensorCopy(dev_ctx, *ins[index], true, out);
104106
return;
105107
}
106-
108+
107109
NpuOpRunner runner;
108110
runner.SetType("Concat")
109111
.AddInput(dev_ctx, std::move(std::vector<int>(1, axis)))

backends/npu/kernels/funcs/npu_funcs.h

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ inline void TensorCopy(const Context& dev_ctx,
8686
C_Stream stream = static_cast<C_Stream>(dev_ctx.stream());
8787

8888
auto size = src.numel() * paddle::experimental::SizeOf(src.dtype());
89+
if (UNLIKELY(size) == 0) {
90+
return;
91+
}
8992

9093
if (src_place.GetType() == phi::AllocationType::CPU &&
9194
dst_place_.GetType() == phi::AllocationType::CUSTOM) {
@@ -108,17 +111,16 @@ inline void TensorCopy(const Context& dev_ctx,
108111
dev_ctx.Wait();
109112
}
110113
} else {
111-
PADDLE_THROW(phi::errors::Unimplemented(
112-
"TensorCopy is not supported."));
114+
PADDLE_THROW(
115+
phi::errors::Unimplemented("TensorCopy is not supported."));
113116
}
114117
} else {
115-
PADDLE_THROW(phi::errors::Unimplemented(
116-
"TensorCopy is not supported."));
118+
PADDLE_THROW(phi::errors::Unimplemented("TensorCopy is not supported."));
117119
}
118120
} else if (src_place.GetType() == phi::AllocationType::CPU &&
119-
dst_place_.GetType() == phi::AllocationType::CPU) {
120-
std::memcpy(dst_ptr, src_ptr, size);
121-
}
121+
dst_place_.GetType() == phi::AllocationType::CPU) {
122+
std::memcpy(dst_ptr, src_ptr, size);
123+
}
122124
}
123125

124126
/**
@@ -357,9 +359,9 @@ inline void NpuBroadcast(const Context& dev_ctx,
357359
dev_ctx.template Alloc<T>(&tmp_tensor);
358360
NpuOpRunner runner;
359361
runner.SetType("Expand")
360-
.AddInput(tmp_src)
361-
.AddInput(dev_ctx, phi::vectorize<int64_t>(tmp_tensor_dims))
362-
.AddOutput(tmp_tensor);
362+
.AddInput(tmp_src)
363+
.AddInput(dev_ctx, phi::vectorize<int64_t>(tmp_tensor_dims))
364+
.AddOutput(tmp_tensor);
363365
auto stream = dev_ctx.stream();
364366
runner.Run(stream);
365367
tmp_src = tmp_tensor;
@@ -421,12 +423,13 @@ inline void NpuElementWiseOpBroadcast(const Context& dev_ctx,
421423
phi::errors::InvalidArgument(
422424
"Axis should be great than or equal to 0, but received axis is %d.",
423425
axis));
424-
PADDLE_ENFORCE_LE(axis,
425-
max_dim,
426-
phi::errors::InvalidArgument(
427-
"Axis should be less than or equal to %d, but received axis is %d.",
428-
max_dim,
429-
axis));
426+
PADDLE_ENFORCE_LE(
427+
axis,
428+
max_dim,
429+
phi::errors::InvalidArgument(
430+
"Axis should be less than or equal to %d, but received axis is %d.",
431+
max_dim,
432+
axis));
430433

431434
for (int i = 0; i < x_dims.size(); ++i) {
432435
dst_dims_vec[i + x_axis] =

0 commit comments

Comments
 (0)