Skip to content

Commit bf395f1

Browse files
authored
Merge branch 'develop' into 104
2 parents f9653a6 + a0700da commit bf395f1

File tree

366 files changed

+9519
-3043
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

366 files changed

+9519
-3043
lines changed

cmake/cinn.cmake

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ cinn_cc_library(
164164
add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB)
165165
add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ${core_deps})
166166
if(NOT CINN_ONLY)
167-
target_link_libraries(cinnapi pd_op_dialect phi)
168-
add_dependencies(cinnapi pd_op_dialect phi)
167+
target_link_libraries(cinnapi op_dialect_vjp phi)
168+
add_dependencies(cinnapi op_dialect_vjp phi)
169169
endif()
170170

171171
target_link_libraries(cinnapi ${PYTHON_LIBRARIES})
@@ -222,8 +222,8 @@ function(gen_cinncore LINKTYPE)
222222
add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB)
223223
add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ${core_deps})
224224
if(NOT CINN_ONLY)
225-
target_link_libraries(${CINNCORE_TARGET} pd_op_dialect phi)
226-
add_dependencies(${CINNCORE_TARGET} pd_op_dialect phi)
225+
target_link_libraries(${CINNCORE_TARGET} op_dialect_vjp phi)
226+
add_dependencies(${CINNCORE_TARGET} op_dialect_vjp phi)
227227
endif()
228228

229229
add_dependencies(${CINNCORE_TARGET} pybind)

cmake/external/xpu.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ set(XPU_XPTI_LIB_NAME "libxpti.so")
2626
if(NOT DEFINED XPU_BASE_DATE)
2727
set(XPU_BASE_DATE "20231103")
2828
endif()
29-
set(XPU_XCCL_BASE_VERSION "1.0.53.6")
29+
set(XPU_XCCL_BASE_VERSION "1.1.6.1")
3030
if(NOT DEFINED XPU_XFT_BASE_VERSION)
3131
set(XPU_XFT_BASE_VERSION "20230602")
3232
endif()

paddle/cinn/backends/nvrtc/nvrtc_util.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
PD_DECLARE_string(cinn_nvcc_cmd_path);
3434
PD_DECLARE_bool(nvrtc_compile_to_cubin);
35+
PD_DECLARE_bool(cinn_nvrtc_cubin_with_fmad);
3536

3637
namespace cinn {
3738
namespace backends {
@@ -106,6 +107,9 @@ std::string Compiler::CompileCudaSource(const std::string& code,
106107
}
107108
if (compile_to_cubin_) {
108109
compile_options.push_back("-arch=sm_" + cc);
110+
std::string enable_fmad =
111+
FLAGS_cinn_nvrtc_cubin_with_fmad ? "true" : "false";
112+
compile_options.push_back("--fmad=" + enable_fmad);
109113
} else {
110114
compile_options.push_back("-arch=compute_" + cc);
111115
}

paddle/cinn/common/context.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,22 @@ struct NameGenerator {
5252
mutable std::mutex mutex_;
5353
};
5454

55+
struct PrettyNamer {
56+
const std::string& GetOrNew(const size_t hash_key,
57+
const std::string& name_hint) {
58+
if (pretty_names_.find(hash_key) == pretty_names_.end()) {
59+
pretty_names_[hash_key] = name_generator_.New(name_hint);
60+
}
61+
return pretty_names_.at(hash_key);
62+
}
63+
64+
NameGenerator& GetNameGenerator() { return name_generator_; }
65+
66+
private:
67+
absl::flat_hash_map<size_t, std::string> pretty_names_;
68+
NameGenerator name_generator_;
69+
};
70+
5571
class Context {
5672
public:
5773
static Context& Global();
@@ -61,10 +77,15 @@ class Context {
6177
* @param name_hint The prefix.
6278
*/
6379
std::string NewName(const std::string& name_hint) {
64-
return name_generator_.New(name_hint);
80+
return pretty_namer_.GetNameGenerator().New(name_hint);
6581
}
6682

67-
void ResetNameId() { name_generator_.ResetID(); }
83+
std::string PrettyUniqName(const size_t hash_key,
84+
const std::string& name_hint) {
85+
return pretty_namer_.GetOrNew(hash_key, name_hint);
86+
}
87+
88+
void ResetNameId() { pretty_namer_.GetNameGenerator().ResetID(); }
6889

6990
const std::vector<std::string>& runtime_include_dir();
7091

@@ -82,7 +103,7 @@ class Context {
82103
private:
83104
Context() = default;
84105

85-
NameGenerator name_generator_;
106+
PrettyNamer pretty_namer_;
86107
std::vector<std::string> runtime_include_dir_;
87108
mutable std::mutex mutex_;
88109

paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ if(NOT CINN_ONLY)
6262
manual_op.cc
6363
op_attribute.cc
6464
DEPS
65-
pd_op_dialect)
65+
op_dialect_vjp)
6666

6767
target_include_directories(cinn_op_dialect PRIVATE ${CINN_DIALECT_BINARY_DIR})
6868
endif()

paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ if(NOT CINN_ONLY)
77
cinn_group_lowering_pass.cc
88
tensor_node.cc
99
DEPS
10-
pd_op_dialect
10+
op_dialect_vjp
1111
pir_compiler
1212
cinn_runtime_dialect)
1313

@@ -18,7 +18,7 @@ if(NOT CINN_ONLY)
1818
DEPS
1919
drr
2020
cinn_op_dialect
21-
pd_op_dialect)
21+
op_dialect_vjp)
2222

2323
cinn_cc_library(
2424
add_broadcast_to_elementwise_pass
@@ -27,5 +27,5 @@ if(NOT CINN_ONLY)
2727
DEPS
2828
pir
2929
cinn_op_dialect
30-
pd_op_dialect)
30+
op_dialect_vjp)
3131
endif()

paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,19 @@ bool IsSameDim(const phi::DDim& first, const std::vector<int64_t>& second) {
8181
return false;
8282
}
8383

84+
std::vector<int64_t> GetBroadcastAxis(const phi::DDim& in_shape,
85+
const std::vector<int64_t>& out_shape) {
86+
std::vector<int64_t> broadcast_axes(in_shape.size(), 0);
87+
auto in_shape_size = in_shape.size();
88+
if (in_shape_size >= 1) {
89+
for (int i = 1; i <= in_shape_size; ++i) {
90+
broadcast_axes[in_shape_size - i] = out_shape.size() - i;
91+
}
92+
}
93+
94+
return broadcast_axes;
95+
}
96+
8497
bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
8598
auto x_dims = op->operand_source(0)
8699
.type()
@@ -93,21 +106,21 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
93106

94107
if (x_dims != y_dims) {
95108
auto output_shape = GetOutputShape(x_dims, y_dims);
96-
std::vector<int64_t> vec_dims;
97-
for (int64_t i = 0; i < output_shape.size(); ++i) {
98-
vec_dims.push_back(i);
99-
}
100109
if (!IsSameDim(x_dims, output_shape)) {
101110
// add broadcast to input 0
102111
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
103-
op->operand_source(0), vec_dims, output_shape);
112+
op->operand_source(0),
113+
GetBroadcastAxis(x_dims, output_shape),
114+
output_shape);
104115

105116
op->operand(0).set_source(new_transpose_op->result(0));
106117
}
107118

108119
if (!IsSameDim(y_dims, output_shape)) {
109120
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
110-
op->operand_source(1), vec_dims, output_shape);
121+
op->operand_source(1),
122+
GetBroadcastAxis(y_dims, output_shape),
123+
output_shape);
111124

112125
op->operand(1).set_source(new_transpose_op->result(0));
113126
}

paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ OpPatternKind GetOpKind(const std::string& op_name) {
5858
}
5959

6060
phi::DDim GetFirstInputShape(const ::pir::Operation* op) {
61+
if (op->num_operands() == 0) {
62+
return phi::DDim({});
63+
}
6164
auto in = op->operand_source(0);
6265

6366
return in.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();

paddle/cinn/hlir/framework/CMakeLists.txt

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@ gather_srcs(
2222
op_lowering_impl.cc
2323
accuracy_checker.cc
2424
visualize_helper.cc
25-
compile_error.cc
26-
group_scheduler.cc)
25+
compile_error.cc)
2726

28-
# TODO(Aurelius84): pir_compiler depends on pd_op_dialect and could
27+
# TODO(Aurelius84): pir_compiler depends on op_dialect_vjp and could
2928
# not found under CINN_ONLY mode
3029
if(NOT CINN_ONLY)
31-
cinn_cc_library(pir_compiler SRCS pir_compiler.cc DEPS cinnapi pd_op_dialect)
30+
cinn_cc_library(pir_compiler SRCS pir_compiler.cc DEPS cinnapi op_dialect_vjp)
3231
endif()
3332

3433
if(WITH_CUDA)
@@ -44,8 +43,6 @@ endif()
4443
if(WITH_CUDA)
4544
cinn_cc_test(test_hlir_framework_op_lowering SRCS op_lowering_test.cc DEPS
4645
cinncore decomposer_test_helper)
47-
cinn_cc_test(test_group_scheduler SRCS group_scheduler_test.cc DEPS cinncore
48-
decomposer_test_helper)
4946
endif()
5047
cinn_cc_test(test_hlir_framework_tensor SRCS tensor_test.cc DEPS cinncore)
5148
cinn_cc_test(test_hlir_framework_scope SRCS scope_test.cc DEPS cinncore)

paddle/cinn/hlir/framework/op_lowering_impl.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
1818
#include "paddle/cinn/hlir/framework/compile_error.h"
1919
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
20-
#include "paddle/cinn/hlir/framework/group_scheduler.h"
2120
#include "paddle/cinn/hlir/framework/op_lowering_util.h"
2221
#include "paddle/cinn/hlir/op/external_api_registry.h"
22+
#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h"
2323
#include "paddle/cinn/ir/schedule/ir_schedule.h"
2424
#include "paddle/cinn/optim/transform_gpu_forloop.h"
2525
#include "paddle/cinn/runtime/flags.h"
@@ -470,8 +470,20 @@ ir::Expr OpLowererImpl::DoGroupSchedule(
470470
const GroupPtr& group,
471471
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
472472
if (FLAGS_cinn_new_group_scheduler) {
473-
GroupScheduler group_scheduler(&ir_sch, group, target_);
474-
group_scheduler();
473+
std::unordered_set<std::string> output_tensor_names;
474+
std::transform(
475+
group->output_nodes.begin(),
476+
group->output_nodes.end(),
477+
std::inserter(output_tensor_names, output_tensor_names.begin()),
478+
[](const Node* node) {
479+
NodeData* node_data =
480+
(*node->outlinks().begin())->sink()->safe_as<NodeData>();
481+
CHECK(node_data);
482+
return node_data->id();
483+
});
484+
ir::StaticShapeGroupScheduler group_scheduler(
485+
&ir_sch, output_tensor_names, target_);
486+
group_scheduler.Schedule();
475487
return ir_sch.GetModule().GetExprs().at(0);
476488
}
477489
// topological order.

paddle/cinn/hlir/framework/pir/op_lowering_util.cc

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -810,18 +810,22 @@ void LoopAssignReduceWithoutLast(ir::IRSchedule& ir_sch, // NOLINT
810810
}
811811

812812
std::vector<int> GetReducerDimAttr(::pir::Operation* reduce_op) {
813-
VLOG(3) << "GetReducerDimAttr from " << reduce_op->name();
814-
auto* source_op = reduce_op->operand_source(/*dim_idx=*/1)
815-
.dyn_cast<::pir::OpResult>()
816-
.owner();
817-
CHECK(source_op->isa<paddle::dialect::FullIntArrayOp>());
813+
int rank = reduce_op->operand_source(0)
814+
.type()
815+
.dyn_cast<::pir::DenseTensorType>()
816+
.dims()
817+
.size();
818+
819+
auto attr = reduce_op->attributes().at("dim");
820+
auto attr_vec = attr.dyn_cast<::pir::ArrayAttribute>().AsVector();
821+
818822
std::vector<int> dim;
819-
auto dim_attr = source_op->attributes()
820-
.at("value")
821-
.dyn_cast<::pir::ArrayAttribute>()
822-
.AsVector();
823-
for (auto& attr : dim_attr) {
824-
dim.push_back(attr.dyn_cast<::pir::Int64Attribute>().data());
823+
for (auto vec_element : attr_vec) {
824+
auto axis = vec_element.dyn_cast<::pir::Int64Attribute>().data();
825+
if (axis < 0) {
826+
axis += rank;
827+
}
828+
dim.push_back(axis);
825829
}
826830
return dim;
827831
}

paddle/cinn/hlir/framework/pir/op_mapper.cc

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
// limitations under the License.
1414

1515
#include "paddle/cinn/hlir/framework/pir/op_mapper.h"
16+
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
17+
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
1618
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
1719
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
1820

@@ -25,15 +27,41 @@ namespace {
2527

2628
void AppendAttrForReduceOp(const ::pir::Operation& op,
2729
utils::AttributeMap& attrs) { // NOLINT
28-
auto* source_op =
29-
op.operand_source(/*dim_idx=*/1).dyn_cast<::pir::OpResult>().owner();
30-
CHECK(source_op->isa<paddle::dialect::FullIntArrayOp>());
31-
auto dim_val =
32-
paddle::dialect::GetInt64Vector(source_op->attributes().at("value"));
33-
std::vector<int> dim(dim_val.begin(), dim_val.end());
30+
auto attr = op.attributes().at("dim");
31+
auto attr_vec = attr.dyn_cast<::pir::ArrayAttribute>().AsVector();
32+
33+
std::vector<int> dim;
34+
for (auto vec_element : attr_vec) {
35+
dim.push_back(vec_element.dyn_cast<::pir::Int64Attribute>().data());
36+
}
37+
3438
attrs["dim"] = dim;
3539
}
3640

41+
void AppendAttrForBoadcastToOp(const ::pir::Operation& op,
42+
utils::AttributeMap& attrs) { // NOLINT
43+
auto axes_attr = op.attributes().at("broadcast_axes");
44+
auto attr_vec = axes_attr.dyn_cast<::pir::ArrayAttribute>().AsVector();
45+
46+
std::vector<int> axis;
47+
for (auto vec_element : attr_vec) {
48+
axis.push_back(vec_element.dyn_cast<::pir::Int64Attribute>().data());
49+
}
50+
51+
attrs["broadcast_axes"] = axis;
52+
53+
auto out_shape_attr = op.attributes().at("out_shape");
54+
auto out_shape_attr_vec =
55+
out_shape_attr.dyn_cast<::pir::ArrayAttribute>().AsVector();
56+
57+
std::vector<int> out_shape;
58+
for (auto vec_element : out_shape_attr_vec) {
59+
out_shape.push_back(vec_element.dyn_cast<::pir::Int64Attribute>().data());
60+
}
61+
62+
attrs["out_shape"] = out_shape;
63+
}
64+
3765
} // namespace
3866

3967
#define REGISTER_OPERAND_RULE(OP, args...) \
@@ -42,18 +70,17 @@ void AppendAttrForReduceOp(const ::pir::Operation& op,
4270
};
4371

4472
#define REGISTER_ATTR_RULE(OP, func) \
45-
attr_funcs_[paddle::dialect::OP::name()] = func;
73+
attr_funcs_[cinn::dialect::OP::name()] = func;
4674

4775
void OpMapper::RegisterMapRules() {
4876
// max(x, dim) -> reduce_max(x)
4977
REGISTER_OPERAND_RULE(MaxOp, 0);
5078
REGISTER_OPERAND_RULE(SumOp, 0);
5179
REGISTER_OPERAND_RULE(MinOp, 0);
5280
REGISTER_OPERAND_RULE(ProdOp, 0);
53-
REGISTER_ATTR_RULE(MaxOp, AppendAttrForReduceOp);
54-
REGISTER_ATTR_RULE(SumOp, AppendAttrForReduceOp);
55-
REGISTER_ATTR_RULE(MinOp, AppendAttrForReduceOp);
56-
REGISTER_ATTR_RULE(ProdOp, AppendAttrForReduceOp);
81+
REGISTER_ATTR_RULE(ReduceMaxOp, AppendAttrForReduceOp);
82+
REGISTER_ATTR_RULE(ReduceSumOp, AppendAttrForReduceOp);
83+
REGISTER_ATTR_RULE(BroadcastOp, AppendAttrForBoadcastToOp);
5784
}
5885

5986
} // namespace pir

0 commit comments

Comments
 (0)