@@ -34,10 +34,10 @@ extern std::once_flag PaddleInferenceAnakinPredictor<T, P, R>::init_anakin_;
34
34
35
35
template <typename T, Precision P, OpRunType R>
36
36
void PaddleInferenceAnakinPredictor<T, P, R>::InitEnv() {
37
- anakin::TargetWrapper<T>::set_device (this ->config_ .device_id );
38
37
std::call_once (this ->init_anakin_ , [this ]() {
39
38
anakin::Env<T>::env_init (this ->config_ .max_stream );
40
39
});
40
+ anakin::TargetWrapper<T>::set_device (this ->config_ .device_id );
41
41
}
42
42
template <typename T, Precision P, OpRunType R>
43
43
void PaddleInferenceAnakinPredictor<T, P, R>::InitNet() {
@@ -194,6 +194,7 @@ template <typename T, Precision P, OpRunType R>
194
194
bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
195
195
const std::vector<PaddleTensor> &inputs,
196
196
std::vector<PaddleTensor> *output_data) {
197
+ anakin::TargetWrapper<T>::set_device (this ->config_ .device_id );
197
198
for (const auto &input : inputs) {
198
199
if (input.dtype != PaddleDType::FLOAT32) {
199
200
LOG (FATAL) << " Only support float type inputs. " << input.name
@@ -326,6 +327,27 @@ void PaddleInferenceAnakinMLUPredictor<P, R>::Predict() {
326
327
}
327
328
#endif
328
329
330
+ #ifdef ANAKIN_BM_PLACE
331
+ template <Precision P, OpRunType R>
332
+ void PaddleInferenceAnakinBMPredictor<P, R>::OptimizeGraph() {
333
+ if (!this ->graph_p_ ->fusion_optimize ()) {
334
+ LOG (FATAL) << " Graph optimization error." ;
335
+ }
336
+ }
337
+ template <Precision P, OpRunType R>
338
+ void PaddleInferenceAnakinBMPredictor<P, R>::InitNet() {
339
+ std::unique_lock<std::mutex> lock (this ->mutex_ );
340
+ this ->executor_p_ = new anakin::Net<anakin::BM, P, R>();
341
+ this ->executor_p_ ->fusion_init (*this ->graph_p_ , this ->ctx_p_ , true );
342
+ }
343
+ template <Precision P, OpRunType R>
344
+ void PaddleInferenceAnakinBMPredictor<P, R>::Predict() {
345
+ anakin::TargetWrapper<anakin::BM>::device_sync ();
346
+ this ->executor_p_ ->fusion_prediction ();
347
+ anakin::TargetWrapper<anakin::BM>::device_sync ();
348
+ }
349
+ #endif
350
+
329
351
#ifdef PADDLE_WITH_CUDA
330
352
template class PaddleInferenceAnakinPredictor <
331
353
anakin::NV, anakin::Precision::FP32, ::anakin::OpRunType::ASYNC>;
@@ -338,6 +360,10 @@ template class PaddleInferenceAnakinPredictor<
338
360
template class PaddleInferenceAnakinMLUPredictor <anakin::Precision::FP32,
339
361
::anakin::OpRunType::SYNC>;
340
362
#endif
363
+ #ifdef ANAKIN_BM_PLACE
364
+ template class PaddleInferenceAnakinBMPredictor <anakin::Precision::FP32,
365
+ ::anakin::OpRunType::ASYNC>;
366
+ #endif
341
367
342
368
// A factory to help create difference predictor.
343
369
template <>
@@ -365,6 +391,14 @@ CreatePaddlePredictor<contrib::AnakinConfig, PaddleEngineKind::kAnakin>(
365
391
::anakin::OpRunType::SYNC>(
366
392
config));
367
393
}
394
+ #endif
395
+ #ifdef ANAKIN_BM_PLACE
396
+ if (config.target_type == contrib::AnakinConfig::BM) {
397
+ return std::unique_ptr<PaddlePredictor>(
398
+ new PaddleInferenceAnakinBMPredictor<anakin::Precision::FP32,
399
+ ::anakin::OpRunType::ASYNC>(
400
+ config));
401
+ }
368
402
#endif
369
403
LOG (FATAL) << " Anakin Predictor create on unknown platform." ;
370
404
return nullptr ;
0 commit comments