Skip to content

[CINN] x86 runtime intrinsics for composite reduce #72371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -845,10 +845,9 @@ ir::LoweredFunc OpLowererImpl::GenerateInferShapeFunc(
ir::ConvertExprBlockToStmtBlock(infer_shape_func->body);
return infer_shape_func;
}
// TODO(heqianyue): support argidx and variance op on CPU
bool IsOpDeniedOnCpu(::pir::Operation* op) {
static std::set<std::string> banned_ops = {
"cinn_op.argmax", "cinn_op.argmin", "pd_op.variance"};
// no op is denied after the support for composite reduce on cpu
static std::set<std::string> banned_ops = {};
return banned_ops.count(op->name());
}
ir::Expr OpLowererImpl::LowerX86(const OpLoweringGroupPtr& group,
Expand Down
53 changes: 53 additions & 0 deletions paddle/cinn/runtime/cinn_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,59 @@ cinn_type_t cinn_float64_t(int num_asterisks) {
return cinn_type_t(cinn_type_float, 64, num_asterisks);
}

#define ARGIDX_FUNC_MACRO_DEF_IMPL(TYPENAME, DTYPE, ITYPE) \
void min_##TYPENAME(TYPENAME* sret, const TYPENAME* a, const TYPENAME* b) { \
*sret = a->value == b->value ? (a->index < b->index ? *a : *b) \
: (a->value < b->value ? *a : *b); \
} \
void max_##TYPENAME(TYPENAME* sret, const TYPENAME* a, const TYPENAME* b) { \
*sret = a->value == b->value ? (a->index < b->index ? *a : *b) \
: (a->value > b->value ? *a : *b); \
} \
ITYPE cast_##TYPENAME(const TYPENAME* argidx) { return argidx->index; } \
void create_##TYPENAME(TYPENAME* sret, DTYPE val, ITYPE idx) { \
*sret = TYPENAME{val, idx}; \
}

#define ARGIDX_FUNC_MACRO_DEF(DNAME, DTYPE) \
ARGIDX_FUNC_MACRO_DEF_IMPL(argidx_##DNAME##_i32, DTYPE, int) \
ARGIDX_FUNC_MACRO_DEF_IMPL(argidx_##DNAME##_i64, DTYPE, int64_t)

ARGIDX_FUNC_MACRO_DEF(fp32, float)
ARGIDX_FUNC_MACRO_DEF(fp64, double)
ARGIDX_FUNC_MACRO_DEF(i16, int16_t)
ARGIDX_FUNC_MACRO_DEF(i32, int)
ARGIDX_FUNC_MACRO_DEF(i64, int64_t)
ARGIDX_FUNC_MACRO_DEF(u8, uint8_t)

#undef ARGIDX_FUNC_MACRO_DEF_IMPL
#undef ARGIDX_FUNC_MACRO_DEF

#define WELFORD_COMBINE_MACRO(TYPE_SUFFIX, DTYPE) \
void sum_welford_##TYPE_SUFFIX(welford_##TYPE_SUFFIX* sret, \
const welford_##TYPE_SUFFIX* a, \
const welford_##TYPE_SUFFIX* b) { \
DTYPE delta = b->mean - a->mean; \
DTYPE weight = a->weight + b->weight; \
DTYPE w2_over_w = \
a->weight == b->weight ? (DTYPE)0.5 : b->weight / weight; \
DTYPE mean = a->mean + delta * w2_over_w; \
DTYPE m2 = a->m2 + b->m2 + delta * delta * a->weight * w2_over_w; \
*sret = {mean, m2, weight}; \
} \
DTYPE cast_welford_##TYPE_SUFFIX(const welford_##TYPE_SUFFIX* wf) { \
return wf->m2 / wf->weight; \
} \
void create_welford_##TYPE_SUFFIX( \
welford_##TYPE_SUFFIX* sret, DTYPE m, DTYPE m2, DTYPE w) { \
*sret = welford_##TYPE_SUFFIX{m, m2, w}; \
}

WELFORD_COMBINE_MACRO(fp32, float)
WELFORD_COMBINE_MACRO(fp64, double)

#undef WELFORD_COMBINE_MACRO

} // extern "C"

struct cinn_buffer_t* cinn_buffer_t::new_(cinn_device_kind_t device,
Expand Down
81 changes: 81 additions & 0 deletions paddle/cinn/runtime/cinn_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -611,3 +611,84 @@ template <typename T>
cinn_type_t cinn_type_of();

#endif // __cplusplus

#ifdef __cplusplus
extern "C" {
#endif

#define WELFORD_STRUCT_MACRO(TYPENAME, DTYPE) \
typedef struct { \
DTYPE mean; \
DTYPE m2; \
DTYPE weight; \
} TYPENAME;

WELFORD_STRUCT_MACRO(welford_fp32, float)
WELFORD_STRUCT_MACRO(welford_fp64, double)
#undef WELFORD_STRUCT_MACRO

#define ARGIDX_STRUCT_MACRO(TYPENAME, DTYPE, ITYPE, IINIT) \
typedef struct { \
DTYPE value; \
ITYPE index; \
} TYPENAME;

#define EXPAND_ARGIDX_DTYPE_MACRO_IMPL(DTYPE, DNAME, ITYPE, INAME, IMAX) \
ARGIDX_STRUCT_MACRO(argidx_##DNAME##_##INAME, DTYPE, ITYPE, IMAX)

#define EXPAND_ARGIDX_DTYPE_MACRO(DTYPE, DNAME) \
EXPAND_ARGIDX_DTYPE_MACRO_IMPL(DTYPE, DNAME, int, i32, 0) \
EXPAND_ARGIDX_DTYPE_MACRO_IMPL(DTYPE, DNAME, int64_t, i64, 0LL)

EXPAND_ARGIDX_DTYPE_MACRO(float, fp32)
EXPAND_ARGIDX_DTYPE_MACRO(double, fp64)
EXPAND_ARGIDX_DTYPE_MACRO(int16_t, i16)
EXPAND_ARGIDX_DTYPE_MACRO(int, i32)
EXPAND_ARGIDX_DTYPE_MACRO(int64_t, i64)
EXPAND_ARGIDX_DTYPE_MACRO(uint8_t, u8)

#undef EXPAND_ARGIDX_DTYPE_MACRO
#undef EXPAND_ARGIDX_DTYPE_MACRO_IMPL
#undef ARGIDX_STRUCT_MACRO

#define ARGIDX_STRUCT_FUNC_MACRO(TYPENAME, DTYPE, ITYPE) \
ITYPE cast_##TYPENAME(const TYPENAME* argidx); \
void create_##TYPENAME(TYPENAME* sret, DTYPE val, ITYPE idx);

#define ARGIDX_COMBINE_MACRO(TYPENAME) \
void min_##TYPENAME(TYPENAME* sret, const TYPENAME* a, const TYPENAME* b); \
void max_##TYPENAME(TYPENAME* sret, const TYPENAME* a, const TYPENAME* b);

#define EXPAND_ARGIDX_FUNC_MACRO(DTYPE, DNAME) \
ARGIDX_COMBINE_MACRO(argidx_##DNAME##_##i32) \
ARGIDX_COMBINE_MACRO(argidx_##DNAME##_##i64) \
ARGIDX_STRUCT_FUNC_MACRO(argidx_##DNAME##_##i32, DTYPE, int) \
ARGIDX_STRUCT_FUNC_MACRO(argidx_##DNAME##_##i64, DTYPE, int64_t)

// TODO(heqianyue): fp16 not added
EXPAND_ARGIDX_FUNC_MACRO(float, fp32)
EXPAND_ARGIDX_FUNC_MACRO(double, fp64)
EXPAND_ARGIDX_FUNC_MACRO(int16_t, i16)
EXPAND_ARGIDX_FUNC_MACRO(int, i32)
EXPAND_ARGIDX_FUNC_MACRO(int64_t, i64)
EXPAND_ARGIDX_FUNC_MACRO(uint8_t, u8)

#define WELFORD_COMBINE_DEF_MACRO(TYPE_SUFFIX, DTYPE) \
void sum_welford_##TYPE_SUFFIX(welford_##TYPE_SUFFIX* sret, \
const welford_##TYPE_SUFFIX* a, \
const welford_##TYPE_SUFFIX* b); \
DTYPE cast_welford_##TYPE_SUFFIX(const welford_##TYPE_SUFFIX* wf); \
void create_welford_##TYPE_SUFFIX( \
welford_##TYPE_SUFFIX* sret, DTYPE m, DTYPE m2, DTYPE w);

WELFORD_COMBINE_DEF_MACRO(fp32, float)
WELFORD_COMBINE_DEF_MACRO(fp64, double)
#undef WELFORD_COMBINE_DEF_MACRO

#ifdef __cplusplus
} // extern "C"
#endif

#undef ARGIDX_COMBINE_MACRO
#undef EXPAND_ARGIDX_FUNC_MACRO
#undef ARGIDX_STRUCT_FUNC_MACRO