Skip to content

Commit 5828b55

Browse files
committed
Temp enhance the remove layout implementation to reduce the duplicated values with different layout in scf.for.
Signed-off-by: Lu,Chengjun <chengjun.lu@intel.com>
1 parent 92cff48 commit 5828b55

File tree

3 files changed

+75
-34
lines changed

3 files changed

+75
-34
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ LogicalResult getConvertBackwardSlice(
5050
DenseMap<Value, Attribute> &layout,
5151
std::function<bool(Operation *)> stopPropagation = nullptr,
5252
std::function<Value(OpOperand &, Attribute)> getExistingConversion =
53-
nullptr);
53+
nullptr,
54+
bool includeForOp = false);
5455

5556
LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name,
5657
ArrayRef<Type> paramTypes,

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -157,19 +157,22 @@ class LayoutRematerialization {
157157
getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding,
158158
SetVector<Value> &slice,
159159
DenseMap<Value, Attribute> &layout,
160-
std::function<bool(Operation *)> stopPropagation);
160+
std::function<bool(Operation *)> stopPropagation,
161+
bool includeForOp = false);
161162

162163
LogicalResult getRematerializableSlice(
163164
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
164165
DenseMap<Value, Attribute> &layout,
165-
std::function<bool(Operation *)> stopPropagation = nullptr);
166+
std::function<bool(Operation *)> stopPropagation = nullptr,
167+
bool includeForOp = false);
166168

167169
private:
168170
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
169171
// Existing tuples of (value, layout) that needs to be updated when recreating
170172
// scf ops. This prevents keeping track of Values that have been delete when
171-
// rewriting slices.
172-
DenseMap<Value, Attribute> mappedValues;
173+
// rewriting slices. The Value maybe mapped to different attributes in remove
174+
// layout.
175+
DenseMap<Value, SmallVector<Attribute>> mappedValues;
173176
// map of the values remat based on encoding.
174177
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
175178
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -183,7 +186,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
183186
Value newV) {
184187
LDBG("addRematValue " << old << " encoding " << encoding << " " << newV);
185188
rematMapping[{old, encoding}] = newV;
186-
mappedValues[old] = encoding;
189+
if (mappedValues.contains(old)) {
190+
mappedValues[old].push_back(encoding);
191+
} else {
192+
mappedValues[old] = {encoding};
193+
}
187194
}
188195

189196
// Remove unneeded values now that we are done with the rematMapping.
@@ -988,22 +995,28 @@ void LayoutRematerialization::updateRematMapping(
988995
for (auto [old, newV] : values) {
989996
auto it = mappedValues.find(old);
990997
if (it != mappedValues.end()) {
991-
Attribute encoding = it->second;
992-
auto rematIt = rematMapping.find({old, it->second});
993-
assert(rematIt != rematMapping.end());
994-
Value replacedValue = rematIt->second;
995-
rematMapping.erase(rematIt);
996-
mappedValues.erase(it);
997-
// Loop through the replacement value to find the new version of remat
998-
// value. This should be okay as the number of values should be small.
999-
for (auto [before, after] : values) {
1000-
if (before == replacedValue) {
1001-
replacedValue = after;
1002-
break;
998+
SmallVector<Attribute> encodings = it->second;
999+
for (auto encoding : encodings) {
1000+
auto rematIt = rematMapping.find({old, encoding});
1001+
assert(rematIt != rematMapping.end());
1002+
Value replacedValue = rematIt->second;
1003+
rematMapping.erase(rematIt);
1004+
// Loop through the replacement value to find the new version of remat
1005+
// value. This should be okay as the number of values should be small.
1006+
for (auto [before, after] : values) {
1007+
if (before == replacedValue) {
1008+
replacedValue = after;
1009+
break;
1010+
}
10031011
}
1012+
rematMapping[{newV, encoding}] = replacedValue;
1013+
}
1014+
mappedValues.erase(it);
1015+
if (mappedValues.contains(newV)) {
1016+
mappedValues[newV].append(encodings);
1017+
} else {
1018+
mappedValues[newV] = std::move(encodings);
10041019
}
1005-
rematMapping[{newV, encoding}] = replacedValue;
1006-
mappedValues[newV] = encoding;
10071020
}
10081021
}
10091022
}
@@ -1078,6 +1091,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
10781091
deadOps.push_back(forOp.getOperation());
10791092
Block &loopBody = *newForOp.getBody();
10801093
for (auto m : argMapping) {
1094+
mapping.map(newForOp.getResult(m.first), newForOp.getResult(m.second));
10811095
mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second));
10821096
int numIndVars = newForOp.getNumInductionVars();
10831097
mapping.map(loopBody.getArgument(m.first + numIndVars),
@@ -1188,8 +1202,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11881202
builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv));
11891203
}
11901204

1191-
for (Operation *op : deadOps)
1192-
opToDelete.insert(op);
1205+
for (Operation *op : deadOps) {
1206+
if (!isa<scf::ForOp>(op))
1207+
opToDelete.insert(op);
1208+
else
1209+
op->erase();
1210+
}
11931211
}
11941212

11951213
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
@@ -1202,7 +1220,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
12021220
LogicalResult LayoutRematerialization::getConvertBackwardSlice(
12031221
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12041222
DenseMap<Value, Attribute> &layout,
1205-
std::function<bool(Operation *)> stopPropagation) {
1223+
std::function<bool(Operation *)> stopPropagation, bool includeForOp) {
12061224
// Allow re-using existing conversions for a value. Check dominance of any
12071225
// reusable materializations against the root value. This is sufficient
12081226
// because the conversions are processed in post-order.
@@ -1231,15 +1249,16 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
12311249
};
12321250

12331251
return ttgi::getConvertBackwardSlice(root, slice, rootEncoding, layout,
1234-
stopPropagation, getExistingConversion);
1252+
stopPropagation, getExistingConversion,
1253+
includeForOp);
12351254
}
12361255

12371256
LogicalResult LayoutRematerialization::getRematerializableSlice(
12381257
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12391258
DenseMap<Value, Attribute> &layout,
1240-
std::function<bool(Operation *)> stopPropagation) {
1241-
LogicalResult result = getConvertBackwardSlice(root, rootEncoding, slice,
1242-
layout, stopPropagation);
1259+
std::function<bool(Operation *)> stopPropagation, bool includeForOp) {
1260+
LogicalResult result = getConvertBackwardSlice(
1261+
root, rootEncoding, slice, layout, stopPropagation, includeForOp);
12431262
if (result.failed() || slice.empty())
12441263
return failure();
12451264

@@ -1362,8 +1381,9 @@ void LayoutRematerialization::backwardRematerialization(
13621381
// rematerialized.
13631382
SetVector<Value> slice;
13641383
DenseMap<Value, Attribute> layout;
1365-
LogicalResult result = getRematerializableSlice(
1366-
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
1384+
LogicalResult result = getRematerializableSlice(convertOp.getSrcMutable(),
1385+
targetType.getEncoding(),
1386+
slice, layout, nullptr, true);
13671387
if (result.failed()) {
13681388
LDBG(" getRematerializableSlice failed");
13691389
return;

third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ LogicalResult getConvertBackwardSlice(
182182
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
183183
DenseMap<Value, Attribute> &layout,
184184
std::function<bool(Operation *)> stopPropagation,
185-
std::function<Value(OpOperand &, Attribute)> getExistingConversion) {
185+
std::function<Value(OpOperand &, Attribute)> getExistingConversion,
186+
bool includeForOp) {
186187
DenseSet<std::pair<OpOperand *, Attribute>> seen;
187188
SmallVector<std::pair<OpOperand *, Attribute>> queue;
188189

@@ -197,6 +198,12 @@ LogicalResult getConvertBackwardSlice(
197198

198199
auto updateLayout = [&](Value value, Attribute encoding) {
199200
assert(isTensorOrTensorPointerType(value.getType()));
201+
auto tensorType = getRankedTensorType(value.getType());
202+
auto originEncoding = tensorType.getEncoding();
203+
if (originEncoding == encoding) {
204+
return success();
205+
}
206+
200207
slice.insert(value);
201208
Attribute &existing = layout[value];
202209
if (existing && existing != encoding)
@@ -211,10 +218,7 @@ LogicalResult getConvertBackwardSlice(
211218
queue.pop_back();
212219
if (!isTensorOrTensorPointerType(currentValue.getType()))
213220
continue;
214-
// Skip propagating through for op results for now.
215-
// TODO: enable this based on needs.
216-
if (currentValue.getDefiningOp<scf::ForOp>())
217-
return failure();
221+
218222
if (failed(updateLayout(currentValue, encoding)))
219223
return failure();
220224

@@ -226,6 +230,22 @@ LogicalResult getConvertBackwardSlice(
226230
currentValue = existing;
227231
}
228232

233+
if (auto forOp = currentValue.getDefiningOp<scf::ForOp>()) {
234+
if (!includeForOp)
235+
return failure();
236+
if (stopPropagation && stopPropagation(forOp))
237+
continue;
238+
unsigned argIdx = mlir::cast<OpResult>(currentValue).getResultNumber();
239+
int numIndVars = forOp.getNumInductionVars();
240+
Block &loopBody = *forOp.getBody();
241+
auto blockArg = loopBody.getArgument(argIdx + numIndVars);
242+
OpOperand *initOperand = forOp.getTiedLoopInit(blockArg);
243+
OpOperand &yieldOperand = loopBody.getTerminator()->getOpOperand(argIdx);
244+
enqueue(*initOperand, encoding);
245+
enqueue(yieldOperand, encoding);
246+
continue;
247+
}
248+
229249
if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {
230250
if (stopPropagation && stopPropagation(ifOp))
231251
continue;

0 commit comments

Comments
 (0)