Skip to content

Implement support for become and explicit tail call codegen for the LLVM backend #144232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions compiler/rustc_codegen_gcc/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1742,6 +1742,21 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
call
}

fn tail_call(
&mut self,
_llty: Self::Type,
_fn_attrs: Option<&CodegenFnAttrs>,
_fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
_llfn: Self::Value,
_args: &[Self::Value],
_funclet: Option<&Self::Funclet>,
_instance: Option<Instance<'tcx>>,
) {
bug!(
"Guaranteed tail calls with the 'become' keyword are not implemented in the GCC backend"
);
}

fn zext(&mut self, value: RValue<'gcc>, dest_typ: Type<'gcc>) -> RValue<'gcc> {
// FIXME(antoyo): this does not zero-extend.
self.gcc_int_cast(value, dest_typ)
Expand Down
26 changes: 25 additions & 1 deletion compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::*;
use rustc_data_structures::small_c_str::SmallCStr;
use rustc_hir::def_id::DefId;
use rustc_middle::bug;
use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs;
use rustc_middle::ty::layout::{
FnAbiError, FnAbiOfHelpers, FnAbiRequest, HasTypingEnv, LayoutError, LayoutOfHelpers,
Expand All @@ -23,7 +24,7 @@ use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
use rustc_sanitizers::{cfi, kcfi};
use rustc_session::config::OptLevel;
use rustc_span::Span;
use rustc_target::callconv::FnAbi;
use rustc_target::callconv::{FnAbi, PassMode};
use rustc_target::spec::{HasTargetSpec, SanitizerSet, Target};
use smallvec::SmallVec;
use tracing::{debug, instrument};
Expand Down Expand Up @@ -1362,6 +1363,29 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
call
}

fn tail_call(
&mut self,
llty: Self::Type,
fn_attrs: Option<&CodegenFnAttrs>,
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
llfn: Self::Value,
args: &[Self::Value],
funclet: Option<&Self::Funclet>,
instance: Option<Instance<'tcx>>,
) {
let call = self.call(llty, fn_attrs, Some(fn_abi), llfn, args, funclet, instance);

match &fn_abi.ret.mode {
PassMode::Ignore | PassMode::Indirect { .. } => self.ret_void(),
PassMode::Direct(_) | PassMode::Pair { .. } => self.ret(call),
mode @ PassMode::Cast { .. } => {
bug!("Encountered `PassMode::{mode:?}` during codegen")
}
}

llvm::LLVMRustSetTailCallKind(call, llvm::TailCallKind::MustTail);
}

fn zext(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
unsafe { llvm::LLVMBuildZExt(self.llbuilder, val, dest_ty, UNNAMED) }
}
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ pub(crate) enum ModuleFlagMergeBehavior {

// Consts for the LLVM CallConv type, pre-cast to usize.

#[derive(Copy, Clone, PartialEq, Debug)]
#[repr(C)]
#[allow(dead_code)]
pub(crate) enum TailCallKind {
None = 0,
Tail = 1,
MustTail = 2,
NoTail = 3,
}

/// LLVM CallingConv::ID. Should we wrap this?
///
/// See <https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/CallingConv.h>
Expand Down Expand Up @@ -1181,6 +1191,7 @@ unsafe extern "C" {
pub(crate) safe fn LLVMIsGlobalConstant(GlobalVar: &Value) -> Bool;
pub(crate) safe fn LLVMSetGlobalConstant(GlobalVar: &Value, IsConstant: Bool);
pub(crate) safe fn LLVMSetTailCall(CallInst: &Value, IsTailCall: Bool);
pub(crate) safe fn LLVMRustSetTailCallKind(CallInst: &Value, Kind: TailCallKind);

// Operations on attributes
pub(crate) fn LLVMCreateStringAttribute(
Expand Down
63 changes: 52 additions & 11 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ impl<'a, 'tcx> TerminatorCodegenHelper<'tcx> {
mut unwind: mir::UnwindAction,
lifetime_ends_after_call: &[(Bx::Value, Size)],
instance: Option<Instance<'tcx>>,
tail: bool,
mergeable_succ: bool,
) -> MergingSucc {
let tcx = bx.tcx();
Expand Down Expand Up @@ -221,6 +222,11 @@ impl<'a, 'tcx> TerminatorCodegenHelper<'tcx> {
}
};

if tail {
bx.tail_call(fn_ty, fn_attrs, fn_abi, fn_ptr, llargs, self.funclet(fx), instance);
return MergingSucc::False;
}

if let Some(unwind_block) = unwind_block {
let ret_llbb = if let Some((_, target)) = destination {
fx.llbb(target)
Expand Down Expand Up @@ -659,6 +665,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
unwind,
&[],
Some(drop_instance),
false,
!maybe_null && mergeable_succ,
)
}
Expand Down Expand Up @@ -747,8 +754,19 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let (fn_abi, llfn, instance) = common::build_langcall(bx, span, lang_item);

// Codegen the actual panic invoke/call.
let merging_succ =
helper.do_call(self, bx, fn_abi, llfn, &args, None, unwind, &[], Some(instance), false);
let merging_succ = helper.do_call(
self,
bx,
fn_abi,
llfn,
&args,
None,
unwind,
&[],
Some(instance),
false,
false,
);
assert_eq!(merging_succ, MergingSucc::False);
MergingSucc::False
}
Expand Down Expand Up @@ -778,6 +796,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
&[],
Some(instance),
false,
false,
);
assert_eq!(merging_succ, MergingSucc::False);
}
Expand Down Expand Up @@ -845,6 +864,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
unwind,
&[],
Some(instance),
false,
mergeable_succ,
))
}
Expand All @@ -860,6 +880,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
target: Option<mir::BasicBlock>,
unwind: mir::UnwindAction,
fn_span: Span,
tail: bool,
mergeable_succ: bool,
) -> MergingSucc {
let source_info = mir::SourceInfo { span: fn_span, ..terminator.source_info };
Expand Down Expand Up @@ -1003,8 +1024,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
// We still need to call `make_return_dest` even if there's no `target`, since
// `fn_abi.ret` could be `PassMode::Indirect`, even if it is uninhabited,
// and `make_return_dest` adds the return-place indirect pointer to `llargs`.
let return_dest = self.make_return_dest(bx, destination, &fn_abi.ret, &mut llargs);
let destination = target.map(|target| (return_dest, target));
let destination = if !tail {
let return_dest = self.make_return_dest(bx, destination, &fn_abi.ret, &mut llargs);
target.map(|target| (return_dest, target))
} else {
None
};

// Split the rust-call tupled arguments off.
let (first_args, untuple) = if sig.abi() == ExternAbi::RustCall
Expand All @@ -1020,6 +1045,13 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
// to generate `lifetime_end` when the call returns.
let mut lifetime_ends_after_call: Vec<(Bx::Value, Size)> = Vec::new();
'make_args: for (i, arg) in first_args.iter().enumerate() {
if tail && matches!(fn_abi.args[i].mode, PassMode::Indirect { .. }) {
span_bug!(
fn_span,
"arguments using PassMode::Indirect are currently not supported for tail calls"
);
}

let mut op = self.codegen_operand(bx, &arg.node);

if let (0, Some(ty::InstanceKind::Virtual(_, idx))) = (i, instance.map(|i| i.def)) {
Expand Down Expand Up @@ -1147,6 +1179,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
unwind,
&lifetime_ends_after_call,
instance,
tail,
mergeable_succ,
)
}
Expand Down Expand Up @@ -1388,15 +1421,23 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
target,
unwind,
fn_span,
false,
mergeable_succ(),
),
mir::TerminatorKind::TailCall { .. } => {
// FIXME(explicit_tail_calls): implement tail calls in ssa backend
span_bug!(
terminator.source_info.span,
"`TailCall` terminator is not yet supported by `rustc_codegen_ssa`"
)
}
mir::TerminatorKind::TailCall { ref func, ref args, fn_span } => self
.codegen_call_terminator(
helper,
bx,
terminator,
func,
args,
mir::Place::from(mir::RETURN_PLACE),
None,
mir::UnwindAction::Unreachable,
fn_span,
true,
mergeable_succ(),
),
mir::TerminatorKind::CoroutineDrop | mir::TerminatorKind::Yield { .. } => {
bug!("coroutine ops in codegen")
}
Expand Down
12 changes: 12 additions & 0 deletions compiler/rustc_codegen_ssa/src/traits/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,18 @@ pub trait BuilderMethods<'a, 'tcx>:
funclet: Option<&Self::Funclet>,
instance: Option<Instance<'tcx>>,
) -> Self::Value;

fn tail_call(
&mut self,
llty: Self::Type,
fn_attrs: Option<&CodegenFnAttrs>,
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
llfn: Self::Value,
args: &[Self::Value],
funclet: Option<&Self::Funclet>,
instance: Option<Instance<'tcx>>,
);

fn zext(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value;

fn apply_attrs_to_cleanup_callsite(&mut self, llret: Self::Value);
Expand Down
26 changes: 26 additions & 0 deletions compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1949,3 +1949,29 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
MD.NoHWAddress = true;
GV.setSanitizerMetadata(MD);
}

enum LLVMRustTailCallKind {
LLVMRustTailCallKindNone = 0,
LLVMRustTailCallKindTail = 1,
LLVMRustTailCallKindMustTail = 2,
LLVMRustTailCallKindNoTail = 3
};

extern "C" void LLVMRustSetTailCallKind(LLVMValueRef Call,
LLVMRustTailCallKind Kind) {
CallInst *CI = unwrap<CallInst>(Call);
switch (Kind) {
case LLVMRustTailCallKindNone:
CI->setTailCallKind(CallInst::TCK_None);
break;
case LLVMRustTailCallKindTail:
CI->setTailCallKind(CallInst::TCK_Tail);
break;
case LLVMRustTailCallKindMustTail:
CI->setTailCallKind(CallInst::TCK_MustTail);
break;
case LLVMRustTailCallKindNoTail:
CI->setTailCallKind(CallInst::TCK_NoTail);
break;
}
}
18 changes: 18 additions & 0 deletions tests/codegen/become-musttail.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//@ compile-flags: -C opt-level=0 -Cpanic=abort -C no-prepopulate-passes
//@ needs-unwind

#![crate_type = "lib"]
#![feature(explicit_tail_calls)]

// CHECK-LABEL: define {{.*}}@fibonacci(
#[no_mangle]
#[inline(never)]
pub fn fibonacci(n: u64, a: u64, b: u64) -> u64 {
// CHECK: musttail call {{.*}}@fibonacci(
// CHECK-NEXT: ret i64
match n {
0 => a,
1 => b,
_ => become fibonacci(n - 1, b, a + b),
}
}
15 changes: 15 additions & 0 deletions tests/ui/explicit-tail-calls/recursion-no-tce.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//@ run-crash

use std::hint::black_box;

pub fn count(curr: u64, top: u64) -> u64 {
if black_box(curr) >= top {
curr
} else {
count(curr + 1, top)
}
}

fn main() {
println!("{}", count(0, black_box(1000000)));
}
17 changes: 17 additions & 0 deletions tests/ui/explicit-tail-calls/recursion-tce.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//@ run-pass
#![expect(incomplete_features)]
#![feature(explicit_tail_calls)]

use std::hint::black_box;

pub fn count(curr: u64, top: u64) -> u64 {
if black_box(curr) >= top {
curr
} else {
become count(curr + 1, top)
}
}

fn main() {
println!("{}", count(0, black_box(1000000)));
}
Loading