Skip to content

Commit cb7fd37

Browse files
authored
support c_embedding_grad for kunlun (#51399)
1 parent e3826e0 commit cb7fd37

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

paddle/fluid/operators/collective/c_embedding_op_xpu.cc

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,70 @@ class CEmbeddingOpXPUKernel : public framework::OpKernel<T> {
7171
}
7272
};
7373

74+
template <typename DeviceContext, typename T>
75+
class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> {
76+
public:
77+
void Compute(const framework::ExecutionContext& context) const override {
78+
const int64_t start_idx = context.Attr<int64_t>("start_index");
79+
auto ids_t = context.Input<phi::DenseTensor>("Ids");
80+
auto d_output_t =
81+
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
82+
auto table_t = context.Input<phi::DenseTensor>("W");
83+
auto table_grad_t =
84+
context.Output<phi::DenseTensor>(framework::GradVarName("W"));
85+
86+
T* table_grad_data =
87+
table_grad_t->mutable_data<T>(table_t->dims(), context.GetPlace());
88+
89+
size_t table_t_mem_size =
90+
table_t->numel() * phi::SizeOf(table_grad_t->dtype());
91+
size_t table_grad_t_mem_size =
92+
table_grad_t->numel() *
93+
framework::SizeOfType(
94+
framework::TransToProtoVarType(table_grad_t->dtype()));
95+
96+
VLOG(10) << "table_dims:" << table_t->dims()
97+
<< ", table_t memory_size:" << table_t_mem_size
98+
<< ", table_grad_t memory_size:" << table_grad_t_mem_size
99+
<< ", start_index:" << start_idx;
100+
101+
auto& dev_ctx = context.template device_context<DeviceContext>();
102+
int r = xpu::constant(
103+
dev_ctx.x_context(), table_grad_data, table_grad_t_mem_size, (T)0);
104+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
105+
const T* d_output_data = d_output_t->data<T>();
106+
107+
const int64_t height = table_t->dims()[0];
108+
const int64_t width = table_t->dims()[1];
109+
110+
const auto& index_type = framework::TransToProtoVarType(ids_t->dtype());
111+
if (index_type == framework::proto::VarType::INT32) {
112+
r = xpu::embedding_grad(dev_ctx.x_context(),
113+
d_output_data,
114+
ids_t->data<int32_t>(),
115+
table_grad_data,
116+
height,
117+
width,
118+
ids_t->numel(),
119+
-1,
120+
static_cast<int32_t>(start_idx));
121+
} else if (index_type == framework::proto::VarType::INT64) {
122+
r = xpu::embedding_grad(dev_ctx.x_context(),
123+
d_output_data,
124+
ids_t->data<int64_t>(),
125+
table_grad_data,
126+
height,
127+
width,
128+
ids_t->numel(),
129+
-1,
130+
static_cast<int64_t>(start_idx));
131+
} else {
132+
PADDLE_THROW(platform::errors::Unavailable(
133+
"XPU c_embedding ids only support int32 or int64."));
134+
}
135+
}
136+
};
137+
74138
} // namespace operators
75139
} // namespace paddle
76140

@@ -80,3 +144,6 @@ namespace plat = paddle::platform;
80144
REGISTER_OP_XPU_KERNEL(
81145
c_embedding,
82146
ops::CEmbeddingOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
147+
REGISTER_OP_XPU_KERNEL(
148+
c_embedding_grad,
149+
ops::CEmbeddingGradOpXPUKernel<paddle::platform::XPUDeviceContext, float>);

paddle/phi/backends/xpu/xpu2_op_list.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ XPUOpMap& get_kl2_ops() {
9797
{"c_concat",
9898
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
9999
{"c_embedding", XPUKernelSet({phi::DataType::FLOAT32})},
100+
{"c_embedding_grad", XPUKernelSet({phi::DataType::FLOAT32})},
100101
{"c_identity",
101102
XPUKernelSet({phi::DataType::FLOAT16,
102103
phi::DataType::FLOAT32,

0 commit comments

Comments
 (0)