Skip to content

Commit 0b00dad

Browse files
committed
support anakin for bitmain arch
test=develop
1 parent cd5c898 commit 0b00dad

File tree

4 files changed

+56
-2
lines changed

4 files changed

+56
-2
lines changed

paddle/fluid/inference/api/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,14 @@ cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_
6060
ARGS --dirname=${WORD2VEC_MODEL_DIR})
6161

6262
if(ANAKIN_FOUND)
63+
# Do not turn warnings into errors.
64+
set_source_files_properties(api.cc api_anakin_engine.cc PROPERTIES COMPILE_FLAGS "-Wno-error")
6365
if (ANAKIN_MLU AND NOT WITH_GPU AND NOT ANAKIN_X86)
6466
message(STATUS "Compile with anakin mlu place.")
6567
add_definitions(-DANAKIN_MLU_PLACE)
68+
elseif(ANAKIN_BM AND NOT WITH_GPU AND NOT ANAKIN_X86)
69+
message(STATUS "Compile with anakin bm place.")
70+
add_definitions(-DANAKIN_BM_PLACE)
6671
elseif(ANAKIN_X86)
6772
message(STATUS "Compile with anakin x86 place.")
6873
add_definitions(-DANAKIN_X86_PLACE)

paddle/fluid/inference/api/api_anakin_engine.cc

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ extern std::once_flag PaddleInferenceAnakinPredictor<T, P, R>::init_anakin_;
3434

3535
template <typename T, Precision P, OpRunType R>
3636
void PaddleInferenceAnakinPredictor<T, P, R>::InitEnv() {
37-
anakin::TargetWrapper<T>::set_device(this->config_.device_id);
3837
std::call_once(this->init_anakin_, [this]() {
3938
anakin::Env<T>::env_init(this->config_.max_stream);
4039
});
40+
anakin::TargetWrapper<T>::set_device(this->config_.device_id);
4141
}
4242
template <typename T, Precision P, OpRunType R>
4343
void PaddleInferenceAnakinPredictor<T, P, R>::InitNet() {
@@ -194,6 +194,7 @@ template <typename T, Precision P, OpRunType R>
194194
bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
195195
const std::vector<PaddleTensor> &inputs,
196196
std::vector<PaddleTensor> *output_data) {
197+
anakin::TargetWrapper<T>::set_device(this->config_.device_id);
197198
for (const auto &input : inputs) {
198199
if (input.dtype != PaddleDType::FLOAT32) {
199200
LOG(FATAL) << "Only support float type inputs. " << input.name
@@ -326,6 +327,27 @@ void PaddleInferenceAnakinMLUPredictor<P, R>::Predict() {
326327
}
327328
#endif
328329

330+
#ifdef ANAKIN_BM_PLACE
331+
template <Precision P, OpRunType R>
332+
void PaddleInferenceAnakinBMPredictor<P, R>::OptimizeGraph() {
333+
if (!this->graph_p_->fusion_optimize()) {
334+
LOG(FATAL) << "Graph optimization error.";
335+
}
336+
}
337+
template <Precision P, OpRunType R>
338+
void PaddleInferenceAnakinBMPredictor<P, R>::InitNet() {
339+
std::unique_lock<std::mutex> lock(this->mutex_);
340+
this->executor_p_ = new anakin::Net<anakin::BM, P, R>();
341+
this->executor_p_->fusion_init(*this->graph_p_, this->ctx_p_, true);
342+
}
343+
template <Precision P, OpRunType R>
344+
void PaddleInferenceAnakinBMPredictor<P, R>::Predict() {
345+
anakin::TargetWrapper<anakin::BM>::device_sync();
346+
this->executor_p_->fusion_prediction();
347+
anakin::TargetWrapper<anakin::BM>::device_sync();
348+
}
349+
#endif
350+
329351
#ifdef PADDLE_WITH_CUDA
330352
template class PaddleInferenceAnakinPredictor<
331353
anakin::NV, anakin::Precision::FP32, ::anakin::OpRunType::ASYNC>;
@@ -338,6 +360,10 @@ template class PaddleInferenceAnakinPredictor<
338360
template class PaddleInferenceAnakinMLUPredictor<anakin::Precision::FP32,
339361
::anakin::OpRunType::SYNC>;
340362
#endif
363+
#ifdef ANAKIN_BM_PLACE
364+
template class PaddleInferenceAnakinBMPredictor<anakin::Precision::FP32,
365+
::anakin::OpRunType::ASYNC>;
366+
#endif
341367

342368
// A factory to help create difference predictor.
343369
template <>
@@ -365,6 +391,14 @@ CreatePaddlePredictor<contrib::AnakinConfig, PaddleEngineKind::kAnakin>(
365391
::anakin::OpRunType::SYNC>(
366392
config));
367393
}
394+
#endif
395+
#ifdef ANAKIN_BM_PLACE
396+
if (config.target_type == contrib::AnakinConfig::BM) {
397+
return std::unique_ptr<PaddlePredictor>(
398+
new PaddleInferenceAnakinBMPredictor<anakin::Precision::FP32,
399+
::anakin::OpRunType::ASYNC>(
400+
config));
401+
}
368402
#endif
369403
LOG(FATAL) << "Anakin Predictor create on unknown platform.";
370404
return nullptr;

paddle/fluid/inference/api/api_anakin_engine.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,19 @@ class PaddleInferenceAnakinMLUPredictor final
9292
void Predict() override;
9393
};
9494
#endif
95+
96+
#ifdef ANAKIN_BM_PLACE
97+
template <Precision P, OpRunType R>
98+
class PaddleInferenceAnakinBMPredictor final
99+
: public PaddleInferenceAnakinPredictor<anakin::BM, P, R> {
100+
public:
101+
explicit PaddleInferenceAnakinBMPredictor(const AnakinConfig& config) {
102+
this->ResetConfig(config);
103+
this->InitPredictor();
104+
}
105+
void OptimizeGraph() override;
106+
void InitNet() override;
107+
void Predict() override;
108+
};
109+
#endif
95110
} // namespace paddle

paddle/fluid/inference/api/paddle_anakin_config.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace paddle {
2525
namespace contrib {
2626
// Configurations for Anakin engine.
2727
struct AnakinConfig : public PaddlePredictor::Config {
28-
enum TargetType { NVGPU = 0, X86, MLU };
28+
enum TargetType { NVGPU = 0, X86, MLU, BM };
2929
int device_id{0};
3030
std::string model_file;
3131
std::map<std::string, std::vector<int>> init_inputs_shape;

0 commit comments

Comments
 (0)