Skip to content

Commit 2353db3

Browse files
authored
[IPU] add activation ops (#43662)
* add argmin and argsort ops (#800) * add argmin and arsort ops * Add dot bmm ops (#803) * add bmm * add dot op * clean CreateConst * clean CreateCast * add activation ops (#808) * add activation ops * fix 1function-redefined error
1 parent 2a795df commit 2353db3

14 files changed

+986
-63
lines changed

paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,21 @@ Node *tanh_handler(Graph *graph, Node *node) {
119119
return activation_op_handler(graph, node, "popart_tanh");
120120
}
121121

122+
Node *brelu_handler(Graph *graph, Node *node) {
123+
auto *op = node->Op();
124+
auto t_min_ = BOOST_GET_CONST(float, op->GetAttr("t_min"));
125+
auto t_max_ = BOOST_GET_CONST(float, op->GetAttr("t_max"));
126+
auto x = GetInputVarNode("X", node);
127+
auto cli_min = CreateConst(graph, node, std::vector<float>{t_min_}, {1},
128+
ONNXDataType::FLOAT)
129+
->outputs.front();
130+
auto clip_max = CreateConst(graph, node, std::vector<float>{t_max_}, {1},
131+
ONNXDataType::FLOAT)
132+
->outputs.front();
133+
return CreateBaseOp(graph, node, "popart_clip", {x, cli_min, clip_max},
134+
node->outputs);
135+
}
136+
122137
Node *gelu_handler(Graph *graph, Node *node) {
123138
auto *op = node->Op();
124139
auto approximate_ = BOOST_GET_CONST(bool, op->GetAttr("approximate"));
@@ -160,6 +175,245 @@ Node *log_softmax_handler(Graph *graph, Node *node) {
160175
node->outputs);
161176
}
162177

178+
Node *elu_handler(Graph *graph, Node *node) {
179+
auto alpha_ = BOOST_GET_CONST(float, node->Op()->GetAttr("alpha"));
180+
return CreateBaseOp(graph, node, "popart_elu", node->inputs, node->outputs,
181+
{
182+
{"alpha", alpha_},
183+
});
184+
}
185+
186+
Node *hard_shrink_handler(Graph *graph, Node *node) {
187+
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
188+
return CreateBaseOp(graph, node, "popart_shrink", node->inputs, node->outputs,
189+
{
190+
{"lambd", threshold_},
191+
{"bias", 0.0f},
192+
});
193+
}
194+
195+
Node *hard_sigmoid_handler(Graph *graph, Node *node) {
196+
auto slope_ = BOOST_GET_CONST(float, node->Op()->GetAttr("slope"));
197+
auto offset_ = BOOST_GET_CONST(float, node->Op()->GetAttr("offset"));
198+
return CreateBaseOp(graph, node, "popart_hardsigmoid", node->inputs,
199+
node->outputs,
200+
{
201+
{"alpha", slope_},
202+
{"beta", offset_},
203+
});
204+
}
205+
206+
Node *hard_swish_handler(Graph *graph, Node *node) {
207+
auto x = GetInputVarNode("X", node);
208+
auto scale_ = BOOST_GET_CONST(float, node->Op()->GetAttr("scale"));
209+
auto offset_ = BOOST_GET_CONST(float, node->Op()->GetAttr("offset"));
210+
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
211+
auto scale_node =
212+
CreateConst(graph, node, std::vector<float>{scale_}, {1}, GetVarDType(x))
213+
->outputs.front();
214+
auto offset_node =
215+
CreateConst(graph, node, std::vector<float>{offset_}, {1}, GetVarDType(x))
216+
->outputs.front();
217+
auto add_node = CreateBaseOp(graph, node, "popart_add", {x, offset_node}, {})
218+
->outputs.front();
219+
auto cli_min = CreateConst(graph, node, std::vector<float>{0.0}, {1},
220+
ONNXDataType::FLOAT)
221+
->outputs.front();
222+
auto clip_max = CreateConst(graph, node, std::vector<float>{threshold_}, {1},
223+
ONNXDataType::FLOAT)
224+
->outputs.front();
225+
auto clip_node = CreateBaseOp(graph, node, "popart_clip",
226+
{add_node, cli_min, clip_max}, {})
227+
->outputs.front();
228+
auto mul_node = CreateBaseOp(graph, node, "popart_mul", {x, clip_node}, {})
229+
->outputs.front();
230+
return CreateBaseOp(graph, node, "popart_div", {mul_node, scale_node},
231+
{GetOutputVarNode("Out", node)});
232+
}
233+
234+
Node *leaky_relu_handler(Graph *graph, Node *node) {
235+
auto alpha_ = BOOST_GET_CONST(float, node->Op()->GetAttr("alpha"));
236+
return CreateBaseOp(graph, node, "popart_leakyrelu", node->inputs,
237+
node->outputs,
238+
{
239+
{"alpha", alpha_},
240+
});
241+
}
242+
243+
Node *log10_handler(Graph *graph, Node *node) {
244+
auto x = GetInputVarNode("X", node);
245+
float ln10 = 2.30258509299404568401;
246+
auto ln10_tensor =
247+
CreateConst(graph, node, std::vector<float>{ln10}, {1}, GetVarDType(x))
248+
->outputs.front();
249+
auto log = CreateBaseOp(graph, node, "popart_log", {x}, {})->outputs.front();
250+
return CreateBaseOp(graph, node, "popart_div", {log, ln10_tensor},
251+
node->outputs);
252+
}
253+
254+
Node *log1p_handler(Graph *graph, Node *node) {
255+
auto x = GetInputVarNode("X", node);
256+
auto one =
257+
CreateConst(graph, node, std::vector<float>{1.0}, {1}, GetVarDType(x))
258+
->outputs.front();
259+
auto add =
260+
CreateBaseOp(graph, node, "popart_add", {x, one}, {})->outputs.front();
261+
return CreateBaseOp(graph, node, "popart_log", {add}, node->outputs);
262+
}
263+
264+
Node *log2_handler(Graph *graph, Node *node) {
265+
auto x = GetInputVarNode("X", node);
266+
float ln2 = 0.693147180559945309;
267+
auto ln2_tensor =
268+
CreateConst(graph, node, std::vector<float>{ln2}, {1}, GetVarDType(x))
269+
->outputs.front();
270+
auto log = CreateBaseOp(graph, node, "popart_log", {x}, {})->outputs.front();
271+
return CreateBaseOp(graph, node, "popart_div", {log, ln2_tensor},
272+
node->outputs);
273+
}
274+
275+
Node *logsigmoid_handler(Graph *graph, Node *node) {
276+
auto sigmoid = CreateBaseOp(graph, node, "popart_sigmoid",
277+
{GetInputVarNode("X", node)}, {})
278+
->outputs.front();
279+
return CreateBaseOp(graph, node, "popart_log", {sigmoid}, node->outputs);
280+
}
281+
282+
Node *mish_handler(Graph *graph, Node *node) {
283+
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
284+
if (!is_float_equal(threshold_, 20.0f)) {
285+
PADDLE_THROW(platform::errors::Unimplemented(
286+
"For mish op, only support threshold = 20.0"));
287+
}
288+
auto x = GetInputVarNode("X", node);
289+
auto softplus =
290+
CreateBaseOp(graph, node, "popart_softplus", {x}, {})->outputs.front();
291+
auto tanh =
292+
CreateBaseOp(graph, node, "popart_tanh", {softplus}, {})->outputs.front();
293+
return CreateBaseOp(graph, node, "popart_mul", {x, tanh}, node->outputs);
294+
}
295+
296+
Node *prelu_handler(Graph *graph, Node *node) {
297+
auto x = GetInputVarNode("X", node);
298+
auto alpha = GetInputVarNode("Alpha", node);
299+
auto out = GetOutputVarNode("Out", node);
300+
auto x_rank = x->Var()->GetShape().size();
301+
auto alpha_rank = alpha->Var()->GetShape().size();
302+
if (x_rank != alpha_rank) {
303+
if (alpha_rank > 1) {
304+
PADDLE_THROW(platform::errors::Unimplemented(
305+
"For prelu op, Only support rank of alpha <=1 while Rank(alpha) != "
306+
"Rank(input)."));
307+
}
308+
}
309+
310+
if (x_rank != alpha_rank) {
311+
if (alpha_rank > 1) {
312+
PADDLE_THROW(platform::errors::Unimplemented(
313+
"For prelu op, Only support rank of alpha <= 1 while rank of alpha "
314+
"is not equal with rank of input for operator prelu"));
315+
}
316+
if (x_rank <= 1) {
317+
PADDLE_THROW(
318+
platform::errors::Unimplemented("For prelu op, Rank of input should "
319+
"greater than 2 for operator prelu"));
320+
}
321+
auto shape = std::vector<int64_t>(x_rank - 1, 1);
322+
shape[0] = -1;
323+
int64_t size = shape.size();
324+
auto dim = std::vector<int64_t>{size};
325+
auto reshape_const =
326+
CreateConst(graph, node, shape, dim, ONNXDataType::INT64)
327+
->outputs.front();
328+
alpha =
329+
CreateBaseOp(graph, node, "popart_reshape", {alpha, reshape_const}, {})
330+
->outputs.front();
331+
}
332+
return CreateBaseOp(graph, node, "popart_prelu", {x, alpha}, {out});
333+
}
334+
335+
Node *relu6_handler(Graph *graph, Node *node) {
336+
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
337+
auto cli_min = CreateConst(graph, node, std::vector<float>{0.0}, {1},
338+
ONNXDataType::FLOAT)
339+
->outputs.front();
340+
auto clip_max = CreateConst(graph, node, std::vector<float>{threshold_}, {1},
341+
ONNXDataType::FLOAT)
342+
->outputs.front();
343+
return CreateBaseOp(graph, node, "popart_clip",
344+
{GetInputVarNode("X", node), cli_min, clip_max},
345+
node->outputs);
346+
}
347+
348+
Node *rsqrt_handler(Graph *graph, Node *node) {
349+
auto rsqrt =
350+
CreateBaseOp(graph, node, "popart_sqrt", {GetInputVarNode("X", node)}, {})
351+
->outputs.front();
352+
return CreateBaseOp(graph, node, "popart_reciprocal", {rsqrt}, node->outputs);
353+
}
354+
355+
Node *selu_handler(Graph *graph, Node *node) {
356+
auto alpha_ = BOOST_GET_CONST(float, node->Op()->GetAttr("alpha"));
357+
auto scale_ = BOOST_GET_CONST(float, node->Op()->GetAttr("scale"));
358+
return CreateBaseOp(graph, node, "popart_selu", node->inputs, node->outputs,
359+
{
360+
{"alpha", alpha_},
361+
{"gamma", scale_},
362+
});
363+
}
364+
365+
Node *silu_handler(Graph *graph, Node *node) {
366+
auto x = GetInputVarNode("X", node);
367+
auto sigmoid =
368+
CreateBaseOp(graph, node, "popart_sigmoid", {x}, {})->outputs.front();
369+
return CreateBaseOp(graph, node, "popart_mul", {x, sigmoid}, node->outputs);
370+
}
371+
372+
Node *softshrink_handler(Graph *graph, Node *node) {
373+
auto lambda_ = BOOST_GET_CONST(float, node->Op()->GetAttr("lambda"));
374+
return CreateBaseOp(graph, node, "popart_shrink", node->inputs, node->outputs,
375+
{
376+
{"lambd", lambda_},
377+
{"bias", lambda_},
378+
});
379+
}
380+
381+
Node *square_handler(Graph *graph, Node *node) {
382+
auto x = GetInputVarNode("X", node);
383+
return CreateBaseOp(graph, node, "popart_mul", {x, x}, node->outputs);
384+
}
385+
386+
Node *swish_handler(Graph *graph, Node *node) {
387+
auto x = GetInputVarNode("X", node);
388+
auto out = GetOutputVarNode("Out", node);
389+
auto beta_ = BOOST_GET_CONST(float, node->Op()->GetAttr("beta"));
390+
auto beta_node =
391+
CreateConst(graph, node, std::vector<float>{beta_}, {1}, GetVarDType(x))
392+
->outputs.front();
393+
auto beta_x_node = CreateBaseOp(graph, node, "popart_mul", {x, beta_node}, {})
394+
->outputs.front();
395+
auto sigmod_node =
396+
CreateBaseOp(graph, node, "popart_sigmoid", {beta_x_node}, {})
397+
->outputs.front();
398+
return CreateBaseOp(graph, node, "popart_mul", {x, sigmod_node}, {out});
399+
}
400+
401+
Node *tanh_shrink_handler(Graph *graph, Node *node) {
402+
auto x = GetInputVarNode("X", node);
403+
auto tanh =
404+
CreateBaseOp(graph, node, "popart_tanh", {x}, {})->outputs.front();
405+
return CreateBaseOp(graph, node, "popart_sub", {x, tanh}, node->outputs);
406+
}
407+
408+
Node *thresholded_relu_handler(Graph *graph, Node *node) {
409+
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
410+
auto x = GetInputVarNode("X", node);
411+
return CreateBaseOp(graph, node, "popart_thresholdedrelu", {x}, node->outputs,
412+
{
413+
{"alpha", threshold_},
414+
});
415+
}
416+
163417
} // namespace
164418
} // namespace ipu
165419
} // namespace platform
@@ -188,5 +442,26 @@ REGISTER_HANDLER(softsign, softsign_handler);
188442
REGISTER_HANDLER(sqrt, sqrt_handler);
189443
REGISTER_HANDLER(tan, tan_handler);
190444
REGISTER_HANDLER(tanh, tanh_handler);
445+
REGISTER_HANDLER(brelu, brelu_handler);
191446
REGISTER_HANDLER(gelu, gelu_handler);
192447
REGISTER_HANDLER(log_softmax, log_softmax_handler);
448+
REGISTER_HANDLER(elu, elu_handler);
449+
REGISTER_HANDLER(hard_shrink, hard_shrink_handler);
450+
REGISTER_HANDLER(hard_sigmoid, hard_sigmoid_handler);
451+
REGISTER_HANDLER(hard_swish, hard_swish_handler);
452+
REGISTER_HANDLER(leaky_relu, leaky_relu_handler);
453+
REGISTER_HANDLER(log10, log10_handler);
454+
REGISTER_HANDLER(log1p, log1p_handler);
455+
REGISTER_HANDLER(log2, log2_handler);
456+
REGISTER_HANDLER(logsigmoid, logsigmoid_handler);
457+
REGISTER_HANDLER(mish, mish_handler);
458+
REGISTER_HANDLER(prelu, prelu_handler);
459+
REGISTER_HANDLER(relu6, relu6_handler);
460+
REGISTER_HANDLER(rsqrt, rsqrt_handler);
461+
REGISTER_HANDLER(selu, selu_handler);
462+
REGISTER_HANDLER(silu, silu_handler);
463+
REGISTER_HANDLER(softshrink, softshrink_handler);
464+
REGISTER_HANDLER(square, square_handler);
465+
REGISTER_HANDLER(swish, swish_handler);
466+
REGISTER_HANDLER(tanh_shrink, tanh_shrink_handler);
467+
REGISTER_HANDLER(thresholded_relu, thresholded_relu_handler);

paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,20 @@ const bool is_float_equal(float a, float b, float eps) {
117117
return std::fabs(a - b) <= eps;
118118
}
119119

120-
const int GetOutputVarDType(const Node *node, const std::string &output_name) {
121-
auto out_node = GetOutputVarNode(output_name, node);
122-
PADDLE_ENFORCE_NOT_NULL(out_node, platform::errors::Unavailable(
123-
"Node's out node does not exist."));
124-
auto var = out_node->Var();
120+
const ONNXDataType GetVarDType(const Node *node) {
121+
auto var = node->Var();
125122
PADDLE_ENFORCE_NOT_NULL(
126123
var, platform::errors::Unavailable("Node is not a variable."));
127124
auto proto_var_type = var->GetDataType();
128-
return static_cast<int>(VarType2OnnxDType(proto_var_type));
125+
return VarType2OnnxDType(proto_var_type);
126+
}
127+
128+
const ONNXDataType GetOutputVarDType(const Node *node,
129+
const std::string &output_name) {
130+
auto out_node = GetOutputVarNode(output_name, node);
131+
PADDLE_ENFORCE_NOT_NULL(out_node, platform::errors::Unavailable(
132+
"Node's out node does not exist."));
133+
return GetVarDType(out_node);
129134
}
130135

131136
} // namespace ipu

paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ Node *GetOutputVarNodeByVarName(const std::string &var_name,
7878
const Node *op_node);
7979

8080
const bool is_float_equal(float a, float b, float eps = 1e-8);
81-
const int GetOutputVarDType(const Node *node,
82-
const std::string &output_name = "Out");
81+
const ONNXDataType GetVarDType(const Node *node);
82+
const ONNXDataType GetOutputVarDType(const Node *node,
83+
const std::string &output_name = "Out");
8384

8485
} // namespace ipu
8586
} // namespace platform

0 commit comments

Comments
 (0)