Skip to content

Offload device1 #142696

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ pub(crate) fn codegen(
// binaries. So we must clone the module to produce the asm output
// if we are also producing object code.
let llmod = if let EmitObj::ObjectCode(_) = config.emit_obj {
llvm::LLVMCloneModule(llmod)
unsafe { llvm::LLVMCloneModule(llmod) }
} else {
llmod
};
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use std::ops::Deref;
use std::{iter, ptr};

pub(crate) mod autodiff;
pub(crate) mod gpu_device;
pub(crate) mod gpu_offload;
pub(crate) mod gpu_wrapper;

use libc::{c_char, c_uint, size_t};
use rustc_abi as abi;
Expand Down
113 changes: 113 additions & 0 deletions compiler/rustc_codegen_llvm/src/builder/gpu_device.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use std::ffi::{CString, c_uint};

use llvm::Linkage::*;
use rustc_codegen_ssa::back::write::CodegenContext;

use crate::llvm::{self, Linkage};
use crate::{LlvmCodegenBackend, SimpleCx};

fn add_unnamed_global_in_addrspace<'ll>(
cx: &SimpleCx<'ll>,
name: &str,
initializer: &'ll llvm::Value,
l: Linkage,
addrspace: u32,
) -> &'ll llvm::Value {
let llglobal = add_global_in_addrspace(cx, name, initializer, l, addrspace);
llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
llglobal
}

pub(crate) fn add_global_in_addrspace<'ll>(
cx: &SimpleCx<'ll>,
name: &str,
initializer: &'ll llvm::Value,
l: Linkage,
addrspace: u32,
) -> &'ll llvm::Value {
let c_name = CString::new(name).unwrap();
let llglobal: &'ll llvm::Value = llvm::add_global_in_addrspace(
cx.llmod,
cx.val_ty(initializer),
&c_name,
addrspace as c_uint,
);
llvm::set_global_constant(llglobal, true);
llvm::set_linkage(llglobal, l);
llvm::set_initializer(llglobal, initializer);
llglobal
}

#[allow(unused)]
pub(crate) fn gen_asdf<'ll>(cgcx: &CodegenContext<LlvmCodegenBackend>, _old_cx: &SimpleCx<'ll>) {
let llcx = unsafe { llvm::LLVMRustContextCreate(false) };
let module_name = CString::new("offload.wrapper.module").unwrap();
let llmod = unsafe { llvm::LLVMModuleCreateWithNameInContext(module_name.as_ptr(), llcx) };
let cx = SimpleCx::new(llmod, llcx, cgcx.pointer_size);
let initializer = cx.get_const_i32(0);
add_unnamed_global_in_addrspace(&cx, "__omp_rtl_debug_kind", initializer, WeakODRLinkage, 1);
add_unnamed_global_in_addrspace(
&cx,
"__omp_rtl_assume_teams_oversubscription",
initializer,
WeakODRLinkage,
1,
);
add_unnamed_global_in_addrspace(
&cx,
"__omp_rtl_assume_threads_oversubscription",
initializer,
WeakODRLinkage,
1,
);
add_unnamed_global_in_addrspace(
&cx,
"__omp_rtl_assume_no_thread_state",
initializer,
WeakODRLinkage,
1,
);
add_unnamed_global_in_addrspace(
&cx,
"__oclc_ABI_version",
cx.get_const_i32(500),
WeakODRLinkage,
4,
);
unsafe {
llvm::LLVMPrintModuleToFile(
llmod,
CString::new("rustmagic-openmp-amdgcn-amd-amdhsa-gfx90a.ll").unwrap().as_ptr(),
std::ptr::null_mut(),
);

// Clean up
llvm::LLVMDisposeModule(llmod);
llvm::LLVMContextDispose(llcx);
}
// TODO: addressspace 1 or 4

Check failure on line 88 in compiler/rustc_codegen_llvm/src/builder/gpu_device.rs

View workflow job for this annotation

GitHub Actions / PR - tidy

TODO is used for tasks that should be done before merging a PR; If you want to leave a message in the codebase use FIXME

Check failure on line 88 in compiler/rustc_codegen_llvm/src/builder/gpu_device.rs

View workflow job for this annotation

GitHub Actions / PR - aarch64-gnu-llvm-19-2

TODO is used for tasks that should be done before merging a PR; If you want to leave a message in the codebase use FIXME
}
// source_filename = "mem.cpp"
// GPU: target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"
// CPU: target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
// target triple = "amdgcn-amd-amdhsa"
//
// @__omp_rtl_debug_kind = weak_odr hidden local_unnamed_addr addrspace(1) constant i32 0
// @__omp_rtl_assume_teams_oversubscription = weak_odr hidden local_unnamed_addr addrspace(1) constant i32 0
// @__omp_rtl_assume_threads_oversubscription = weak_odr hidden local_unnamed_addr addrspace(1) constant i32 0
// @__omp_rtl_assume_no_thread_state = weak_odr hidden local_unnamed_addr addrspace(1) constant i32 0
// @__omp_rtl_assume_no_nested_parallelism = weak_odr hidden local_unnamed_addr addrspace(1) constant i32 0
// @__oclc_ABI_version = weak_odr hidden local_unnamed_addr addrspace(4) constant i32 500
//
// !llvm.module.flags = !{!0, !1, !2, !3, !4}
// !opencl.ocl.version = !{!5}
// !llvm.ident = !{!6, !7}
//
// !0 = !{i32 1, !"amdhsa_code_object_version", i32 500}
// !1 = !{i32 1, !"wchar_size", i32 4}
// !2 = !{i32 7, !"openmp", i32 51}
// !3 = !{i32 7, !"openmp-device", i32 51}
// !4 = !{i32 8, !"PIC Level", i32 2}
// !5 = !{i32 2, i32 0}
// !6 = !{!"clang version 20.1.5-rust-1.89.0-nightly (https://github.com/rust-lang/llvm-project.git c1118fdbb3024157df7f4cfe765f2b0b4339e8a2)"}
// !7 = !{!"AMD clang version 19.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-6.4.0 25133 c7fe45cf4b819c5991fe208aaa96edf142730f1d)"}
119 changes: 108 additions & 11 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,40 @@
use crate::{LlvmCodegenBackend, SimpleCx, attributes};

pub(crate) fn handle_gpu_code<'ll>(
_cgcx: &CodegenContext<LlvmCodegenBackend>,
cgcx: &CodegenContext<LlvmCodegenBackend>,
cx: &'ll SimpleCx<'_>,
) {
// The offload memory transfer type for each kernel
let mut o_types = vec![];
let mut kernels = vec![];
let mut region_ids = vec![];
let offload_entry_ty = add_tgt_offload_entry(&cx);
for num in 0..9 {
let kernel = cx.get_function(&format!("kernel_{num}"));
if let Some(kernel) = kernel {
o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num));
let (o, k) = gen_define_handling(&cx, kernel, offload_entry_ty, num);
o_types.push(o);
region_ids.push(k);
kernels.push(kernel);
}
}
gen_call_handling(&cx, &kernels, &o_types, &region_ids);
crate::builder::gpu_wrapper::gen_image_wrapper_module(&cgcx);
}

gen_call_handling(&cx, &kernels, &o_types);
// ; Function Attrs: nounwind
// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
let tptr = cx.type_ptr();
let ti64 = cx.type_i64();
let ti32 = cx.type_i32();
let args = vec![tptr, ti64, ti32, ti32, tptr, tptr];
let tgt_fn_ty = cx.type_func(&args, ti32);
let name = "__tgt_target_kernel";
let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
(tgt_decl, tgt_fn_ty)
}

// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper:
Expand Down Expand Up @@ -83,7 +101,7 @@
offload_entry_ty
}

fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Type, Vec<&'ll llvm::Type>) {
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
let tptr = cx.type_ptr();
let ti64 = cx.type_i64();
Expand Down Expand Up @@ -118,9 +136,10 @@
vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];

cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
(kernel_arguments_ty, kernel_elements)
// For now we don't handle kernels, so for now we just add a global dummy
// to make sure that the __tgt_offload_entry is defined and handled correctly.
cx.declare_global("my_struct_global2", kernel_arguments_ty);
//cx.declare_global("my_struct_global2", kernel_arguments_ty);
}

fn gen_tgt_data_mappers<'ll>(
Expand Down Expand Up @@ -187,7 +206,7 @@
kernel: &'ll llvm::Value,
offload_entry_ty: &'ll llvm::Type,
num: i64,
) -> &'ll llvm::Value {
) -> (&'ll llvm::Value, &'ll llvm::Value) {
let types = cx.func_params_types(cx.get_type_of_global(kernel));
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
// reference) types.
Expand Down Expand Up @@ -245,10 +264,10 @@
llvm::set_alignment(llglobal, Align::ONE);
let c_section_name = CString::new(".omp_offloading_entries").unwrap();
llvm::set_section(llglobal, &c_section_name);
o_types
(o_types, region_id)
}

fn declare_offload_fn<'ll>(
pub(crate) fn declare_offload_fn<'ll>(
cx: &'ll SimpleCx<'_>,
name: &str,
ty: &'ll llvm::Type,
Expand Down Expand Up @@ -287,15 +306,17 @@
cx: &'ll SimpleCx<'_>,
_kernels: &[&'ll llvm::Value],
o_types: &[&'ll llvm::Value],
region_ids: &[&'ll llvm::Value],
) {
let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
let tptr = cx.type_ptr();
let ti32 = cx.type_i32();
let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);

gen_tgt_kernel_global(&cx);
let (tgt_kernel_decl, tgt_kernel_types) = gen_tgt_kernel_global(&cx);
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);

let main_fn = cx.get_function("main");
Expand Down Expand Up @@ -329,6 +350,10 @@
// These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
let ty2 = cx.type_array(cx.type_i64(), num_args);
let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");

//%kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");

// Now we allocate once per function param, a copy to be passed to one of our maps.
let mut vals = vec![];
let mut geps = vec![];
Expand All @@ -341,7 +366,8 @@
let arg_name = format!("{name}.addr");
let alloca = builder.direct_alloca(in_ty, Align::EIGHT, &arg_name);

builder.store(p, alloca, Align::EIGHT);
let v = unsafe { llvm::LLVMGetOperand(kernel_call, index as u32).unwrap() };
builder.store(v, alloca, Align::EIGHT);
let val = builder.load(in_ty, alloca, Align::EIGHT);
let gep = builder.inbounds_gep(cx.type_f32(), val, &[i32_0]);
vals.push(val);
Expand Down Expand Up @@ -421,16 +447,87 @@

// Step 3)
// Here we will add code for the actual kernel launches in a follow-up PR.
//%28 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 0
//store i32 3, ptr %28, align 4
//%29 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 1
//store i32 3, ptr %29, align 4
//%30 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 2
//store ptr %26, ptr %30, align 8
//%31 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 3
//store ptr %27, ptr %31, align 8
//%32 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 4
//store ptr @.offload_sizes, ptr %32, align 8
//%33 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 5
//store ptr @.offload_maptypes, ptr %33, align 8
//%34 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 6
//store ptr null, ptr %34, align 8
//%35 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 7
//store ptr null, ptr %35, align 8
//%36 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 8
//store i64 0, ptr %36, align 8
//%37 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 9
//store i64 0, ptr %37, align 8
//%38 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 10
//store [3 x i32] [i32 2097152, i32 0, i32 0], ptr %38, align 4
//%39 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 11
//store [3 x i32] [i32 256, i32 0, i32 0], ptr %39, align 4
//%40 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 12
//store i32 0, ptr %40, align 4
// FIXME(offload): launch kernels
let mut values = vec![];
values.push((4, cx.get_const_i32(3)));
values.push((4, cx.get_const_i32(3)));
values.push((8, geps.0));
values.push((8, geps.1));
values.push((8, geps.2));
values.push((8, o_types[0]));
values.push((8, cx.const_null(cx.type_ptr())));
values.push((8, cx.const_null(cx.type_ptr())));
values.push((8, cx.get_const_i64(0)));
values.push((8, cx.get_const_i64(0)));
let ti32 = cx.type_i32();
let ci32_0 = cx.get_const_i32(0);
values.push((8, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0])));
values.push((8, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0])));
values.push((4, cx.get_const_i32(0)));

for (i, value) in values.iter().enumerate() {
let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
builder.store(value.1, ptr, Align::from_bytes(value.0).unwrap());
}

let args = vec![
s_ident_t,
// MAX == -1
cx.get_const_i64(u64::MAX),
cx.get_const_i32(2097152),
cx.get_const_i32(256),
region_ids[0],
a5,
];
let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
// %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
unsafe {
let next = llvm::LLVMGetNextInstruction(offload_success).unwrap();
dbg!(&next);

Check failure on line 512 in compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

View workflow job for this annotation

GitHub Actions / PR - tidy

`dbg!` macro is intended as a debugging tool. It should not be in version control.

Check failure on line 512 in compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

View workflow job for this annotation

GitHub Actions / PR - aarch64-gnu-llvm-19-2

`dbg!` macro is intended as a debugging tool. It should not be in version control.
llvm::LLVMRustPositionAfter(builder.llbuilder, next);
let called_kernel = llvm::LLVMGetCalledValue(next).unwrap();
llvm::LLVMInstructionEraseFromParent(next);
dbg!(&called_kernel);

Check failure on line 516 in compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

View workflow job for this annotation

GitHub Actions / PR - tidy

`dbg!` macro is intended as a debugging tool. It should not be in version control.

Check failure on line 516 in compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

View workflow job for this annotation

GitHub Actions / PR - aarch64-gnu-llvm-19-2

`dbg!` macro is intended as a debugging tool. It should not be in version control.
}

// Step 4)
unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
//unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };

let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);

builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);

drop(builder);
unsafe { llvm::LLVMDeleteFunction(called) };
dbg!("survived");

Check failure on line 529 in compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

View workflow job for this annotation

GitHub Actions / PR - tidy

`dbg!` macro is intended as a debugging tool. It should not be in version control.

Check failure on line 529 in compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

View workflow job for this annotation

GitHub Actions / PR - aarch64-gnu-llvm-19-2

`dbg!` macro is intended as a debugging tool. It should not be in version control.

// With this we generated the following begin and end mappers. We could easily generate the
// update mapper in an update.
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
Expand Down
Loading
Loading