diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index b42e756bde7db..b744318fc7be7 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -845,10 +845,9 @@ ir::LoweredFunc OpLowererImpl::GenerateInferShapeFunc( ir::ConvertExprBlockToStmtBlock(infer_shape_func->body); return infer_shape_func; } -// TODO(heqianyue): support argidx and variance op on CPU bool IsOpDeniedOnCpu(::pir::Operation* op) { - static std::set banned_ops = { - "cinn_op.argmax", "cinn_op.argmin", "pd_op.variance"}; + // no op is denied after the support for composite reduce on cpu + static std::set banned_ops = {}; return banned_ops.count(op->name()); } ir::Expr OpLowererImpl::LowerX86(const OpLoweringGroupPtr& group, diff --git a/paddle/cinn/runtime/cinn_runtime.cc b/paddle/cinn/runtime/cinn_runtime.cc index c4c25e8f86786..0c21e436be080 100644 --- a/paddle/cinn/runtime/cinn_runtime.cc +++ b/paddle/cinn/runtime/cinn_runtime.cc @@ -170,6 +170,59 @@ cinn_type_t cinn_float64_t(int num_asterisks) { return cinn_type_t(cinn_type_float, 64, num_asterisks); } +#define ARGIDX_FUNC_MACRO_DEF_IMPL(TYPENAME, DTYPE, ITYPE) \ + void min_##TYPENAME(TYPENAME* sret, const TYPENAME* a, const TYPENAME* b) { \ + *sret = a->value == b->value ? (a->index < b->index ? *a : *b) \ + : (a->value < b->value ? *a : *b); \ + } \ + void max_##TYPENAME(TYPENAME* sret, const TYPENAME* a, const TYPENAME* b) { \ + *sret = a->value == b->value ? (a->index < b->index ? *a : *b) \ + : (a->value > b->value ? *a : *b); \ + } \ + ITYPE cast_##TYPENAME(const TYPENAME* argidx) { return argidx->index; } \ + void create_##TYPENAME(TYPENAME* sret, DTYPE val, ITYPE idx) { \ + *sret = TYPENAME{val, idx}; \ + } + +#define ARGIDX_FUNC_MACRO_DEF(DNAME, DTYPE) \ + ARGIDX_FUNC_MACRO_DEF_IMPL(argidx_##DNAME##_i32, DTYPE, int) \ + ARGIDX_FUNC_MACRO_DEF_IMPL(argidx_##DNAME##_i64, DTYPE, int64_t) + +ARGIDX_FUNC_MACRO_DEF(fp32, float) +ARGIDX_FUNC_MACRO_DEF(fp64, double) +ARGIDX_FUNC_MACRO_DEF(i16, int16_t) +ARGIDX_FUNC_MACRO_DEF(i32, int) +ARGIDX_FUNC_MACRO_DEF(i64, int64_t) +ARGIDX_FUNC_MACRO_DEF(u8, uint8_t) + +#undef ARGIDX_FUNC_MACRO_DEF_IMPL +#undef ARGIDX_FUNC_MACRO_DEF + +#define WELFORD_COMBINE_MACRO(TYPE_SUFFIX, DTYPE) \ + void sum_welford_##TYPE_SUFFIX(welford_##TYPE_SUFFIX* sret, \ + const welford_##TYPE_SUFFIX* a, \ + const welford_##TYPE_SUFFIX* b) { \ + DTYPE delta = b->mean - a->mean; \ + DTYPE weight = a->weight + b->weight; \ + DTYPE w2_over_w = \ + a->weight == b->weight ? (DTYPE)0.5 : b->weight / weight; \ + DTYPE mean = a->mean + delta * w2_over_w; \ + DTYPE m2 = a->m2 + b->m2 + delta * delta * a->weight * w2_over_w; \ + *sret = {mean, m2, weight}; \ + } \ + DTYPE cast_welford_##TYPE_SUFFIX(const welford_##TYPE_SUFFIX* wf) { \ + return wf->m2 / wf->weight; \ + } \ + void create_welford_##TYPE_SUFFIX( \ + welford_##TYPE_SUFFIX* sret, DTYPE m, DTYPE m2, DTYPE w) { \ + *sret = welford_##TYPE_SUFFIX{m, m2, w}; \ + } + +WELFORD_COMBINE_MACRO(fp32, float) +WELFORD_COMBINE_MACRO(fp64, double) + +#undef WELFORD_COMBINE_MACRO + } // extern "C" struct cinn_buffer_t* cinn_buffer_t::new_(cinn_device_kind_t device, diff --git a/paddle/cinn/runtime/cinn_runtime.h b/paddle/cinn/runtime/cinn_runtime.h index 4a5ce5d18d179..a1c2a0d4d2211 100644 --- a/paddle/cinn/runtime/cinn_runtime.h +++ b/paddle/cinn/runtime/cinn_runtime.h @@ -611,3 +611,84 @@ template cinn_type_t cinn_type_of(); #endif // __cplusplus + +#ifdef __cplusplus +extern "C" { +#endif + +#define WELFORD_STRUCT_MACRO(TYPENAME, DTYPE) \ + typedef struct { \ + DTYPE mean; \ + DTYPE m2; \ + DTYPE weight; \ + } TYPENAME; + +WELFORD_STRUCT_MACRO(welford_fp32, float) +WELFORD_STRUCT_MACRO(welford_fp64, double) +#undef WELFORD_STRUCT_MACRO + +#define ARGIDX_STRUCT_MACRO(TYPENAME, DTYPE, ITYPE, IINIT) \ + typedef struct { \ + DTYPE value; \ + ITYPE index; \ + } TYPENAME; + +#define EXPAND_ARGIDX_DTYPE_MACRO_IMPL(DTYPE, DNAME, ITYPE, INAME, IMAX) \ + ARGIDX_STRUCT_MACRO(argidx_##DNAME##_##INAME, DTYPE, ITYPE, IMAX) + +#define EXPAND_ARGIDX_DTYPE_MACRO(DTYPE, DNAME) \ + EXPAND_ARGIDX_DTYPE_MACRO_IMPL(DTYPE, DNAME, int, i32, 0) \ + EXPAND_ARGIDX_DTYPE_MACRO_IMPL(DTYPE, DNAME, int64_t, i64, 0LL) + +EXPAND_ARGIDX_DTYPE_MACRO(float, fp32) +EXPAND_ARGIDX_DTYPE_MACRO(double, fp64) +EXPAND_ARGIDX_DTYPE_MACRO(int16_t, i16) +EXPAND_ARGIDX_DTYPE_MACRO(int, i32) +EXPAND_ARGIDX_DTYPE_MACRO(int64_t, i64) +EXPAND_ARGIDX_DTYPE_MACRO(uint8_t, u8) + +#undef EXPAND_ARGIDX_DTYPE_MACRO +#undef EXPAND_ARGIDX_DTYPE_MACRO_IMPL +#undef ARGIDX_STRUCT_MACRO + +#define ARGIDX_STRUCT_FUNC_MACRO(TYPENAME, DTYPE, ITYPE) \ + ITYPE cast_##TYPENAME(const TYPENAME* argidx); \ + void create_##TYPENAME(TYPENAME* sret, DTYPE val, ITYPE idx); + +#define ARGIDX_COMBINE_MACRO(TYPENAME) \ + void min_##TYPENAME(TYPENAME* sret, const TYPENAME* a, const TYPENAME* b); \ + void max_##TYPENAME(TYPENAME* sret, const TYPENAME* a, const TYPENAME* b); + +#define EXPAND_ARGIDX_FUNC_MACRO(DTYPE, DNAME) \ + ARGIDX_COMBINE_MACRO(argidx_##DNAME##_##i32) \ + ARGIDX_COMBINE_MACRO(argidx_##DNAME##_##i64) \ + ARGIDX_STRUCT_FUNC_MACRO(argidx_##DNAME##_##i32, DTYPE, int) \ + ARGIDX_STRUCT_FUNC_MACRO(argidx_##DNAME##_##i64, DTYPE, int64_t) + +// TODO(heqianyue): fp16 not added +EXPAND_ARGIDX_FUNC_MACRO(float, fp32) +EXPAND_ARGIDX_FUNC_MACRO(double, fp64) +EXPAND_ARGIDX_FUNC_MACRO(int16_t, i16) +EXPAND_ARGIDX_FUNC_MACRO(int, i32) +EXPAND_ARGIDX_FUNC_MACRO(int64_t, i64) +EXPAND_ARGIDX_FUNC_MACRO(uint8_t, u8) + +#define WELFORD_COMBINE_DEF_MACRO(TYPE_SUFFIX, DTYPE) \ + void sum_welford_##TYPE_SUFFIX(welford_##TYPE_SUFFIX* sret, \ + const welford_##TYPE_SUFFIX* a, \ + const welford_##TYPE_SUFFIX* b); \ + DTYPE cast_welford_##TYPE_SUFFIX(const welford_##TYPE_SUFFIX* wf); \ + void create_welford_##TYPE_SUFFIX( \ + welford_##TYPE_SUFFIX* sret, DTYPE m, DTYPE m2, DTYPE w); + +WELFORD_COMBINE_DEF_MACRO(fp32, float) +WELFORD_COMBINE_DEF_MACRO(fp64, double) +#undef WELFORD_COMBINE_DEF_MACRO + +#ifdef __cplusplus +} // extern "C" +#endif + +#undef ARGIDX_COMBINE_MACRO +#undef EXPAND_ARGIDX_FUNC_MACRO +#undef ARGIDX_STRUCT_FUNC_MACRO