Skip to content

Commit d4217fc

Browse files
JamesLim-syzhangbopdXreki
authored
Matmul performance optimization with cuBlasLt (#46431)
* implement of matmul using cublasLt instead of cublas * Update matmul_kernel_impl_via_blasLt.h --------- Co-authored-by: zhangbopd <1299246947@qq.com> Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> Co-authored-by: Liu Yiqun <liuyiqun01@baidu.com>
1 parent 57f6a46 commit d4217fc

File tree

5 files changed

+992
-33
lines changed

5 files changed

+992
-33
lines changed

paddle/phi/kernels/autotune/auto_tune_base.h

+42-4
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,43 @@ class AutoTuneBase {
141141
}
142142
};
143143

144-
// To init the auto_tuner object.
144+
template <typename T, typename ReturnType, typename... Args>
145+
class MatmulAutoTuner
146+
: public AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> {
147+
public:
148+
static MatmulAutoTuner<T, ReturnType, Args...>* Instance(
149+
ReturnType (*func)(Args...)) {
150+
static std::once_flag matmul_init_flag;
151+
static std::unique_ptr<MatmulAutoTuner<T, ReturnType, Args...>> instance;
152+
std::call_once(matmul_init_flag, [&] {
153+
auto obj = MakeCallback<T>(func);
154+
instance.reset(new MatmulAutoTuner<T, ReturnType, Args...>);
155+
instance->AddCallBack(func);
156+
});
157+
return instance.get();
158+
}
159+
160+
template <typename Context>
161+
void Run(const Context& ctx, const size_t key, Args... args) {
162+
this->is_init_ = true;
163+
this->CheckKernelSize();
164+
auto& cache = AutoTuneCache::Instance().GetMatmul();
165+
if (cache.Find(key)) {
166+
auto best_idx = cache.Get(key);
167+
this->kernels_[best_idx].Run(args...);
168+
} else {
169+
bool use_autotune = AutoTuneStatus::Instance().UseAutoTune();
170+
if (use_autotune) {
171+
auto best_idx = this->PickBestKernel(ctx, args...);
172+
cache.Set(key, best_idx);
173+
} else {
174+
this->kernels_[0].Run(args...);
175+
}
176+
}
177+
}
178+
};
179+
180+
// Define the auto_tuner inital object.
145181
#define DEFINE_AUTOTUNER_COMMON_OBJ(name) \
146182
template <typename T, typename ReturnType, typename... Args> \
147183
class name##AutoTuner \
@@ -161,18 +197,20 @@ class AutoTuneBase {
161197
} \
162198
};
163199

164-
// To init auto_tuner inital function.
200+
// Define the auto_tuner inital function.
165201
#define DEFINE_AUTOTUNER_FN(name) \
166202
template <typename T, typename ReturnType, typename... Args> \
167203
static name##AutoTuner<T, ReturnType, Args...>* Make##name##Tuner( \
168204
ReturnType (*func)(Args...)) { \
169205
return name##AutoTuner<T, ReturnType, Args...>::Instance(func); \
170206
}
171207

172-
#define DEFINE_AUTOTUNER(name) \
173-
DEFINE_AUTOTUNER_COMMON_OBJ(name) DEFINE_AUTOTUNER_FN(name)
208+
#define DEFINE_AUTOTUNER(name) \
209+
DEFINE_AUTOTUNER_COMMON_OBJ(name) \
210+
DEFINE_AUTOTUNER_FN(name)
174211

175212
DEFINE_AUTOTUNER(Transpose)
213+
DEFINE_AUTOTUNER_FN(Matmul)
176214

177215
#undef DEFINE_AUTOTUNER_COMMON_OBJECT
178216
#undef DEFINE_AUTOTUNER_FN

paddle/phi/kernels/autotune/cache.h

+16-10
Original file line numberDiff line numberDiff line change
@@ -44,28 +44,33 @@ enum class AlgorithmType {
4444
kConvBackwardData = 2,
4545
kConvBackwardFilter = 3,
4646
kTranspose = 4,
47-
#ifdef PADDLE_WITH_CUDNN_FRONTEND
48-
kConvForwardV8 = 5,
49-
kConvBackwardDataV8 = 6,
50-
kConvBackwardFilterV8 = 7,
51-
kAlgorithmCount = 8
47+
kMatmul = 5,
48+
#if !defined(PADDLE_WITH_CUDNN_FRONTEND)
49+
kAlgorithmCount = 6
5250
#else
53-
kAlgorithmCount = 5
51+
kConvForwardV8 = 6,
52+
kConvBackwardDataV8 = 7,
53+
kConvBackwardFilterV8 = 8,
54+
kAlgorithmCount = 9
5455
#endif
5556
};
5657

5758
// AlgorithmsConfigKey -> AlgorithmsID
58-
// (todo. hong) use cudnnConvolutionFwdAlgo_t
59-
using AlgorithmsCacheMap = AlgorithmsCache<size_t, int64_t>;
6059
// AlgorithmType -> AlgorithmsCache
60+
using AlgorithmsCacheMap = AlgorithmsCache<size_t, int64_t>;
6161
using AlgorithmsTypeMap = std::unordered_map<int64_t, AlgorithmsCacheMap>;
62+
63+
// (todo. hong) use cudnnConvolutionFwdAlgo_t
6264
using ConvAlgorithmsCacheMap = ConvAlgorithmsCache<ConvAutoTuneResult>;
6365
using ConvAlgorithmsTypeMap =
6466
std::unordered_map<int64_t, ConvAlgorithmsCacheMap>;
67+
68+
using MatmulAlgorithmsCacheMap = MatmulAlgorithmsCache<size_t, int64_t>;
6569
#ifdef PADDLE_WITH_CUDNN_FRONTEND
6670
using CudnnV8AlgorithmsTypeMap =
6771
std::unordered_map<int64_t, CudnnFrontendPlanCache>;
6872
#endif
73+
6974
class AutoTuneCache {
7075
public:
7176
static AutoTuneCache& Instance() {
@@ -77,6 +82,8 @@ class AutoTuneCache {
7782
return auto_tune_map_[static_cast<int64_t>(algo_type)];
7883
}
7984

85+
MatmulAlgorithmsCacheMap& GetMatmul() { return matmul_auto_tune_map_; }
86+
8087
ConvAlgorithmsCacheMap& GetConv(const AlgorithmType& algo_type) {
8188
return conv_auto_tune_map_[static_cast<int64_t>(algo_type)];
8289
}
@@ -87,8 +94,6 @@ class AutoTuneCache {
8794
}
8895
#endif
8996

90-
AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); }
91-
9297
void Clean() {
9398
for (auto& v : auto_tune_map_) {
9499
v.second.Clean();
@@ -162,6 +167,7 @@ class AutoTuneCache {
162167

163168
AlgorithmsTypeMap auto_tune_map_;
164169
ConvAlgorithmsTypeMap conv_auto_tune_map_;
170+
MatmulAlgorithmsCacheMap matmul_auto_tune_map_;
165171
#ifdef PADDLE_WITH_CUDNN_FRONTEND
166172
CudnnV8AlgorithmsTypeMap cudnn_v8_auto_tune_map_;
167173
#endif

paddle/phi/kernels/autotune/cache_base.h

+54
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,31 @@ size_t GenKey(Args&&... args) {
6060
return seed;
6161
}
6262

63+
struct MatmulHashValueType {
64+
uint64_t data[8];
65+
};
66+
67+
struct MatmulCacheKey {
68+
public:
69+
MatmulCacheKey() {}
70+
MatmulCacheKey(const std::vector<int64_t>& x_dims,
71+
const std::vector<int64_t>& y_dims,
72+
const bool trans_x,
73+
const bool trans_y,
74+
phi::DataType dtype) {
75+
key = GenKey(x_dims,
76+
y_dims,
77+
static_cast<int64_t>(trans_x),
78+
static_cast<int64_t>(trans_y),
79+
static_cast<int64_t>(dtype));
80+
}
81+
size_t GetKey() const { return key; }
82+
size_t GetSubKey(int64_t idx) const { return GenKey(key, idx); }
83+
84+
private:
85+
size_t key;
86+
};
87+
6388
struct ConvCacheKey {
6489
ConvCacheKey() {}
6590
ConvCacheKey(const std::vector<int64_t>& arg_x_dims,
@@ -213,5 +238,34 @@ class ConvAlgorithmsCache : public AlgorithmsCache<ConvCacheKey,
213238
}
214239
};
215240

241+
template <typename KeyT, typename AlgorithmT>
242+
class MatmulAlgorithmsCache : public AlgorithmsCache<KeyT, AlgorithmT> {
243+
public:
244+
MatmulAlgorithmsCache() : AlgorithmsCache<KeyT, AlgorithmT>() {}
245+
246+
bool FindSubKey(const KeyT& sub_key) {
247+
std::lock_guard<std::mutex> lock(*(this->cache_mutex_));
248+
bool ret = (sub_hash_.find(sub_key) != sub_hash_.end()) ? true : false;
249+
return ret;
250+
}
251+
252+
void SetSubKey(const KeyT& sub_key, const MatmulHashValueType* algo) {
253+
std::lock_guard<std::mutex> lock(*(this->cache_mutex_));
254+
sub_hash_[sub_key] = *algo;
255+
}
256+
257+
MatmulHashValueType* GetSubKey(const KeyT& sub_key) {
258+
std::lock_guard<std::mutex> lock(*(this->cache_mutex_));
259+
PADDLE_ENFORCE_NE(
260+
sub_hash_.find(sub_key),
261+
sub_hash_.end(),
262+
phi::errors::PreconditionNotMet("The key does not exist."));
263+
return &(sub_hash_[sub_key]);
264+
}
265+
266+
private:
267+
std::unordered_map<KeyT, MatmulHashValueType> sub_hash_;
268+
};
269+
216270
} // namespace autotune
217271
} // namespace phi

0 commit comments

Comments
 (0)