Skip to content

Commit 1ecf5ee

Browse files
authored
[CINN] Add the TileTransposeTactic (#70942)
1 parent 6821f2e commit 1ecf5ee

15 files changed

+912
-73
lines changed

paddle/cinn/backends/codegen_gpu_dev.cc

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,23 @@ std::vector<ir::stmt::StmtRef> CodeGenGpuDev::GenerateBufferAliasStmts(
9595
}
9696

9797
for (auto &t : unique_tensors) {
98-
auto data_type = t->type();
99-
auto data_ptr_type = data_type;
100-
data_ptr_type.set_cpp_handle();
98+
auto tensor_type = t->type();
99+
auto tensor_ptr_type = tensor_type;
100+
tensor_ptr_type.set_cpp_handle();
101+
102+
auto buffer_type = t->buffer->dtype;
103+
auto buffer_ptr_type = buffer_type;
104+
buffer_ptr_type.set_cpp_handle();
105+
106+
Expr t_var = Var(t->name, tensor_ptr_type);
107+
Expr buf_var = Var(t->buffer->name, buffer_ptr_type);
108+
109+
// A tensor and its buffer may have different types when multiple tensors
110+
// share the same buffer. In this case, add a Cast before aliasing.
111+
if (tensor_type != buffer_type) {
112+
buf_var = common::cast(buf_var, tensor_ptr_type);
113+
}
101114

102-
Var t_var(t->name, data_ptr_type);
103-
Var buf_var(t->buffer->name, data_ptr_type);
104115
buffer_alias.push_back(ir::stmt::Let(t_var, buf_var));
105116
}
106117

paddle/cinn/ir/group_schedule/config/group_tile_config.cc

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -640,33 +640,20 @@ TileConfigMap BuildStaticReduceConfig(
640640
/* spatial_inner_num = */ 1,
641641
/* vectorize_factor = */ 1,
642642
NoneReduceMethod()};
643-
BucketInfo bucket_info__1024_1M{/* sp_lower_bound = */ 1024,
644-
/* sp_upper_bound = */ 1024 * 1024 - 1,
645-
/* rb_lower_bound = */ 1,
646-
/* rb_upper_bound = */ 1,
647-
/* sp_is_dynamic = */ true,
648-
/* rb_is_dynamic = */ false};
649-
TileConfig tile_config__1024_1M{/* warp_num = */ 32,
650-
/* tree_reduce_num = */ 1,
651-
/* grid_reduce_num = */ 1,
652-
/* spatial_inner_num = */ 4,
653-
/* vectorize_factor = */ 1,
654-
NoneReduceMethod()};
655-
BucketInfo bucket_info__1M_INF{/* sp_lower_bound = */ 1024 * 1024,
656-
/* sp_upper_bound = */ kMaxNumel,
657-
/* rb_lower_bound = */ 1,
658-
/* rb_upper_bound = */ 1,
659-
/* sp_is_dynamic = */ true,
660-
/* rb_is_dynamic = */ false};
661-
TileConfig tile_config__1M_INF{/* warp_num = */ 32,
662-
/* tree_reduce_num = */ 1,
663-
/* grid_reduce_num = */ 1,
664-
/* spatial_inner_num = */ 4,
665-
/* vectorize_factor = */ 1,
666-
NoneReduceMethod()};
643+
BucketInfo bucket_info__1024_INF{/* sp_lower_bound = */ 1024,
644+
/* sp_upper_bound = */ kMaxNumel,
645+
/* rb_lower_bound = */ 1,
646+
/* rb_upper_bound = */ 1,
647+
/* sp_is_dynamic = */ true,
648+
/* rb_is_dynamic = */ false};
649+
TileConfig tile_config__1024_INF{/* warp_num = */ 32,
650+
/* tree_reduce_num = */ 1,
651+
/* grid_reduce_num = */ 1,
652+
/* spatial_inner_num = */ 4,
653+
/* vectorize_factor = */ 1,
654+
NoneReduceMethod()};
667655
return {{bucket_info__1_1023, tile_config__1_1023},
668-
{bucket_info__1024_1M, tile_config__1024_1M},
669-
{bucket_info__1M_INF, tile_config__1M_INF}};
656+
{bucket_info__1024_INF, tile_config__1024_INF}};
670657
} else if (base_info->reduce_numel <= 256) {
671658
BucketInfo bucket_info{/* sp_lower_bound = */ 1,
672659
/* sp_upper_bound = */ kMaxNumel,

paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h"
2222
#include "paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.h"
2323
#include "paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.h"
24+
#include "paddle/cinn/ir/group_schedule/tactic/tile_transpose_tactic.h"
2425
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
2526
#include "paddle/cinn/ir/op/ir_operators.h"
2627
#include "paddle/common/enforce.h"
@@ -36,6 +37,7 @@ void DynamicShapeGroupScheduler::Init() {
3637
InitBuckets();
3738
tactics_.emplace_back(CreateAlignIterSpaceTactic());
3839
tactics_.emplace_back(CreateTileBroadcastTactic());
40+
tactics_.emplace_back(CreateTileTransposeTactic());
3941
tactics_.emplace_back(CreateTileFirstGeneralTactic());
4042
tactics_.emplace_back(CreateComputeInlineTactic());
4143
tactics_.emplace_back(CreateComputeAtReductionTactic());

paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ gather_srcs(cinnapi_src SRCS compute_at_reduction_tactic.cc)
88
gather_srcs(cinnapi_src SRCS bind_cuda_tactic.cc)
99
gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc)
1010
gather_srcs(cinnapi_src SRCS tile_broadcast_tactic.cc)
11+
gather_srcs(cinnapi_src SRCS tile_transpose_tactic.cc)
1112
gather_srcs(cinnapi_src SRCS tile_first_general_tactic.cc)

0 commit comments

Comments
 (0)