Skip to content

Commit 36c37b0

Browse files
authored
[flang][OpenMP] Restore reduction processor behavior broken by #145837 (#150178)
Fixes #149089 and #149700. Before #145837, when processing a reduction symbol not yet supported by OpenMP lowering, the reduction processor would simply skip filling in the reduction symbols and variables. With #145837, this behvaior was slightly changed because the reduction symbols are populated before invoking the reduction processor (this is more convenient to shared the code with `do concurrent`). This PR restores the previous behavior.
1 parent 77b1b95 commit 36c37b0

File tree

5 files changed

+53
-21
lines changed

5 files changed

+53
-21
lines changed

flang/include/flang/Lower/Support/ReductionProcessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class ReductionProcessor {
124124
/// Creates a reduction declaration and associates it with an OpenMP block
125125
/// directive.
126126
template <typename OpType, typename RedOperatorListTy>
127-
static void processReductionArguments(
127+
static bool processReductionArguments(
128128
mlir::Location currentLocation, lower::AbstractConverter &converter,
129129
const RedOperatorListTy &redOperatorList,
130130
llvm::SmallVectorImpl<mlir::Value> &reductionVars,

flang/lib/Lower/Bridge.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2125,9 +2125,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
21252125

21262126
llvm::SmallVector<mlir::Value> reduceVars;
21272127
Fortran::lower::omp::ReductionProcessor rp;
2128-
rp.processReductionArguments<fir::DeclareReductionOp>(
2128+
bool result = rp.processReductionArguments<fir::DeclareReductionOp>(
21292129
toLocation(), *this, info.reduceOperatorList, reduceVars,
21302130
reduceVarByRef, reductionDeclSymbols, info.reduceSymList);
2131+
assert(result && "Failed to process `do concurrent` reductions");
21312132

21322133
doConcurrentLoopOp.getReduceVarsMutable().assign(reduceVars);
21332134
doConcurrentLoopOp.setReduceSymsAttr(

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,11 +1116,12 @@ bool ClauseProcessor::processInReduction(
11161116
collectReductionSyms(clause, inReductionSyms);
11171117

11181118
ReductionProcessor rp;
1119-
rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
1120-
currentLocation, converter,
1121-
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
1122-
inReductionVars, inReduceVarByRef, inReductionDeclSymbols,
1123-
inReductionSyms);
1119+
if (!rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
1120+
currentLocation, converter,
1121+
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
1122+
inReductionVars, inReduceVarByRef, inReductionDeclSymbols,
1123+
inReductionSyms))
1124+
inReductionSyms.clear();
11241125

11251126
// Copy local lists into the output.
11261127
llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
@@ -1461,10 +1462,12 @@ bool ClauseProcessor::processReduction(
14611462
}
14621463

14631464
ReductionProcessor rp;
1464-
rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
1465-
currentLocation, converter,
1466-
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
1467-
reductionVars, reduceVarByRef, reductionDeclSymbols, reductionSyms);
1465+
if (!rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
1466+
currentLocation, converter,
1467+
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
1468+
reductionVars, reduceVarByRef, reductionDeclSymbols,
1469+
reductionSyms))
1470+
reductionSyms.clear();
14681471
// Copy local lists into the output.
14691472
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
14701473
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
@@ -1486,11 +1489,12 @@ bool ClauseProcessor::processTaskReduction(
14861489
collectReductionSyms(clause, taskReductionSyms);
14871490

14881491
ReductionProcessor rp;
1489-
rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
1490-
currentLocation, converter,
1491-
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
1492-
taskReductionVars, taskReduceVarByRef, taskReductionDeclSymbols,
1493-
taskReductionSyms);
1492+
if (!rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
1493+
currentLocation, converter,
1494+
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
1495+
taskReductionVars, taskReduceVarByRef, taskReductionDeclSymbols,
1496+
taskReductionSyms))
1497+
taskReductionSyms.clear();
14941498
// Copy local lists into the output.
14951499
llvm::copy(taskReductionVars,
14961500
std::back_inserter(result.taskReductionVars));

flang/lib/Lower/Support/ReductionProcessor.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace lower {
3939
namespace omp {
4040

4141
// explicit template declarations
42-
template void ReductionProcessor::processReductionArguments<
42+
template bool ReductionProcessor::processReductionArguments<
4343
mlir::omp::DeclareReductionOp, omp::clause::ReductionOperatorList>(
4444
mlir::Location currentLocation, lower::AbstractConverter &converter,
4545
const omp::clause::ReductionOperatorList &redOperatorList,
@@ -48,7 +48,7 @@ template void ReductionProcessor::processReductionArguments<
4848
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
4949
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
5050

51-
template void ReductionProcessor::processReductionArguments<
51+
template bool ReductionProcessor::processReductionArguments<
5252
fir::DeclareReductionOp, llvm::SmallVector<fir::ReduceOperationEnum>>(
5353
mlir::Location currentLocation, lower::AbstractConverter &converter,
5454
const llvm::SmallVector<fir::ReduceOperationEnum> &redOperatorList,
@@ -607,7 +607,7 @@ static bool doReductionByRef(mlir::Value reductionVar) {
607607
}
608608

609609
template <typename OpType, typename RedOperatorListTy>
610-
void ReductionProcessor::processReductionArguments(
610+
bool ReductionProcessor::processReductionArguments(
611611
mlir::Location currentLocation, lower::AbstractConverter &converter,
612612
const RedOperatorListTy &redOperatorList,
613613
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
@@ -627,10 +627,10 @@ void ReductionProcessor::processReductionArguments(
627627
std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
628628
if (!ReductionProcessor::supportedIntrinsicProcReduction(
629629
*reductionIntrinsic)) {
630-
return;
630+
return false;
631631
}
632632
} else {
633-
return;
633+
return false;
634634
}
635635
}
636636
}
@@ -765,6 +765,8 @@ void ReductionProcessor::processReductionArguments(
765765

766766
if (isDoConcurrent)
767767
builder.restoreInsertionPoint(dcIP);
768+
769+
return true;
768770
}
769771

770772
const semantics::SourceName
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
! Tests reduction processor behavior when a reduction symbol is not supported.
2+
3+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
4+
5+
subroutine foo
6+
implicit none
7+
integer :: k, i
8+
9+
interface max
10+
function max(m1,m2)
11+
integer :: m1, m2
12+
end function
13+
end interface
14+
15+
!$omp do reduction (max: k)
16+
do i=1,10
17+
end do
18+
!$omp end do
19+
end
20+
21+
! Verify that unsupported reduction is ignored.
22+
! CHECK: omp.wsloop
23+
! CHECK-SAME: private(@{{[^[:space:]]+}} %{{[^[:space:]]+}}
24+
! CHECK-SAME: -> %{{[^[:space:]]+}} : !{{[^[:space:]]+}}) {
25+
! CHECK: }

0 commit comments

Comments
 (0)