Skip to content

Commit fc5fb2a

Browse files
authored
add dist_attr for dist op and var (#35585)
* add dist_attr for dist op * add unitest * update inputname * update function name * add unitest * update CMakeLists.txt for CI * fix dis_matmul * fix compile error * update matmul to matmul_v2
1 parent 09eaa7d commit fc5fb2a

File tree

6 files changed

+354
-58
lines changed

6 files changed

+354
-58
lines changed

paddle/fluid/operators/searchsorted_op.h

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,37 @@ using Tensor = framework::Tensor;
3030
template <typename T1, typename T2, typename OutType>
3131
class GpuAndCpuSearchSortedCompute {
3232
public:
33-
static HOSTDEVICE bool IsNan(float x) { return ::isnan(x); }
34-
static HOSTDEVICE bool IsNan(double x) { return ::isnan(x); }
33+
static HOSTDEVICE bool IsNan(float x) {
34+
#ifdef __NVCC__
35+
return ::isnan(x);
36+
#else
37+
return std::isnan(x);
38+
#endif
39+
}
40+
static HOSTDEVICE bool IsNan(double x) {
41+
#ifdef __NVCC__
42+
return ::isnan(x);
43+
#else
44+
return std::isnan(x);
45+
#endif
46+
}
3547
static HOSTDEVICE bool IsNan(int x) { return false; }
3648
static HOSTDEVICE bool IsNan(int64_t x) { return false; }
3749

38-
static HOSTDEVICE bool IsInf(float x) { return ::isinf(x); }
39-
static HOSTDEVICE bool IsInf(double x) { return ::isinf(x); }
50+
static HOSTDEVICE bool IsInf(float x) {
51+
#ifdef __NVCC__
52+
return ::isinf(x);
53+
#else
54+
return std::isinf(x);
55+
#endif
56+
}
57+
static HOSTDEVICE bool IsInf(double x) {
58+
#ifdef __NVCC__
59+
return ::isinf(x);
60+
#else
61+
return std::isinf(x);
62+
#endif
63+
}
4064
static HOSTDEVICE bool IsInf(int x) { return false; }
4165
static HOSTDEVICE bool IsInf(int64_t x) { return false; }
4266

python/paddle/distributed/auto_parallel/operators/common.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,50 @@ def find_best_compatible_distributed_operator_impl(name, op_dist_attr,
114114
best_compatible_impl, idx = None, -1
115115

116116
return best_compatible_impl, idx
117+
118+
119+
def copy_distributed_attr_for_var(src_op_dist_attr, var, src_var):
120+
"""
121+
copy src var's dist_attr to dst var
122+
"""
123+
import copy
124+
125+
auto_paralle_context = src_op_dist_attr.get_owner_context()
126+
dist_attr = copy.deepcopy(
127+
auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
128+
dist_attr._owner_tensor = var
129+
dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
130+
src_var)._owner_context
131+
auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)
132+
133+
134+
def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr):
135+
"""
136+
copy src op's dist_attr to dst dist op
137+
"""
138+
from ..attribute import OperatorDistributedAttribute
139+
140+
auto_paralle_context = src_op_dist_attr.get_owner_context()
141+
op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context)
142+
auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc,
143+
op_dist_attr)
144+
auto_paralle_context.set_op_distributed_attr_for_program(dist_op,
145+
op_dist_attr)
146+
147+
op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh())
148+
op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx())
149+
150+
for input_varname in dist_op.desc.input_arg_names():
151+
input_var = dst_block.var(input_varname)
152+
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
153+
input_var)
154+
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
155+
op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping)
156+
157+
for output_varname in dist_op.desc.output_arg_names():
158+
output_var = dst_block.var(output_varname)
159+
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
160+
output_var)
161+
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
162+
op_dist_attr.set_output_dims_mapping(output_varname,
163+
tensor_dims_mapping)

python/paddle/distributed/auto_parallel/operators/dist_embedding.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from .common import DistributedOperatorImpl
1717
from .common import register_distributed_operator
1818
from .common import register_distributed_operator_impl
19+
from .common import copy_distributed_attr_for_var
20+
from .common import copy_distributed_attr_for_dist_op
1921
from ..utils import is_dim_shard
2022
from ..utils import is_dim_replicate
2123
from ..utils import is_valid_list_index
@@ -173,21 +175,24 @@ def static_handle(dst_block,
173175
type=core.VarDesc.VarType.LOD_TENSOR,
174176
persistable=False,
175177
stop_gradient=Out_var.stop_gradient)
178+
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
179+
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
180+
Out_var)
176181

177182
check_variable_and_dtype(
178183
Out_var, 'tensor',
179184
['float16', 'float32', 'float64', 'int32', 'int64'],
180185
'c_allreduce_sum')
181186

182-
dst_block.append_op(
187+
c_embedding_op = dst_block.append_op(
183188
type='c_embedding',
184189
inputs={'Ids': [Ids_var],
185190
'W': [Weight_var]},
186191
outputs={'Out': [intermediate_var_0]},
187192
attrs={"start_index": relative_idx})
188193

189194
# use_model_parallel
190-
dst_block.append_op(
195+
c_allreduce_sum_op = dst_block.append_op(
191196
type='c_allreduce_sum',
192197
inputs={'X': [intermediate_var_0]},
193198
outputs={'Out': [Out_var]},
@@ -197,6 +202,12 @@ def static_handle(dst_block,
197202
'use_model_parallel': True,
198203
})
199204

205+
# copy serial op's dist_attr to dist op's dist_attr
206+
copy_distributed_attr_for_dist_op(c_embedding_op, dst_block,
207+
op_dist_attr)
208+
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block,
209+
op_dist_attr)
210+
200211
if in_dygraph_mode():
201212
raise NotImplementedError(
202213
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(

python/paddle/distributed/auto_parallel/operators/dist_matmul.py

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from .common import DistributedOperatorImpl
1717
from .common import register_distributed_operator
1818
from .common import register_distributed_operator_impl
19+
from .common import copy_distributed_attr_for_var
20+
from .common import copy_distributed_attr_for_dist_op
1921
from ..utils import is_dim_shard
2022
from ..utils import is_dim_replicate
2123
from ..utils import is_valid_list_index
@@ -223,13 +225,16 @@ def static_handle(dst_block,
223225
type=core.VarDesc.VarType.LOD_TENSOR,
224226
persistable=False,
225227
stop_gradient=X_var.stop_gradient)
228+
# copy X_var's dist_attr to intermediate_var_0's dist_attr
229+
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
230+
X_var)
226231

227232
check_variable_and_dtype(
228233
X_var, 'tensor',
229234
['float16', 'float32', 'float64', 'int32', 'int64'],
230235
'_c_identity')
231236

232-
dst_block.append_op(
237+
c_identity_op = dst_block.append_op(
233238
type='c_identity',
234239
inputs={'X': [X_var]},
235240
outputs={'Out': intermediate_var_0},
@@ -250,12 +255,18 @@ def static_handle(dst_block,
250255
'alpha': 1,
251256
}
252257
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
253-
dst_block.append_op(
258+
matmul_op = dst_block.append_op(
254259
type='matmul',
255260
inputs=inputs,
256261
outputs={'Out': Out_var},
257262
attrs=attrs)
258263

264+
# copy serial op's dist_attr to dist op's dist_attr
265+
copy_distributed_attr_for_dist_op(c_identity_op, dst_block,
266+
op_dist_attr)
267+
copy_distributed_attr_for_dist_op(matmul_op, dst_block,
268+
op_dist_attr)
269+
259270
if in_dygraph_mode():
260271
raise NotImplementedError(
261272
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
@@ -369,13 +380,17 @@ def static_handle(dst_block,
369380
persistable=False,
370381
is_data=False,
371382
need_check_feed=Out_var.desc.need_check_feed())
372-
dst_block.append_op(
383+
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
384+
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
385+
Out_var)
386+
387+
matmul_op = dst_block.append_op(
373388
type='matmul',
374389
inputs=inputs,
375390
outputs={'Out': intermediate_var_0},
376391
attrs=attrs)
377392

378-
dst_block.append_op(
393+
c_allreduce_sum_op = dst_block.append_op(
379394
type='c_allreduce_sum',
380395
inputs={'X': intermediate_var_0},
381396
outputs={'Out': Out_var},
@@ -385,6 +400,12 @@ def static_handle(dst_block,
385400
'use_model_parallel': True
386401
})
387402

403+
# copy serial op's dist_attr to dist op's dist_attr
404+
copy_distributed_attr_for_dist_op(matmul_op, dst_block,
405+
op_dist_attr)
406+
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block,
407+
op_dist_attr)
408+
388409
if in_dygraph_mode():
389410
raise NotImplementedError(
390411
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
@@ -540,15 +561,12 @@ def static_handle(dst_block,
540561
Out_var = dst_block.var(output_name_mapping['Out'][0])
541562

542563
# TODO infer logic comm presentation
543-
from ..process import new_process_group
544-
from ..transpiler import _get_comm_group
545564
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
546565
)._get_model_parallel_info()
547-
group_ranks = _get_comm_group(process_mesh.topology,
548-
model_parallel_axis,
549-
process_mesh.process_group, rank_id)
566+
group_ranks = _get_comm_group(process_mesh.process_group,
567+
process_mesh.topology,
568+
model_parallel_axis, rank_id)
550569
group = new_process_group(group_ranks)
551-
# print("@@@@@@@@@@@@@@@@@@@@@ 5", group)
552570

553571
intermediate_var_0 = dst_block.create_var(
554572
name=unique_name.generate_with_ignorable_key(".".join(
@@ -558,13 +576,16 @@ def static_handle(dst_block,
558576
type=core.VarDesc.VarType.LOD_TENSOR,
559577
persistable=False,
560578
stop_gradient=X_var.stop_gradient)
579+
# copy X_var's dist_attr to intermediate_var_0's dist_attr
580+
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
581+
X_var)
561582

562583
check_variable_and_dtype(
563584
X_var, 'tensor',
564585
['float16', 'float32', 'float64', 'int32', 'int64'],
565586
'_c_identity')
566587

567-
dst_block.append_op(
588+
c_identity_op = dst_block.append_op(
568589
type='c_identity',
569590
inputs={'X': [X_var]},
570591
outputs={'Out': intermediate_var_0},
@@ -581,12 +602,18 @@ def static_handle(dst_block,
581602
['float16', 'float32', 'float64'], 'linear')
582603
attrs = {'trans_x': False, 'trans_y': False}
583604
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
584-
dst_block.append_op(
605+
matmul_v2_op = dst_block.append_op(
585606
type='matmul_v2',
586607
inputs=inputs,
587608
outputs={'Out': Out_var},
588609
attrs=attrs)
589610

611+
# copy serial op's dist_attr to dist op's dist_attr
612+
copy_distributed_attr_for_dist_op(c_identity_op, dst_block,
613+
op_dist_attr)
614+
copy_distributed_attr_for_dist_op(matmul_v2_op, dst_block,
615+
op_dist_attr)
616+
590617
if in_dygraph_mode():
591618
raise NotImplementedError(
592619
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
@@ -675,15 +702,12 @@ def static_handle(dst_block,
675702
Out_var = dst_block.var(output_name_mapping['Out'][0])
676703

677704
# TODO infer logic comm presentation
678-
from ..process import new_process_group
679-
from ..transpiler import _get_comm_group
680705
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
681706
)._get_model_parallel_info()
682-
group_ranks = _get_comm_group(process_mesh.topology,
683-
model_parallel_axis,
684-
process_mesh.process_group, rank_id)
707+
group_ranks = _get_comm_group(process_mesh.process_group,
708+
process_mesh.topology,
709+
model_parallel_axis, rank_id)
685710
group = new_process_group(group_ranks)
686-
# print("@@@@@@@@@@@@@@@@@@@@@ 4", group)
687711

688712
check_variable_and_dtype(
689713
X_var, 'x', ['float16', 'float32', 'float64'], 'linear')
@@ -699,13 +723,17 @@ def static_handle(dst_block,
699723
persistable=False,
700724
is_data=False,
701725
need_check_feed=Out_var.desc.need_check_feed())
702-
dst_block.append_op(
726+
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
727+
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
728+
Out_var)
729+
730+
matmul_v2_op = dst_block.append_op(
703731
type='matmul_v2',
704732
inputs=inputs,
705733
outputs={'Out': intermediate_var_0},
706734
attrs=attrs)
707735

708-
dst_block.append_op(
736+
c_allreduce_sum_op = dst_block.append_op(
709737
type='c_allreduce_sum',
710738
inputs={'X': intermediate_var_0},
711739
outputs={'Out': Out_var},
@@ -715,6 +743,12 @@ def static_handle(dst_block,
715743
'use_model_parallel': True
716744
})
717745

746+
# copy serial op's dist_attr to dist op's dist_attr
747+
copy_distributed_attr_for_dist_op(matmul_v2_op, dst_block,
748+
op_dist_attr)
749+
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block,
750+
op_dist_attr)
751+
718752
if in_dygraph_mode():
719753
raise NotImplementedError(
720754
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,8 @@ if(WITH_DISTRIBUTE)
582582
py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS})
583583
py_test_modules(test_fleet_lars_meta_optimizer MODULES test_fleet_lars_meta_optimizer ENVS ${dist_ENVS})
584584
py_test_modules(test_fleet_lamb_meta_optimizer MODULES test_fleet_lamb_meta_optimizer ENVS ${dist_ENVS})
585+
py_test_modules(test_auto_parallel_partitioner MODULES test_auto_parallel_partitioner ENVS ${dist_ENVS})
586+
py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS})
585587
endif(NOT WIN32)
586588
endif(NOT APPLE)
587589
if(WITH_DGC)

0 commit comments

Comments
 (0)