Skip to content

测试用,无需review #50893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 82 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
ce8067d
remove utils
Liyulingyue Feb 20, 2023
2f818e3
remove utils
Liyulingyue Feb 20, 2023
ce3d484
remove utils
Liyulingyue Feb 20, 2023
500616c
remove utils
Liyulingyue Feb 20, 2023
f58156a
Update get_data_from_tensor.h
Liyulingyue Feb 21, 2023
d41826b
Merge pull request #2 from Liyulingyue/phi2
Liyulingyue Feb 21, 2023
fcf2eaf
Update rnn_functor.h
Liyulingyue Feb 21, 2023
46a9fb7
Update rnn_grad_kernel.cu.cc
Liyulingyue Feb 21, 2023
651dff3
Update rnn_kernel.cu.cc
Liyulingyue Feb 21, 2023
64ff5c6
Update rnn_kernel.cc
Liyulingyue Feb 21, 2023
b29339d
Update rnn_grad_kernel.cu.cc
Liyulingyue Feb 21, 2023
e4403de
Update rnn_functor.h
Liyulingyue Feb 21, 2023
e735745
Update rnn_kernel.cu.cc
Liyulingyue Feb 21, 2023
792ecd7
Update rnn_kernel.cc
Liyulingyue Feb 21, 2023
b3613d5
remove utils
Liyulingyue Feb 21, 2023
03fa265
Update rnn_functor.h
Liyulingyue Feb 21, 2023
e3f94ab
remove utils
Liyulingyue Feb 21, 2023
2589560
remove utils
Liyulingyue Feb 21, 2023
e486c93
remove utils
Liyulingyue Feb 21, 2023
c6bec24
remove utils
Liyulingyue Feb 21, 2023
815e2c3
remove utils
Liyulingyue Feb 21, 2023
e80f3ba
Update rnn_functor.h
Liyulingyue Feb 21, 2023
431d6a0
Update unsqueeze_op.h
Liyulingyue Feb 21, 2023
5d0cfe1
Update utils.h
Liyulingyue Feb 21, 2023
0a59f2d
roll back
Liyulingyue Feb 21, 2023
36a744f
Merge remote-tracking branch 'origin/phi3' into phi3
Liyulingyue Feb 21, 2023
dbc0540
Update tensor_utils.h
Liyulingyue Feb 22, 2023
18d5646
Update tensor_utils.h
Liyulingyue Feb 22, 2023
a5feb29
Update tensor_utils.h
Liyulingyue Feb 22, 2023
57d858b
Update tensor_utils.h
Liyulingyue Feb 22, 2023
9bf4a98
Update tensor_utils.h
Liyulingyue Feb 22, 2023
bea088a
use TensorToVector
Liyulingyue Feb 22, 2023
07c4dc3
use TensorToVector
Liyulingyue Feb 22, 2023
7bbf25b
use TensorToVector
Liyulingyue Feb 22, 2023
950b772
use TensorToVector
Liyulingyue Feb 22, 2023
c3f7746
use TensorToVector
Liyulingyue Feb 22, 2023
4644892
Update rnn_kernel.cc
Liyulingyue Feb 22, 2023
de02e58
Update rnn_grad_kernel.cc
Liyulingyue Feb 22, 2023
89fd1ea
Update rnn_functor.h
Liyulingyue Feb 22, 2023
f2e1aa7
Update rnn_grad_kernel.cu.cc
Liyulingyue Feb 22, 2023
2d9de72
Update rnn_kernel.cu.cc
Liyulingyue Feb 22, 2023
03b7d38
Update rnn_functor.h
Liyulingyue Feb 22, 2023
5b112a2
Update rnn_grad_kernel.cu.cc
Liyulingyue Feb 22, 2023
aa809ab
Update rnn_kernel.cu.cc
Liyulingyue Feb 22, 2023
6466fca
Update rnn_functor.h
Liyulingyue Feb 22, 2023
c6a90b1
Update rnn_grad_kernel.cu.cc
Liyulingyue Feb 22, 2023
2d1dab8
Update rnn_kernel.cu.cc
Liyulingyue Feb 22, 2023
d1e51f8
Merge branch 'phi5' into phi4
Liyulingyue Feb 22, 2023
787eb8b
add TensorToVector
Liyulingyue Feb 22, 2023
89eafa1
Merge remote-tracking branch 'origin/phi4' into phi4
Liyulingyue Feb 22, 2023
ced35ef
roll back
Liyulingyue Feb 22, 2023
acd87e1
Merge remote-tracking branch 'origin/phi3' into phi3
Liyulingyue Feb 22, 2023
405c0e9
Update tensor_utils.h
Liyulingyue Feb 22, 2023
5ca9183
Merge branch 'phi4' into phi3
Liyulingyue Feb 23, 2023
23018e8
Update rnn_functor.h
Liyulingyue Feb 23, 2023
a9c5d25
Update rnn_grad_kernel.cu.cc
Liyulingyue Feb 23, 2023
2823262
Update tensor_utils.h
Liyulingyue Feb 23, 2023
84b7b12
Update rnn_kernel.cu.cc
Liyulingyue Feb 23, 2023
56be7c1
Update rnn_grad_kernel.cc
Liyulingyue Feb 23, 2023
4cdf203
Update rnn_kernel.cc
Liyulingyue Feb 23, 2023
da37650
Update rnn_grad_kernel.cu.cc
Liyulingyue Feb 23, 2023
12c47dd
Update rnn_kernel.cu.cc
Liyulingyue Feb 23, 2023
6fd81b8
Update rnn_grad_kernel.cc
Liyulingyue Feb 23, 2023
50370c9
Update rnn_kernel.cc
Liyulingyue Feb 23, 2023
f7aeb2b
TensorCopySync to phi::Copy
Liyulingyue Feb 23, 2023
6110034
fix codestyle
Liyulingyue Feb 23, 2023
881527f
Merge branch 'PaddlePaddle:develop' into phi5
Liyulingyue Feb 23, 2023
38fc26a
rnn_kernel.cc: add ;
Liyulingyue Feb 24, 2023
5f7c6f6
replace all GetDataFromTensor with phi::GetVectorFromTensor
Liyulingyue Feb 24, 2023
1573f14
rollback unsqueeze_op.h and utils.h
Liyulingyue Feb 24, 2023
13ce91f
rollback everything; remain changes in slice_op, split_op, stride_op,…
Liyulingyue Feb 24, 2023
fa6b7d3
utils.h: use GetDataFromTensor in it GetShape
Liyulingyue Feb 24, 2023
3dbd23e
delete the import of tensor_utils.h in unsqueeze_op.h; remain using o…
Liyulingyue Feb 24, 2023
706b9a1
remove all include of phi in fluid; remain the change of split; add c…
Liyulingyue Feb 25, 2023
ac02dbb
remove phi in strided_slice_op_mlu.cc(which should done in previous c…
Liyulingyue Feb 25, 2023
01704fb
add phi in concat_op_mlu.cc, cudnn_lstm_op.cu.cc, reshape_op_mlu.cc, …
Liyulingyue Feb 25, 2023
3087abc
add phi in slice_op_npu.cc
Liyulingyue Feb 25, 2023
834ddcf
add phi in rnn_op_mlu.cc, slice_op_mlu.cc, strided_slice_op_mlu.cc, u…
Liyulingyue Feb 25, 2023
be670d7
rollback unsqueeze_op.h, and import phi in utils.h
Liyulingyue Feb 25, 2023
e3cc185
rollback util.h, and import phi in strided_slice_op_npu.cc
Liyulingyue Feb 25, 2023
a03b7d7
import phi in reshape_op_npu.cc
Liyulingyue Feb 25, 2023
fb4e7a8
add phi in utils.h, not delete the GetDataFromTensor
Liyulingyue Feb 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/fluid/operators/concat_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/operators/concat_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/phi/core/tensor_utils.h"

namespace paddle {
namespace operators {
Expand All @@ -32,7 +33,7 @@ class ConcatMLUKernel : public framework::OpKernel<T> {
bool need_resize_out_dims = false;
if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<phi::DenseTensor>("AxisTensor");
axis = GetDataFromTensor<int>(axis_tensor)[0];
axis = phi::GetVectorFromTensor<int>(axis_tensor)[0];
need_resize_out_dims = true;
}
axis = ComputeAxis(static_cast<int64_t>(axis),
Expand Down Expand Up @@ -97,7 +98,7 @@ class ConcatGradMLUKernel : public framework::OpKernel<T> {

if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<phi::DenseTensor>("AxisTensor");
axis = GetDataFromTensor<int>(axis_tensor)[0];
axis = phi::GetVectorFromTensor<int>(axis_tensor)[0];
}

axis = ComputeAxis(static_cast<int64_t>(axis),
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/operators/cudnn_lstm_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
Expand Down Expand Up @@ -242,7 +242,7 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
std::vector<int> SequenceLength;
if (has_seq_length) {
auto *sequence_length = ctx.Input<phi::DenseTensor>("SequenceLength");
SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
SequenceLength = phi::GetVectorFromTensor<int>(sequence_length);
}

auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
Expand Down Expand Up @@ -532,7 +532,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
std::vector<int> SequenceLength;
if (has_seq_length) {
auto *sequence_length = ctx.Input<phi::DenseTensor>("SequenceLength");
SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
SequenceLength = phi::GetVectorFromTensor<int>(sequence_length);
}

int seq_length = input_dims[0];
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/operators/interpolate_v2_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/interpolate_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/phi/core/tensor_utils.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -96,7 +96,7 @@ class InterpolateV2MLUKernel : public framework::OpKernel<T> {
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
std::vector<float> scale_data;
scale_data = GetDataFromTensor<float>(scale_tensor);
scale_data = phi::GetVectorFromTensor<float>(scale_tensor);

if (scale_data.size() > 1 && scale_data.size() <= 2) {
scale_h = scale_data[0];
Expand Down Expand Up @@ -147,7 +147,7 @@ class InterpolateV2MLUKernel : public framework::OpKernel<T> {
auto out_size = ctx.Input<phi::DenseTensor>("OutSize");
if (out_size != nullptr) {
std::vector<int32_t> out_size_data;
out_size_data = GetDataFromTensor<int>(out_size);
out_size_data = phi::GetVectorFromTensor<int>(out_size);
if (out_size_data.size() <= 2) {
out_h = out_size_data[0];
out_w = out_size_data[1];
Expand Down Expand Up @@ -398,7 +398,7 @@ class InterpolateV2GradMLUKernel : public framework::OpKernel<T> {
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
std::vector<float> scale_data;
scale_data = GetDataFromTensor<float>(scale_tensor);
scale_data = phi::GetVectorFromTensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
Expand Down Expand Up @@ -430,7 +430,7 @@ class InterpolateV2GradMLUKernel : public framework::OpKernel<T> {
auto out_size = ctx.Input<phi::DenseTensor>("OutSize");
if (out_size != nullptr) {
std::vector<int32_t> out_size_data;
out_size_data = GetDataFromTensor<int>(out_size);
out_size_data = phi::GetVectorFromTensor<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/operators/one_hot_v2_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/phi/core/tensor_utils.h"

namespace paddle {
namespace operators {
Expand All @@ -31,8 +31,8 @@ class OneHotV2MLUKernel : public framework::OpKernel<T> {
int depth = ctx.Attr<int>("depth");
if (ctx.HasInput("depth_tensor")) {
std::vector<int32_t> depth_data;
depth_data =
GetDataFromTensor<int>(ctx.Input<phi::DenseTensor>("depth_tensor"));
depth_data = phi::GetVectorFromTensor<int>(
ctx.Input<phi::DenseTensor>("depth_tensor"));
depth = depth_data[0];

auto out_dims = out->dims();
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/reshape_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/phi/core/tensor_utils.h"

namespace paddle {
namespace operators {
Expand All @@ -38,14 +38,15 @@ class Reshape2MLUKernel : public framework::OpKernel<T> {
"shape is [%d]",
shape_tensor->dims().size()));

target_shape_vector.push_back(GetDataFromTensor<int>(shape_tensor)[0]);
target_shape_vector.push_back(
phi::GetVectorFromTensor<int>(shape_tensor)[0]);
}
} else {
auto* shape_tensor = ctx.HasInput("Shape")
? ctx.Input<phi::DenseTensor>("Shape")
: nullptr;
if (shape_tensor) {
target_shape_vector = GetDataFromTensor<int>(shape_tensor);
target_shape_vector = phi::GetVectorFromTensor<int>(shape_tensor);
} else {
target_shape_vector = ctx.Attr<std::vector<int>>("shape");
PADDLE_ENFORCE_GT(
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/reshape_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License. */
#include <string>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/core/tensor_utils.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -46,14 +46,15 @@ class Reshape2NPUKernel : public framework::OpKernel<T> {
"shape is [%d]",
shape_tensor->dims().size()));

target_shape_vector.push_back(GetDataFromTensor<int>(shape_tensor)[0]);
target_shape_vector.push_back(
phi::GetVectorFromTensor<int>(shape_tensor)[0]);
}
} else {
auto* shape_tensor = ctx.HasInput("Shape")
? ctx.Input<phi::DenseTensor>("Shape")
: nullptr;
if (shape_tensor) {
target_shape_vector = GetDataFromTensor<int>(shape_tensor);
target_shape_vector = phi::GetVectorFromTensor<int>(shape_tensor);
} else {
target_shape_vector = ctx.Attr<std::vector<int>>("shape");
PADDLE_ENFORCE_GT(
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/operators/rnn_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace paddle {
Expand Down Expand Up @@ -97,7 +97,7 @@ class RNNMLUKernel : public framework::OpKernel<T> {
std::vector<int> seq_len_vec(batch_size, seq_len);
if (has_seq_length) { // set seq_len if no padding, otherwise seq_len for
// each element.
seq_len_vec = operators::GetDataFromTensor(sequence_length);
seq_len_vec = phi::GetVectorFromTensor(sequence_length);
}
cnnlDirectionMode_t direction =
is_bidirec ? CNNL_RNN_BIDIRECTIONAL : CNNL_RNN_UNIDIRECTIONAL;
Expand Down Expand Up @@ -480,7 +480,7 @@ class RNNMLUGradKernel : public framework::OpKernel<T> {

std::vector<int> seq_len_vec(batch_size, seq_len);
if (has_seq_length) {
seq_len_vec = operators::GetDataFromTensor(sequence_length);
seq_len_vec = phi::GetVectorFromTensor(sequence_length);
}
cnnlDirectionMode_t direction =
is_bidirec ? CNNL_RNN_BIDIRECTIONAL : CNNL_RNN_UNIDIRECTIONAL;
Expand Down
15 changes: 9 additions & 6 deletions paddle/fluid/operators/slice_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"

namespace paddle {
Expand All @@ -38,15 +39,16 @@ class SliceMLUKernel : public framework::OpKernel<T> {
auto starts_tensor_list =
ctx.MultiInput<phi::DenseTensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
starts =
GetDataFromTensor<int>(ctx.Input<phi::DenseTensor>("StartsTensor"));
starts = phi::GetVectorFromTensor<int>(
ctx.Input<phi::DenseTensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int>(starts_tensor_list);
}

auto ends_tensor_list = ctx.MultiInput<phi::DenseTensor>("EndsTensorList");
if (ctx.HasInput("EndsTensor")) {
ends = GetDataFromTensor<int>(ctx.Input<phi::DenseTensor>("EndsTensor"));
ends = phi::GetVectorFromTensor<int>(
ctx.Input<phi::DenseTensor>("EndsTensor"));
} else if (ends_tensor_list.size() > 0) {
ends = GetDataFromTensorList<int>(ends_tensor_list);
}
Expand Down Expand Up @@ -141,15 +143,16 @@ class SliceGradMLUKernel : public framework::OpKernel<T> {
auto starts_tensor_list =
ctx.MultiInput<phi::DenseTensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
starts =
GetDataFromTensor<int>(ctx.Input<phi::DenseTensor>("StartsTensor"));
starts = phi::GetVectorFromTensor<int>(
ctx.Input<phi::DenseTensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int>(starts_tensor_list);
}

auto ends_tensor_list = ctx.MultiInput<phi::DenseTensor>("EndsTensorList");
if (ctx.HasInput("EndsTensor")) {
ends = GetDataFromTensor<int>(ctx.Input<phi::DenseTensor>("EndsTensor"));
ends = phi::GetVectorFromTensor<int>(
ctx.Input<phi::DenseTensor>("EndsTensor"));
} else if (ends_tensor_list.size() > 0) {
ends = GetDataFromTensorList<int>(ends_tensor_list);
}
Expand Down
15 changes: 9 additions & 6 deletions paddle/fluid/operators/slice_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"

namespace paddle {
Expand Down Expand Up @@ -77,15 +78,16 @@ class SliceNPUKernel : public framework::OpKernel<T> {
auto starts_tensor_list =
ctx.MultiInput<phi::DenseTensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
starts =
GetDataFromTensor<int>(ctx.Input<phi::DenseTensor>("StartsTensor"));
starts = phi::GetVectorFromTensor<int>(
ctx.Input<phi::DenseTensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int>(starts_tensor_list);
}

auto ends_tensor_list = ctx.MultiInput<phi::DenseTensor>("EndsTensorList");
if (ctx.HasInput("EndsTensor")) {
ends = GetDataFromTensor<int>(ctx.Input<phi::DenseTensor>("EndsTensor"));
ends = phi::GetVectorFromTensor<int>(
ctx.Input<phi::DenseTensor>("EndsTensor"));
} else if (ends_tensor_list.size() > 0) {
ends = GetDataFromTensorList<int>(ends_tensor_list);
}
Expand Down Expand Up @@ -172,15 +174,16 @@ class SliceGradNPUKernel : public framework::OpKernel<T> {
auto starts_tensor_list =
ctx.MultiInput<phi::DenseTensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
starts =
GetDataFromTensor<int>(ctx.Input<phi::DenseTensor>("StartsTensor"));
starts = phi::GetVectorFromTensor<int>(
ctx.Input<phi::DenseTensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int>(starts_tensor_list);
}

auto ends_tensor_list = ctx.MultiInput<phi::DenseTensor>("EndsTensorList");
if (ctx.HasInput("EndsTensor")) {
ends = GetDataFromTensor<int>(ctx.Input<phi::DenseTensor>("EndsTensor"));
ends = phi::GetVectorFromTensor<int>(
ctx.Input<phi::DenseTensor>("EndsTensor"));
} else if (ends_tensor_list.size() > 0) {
ends = GetDataFromTensorList<int>(ends_tensor_list);
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/split_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/split_op.h"
#include "paddle/phi/core/tensor_utils.h"

namespace paddle {
namespace operators {
Expand All @@ -35,7 +36,7 @@ class SplitMLUKernel : public framework::OpKernel<T> {
bool need_resize_outs_dims = false;
if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<phi::DenseTensor>("AxisTensor");
axis = GetDataFromTensor(axis_tensor)[0];
axis = phi::GetVectorFromTensor(axis_tensor)[0];
need_resize_outs_dims = true;
}
auto sections_tensor_list =
Expand Down
13 changes: 7 additions & 6 deletions paddle/fluid/operators/strided_slice_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"

namespace paddle {
Expand Down Expand Up @@ -168,21 +169,21 @@ class StridedSliceMLUKernel : public framework::OpKernel<T> {
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
} else if (ctx.HasInput("StartsTensor")) {
auto* starts_tensor = ctx.Input<phi::DenseTensor>("StartsTensor");
starts = GetDataFromTensor<int64_t>(starts_tensor);
starts = phi::GetVectorFromTensor<int64_t>(starts_tensor);
}

if (list_new_ends_tensor.size() > 0) {
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
} else if (ctx.HasInput("EndsTensor")) {
auto* ends_tensor = ctx.Input<phi::DenseTensor>("EndsTensor");
ends = GetDataFromTensor<int64_t>(ends_tensor);
ends = phi::GetVectorFromTensor<int64_t>(ends_tensor);
}

if (list_new_strides_tensor.size() > 0) {
strides = GetDataFromTensorList<int64_t>(list_new_strides_tensor);
} else if (ctx.HasInput("StridesTensor")) {
auto* strides_tensor = ctx.Input<phi::DenseTensor>("StridesTensor");
strides = GetDataFromTensor<int64_t>(strides_tensor);
strides = phi::GetVectorFromTensor<int64_t>(strides_tensor);
}

// out dims calculation
Expand Down Expand Up @@ -336,21 +337,21 @@ class StridedSliceGradMLUKernel : public framework::OpKernel<T> {
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
} else if (ctx.HasInput("StartsTensor")) {
auto* starts_tensor = ctx.Input<phi::DenseTensor>("StartsTensor");
starts = GetDataFromTensor<int64_t>(starts_tensor);
starts = phi::GetVectorFromTensor<int64_t>(starts_tensor);
}

if (list_new_ends_tensor.size() > 0) {
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
} else if (ctx.HasInput("EndsTensor")) {
auto* ends_tensor = ctx.Input<phi::DenseTensor>("EndsTensor");
ends = GetDataFromTensor<int64_t>(ends_tensor);
ends = phi::GetVectorFromTensor<int64_t>(ends_tensor);
}

if (list_new_strides_tensor.size() > 0) {
strides = GetDataFromTensorList<int64_t>(list_new_strides_tensor);
} else if (ctx.HasInput("StridesTensor")) {
auto* strides_tensor = ctx.Input<phi::DenseTensor>("StridesTensor");
strides = GetDataFromTensor<int64_t>(strides_tensor);
strides = phi::GetVectorFromTensor<int64_t>(strides_tensor);
}

std::vector<int64_t> out_dims_vector(input_dims.size(), -1);
Expand Down
Loading