Skip to content

Commit 7a9d18c

Browse files
committed
[compiler] support shape reification for callOp
1 parent 173af1c commit 7a9d18c

File tree

14 files changed

+286
-26
lines changed

14 files changed

+286
-26
lines changed

compiler/include/byteir/Dialect/mhlo/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
#include "byteir/Dialect/mhlo/Transforms/LayoutTransformation.h"
3737
#include "byteir/Dialect/mhlo/Transforms/MatmulLayoutTransform.h"
3838
#include "byteir/Dialect/mhlo/Transforms/RewriteWithConstraint.h"
39-
#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h"
39+
#include "byteir/Transforms/ShapeReification.h"
4040
#include "byteir/Dialect/mhlo/Transforms/StaticShapeInference.h"
4141
#include "byteir/Dialect/mhlo/Transforms/UnfuseBatchNorm.h"
4242

compiler/include/byteir/Dialect/mhlo/Passes.td

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -305,25 +305,6 @@ def RewriteWithConstraint : Pass<"rewrite-with-constraint", "mlir::func::FuncOp
305305
let constructor = "mlir::createRewriteWithConstraintPass()";
306306
}
307307

308-
//===----------------------------------------------------------------------===//
309-
// ShapeReification
310-
//===----------------------------------------------------------------------===//
311-
312-
def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> {
313-
let summary = "Iteratively reify all shape computations.";
314-
let description = [{
315-
If an operation has a shape reification implementation, that is to say, we
316-
know how to express the outputs' shape by it's inputs' shape symbolicly,
317-
then a tensor.dim or shape.shape_of on this type of operation could be
318-
reified. And shape reification procedure could be handled recursively.
319-
}];
320-
let constructor = "mlir::createByteIRShapeReificationPass()";
321-
let dependentDialects = [
322-
"mlir::shape::ShapeDialect",
323-
"mlir::tensor::TensorDialect"
324-
];
325-
}
326-
327308
//===----------------------------------------------------------------------===//
328309
// Static Shape Inference
329310
//===----------------------------------------------------------------------===//

compiler/include/byteir/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "byteir/Transforms/SetArgShape.h"
3737
#include "byteir/Transforms/SetSpace.h"
3838
#include "byteir/Transforms/TryCatchModulePipeline.h"
39+
#include "byteir/Transforms/ShapeReification.h"
3940

4041
namespace mlir {
4142

compiler/include/byteir/Transforms/Passes.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,4 +425,23 @@ def SetOpSpace: Pass<"set-op-space", "func::FuncOp"> {
425425
];
426426
}
427427

428+
//===----------------------------------------------------------------------===//
429+
// ShapeReification
430+
//===----------------------------------------------------------------------===//
431+
432+
def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> {
433+
let summary = "Iteratively reify all shape computations.";
434+
let description = [{
435+
If an operation has a shape reification implementation, that is to say, we
436+
know how to express the outputs' shape by it's inputs' shape symbolicly,
437+
then a tensor.dim or shape.shape_of on this type of operation could be
438+
reified. And shape reification procedure could be handled recursively.
439+
}];
440+
let constructor = "mlir::createByteIRShapeReificationPass()";
441+
let dependentDialects = [
442+
"mlir::shape::ShapeDialect",
443+
"mlir::tensor::TensorDialect"
444+
];
445+
}
446+
428447
#endif // BYTEIR_TRANSFORMS_PASSES

compiler/include/byteir/Dialect/mhlo/Transforms/ShapeReification.h renamed to compiler/include/byteir/Transforms/ShapeReification.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//===- ShapeReification.h -------------------------------------*--- C++ -*-===//
22
//
3-
// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
// Licensed under the Apache License, Version 2.0 (the "License");
55
// you may not use this file except in compliance with the License.
66
// You may obtain a copy of the License at

compiler/lib/Analysis/SymbolicShape.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
//===----------------------------------------------------------------------===//
1717

1818
#include "byteir/Analysis/SymbolicShape.h"
19-
#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h"
19+
#include "byteir/Transforms/ShapeReification.h"
2020
#include "mlir/Dialect/Shape/IR/Shape.h"
2121
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2222
#include "mlir/IR/IRMapping.h"

compiler/lib/Dialect/mhlo/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ add_mlir_dialect_library(ByteIRMhloPasses
105105
Transforms/ReduceFusion.cpp
106106
Transforms/ReshapeGather.cpp
107107
Transforms/RewriteWithConstraint.cpp
108-
Transforms/ShapeReification.cpp
109108
Transforms/StaticShapeInference.cpp
110109
Transforms/TrivialFusion.cpp
111110
Transforms/UnfuseBatchNorm.cpp

compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,19 @@
1717

1818
#include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h"
1919
#include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h"
20+
#include "byteir/Transforms/ShapeReification.h"
2021
#include "mhlo/IR/hlo_ops.h"
2122
#include "mlir/Dialect/Arith/IR/Arith.h"
2223
#include "mlir/Dialect/Func/IR/FuncOps.h"
2324
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2425
#include "mlir/IR/BuiltinTypes.h"
26+
#include "mlir/Pass/PassManager.h"
27+
#include "mlir/Transforms/Passes.h"
28+
#include "mlir/Transforms/TopologicalSortUtils.h"
2529
#include "llvm/ADT/StringMap.h"
2630
#include "llvm/Support/Debug.h"
31+
#include <queue>
32+
#include <string>
2733

2834
using namespace mlir;
2935

@@ -177,6 +183,168 @@ mlir::inferReturnTypeComponents(llvm::StringRef name) {
177183
return nullptr;
178184
}
179185

186+
namespace {
187+
188+
SmallVector<Operation *> collectAllOpsForReturn(Operation *retOp) {
189+
llvm::DenseSet<Operation *> visitedOp;
190+
std::queue<Operation *> opQueue;
191+
192+
opQueue.push(retOp);
193+
while (!opQueue.empty()) {
194+
auto frontOp = opQueue.front();
195+
opQueue.pop();
196+
if (visitedOp.find(frontOp) != visitedOp.end()) {
197+
continue;
198+
}
199+
visitedOp.insert(frontOp);
200+
for (Value operand : frontOp->getOperands()) {
201+
if (!operand.getDefiningOp()) {
202+
continue;
203+
}
204+
if (Operation *defOp = operand.getDefiningOp()) {
205+
opQueue.push(defOp);
206+
}
207+
}
208+
}
209+
visitedOp.erase(retOp);
210+
return SmallVector<Operation *>(visitedOp.begin(), visitedOp.end());
211+
}
212+
213+
bool deduceFromFuncArgShape(Value value) {
214+
if (value.isa<BlockArgument>()) {
215+
return false;
216+
}
217+
218+
auto defOp = value.getDefiningOp();
219+
if (!defOp) {
220+
return false;
221+
}
222+
223+
if (isa<arith::ConstantIndexOp, arith::ConstantOp>(defOp)) {
224+
return true;
225+
}
226+
227+
if (isa<tensor::DimOp, shape::ShapeOfOp>(defOp)) {
228+
auto operand = defOp->getOperand(0);
229+
if (operand.isa<BlockArgument>()) {
230+
return true;
231+
}
232+
return false;
233+
}
234+
235+
for (Value &&operand : defOp->getOperands()) {
236+
if (!deduceFromFuncArgShape(operand)) {
237+
return false;
238+
}
239+
}
240+
return true;
241+
}
242+
243+
LogicalResult reifyCallOp(OpBuilder &builder, Operation *op,
244+
SmallVectorImpl<Value> &reifications) {
245+
OpBuilder::InsertionGuard guard(builder);
246+
auto callOp = dyn_cast<func::CallOp>(op);
247+
if (!callOp) {
248+
return failure();
249+
}
250+
251+
ModuleOp moduleOp = op->getParentRegion()->getParentOfType<ModuleOp>();
252+
// auxiliary builder used for create operations in shape func
253+
// original builder maybe a rewriter, used for create operations in specific
254+
// pattern.
255+
OpBuilder auxiliaryBuilder(moduleOp);
256+
StringRef funcName = callOp.getCallee();
257+
auto funcOp = moduleOp.lookupSymbol<func::FuncOp>(funcName);
258+
259+
// clone funcOp, newFuncOp used for deduce function shape
260+
std::string newFuncName = funcName.str() + "_Shape";
261+
auxiliaryBuilder.setInsertionPointToStart(moduleOp.getBody());
262+
auto newFuncOp = auxiliaryBuilder.create<func::FuncOp>(
263+
funcOp->getLoc(), newFuncName, funcOp.getFunctionType());
264+
newFuncOp.setPrivate();
265+
IRMapping emptyBvm;
266+
funcOp.cloneInto(newFuncOp, emptyBvm);
267+
268+
// replace the operands of returnOp with corresponding shape
269+
func::ReturnOp retOp = *newFuncOp.getOps<func::ReturnOp>().begin();
270+
if (!retOp) {
271+
newFuncOp->erase();
272+
return failure();
273+
}
274+
275+
SmallVector<Type> allResultTypes;
276+
SmallVector<Value> allResults;
277+
278+
auxiliaryBuilder.setInsertionPoint(retOp);
279+
for (Value &&retTensor : retOp.getOperands()) {
280+
auto retShape =
281+
auxiliaryBuilder.create<shape::ShapeOfOp>(retOp.getLoc(), retTensor);
282+
allResultTypes.emplace_back(retShape.getType());
283+
allResults.emplace_back(retShape);
284+
}
285+
286+
// return the shape of original tensor returned by function
287+
auto newRetOp =
288+
auxiliaryBuilder.create<func::ReturnOp>(retOp.getLoc(), allResults);
289+
auto newFuncType = auxiliaryBuilder.getFunctionType(
290+
newFuncOp.getArgumentTypes(), allResultTypes);
291+
newFuncOp.setFunctionType(newFuncType);
292+
retOp->erase();
293+
294+
// reify newFunc to get the shape computation for current callOp
295+
{
296+
PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName());
297+
pm.addPass(createCanonicalizerPass());
298+
pm.addPass(createCSEPass());
299+
pm.addPass(createByteIRShapeReificationPass());
300+
pm.addPass(createCanonicalizerPass());
301+
pm.addPass(createCSEPass());
302+
303+
if (mlir::failed(pm.run(newFuncOp))) {
304+
newFuncOp->erase();
305+
return failure();
306+
}
307+
}
308+
309+
// collect all shape computation ops
310+
SmallVector<Operation *> reificationOps = collectAllOpsForReturn(newRetOp);
311+
312+
// value only depends on the shape of FuncArgs.
313+
for (Value &&ret : newRetOp.getOperands()) {
314+
if (!deduceFromFuncArgShape(ret)) {
315+
newFuncOp->erase();
316+
return failure();
317+
}
318+
}
319+
320+
// mapping the shape computation ops and collect reifications
321+
{
322+
mlir::computeTopologicalSorting(reificationOps);
323+
324+
IRMapping bvm;
325+
size_t numArg = newFuncOp.getNumArguments();
326+
for (size_t i = 0; i < numArg; ++i) {
327+
bvm.map(newFuncOp.getArgument(i), callOp.getOperand(i));
328+
}
329+
330+
builder.setInsertionPoint(callOp);
331+
332+
for (Operation *oldOp : reificationOps) {
333+
auto newOp = builder.clone(*oldOp, bvm);
334+
}
335+
336+
for (Value &&ret : newRetOp.getOperands()) {
337+
reifications.push_back(bvm.lookup(ret));
338+
}
339+
}
340+
341+
// remove newFuncOp
342+
newFuncOp->erase();
343+
return success();
344+
}
345+
346+
} // namespace
347+
180348
LogicalResult mlir::reifyShapes(OpBuilder &builder, Operation *op,
181349
SmallVectorImpl<Value> &reifications) {
182350
if (!op)
@@ -207,6 +375,16 @@ LogicalResult mlir::reifyShapes(OpBuilder &builder, Operation *op,
207375
}
208376
if (failed(inferFunc(op, builder, op->getOperands(), reifications)))
209377
return failure();
378+
} else if (auto callOp = dyn_cast<func::CallOp>(op)) {
379+
if (failed(reifyCallOp(builder, op, reifications))) {
380+
return failure();
381+
}
382+
} else if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op)) {
383+
for (OpResult &&result : op->getOpResults()) {
384+
auto tiedOperand = dpsOp.getTiedOpOperand(result);
385+
reifications.push_back(
386+
builder.create<shape::ShapeOfOp>(op->getLoc(), tiedOperand->get()));
387+
}
210388
} else {
211389
// Return failure if op doesn't have InferShapedTypeOpInterface and not
212390
// registered.

compiler/lib/Pipelines/ByreTensorOpt.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ void createByreTensorOptPipelineImpl(OpPassManager &pm, std::string entryFunc,
4747
createConvertHloToByreCustomPass(getCudaByreCustomConfig()));
4848
pm.addNestedPass<func::FuncOp>(
4949
createConvertHloToByreTensorPass(appendArgTypes));
50+
pm.addNestedPass<func::FuncOp>(createByteIRShapeReificationPass());
5051
pm.addPass(createCanonicalizerPass());
5152
}
5253
} // namespace

compiler/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_mlir_library(ByteIRTransforms
1717
RewriteOpToStdCall.cpp
1818
SetArgShape.cpp
1919
SetSpace.cpp
20+
ShapeReification.cpp
2021
Utils.cpp
2122

2223
ADDITIONAL_HEADER_DIRS

compiler/lib/Transforms/PassDetail.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ namespace scf {
5151
class SCFDialect;
5252
} // namespace scf
5353

54+
namespace shape {
55+
class ShapeDialect;
56+
} // namespace shape
57+
58+
namespace tensor {
59+
class TensorDialect;
60+
} // namespace tensor
61+
5462
#define GEN_PASS_CLASSES
5563
#include "byteir/Transforms/Passes.h.inc"
5664

compiler/lib/Dialect/mhlo/Transforms/ShapeReification.cpp renamed to compiler/lib/Transforms/ShapeReification.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//===- ShapeReification.cpp -----------------------------------*--- C++ -*-===//
22
//
3-
// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
// Licensed under the Apache License, Version 2.0 (the "License");
55
// you may not use this file except in compliance with the License.
66
// You may obtain a copy of the License at
@@ -15,7 +15,7 @@
1515
//
1616
//===----------------------------------------------------------------------===//
1717

18-
#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h"
18+
#include "byteir/Transforms/ShapeReification.h"
1919

2020
#include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h"
2121
#include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h"

compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,20 @@ func.func @test_normal_function_call(%arg0 : tensor<4xf32>) -> tensor<4xf32> att
2020
}
2121
// CHECK-LABEL: test_normal_function_call
2222
// CHECK: call @some_func
23+
24+
25+
// -----
26+
27+
func.func private @Unknown0(%arg0: tensor<?x20xf32>, %arg1: tensor<?x20xf32>) -> tensor<?x20xf32> attributes {__byteir_elementwise_fusion__, byre_compute_name = "Unknown0"} {
28+
%0 = mhlo.add %arg0, %arg1 : tensor<?x20xf32>
29+
return %0 : tensor<?x20xf32>
30+
}
31+
32+
func.func @forward(%arg0: tensor<?x20xf32>, %arg1: tensor<?x20xf32>) -> tensor<?x20xf32> attributes {__placeholder__byre.entry_point} {
33+
%1 = call @Unknown0(%arg1, %arg0) : (tensor<?x20xf32>, tensor<?x20xf32>) -> tensor<?x20xf32>
34+
return %1 : tensor<?x20xf32>
35+
}
36+
37+
// CHECK-LABEL: func.func @forward
38+
// CHECK: tensor.empty
39+
// CHECK-NEXT: byre.compute_on_tensor @Unknown0

0 commit comments

Comments
 (0)