Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit d0ac73e

Browse files
authored
fix fill_constant unsupport str_value bug (#1228)
1 parent ed2c0ee commit d0ac73e

File tree

6 files changed

+140
-96
lines changed

6 files changed

+140
-96
lines changed

cinn/backends/compiler.cc

100755100644
Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,46 @@ using ir::Module;
3535

3636
static constexpr int DebugLogMaxLen = 30000;
3737

38+
class SourceCodePrint {
39+
public:
40+
static SourceCodePrint* GetInstance() {
41+
static SourceCodePrint print;
42+
return &print;
43+
}
44+
45+
void write(const std::string& source_code) {
46+
if (of.is_open()) {
47+
VLOG(4) << "Write to " << FLAGS_cinn_source_code_save_path;
48+
of << source_code << std::endl;
49+
} else if (!FLAGS_cinn_source_code_save_path.empty()) {
50+
LOG(WARNING) << "Failed to open " << FLAGS_cinn_source_code_save_path << ", source code will print.";
51+
if (source_code.size() > DebugLogMaxLen) {
52+
LOG(INFO) << "[CUDA] source code-0:\n" << source_code.substr(0, DebugLogMaxLen);
53+
for (int i = 1; i * DebugLogMaxLen < source_code.size(); ++i) {
54+
LOG(INFO) << "[CUDA] source code-" << i << ":\n" << source_code.substr(DebugLogMaxLen * i, DebugLogMaxLen);
55+
}
56+
} else {
57+
LOG(INFO) << "[CUDA] source code:\n" << source_code;
58+
}
59+
}
60+
}
61+
62+
private:
63+
SourceCodePrint() {
64+
if (!FLAGS_cinn_source_code_save_path.empty()) {
65+
of.open(FLAGS_cinn_source_code_save_path, std::ios_base::out);
66+
}
67+
}
68+
69+
~SourceCodePrint() {
70+
if (of.is_open()) {
71+
of.close();
72+
}
73+
};
74+
75+
std::ofstream of;
76+
};
77+
3878
void Compiler::Build(const Module& module, const std::string& code) {
3979
if (target_.arch == Target::Arch::NVGPU) {
4080
CompileCudaModule(module, code);
@@ -83,22 +123,7 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code)
83123
CodeGenCUDA_Dev codegen(target_);
84124
auto source_code = codegen.Compile(device_module);
85125
if (!code.empty()) source_code = code;
86-
if (FLAGS_cinn_source_code_save_path.empty()) {
87-
if (source_code.size() > DebugLogMaxLen) {
88-
VLOG(3) << "[CUDA] source code-0:\n" << source_code.substr(0, DebugLogMaxLen);
89-
for (int i = 1; i * DebugLogMaxLen < source_code.size(); ++i) {
90-
VLOG(3) << "[CUDA] source code-" << i << ":\n" << source_code.substr(DebugLogMaxLen * i, DebugLogMaxLen);
91-
}
92-
} else {
93-
VLOG(3) << "[CUDA] source code:\n" << source_code;
94-
}
95-
} else {
96-
VLOG(4) << "Write to " << FLAGS_cinn_source_code_save_path;
97-
std::ofstream of(FLAGS_cinn_source_code_save_path, std::ofstream::out);
98-
CHECK(of.is_open()) << "Failed to open " << FLAGS_cinn_source_code_save_path;
99-
of << source_code << std::endl;
100-
of.close();
101-
}
126+
SourceCodePrint::GetInstance()->write(source_code);
102127
using runtime::cuda::CUDAModule;
103128

104129
backends::nvrtc::Compiler compiler;

cinn/frontend/op_mappers/paddle/constant.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,19 @@ void FillConstantOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperCont
6060

6161
auto shape = utils::ToShapeType(utils::GetAttrOrDefault<std::vector<int64_t>>(op_desc, "shape"));
6262
auto value = utils::GetAttrOrDefault<float>(op_desc, "value", 0.0f);
63+
auto str_value = utils::GetAttrOrDefault<std::string>(op_desc, "str_value", "");
6364
auto force_cpu = utils::GetAttrOrDefault<bool>(op_desc, "force_cpu", false);
6465

6566
auto dtype_id = utils::GetAttrOrDefault<int>(op_desc, "dtype", static_cast<int>(paddle::cpp::VarDescAPI::Type::FP32));
6667
auto dtype_pd = static_cast<paddle::cpp::VarDescAPI::Type>(dtype_id);
6768
auto dtype_cinn = utils::CppVarType2CommonType(dtype_pd);
6869
auto dtype = common::Type2Str(dtype_cinn);
6970

71+
if (!str_value.empty()) {
72+
size_t end_pos = 0;
73+
value = std::stof(str_value, &end_pos);
74+
}
75+
7076
VLOG(4) << "fill constant (" << value << ") with shape (" << cinn::utils::Join(shape, ",") << ") and dtype [" << dtype
7177
<< "]";
7278

cinn/hlir/framework/graph.cc

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ void Graph::VisualizeGroupedGraph(const std::vector<std::vector<Node*>>& groups,
149149

150150
auto& shape_dict = HasAttr("infershape") ? GetAttrs<absl::flat_hash_map<std::string, shape_t>>("infershape")
151151
: absl::flat_hash_map<std::string, shape_t>{};
152+
auto& dtype_dict = HasAttr("inferdtype") ? GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype")
153+
: absl::flat_hash_map<std::string, common::Type>{};
152154

153155
std::unordered_map<std::string, int> recompute_nodes;
154156
FindRecomputeNodes(groups, &recompute_nodes);
@@ -166,8 +168,15 @@ void Graph::VisualizeGroupedGraph(const std::vector<std::vector<Node*>>& groups,
166168

167169
std::unordered_map<std::string, std::string> outnode2dot_id;
168170
for (auto* node : group) {
169-
AddGroupNode(
170-
node, dot_cluster_id, fetch_var_ids, shape_dict, &recompute_nodes, &outnode2dot_id, &nodedatas_set, &dot);
171+
AddGroupNode(node,
172+
dot_cluster_id,
173+
fetch_var_ids,
174+
shape_dict,
175+
dtype_dict,
176+
&recompute_nodes,
177+
&outnode2dot_id,
178+
&nodedatas_set,
179+
&dot);
171180
}
172181
group_id++;
173182
}
@@ -182,6 +191,8 @@ void Graph::VisualizeGroups(const std::vector<std::vector<Node*>>& groups,
182191
const std::unordered_set<std::string>& fetch_var_ids) {
183192
auto& shape_dict = HasAttr("infershape") ? GetAttrs<absl::flat_hash_map<std::string, shape_t>>("infershape")
184193
: absl::flat_hash_map<std::string, shape_t>{};
194+
auto& dtype_dict = HasAttr("inferdtype") ? GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype")
195+
: absl::flat_hash_map<std::string, common::Type>{};
185196

186197
std::unordered_map<std::string, int> recompute_nodes;
187198
FindRecomputeNodes(groups, &recompute_nodes);
@@ -197,7 +208,15 @@ void Graph::VisualizeGroups(const std::vector<std::vector<Node*>>& groups,
197208

198209
std::unordered_map<std::string, std::string> outnode2dot_id;
199210
for (auto* node : group) {
200-
AddGroupNode(node, dot_cluster_id, fetch_var_ids, shape_dict, &recompute_nodes, &outnode2dot_id, nullptr, &dot);
211+
AddGroupNode(node,
212+
dot_cluster_id,
213+
fetch_var_ids,
214+
shape_dict,
215+
dtype_dict,
216+
&recompute_nodes,
217+
&outnode2dot_id,
218+
nullptr,
219+
&dot);
201220
nodes_set.insert(node);
202221
}
203222

cinn/hlir/framework/visualize_helper.cc

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,38 @@ namespace cinn {
3232
namespace hlir {
3333
namespace framework {
3434

35+
std::string Attribute2String(const utils::Attribute& attr) {
36+
std::stringstream ss;
37+
if (absl::get_if<bool>(&attr)) {
38+
ss << std::boolalpha << absl::get<bool>(attr);
39+
} else if (absl::get_if<float>(&attr)) {
40+
ss << absl::get<float>(attr) << "f";
41+
} else if (absl::get_if<double>(&attr)) {
42+
ss << absl::get<double>(attr);
43+
} else if (absl::get_if<int>(&attr)) {
44+
ss << absl::get<int>(attr);
45+
} else if (absl::get_if<int64_t>(&attr)) {
46+
ss << absl::get<int64_t>(attr);
47+
} else if (absl::get_if<std::string>(&attr)) {
48+
ss << absl::get<std::string>(attr);
49+
} else if (absl::get_if<std::vector<bool>>(&attr)) {
50+
ss << "[" + cinn::utils::Join(absl::get<std::vector<bool>>(attr), ", ") + "]";
51+
} else if (absl::get_if<std::vector<int>>(&attr)) {
52+
ss << "[" + cinn::utils::Join(absl::get<std::vector<int>>(attr), ", ") + "]";
53+
} else if (absl::get_if<std::vector<int64_t>>(&attr)) {
54+
ss << "[" + cinn::utils::Join(absl::get<std::vector<int64_t>>(attr), ", ") + "]";
55+
} else if (absl::get_if<std::vector<float>>(&attr)) {
56+
ss << "[" + cinn::utils::Join(absl::get<std::vector<float>>(attr), ", ") + "]";
57+
} else if (absl::get_if<std::vector<double>>(&attr)) {
58+
ss << "[" + cinn::utils::Join(absl::get<std::vector<double>>(attr), ", ") + "]";
59+
} else if (absl::get_if<std::vector<std::string>>(&attr)) {
60+
ss << "[" + cinn::utils::Join(absl::get<std::vector<std::string>>(attr), ", ") + "]";
61+
} else {
62+
LOG(FATAL) << "Unkown attribute data type! Please check.";
63+
}
64+
return ss.str();
65+
}
66+
3567
bool MakeDirectory(const std::string& dirname, mode_t mode) {
3668
auto len = dirname.length();
3769
std::vector<char> dir_path(len + 1, '\0');
@@ -96,22 +128,27 @@ std::string GetFilePathForGroup(const std::vector<std::vector<Node*>>& groups,
96128

97129
std::string GenNodeDataLabel(const NodeData* node,
98130
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
131+
const absl::flat_hash_map<std::string, common::Type>& dtype_dict,
99132
const std::string dot_nodedata_id) {
133+
std::stringstream ss;
134+
ss << dot_nodedata_id;
100135
if (shape_dict.count(node->id())) {
101136
shape_t node_shape = shape_dict.at(node->id());
102-
std::stringstream ss;
103-
ss << dot_nodedata_id << "\\n{";
137+
ss << "\\n[";
104138
for (size_t i = 0; i < node_shape.size(); ++i) {
105139
if (i > 0) {
106140
ss << "x";
107141
}
108142
ss << node_shape[i];
109143
}
110-
ss << "}";
111-
return ss.str();
112-
} else {
113-
return dot_nodedata_id;
144+
ss << "]";
114145
}
146+
if (dtype_dict.count(node->id())) {
147+
ss << "\\n";
148+
ss << common::Type2Str(dtype_dict.at(node->id()));
149+
}
150+
151+
return ss.str();
115152
}
116153

117154
void Summary(const std::vector<std::vector<Node*>>& groups, const std::string& viz_path) {
@@ -225,40 +262,8 @@ std::string DebugString(const Node* node) {
225262
}
226263
ss << ", id=" << node->id() << ", ";
227264

228-
auto get_attr_value = [](const utils::Attribute& attr) -> std::string {
229-
std::stringstream ss;
230-
if (absl::get_if<bool>(&attr)) {
231-
ss << std::boolalpha << absl::get<bool>(attr);
232-
} else if (absl::get_if<float>(&attr)) {
233-
ss << std::scientific << absl::get<float>(attr);
234-
} else if (absl::get_if<double>(&attr)) {
235-
ss << std::scientific << absl::get<double>(attr);
236-
} else if (absl::get_if<int>(&attr)) {
237-
ss << absl::get<int>(attr);
238-
} else if (absl::get_if<int64_t>(&attr)) {
239-
ss << absl::get<int64_t>(attr);
240-
} else if (absl::get_if<std::string>(&attr)) {
241-
ss << absl::get<std::string>(attr);
242-
} else if (absl::get_if<std::vector<bool>>(&attr)) {
243-
ss << "[" + cinn::utils::Join(absl::get<std::vector<bool>>(attr), ", ") + "]";
244-
} else if (absl::get_if<std::vector<int>>(&attr)) {
245-
ss << "[" + cinn::utils::Join(absl::get<std::vector<int>>(attr), ", ") + "]";
246-
} else if (absl::get_if<std::vector<int64_t>>(&attr)) {
247-
ss << "[" + cinn::utils::Join(absl::get<std::vector<int64_t>>(attr), ", ") + "]";
248-
} else if (absl::get_if<std::vector<float>>(&attr)) {
249-
ss << "[" + cinn::utils::Join(absl::get<std::vector<float>>(attr), ", ") + "]";
250-
} else if (absl::get_if<std::vector<double>>(&attr)) {
251-
ss << "[" + cinn::utils::Join(absl::get<std::vector<double>>(attr), ", ") + "]";
252-
} else if (absl::get_if<std::vector<std::string>>(&attr)) {
253-
ss << "[" + cinn::utils::Join(absl::get<std::vector<std::string>>(attr), ", ") + "]";
254-
} else {
255-
LOG(FATAL) << "Unkown attribute data type! Please check.";
256-
}
257-
return ss.str();
258-
};
259-
260265
for (const auto& attr_pair : node->attrs.attr_store) {
261-
ss << attr_pair.first << "=" << get_attr_value(attr_pair.second) << ", ";
266+
ss << attr_pair.first << "=" << Attribute2String(attr_pair.second) << ", ";
262267
}
263268
ss << "}";
264269
return ss.str();
@@ -283,6 +288,7 @@ void AddGroupNode(const Node* node,
283288
const std::string& dot_cluster_id,
284289
const std::unordered_set<std::string>& fetch_var_ids,
285290
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
291+
const absl::flat_hash_map<std::string, common::Type>& dtype_dict,
286292
std::unordered_map<std::string, int>* recompute_nodes,
287293
std::unordered_map<std::string, std::string>* outnode2dot_id,
288294
std::unordered_set<std::string>* nodedatas_set,
@@ -301,7 +307,7 @@ void AddGroupNode(const Node* node,
301307
}
302308
std::string dot_innode_id = outnode2dot_id->at(innode->id());
303309
if (!nodedatas_set || !nodedatas_set->count(dot_innode_id)) {
304-
std::string label = GenNodeDataLabel(innode, shape_dict, dot_innode_id);
310+
std::string label = GenNodeDataLabel(innode, shape_dict, dtype_dict, dot_innode_id);
305311
dot->AddNode(dot_innode_id, GetGroupVarAttrs(false), label, dot_cluster_id, true);
306312
if (nodedatas_set) {
307313
nodedatas_set->insert(dot_innode_id);
@@ -318,7 +324,7 @@ void AddGroupNode(const Node* node,
318324
(*outnode2dot_id)[outnode->id()] = dot_outnode_id;
319325
if (!nodedatas_set || !nodedatas_set->count(dot_outnode_id)) {
320326
bool is_fetched = fetch_var_ids.count(outnode->id());
321-
std::string label = GenNodeDataLabel(outnode, shape_dict, dot_outnode_id);
327+
std::string label = GenNodeDataLabel(outnode, shape_dict, dtype_dict, dot_outnode_id);
322328
dot->AddNode(dot_outnode_id, GetGroupVarAttrs(is_fetched), label, dot_cluster_id, true);
323329
if (nodedatas_set) {
324330
nodedatas_set->insert(dot_outnode_id);

cinn/hlir/framework/visualize_helper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ namespace cinn {
3232
namespace hlir {
3333
namespace framework {
3434

35+
std::string Attribute2String(const utils::Attribute& attr);
36+
3537
inline void WriteToFile(const std::string& filepath, const std::string& content) {
3638
VLOG(4) << "Write to " << filepath;
3739
std::ofstream of(filepath);
@@ -105,6 +107,7 @@ std::string GetFilePathForGroup(const std::vector<std::vector<Node*>>& groups,
105107

106108
std::string GenNodeDataLabel(const NodeData* node,
107109
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
110+
const absl::flat_hash_map<std::string, common::Type>& dtype_dict,
108111
const std::string dot_nodedata_id);
109112

110113
void Summary(const std::vector<std::vector<Node*>>& groups, const std::string& viz_path);
@@ -118,6 +121,7 @@ void AddGroupNode(const Node* node,
118121
const std::string& dot_cluster_id,
119122
const std::unordered_set<std::string>& fetch_var_ids,
120123
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
124+
const absl::flat_hash_map<std::string, common::Type>& dtype_dict,
121125
std::unordered_map<std::string, int>* recompute_nodes,
122126
std::unordered_map<std::string, std::string>* outnode2dot_id,
123127
std::unordered_set<std::string>* nodedatas_set,

0 commit comments

Comments
 (0)