diff --git a/compiler/rustc_codegen_gcc/src/builder.rs b/compiler/rustc_codegen_gcc/src/builder.rs index a4ec4bf8deac4..48e11c755a4f0 100644 --- a/compiler/rustc_codegen_gcc/src/builder.rs +++ b/compiler/rustc_codegen_gcc/src/builder.rs @@ -1742,6 +1742,21 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> { call } + fn tail_call( + &mut self, + _llty: Self::Type, + _fn_attrs: Option<&CodegenFnAttrs>, + _fn_abi: &FnAbi<'tcx, Ty<'tcx>>, + _llfn: Self::Value, + _args: &[Self::Value], + _funclet: Option<&Self::Funclet>, + _instance: Option>, + ) { + bug!( + "Guaranteed tail calls with the 'become' keyword are not implemented in the GCC backend" + ); + } + fn zext(&mut self, value: RValue<'gcc>, dest_typ: Type<'gcc>) -> RValue<'gcc> { // FIXME(antoyo): this does not zero-extend. self.gcc_int_cast(value, dest_typ) diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 0ade9edb0d2ea..56f9cd080ffd6 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -15,6 +15,7 @@ use rustc_codegen_ssa::mir::place::PlaceRef; use rustc_codegen_ssa::traits::*; use rustc_data_structures::small_c_str::SmallCStr; use rustc_hir::def_id::DefId; +use rustc_middle::bug; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; use rustc_middle::ty::layout::{ FnAbiError, FnAbiOfHelpers, FnAbiRequest, HasTypingEnv, LayoutError, LayoutOfHelpers, @@ -24,7 +25,7 @@ use rustc_middle::ty::{self, Instance, Ty, TyCtxt}; use rustc_sanitizers::{cfi, kcfi}; use rustc_session::config::OptLevel; use rustc_span::Span; -use rustc_target::callconv::FnAbi; +use rustc_target::callconv::{FnAbi, PassMode}; use rustc_target::spec::{HasTargetSpec, SanitizerSet, Target}; use smallvec::SmallVec; use tracing::{debug, instrument}; @@ -1431,6 +1432,29 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { call } + fn tail_call( + &mut self, + llty: Self::Type, + fn_attrs: Option<&CodegenFnAttrs>, + fn_abi: &FnAbi<'tcx, Ty<'tcx>>, + llfn: Self::Value, + args: &[Self::Value], + funclet: Option<&Self::Funclet>, + instance: Option>, + ) { + let call = self.call(llty, fn_attrs, Some(fn_abi), llfn, args, funclet, instance); + + match &fn_abi.ret.mode { + PassMode::Ignore | PassMode::Indirect { .. } => self.ret_void(), + PassMode::Direct(_) | PassMode::Pair { .. } => self.ret(call), + mode @ PassMode::Cast { .. } => { + bug!("Encountered `PassMode::{mode:?}` during codegen") + } + } + + llvm::LLVMRustSetTailCallKind(call, llvm::TailCallKind::MustTail); + } + fn zext(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value { unsafe { llvm::LLVMBuildZExt(self.llbuilder, val, dest_ty, UNNAMED) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index edfb29dd1be72..9d4bd70549b46 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -97,6 +97,16 @@ pub(crate) enum ModuleFlagMergeBehavior { // Consts for the LLVM CallConv type, pre-cast to usize. +#[derive(Copy, Clone, PartialEq, Debug)] +#[repr(C)] +#[allow(dead_code)] +pub(crate) enum TailCallKind { + None = 0, + Tail = 1, + MustTail = 2, + NoTail = 3, +} + /// LLVM CallingConv::ID. Should we wrap this? /// /// See @@ -1186,6 +1196,7 @@ unsafe extern "C" { pub(crate) safe fn LLVMIsGlobalConstant(GlobalVar: &Value) -> Bool; pub(crate) safe fn LLVMSetGlobalConstant(GlobalVar: &Value, IsConstant: Bool); pub(crate) safe fn LLVMSetTailCall(CallInst: &Value, IsTailCall: Bool); + pub(crate) safe fn LLVMRustSetTailCallKind(CallInst: &Value, Kind: TailCallKind); // Operations on attributes pub(crate) fn LLVMCreateStringAttribute( diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index bde63fd501aa2..91b3d2a1a26fa 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -160,6 +160,7 @@ impl<'a, 'tcx> TerminatorCodegenHelper<'tcx> { mut unwind: mir::UnwindAction, lifetime_ends_after_call: &[(Bx::Value, Size)], instance: Option>, + tail: bool, mergeable_succ: bool, ) -> MergingSucc { let tcx = bx.tcx(); @@ -221,6 +222,11 @@ impl<'a, 'tcx> TerminatorCodegenHelper<'tcx> { } }; + if tail { + bx.tail_call(fn_ty, fn_attrs, fn_abi, fn_ptr, llargs, self.funclet(fx), instance); + return MergingSucc::False; + } + if let Some(unwind_block) = unwind_block { let ret_llbb = if let Some((_, target)) = destination { fx.llbb(target) @@ -659,6 +665,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { unwind, &[], Some(drop_instance), + false, !maybe_null && mergeable_succ, ) } @@ -747,8 +754,19 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let (fn_abi, llfn, instance) = common::build_langcall(bx, span, lang_item); // Codegen the actual panic invoke/call. - let merging_succ = - helper.do_call(self, bx, fn_abi, llfn, &args, None, unwind, &[], Some(instance), false); + let merging_succ = helper.do_call( + self, + bx, + fn_abi, + llfn, + &args, + None, + unwind, + &[], + Some(instance), + false, + false, + ); assert_eq!(merging_succ, MergingSucc::False); MergingSucc::False } @@ -778,6 +796,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { &[], Some(instance), false, + false, ); assert_eq!(merging_succ, MergingSucc::False); } @@ -845,6 +864,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { unwind, &[], Some(instance), + false, mergeable_succ, )) } @@ -860,6 +880,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { target: Option, unwind: mir::UnwindAction, fn_span: Span, + tail: bool, mergeable_succ: bool, ) -> MergingSucc { let source_info = mir::SourceInfo { span: fn_span, ..terminator.source_info }; @@ -1003,8 +1024,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { // We still need to call `make_return_dest` even if there's no `target`, since // `fn_abi.ret` could be `PassMode::Indirect`, even if it is uninhabited, // and `make_return_dest` adds the return-place indirect pointer to `llargs`. - let return_dest = self.make_return_dest(bx, destination, &fn_abi.ret, &mut llargs); - let destination = target.map(|target| (return_dest, target)); + let destination = if !tail { + let return_dest = self.make_return_dest(bx, destination, &fn_abi.ret, &mut llargs); + target.map(|target| (return_dest, target)) + } else { + None + }; // Split the rust-call tupled arguments off. let (first_args, untuple) = if sig.abi() == ExternAbi::RustCall @@ -1020,6 +1045,13 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { // to generate `lifetime_end` when the call returns. let mut lifetime_ends_after_call: Vec<(Bx::Value, Size)> = Vec::new(); 'make_args: for (i, arg) in first_args.iter().enumerate() { + if tail && matches!(fn_abi.args[i].mode, PassMode::Indirect { .. }) { + span_bug!( + fn_span, + "arguments using PassMode::Indirect are currently not supported for tail calls" + ); + } + let mut op = self.codegen_operand(bx, &arg.node); if let (0, Some(ty::InstanceKind::Virtual(_, idx))) = (i, instance.map(|i| i.def)) { @@ -1147,6 +1179,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { unwind, &lifetime_ends_after_call, instance, + tail, mergeable_succ, ) } @@ -1388,15 +1421,23 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { target, unwind, fn_span, + false, mergeable_succ(), ), - mir::TerminatorKind::TailCall { .. } => { - // FIXME(explicit_tail_calls): implement tail calls in ssa backend - span_bug!( - terminator.source_info.span, - "`TailCall` terminator is not yet supported by `rustc_codegen_ssa`" - ) - } + mir::TerminatorKind::TailCall { ref func, ref args, fn_span } => self + .codegen_call_terminator( + helper, + bx, + terminator, + func, + args, + mir::Place::from(mir::RETURN_PLACE), + None, + mir::UnwindAction::Unreachable, + fn_span, + true, + mergeable_succ(), + ), mir::TerminatorKind::CoroutineDrop | mir::TerminatorKind::Yield { .. } => { bug!("coroutine ops in codegen") } diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 979456a6ba70f..4b18146863bf4 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -595,6 +595,18 @@ pub trait BuilderMethods<'a, 'tcx>: funclet: Option<&Self::Funclet>, instance: Option>, ) -> Self::Value; + + fn tail_call( + &mut self, + llty: Self::Type, + fn_attrs: Option<&CodegenFnAttrs>, + fn_abi: &FnAbi<'tcx, Ty<'tcx>>, + llfn: Self::Value, + args: &[Self::Value], + funclet: Option<&Self::Funclet>, + instance: Option>, + ); + fn zext(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value; fn apply_attrs_to_cleanup_callsite(&mut self, llret: Self::Value); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 82568ed4ae177..653ecfdf9b339 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -1986,3 +1986,29 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) { MD.NoHWAddress = true; GV.setSanitizerMetadata(MD); } + +enum LLVMRustTailCallKind { + LLVMRustTailCallKindNone = 0, + LLVMRustTailCallKindTail = 1, + LLVMRustTailCallKindMustTail = 2, + LLVMRustTailCallKindNoTail = 3 +}; + +extern "C" void LLVMRustSetTailCallKind(LLVMValueRef Call, + LLVMRustTailCallKind Kind) { + CallInst *CI = unwrap(Call); + switch (Kind) { + case LLVMRustTailCallKindNone: + CI->setTailCallKind(CallInst::TCK_None); + break; + case LLVMRustTailCallKindTail: + CI->setTailCallKind(CallInst::TCK_Tail); + break; + case LLVMRustTailCallKindMustTail: + CI->setTailCallKind(CallInst::TCK_MustTail); + break; + case LLVMRustTailCallKindNoTail: + CI->setTailCallKind(CallInst::TCK_NoTail); + break; + } +} diff --git a/tests/codegen-llvm/become-musttail.rs b/tests/codegen-llvm/become-musttail.rs new file mode 100644 index 0000000000000..07f3357191042 --- /dev/null +++ b/tests/codegen-llvm/become-musttail.rs @@ -0,0 +1,18 @@ +//@ compile-flags: -C opt-level=0 -Cpanic=abort -C no-prepopulate-passes +//@ needs-unwind + +#![crate_type = "lib"] +#![feature(explicit_tail_calls)] + +// CHECK-LABEL: define {{.*}}@fibonacci( +#[no_mangle] +#[inline(never)] +pub fn fibonacci(n: u64, a: u64, b: u64) -> u64 { + // CHECK: musttail call {{.*}}@fibonacci( + // CHECK-NEXT: ret i64 + match n { + 0 => a, + 1 => b, + _ => become fibonacci(n - 1, b, a + b), + } +} diff --git a/tests/ui/explicit-tail-calls/recursion-tce.rs b/tests/ui/explicit-tail-calls/recursion-tce.rs new file mode 100644 index 0000000000000..8c89ceb7869aa --- /dev/null +++ b/tests/ui/explicit-tail-calls/recursion-tce.rs @@ -0,0 +1,17 @@ +//@ run-pass +#![expect(incomplete_features)] +#![feature(explicit_tail_calls)] + +use std::hint::black_box; + +pub fn count(curr: u64, top: u64) -> u64 { + if black_box(curr) >= top { + curr + } else { + become count(curr + 1, top) + } +} + +fn main() { + println!("{}", count(0, black_box(1000000))); +}