Skip to content

Commit 43622a2

Browse files
authored
[RM FLUID] trainer_pass&heter_trainer_pass (#50610)
* [RM FLUID] trainer_pass&heter_trainer_pass * [RM FLUID] rm distributed_strategy
1 parent 47306c5 commit 43622a2

File tree

15 files changed

+63
-1587
lines changed

15 files changed

+63
-1587
lines changed

python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _get_distributed_strategy(self):
9696
return strategy
9797

9898
def _build_trainer_programs(self, compiled_config):
99-
from paddle.fluid.incubate.fleet.parameter_server.ir import (
99+
from paddle.incubate.fleet.parameter_server.ir import (
100100
trainer_pass as worker,
101101
)
102102

@@ -106,7 +106,7 @@ def _build_trainer_programs(self, compiled_config):
106106
use_ps_gpu = self.user_defined_strategy.a_sync_configs["use_ps_gpu"]
107107

108108
if not compiled_config.is_geo_mode():
109-
from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
109+
from paddle.incubate.fleet.parameter_server.ir.public import (
110110
_add_lr_decay_table_pass,
111111
)
112112

@@ -150,7 +150,7 @@ def _build_trainer_programs(self, compiled_config):
150150
compiled_config.set_origin_ps_startup_program(_startup)
151151
# for heter program
152152
if self.role_maker._is_heter_parameter_server_mode:
153-
from paddle.fluid.incubate.fleet.parameter_server.ir import (
153+
from paddle.incubate.fleet.parameter_server.ir import (
154154
heter_trainer_pass as heter_worker,
155155
)
156156

@@ -191,13 +191,13 @@ def _build_pserver_programs(self, compiled_config):
191191
_main = paddle.static.Program()
192192
_startup = paddle.static.Program()
193193

194-
from paddle.fluid.incubate.fleet.parameter_server.ir import (
194+
from paddle.incubate.fleet.parameter_server.ir import (
195195
pserver_pass as server,
196196
)
197197

198198
if not compiled_config.is_geo_mode():
199199

200-
from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
200+
from paddle.incubate.fleet.parameter_server.ir.public import (
201201
_get_optimize_ops,
202202
)
203203

@@ -209,7 +209,7 @@ def _build_pserver_programs(self, compiled_config):
209209
if len(ops) == 0:
210210
return _main, _startup
211211

212-
from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
212+
from paddle.incubate.fleet.parameter_server.ir.public import (
213213
_add_lr_decay_table_pass,
214214
)
215215

@@ -299,9 +299,7 @@ def get_sys_free_mem():
299299

300300
free = get_sys_free_mem()
301301

302-
from paddle.fluid.incubate.fleet.parameter_server.ir import (
303-
vars_metatools,
304-
)
302+
from paddle.incubate.fleet.parameter_server.ir import vars_metatools
305303

306304
processed_var_names = set(["@EMPTY@"])
307305
param_memory_size = 0
@@ -371,9 +369,7 @@ def minimize_impl(
371369

372370
_origin_main_program = loss.block.program
373371
_origin_startup_program = startup_program
374-
from paddle.fluid.incubate.fleet.parameter_server.ir import (
375-
public as public,
376-
)
372+
from paddle.incubate.fleet.parameter_server.ir import public as public
377373

378374
compiled_config = public.CompileTimeStrategy(
379375
_origin_main_program,
@@ -409,14 +405,14 @@ def minimize_impl(
409405
}
410406
else:
411407
loss.block.program = main_program
412-
fluid.framework.switch_startup_program(startup_program)
408+
paddle.framework.switch_startup_program(startup_program)
413409

414410
elif self.role_maker._is_server():
415411
main_program, startup_program = self._build_pserver_programs(
416412
compiled_config
417413
)
418414
loss.block.program = main_program
419-
fluid.framework.switch_startup_program(startup_program)
415+
paddle.framework.switch_startup_program(startup_program)
420416
return None, None
421417

422418
def _disable_strategy(self, dist_strategy):

python/paddle/fluid/device_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _gen_worker_desc(self, trainer_desc):
123123
hogwild.stat_var_names.extend([i])
124124
downpour.stat_var_names.extend([i])
125125

126-
from paddle.fluid.incubate.fleet.parameter_server import version
126+
from paddle.incubate.fleet.parameter_server import version
127127

128128
if (
129129
version.is_transpiler()
@@ -271,7 +271,7 @@ def _gen_worker_desc(self, trainer_desc):
271271
for i in opt_info["stat_var_names"]:
272272
downpour.stat_var_names.extend([i])
273273

274-
from paddle.fluid.incubate.fleet.parameter_server import version
274+
from paddle.incubate.fleet.parameter_server import version
275275

276276
if (
277277
version.is_transpiler()

0 commit comments

Comments
 (0)