@@ -67,13 +67,8 @@ class AutoTuneBase {
67
67
const AlgorithmType& algo,
68
68
const size_t key,
69
69
Args&&... args) {
70
- PADDLE_ENFORCE_GT (
71
- kernels_.size (),
72
- 0 ,
73
- phi::errors::InvalidArgument (
74
- " kernel num must be greater than 0, now is %d" , kernels_.size ()));
75
70
is_init_ = true ;
76
-
71
+ CheckKernelSize ();
77
72
auto & cache = AutoTuneCache::Instance ().Get (algo);
78
73
if (cache.Find (key)) {
79
74
auto best_idx = cache.Get (key);
@@ -91,19 +86,22 @@ class AutoTuneBase {
91
86
}
92
87
}
93
88
94
- private :
89
+ protected :
95
90
bool is_init_{false };
96
91
std::vector<KernelType> kernels_;
97
92
mutable std::mutex mutex_;
98
93
99
- template <typename Context, typename ... Args>
100
- size_t PickBestKernel (const Context& ctx, Args&&... args) {
101
- std::lock_guard<std::mutex> lock (mutex_);
94
+ void CheckKernelSize () {
102
95
PADDLE_ENFORCE_GT (
103
96
kernels_.size (),
104
97
0 ,
105
98
phi::errors::InvalidArgument (
106
99
" kernel num must be greater than 0, now is %d" , kernels_.size ()));
100
+ }
101
+
102
+ template <typename Context, typename ... Args>
103
+ size_t PickBestKernel (const Context& ctx, Args&&... args) {
104
+ std::lock_guard<std::mutex> lock (mutex_);
107
105
size_t best_idx = 0 ;
108
106
float min_time = std::numeric_limits<float >::max ();
109
107
@@ -143,36 +141,42 @@ class AutoTuneBase {
143
141
}
144
142
};
145
143
146
- template <typename T, typename ReturnType, typename ... Args>
147
- static AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> MakeAutoTuner (
148
- ReturnType (*func)(Args...)) {
149
- auto obj = MakeCallback<T>(func);
150
- return AutoTuneBase<T, decltype (obj)>(obj);
151
- }
152
-
153
- template <typename T, typename ReturnType, typename ... Args>
154
- class TransposeAutoTuner
155
- : public AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> {
156
- public:
157
- static AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>* Instance (
158
- ReturnType (*func)(Args...)) {
159
- static std::once_flag transpose_init_flag_;
160
- static std::unique_ptr<
161
- AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>>
162
- instance_;
163
- std::call_once (transpose_init_flag_, [&] {
164
- auto obj = MakeCallback<T>(func);
165
- instance_.reset (new AutoTuneBase<T, decltype (obj)>(obj));
166
- });
167
- return instance_.get ();
144
+ // To init the auto_tuner object.
145
+ #define DEFINE_AUTOTUNER_COMMON_OBJ (name ) \
146
+ template <typename T, typename ReturnType, typename ... Args> \
147
+ class name ##AutoTuner \
148
+ : public AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> { \
149
+ public: \
150
+ static name##AutoTuner<T, ReturnType, Args...>* Instance ( \
151
+ ReturnType (*func)(Args...)) { \
152
+ static std::once_flag name##_init_flag; \
153
+ static std::unique_ptr<name##AutoTuner<T, ReturnType, Args...>> \
154
+ instance; \
155
+ std::call_once (name##_init_flag, [&] { \
156
+ auto obj = MakeCallback<T>(func); \
157
+ instance.reset (new name##AutoTuner<T, ReturnType, Args...>); \
158
+ instance->AddCallBack (func); \
159
+ }); \
160
+ return instance.get (); \
161
+ } \
162
+ };
163
+
164
+ // To init auto_tuner inital function.
165
+ #define DEFINE_AUTOTUNER_FN (name ) \
166
+ template <typename T, typename ReturnType, typename ... Args> \
167
+ static name##AutoTuner<T, ReturnType, Args...>* Make##name##Tuner( \
168
+ ReturnType (*func)(Args...)) { \
169
+ return name##AutoTuner<T, ReturnType, Args...>::Instance (func); \
168
170
}
169
- };
170
171
171
- template <typename T, typename ReturnType, typename ... Args>
172
- static AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>*
173
- MakeTransposeTuner (ReturnType (*func)(Args...)) {
174
- return TransposeAutoTuner<T, ReturnType, Args...>::Instance (func);
175
- }
172
+ #define DEFINE_AUTOTUNER (name ) \
173
+ DEFINE_AUTOTUNER_COMMON_OBJ (name) DEFINE_AUTOTUNER_FN(name)
174
+
175
+ DEFINE_AUTOTUNER(Transpose)
176
+
177
+ #undef DEFINE_AUTOTUNER_COMMON_OBJECT
178
+ #undef DEFINE_AUTOTUNER_FN
179
+ #undef DEFINE_AUTOTUNER
176
180
177
181
} // namespace autotune
178
182
} // namespace phi
0 commit comments