Skip to content

Matmul performance optimization with cuBlasLt #46431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 76 commits into from
Feb 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
7f42952
for 1st time interface combine.
JamesLim-sy Mar 17, 2022
1dba1a6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JamesLim-sy Sep 23, 2022
96cf58a
another first commit
JamesLim-sy Sep 23, 2022
07677e3
first commit
JamesLim-sy Sep 26, 2022
ee801c3
first commit
JamesLim-sy Sep 26, 2022
67bf57c
merge alloc together
JamesLim-sy Sep 26, 2022
64ee6d7
remove the autotune.h file
JamesLim-sy Sep 26, 2022
de873b4
add CheckEighResult for both sysej and evd kernel
JamesLim-sy Sep 26, 2022
3aa505a
profile reduce kernel for fp16 and reduceHigherdim
zhangbopd Oct 13, 2022
c6c5ca2
Merge branch 'PaddlePaddle:develop' into develop
zhangbopd Oct 21, 2022
4187e28
Merge branch 'PaddlePaddle:develop' into develop
zhangbopd Oct 27, 2022
14180cc
use reinterpret_cast
zhangbopd Oct 27, 2022
2c6eaa5
fix for CI on ROCm
zhangbopd Oct 28, 2022
1eaf75f
add Macro for ROCm
zhangbopd Oct 28, 2022
5f8c72b
ROCm CI config
zhangbopd Nov 3, 2022
444b1c4
ROCm CI config
zhangbopd Nov 3, 2022
fbb8361
unit test repair
zhangbopd Nov 7, 2022
19de67a
Merge branch 'PaddlePaddle:develop' into develop
zhangbopd Nov 8, 2022
427e98c
Merge branch 'PaddlePaddle:develop' into develop
zhangbopd Nov 16, 2022
cbf1f3d
Merge branch 'PaddlePaddle:develop' into develop
zhangbopd Nov 21, 2022
c6dbe30
pull
zhangbopd Nov 21, 2022
2a9ef0a
add common_funcs.h
zhangbopd Nov 22, 2022
ba99367
reduceType
zhangbopd Nov 22, 2022
d326b58
Update reduce_function.h
zhangbopd Nov 22, 2022
2ccb0ea
not higher
zhangbopd Nov 22, 2022
2a14bdb
conflict fix
zhangbopd Nov 22, 2022
ff38003
rename
zhangbopd Nov 23, 2022
3c7e544
Merge branch 'PaddlePaddle:develop' into develop
zhangbopd Nov 30, 2022
66475ea
Merge branch 'PaddlePaddle:develop' into develop
zhangbopd Dec 5, 2022
e3fd59b
implement of matmul using cublasLt instead of cublas
zhangbopd Dec 7, 2022
9c2b658
Merge branch 'PaddlePaddle:develop' into matmul-autotune
zhangbopd Dec 7, 2022
218990e
cublasLt bugfix
zhangbopd Dec 7, 2022
40d66f9
Merge branch 'matmul-autotune' of https://github.com/zhangbopd/Paddle…
zhangbopd Dec 7, 2022
f49b23d
Update matmul_kernel_impl.h
zhangbopd Dec 13, 2022
8e8dda6
Update matmul_kernel_impl_via_blasLt.h
zhangbopd Dec 13, 2022
75e83bb
for-loop-algo
zhangbopd Dec 27, 2022
192a1a8
PR comments changes
zhangbopd Jan 4, 2023
e636886
add macro
zhangbopd Jan 4, 2023
5783696
ci unused variable isCublasLt
zhangbopd Jan 4, 2023
11bf150
ci unused variable isCublasLt macro
zhangbopd Jan 4, 2023
8eb3aa8
split matmul to autotune
zhangbopd Jan 10, 2023
9405067
Merge branch 'matmul-autotune' of https://github.com/zhangbopd//Paddl…
JamesLim-sy Jan 17, 2023
c780a94
[WIP]: temporary storage of codes
zhangbopd Jan 18, 2023
307c89e
[WIP] temp storage
zhangbopd Jan 18, 2023
3dae0f9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbopd Jan 18, 2023
e0c40bc
temp storage
zhangbopd Jan 30, 2023
2e8a684
temp storage
zhangbopd Jan 30, 2023
65a77d9
add some changes
zhangbopd Jan 30, 2023
d876e7e
add some changes
zhangbopd Jan 30, 2023
faaa937
temp storage for changing cublasLtWithBatch computation
zhangbopd Jan 30, 2023
5285192
temp storage of compile-time debug
zhangbopd Jan 31, 2023
f02d2e9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbopd Jan 31, 2023
adec3bd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbopd Jan 31, 2023
4372c0d
add some changes
zhangbopd Feb 1, 2023
6791ff5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbopd Feb 1, 2023
7f5b526
fix bugs for ci
zhangbopd Feb 1, 2023
80fafc5
revert the case number written style
zhangbopd Feb 1, 2023
8fe0afe
revert the case number written style
zhangbopd Feb 1, 2023
fbda72c
add some changes
JamesLim-sy Feb 3, 2023
6ec9106
add some changes for matmul_auto_tune
JamesLim-sy Feb 6, 2023
ad58d06
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbopd Feb 8, 2023
c4a540d
revise the data format
JamesLim-sy Feb 13, 2023
fea6614
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JamesLim-sy Feb 14, 2023
335c134
fix according to CI and review advices
JamesLim-sy Feb 16, 2023
cb7c608
fix bugs according to CI
JamesLim-sy Feb 17, 2023
9cdfa2c
change according to ci
JamesLim-sy Feb 17, 2023
e2ed925
Merge branch 'develop' into add_autotune_kernel_tool
Xreki Feb 21, 2023
c1a7448
Polish codes.
Xreki Feb 21, 2023
eef4555
Warp the matmul function and revert the change of matmul_grad_kernel.
Xreki Feb 21, 2023
9044737
Simplify the codes.
Xreki Feb 21, 2023
16864be
Fix typo.
Xreki Feb 21, 2023
cc539d7
Fix compiling error.
Xreki Feb 21, 2023
cf85133
Merge branch 'develop' into add_autotune_kernel_tool
Xreki Feb 21, 2023
c35bdea
Fix compiling error when no gpu.
Xreki Feb 21, 2023
e863cbe
Add the missing argument.
Xreki Feb 22, 2023
febeb01
Merge branch 'develop' into add_autotune_kernel_tool
Xreki Feb 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions paddle/phi/kernels/autotune/auto_tune_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,43 @@ class AutoTuneBase {
}
};

// To init the auto_tuner object.
template <typename T, typename ReturnType, typename... Args>
class MatmulAutoTuner
: public AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> {
public:
static MatmulAutoTuner<T, ReturnType, Args...>* Instance(
ReturnType (*func)(Args...)) {
static std::once_flag matmul_init_flag;
static std::unique_ptr<MatmulAutoTuner<T, ReturnType, Args...>> instance;
std::call_once(matmul_init_flag, [&] {
auto obj = MakeCallback<T>(func);
instance.reset(new MatmulAutoTuner<T, ReturnType, Args...>);
instance->AddCallBack(func);
});
return instance.get();
}

template <typename Context>
void Run(const Context& ctx, const size_t key, Args... args) {
this->is_init_ = true;
this->CheckKernelSize();
auto& cache = AutoTuneCache::Instance().GetMatmul();
if (cache.Find(key)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

似乎,没有开启AutoTune功能的时候,这里会多1次查cache的开销。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块比较难避免,AutoTune关闭的状态存在于调优功能开启的之前,和之后,这里的操作逻辑与conv_udnn_v7.h中一致

auto best_idx = cache.Get(key);
this->kernels_[best_idx].Run(args...);
} else {
bool use_autotune = AutoTuneStatus::Instance().UseAutoTune();
if (use_autotune) {
auto best_idx = this->PickBestKernel(ctx, args...);
cache.Set(key, best_idx);
} else {
this->kernels_[0].Run(args...);
}
}
}
};

// Define the auto_tuner inital object.
#define DEFINE_AUTOTUNER_COMMON_OBJ(name) \
template <typename T, typename ReturnType, typename... Args> \
class name##AutoTuner \
Expand All @@ -161,18 +197,20 @@ class AutoTuneBase {
} \
};

// To init auto_tuner inital function.
// Define the auto_tuner inital function.
#define DEFINE_AUTOTUNER_FN(name) \
template <typename T, typename ReturnType, typename... Args> \
static name##AutoTuner<T, ReturnType, Args...>* Make##name##Tuner( \
ReturnType (*func)(Args...)) { \
return name##AutoTuner<T, ReturnType, Args...>::Instance(func); \
}

#define DEFINE_AUTOTUNER(name) \
DEFINE_AUTOTUNER_COMMON_OBJ(name) DEFINE_AUTOTUNER_FN(name)
#define DEFINE_AUTOTUNER(name) \
DEFINE_AUTOTUNER_COMMON_OBJ(name) \
DEFINE_AUTOTUNER_FN(name)

DEFINE_AUTOTUNER(Transpose)
DEFINE_AUTOTUNER_FN(Matmul)

#undef DEFINE_AUTOTUNER_COMMON_OBJECT
#undef DEFINE_AUTOTUNER_FN
Expand Down
26 changes: 16 additions & 10 deletions paddle/phi/kernels/autotune/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,33 @@ enum class AlgorithmType {
kConvBackwardData = 2,
kConvBackwardFilter = 3,
kTranspose = 4,
#ifdef PADDLE_WITH_CUDNN_FRONTEND
kConvForwardV8 = 5,
kConvBackwardDataV8 = 6,
kConvBackwardFilterV8 = 7,
kAlgorithmCount = 8
kMatmul = 5,
#if !defined(PADDLE_WITH_CUDNN_FRONTEND)
kAlgorithmCount = 6
#else
kAlgorithmCount = 5
kConvForwardV8 = 6,
kConvBackwardDataV8 = 7,
kConvBackwardFilterV8 = 8,
kAlgorithmCount = 9
#endif
};

// AlgorithmsConfigKey -> AlgorithmsID
// (todo. hong) use cudnnConvolutionFwdAlgo_t
using AlgorithmsCacheMap = AlgorithmsCache<size_t, int64_t>;
// AlgorithmType -> AlgorithmsCache
using AlgorithmsCacheMap = AlgorithmsCache<size_t, int64_t>;
using AlgorithmsTypeMap = std::unordered_map<int64_t, AlgorithmsCacheMap>;

// (todo. hong) use cudnnConvolutionFwdAlgo_t
using ConvAlgorithmsCacheMap = ConvAlgorithmsCache<ConvAutoTuneResult>;
using ConvAlgorithmsTypeMap =
std::unordered_map<int64_t, ConvAlgorithmsCacheMap>;

using MatmulAlgorithmsCacheMap = MatmulAlgorithmsCache<size_t, int64_t>;
#ifdef PADDLE_WITH_CUDNN_FRONTEND
using CudnnV8AlgorithmsTypeMap =
std::unordered_map<int64_t, CudnnFrontendPlanCache>;
#endif

class AutoTuneCache {
public:
static AutoTuneCache& Instance() {
Expand All @@ -77,6 +82,8 @@ class AutoTuneCache {
return auto_tune_map_[static_cast<int64_t>(algo_type)];
}

MatmulAlgorithmsCacheMap& GetMatmul() { return matmul_auto_tune_map_; }

ConvAlgorithmsCacheMap& GetConv(const AlgorithmType& algo_type) {
return conv_auto_tune_map_[static_cast<int64_t>(algo_type)];
}
Expand All @@ -87,8 +94,6 @@ class AutoTuneCache {
}
#endif

AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); }

void Clean() {
for (auto& v : auto_tune_map_) {
v.second.Clean();
Expand Down Expand Up @@ -162,6 +167,7 @@ class AutoTuneCache {

AlgorithmsTypeMap auto_tune_map_;
ConvAlgorithmsTypeMap conv_auto_tune_map_;
MatmulAlgorithmsCacheMap matmul_auto_tune_map_;
#ifdef PADDLE_WITH_CUDNN_FRONTEND
CudnnV8AlgorithmsTypeMap cudnn_v8_auto_tune_map_;
#endif
Expand Down
54 changes: 54 additions & 0 deletions paddle/phi/kernels/autotune/cache_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,31 @@ size_t GenKey(Args&&... args) {
return seed;
}

struct MatmulHashValueType {
uint64_t data[8];
};

struct MatmulCacheKey {
public:
MatmulCacheKey() {}
MatmulCacheKey(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
const bool trans_x,
const bool trans_y,
phi::DataType dtype) {
key = GenKey(x_dims,
y_dims,
static_cast<int64_t>(trans_x),
static_cast<int64_t>(trans_y),
static_cast<int64_t>(dtype));
}
size_t GetKey() const { return key; }
size_t GetSubKey(int64_t idx) const { return GenKey(key, idx); }

private:
size_t key;
};

struct ConvCacheKey {
ConvCacheKey() {}
ConvCacheKey(const std::vector<int64_t>& arg_x_dims,
Expand Down Expand Up @@ -213,5 +238,34 @@ class ConvAlgorithmsCache : public AlgorithmsCache<ConvCacheKey,
}
};

template <typename KeyT, typename AlgorithmT>
class MatmulAlgorithmsCache : public AlgorithmsCache<KeyT, AlgorithmT> {
public:
MatmulAlgorithmsCache() : AlgorithmsCache<KeyT, AlgorithmT>() {}

bool FindSubKey(const KeyT& sub_key) {
std::lock_guard<std::mutex> lock(*(this->cache_mutex_));
bool ret = (sub_hash_.find(sub_key) != sub_hash_.end()) ? true : false;
return ret;
}

void SetSubKey(const KeyT& sub_key, const MatmulHashValueType* algo) {
std::lock_guard<std::mutex> lock(*(this->cache_mutex_));
sub_hash_[sub_key] = *algo;
}

MatmulHashValueType* GetSubKey(const KeyT& sub_key) {
std::lock_guard<std::mutex> lock(*(this->cache_mutex_));
PADDLE_ENFORCE_NE(
sub_hash_.find(sub_key),
sub_hash_.end(),
phi::errors::PreconditionNotMet("The key does not exist."));
return &(sub_hash_[sub_key]);
}

private:
std::unordered_map<KeyT, MatmulHashValueType> sub_hash_;
};

} // namespace autotune
} // namespace phi
Loading