@@ -157,19 +157,22 @@ class LayoutRematerialization {
157
157
getConvertBackwardSlice (OpOperand &root, Attribute rootEncoding,
158
158
SetVector<Value> &slice,
159
159
DenseMap<Value, Attribute> &layout,
160
- std::function<bool (Operation *)> stopPropagation);
160
+ std::function<bool (Operation *)> stopPropagation,
161
+ bool includeForOp = false );
161
162
162
163
LogicalResult getRematerializableSlice (
163
164
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
164
165
DenseMap<Value, Attribute> &layout,
165
- std::function<bool (Operation *)> stopPropagation = nullptr);
166
+ std::function<bool (Operation *)> stopPropagation = nullptr,
167
+ bool includeForOp = false);
166
168
167
169
private:
168
170
void updateRematMapping (SmallVector<std::tuple<Value, Value>> &values);
169
171
// Existing tuples of (value, layout) that needs to be updated when recreating
170
172
// scf ops. This prevents keeping track of Values that have been delete when
171
- // rewriting slices.
172
- DenseMap<Value, Attribute> mappedValues;
173
+ // rewriting slices. The Value maybe mapped to different attributes in remove
174
+ // layout.
175
+ DenseMap<Value, SmallVector<Attribute>> mappedValues;
173
176
// map of the values remat based on encoding.
174
177
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
175
178
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -183,7 +186,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
183
186
Value newV) {
184
187
LDBG (" addRematValue " << old << " encoding " << encoding << " " << newV);
185
188
rematMapping[{old, encoding}] = newV;
186
- mappedValues[old] = encoding;
189
+ if (mappedValues.contains (old)) {
190
+ mappedValues[old].push_back (encoding);
191
+ } else {
192
+ mappedValues[old] = {encoding};
193
+ }
187
194
}
188
195
189
196
// Remove unneeded values now that we are done with the rematMapping.
@@ -988,22 +995,28 @@ void LayoutRematerialization::updateRematMapping(
988
995
for (auto [old, newV] : values) {
989
996
auto it = mappedValues.find (old);
990
997
if (it != mappedValues.end ()) {
991
- Attribute encoding = it->second ;
992
- auto rematIt = rematMapping.find ({old, it->second });
993
- assert (rematIt != rematMapping.end ());
994
- Value replacedValue = rematIt->second ;
995
- rematMapping.erase (rematIt);
996
- mappedValues.erase (it);
997
- // Loop through the replacement value to find the new version of remat
998
- // value. This should be okay as the number of values should be small.
999
- for (auto [before, after] : values) {
1000
- if (before == replacedValue) {
1001
- replacedValue = after;
1002
- break ;
998
+ SmallVector<Attribute> encodings = it->second ;
999
+ for (auto encoding : encodings) {
1000
+ auto rematIt = rematMapping.find ({old, encoding});
1001
+ assert (rematIt != rematMapping.end ());
1002
+ Value replacedValue = rematIt->second ;
1003
+ rematMapping.erase (rematIt);
1004
+ // Loop through the replacement value to find the new version of remat
1005
+ // value. This should be okay as the number of values should be small.
1006
+ for (auto [before, after] : values) {
1007
+ if (before == replacedValue) {
1008
+ replacedValue = after;
1009
+ break ;
1010
+ }
1003
1011
}
1012
+ rematMapping[{newV, encoding}] = replacedValue;
1013
+ }
1014
+ mappedValues.erase (it);
1015
+ if (mappedValues.contains (newV)) {
1016
+ mappedValues[newV].append (encodings);
1017
+ } else {
1018
+ mappedValues[newV] = std::move (encodings);
1004
1019
}
1005
- rematMapping[{newV, encoding}] = replacedValue;
1006
- mappedValues[newV] = encoding;
1007
1020
}
1008
1021
}
1009
1022
}
@@ -1078,6 +1091,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
1078
1091
deadOps.push_back (forOp.getOperation ());
1079
1092
Block &loopBody = *newForOp.getBody ();
1080
1093
for (auto m : argMapping) {
1094
+ mapping.map (newForOp.getResult (m.first ), newForOp.getResult (m.second ));
1081
1095
mapping.map (forOp.getResult (m.first ), newForOp.getResult (m.second ));
1082
1096
int numIndVars = newForOp.getNumInductionVars ();
1083
1097
mapping.map (loopBody.getArgument (m.first + numIndVars),
@@ -1188,8 +1202,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
1188
1202
builder.replaceAllUsesWith (std::get<0 >(kv), std::get<1 >(kv));
1189
1203
}
1190
1204
1191
- for (Operation *op : deadOps)
1192
- opToDelete.insert (op);
1205
+ for (Operation *op : deadOps) {
1206
+ if (!isa<scf::ForOp>(op))
1207
+ opToDelete.insert (op);
1208
+ else
1209
+ op->erase ();
1210
+ }
1193
1211
}
1194
1212
1195
1213
void LayoutRematerialization::rewriteSlice (SetVector<Value> &slice,
@@ -1202,7 +1220,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
1202
1220
LogicalResult LayoutRematerialization::getConvertBackwardSlice (
1203
1221
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
1204
1222
DenseMap<Value, Attribute> &layout,
1205
- std::function<bool (Operation *)> stopPropagation) {
1223
+ std::function<bool (Operation *)> stopPropagation, bool includeForOp ) {
1206
1224
// Allow re-using existing conversions for a value. Check dominance of any
1207
1225
// reusable materializations against the root value. This is sufficient
1208
1226
// because the conversions are processed in post-order.
@@ -1231,15 +1249,16 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
1231
1249
};
1232
1250
1233
1251
return ttgi::getConvertBackwardSlice (root, slice, rootEncoding, layout,
1234
- stopPropagation, getExistingConversion);
1252
+ stopPropagation, getExistingConversion,
1253
+ includeForOp);
1235
1254
}
1236
1255
1237
1256
LogicalResult LayoutRematerialization::getRematerializableSlice (
1238
1257
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
1239
1258
DenseMap<Value, Attribute> &layout,
1240
- std::function<bool (Operation *)> stopPropagation) {
1241
- LogicalResult result = getConvertBackwardSlice (root, rootEncoding, slice,
1242
- layout, stopPropagation);
1259
+ std::function<bool (Operation *)> stopPropagation, bool includeForOp ) {
1260
+ LogicalResult result = getConvertBackwardSlice (
1261
+ root, rootEncoding, slice, layout, stopPropagation, includeForOp );
1243
1262
if (result.failed () || slice.empty ())
1244
1263
return failure ();
1245
1264
@@ -1362,8 +1381,9 @@ void LayoutRematerialization::backwardRematerialization(
1362
1381
// rematerialized.
1363
1382
SetVector<Value> slice;
1364
1383
DenseMap<Value, Attribute> layout;
1365
- LogicalResult result = getRematerializableSlice (
1366
- convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout);
1384
+ LogicalResult result = getRematerializableSlice (convertOp.getSrcMutable (),
1385
+ targetType.getEncoding (),
1386
+ slice, layout, nullptr , true );
1367
1387
if (result.failed ()) {
1368
1388
LDBG (" getRematerializableSlice failed" );
1369
1389
return ;
0 commit comments