Skip to content

Commit ace727c

Browse files
Zhang TingSuperjomn
Zhang Ting
andauthored
add decomposer registry (PaddlePaddle#458)
* add decomposer registry Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
1 parent 96b621b commit ace727c

File tree

8 files changed

+126
-0
lines changed

8 files changed

+126
-0
lines changed

cinn/common/target.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include <glog/logging.h>
44

5+
#include <sstream>
6+
57
#include "cinn/runtime/cinn_runtime.h"
68

79
namespace cinn {
@@ -49,6 +51,12 @@ int Target::get_target_bits() const {
4951
return -1;
5052
}
5153

54+
std::string Target::arch_str() const {
55+
std::ostringstream oss;
56+
oss << arch;
57+
return oss.str();
58+
}
59+
5260
std::ostream &operator<<(std::ostream &os, const Target &target) {
5361
os << "Target<";
5462
switch (target.os) {

cinn/common/target.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <ostream>
4+
#include <string>
45
#include <vector>
56

67
namespace cinn {
@@ -69,6 +70,8 @@ struct Target {
6970

7071
std::vector<Lib> get_target_libs() const;
7172

73+
std::string arch_str() const;
74+
7275
bool operator==(const Target& other) const;
7376
bool operator!=(const Target& other) const { return !(*this == other); }
7477
friend std::ostream& operator<<(std::ostream& os, const Target& target);

cinn/frontend/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,8 @@ else()
2626
endif()
2727

2828
cc_test(test_net_builder SRCS net_builder_test.cc DEPS cinncore)
29+
cc_test(test_decomposer_registry
30+
SRCS decomposer_registry_test.cc DEPS cinncore)
2931

3032
add_subdirectory(paddle)
33+
add_subdirectory(decomposer)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
core_gather_headers()
2+
3+
core_gather_srcs(SRCS
4+
activation.cc
5+
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include "cinn/frontend/decomposer_registry.h"
2+
3+
namespace cinn {
4+
namespace frontend {
5+
namespace decomposer {
6+
7+
void relu(const Instruction& instr, const DecomposerContext& context) { LOG(FATAL) << "not implemented"; }
8+
9+
} // namespace decomposer
10+
} // namespace frontend
11+
} // namespace cinn
12+
13+
CINN_REGISTER_HELPER(activation) {
14+
CINN_DECOMPOSER_REGISTER(relu, ::cinn::common::DefaultHostTarget()).set_body(cinn::frontend::decomposer::relu);
15+
16+
return true;
17+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
#include "cinn/common/macros.h"
4+
5+
CINN_USE_REGISTER(activation)

cinn/frontend/decomposer_registry.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#pragma once
2+
3+
#include <functional>
4+
#include <string>
5+
#include <unordered_map>
6+
7+
#include "cinn/common/target.h"
8+
#include "cinn/frontend/syntax.h"
9+
10+
namespace cinn {
11+
namespace frontend {
12+
13+
class Decomposer;
14+
15+
class DecomposerContext {
16+
public:
17+
explicit DecomposerContext(Program* prog) : program(prog) {}
18+
19+
Program* program{nullptr};
20+
};
21+
22+
class InstrDecomposerRegistry : public Registry<Decomposer> {
23+
public:
24+
static InstrDecomposerRegistry* Global() {
25+
static InstrDecomposerRegistry x;
26+
return &x;
27+
}
28+
29+
inline const Decomposer* Get(const std::string& op_name, const common::Target& target) {
30+
const Decomposer* decomposer = Find(op_name, target);
31+
CHECK(decomposer) << "Decomposer for [" << op_name << ", " << target << "] is not registered";
32+
return decomposer;
33+
}
34+
35+
inline const Decomposer* Find(const std::string& name, const common::Target& target) {
36+
return Registry<Decomposer>::Find(name + "_" + target.arch_str());
37+
}
38+
39+
inline Decomposer& __REGISTER__(const std::string& name, const common::Target& target) {
40+
return Registry<Decomposer>::__REGISTER__(name + "_" + target.arch_str());
41+
}
42+
43+
private:
44+
InstrDecomposerRegistry() = default;
45+
CINN_DISALLOW_COPY_AND_ASSIGN(InstrDecomposerRegistry);
46+
};
47+
48+
class Decomposer {
49+
public:
50+
using DecomposerKernel = std::function<void(const Instruction& instr, const DecomposerContext&)>;
51+
52+
Decomposer& set_body(const DecomposerKernel& kernel) {
53+
kernel_ = kernel;
54+
return *this;
55+
}
56+
57+
void Run(const Instruction& instr, const DecomposerContext& context) { kernel_(instr, context); }
58+
59+
std::string name;
60+
61+
private:
62+
DecomposerKernel kernel_;
63+
};
64+
65+
#define CINN_DECOMPOSER_REGISTER(name, target) \
66+
static ::cinn::frontend::Decomposer& CINN_STR_CONCAT(__make_Decomposer_name, __COUNTER__) = \
67+
::cinn::frontend::InstrDecomposerRegistry::Global()->__REGISTER__(#name, target)
68+
69+
} // namespace frontend
70+
} // namespace cinn
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#include "cinn/frontend/decomposer_registry.h"
2+
3+
#include <gtest/gtest.h>
4+
5+
#include "cinn/frontend/decomposer/use_decomposer.h"
6+
7+
namespace cinn::frontend {
8+
9+
TEST(InstrDecomposerRegistry, basic) {
10+
common::Target target = common::DefaultHostTarget();
11+
ASSERT_EQ(InstrDecomposerRegistry::Global()->Find("conv", target), nullptr);
12+
ASSERT_NE(InstrDecomposerRegistry::Global()->Find("relu", target), nullptr);
13+
}
14+
15+
} // namespace cinn::frontend

0 commit comments

Comments
 (0)