Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 30756f2

Browse files
Merge branch 'develop' of https://github.com/PaddlePaddle/CINN into develop
2 parents 6d7f058 + 26bedf3 commit 30756f2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+3013
-173
lines changed

cinn/auto_schedule/auto_tuner.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "cinn/auto_schedule/measure/simple_builder.h"
2626
#include "cinn/auto_schedule/measure/simple_runner.h"
2727
#include "cinn/auto_schedule/task/task_creator.h"
28+
#include "cinn/auto_schedule/task/task_registry.h"
2829
#include "cinn/auto_schedule/task/tune_task.h"
2930
#include "cinn/auto_schedule/task_scheduler/task_scheduler.h"
3031
#include "cinn/common/context.h"
@@ -49,11 +50,21 @@ void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler*
4950
const auto& dtype_dict = graph_->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype");
5051
const auto& shape_dict = graph_->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
5152

52-
op_lowerer_ = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target_);
53+
op_lowerer_ = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target_);
54+
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
5355
for (TuneTask& task : tasks_) {
5456
task.SetOpLowerer(op_lowerer_.get());
5557
task.TaskGraphToUnoptLoweredFunc();
5658
task.SerializeToString(shape_dict, dtype_dict);
59+
60+
// Register the initial ModuleExpr corresponding to the task
61+
std::vector<ir::Expr> exprs(task.lowered_funcs.size());
62+
std::transform(
63+
task.lowered_funcs.begin(), task.lowered_funcs.end(), exprs.begin(), [&](const ir::LoweredFunc& func) {
64+
return func->body;
65+
});
66+
task_registry->Regist(task.serialized_key, ir::ModuleExpr(exprs));
67+
5768
VLOG(3) << "Add a task with serialized_key:\n" << task.serialized_key;
5869
}
5970

cinn/auto_schedule/task/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ gather_srcs(cinnapi_src SRCS task_creator.cc task_optimizer.cc)
99

1010
cc_test(test_task_creator SRCS task_creator_test.cc DEPS cinncore)
1111
cc_test(test_tune_task SRCS tune_task_test.cc DEPS cinncore)
12+
cc_test(test_task_registry SRCS task_registry_test.cc DEPS cinncore)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <gflags/gflags.h>
18+
19+
#include <mutex>
20+
#include <string>
21+
22+
#include "cinn/ir/ir_schedule.h"
23+
#include "cinn/optim/ir_copy.h"
24+
#include "cinn/utils/registry.h"
25+
26+
namespace cinn {
27+
28+
namespace auto_schedule {
29+
30+
struct InitialTaskInfo {
31+
std::string task_key;
32+
ir::ModuleExpr module_expr;
33+
34+
InitialTaskInfo(const std::string& task_key, const ir::ModuleExpr& module_expr)
35+
: task_key(task_key), module_expr(module_expr) {}
36+
};
37+
38+
// Global task registry, used to save the initial ModuleExpr of each task.
39+
class InitialTaskRegistry : public Registry<InitialTaskInfo> {
40+
public:
41+
static InitialTaskRegistry* Global() {
42+
static InitialTaskRegistry x;
43+
return &x;
44+
}
45+
46+
// Get the initial ModuleExpr of a task.
47+
inline const InitialTaskInfo* Get(const std::string& task_key) {
48+
const InitialTaskInfo* task_info = Registry<InitialTaskInfo>::Find(task_key);
49+
CHECK(task_info) << "InitialTaskInfo [" << task_key << "] is not registered";
50+
return task_info;
51+
}
52+
53+
// Check if the task info with task_key exists;
54+
inline const bool Has(const std::string& task_key) { return nullptr != Registry<InitialTaskInfo>::Find(task_key); }
55+
56+
// Regist the initial ModuleExpr of a task into the map
57+
inline void Regist(const std::string& task_key, const ir::ModuleExpr& module_expr) {
58+
std::lock_guard<std::mutex> guard(registering_mutex);
59+
if (fmap_.count(task_key) == 0) {
60+
InitialTaskInfo* task_info = new InitialTaskInfo(task_key, optim::IRCopy(module_expr));
61+
__REGISTER__(task_key, task_info);
62+
}
63+
}
64+
65+
private:
66+
InitialTaskRegistry() = default;
67+
CINN_DISALLOW_COPY_AND_ASSIGN(InitialTaskRegistry);
68+
69+
// Regist the initial ModuleExpr of a task.
70+
inline InitialTaskInfo* __REGISTER__(const std::string& task_key, InitialTaskInfo* task_info) {
71+
fmap_[task_key] = task_info;
72+
const_list_.push_back(task_info);
73+
entry_list_.push_back(task_info);
74+
return task_info;
75+
}
76+
};
77+
78+
} // namespace auto_schedule
79+
} // namespace cinn
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "cinn/auto_schedule/task/task_registry.h"
16+
17+
#include <glog/logging.h>
18+
#include <gtest/gtest.h>
19+
20+
#include <cstdlib>
21+
22+
#include "cinn/auto_schedule/task/task_creator.h"
23+
#include "cinn/auto_schedule/task/tune_task.h"
24+
#include "cinn/frontend/net_builder.h"
25+
#include "cinn/hlir/framework/graph.h"
26+
#include "cinn/hlir/framework/graph_compiler.h"
27+
#include "cinn/hlir/framework/op_lowering.h"
28+
#include "cinn/utils/string.h"
29+
#include "cinn/utils/type_defs.h"
30+
31+
DECLARE_bool(auto_schedule_use_cost_model);
32+
DECLARE_bool(cinn_ir_schedule);
33+
34+
namespace cinn {
35+
namespace auto_schedule {
36+
37+
std::vector<TuneTask> CreateTasks(hlir::framework::Graph* graph, const common::Target& target) {
38+
// create tasks
39+
TaskCreator task_creator;
40+
std::vector<TuneTask> tasks = task_creator.CreateTuneTaskOpLevel(graph);
41+
42+
const auto& dtype_dict = graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype");
43+
const auto& shape_dict = graph->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
44+
45+
std::unique_ptr<hlir::framework::OpLowerer> op_lowerer =
46+
std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target);
47+
for (TuneTask& task : tasks) {
48+
task.SetOpLowerer(op_lowerer.get());
49+
task.TaskGraphToUnoptLoweredFunc();
50+
task.SerializeToString(shape_dict, dtype_dict);
51+
VLOG(3) << "Add a task with serialized_key:\n" << task.serialized_key;
52+
}
53+
54+
return tasks;
55+
}
56+
57+
std::shared_ptr<hlir::framework::Graph> CreateAddProgram(const common::Target& target) {
58+
frontend::NetBuilder builder("test");
59+
60+
auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
61+
auto b = builder.CreateInput(Float(32), {64}, "B");
62+
auto c = builder.Add(a, b, 1);
63+
64+
return std::make_shared<hlir::framework::Graph>(builder.Build(), target);
65+
}
66+
67+
TEST(TestTaskRegistry, basic) {
68+
FLAGS_auto_schedule_use_cost_model = true;
69+
FLAGS_cinn_ir_schedule = true;
70+
71+
#ifdef CINN_WITH_CUDA
72+
Target target = common::DefaultNVGPUTarget();
73+
#else
74+
Target target = common::DefaultHostTarget();
75+
#endif
76+
std::shared_ptr<hlir::framework::Graph> graph = CreateAddProgram(target);
77+
std::vector<TuneTask> tasks = CreateTasks(graph.get(), target);
78+
79+
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
80+
81+
std::vector<ir::ModuleExpr> module_exprs;
82+
for (const TuneTask& task : tasks) {
83+
module_exprs.emplace_back(task.GetLoweredFuncBodyExprs());
84+
task_registry->Regist(task.serialized_key, module_exprs.back());
85+
}
86+
87+
for (int i = 0; i < tasks.size(); ++i) {
88+
std::string key = tasks[i].serialized_key;
89+
VLOG(3) << "serialized_key = " << key;
90+
ir::ModuleExpr new_expr = task_registry->Get(key)->module_expr;
91+
92+
ASSERT_EQ(new_expr.GetExprs().size(), module_exprs[i].GetExprs().size());
93+
for (int j = 0; j < new_expr.GetExprs().size(); ++j) {
94+
VLOG(3) << "expr " << j << " of task " << key << " : " << new_expr.GetExprs().at(j);
95+
ASSERT_EQ(utils::GetStreamCnt(new_expr.GetExprs().at(j)), utils::GetStreamCnt(module_exprs[i].GetExprs().at(j)));
96+
}
97+
}
98+
99+
bool flag = task_registry->Has(tasks[0].serialized_key);
100+
ASSERT_EQ(flag, true);
101+
102+
flag = task_registry->Has("not_exist");
103+
ASSERT_EQ(flag, false);
104+
}
105+
106+
} // namespace auto_schedule
107+
} // namespace cinn

cinn/frontend/computation_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Program CreateTestProgram() {
4848
auto i = builder.Max(e, h);
4949
auto j = builder.Min(e, h);
5050
auto k = builder.Multiply(i, j);
51-
auto l = builder.ConstScalar<bool>(1, "condition");
51+
auto l = builder.Constant<bool>(1, "condition");
5252
auto m = builder.BroadcastTo(l, {B, M, N}, {0});
5353
auto n = builder.Select(m, j, k);
5454
auto o = builder.ReduceSum(n, {0, 1, 2});

cinn/frontend/decomposer/batch_norm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ struct BatchNormHelper {
5454

5555
template <typename T>
5656
Variable GetTensorFromScalar(T value, std::string name, const std::vector<int>& shape) {
57-
// return builder->BroadcastTo(builder->ConstScalar<T>(value, common::UniqName(name)), shape, {0});
57+
// return builder->BroadcastTo(builder->Constant<T>(value, common::UniqName(name)), shape, {0});
5858
return builder->FillConstant<T>(shape, value, common::UniqName(name));
5959
}
6060

cinn/frontend/net_builder.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@ std::vector<Variable> NetBuilder::Split(const Variable& operand, const std::vect
220220
}
221221

222222
Variable NetBuilder::Concat(const std::vector<Variable>& input_vars, int axis) {
223+
CHECK(!input_vars.empty()) << "The inputs of concat op should not be empty! Please check.";
224+
if (input_vars.size() == 1UL) {
225+
return Identity(input_vars.front());
226+
}
223227
return CustomInstr("concat", input_vars, {{"axis", axis}}).front();
224228
}
225229

@@ -446,6 +450,24 @@ Variable NetBuilder::Sort(const Variable& operand, const int& axis, const bool&
446450
return instr.GetOutput(0);
447451
}
448452

453+
Variable NetBuilder::Argmax(const Variable& x, const int& axis, const bool& keep_dim) {
454+
Instruction instr("argmax", {x});
455+
instr.SetAttr("axis", axis);
456+
instr.SetAttr("keep_dim", keep_dim);
457+
InferShape(instr);
458+
AppendInstruction(instr);
459+
return instr.GetOutput(0);
460+
}
461+
462+
Variable NetBuilder::Argmin(const Variable& x, const int& axis, const bool& keep_dim) {
463+
Instruction instr("argmin", {x});
464+
instr.SetAttr("axis", axis);
465+
instr.SetAttr("keep_dim", keep_dim);
466+
InferShape(instr);
467+
AppendInstruction(instr);
468+
return instr.GetOutput(0);
469+
}
470+
449471
Variable NetBuilder::Conv2d(const Variable& a,
450472
const Variable& b,
451473
const std::vector<int>& strides,
@@ -502,6 +524,10 @@ Variable NetBuilder::Pool2d(const Variable& a,
502524
.front();
503525
}
504526

527+
Variable NetBuilder::Repeat(const Variable& x, int repeats, int axis) {
528+
return CustomInstr("repeat", {x}, {{"repeats", repeats}, {"axis", axis}}).front();
529+
}
530+
505531
std::vector<Variable> NetBuilder::BatchNorm(const Variable& a,
506532
const Variable& scale,
507533
const Variable& bias,
@@ -557,6 +583,14 @@ Variable NetBuilder::Arange(const float start, const float stop, const float ste
557583
return CustomInstr("arange", {}, {{"start", start}, {"stop", stop}, {"step", step}, {"dtype", dtype}}).front();
558584
}
559585

586+
Variable NetBuilder::Flip(const Variable& operand, const std::vector<int>& axes) {
587+
Instruction instr("flip", {operand});
588+
instr.SetAttr("axes", axes);
589+
InferShape(instr);
590+
AppendInstruction(instr);
591+
return instr.GetOutput(0);
592+
}
593+
560594
// conv2d grad, output(grad_x, grad_w)
561595
std::vector<Variable> NetBuilder::Conv2dGrad(const Variable& dy,
562596
const Variable& x,

0 commit comments

Comments
 (0)