Skip to content

[CINN] CodeGenLLVM supports composite reduce #72368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 150 additions & 1 deletion paddle/cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ CodeGenLLVM::CodeGenLLVM(llvm::Module *m,
md_tbaa_root_ = md_builder_->createTBAARoot("cinn-tbaa");
md_tbaa_alias_set_ = md_builder_->createTBAANode("cinn-alias", md_tbaa_root_);
InitTarget(target_);
RegisterCustomizedPODStructType();
}

CodeGenLLVM::~CodeGenLLVM() {}
Expand Down Expand Up @@ -382,6 +383,91 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Min *op) {
return Select(p, lhs, rhs);
}

inline bool FuncHasStructSRet(const llvm::Function *callee) {
// StructSRet is constructed by hand, not the LLVM
// function attribute (StructSRet)
const auto *func_type = callee->getFunctionType();
llvm::Type *param_ty = func_type->getParamType(0);
llvm::Type *ret_type = func_type->getReturnType();
return param_ty->isPointerTy() && ret_type->isVoidTy();
}

std::vector<llvm::Value *> AdaptABIArguments(
llvm::IRBuilder<> *builder,
llvm::Function *callee,
std::vector<llvm::Value *> &&input_args) {
// this function decompose the input arguments to primitive elements
// and reconstruct the input arguments according to callee func type
// for example: welford_fp64 (welford_fp64, welford_fp32) ->
// (*struct.welford_fp64*, *struct.welford_fp64*, {<2x float>, float})
// struct_to_int64: if true, struct of 64bit size will be turned into i64
// by bit casting
std::vector<llvm::Value *> output_args;

auto input_it = input_args.begin();
const auto *func_type = callee->getFunctionType();
const int num_params = func_type->getNumParams();

output_args.reserve(num_params);

int param_idx = 0;
// Case 1: check whether we have a pointer for struct ret
if (FuncHasStructSRet(callee)) {
llvm::Type *param_ty = func_type->getParamType(0);
llvm::Value *sret_ptr = builder->CreateAlloca(
param_ty->getPointerElementType(), nullptr, "sret.ptr");
output_args.push_back(sret_ptr);
param_idx++;
}

// traverse all the param in the target func type
for (; param_idx < num_params; param_idx++) {
llvm::Type *param_ty = func_type->getParamType(param_idx);

if (input_it == input_args.end()) break;
llvm::Value *current_input = *input_it;
llvm::Type *input_type = current_input->getType();

if (input_type == param_ty) {
// Case 2: if type matches, just pass it directly
output_args.push_back(current_input);
input_it++;
} else if (param_ty->isPointerTy()) {
// Case 3: if the input type is pointer,
// we need to create a local buffer
PADDLE_ENFORCE_EQ(input_type,
param_ty->getPointerElementType(),
::common::errors::PreconditionNotMet(
"Pointer parameter type mismatch"));
auto *input_ptr = builder->CreateAlloca(input_type, nullptr, "input.ptr");
builder->CreateStore(current_input, input_ptr);
output_args.push_back(input_ptr);
input_it++;
} else if (param_ty->isIntegerTy() && input_type->isIntegerTy()) {
// Case 4: for argidx type, tensor index might be deterministically casted
// to int64_t, so we need to cast the integer type here
if (input_type->getPrimitiveSizeInBits() <
param_ty->getPrimitiveSizeInBits()) { // sign ext
output_args.push_back(
builder->CreateSExt(current_input, param_ty, "sext"));
} else {
output_args.push_back(
builder->CreateTrunc(current_input, param_ty, "trunc"));
}
input_it++;
} else {
PADDLE_THROW(
::common::errors::Fatal("Unhandled case for ABI param adaptation."));
}
}

PADDLE_ENFORCE_EQ(input_it,
input_args.end(),
::common::errors::PreconditionNotMet(
"Not all input args are consumed by the callee"));
return output_args;
}

llvm::Value *CodeGenLLVM::Visit(const ir::Max *op) {
auto *lhs = Visit(&op->a());
auto *rhs = Visit(&op->b());
Expand All @@ -407,10 +493,39 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Not *op) {
return Not(Visit(&op->v()));
}

llvm::Value *CodeGenLLVM::CastCompositeType(const ir::Expr &op_v) {
if (op_v.type().is_customized_type()) {
std::string from_type_name = op_v.type().to_string();
if (from_type_name.find("argidx_") != std::string::npos ||
from_type_name.find("welford_") != std::string::npos) {
auto callee = m_->getFunction("cast_" + from_type_name);
CHECK(callee) << "type casting function is null";
auto func_type = callee->getFunctionType();
auto value = Visit(&op_v);
auto params = AdaptABIArguments(b_, callee, {value});
llvm::Value *call_handle = Call(callee, params, "pod_value_cast");
if (FuncHasStructSRet(callee)) {
// currently, the functions are relatively simple, sret arg can only
// be found in the return type, which is the first argument
// sret created a stack temp value, we need to load from it
auto *sret_ptr = params[0];
auto *sret_ty = sret_ptr->getType()->getPointerElementType();
return b_->CreateLoad(sret_ty, sret_ptr);
} else {
return call_handle;
}
}
}
return nullptr;
}

llvm::Value *CodeGenLLVM::Visit(const ir::Cast *op) {
auto from = op->v().type();
auto to = op->type();

if (auto cast_call = CastCompositeType(op->v())) {
return cast_call;
}
llvm::Type *source = CinnTypeToLLVMType(from, m_);
llvm::Type *target = CinnTypeToLLVMType(to, m_);
CHECK(source) << "source ir type is null";
Expand Down Expand Up @@ -759,7 +874,16 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) {
args[0] = BitCast(args[0], ll_void_p_ty(), "cast_to_void_p");
}

return Call(callee, std::move(args));
auto params = AdaptABIArguments(b_, callee, std::move(args));
llvm::Value *ret = Call(callee, params);
if (FuncHasStructSRet(callee)) {
// void return type and the first param is a pointer
// will be considered as sret callee
auto *sret_ptr = params[0];
auto *sret_ty = sret_ptr->getType()->getPointerElementType();
ret = b_->CreateLoad(sret_ty, params.front());
}
return ret;
}

llvm::Value *CodeGenLLVM::Visit(const ir::_Module_ *op) {
Expand Down Expand Up @@ -1695,5 +1819,30 @@ llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::PodValueToX *op) {
return Call(callee, std::vector<llvm::Value *>({value}), "pod_value_cast");
}

void CodeGenLLVM::RegisterCustomizedPODStructType() {
// some of the POD struct type defined in cinn_runtime.h
// might be missing when loaded by LLVM, for unknown reasons
// To make sure they can be found, we explicitly register them here
#define REGISTER_STRUCT_TYPE(name, ...) \
llvm::StructType::create({__VA_ARGS__}, "struct." #name, /*isPacked=*/false);

#define REGISTER_ARGIDX_TYPE(dname, dtype) \
REGISTER_STRUCT_TYPE(argidx_##dname##_i32, dtype, ll_int32_ty()) \
REGISTER_STRUCT_TYPE(argidx_##dname##_i64, dtype, ll_int64_ty())

REGISTER_STRUCT_TYPE(welford_fp32, ll_fp32_ty(), ll_fp32_ty(), ll_fp32_ty())
REGISTER_STRUCT_TYPE(welford_fp64, ll_fp64_ty(), ll_fp64_ty(), ll_fp64_ty())

REGISTER_ARGIDX_TYPE(fp32, ll_fp32_ty())
REGISTER_ARGIDX_TYPE(fp64, ll_fp64_ty())
REGISTER_ARGIDX_TYPE(i16, ll_int16_ty())
REGISTER_ARGIDX_TYPE(i32, ll_int32_ty())
REGISTER_ARGIDX_TYPE(i64, ll_int64_ty())
REGISTER_ARGIDX_TYPE(u8, ll_uint8_ty())

#undef REGISTER_ARGIDX_TYPE
#undef REGISTER_STRUCT_TYPE
}

} // namespace backends
} // namespace cinn
3 changes: 3 additions & 0 deletions paddle/cinn/backends/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin<CodeGenLLVM> {
void Scalarize(const Expr &e,
std::function<void(int i, llvm::Value *v)> flambda);

llvm::Value *CastCompositeType(const ir::Expr &op_v);
void RegisterCustomizedPODStructType();

llvm::Module *m_;
llvm::IRBuilder<> *b_;
// Current function
Expand Down