Skip to content

Commit ac596a3

Browse files
authored
Feature/switch program (#5932)
* Unify fluid submodules to fluid module Change books just use `import fluid`, not submodules * Remove g_main_program/g_startup_program Use default_main_program/default_startup_program instead * Typo * Add API for switch default program * Two functions: switch_main_program/switch_startup_program * A guard: program_guard. Users can use the `with` statement change default programs * Change unittests in `test_layers` * Fix CI * Fix CI * Fix CI
1 parent 35453df commit ac596a3

File tree

2 files changed

+188
-162
lines changed

2 files changed

+188
-162
lines changed

python/paddle/v2/fluid/framework.py

+78-1
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import numpy as np
44
from . import core
55
import proto.framework_pb2 as framework_pb2
6+
import contextlib
67

78
__all__ = [
89
'Block', 'Variable', 'Program', 'Operator', 'default_startup_program',
9-
'default_main_program'
10+
'default_main_program', 'program_guard', 'switch_startup_program',
11+
'switch_main_program'
1012
]
1113

1214

@@ -659,8 +661,83 @@ def __init__(self, block, shape, dtype, **kwargs):
659661

660662

661663
def default_startup_program():
664+
"""
665+
Get default startup program. In startup program, Paddle will initialize
666+
parameters, initialize nccl handle, etc.
667+
668+
Returns:
669+
Program: startup program
670+
"""
662671
return _startup_program_
663672

664673

665674
def default_main_program():
675+
"""
676+
Get default main program. The main program is used for training or testing.
677+
678+
Returns:
679+
Program: main program
680+
"""
666681
return _main_program_
682+
683+
684+
def switch_main_program(program):
685+
"""
686+
Switch the main program to a new program.
687+
688+
Args:
689+
program(Program): The new main program
690+
691+
Returns:
692+
Program: The previous main program
693+
"""
694+
global _main_program_
695+
prev_program = _main_program_
696+
_main_program_ = program
697+
return prev_program
698+
699+
700+
def switch_startup_program(program):
701+
"""
702+
Switch the startup program to a new program
703+
Args:
704+
program(Program): The new startup program
705+
706+
Returns:
707+
Program: The previous startup program
708+
"""
709+
global _startup_program_
710+
prev_program = _startup_program_
711+
_startup_program_ = program
712+
return prev_program
713+
714+
715+
@contextlib.contextmanager
716+
def program_guard(main_program, startup_program=None):
717+
"""
718+
Switch program with `with` statement
719+
720+
Examples:
721+
>>> with program_guard(Program()):
722+
>>> data = fluid.layers.data(...)
723+
>>> hidden = fluid.layers.fc(...)
724+
725+
Args:
726+
main_program(Program): New main program inside `with` statement
727+
startup_program(Program): New startup program inside `with` statement.
728+
None means do not change startup program.
729+
730+
Returns:
731+
None
732+
"""
733+
if not isinstance(main_program, Program):
734+
raise TypeError("main_program should be Program")
735+
main_program = switch_main_program(main_program)
736+
if startup_program is not None:
737+
if not isinstance(startup_program, Program):
738+
raise TypeError("startup_program should be Program")
739+
startup_program = switch_startup_program(startup_program)
740+
yield
741+
switch_main_program(main_program)
742+
if startup_program is not None:
743+
switch_startup_program(startup_program)

0 commit comments

Comments
 (0)