From fc1b0b5f4f495354c10dbbc818f5b3ba61e39cd2 Mon Sep 17 00:00:00 2001 From: Michal Hornicky Date: Thu, 17 Jul 2025 11:07:10 +0200 Subject: [PATCH] Propagate musttail attribute to LLVM backend for guaranteed tail calls This commit implements proper tail call optimization for the `become` expression by propagating LLVM's musttail attribute, which guarantees tail call optimization rather than leaving it as an optimization hint. Changes: - Add `set_tail_call` method to BuilderMethods trait - Add FFI wrapper LLVMRustSetTailCallKind to access LLVM's setTailCallKind API - Implement tail call handling in LLVM backend using musttail - Implement TailCall terminator codegen in rustc_codegen_ssa - Make GCC backend fail explicitly on tail calls (not yet supported) - Add codegen tests to verify musttail is properly emitted - Add runtime tests for deep recursion and mutual recursion The musttail attribute is critical for the correctness of the `become` expression as it guarantees the tail call will be optimized, preventing stack overflow in recursive scenarios. --- compiler/rustc_codegen_gcc/src/builder.rs | 7 ++ compiler/rustc_codegen_llvm/src/builder.rs | 5 + compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 12 +++ compiler/rustc_codegen_ssa/src/mir/block.rs | 100 ++++++++++++++++-- .../rustc_codegen_ssa/src/traits/builder.rs | 4 + .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 26 +++++ tests/codegen/tail-call-become.rs | 68 ++++++++++++ tests/codegen/tail-call-musttail.rs | 33 ++++++ .../explicit-tail-calls/llvm-ir-tail-call.rs | 35 ++++++ .../explicit-tail-calls/mutual-recursion.rs | 63 +++++++++++ .../runtime-deep-recursion.rs | 38 +++++++ 11 files changed, 385 insertions(+), 6 deletions(-) create mode 100644 tests/codegen/tail-call-become.rs create mode 100644 tests/codegen/tail-call-musttail.rs create mode 100644 tests/ui/explicit-tail-calls/llvm-ir-tail-call.rs create mode 100644 tests/ui/explicit-tail-calls/mutual-recursion.rs create mode 100644 tests/ui/explicit-tail-calls/runtime-deep-recursion.rs diff --git a/compiler/rustc_codegen_gcc/src/builder.rs b/compiler/rustc_codegen_gcc/src/builder.rs index 28d1ec7d89564..1e368cd774b68 100644 --- a/compiler/rustc_codegen_gcc/src/builder.rs +++ b/compiler/rustc_codegen_gcc/src/builder.rs @@ -1751,6 +1751,13 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> { // FIXME(bjorn3): implement } + fn set_tail_call(&mut self, _call_inst: RValue<'gcc>) { + // Explicitly fail when this method is called + bug!( + "Guaranteed tail calls with the 'become' keyword are not implemented in the GCC backend" + ); + } + fn set_span(&mut self, _span: Span) {} fn from_immediate(&mut self, val: Self::Value) -> Self::Value { diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 514923ad6f37f..89c3182030b10 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -1371,6 +1371,11 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { let cold_inline = llvm::AttributeKind::Cold.create_attr(self.llcx); attributes::apply_to_callsite(llret, llvm::AttributePlace::Function, &[cold_inline]); } + + fn set_tail_call(&mut self, call_inst: &'ll Value) { + // Use musttail for guaranteed tail call optimization required by 'become' + llvm::LLVMRustSetTailCallKind(call_inst, llvm::TailCallKind::MustTail); + } } impl<'ll> StaticBuilderMethods for Builder<'_, 'll, '_> { diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 0b1e632cbc42c..09743e3f1dc02 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -97,6 +97,17 @@ pub(crate) enum ModuleFlagMergeBehavior { // Consts for the LLVM CallConv type, pre-cast to usize. +/// LLVM TailCallKind for musttail support +#[derive(Copy, Clone, PartialEq, Debug)] +#[repr(C)] +#[allow(dead_code)] +pub(crate) enum TailCallKind { + None = 0, + Tail = 1, + MustTail = 2, + NoTail = 3, +} + /// LLVM CallingConv::ID. Should we wrap this? /// /// See @@ -1181,6 +1192,7 @@ unsafe extern "C" { pub(crate) safe fn LLVMIsGlobalConstant(GlobalVar: &Value) -> Bool; pub(crate) safe fn LLVMSetGlobalConstant(GlobalVar: &Value, IsConstant: Bool); pub(crate) safe fn LLVMSetTailCall(CallInst: &Value, IsTailCall: Bool); + pub(crate) safe fn LLVMRustSetTailCallKind(CallInst: &Value, Kind: TailCallKind); // Operations on attributes pub(crate) fn LLVMCreateStringAttribute( diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index bde63fd501aa2..b8df6a5aae0a3 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -342,6 +342,97 @@ impl<'a, 'tcx> TerminatorCodegenHelper<'tcx> { /// Codegen implementations for some terminator variants. impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { + fn codegen_tail_call_terminator( + &mut self, + bx: &mut Bx, + func: &mir::Operand<'tcx>, + args: &[Spanned>], + fn_span: Span, + ) { + // We don't need source_info as we already have fn_span for diagnostics + let func = self.codegen_operand(bx, func); + let fn_ty = func.layout.ty; + + // Create the callee. This is a fn ptr or zero-sized and hence a kind of scalar. + let (fn_ptr, fn_abi, instance) = match *fn_ty.kind() { + ty::FnDef(def_id, substs) => { + let instance = ty::Instance::expect_resolve( + bx.tcx(), + bx.typing_env(), + def_id, + substs, + fn_span, + ); + let fn_ptr = bx.get_fn_addr(instance); + let fn_abi = bx.fn_abi_of_instance(instance, ty::List::empty()); + (fn_ptr, fn_abi, Some(instance)) + } + ty::FnPtr(..) => { + let sig = fn_ty.fn_sig(bx.tcx()); + let extra_args = bx.tcx().mk_type_list(&[]); + let fn_ptr = func.immediate(); + let fn_abi = bx.fn_abi_of_fn_ptr(sig, extra_args); + (fn_ptr, fn_abi, None) + } + _ => bug!("{} is not callable", func.layout.ty), + }; + + let mut llargs = Vec::with_capacity(args.len()); + let mut lifetime_ends_after_call = Vec::new(); + + // Process arguments + for arg in args { + let op = self.codegen_operand(bx, &arg.node); + let arg_idx = llargs.len(); + + if arg_idx < fn_abi.args.len() { + self.codegen_argument( + bx, + op, + &mut llargs, + &fn_abi.args[arg_idx], + &mut lifetime_ends_after_call, + ); + } else { + // This can happen in case of C-variadic functions + let is_immediate = match op.val { + Immediate(_) => true, + _ => false, + }; + + if is_immediate { + llargs.push(op.immediate()); + } else { + let temp = PlaceRef::alloca(bx, op.layout); + op.val.store(bx, temp); + llargs.push(bx.load( + bx.backend_type(op.layout), + temp.val.llval, + temp.val.align, + )); + } + } + } + + // Call the function + let fn_ty = bx.fn_decl_backend_type(fn_abi); + let fn_attrs = if let Some(instance) = instance + && bx.tcx().def_kind(instance.def_id()).has_codegen_attrs() + { + Some(bx.tcx().codegen_fn_attrs(instance.def_id())) + } else { + None + }; + + // Perform the actual function call + let llret = bx.call(fn_ty, fn_attrs, Some(fn_abi), fn_ptr, &llargs, None, instance); + + // Mark as tail call - this is the critical part + bx.set_tail_call(llret); + + // Return the result - musttail requires ret immediately after the call + bx.ret(llret); + } /// Generates code for a `Resume` terminator. fn codegen_resume_terminator(&mut self, helper: TerminatorCodegenHelper<'tcx>, bx: &mut Bx) { if let Some(funclet) = helper.funclet(self) { @@ -1390,12 +1481,9 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { fn_span, mergeable_succ(), ), - mir::TerminatorKind::TailCall { .. } => { - // FIXME(explicit_tail_calls): implement tail calls in ssa backend - span_bug!( - terminator.source_info.span, - "`TailCall` terminator is not yet supported by `rustc_codegen_ssa`" - ) + mir::TerminatorKind::TailCall { ref func, ref args, fn_span } => { + self.codegen_tail_call_terminator(bx, func, args, fn_span); + MergingSucc::False } mir::TerminatorKind::CoroutineDrop | mir::TerminatorKind::Yield { .. } => { bug!("coroutine ops in codegen") diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 979456a6ba70f..c8eed8e660af2 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -595,6 +595,10 @@ pub trait BuilderMethods<'a, 'tcx>: funclet: Option<&Self::Funclet>, instance: Option>, ) -> Self::Value; + + /// Mark a call instruction as a tail call (guaranteed tail call optimization) + /// Used for implementing the `become` expression + fn set_tail_call(&mut self, call_inst: Self::Value); fn zext(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value; fn apply_attrs_to_cleanup_callsite(&mut self, llret: Self::Value); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 90aa9188c8300..ad2e0c0ad5baa 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -51,6 +51,14 @@ // //===----------------------------------------------------------------------=== +// Define TailCallKind enum values to match LLVM's +enum LLVMRustTailCallKind { + LLVMRustTailCallKindNone = 0, + LLVMRustTailCallKindTail = 1, + LLVMRustTailCallKindMustTail = 2, + LLVMRustTailCallKindNoTail = 3 +}; + using namespace llvm; using namespace llvm::sys; using namespace llvm::object; @@ -1949,3 +1957,21 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) { MD.NoHWAddress = true; GV.setSanitizerMetadata(MD); } + +extern "C" void LLVMRustSetTailCallKind(LLVMValueRef Call, LLVMRustTailCallKind Kind) { + CallInst *CI = unwrap(Call); + switch (Kind) { + case LLVMRustTailCallKindNone: + CI->setTailCallKind(CallInst::TCK_None); + break; + case LLVMRustTailCallKindTail: + CI->setTailCallKind(CallInst::TCK_Tail); + break; + case LLVMRustTailCallKindMustTail: + CI->setTailCallKind(CallInst::TCK_MustTail); + break; + case LLVMRustTailCallKindNoTail: + CI->setTailCallKind(CallInst::TCK_NoTail); + break; + } +} diff --git a/tests/codegen/tail-call-become.rs b/tests/codegen/tail-call-become.rs new file mode 100644 index 0000000000000..de634d9347506 --- /dev/null +++ b/tests/codegen/tail-call-become.rs @@ -0,0 +1,68 @@ +//@ compile-flags: -C opt-level=0 -Cpanic=abort -C no-prepopulate-passes +//@ needs-llvm-components: x86 + +#![feature(explicit_tail_calls)] +#![crate_type = "lib"] + +// CHECK-LABEL: define {{.*}}@with_tail( +#[no_mangle] +#[inline(never)] +pub fn with_tail(n: u32) -> u32 { + // CHECK: tail call {{.*}}@with_tail( + if n == 0 { 0 } else { become with_tail(n - 1) } +} + +// CHECK-LABEL: define {{.*}}@no_tail( +#[no_mangle] +#[inline(never)] +pub fn no_tail(n: u32) -> u32 { + // CHECK-NOT: tail call + // CHECK: call {{.*}}@no_tail( + if n == 0 { 0 } else { no_tail(n - 1) } +} + +// CHECK-LABEL: define {{.*}}@even_with_tail( +#[no_mangle] +#[inline(never)] +pub fn even_with_tail(n: u32) -> bool { + // CHECK: tail call {{.*}}@odd_with_tail( + match n { + 0 => true, + _ => become odd_with_tail(n - 1), + } +} + +// CHECK-LABEL: define {{.*}}@odd_with_tail( +#[no_mangle] +#[inline(never)] +pub fn odd_with_tail(n: u32) -> bool { + // CHECK: tail call {{.*}}@even_with_tail( + match n { + 0 => false, + _ => become even_with_tail(n - 1), + } +} + +// CHECK-LABEL: define {{.*}}@even_no_tail( +#[no_mangle] +#[inline(never)] +pub fn even_no_tail(n: u32) -> bool { + // CHECK-NOT: tail call + // CHECK: call {{.*}}@odd_no_tail( + match n { + 0 => true, + _ => odd_no_tail(n - 1), + } +} + +// CHECK-LABEL: define {{.*}}@odd_no_tail( +#[no_mangle] +#[inline(never)] +pub fn odd_no_tail(n: u32) -> bool { + // CHECK-NOT: tail call + // CHECK: call {{.*}}@even_no_tail( + match n { + 0 => false, + _ => even_no_tail(n - 1), + } +} diff --git a/tests/codegen/tail-call-musttail.rs b/tests/codegen/tail-call-musttail.rs new file mode 100644 index 0000000000000..618a7eacb3c6e --- /dev/null +++ b/tests/codegen/tail-call-musttail.rs @@ -0,0 +1,33 @@ +//@ compile-flags: -C opt-level=0 -Cpanic=abort -C no-prepopulate-passes +//@ needs-unwind + +#![crate_type = "lib"] +#![feature(explicit_tail_calls)] + +// Ensure that explicit tail calls use musttail in LLVM + +// CHECK-LABEL: define {{.*}}@simple_tail_call( +#[no_mangle] +#[inline(never)] +pub fn simple_tail_call(n: i32) -> i32 { + // CHECK: musttail call {{.*}}@simple_tail_call( + // CHECK-NEXT: ret i32 + if n <= 0 { + 0 + } else { + become simple_tail_call(n - 1) + } +} + +// CHECK-LABEL: define {{.*}}@tail_call_with_args( +#[no_mangle] +#[inline(never)] +pub fn tail_call_with_args(a: i32, b: i32, c: i32) -> i32 { + // CHECK: musttail call {{.*}}@tail_call_with_args( + // CHECK-NEXT: ret i32 + if a == 0 { + b + c + } else { + become tail_call_with_args(a - 1, b + 1, c) + } +} \ No newline at end of file diff --git a/tests/ui/explicit-tail-calls/llvm-ir-tail-call.rs b/tests/ui/explicit-tail-calls/llvm-ir-tail-call.rs new file mode 100644 index 0000000000000..a0818fda47f96 --- /dev/null +++ b/tests/ui/explicit-tail-calls/llvm-ir-tail-call.rs @@ -0,0 +1,35 @@ +//@ compile-flags: -O +//@ run-pass +#![expect(incomplete_features)] +#![feature(explicit_tail_calls)] + +// A deep recursive function that uses explicit tail calls +// This will cause stack overflow without tail call optimization +fn deep_recursion(n: u32) -> u32 { + match n { + 0 => 0, + _ => become deep_recursion(n - 1) + } +} + +// A deep recursive function without explicit tail calls +// This will overflow the stack for large values +fn deep_recursion_no_tail(n: u32) -> u32 { + match n { + 0 => 0, + _ => deep_recursion_no_tail(n - 1) + } +} + +fn main() { + // Verify correctness for small values + assert_eq!(deep_recursion(10), 0); + assert_eq!(deep_recursion_no_tail(10), 0); + + // This will succeed only if tail call optimization is working + // It would overflow the stack otherwise + println!("Starting deep recursion with 50,000 calls"); + let result = deep_recursion(50_000); + assert_eq!(result, 0); + println!("Successfully completed 50,000 recursive calls with tail call optimization"); +} diff --git a/tests/ui/explicit-tail-calls/mutual-recursion.rs b/tests/ui/explicit-tail-calls/mutual-recursion.rs new file mode 100644 index 0000000000000..cb037ce2f20b0 --- /dev/null +++ b/tests/ui/explicit-tail-calls/mutual-recursion.rs @@ -0,0 +1,63 @@ +//@ run-pass +#![expect(incomplete_features)] +#![feature(explicit_tail_calls)] + +// This is a classic example of mutual recursion: +// even(n) calls odd(n-1), and odd(n) calls even(n-1) +// Without tail calls, this would quickly overflow the stack for large inputs + +// Check if a number is even using mutual recursion +fn is_even(n: u64) -> bool { + match n { + 0 => true, + _ => become is_odd(n - 1) + } +} + +// Check if a number is odd using mutual recursion +fn is_odd(n: u64) -> bool { + match n { + 0 => false, + _ => become is_even(n - 1) + } +} + +// Versions without tail calls for comparison +fn is_even_no_tail(n: u64) -> bool { + match n { + 0 => true, + _ => is_odd_no_tail(n - 1) + } +} + +fn is_odd_no_tail(n: u64) -> bool { + match n { + 0 => false, + _ => is_even_no_tail(n - 1) + } +} + +fn main() { + // Verify correctness for small values + assert_eq!(is_even(0), true); + assert_eq!(is_odd(0), false); + assert_eq!(is_even(1), false); + assert_eq!(is_odd(1), true); + assert_eq!(is_even(10), true); + assert_eq!(is_odd(10), false); + + // Test with an extremely large number that would definitely overflow the stack + // without tail call optimization - each call creates 2 stack frames (alternating between functions) + // so 100,000 would create 200,000 stack frames total + assert_eq!(is_even(100_000), true); + assert_eq!(is_odd(100_000), false); + assert_eq!(is_even(100_001), false); + assert_eq!(is_odd(100_001), true); + + println!("Deep mutual recursion test passed with 100,000 alternating recursive calls!"); + + // Verify non-tail versions work for small values + assert_eq!(is_even_no_tail(10), true); + assert_eq!(is_odd_no_tail(10), false); + // But would overflow for large values (not tested to avoid crashing) +} \ No newline at end of file diff --git a/tests/ui/explicit-tail-calls/runtime-deep-recursion.rs b/tests/ui/explicit-tail-calls/runtime-deep-recursion.rs new file mode 100644 index 0000000000000..5ab45c76875bc --- /dev/null +++ b/tests/ui/explicit-tail-calls/runtime-deep-recursion.rs @@ -0,0 +1,38 @@ +//@ run-pass +#![expect(incomplete_features)] +#![feature(explicit_tail_calls)] + +// A function that causes deep recursion - counts down from n to 0 +// Without tail calls, this would overflow the stack for large n values +fn countdown(n: u32) -> u32 { + match n { + 0 => 0, + _ => become countdown(n - 1) + } +} + +// Same function but without tail call optimization +fn countdown_no_tail(n: u32) -> u32 { + match n { + 0 => 0, + _ => countdown_no_tail(n - 1) + } +} + +// This test is specifically designed to verify tail call optimization +// We use an extremely large recursion depth (500,000) that would +// absolutely overflow the stack without tail call optimization +fn main() { + // Small test to verify correctness + assert_eq!(countdown(10), 0); + + // Regular recursion would overflow here (500,000 stack frames) + // Only works if tail call optimization is actually happening + let result = countdown(500_000); + assert_eq!(result, 0); + println!("Successfully completed 500,000 recursive calls with tail call optimization"); + + // We can't test the non-tail version with a large number as it would crash, + // but we can verify it works for small values + assert_eq!(countdown_no_tail(10), 0); +} \ No newline at end of file