|
3 | 3 | import numpy as np
|
4 | 4 | from . import core
|
5 | 5 | import proto.framework_pb2 as framework_pb2
|
| 6 | +import contextlib |
6 | 7 |
|
7 | 8 | __all__ = [
|
8 | 9 | 'Block', 'Variable', 'Program', 'Operator', 'default_startup_program',
|
9 |
| - 'default_main_program' |
| 10 | + 'default_main_program', 'program_guard', 'switch_startup_program', |
| 11 | + 'switch_main_program' |
10 | 12 | ]
|
11 | 13 |
|
12 | 14 |
|
@@ -659,8 +661,83 @@ def __init__(self, block, shape, dtype, **kwargs):
|
659 | 661 |
|
660 | 662 |
|
661 | 663 | def default_startup_program():
|
| 664 | + """ |
| 665 | + Get default startup program. In startup program, Paddle will initialize |
| 666 | + parameters, initialize nccl handle, etc. |
| 667 | + |
| 668 | + Returns: |
| 669 | + Program: startup program |
| 670 | + """ |
662 | 671 | return _startup_program_
|
663 | 672 |
|
664 | 673 |
|
665 | 674 | def default_main_program():
|
| 675 | + """ |
| 676 | + Get default main program. The main program is used for training or testing. |
| 677 | + |
| 678 | + Returns: |
| 679 | + Program: main program |
| 680 | + """ |
666 | 681 | return _main_program_
|
| 682 | + |
| 683 | + |
| 684 | +def switch_main_program(program): |
| 685 | + """ |
| 686 | + Switch the main program to a new program. |
| 687 | + |
| 688 | + Args: |
| 689 | + program(Program): The new main program |
| 690 | +
|
| 691 | + Returns: |
| 692 | + Program: The previous main program |
| 693 | + """ |
| 694 | + global _main_program_ |
| 695 | + prev_program = _main_program_ |
| 696 | + _main_program_ = program |
| 697 | + return prev_program |
| 698 | + |
| 699 | + |
| 700 | +def switch_startup_program(program): |
| 701 | + """ |
| 702 | + Switch the startup program to a new program |
| 703 | + Args: |
| 704 | + program(Program): The new startup program |
| 705 | +
|
| 706 | + Returns: |
| 707 | + Program: The previous startup program |
| 708 | + """ |
| 709 | + global _startup_program_ |
| 710 | + prev_program = _startup_program_ |
| 711 | + _startup_program_ = program |
| 712 | + return prev_program |
| 713 | + |
| 714 | + |
| 715 | +@contextlib.contextmanager |
| 716 | +def program_guard(main_program, startup_program=None): |
| 717 | + """ |
| 718 | + Switch program with `with` statement |
| 719 | + |
| 720 | + Examples: |
| 721 | + >>> with program_guard(Program()): |
| 722 | + >>> data = fluid.layers.data(...) |
| 723 | + >>> hidden = fluid.layers.fc(...) |
| 724 | + |
| 725 | + Args: |
| 726 | + main_program(Program): New main program inside `with` statement |
| 727 | + startup_program(Program): New startup program inside `with` statement. |
| 728 | + None means do not change startup program. |
| 729 | +
|
| 730 | + Returns: |
| 731 | + None |
| 732 | + """ |
| 733 | + if not isinstance(main_program, Program): |
| 734 | + raise TypeError("main_program should be Program") |
| 735 | + main_program = switch_main_program(main_program) |
| 736 | + if startup_program is not None: |
| 737 | + if not isinstance(startup_program, Program): |
| 738 | + raise TypeError("startup_program should be Program") |
| 739 | + startup_program = switch_startup_program(startup_program) |
| 740 | + yield |
| 741 | + switch_main_program(main_program) |
| 742 | + if startup_program is not None: |
| 743 | + switch_startup_program(startup_program) |
0 commit comments