Skip to content

Commit 617655f

Browse files
authored
skip some op cache check; refine warning message (#71693)
1 parent 59e2669 commit 617655f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

paddle/pir/src/dialect/shape/transforms/shape_optimization_pass.cc

+8-2
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,8 @@ void InferSymExprForOp(Operation* op,
344344
static const std::set<std::string> skip_cache_check_op_set = {
345345
// new symbol
346346
"pd_op.data",
347+
"pd_op.arange",
348+
"pd_op.masked_select",
347349
// unneeded to cache
348350
"cinn_op.generate_shape",
349351
};
@@ -357,7 +359,8 @@ void CacheForwardOpSymbolicShape(
357359
[&](const InferSymbolicShapeCacheValue& infer_result,
358360
const InferSymbolicShapeCacheValue& cache_result) {
359361
if (infer_result.size() != cache_result.size()) {
360-
LOG(WARNING) << "cached shape is not consistent with real shape";
362+
LOG(WARNING) << "cached shape is not consistent with real shape for "
363+
<< op->name() << "[id:" << op->id() << "]";
361364
} else {
362365
for (uint32_t i = 0; i < cache_result.size(); ++i) {
363366
if (infer_result[i] != cache_result[i]) {
@@ -368,7 +371,10 @@ void CacheForwardOpSymbolicShape(
368371
skip_cache_check_op_set.end()) {
369372
continue;
370373
}
371-
LOG(WARNING) << "cached shape is not consistent with real shape";
374+
LOG(WARNING)
375+
<< "cached shape is not consistent with real shape for "
376+
<< op->name() << "[id:" << op->id()
377+
<< "] with result index: " << i;
372378
VLOG(3) << "InferSymbolicShapeCacheKey is: "
373379
<< op_infer_cache_key;
374380
VLOG(3) << "cached shape is: " << cache_result[i];

0 commit comments

Comments
 (0)