Skip to content

Commit 26414ec

Browse files
authored
[MLU] change log of cncl & fix kernels. (PaddlePaddle#565)
1 parent 3c3c9e5 commit 26414ec

File tree

5 files changed

+58
-29
lines changed

5 files changed

+58
-29
lines changed

backends/mlu/kernels/adam_kernel.cc

+16-22
Original file line numberDiff line numberDiff line change
@@ -231,29 +231,23 @@ void AdamWKernel(const Context& dev_ctx,
231231
VLOG(3) << "Skip update" << skip_update_ << ", With decay: " << with_decay;
232232

233233
if (!skip_update_ && with_decay) {
234-
if (master_param.is_initialized()) {
235-
PADDLE_THROW(
236-
phi::errors::Unimplemented("Master Param is not supported on MLU"));
237-
} else {
238-
// update param with decay coeff: mul(-1 * lr, coeff * param) + param
239-
MLUCnnlTensorDesc lr_desc(learning_rate);
240-
MLUCnnlTensorDesc param_desc(param);
241-
MLUCnnlOpTensorDesc mul_op_desc(
242-
CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(), CNNL_NOT_PROPAGATE_NAN);
234+
MLUCnnlTensorDesc lr_desc(learning_rate);
235+
MLUCnnlTensorDesc param_desc(param);
236+
MLUCnnlOpTensorDesc mul_op_desc(
237+
CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(), CNNL_NOT_PROPAGATE_NAN);
243238

244-
MLUCnnl::OpTensor(dev_ctx,
245-
mul_op_desc.get(),
246-
lr_desc.get(),
247-
GetBasePtr(&learning_rate),
248-
param_desc.get(),
249-
GetBasePtr(&param),
250-
param_desc.get(),
251-
const_cast<void*>(GetBasePtr(&param)),
252-
ToCnnlDataType<T>(),
253-
/*alpha1*/ -1.f,
254-
/*alpha2*/ coeff,
255-
/*beta*/ 1.f);
256-
}
239+
MLUCnnl::OpTensor(dev_ctx,
240+
mul_op_desc.get(),
241+
lr_desc.get(),
242+
GetBasePtr(&learning_rate),
243+
param_desc.get(),
244+
GetBasePtr(&param),
245+
param_desc.get(),
246+
const_cast<void*>(GetBasePtr(&param)),
247+
ToCnnlDataType<T>(),
248+
/*alpha1*/ -1.f,
249+
/*alpha2*/ coeff,
250+
/*beta*/ 1.f);
257251
}
258252

259253
custom_kernel::AdamKernel<T, Context>(dev_ctx,

backends/mlu/kernels/elementwise_div_kernel.cc

+8-3
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,13 @@ void DivideGradKernel(const Context& dev_ctx,
4343
int axis,
4444
phi::DenseTensor* dx,
4545
phi::DenseTensor* dy) {
46-
const auto& x_dims = x.dims();
47-
const auto& y_dims = y.dims();
46+
Tensor x_t, y_t;
47+
x_t = x;
48+
y_t = y;
49+
if (x.dims().size() == 0) x_t.Resize(phi::make_ddim({1}));
50+
if (y.dims().size() == 0) y_t.Resize(phi::make_ddim({1}));
51+
const auto& x_dims = x_t.dims();
52+
const auto& y_dims = y_t.dims();
4853
axis =
4954
(axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1) : axis);
5055
int max_dim = std::max(x_dims.size(), y_dims.size());
@@ -74,7 +79,7 @@ void DivideGradKernel(const Context& dev_ctx,
7479
dout_desc.get(),
7580
GetBasePtr(&dout),
7681
y_desc.get(),
77-
GetBasePtr(&y),
82+
GetBasePtr(&y_t),
7883
dout_desc.get(),
7984
GetBasePtr(&dout_div_y));
8085

backends/mlu/kernels/funcs/reduce_op.h

+5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "kernels/funcs/mlu_baseop.h"
18+
#include "kernels/funcs/mlu_funcs.h"
1819

1920
namespace custom_kernel {
2021

@@ -27,6 +28,10 @@ void MLUReduceOp(const Context& dev_ctx,
2728
const std::string& reduce_name,
2829
phi::DenseTensor* out) {
2930
dev_ctx.template Alloc<T>(out);
31+
if (x.dims().size() == 0) {
32+
TensorCopy(dev_ctx, x, true, out);
33+
return;
34+
}
3035

3136
auto dims = axes;
3237
auto input_dims = phi::vectorize(x.dims());

backends/mlu/kernels/split_kernel.cc

+26
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,23 @@ void SplitKernel(const Context& dev_ctx,
6969
vct_tensor.data());
7070
}
7171

72+
template <typename T, typename Context>
73+
void SplitWithNumKernel(const Context& dev_ctx,
74+
const phi::DenseTensor& x,
75+
int num,
76+
const phi::Scalar& axis_scalar,
77+
std::vector<phi::DenseTensor*> outs) {
78+
int axis_value = axis_scalar.to<int>();
79+
auto input_axis_dim = x.dims().at(axis_value);
80+
std::vector<int64_t> sections_vec;
81+
for (int i = 0; i < num; ++i) {
82+
sections_vec.push_back(input_axis_dim / num);
83+
}
84+
phi::IntArray sections(sections_vec);
85+
custom_kernel::SplitKernel<T, Context>(
86+
dev_ctx, x, sections, axis_scalar, outs);
87+
}
88+
7289
} // namespace custom_kernel
7390

7491
PD_REGISTER_PLUGIN_KERNEL(split,
@@ -80,3 +97,12 @@ PD_REGISTER_PLUGIN_KERNEL(split,
8097
int,
8198
bool,
8299
phi::dtype::float16) {}
100+
PD_REGISTER_PLUGIN_KERNEL(split_with_num,
101+
mlu,
102+
ALL_LAYOUT,
103+
custom_kernel::SplitWithNumKernel,
104+
float,
105+
int64_t,
106+
int,
107+
bool,
108+
phi::dtype::float16) {}

backends/mlu/runtime/runtime.cc

+3-4
Original file line numberDiff line numberDiff line change
@@ -324,17 +324,16 @@ C_Status XcclCommInitRank(size_t nranks,
324324
PADDLE_ENFORCE_MLU_SUCCESS(cnrtGetDevice(&dev_id));
325325
int dev_list[] = {dev_id};
326326
int rank_list[] = {rank};
327-
VLOG(4) << "[CNCL] create comm.";
327+
VLOG(4) << "[CNCL] create comm with rank: " << rank << " clique: "
328+
<< reinterpret_cast<cnclCliqueId *>(unique_id->data)->hash;
328329
PADDLE_ENFORCE_MLU_SUCCESS(
329330
cnclInitComms(reinterpret_cast<cnclComm_t *>(comm),
330331
1,
331332
dev_list,
332333
rank_list,
333334
nranks,
334335
reinterpret_cast<cnclCliqueId *>(unique_id->data)));
335-
VLOG(4) << "[CNCL] comm inited: " << reinterpret_cast<cnclComm_t>(*comm)
336-
<< " clique: "
337-
<< reinterpret_cast<cnclCliqueId *>(unique_id->data)->hash;
336+
VLOG(4) << "[CNCL] comm inited: " << reinterpret_cast<cnclComm_t>(*comm);
338337
return C_SUCCESS;
339338
}
340339

0 commit comments

Comments
 (0)