@@ -32,6 +32,38 @@ namespace cinn {
32
32
namespace hlir {
33
33
namespace framework {
34
34
35
+ std::string Attribute2String (const utils::Attribute& attr) {
36
+ std::stringstream ss;
37
+ if (absl::get_if<bool >(&attr)) {
38
+ ss << std::boolalpha << absl::get<bool >(attr);
39
+ } else if (absl::get_if<float >(&attr)) {
40
+ ss << absl::get<float >(attr) << " f" ;
41
+ } else if (absl::get_if<double >(&attr)) {
42
+ ss << absl::get<double >(attr);
43
+ } else if (absl::get_if<int >(&attr)) {
44
+ ss << absl::get<int >(attr);
45
+ } else if (absl::get_if<int64_t >(&attr)) {
46
+ ss << absl::get<int64_t >(attr);
47
+ } else if (absl::get_if<std::string>(&attr)) {
48
+ ss << absl::get<std::string>(attr);
49
+ } else if (absl::get_if<std::vector<bool >>(&attr)) {
50
+ ss << " [" + cinn::utils::Join (absl::get<std::vector<bool >>(attr), " , " ) + " ]" ;
51
+ } else if (absl::get_if<std::vector<int >>(&attr)) {
52
+ ss << " [" + cinn::utils::Join (absl::get<std::vector<int >>(attr), " , " ) + " ]" ;
53
+ } else if (absl::get_if<std::vector<int64_t >>(&attr)) {
54
+ ss << " [" + cinn::utils::Join (absl::get<std::vector<int64_t >>(attr), " , " ) + " ]" ;
55
+ } else if (absl::get_if<std::vector<float >>(&attr)) {
56
+ ss << " [" + cinn::utils::Join (absl::get<std::vector<float >>(attr), " , " ) + " ]" ;
57
+ } else if (absl::get_if<std::vector<double >>(&attr)) {
58
+ ss << " [" + cinn::utils::Join (absl::get<std::vector<double >>(attr), " , " ) + " ]" ;
59
+ } else if (absl::get_if<std::vector<std::string>>(&attr)) {
60
+ ss << " [" + cinn::utils::Join (absl::get<std::vector<std::string>>(attr), " , " ) + " ]" ;
61
+ } else {
62
+ LOG (FATAL) << " Unkown attribute data type! Please check." ;
63
+ }
64
+ return ss.str ();
65
+ }
66
+
35
67
bool MakeDirectory (const std::string& dirname, mode_t mode) {
36
68
auto len = dirname.length ();
37
69
std::vector<char > dir_path (len + 1 , ' \0 ' );
@@ -96,22 +128,27 @@ std::string GetFilePathForGroup(const std::vector<std::vector<Node*>>& groups,
96
128
97
129
std::string GenNodeDataLabel (const NodeData* node,
98
130
const absl::flat_hash_map<std::string, shape_t >& shape_dict,
131
+ const absl::flat_hash_map<std::string, common::Type>& dtype_dict,
99
132
const std::string dot_nodedata_id) {
133
+ std::stringstream ss;
134
+ ss << dot_nodedata_id;
100
135
if (shape_dict.count (node->id ())) {
101
136
shape_t node_shape = shape_dict.at (node->id ());
102
- std::stringstream ss;
103
- ss << dot_nodedata_id << " \\ n{" ;
137
+ ss << " \\ n[" ;
104
138
for (size_t i = 0 ; i < node_shape.size (); ++i) {
105
139
if (i > 0 ) {
106
140
ss << " x" ;
107
141
}
108
142
ss << node_shape[i];
109
143
}
110
- ss << " }" ;
111
- return ss.str ();
112
- } else {
113
- return dot_nodedata_id;
144
+ ss << " ]" ;
114
145
}
146
+ if (dtype_dict.count (node->id ())) {
147
+ ss << " \\ n" ;
148
+ ss << common::Type2Str (dtype_dict.at (node->id ()));
149
+ }
150
+
151
+ return ss.str ();
115
152
}
116
153
117
154
void Summary (const std::vector<std::vector<Node*>>& groups, const std::string& viz_path) {
@@ -225,40 +262,8 @@ std::string DebugString(const Node* node) {
225
262
}
226
263
ss << " , id=" << node->id () << " , " ;
227
264
228
- auto get_attr_value = [](const utils::Attribute& attr) -> std::string {
229
- std::stringstream ss;
230
- if (absl::get_if<bool >(&attr)) {
231
- ss << std::boolalpha << absl::get<bool >(attr);
232
- } else if (absl::get_if<float >(&attr)) {
233
- ss << std::scientific << absl::get<float >(attr);
234
- } else if (absl::get_if<double >(&attr)) {
235
- ss << std::scientific << absl::get<double >(attr);
236
- } else if (absl::get_if<int >(&attr)) {
237
- ss << absl::get<int >(attr);
238
- } else if (absl::get_if<int64_t >(&attr)) {
239
- ss << absl::get<int64_t >(attr);
240
- } else if (absl::get_if<std::string>(&attr)) {
241
- ss << absl::get<std::string>(attr);
242
- } else if (absl::get_if<std::vector<bool >>(&attr)) {
243
- ss << " [" + cinn::utils::Join (absl::get<std::vector<bool >>(attr), " , " ) + " ]" ;
244
- } else if (absl::get_if<std::vector<int >>(&attr)) {
245
- ss << " [" + cinn::utils::Join (absl::get<std::vector<int >>(attr), " , " ) + " ]" ;
246
- } else if (absl::get_if<std::vector<int64_t >>(&attr)) {
247
- ss << " [" + cinn::utils::Join (absl::get<std::vector<int64_t >>(attr), " , " ) + " ]" ;
248
- } else if (absl::get_if<std::vector<float >>(&attr)) {
249
- ss << " [" + cinn::utils::Join (absl::get<std::vector<float >>(attr), " , " ) + " ]" ;
250
- } else if (absl::get_if<std::vector<double >>(&attr)) {
251
- ss << " [" + cinn::utils::Join (absl::get<std::vector<double >>(attr), " , " ) + " ]" ;
252
- } else if (absl::get_if<std::vector<std::string>>(&attr)) {
253
- ss << " [" + cinn::utils::Join (absl::get<std::vector<std::string>>(attr), " , " ) + " ]" ;
254
- } else {
255
- LOG (FATAL) << " Unkown attribute data type! Please check." ;
256
- }
257
- return ss.str ();
258
- };
259
-
260
265
for (const auto & attr_pair : node->attrs .attr_store ) {
261
- ss << attr_pair.first << " =" << get_attr_value (attr_pair.second ) << " , " ;
266
+ ss << attr_pair.first << " =" << Attribute2String (attr_pair.second ) << " , " ;
262
267
}
263
268
ss << " }" ;
264
269
return ss.str ();
@@ -283,6 +288,7 @@ void AddGroupNode(const Node* node,
283
288
const std::string& dot_cluster_id,
284
289
const std::unordered_set<std::string>& fetch_var_ids,
285
290
const absl::flat_hash_map<std::string, shape_t >& shape_dict,
291
+ const absl::flat_hash_map<std::string, common::Type>& dtype_dict,
286
292
std::unordered_map<std::string, int >* recompute_nodes,
287
293
std::unordered_map<std::string, std::string>* outnode2dot_id,
288
294
std::unordered_set<std::string>* nodedatas_set,
@@ -301,7 +307,7 @@ void AddGroupNode(const Node* node,
301
307
}
302
308
std::string dot_innode_id = outnode2dot_id->at (innode->id ());
303
309
if (!nodedatas_set || !nodedatas_set->count (dot_innode_id)) {
304
- std::string label = GenNodeDataLabel (innode, shape_dict, dot_innode_id);
310
+ std::string label = GenNodeDataLabel (innode, shape_dict, dtype_dict, dot_innode_id);
305
311
dot->AddNode (dot_innode_id, GetGroupVarAttrs (false ), label, dot_cluster_id, true );
306
312
if (nodedatas_set) {
307
313
nodedatas_set->insert (dot_innode_id);
@@ -318,7 +324,7 @@ void AddGroupNode(const Node* node,
318
324
(*outnode2dot_id)[outnode->id ()] = dot_outnode_id;
319
325
if (!nodedatas_set || !nodedatas_set->count (dot_outnode_id)) {
320
326
bool is_fetched = fetch_var_ids.count (outnode->id ());
321
- std::string label = GenNodeDataLabel (outnode, shape_dict, dot_outnode_id);
327
+ std::string label = GenNodeDataLabel (outnode, shape_dict, dtype_dict, dot_outnode_id);
322
328
dot->AddNode (dot_outnode_id, GetGroupVarAttrs (is_fetched), label, dot_cluster_id, true );
323
329
if (nodedatas_set) {
324
330
nodedatas_set->insert (dot_outnode_id);
0 commit comments