Skip to content

Matmul performance optimization with cuBlasLt #46431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 76 commits into from
Feb 26, 2023

Conversation

JamesLim-sy
Copy link
Contributor

@JamesLim-sy JamesLim-sy commented Sep 23, 2022

PR types

Performance optimization

PR changes

OPs

Describe

  • Feature: Matmul with cuBlasLt and autotune.

  • Ps: After Jan 13, all job was committed by @JamesLim-sy and @Xreki,however @JamesLim-sy`s linux enviroment was changed by other people, also the github-username was changed. What a terrible mistake.

@paddle-bot
Copy link

paddle-bot bot commented Sep 23, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@JamesLim-sy JamesLim-sy force-pushed the add_autotune_kernel_tool branch from de26777 to fbda72c Compare February 3, 2023 03:25
ReturnType (*func)(Args...)) {
static std::once_flag transpose_init_flag_;
static std::unique_ptr<
AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>>
static std::unique_ptr<TransposeAutoTuner<T, ReturnType, Args...>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

作为函数内部的局部变量,变量名不要加_后缀。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

});
return instance_.get();
}

template <typename Context>
void RunMatmul(const Context& ctx, const size_t key, Args... args) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不如在基类里面封装一个RunImpl函数,基类的Run函数里面直接调用RunImpl,这里则重写Run函数。对外的接口不要修改。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已根据建议修改,仅仅在外部的MatmulAutoTuner类中重写Run函数即可.

this->is_init_ = true;
this->CheckKernelSize();
auto& cache = AutoTuneCache::Instance().GetMatmul();
if (cache.Find(key)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

似乎,没有开启AutoTune功能的时候,这里会多1次查cache的开销。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块比较难避免,AutoTune关闭的状态存在于调优功能开启的之前,和之后,这里的操作逻辑与conv_udnn_v7.h中一致

static MatmulAutoTuner<T, ReturnType, Args...>* MakeMatmulTuner(
ReturnType (*func)(Args...)) {
return MatmulAutoTuner<T, ReturnType, Args...>::Instance(func);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

定义个宏吧,DEFINE_AUTOTUNER_FN

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议已修改

const size_t GetSubKey(int64_t idx) { return GetKey(key_, idx); }

private:
int size_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

struct成员不用加_。另外,size_这个成员没有用到。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改.


struct MatmulDescCreator {
public:
static void Create(cublasLtMatmulDesc_t* op_desc,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥要定义成全静态函数呢?为啥不把op_desc、x_desc等直接作为类成员?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

仿照Conv的实现,减少host端对象的构造和析构动作

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实没必要,cublasLtMatmulDesc_t第类型实际上都是真正,真正的开心在CreateDestroy等函数。

};

template <typename T>
struct MatmulWithCublasLt<phi::GPUContext, T> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cublaslt只支持GPU啊,Context没有必要作为模板传入,这一层特化看起来可以避免。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块写的确实挫,太low了,已改掉


double alpha64 = 1.0, beta64 = 0.0;
float alpha32 = 1.0f, beta32 = 0.0f;
void *alpha = nullptr, *beta = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以用MPType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

}

template <typename Context, typename T>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一层封装是为啥?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为blas是同时支持了CPU\GPU,但是Blaslt 目前只做GPU支持,但是兼容了blas之后,Context传过来的可能是CPUContext也可能是GPUContext

@@ -64,10 +64,10 @@ void SliceCompute(const Context& ctx,
}
}

funcs::CheckAndUpdateSliceAttrs<int64_t>(in_dims, axes, &starts, &ends);
// funcs::CheckAndUpdateSliceAttrs<int64_t>(in_dims, axes, &starts, &ends);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是你改的,还是张博改的?弄个干净的分支吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是让张博改的,但是我想先试试效果,就本地也改了做测试

Copy link
Contributor Author

@JamesLim-sy JamesLim-sy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前,暂时先回复了auto_tune部分的review建议

ReturnType (*func)(Args...)) {
static std::once_flag transpose_init_flag_;
static std::unique_ptr<
AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>>
static std::unique_ptr<TransposeAutoTuner<T, ReturnType, Args...>>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

static MatmulAutoTuner<T, ReturnType, Args...>* MakeMatmulTuner(
ReturnType (*func)(Args...)) {
return MatmulAutoTuner<T, ReturnType, Args...>::Instance(func);
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议已修改

@@ -64,10 +64,10 @@ void SliceCompute(const Context& ctx,
}
}

funcs::CheckAndUpdateSliceAttrs<int64_t>(in_dims, axes, &starts, &ends);
// funcs::CheckAndUpdateSliceAttrs<int64_t>(in_dims, axes, &starts, &ends);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是让张博改的,但是我想先试试效果,就本地也改了做测试

});
return instance_.get();
}

template <typename Context>
void RunMatmul(const Context& ctx, const size_t key, Args... args) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已根据建议修改,仅仅在外部的MatmulAutoTuner类中重写Run函数即可.

this->is_init_ = true;
this->CheckKernelSize();
auto& cache = AutoTuneCache::Instance().GetMatmul();
if (cache.Find(key)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块比较难避免,AutoTune关闭的状态存在于调优功能开启的之前,和之后,这里的操作逻辑与conv_udnn_v7.h中一致

static_cast<int64_t>(dtype_));
}

const size_t QueryKey() const { return key_; }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L57行其实用GenKey更合适。
这块我另开一个PR改掉(再混一个PR50516


const size_t QueryKey() const { return key_; }
const size_t GetSize() { return x_dims_.size(); }
const size_t GetSubKey(int64_t idx) { return GetKey(key_, idx); }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes,就是GenSubKey的含义

const size_t GetSubKey(int64_t idx) { return GetKey(key_, idx); }

private:
int size_;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改.

std::vector<int64_t> y_dims_;
bool trans_x_;
bool trans_y_;
int best_algo_;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除,并额外删除trans_x_trans_y_两个变量

@Xreki Xreki force-pushed the add_autotune_kernel_tool branch from 8e24aa4 to c1a7448 Compare February 21, 2023 06:58
@Xreki Xreki force-pushed the add_autotune_kernel_tool branch from 762dcfd to 9044737 Compare February 21, 2023 08:22
@Xreki Xreki force-pushed the add_autotune_kernel_tool branch from 73efccb to c35bdea Compare February 21, 2023 13:24
@JamesLim-sy JamesLim-sy force-pushed the add_autotune_kernel_tool branch 2 times, most recently from ca1a089 to febeb01 Compare February 23, 2023 02:49
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, 模型验证精度没有问题,PR先合入,MatMulFunctionImplWithCublasLt实现有比较多的冗余,需要后续优化。

@Xreki Xreki merged commit d4217fc into PaddlePaddle:develop Feb 26, 2023
@JamesLim-sy
Copy link
Contributor Author

LGTM, 模型验证精度没有问题,PR先合入,MatMulFunctionImplWithCublasLt实现有比较多的冗余,需要后续优化。

考虑到Matmul OP的重要性,完整的验证工作会在后续补全,如有问题会持持续跟踪调整.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants