This repository was archived by the owner on Jan 24, 2024. It is now read-only.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
argmin/argmax float64 类型输入报错
错误的原因是使用了
cinn/runtime/cuda/cinn_cuda_runtime_source.cuh
中的函数cinn_cuda_lt_num_float
,float64 类型和cinn_cuda_lt_num_float
函数声明的参数类型没有匹配,因此抛出了 double 和 float incompatible 的错误。解决办法
添加 host/cuda intrinsics 中关于
target_lt_num_type
和target_gt_num_type
相关的函数,支持 int32, int64, float, double 支持类型更多。这个 PR 只修复了
target_lt_num_type
和target_gt_num_type
,类似的问题在cinn_cuda_find_int
和cinn_cuda_find_int_nd
等函数上也会出现,将会在其他 PR 中修复,虽然可以通过在 op mapper 中加一个 cast 可以避免问题。