Skip to content

Commit 90ef114

Browse files
authored
[flang][cuda] Add cuf.set_allocator_idx for device component (#148750)
1 parent 7cde974 commit 90ef114

File tree

3 files changed

+97
-3
lines changed

3 files changed

+97
-3
lines changed

flang/include/flang/Semantics/tools.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ bool IsPolymorphic(const Symbol &);
199199
bool IsUnlimitedPolymorphic(const Symbol &);
200200
bool IsPolymorphicAllocatable(const Symbol &);
201201

202+
bool IsDeviceAllocatable(const Symbol &symbol);
203+
202204
inline bool IsCUDADeviceContext(const Scope *scope) {
203205
if (scope) {
204206
if (const Symbol * symbol{scope->symbol()}) {

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -771,9 +771,80 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
771771
return builder.create<cuf::SharedMemoryOp>(loc, ty, nm, symNm, lenParams,
772772
indices);
773773

774-
if (!cuf::isCUDADeviceContext(builder.getRegion()))
775-
return builder.create<cuf::AllocOp>(loc, ty, nm, symNm, dataAttr,
776-
lenParams, indices);
774+
if (!cuf::isCUDADeviceContext(builder.getRegion())) {
775+
mlir::Value alloc = builder.create<cuf::AllocOp>(
776+
loc, ty, nm, symNm, dataAttr, lenParams, indices);
777+
if (const auto *details{
778+
ultimateSymbol
779+
.detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
780+
const Fortran::semantics::DeclTypeSpec *type{details->type()};
781+
const Fortran::semantics::DerivedTypeSpec *derived{
782+
type ? type->AsDerived() : nullptr};
783+
if (derived) {
784+
Fortran::semantics::UltimateComponentIterator components{*derived};
785+
auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
786+
787+
mlir::Type fieldTy;
788+
llvm::SmallVector<mlir::Value> coordinates;
789+
for (const auto &sym : components) {
790+
if (Fortran::semantics::IsDeviceAllocatable(sym)) {
791+
unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
792+
mlir::Type fieldTy;
793+
std::vector<mlir::Value> coordinates;
794+
795+
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
796+
// Field found in the base record type.
797+
auto fieldName = recTy.getTypeList()[fieldIdx].first;
798+
fieldTy = recTy.getTypeList()[fieldIdx].second;
799+
mlir::Value fieldIndex = builder.create<fir::FieldIndexOp>(
800+
loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
801+
recTy,
802+
/*typeParams=*/mlir::ValueRange{});
803+
coordinates.push_back(fieldIndex);
804+
} else {
805+
// Field not found in base record type, search in potential
806+
// record type components.
807+
for (auto component : recTy.getTypeList()) {
808+
if (auto childRecTy =
809+
mlir::dyn_cast<fir::RecordType>(component.second)) {
810+
fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
811+
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
812+
mlir::Value parentFieldIndex =
813+
builder.create<fir::FieldIndexOp>(
814+
loc, fir::FieldType::get(childRecTy.getContext()),
815+
component.first, recTy,
816+
/*typeParams=*/mlir::ValueRange{});
817+
coordinates.push_back(parentFieldIndex);
818+
auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
819+
fieldTy = childRecTy.getTypeList()[fieldIdx].second;
820+
mlir::Value childFieldIndex =
821+
builder.create<fir::FieldIndexOp>(
822+
loc, fir::FieldType::get(fieldTy.getContext()),
823+
fieldName, childRecTy,
824+
/*typeParams=*/mlir::ValueRange{});
825+
coordinates.push_back(childFieldIndex);
826+
break;
827+
}
828+
}
829+
}
830+
}
831+
832+
if (coordinates.empty())
833+
TODO(loc, "device resident component in complex derived-type "
834+
"hierarchy");
835+
836+
mlir::Value comp = builder.create<fir::CoordinateOp>(
837+
loc, builder.getRefType(fieldTy), alloc, coordinates);
838+
cuf::DataAttributeAttr dataAttr =
839+
Fortran::lower::translateSymbolCUFDataAttribute(
840+
builder.getContext(), sym);
841+
builder.create<cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
842+
}
843+
}
844+
}
845+
}
846+
return alloc;
847+
}
777848
}
778849

779850
// Let the builder do all the heavy lifting.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
3+
module m1
4+
type ty_device
5+
integer, device, allocatable, dimension(:) :: x
6+
integer :: y
7+
integer, device, allocatable, dimension(:) :: z
8+
end type
9+
contains
10+
subroutine sub1()
11+
type(ty_device) :: a
12+
end subroutine
13+
14+
! CHECK-LABEL: func.func @_QMm1Psub1()
15+
! CHECK: %[[DT:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
16+
! CHECK: %[[X:.*]] = fir.coordinate_of %[[DT]], x : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
17+
! CHECK: cuf.set_allocator_idx %[[X]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
18+
! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]], z : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
19+
! CHECK: cuf.set_allocator_idx %[[Z]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
20+
21+
end module

0 commit comments

Comments
 (0)