@@ -119,6 +119,21 @@ Node *tanh_handler(Graph *graph, Node *node) {
119
119
return activation_op_handler (graph, node, " popart_tanh" );
120
120
}
121
121
122
+ Node *brelu_handler (Graph *graph, Node *node) {
123
+ auto *op = node->Op ();
124
+ auto t_min_ = BOOST_GET_CONST (float , op->GetAttr (" t_min" ));
125
+ auto t_max_ = BOOST_GET_CONST (float , op->GetAttr (" t_max" ));
126
+ auto x = GetInputVarNode (" X" , node);
127
+ auto cli_min = CreateConst (graph, node, std::vector<float >{t_min_}, {1 },
128
+ ONNXDataType::FLOAT)
129
+ ->outputs .front ();
130
+ auto clip_max = CreateConst (graph, node, std::vector<float >{t_max_}, {1 },
131
+ ONNXDataType::FLOAT)
132
+ ->outputs .front ();
133
+ return CreateBaseOp (graph, node, " popart_clip" , {x, cli_min, clip_max},
134
+ node->outputs );
135
+ }
136
+
122
137
Node *gelu_handler (Graph *graph, Node *node) {
123
138
auto *op = node->Op ();
124
139
auto approximate_ = BOOST_GET_CONST (bool , op->GetAttr (" approximate" ));
@@ -160,6 +175,245 @@ Node *log_softmax_handler(Graph *graph, Node *node) {
160
175
node->outputs );
161
176
}
162
177
178
+ Node *elu_handler (Graph *graph, Node *node) {
179
+ auto alpha_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" alpha" ));
180
+ return CreateBaseOp (graph, node, " popart_elu" , node->inputs , node->outputs ,
181
+ {
182
+ {" alpha" , alpha_},
183
+ });
184
+ }
185
+
186
+ Node *hard_shrink_handler (Graph *graph, Node *node) {
187
+ auto threshold_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" threshold" ));
188
+ return CreateBaseOp (graph, node, " popart_shrink" , node->inputs , node->outputs ,
189
+ {
190
+ {" lambd" , threshold_},
191
+ {" bias" , 0 .0f },
192
+ });
193
+ }
194
+
195
+ Node *hard_sigmoid_handler (Graph *graph, Node *node) {
196
+ auto slope_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" slope" ));
197
+ auto offset_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" offset" ));
198
+ return CreateBaseOp (graph, node, " popart_hardsigmoid" , node->inputs ,
199
+ node->outputs ,
200
+ {
201
+ {" alpha" , slope_},
202
+ {" beta" , offset_},
203
+ });
204
+ }
205
+
206
+ Node *hard_swish_handler (Graph *graph, Node *node) {
207
+ auto x = GetInputVarNode (" X" , node);
208
+ auto scale_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" scale" ));
209
+ auto offset_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" offset" ));
210
+ auto threshold_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" threshold" ));
211
+ auto scale_node =
212
+ CreateConst (graph, node, std::vector<float >{scale_}, {1 }, GetVarDType (x))
213
+ ->outputs .front ();
214
+ auto offset_node =
215
+ CreateConst (graph, node, std::vector<float >{offset_}, {1 }, GetVarDType (x))
216
+ ->outputs .front ();
217
+ auto add_node = CreateBaseOp (graph, node, " popart_add" , {x, offset_node}, {})
218
+ ->outputs .front ();
219
+ auto cli_min = CreateConst (graph, node, std::vector<float >{0.0 }, {1 },
220
+ ONNXDataType::FLOAT)
221
+ ->outputs .front ();
222
+ auto clip_max = CreateConst (graph, node, std::vector<float >{threshold_}, {1 },
223
+ ONNXDataType::FLOAT)
224
+ ->outputs .front ();
225
+ auto clip_node = CreateBaseOp (graph, node, " popart_clip" ,
226
+ {add_node, cli_min, clip_max}, {})
227
+ ->outputs .front ();
228
+ auto mul_node = CreateBaseOp (graph, node, " popart_mul" , {x, clip_node}, {})
229
+ ->outputs .front ();
230
+ return CreateBaseOp (graph, node, " popart_div" , {mul_node, scale_node},
231
+ {GetOutputVarNode (" Out" , node)});
232
+ }
233
+
234
+ Node *leaky_relu_handler (Graph *graph, Node *node) {
235
+ auto alpha_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" alpha" ));
236
+ return CreateBaseOp (graph, node, " popart_leakyrelu" , node->inputs ,
237
+ node->outputs ,
238
+ {
239
+ {" alpha" , alpha_},
240
+ });
241
+ }
242
+
243
+ Node *log10_handler (Graph *graph, Node *node) {
244
+ auto x = GetInputVarNode (" X" , node);
245
+ float ln10 = 2.30258509299404568401 ;
246
+ auto ln10_tensor =
247
+ CreateConst (graph, node, std::vector<float >{ln10}, {1 }, GetVarDType (x))
248
+ ->outputs .front ();
249
+ auto log = CreateBaseOp (graph, node, " popart_log" , {x}, {})->outputs .front ();
250
+ return CreateBaseOp (graph, node, " popart_div" , {log , ln10_tensor},
251
+ node->outputs );
252
+ }
253
+
254
+ Node *log1p_handler (Graph *graph, Node *node) {
255
+ auto x = GetInputVarNode (" X" , node);
256
+ auto one =
257
+ CreateConst (graph, node, std::vector<float >{1.0 }, {1 }, GetVarDType (x))
258
+ ->outputs .front ();
259
+ auto add =
260
+ CreateBaseOp (graph, node, " popart_add" , {x, one}, {})->outputs .front ();
261
+ return CreateBaseOp (graph, node, " popart_log" , {add}, node->outputs );
262
+ }
263
+
264
+ Node *log2_handler (Graph *graph, Node *node) {
265
+ auto x = GetInputVarNode (" X" , node);
266
+ float ln2 = 0.693147180559945309 ;
267
+ auto ln2_tensor =
268
+ CreateConst (graph, node, std::vector<float >{ln2}, {1 }, GetVarDType (x))
269
+ ->outputs .front ();
270
+ auto log = CreateBaseOp (graph, node, " popart_log" , {x}, {})->outputs .front ();
271
+ return CreateBaseOp (graph, node, " popart_div" , {log , ln2_tensor},
272
+ node->outputs );
273
+ }
274
+
275
+ Node *logsigmoid_handler (Graph *graph, Node *node) {
276
+ auto sigmoid = CreateBaseOp (graph, node, " popart_sigmoid" ,
277
+ {GetInputVarNode (" X" , node)}, {})
278
+ ->outputs .front ();
279
+ return CreateBaseOp (graph, node, " popart_log" , {sigmoid}, node->outputs );
280
+ }
281
+
282
+ Node *mish_handler (Graph *graph, Node *node) {
283
+ auto threshold_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" threshold" ));
284
+ if (!is_float_equal (threshold_, 20 .0f )) {
285
+ PADDLE_THROW (platform::errors::Unimplemented (
286
+ " For mish op, only support threshold = 20.0" ));
287
+ }
288
+ auto x = GetInputVarNode (" X" , node);
289
+ auto softplus =
290
+ CreateBaseOp (graph, node, " popart_softplus" , {x}, {})->outputs .front ();
291
+ auto tanh =
292
+ CreateBaseOp (graph, node, " popart_tanh" , {softplus}, {})->outputs .front ();
293
+ return CreateBaseOp (graph, node, " popart_mul" , {x, tanh }, node->outputs );
294
+ }
295
+
296
+ Node *prelu_handler (Graph *graph, Node *node) {
297
+ auto x = GetInputVarNode (" X" , node);
298
+ auto alpha = GetInputVarNode (" Alpha" , node);
299
+ auto out = GetOutputVarNode (" Out" , node);
300
+ auto x_rank = x->Var ()->GetShape ().size ();
301
+ auto alpha_rank = alpha->Var ()->GetShape ().size ();
302
+ if (x_rank != alpha_rank) {
303
+ if (alpha_rank > 1 ) {
304
+ PADDLE_THROW (platform::errors::Unimplemented (
305
+ " For prelu op, Only support rank of alpha <=1 while Rank(alpha) != "
306
+ " Rank(input)." ));
307
+ }
308
+ }
309
+
310
+ if (x_rank != alpha_rank) {
311
+ if (alpha_rank > 1 ) {
312
+ PADDLE_THROW (platform::errors::Unimplemented (
313
+ " For prelu op, Only support rank of alpha <= 1 while rank of alpha "
314
+ " is not equal with rank of input for operator prelu" ));
315
+ }
316
+ if (x_rank <= 1 ) {
317
+ PADDLE_THROW (
318
+ platform::errors::Unimplemented (" For prelu op, Rank of input should "
319
+ " greater than 2 for operator prelu" ));
320
+ }
321
+ auto shape = std::vector<int64_t >(x_rank - 1 , 1 );
322
+ shape[0 ] = -1 ;
323
+ int64_t size = shape.size ();
324
+ auto dim = std::vector<int64_t >{size};
325
+ auto reshape_const =
326
+ CreateConst (graph, node, shape, dim, ONNXDataType::INT64)
327
+ ->outputs .front ();
328
+ alpha =
329
+ CreateBaseOp (graph, node, " popart_reshape" , {alpha, reshape_const}, {})
330
+ ->outputs .front ();
331
+ }
332
+ return CreateBaseOp (graph, node, " popart_prelu" , {x, alpha}, {out});
333
+ }
334
+
335
+ Node *relu6_handler (Graph *graph, Node *node) {
336
+ auto threshold_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" threshold" ));
337
+ auto cli_min = CreateConst (graph, node, std::vector<float >{0.0 }, {1 },
338
+ ONNXDataType::FLOAT)
339
+ ->outputs .front ();
340
+ auto clip_max = CreateConst (graph, node, std::vector<float >{threshold_}, {1 },
341
+ ONNXDataType::FLOAT)
342
+ ->outputs .front ();
343
+ return CreateBaseOp (graph, node, " popart_clip" ,
344
+ {GetInputVarNode (" X" , node), cli_min, clip_max},
345
+ node->outputs );
346
+ }
347
+
348
+ Node *rsqrt_handler (Graph *graph, Node *node) {
349
+ auto rsqrt =
350
+ CreateBaseOp (graph, node, " popart_sqrt" , {GetInputVarNode (" X" , node)}, {})
351
+ ->outputs .front ();
352
+ return CreateBaseOp (graph, node, " popart_reciprocal" , {rsqrt}, node->outputs );
353
+ }
354
+
355
+ Node *selu_handler (Graph *graph, Node *node) {
356
+ auto alpha_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" alpha" ));
357
+ auto scale_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" scale" ));
358
+ return CreateBaseOp (graph, node, " popart_selu" , node->inputs , node->outputs ,
359
+ {
360
+ {" alpha" , alpha_},
361
+ {" gamma" , scale_},
362
+ });
363
+ }
364
+
365
+ Node *silu_handler (Graph *graph, Node *node) {
366
+ auto x = GetInputVarNode (" X" , node);
367
+ auto sigmoid =
368
+ CreateBaseOp (graph, node, " popart_sigmoid" , {x}, {})->outputs .front ();
369
+ return CreateBaseOp (graph, node, " popart_mul" , {x, sigmoid}, node->outputs );
370
+ }
371
+
372
+ Node *softshrink_handler (Graph *graph, Node *node) {
373
+ auto lambda_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" lambda" ));
374
+ return CreateBaseOp (graph, node, " popart_shrink" , node->inputs , node->outputs ,
375
+ {
376
+ {" lambd" , lambda_},
377
+ {" bias" , lambda_},
378
+ });
379
+ }
380
+
381
+ Node *square_handler (Graph *graph, Node *node) {
382
+ auto x = GetInputVarNode (" X" , node);
383
+ return CreateBaseOp (graph, node, " popart_mul" , {x, x}, node->outputs );
384
+ }
385
+
386
+ Node *swish_handler (Graph *graph, Node *node) {
387
+ auto x = GetInputVarNode (" X" , node);
388
+ auto out = GetOutputVarNode (" Out" , node);
389
+ auto beta_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" beta" ));
390
+ auto beta_node =
391
+ CreateConst (graph, node, std::vector<float >{beta_}, {1 }, GetVarDType (x))
392
+ ->outputs .front ();
393
+ auto beta_x_node = CreateBaseOp (graph, node, " popart_mul" , {x, beta_node}, {})
394
+ ->outputs .front ();
395
+ auto sigmod_node =
396
+ CreateBaseOp (graph, node, " popart_sigmoid" , {beta_x_node}, {})
397
+ ->outputs .front ();
398
+ return CreateBaseOp (graph, node, " popart_mul" , {x, sigmod_node}, {out});
399
+ }
400
+
401
+ Node *tanh_shrink_handler (Graph *graph, Node *node) {
402
+ auto x = GetInputVarNode (" X" , node);
403
+ auto tanh =
404
+ CreateBaseOp (graph, node, " popart_tanh" , {x}, {})->outputs .front ();
405
+ return CreateBaseOp (graph, node, " popart_sub" , {x, tanh }, node->outputs );
406
+ }
407
+
408
+ Node *thresholded_relu_handler (Graph *graph, Node *node) {
409
+ auto threshold_ = BOOST_GET_CONST (float , node->Op ()->GetAttr (" threshold" ));
410
+ auto x = GetInputVarNode (" X" , node);
411
+ return CreateBaseOp (graph, node, " popart_thresholdedrelu" , {x}, node->outputs ,
412
+ {
413
+ {" alpha" , threshold_},
414
+ });
415
+ }
416
+
163
417
} // namespace
164
418
} // namespace ipu
165
419
} // namespace platform
@@ -188,5 +442,26 @@ REGISTER_HANDLER(softsign, softsign_handler);
188
442
REGISTER_HANDLER (sqrt, sqrt_handler);
189
443
REGISTER_HANDLER (tan, tan_handler);
190
444
REGISTER_HANDLER (tanh, tanh_handler);
445
+ REGISTER_HANDLER (brelu, brelu_handler);
191
446
REGISTER_HANDLER (gelu, gelu_handler);
192
447
REGISTER_HANDLER (log_softmax, log_softmax_handler);
448
+ REGISTER_HANDLER (elu, elu_handler);
449
+ REGISTER_HANDLER (hard_shrink, hard_shrink_handler);
450
+ REGISTER_HANDLER (hard_sigmoid, hard_sigmoid_handler);
451
+ REGISTER_HANDLER (hard_swish, hard_swish_handler);
452
+ REGISTER_HANDLER (leaky_relu, leaky_relu_handler);
453
+ REGISTER_HANDLER (log10, log10_handler);
454
+ REGISTER_HANDLER (log1p, log1p_handler);
455
+ REGISTER_HANDLER (log2, log2_handler);
456
+ REGISTER_HANDLER (logsigmoid, logsigmoid_handler);
457
+ REGISTER_HANDLER (mish, mish_handler);
458
+ REGISTER_HANDLER (prelu, prelu_handler);
459
+ REGISTER_HANDLER (relu6, relu6_handler);
460
+ REGISTER_HANDLER (rsqrt, rsqrt_handler);
461
+ REGISTER_HANDLER (selu, selu_handler);
462
+ REGISTER_HANDLER (silu, silu_handler);
463
+ REGISTER_HANDLER (softshrink, softshrink_handler);
464
+ REGISTER_HANDLER (square, square_handler);
465
+ REGISTER_HANDLER (swish, swish_handler);
466
+ REGISTER_HANDLER (tanh_shrink, tanh_shrink_handler);
467
+ REGISTER_HANDLER (thresholded_relu, thresholded_relu_handler);
0 commit comments