16
16
from .common import DistributedOperatorImpl
17
17
from .common import register_distributed_operator
18
18
from .common import register_distributed_operator_impl
19
+ from .common import copy_distributed_attr_for_var
20
+ from .common import copy_distributed_attr_for_dist_op
19
21
from ..utils import is_dim_shard
20
22
from ..utils import is_dim_replicate
21
23
from ..utils import is_valid_list_index
@@ -223,13 +225,16 @@ def static_handle(dst_block,
223
225
type = core .VarDesc .VarType .LOD_TENSOR ,
224
226
persistable = False ,
225
227
stop_gradient = X_var .stop_gradient )
228
+ # copy X_var's dist_attr to intermediate_var_0's dist_attr
229
+ copy_distributed_attr_for_var (op_dist_attr , intermediate_var_0 ,
230
+ X_var )
226
231
227
232
check_variable_and_dtype (
228
233
X_var , 'tensor' ,
229
234
['float16' , 'float32' , 'float64' , 'int32' , 'int64' ],
230
235
'_c_identity' )
231
236
232
- dst_block .append_op (
237
+ c_identity_op = dst_block .append_op (
233
238
type = 'c_identity' ,
234
239
inputs = {'X' : [X_var ]},
235
240
outputs = {'Out' : intermediate_var_0 },
@@ -250,12 +255,18 @@ def static_handle(dst_block,
250
255
'alpha' : 1 ,
251
256
}
252
257
inputs = {'X' : [intermediate_var_0 ], 'Y' : [Weight_var ]}
253
- dst_block .append_op (
258
+ matmul_op = dst_block .append_op (
254
259
type = 'matmul' ,
255
260
inputs = inputs ,
256
261
outputs = {'Out' : Out_var },
257
262
attrs = attrs )
258
263
264
+ # copy serial op's dist_attr to dist op's dist_attr
265
+ copy_distributed_attr_for_dist_op (c_identity_op , dst_block ,
266
+ op_dist_attr )
267
+ copy_distributed_attr_for_dist_op (matmul_op , dst_block ,
268
+ op_dist_attr )
269
+
259
270
if in_dygraph_mode ():
260
271
raise NotImplementedError (
261
272
"Dist op for [{}] with idx [{}] is NOT implemented yet." .format (
@@ -369,13 +380,17 @@ def static_handle(dst_block,
369
380
persistable = False ,
370
381
is_data = False ,
371
382
need_check_feed = Out_var .desc .need_check_feed ())
372
- dst_block .append_op (
383
+ # copy Out_var's dist_attr to intermediate_var_0's dist_attr
384
+ copy_distributed_attr_for_var (op_dist_attr , intermediate_var_0 ,
385
+ Out_var )
386
+
387
+ matmul_op = dst_block .append_op (
373
388
type = 'matmul' ,
374
389
inputs = inputs ,
375
390
outputs = {'Out' : intermediate_var_0 },
376
391
attrs = attrs )
377
392
378
- dst_block .append_op (
393
+ c_allreduce_sum_op = dst_block .append_op (
379
394
type = 'c_allreduce_sum' ,
380
395
inputs = {'X' : intermediate_var_0 },
381
396
outputs = {'Out' : Out_var },
@@ -385,6 +400,12 @@ def static_handle(dst_block,
385
400
'use_model_parallel' : True
386
401
})
387
402
403
+ # copy serial op's dist_attr to dist op's dist_attr
404
+ copy_distributed_attr_for_dist_op (matmul_op , dst_block ,
405
+ op_dist_attr )
406
+ copy_distributed_attr_for_dist_op (c_allreduce_sum_op , dst_block ,
407
+ op_dist_attr )
408
+
388
409
if in_dygraph_mode ():
389
410
raise NotImplementedError (
390
411
"Dist op for [{}] with idx [{}] is NOT implemented yet." .format (
@@ -540,15 +561,12 @@ def static_handle(dst_block,
540
561
Out_var = dst_block .var (output_name_mapping ['Out' ][0 ])
541
562
542
563
# TODO infer logic comm presentation
543
- from ..process import new_process_group
544
- from ..transpiler import _get_comm_group
545
564
model_parallel_axis , process_mesh = op_dist_attr .get_owner_context (
546
565
)._get_model_parallel_info ()
547
- group_ranks = _get_comm_group (process_mesh .topology ,
548
- model_parallel_axis ,
549
- process_mesh . process_group , rank_id )
566
+ group_ranks = _get_comm_group (process_mesh .process_group ,
567
+ process_mesh . topology ,
568
+ model_parallel_axis , rank_id )
550
569
group = new_process_group (group_ranks )
551
- # print("@@@@@@@@@@@@@@@@@@@@@ 5", group)
552
570
553
571
intermediate_var_0 = dst_block .create_var (
554
572
name = unique_name .generate_with_ignorable_key ("." .join (
@@ -558,13 +576,16 @@ def static_handle(dst_block,
558
576
type = core .VarDesc .VarType .LOD_TENSOR ,
559
577
persistable = False ,
560
578
stop_gradient = X_var .stop_gradient )
579
+ # copy X_var's dist_attr to intermediate_var_0's dist_attr
580
+ copy_distributed_attr_for_var (op_dist_attr , intermediate_var_0 ,
581
+ X_var )
561
582
562
583
check_variable_and_dtype (
563
584
X_var , 'tensor' ,
564
585
['float16' , 'float32' , 'float64' , 'int32' , 'int64' ],
565
586
'_c_identity' )
566
587
567
- dst_block .append_op (
588
+ c_identity_op = dst_block .append_op (
568
589
type = 'c_identity' ,
569
590
inputs = {'X' : [X_var ]},
570
591
outputs = {'Out' : intermediate_var_0 },
@@ -581,12 +602,18 @@ def static_handle(dst_block,
581
602
['float16' , 'float32' , 'float64' ], 'linear' )
582
603
attrs = {'trans_x' : False , 'trans_y' : False }
583
604
inputs = {'X' : [intermediate_var_0 ], 'Y' : [Weight_var ]}
584
- dst_block .append_op (
605
+ matmul_v2_op = dst_block .append_op (
585
606
type = 'matmul_v2' ,
586
607
inputs = inputs ,
587
608
outputs = {'Out' : Out_var },
588
609
attrs = attrs )
589
610
611
+ # copy serial op's dist_attr to dist op's dist_attr
612
+ copy_distributed_attr_for_dist_op (c_identity_op , dst_block ,
613
+ op_dist_attr )
614
+ copy_distributed_attr_for_dist_op (matmul_v2_op , dst_block ,
615
+ op_dist_attr )
616
+
590
617
if in_dygraph_mode ():
591
618
raise NotImplementedError (
592
619
"Dist op for [{}] with idx [{}] is NOT implemented yet." .format (
@@ -675,15 +702,12 @@ def static_handle(dst_block,
675
702
Out_var = dst_block .var (output_name_mapping ['Out' ][0 ])
676
703
677
704
# TODO infer logic comm presentation
678
- from ..process import new_process_group
679
- from ..transpiler import _get_comm_group
680
705
model_parallel_axis , process_mesh = op_dist_attr .get_owner_context (
681
706
)._get_model_parallel_info ()
682
- group_ranks = _get_comm_group (process_mesh .topology ,
683
- model_parallel_axis ,
684
- process_mesh . process_group , rank_id )
707
+ group_ranks = _get_comm_group (process_mesh .process_group ,
708
+ process_mesh . topology ,
709
+ model_parallel_axis , rank_id )
685
710
group = new_process_group (group_ranks )
686
- # print("@@@@@@@@@@@@@@@@@@@@@ 4", group)
687
711
688
712
check_variable_and_dtype (
689
713
X_var , 'x' , ['float16' , 'float32' , 'float64' ], 'linear' )
@@ -699,13 +723,17 @@ def static_handle(dst_block,
699
723
persistable = False ,
700
724
is_data = False ,
701
725
need_check_feed = Out_var .desc .need_check_feed ())
702
- dst_block .append_op (
726
+ # copy Out_var's dist_attr to intermediate_var_0's dist_attr
727
+ copy_distributed_attr_for_var (op_dist_attr , intermediate_var_0 ,
728
+ Out_var )
729
+
730
+ matmul_v2_op = dst_block .append_op (
703
731
type = 'matmul_v2' ,
704
732
inputs = inputs ,
705
733
outputs = {'Out' : intermediate_var_0 },
706
734
attrs = attrs )
707
735
708
- dst_block .append_op (
736
+ c_allreduce_sum_op = dst_block .append_op (
709
737
type = 'c_allreduce_sum' ,
710
738
inputs = {'X' : intermediate_var_0 },
711
739
outputs = {'Out' : Out_var },
@@ -715,6 +743,12 @@ def static_handle(dst_block,
715
743
'use_model_parallel' : True
716
744
})
717
745
746
+ # copy serial op's dist_attr to dist op's dist_attr
747
+ copy_distributed_attr_for_dist_op (matmul_v2_op , dst_block ,
748
+ op_dist_attr )
749
+ copy_distributed_attr_for_dist_op (c_allreduce_sum_op , dst_block ,
750
+ op_dist_attr )
751
+
718
752
if in_dygraph_mode ():
719
753
raise NotImplementedError (
720
754
"Dist op for [{}] with idx [{}] is NOT implemented yet." .format (
0 commit comments