Skip to content

Commit 27281e1

Browse files
authored
Addition of marco for auto_tune_base.h (#50516)
1 parent 7fe44fe commit 27281e1

File tree

4 files changed

+83
-41
lines changed

4 files changed

+83
-41
lines changed

paddle/phi/kernels/autotune/auto_tune_base.h

+42-38
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,8 @@ class AutoTuneBase {
6767
const AlgorithmType& algo,
6868
const size_t key,
6969
Args&&... args) {
70-
PADDLE_ENFORCE_GT(
71-
kernels_.size(),
72-
0,
73-
phi::errors::InvalidArgument(
74-
"kernel num must be greater than 0, now is %d", kernels_.size()));
7570
is_init_ = true;
76-
71+
CheckKernelSize();
7772
auto& cache = AutoTuneCache::Instance().Get(algo);
7873
if (cache.Find(key)) {
7974
auto best_idx = cache.Get(key);
@@ -91,19 +86,22 @@ class AutoTuneBase {
9186
}
9287
}
9388

94-
private:
89+
protected:
9590
bool is_init_{false};
9691
std::vector<KernelType> kernels_;
9792
mutable std::mutex mutex_;
9893

99-
template <typename Context, typename... Args>
100-
size_t PickBestKernel(const Context& ctx, Args&&... args) {
101-
std::lock_guard<std::mutex> lock(mutex_);
94+
void CheckKernelSize() {
10295
PADDLE_ENFORCE_GT(
10396
kernels_.size(),
10497
0,
10598
phi::errors::InvalidArgument(
10699
"kernel num must be greater than 0, now is %d", kernels_.size()));
100+
}
101+
102+
template <typename Context, typename... Args>
103+
size_t PickBestKernel(const Context& ctx, Args&&... args) {
104+
std::lock_guard<std::mutex> lock(mutex_);
107105
size_t best_idx = 0;
108106
float min_time = std::numeric_limits<float>::max();
109107

@@ -143,36 +141,42 @@ class AutoTuneBase {
143141
}
144142
};
145143

146-
template <typename T, typename ReturnType, typename... Args>
147-
static AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> MakeAutoTuner(
148-
ReturnType (*func)(Args...)) {
149-
auto obj = MakeCallback<T>(func);
150-
return AutoTuneBase<T, decltype(obj)>(obj);
151-
}
152-
153-
template <typename T, typename ReturnType, typename... Args>
154-
class TransposeAutoTuner
155-
: public AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> {
156-
public:
157-
static AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>* Instance(
158-
ReturnType (*func)(Args...)) {
159-
static std::once_flag transpose_init_flag_;
160-
static std::unique_ptr<
161-
AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>>
162-
instance_;
163-
std::call_once(transpose_init_flag_, [&] {
164-
auto obj = MakeCallback<T>(func);
165-
instance_.reset(new AutoTuneBase<T, decltype(obj)>(obj));
166-
});
167-
return instance_.get();
144+
// To init the auto_tuner object.
145+
#define DEFINE_AUTOTUNER_COMMON_OBJ(name) \
146+
template <typename T, typename ReturnType, typename... Args> \
147+
class name##AutoTuner \
148+
: public AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> { \
149+
public: \
150+
static name##AutoTuner<T, ReturnType, Args...>* Instance( \
151+
ReturnType (*func)(Args...)) { \
152+
static std::once_flag name##_init_flag; \
153+
static std::unique_ptr<name##AutoTuner<T, ReturnType, Args...>> \
154+
instance; \
155+
std::call_once(name##_init_flag, [&] { \
156+
auto obj = MakeCallback<T>(func); \
157+
instance.reset(new name##AutoTuner<T, ReturnType, Args...>); \
158+
instance->AddCallBack(func); \
159+
}); \
160+
return instance.get(); \
161+
} \
162+
};
163+
164+
// To init auto_tuner inital function.
165+
#define DEFINE_AUTOTUNER_FN(name) \
166+
template <typename T, typename ReturnType, typename... Args> \
167+
static name##AutoTuner<T, ReturnType, Args...>* Make##name##Tuner( \
168+
ReturnType (*func)(Args...)) { \
169+
return name##AutoTuner<T, ReturnType, Args...>::Instance(func); \
168170
}
169-
};
170171

171-
template <typename T, typename ReturnType, typename... Args>
172-
static AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>*
173-
MakeTransposeTuner(ReturnType (*func)(Args...)) {
174-
return TransposeAutoTuner<T, ReturnType, Args...>::Instance(func);
175-
}
172+
#define DEFINE_AUTOTUNER(name) \
173+
DEFINE_AUTOTUNER_COMMON_OBJ(name) DEFINE_AUTOTUNER_FN(name)
174+
175+
DEFINE_AUTOTUNER(Transpose)
176+
177+
#undef DEFINE_AUTOTUNER_COMMON_OBJECT
178+
#undef DEFINE_AUTOTUNER_FN
179+
#undef DEFINE_AUTOTUNER
176180

177181
} // namespace autotune
178182
} // namespace phi

paddle/phi/kernels/autotune/cache.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ size_t TransposeKey(const std::vector<int64_t>& x_dims,
2525
const std::vector<int32_t>& perm,
2626
phi::DataType dtype) {
2727
const auto rank = perm.size();
28-
return GetKey(x_dims, perm, rank, static_cast<int64_t>(dtype));
28+
return GenKey(x_dims, perm, rank, static_cast<int64_t>(dtype));
2929
}
3030

3131
std::string AlgorithmTypeString(int64_t algo_type) {

paddle/phi/kernels/autotune/cache_base.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ namespace phi {
5454
namespace autotune {
5555

5656
template <typename... Args>
57-
size_t GetKey(Args&&... args) {
57+
size_t GenKey(Args&&... args) {
5858
size_t seed = 0;
5959
HashCombine(&seed, std::forward<Args>(args)...);
6060
return seed;
@@ -79,7 +79,7 @@ struct ConvCacheKey {
7979
groups(arg_groups),
8080
data_layout(arg_data_layout) {}
8181
size_t hash_value() const {
82-
return GetKey(x_dims,
82+
return GenKey(x_dims,
8383
w_dims,
8484
strides,
8585
paddings,

python/paddle/fluid/tests/unittests/test_transpose_op.py

+38
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,44 @@ def test_check_grad(self):
157157
self.check_grad(['X'], 'Out')
158158

159159

160+
class TestAutoTuneTransposeBF16Op(OpTest):
161+
def setUp(self):
162+
self.init_op_type()
163+
self.initTestCase()
164+
self.dtype = np.uint16
165+
self.python_api = paddle.transpose
166+
x = np.random.random(self.shape).astype("float32")
167+
self.inputs = {'X': convert_float_to_uint16(x)}
168+
self.attrs = {
169+
'axis': list(self.axis),
170+
'use_mkldnn': self.use_mkldnn,
171+
}
172+
self.outputs = {
173+
'XShape': convert_float_to_uint16(
174+
np.random.random(self.shape).astype("float32")
175+
),
176+
'Out': self.inputs['X'].transpose(self.axis),
177+
}
178+
179+
def initTestCase(self):
180+
fluid.core.set_autotune_range(0, 3)
181+
fluid.core.update_autotune_status()
182+
fluid.core.enable_autotune()
183+
self.shape = (2, 8, 10)
184+
self.axis = (0, 2, 1)
185+
186+
def init_op_type(self):
187+
self.op_type = "transpose2"
188+
self.use_mkldnn = False
189+
190+
def test_check_output(self):
191+
self.check_output(no_check_set=['XShape'])
192+
fluid.core.disable_autotune()
193+
194+
def test_check_grad(self):
195+
self.check_grad(['X'], 'Out')
196+
197+
160198
class TestTransposeBF16Op(OpTest):
161199
def setUp(self):
162200
self.init_op_type()

0 commit comments

Comments
 (0)