Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

fix argmax/argmin float64 error #1312

Merged
merged 5 commits into from
Mar 31, 2023

Conversation

zzk0
Copy link
Contributor

@zzk0 zzk0 commented Mar 28, 2023

argmin/argmax float64 类型输入报错

F0328 08:07:48.080608 588260 nvrtc_util.cc:107] Check failed: compile_res == NVRTC_SUCCESS (6 vs. 0)
default_program(21): error: argument of type "const double *" is incompatible with parameter of type "const float *"

错误的原因是使用了 cinn/runtime/cuda/cinn_cuda_runtime_source.cuh 中的函数 cinn_cuda_lt_num_float,float64 类型和 cinn_cuda_lt_num_float 函数声明的参数类型没有匹配,因此抛出了 double 和 float incompatible 的错误。

解决办法

添加 host/cuda intrinsics 中关于 target_lt_num_typetarget_gt_num_type 相关的函数,支持 int32, int64, float, double 支持类型更多。

  • cuda intrinsics 的修改,支持 int32, int64, float, double 四种类型
  • host intrinsics 的修改,支持 int32, int64, float, double 四种类型
  • cuda intrinsics 函数注册,注册新编写的函数
  • host intrinsics 函数注册,注册新编写的函数
  • 提供相关工具方法调用
  • 修改调用接口的地方,改为新的名称

这个 PR 只修复了 target_lt_num_typetarget_gt_num_type,类似的问题在 cinn_cuda_find_intcinn_cuda_find_int_nd 等函数上也会出现,将会在其他 PR 中修复,虽然可以通过在 op mapper 中加一个 cast 可以避免问题。

Copy link
Collaborator

@thisjiang thisjiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greater Job!

@thisjiang thisjiang merged commit 9422da4 into PaddlePaddle:develop Mar 31, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants