@@ -71,6 +71,70 @@ class CEmbeddingOpXPUKernel : public framework::OpKernel<T> {
71
71
}
72
72
};
73
73
74
+ template <typename DeviceContext, typename T>
75
+ class CEmbeddingGradOpXPUKernel : public framework ::OpKernel<T> {
76
+ public:
77
+ void Compute (const framework::ExecutionContext& context) const override {
78
+ const int64_t start_idx = context.Attr <int64_t >(" start_index" );
79
+ auto ids_t = context.Input <phi::DenseTensor>(" Ids" );
80
+ auto d_output_t =
81
+ context.Input <phi::DenseTensor>(framework::GradVarName (" Out" ));
82
+ auto table_t = context.Input <phi::DenseTensor>(" W" );
83
+ auto table_grad_t =
84
+ context.Output <phi::DenseTensor>(framework::GradVarName (" W" ));
85
+
86
+ T* table_grad_data =
87
+ table_grad_t ->mutable_data <T>(table_t ->dims (), context.GetPlace ());
88
+
89
+ size_t table_t_mem_size =
90
+ table_t ->numel () * phi::SizeOf (table_grad_t ->dtype ());
91
+ size_t table_grad_t_mem_size =
92
+ table_grad_t ->numel () *
93
+ framework::SizeOfType (
94
+ framework::TransToProtoVarType (table_grad_t ->dtype ()));
95
+
96
+ VLOG (10 ) << " table_dims:" << table_t ->dims ()
97
+ << " , table_t memory_size:" << table_t_mem_size
98
+ << " , table_grad_t memory_size:" << table_grad_t_mem_size
99
+ << " , start_index:" << start_idx;
100
+
101
+ auto & dev_ctx = context.template device_context <DeviceContext>();
102
+ int r = xpu::constant (
103
+ dev_ctx.x_context (), table_grad_data, table_grad_t_mem_size, (T)0 );
104
+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " constant" );
105
+ const T* d_output_data = d_output_t ->data <T>();
106
+
107
+ const int64_t height = table_t ->dims ()[0 ];
108
+ const int64_t width = table_t ->dims ()[1 ];
109
+
110
+ const auto & index_type = framework::TransToProtoVarType (ids_t ->dtype ());
111
+ if (index_type == framework::proto::VarType::INT32) {
112
+ r = xpu::embedding_grad (dev_ctx.x_context (),
113
+ d_output_data,
114
+ ids_t ->data <int32_t >(),
115
+ table_grad_data,
116
+ height,
117
+ width,
118
+ ids_t ->numel (),
119
+ -1 ,
120
+ static_cast <int32_t >(start_idx));
121
+ } else if (index_type == framework::proto::VarType::INT64) {
122
+ r = xpu::embedding_grad (dev_ctx.x_context (),
123
+ d_output_data,
124
+ ids_t ->data <int64_t >(),
125
+ table_grad_data,
126
+ height,
127
+ width,
128
+ ids_t ->numel (),
129
+ -1 ,
130
+ static_cast <int64_t >(start_idx));
131
+ } else {
132
+ PADDLE_THROW (platform::errors::Unavailable (
133
+ " XPU c_embedding ids only support int32 or int64." ));
134
+ }
135
+ }
136
+ };
137
+
74
138
} // namespace operators
75
139
} // namespace paddle
76
140
@@ -80,3 +144,6 @@ namespace plat = paddle::platform;
80
144
REGISTER_OP_XPU_KERNEL (
81
145
c_embedding,
82
146
ops::CEmbeddingOpXPUKernel<paddle::platform::XPUDeviceContext, float >);
147
+ REGISTER_OP_XPU_KERNEL (
148
+ c_embedding_grad,
149
+ ops::CEmbeddingGradOpXPUKernel<paddle::platform::XPUDeviceContext, float >);
0 commit comments