Skip to content

Commit bbc913b

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into Value-Shape-Unittest
2 parents 2d94bff + c02a3df commit bbc913b

File tree

51 files changed

+854
-533
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+854
-533
lines changed

paddle/cinn/ir/schedule/impl/base.cc

+31-2
Original file line numberDiff line numberDiff line change
@@ -286,15 +286,44 @@ Expr DyScheduleImpl::SampleCategorical(
286286
utils::LinearRandomEngine::StateType* rand_seed,
287287
const std::vector<int>& candidates,
288288
const std::vector<float>& probs) {
289-
CINN_NOT_IMPLEMENTED;
289+
// check two sizes
290+
CHECK_EQ(candidates.size(), probs.size())
291+
<< "candidates and probs must have same size.";
292+
int seed_idx = utils::SampleDiscreteFromDistribution(probs, rand_seed);
293+
auto result = candidates[seed_idx];
294+
Expr result_expr(result);
295+
return result_expr;
290296
}
291297

292298
std::vector<Expr> DyScheduleImpl::SamplePerfectTile(
293299
utils::LinearRandomEngine::StateType* rand_seed,
294300
const Expr& loop,
295301
int n,
296302
int max_innermost_factor) {
297-
CINN_NOT_IMPLEMENTED;
303+
CHECK(loop.As<ir::For>())
304+
<< "Expr param of SamplePerfectTile should be a For loop";
305+
CHECK_GE(n, 2) << "The number of tile factors should be at least 2";
306+
CHECK_GE(max_innermost_factor, 1)
307+
<< "The max innermost factor should be at least 1";
308+
CHECK(cinn::common::is_zero(loop.As<ir::For>()->min))
309+
<< "The For loop should start from 0";
310+
int loop_extent = GetLoopExtent(loop);
311+
std::vector<int> innermost_factors;
312+
for (int i = max_innermost_factor; i >= 1; --i) {
313+
if (loop_extent % i == 0) {
314+
innermost_factors.push_back(i);
315+
}
316+
}
317+
CHECK(!innermost_factors.empty()) << "No innermost factor found";
318+
int innermost_factor = innermost_factors[utils::SampleUniformInt(
319+
0, innermost_factors.size(), rand_seed)];
320+
auto result = SampleTile(rand_seed, n - 1, loop_extent / innermost_factor);
321+
std::vector<Expr> result_expr;
322+
for (auto& factor : result) {
323+
result_expr.push_back(Expr(factor));
324+
}
325+
result_expr.push_back(Expr(innermost_factor));
326+
return result_expr;
298327
}
299328

300329
Expr DyScheduleImpl::AddUnitLoop(const Expr& block) const {

paddle/fluid/framework/ir/transfer_layout_elim_pass.cc

+13-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ void TransferLayoutElimPass::PutTranferlayoutAfterOp(
120120
auto *new_transfer_layout_node =
121121
graph->CreateOpNode(&new_transfer_layout_desc);
122122

123-
for (auto other_op : var2->outputs) {
123+
// must use a tmp variable var_out, because var2->outputs will be changed in
124+
// loop.
125+
auto var_out = var2->outputs;
126+
for (auto other_op : var_out) {
124127
IR_NODE_UNLINK(var2, other_op);
125128
other_op->Op()->RenameInput(var2->Name(), var2_dot_name);
126129
IR_NODE_LINK_TO(var2_dot, other_op);
@@ -244,6 +247,9 @@ void TransferLayoutElimPass::ApplyImpl(ir::Graph *graph) const {
244247
return "";
245248
};
246249

250+
int move_down_count = 0;
251+
int elim_count = 0;
252+
247253
while (true) {
248254
auto op_node_sorted = framework::ir::TopologyVarientSort(
249255
*graph, static_cast<framework::ir::SortKind>(0));
@@ -309,6 +315,7 @@ void TransferLayoutElimPass::ApplyImpl(ir::Graph *graph) const {
309315
}
310316
op_node->Op()->SetAttr("axis", modify_axis);
311317
modify = true;
318+
move_down_count++;
312319
break;
313320
}
314321
if (is_pool_like_op) {
@@ -318,21 +325,26 @@ void TransferLayoutElimPass::ApplyImpl(ir::Graph *graph) const {
318325
transfer_format(
319326
op_node->Op()->GetAttrIfExists<std::string>("data_format")));
320327
modify = true;
328+
move_down_count++;
321329
break;
322330
}
323331
if (is_act_like_op) {
324332
PutTranferlayoutAfterOp(op_node, graph, nullptr);
325333
modify = true;
334+
move_down_count++;
326335
break;
327336
}
328337
if (is_elim_op) {
329338
ElimTwoTranferlayout(op_node, graph, &modify);
339+
elim_count++;
330340
break;
331341
}
332342
}
333343
}
334344
if (!modify) break;
335345
}
346+
LOG(INFO) << "move down " << move_down_count << " transfer_layout";
347+
LOG(INFO) << "eliminate " << elim_count << " pair of transfer_layout";
336348
}
337349

338350
} // namespace ir

paddle/fluid/imperative/amp_auto_cast.cc

+9-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <memory>
1818
#include <string>
1919

20+
#include "paddle/fluid/eager/api/utils/global_utils.h"
2021
#include "paddle/fluid/eager/eager_tensor.h"
2122
#include "paddle/fluid/imperative/tracer.h"
2223
#include "paddle/fluid/imperative/type_defs.h"
@@ -66,7 +67,14 @@ OpSupportedInfos(const std::string& place,
6667
std::unordered_set<std::string> all_ops;
6768
const auto& op_info = framework::OpInfoMap::Instance().map();
6869
for (const auto& item : op_info) {
69-
all_ops.emplace(item.first);
70+
const std::string op_type = item.first;
71+
// The dtype of custom op is RAW(runtime decided type), skip it since we
72+
// cannot determine its supported dtype here.
73+
if (egr::Controller::Instance().GetOpMetaInfoMap().count(op_type)) {
74+
VLOG(6) << "Skip custom op " << op_type << " for checking amp supported!";
75+
continue;
76+
}
77+
all_ops.emplace(op_type);
7078
}
7179

7280
std::unordered_set<std::string> supported_ops;

paddle/fluid/inference/tensorrt/op_teller.cc

100755100644
+6-1
Original file line numberDiff line numberDiff line change
@@ -2302,15 +2302,20 @@ struct SimpleOpTypeSetTeller : public Teller {
23022302
if (!with_dynamic_shape) {
23032303
if (tile_inputs.find("repeat_times_tensor") != tile_inputs.end()) {
23042304
if (!desc.Input("repeat_times_tensor").empty()) {
2305+
VLOG(3) << "Tile op: repeat_times_tensor is not empty.";
23052306
return false;
23062307
}
23072308
}
23082309
if (tile_inputs.find("RepeatTimes") != tile_inputs.end()) {
23092310
if (!desc.Input("RepeatTimes").empty()) {
2311+
VLOG(3) << "Tile op: RepeatTimes is not empty.";
23102312
return false;
23112313
}
23122314
}
2313-
if (!desc.HasAttr("repeat_times")) return false;
2315+
if (!desc.HasAttr("repeat_times")) {
2316+
VLOG(3) << "Tile op:`repeat_times` is not set.";
2317+
return false;
2318+
}
23142319
}
23152320
}
23162321
#endif

paddle/fluid/pir/dialect/op_generator/op_build_gen.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,11 @@ def GenBuildOutputs(
395395
{name} = std::move(phi::IntArray(std::vector<int64_t>({name}_size, -1)));
396396
{name}.SetFromTensor(true);
397397
}} else if ({name}_.type().isa<paddle::dialect::DenseTensorType>()) {{
398-
size_t {name}_size = common::product({name}_.type().dyn_cast<paddle::dialect::DenseTensorType>().dims());
398+
common::DDim {name}_dim = {name}_.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
399+
size_t {name}_size = common::product({name}_dim);
400+
if (common::contain_unknown_dim({name}_dim)) {{
401+
{name}_size = 1;
402+
}}
399403
{name} = std::move(phi::IntArray(std::vector<int64_t>({name}_size, -1)));
400404
{name}.SetFromTensor(true);
401405
}} else {{
@@ -412,7 +416,11 @@ def GenBuildOutputs(
412416
size_t {name}_size = {name}_.type().dyn_cast<pir::VectorType>().size();
413417
{name} = std::vector<int64_t>({name}_size, -1);
414418
}} else if ({name}_.type().isa<paddle::dialect::DenseTensorType>()) {{
415-
size_t {name}_size = common::product({name}_.type().dyn_cast<paddle::dialect::DenseTensorType>().dims());
419+
common::DDim {name}_dim = {name}_.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
420+
size_t {name}_size = common::product({name}_dim);
421+
if (common::contain_unknown_dim({name}_dim)) {{
422+
{name}_size = 1;
423+
}}
416424
{name} = std::vector<int64_t>({name}_size, -1);
417425
}} else {{
418426
PADDLE_THROW(phi::errors::Unimplemented("Only support VectorType or DenseTensorType"));

paddle/phi/infermeta/spmd_rules/concat.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ SpmdInfo ConcatInferSpmdReverse(const std::vector<DistMetaTensor>& x,
102102
const DistMetaTensor& output,
103103
int axis) {
104104
auto out_dist_attr = output.dist_attr();
105-
out_dist_attr = UnShardTensorDim(out_dist_attr, axis);
105+
out_dist_attr = UnShardTensorDims(out_dist_attr, {axis});
106106
auto n_inputs = x.size();
107107
TensorDistAttr input_attr = CopyTensorDistAttrForOutput(out_dist_attr);
108108
const auto& input_dim_mapping = out_dist_attr.dims_mapping();

paddle/phi/infermeta/spmd_rules/rules.h

+4
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ PD_REGISTER_SPMD_RULE(
102102
unsqueeze,
103103
PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmd),
104104
PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmdReverse));
105+
PD_REGISTER_SPMD_RULE(
106+
unsqueeze2,
107+
PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmd),
108+
PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmdReverse));
105109

106110
// elementwise unary rule
107111
PD_REGISTER_SPMD_RULE(

paddle/phi/infermeta/spmd_rules/unsqueeze.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x,
224224
<< "dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]";
225225
VLOG(4) << "X dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n";
226226

227-
return {{x_dist_attr}, {out_dist_attr_dst}};
227+
return {{x_dist_attr},
228+
{out_dist_attr_dst, CreateUnsqueezeXshape(x_dist_attr)}};
228229
}
229230

230231
SpmdInfo UnsqueezeGradInferSpmd(const DistMetaTensor& xshape,

paddle/pir/dialect/shape/utils/dim_expr.cc

+74
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,78 @@ DimExpr DimExpr::operator/(const DimExpr& other) const {
3434
return Mul<DimExpr>(std::vector{*this, reciprocal});
3535
}
3636

37+
namespace {
38+
39+
bool DimExprEqual(std::int64_t lhs, std::int64_t rhs) { return lhs == rhs; }
40+
41+
bool DimExprEqual(const std::string& lhs, const std::string& rhs) {
42+
return lhs == rhs;
43+
}
44+
45+
bool DimExprEqual(const Negative<DimExpr>& lhs, const Negative<DimExpr>& rhs) {
46+
return lhs->data == rhs->data;
47+
}
48+
49+
bool DimExprEqual(const Reciprocal<DimExpr>& lhs,
50+
const Reciprocal<DimExpr>& rhs) {
51+
return lhs->data == rhs->data;
52+
}
53+
54+
template <template <typename> class Op>
55+
bool DimExprEqual(const Op<DimExpr>& lhs, const Op<DimExpr>& rhs) {
56+
if (lhs->size() != rhs->size()) {
57+
return false;
58+
}
59+
for (std::size_t i = 0; i < lhs->size(); ++i) {
60+
if (lhs->at(i) != rhs->at(i)) {
61+
return false;
62+
}
63+
}
64+
return true;
65+
}
66+
67+
bool DimExprEqual(const Add<DimExpr>& lhs, const Add<DimExpr>& rhs) {
68+
return DimExprEqual<Add>(lhs, rhs);
69+
}
70+
71+
bool DimExprEqual(const Mul<DimExpr>& lhs, const Mul<DimExpr>& rhs) {
72+
return DimExprEqual<Mul>(lhs, rhs);
73+
}
74+
75+
bool DimExprEqual(const Max<DimExpr>& lhs, const Max<DimExpr>& rhs) {
76+
return DimExprEqual<Max>(lhs, rhs);
77+
}
78+
79+
bool DimExprEqual(const Min<DimExpr>& lhs, const Min<DimExpr>& rhs) {
80+
return DimExprEqual<Min>(lhs, rhs);
81+
}
82+
83+
bool DimExprEqual(const Broadcast<DimExpr>& lhs,
84+
const Broadcast<DimExpr>& rhs) {
85+
return DimExprEqual<Broadcast>(lhs, rhs);
86+
}
87+
88+
} // namespace
89+
90+
bool DimExpr::operator==(const DimExpr& other) const {
91+
if (this == &other) {
92+
return true;
93+
}
94+
return std::visit(
95+
[](const auto& lhs, const auto& rhs) {
96+
if constexpr (std::is_same_v<std::decay_t<decltype(lhs)>,
97+
std::decay_t<decltype(rhs)>>) {
98+
return DimExprEqual(lhs, rhs);
99+
} else {
100+
return false;
101+
}
102+
},
103+
this->variant(),
104+
other.variant());
105+
}
106+
107+
bool DimExpr::operator!=(const DimExpr& other) const {
108+
return !(*this == other);
109+
}
110+
37111
} // namespace symbol

paddle/pir/dialect/shape/utils/dim_expr.h

+6
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,16 @@ class DimExpr : public DimExprBase {
128128
return std::get<T>(*this);
129129
}
130130

131+
const DimExprBase& variant() const {
132+
return static_cast<const DimExprBase&>(*this);
133+
}
134+
131135
DimExpr operator+(const DimExpr& other) const;
132136
DimExpr operator-(const DimExpr& other) const;
133137
DimExpr operator*(const DimExpr& other) const;
134138
DimExpr operator/(const DimExpr& other) const;
139+
bool operator==(const DimExpr& other) const;
140+
bool operator!=(const DimExpr& other) const;
135141
};
136142

137143
// DimExprConstraint = Equal DimExpr

paddle/pir/dialect/shape/utils/dim_expr_builder.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
namespace symbol {
1818

1919
using BroadcastDimExpr = Broadcast<DimExpr>;
20+
using MinDimExpr = Min<DimExpr>;
21+
using MaxDimExpr = Max<DimExpr>;
2022

2123
DimExpr DimExprBuilder::ConstSize(std::int64_t dim) { SYMBOL_NOT_IMPLEMENTED; }
2224

@@ -41,11 +43,11 @@ DimExpr DimExprBuilder::Div(const DimExpr& lhs, const DimExpr& rhs) {
4143
}
4244

4345
DimExpr DimExprBuilder::Max(const DimExpr& lhs, const DimExpr& rhs) {
44-
SYMBOL_NOT_IMPLEMENTED;
46+
return MaxDimExpr(std::vector{lhs, rhs});
4547
}
4648

4749
DimExpr DimExprBuilder::Min(const DimExpr& lhs, const DimExpr& rhs) {
48-
SYMBOL_NOT_IMPLEMENTED;
50+
return MinDimExpr(std::vector{lhs, rhs});
4951
}
5052

5153
DimExpr DimExprBuilder::Broadcast(const DimExpr& lhs, const DimExpr& rhs) {

pyproject.toml

+33-2
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,6 @@ combine-as-imports = true
108108
known-first-party = ["paddle"]
109109

110110
[tool.ruff.lint.per-file-ignores]
111-
# Ignore unused imports in __init__.py
112-
"__init__.py" = ["F401", "I001"]
113111
# These files need tabs for testing.
114112
"test/dygraph_to_static/test_legacy_error.py" = ["E101", "W191"]
115113
# Ignore compare with True in sot unittest
@@ -124,3 +122,36 @@ known-first-party = ["paddle"]
124122
"test/dygraph_to_static/test_loop.py" = ["C416", "F821"]
125123
# Ignore unnecessary lambda in dy2st unittest test_lambda
126124
"test/dygraph_to_static/test_lambda.py" = ["PLC3002"]
125+
126+
# temp ignore unused imports in all distributed files
127+
"python/paddle/distributed/transpiler/__init__.py" = ["F401"]
128+
"python/paddle/incubate/distributed/fleet/parameter_server/distribute_transpiler/__init__.py" = ["F401", "I001"]
129+
"python/paddle/distributed/fleet/runtime/__init__.py" = ["F401"]
130+
"python/paddle/distributed/transpiler/details/__init__.py" = ["F401", "I001"]
131+
"python/paddle/incubate/distributed/models/moe/gate/__init__.py" = ["F401", "I001"]
132+
"python/paddle/incubate/distributed/models/moe/__init__.py" = ["F401", "I001"]
133+
"python/paddle/incubate/distributed/utils/io/__init__.py" = ["F401", "I001"]
134+
"python/paddle/distributed/fleet/elastic/__init__.py" = ["F401", "I001"]
135+
136+
# temp ignore isort
137+
"python/paddle/__init__.py" = ["I001"]
138+
"python/paddle/amp/__init__.py" = ["I001"]
139+
"python/paddle/audio/__init__.py" = ["I001"]
140+
"python/paddle/audio/features/__init__.py" = ["I001"]
141+
"python/paddle/audio/functional/__init__.py" = ["I001"]
142+
"python/paddle/base/__init__.py" = ["I001"]
143+
"python/paddle/distributed/__init__.py" = ["I001"]
144+
"python/paddle/distributed/communication/stream/__init__.py" = ["I001"]
145+
"python/paddle/device/cuda/__init__.py" = ["I001"]
146+
"python/paddle/distributed/launch/context/__init__.py" = ["I001"]
147+
"python/paddle/distributed/launch/controllers/__init__.py" = ["I001"]
148+
"python/paddle/distributed/passes/__init__.py" = ["I001"]
149+
"python/paddle/distributed/rpc/__init__.py" = ["I001"]
150+
"python/paddle/distribution/__init__.py" = ["I001"]
151+
"python/paddle/framework/__init__.py" = ["I001"]
152+
"python/paddle/incubate/distributed/fleet/__init__.py" = ["I001"]
153+
"python/paddle/incubate/distributed/fleet/parameter_server/pslib/__init__.py" = ["I001"]
154+
"python/paddle/io/dataloader/__init__.py" = ["I001"]
155+
"python/paddle/jit/__init__.py" = ["I001"]
156+
"python/paddle/pir/__init__.py" = ["I001"]
157+
"python/paddle/tensor/__init__.py" = ["I001"]

python/cinn/auto_schedule/cost_model/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .cost_model import CostModel
16-
from .cost_model import CostModelType
15+
from .cost_model import CostModel, CostModelType
1716
from .xgb_cost_model import XgbCostModel
1817

1918
__all__ = [

0 commit comments

Comments
 (0)