|
18 | 18 | namespace custom_kernel {
|
19 | 19 |
|
20 | 20 | template <typename T, typename Context>
|
21 |
| -void AclopSplitKernel(const Context& dev_ctx, |
22 |
| - const phi::DenseTensor& x, |
23 |
| - const phi::IntArray& num_or_sections, |
24 |
| - const phi::Scalar& axis_scalar, |
25 |
| - std::vector<phi::DenseTensor*> outs) { |
| 21 | +void SplitKernel(const Context& dev_ctx, |
| 22 | + const phi::DenseTensor& x, |
| 23 | + const phi::IntArray& num_or_sections, |
| 24 | + const phi::Scalar& axis_scalar, |
| 25 | + std::vector<phi::DenseTensor*> outs) { |
26 | 26 | // need to infershape output
|
27 | 27 | auto sections = num_or_sections.GetData();
|
28 | 28 | int axis = axis_scalar.to<int>();
|
@@ -77,58 +77,6 @@ void AclopSplitKernel(const Context& dev_ctx,
|
77 | 77 | }
|
78 | 78 | }
|
79 | 79 |
|
80 |
| -template <typename T, typename Context> |
81 |
| -void SplitKernel(const Context& dev_ctx, |
82 |
| - const phi::DenseTensor& x, |
83 |
| - const phi::IntArray& num_or_sections, |
84 |
| - const phi::Scalar& axis_scalar, |
85 |
| - std::vector<phi::DenseTensor*> outs) { |
86 |
| - auto sections = num_or_sections.GetData(); |
87 |
| - int64_t axis = axis_scalar.to<int64_t>(); |
88 |
| - if (axis < 0) { |
89 |
| - axis = x.dims().size() + axis; |
90 |
| - } |
91 |
| - |
92 |
| - std::vector<phi::DenseTensor*> outputs; |
93 |
| - for (size_t j = 0; j < outs.size(); ++j) { |
94 |
| - dev_ctx.template Alloc<T>(outs[j]); |
95 |
| - outputs.push_back(outs[j]); |
96 |
| - } |
97 |
| - |
98 |
| - if (!num_or_sections.FromTensor() && !axis_scalar.FromTensor() && |
99 |
| - sections.size() == 1 && outs.size() == sections[0]) { |
100 |
| - DO_COMPATIBILITY(aclnnSplitTensor, |
101 |
| - (custom_kernel::AclopSplitKernel<T, Context>( |
102 |
| - dev_ctx, x, num_or_sections, axis_scalar, outs))); |
103 |
| - |
104 |
| - uint64_t splitSections = x.dims()[axis] / sections[0]; |
105 |
| - EXEC_NPU_CMD(aclnnSplitTensor, dev_ctx, x, splitSections, axis, outputs); |
106 |
| - } else { |
107 |
| - DO_COMPATIBILITY(aclnnSplitWithSize, |
108 |
| - (custom_kernel::AclopSplitKernel<T, Context>( |
109 |
| - dev_ctx, x, num_or_sections, axis_scalar, outs))); |
110 |
| - |
111 |
| - std::vector<int64_t> sections_; |
112 |
| - int sum = 0; |
113 |
| - int minusOneIndex = -1; |
114 |
| - int all = x.dims()[axis]; |
115 |
| - |
116 |
| - for (int i = 0; i < sections.size(); ++i) { |
117 |
| - sections_.push_back(sections[i]); |
118 |
| - if (sections_[i] == -1) { |
119 |
| - minusOneIndex = i; |
120 |
| - } else { |
121 |
| - sum += sections_[i]; |
122 |
| - } |
123 |
| - } |
124 |
| - if (minusOneIndex != -1) { |
125 |
| - sections_[minusOneIndex] = all - sum; |
126 |
| - } |
127 |
| - |
128 |
| - EXEC_NPU_CMD(aclnnSplitWithSize, dev_ctx, x, sections_, axis, outputs); |
129 |
| - } |
130 |
| -} |
131 |
| - |
132 | 80 | template <typename T, typename Context>
|
133 | 81 | void SplitWithNumKernel(const Context& dev_ctx,
|
134 | 82 | const phi::DenseTensor& x,
|
|
0 commit comments