14
14
15
15
import argparse
16
16
import logging
17
+ import math
17
18
import os
18
19
import pathlib
19
20
import sys
@@ -1130,6 +1131,21 @@ def get_mutable_attribute_grad_semantic(op_info, op_info_items):
1130
1131
return mutable_attribute_grad_semantics
1131
1132
1132
1133
1134
+ def split_ops (op_info_items : dict , cc_file , split_nums ):
1135
+ op_list = list (op_info_items .keys ())
1136
+ ops_max_size = math .ceil (len (op_list ) / split_nums )
1137
+ split_op_info_items = []
1138
+ for i in range (split_nums ):
1139
+ split_op_info_items .append ({})
1140
+ for i , op_name in enumerate (op_list ):
1141
+ list_idx = math .ceil ((i + 1 ) / ops_max_size ) - 1
1142
+ split_op_info_items [list_idx ][op_name ] = op_info_items [op_name ]
1143
+ split_cc_files = []
1144
+ for i in range (split_nums ):
1145
+ split_cc_files .append (cc_file .replace (".cc" , f"{ i + 1 } .cc" ))
1146
+ return split_op_info_items , split_cc_files
1147
+
1148
+
1133
1149
def GenOneDnnExtraAttrsDefaultValue (onednn_extra_args ):
1134
1150
INTARRAY_STR_TEMPLATE = """ pir::Attribute attr_{attr_name} = {op_attribute_type}::get(pir::IrContext::Instance(), phi::IntArray({attr}));
1135
1151
"""
@@ -2080,6 +2096,8 @@ def OpGenerator(
2080
2096
op_info_file ,
2081
2097
op_def_cc_file ,
2082
2098
op_vjp_cc_file ,
2099
+ op_cc_split_num ,
2100
+ bwd_op_cc_split_num ,
2083
2101
onednn_yaml_file ,
2084
2102
ops_onednn_extra_yaml_file ,
2085
2103
):
@@ -2126,9 +2144,11 @@ def OpGenerator(
2126
2144
2127
2145
op_infos = []
2128
2146
all_op_info_items = {}
2147
+ new_op_def_cc_file = []
2129
2148
first_file = True
2130
2149
onednn_only_op_list = []
2131
- for yaml_file in op_yaml_files :
2150
+ for idx in range (len (op_yaml_files )):
2151
+ yaml_file = op_yaml_files [idx ]
2132
2152
op_yaml_items = []
2133
2153
with open (yaml_file , "r" ) as f :
2134
2154
ops = yaml .safe_load (f )
@@ -2194,13 +2214,37 @@ def OpGenerator(
2194
2214
key_suffix = '_sp' if item .is_sparse_op else ''
2195
2215
op_info_items [op ['name' ] + key_suffix ] = item
2196
2216
all_op_info_items [op ['name' ] + key_suffix ] = item
2197
- op_infos .append (op_info_items )
2217
+
2218
+ if dialect_name != "onednn_op" :
2219
+ cc_file = op_def_cc_file [idx ]
2220
+ if (
2221
+ yaml_file .split ('/' )[- 1 ] == "ops.parsed.yaml"
2222
+ and op_cc_split_num is not None
2223
+ ):
2224
+ split_op_info_items , split_cc_files = split_ops (
2225
+ op_info_items , cc_file , op_cc_split_num
2226
+ )
2227
+ op_infos .extend (split_op_info_items )
2228
+ new_op_def_cc_file .extend (split_cc_files )
2229
+ elif (
2230
+ yaml_file .split ('/' )[- 1 ] == "backward.parsed.yaml"
2231
+ and bwd_op_cc_split_num is not None
2232
+ ):
2233
+ split_op_info_items , split_cc_files = split_ops (
2234
+ op_info_items , cc_file , bwd_op_cc_split_num
2235
+ )
2236
+ op_infos .extend (split_op_info_items )
2237
+ new_op_def_cc_file .extend (split_cc_files )
2238
+ else :
2239
+ op_infos .append (op_info_items )
2240
+ new_op_def_cc_file .append (cc_file )
2198
2241
2199
2242
if first_file :
2200
2243
first_file = False
2201
2244
2202
2245
if dialect_name == "onednn_op" :
2203
2246
op_infos = [all_op_info_items ]
2247
+ new_op_def_cc_file = op_def_cc_file
2204
2248
# (3) auto code gen
2205
2249
op_list_strs = []
2206
2250
declare_type_id_strs = []
@@ -2329,7 +2373,7 @@ def OpGenerator(
2329
2373
f .write (op_info_str )
2330
2374
2331
2375
# (6) write to files for xx_op.cc.tmp
2332
- for id in range (len (op_def_cc_file )):
2376
+ for id in range (len (new_op_def_cc_file )):
2333
2377
source_file_str = source_file_strs [id ]
2334
2378
for name in reversed (namespaces ):
2335
2379
source_file_str = NAMESPACE_GARD_TEMPLATE .format (
@@ -2349,7 +2393,7 @@ def OpGenerator(
2349
2393
input = source_file_str ,
2350
2394
define_type_id = define_type_id_strs [id ],
2351
2395
)
2352
- with open (op_def_cc_file [id ], 'w' ) as f :
2396
+ with open (new_op_def_cc_file [id ], 'w' ) as f :
2353
2397
f .write (source_file_str )
2354
2398
2355
2399
# (6) write to files for xx_vjp_op.cc.tmp
@@ -2381,6 +2425,8 @@ def ParseArguments():
2381
2425
parser .add_argument ('--op_info_file' , type = str )
2382
2426
parser .add_argument ('--op_def_cc_file' , type = str )
2383
2427
parser .add_argument ('--op_vjp_cc_file' , type = str )
2428
+ parser .add_argument ('--op_cc_split_num' , type = int )
2429
+ parser .add_argument ('--bwd_op_cc_split_num' , type = int )
2384
2430
parser .add_argument ('--onednn_yaml_file' , type = str )
2385
2431
parser .add_argument ('--ops_onednn_extra_yaml_file' , type = str )
2386
2432
parser .add_argument ('--with_distributed' , type = strtobool )
@@ -2403,6 +2449,8 @@ def ParseArguments():
2403
2449
op_info_file = args .op_info_file
2404
2450
op_def_cc_files = args .op_def_cc_file .split ("," )
2405
2451
op_vjp_cc_file = args .op_vjp_cc_file
2452
+ op_cc_split_num = args .op_cc_split_num
2453
+ bwd_op_cc_split_num = args .bwd_op_cc_split_num
2406
2454
onednn_yaml_file = args .onednn_yaml_file
2407
2455
ops_onednn_extra_yaml_file = args .ops_onednn_extra_yaml_file
2408
2456
@@ -2417,6 +2465,8 @@ def ParseArguments():
2417
2465
op_info_file ,
2418
2466
op_def_cc_files ,
2419
2467
op_vjp_cc_file ,
2468
+ op_cc_split_num ,
2469
+ bwd_op_cc_split_num ,
2420
2470
onednn_yaml_file ,
2421
2471
ops_onednn_extra_yaml_file ,
2422
2472
)
0 commit comments