|
15 | 15 | from __future__ import print_function
|
16 | 16 | import warnings
|
17 | 17 | import paddle
|
| 18 | +from paddle.fluid.framework import dygraph_only |
18 | 19 | from paddle.fluid import compiler
|
19 | 20 | from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
|
20 | 21 | from .strategy_compiler import StrategyCompiler
|
|
23 | 24 | from .runtime_factory import RuntimeFactory
|
24 | 25 | from .util_factory import UtilFactory
|
25 | 26 | from paddle.fluid.wrapped_decorator import wrap_decorator
|
| 27 | +from paddle.fluid.dygraph import parallel_helper |
26 | 28 |
|
27 | 29 |
|
28 | 30 | def _inited_runtime_handler_(func):
|
@@ -178,6 +180,12 @@ def init(self, role_maker=None, is_collective=False):
|
178 | 180 | "`role_maker` should be subclass of `RoleMakerBase`, but got {}".
|
179 | 181 | format(type(role_maker)))
|
180 | 182 | self.strategy_compiler = StrategyCompiler()
|
| 183 | + if paddle.fluid.framework.in_dygraph_mode(): |
| 184 | + if parallel_helper._is_parallel_ctx_initialized(): |
| 185 | + warnings.warn( |
| 186 | + "The dygraph parallel environment has been initialized.") |
| 187 | + else: |
| 188 | + paddle.distributed.init_parallel_env() |
181 | 189 | return None
|
182 | 190 |
|
183 | 191 | def is_first_worker(self):
|
@@ -587,12 +595,344 @@ def distributed_optimizer(self, optimizer, strategy=None):
|
587 | 595 |
|
588 | 596 | """
|
589 | 597 | self.user_defined_optimizer = optimizer
|
| 598 | + if paddle.fluid.framework.in_dygraph_mode(): |
| 599 | + return self |
| 600 | + |
590 | 601 | if strategy == None:
|
591 | 602 | strategy = DistributedStrategy()
|
592 | 603 | self.user_defined_strategy = strategy
|
593 | 604 | self.valid_strategy = None
|
594 | 605 | return self
|
595 | 606 |
|
| 607 | + @dygraph_only |
| 608 | + def distributed_model(self, model): |
| 609 | + """ |
| 610 | + Return dygraph distributed data parallel model (Layer) |
| 611 | + Only work in dygraph mode |
| 612 | +
|
| 613 | + Examples: |
| 614 | + .. code-block:: python |
| 615 | + import paddle |
| 616 | + import paddle.nn as nn |
| 617 | + from paddle.distributed import fleet |
| 618 | +
|
| 619 | + class LinearNet(nn.Layer): |
| 620 | + def __init__(self): |
| 621 | + super(LinearNet, self).__init__() |
| 622 | + self._linear1 = nn.Linear(10, 10) |
| 623 | + self._linear2 = nn.Linear(10, 1) |
| 624 | +
|
| 625 | + def forward(self, x): |
| 626 | + return self._linear2(self._linear1(x)) |
| 627 | +
|
| 628 | + def train(): |
| 629 | + # 1. enable dynamic mode |
| 630 | + paddle.disable_static() |
| 631 | +
|
| 632 | + # 2. initialize fleet environment |
| 633 | + fleet.init(is_collective=True) |
| 634 | +
|
| 635 | + # 3. create layer & optimizer |
| 636 | + layer = LinearNet() |
| 637 | + loss_fn = nn.MSELoss() |
| 638 | + adam = paddle.optimizer.Adam( |
| 639 | + learning_rate=0.001, parameters=layer.parameters()) |
| 640 | +
|
| 641 | + # 4. get data_parallel model using fleet |
| 642 | + adam = fleet.distributed_optimizer(adam) |
| 643 | + dp_layer = fleet.distributed_model(layer) |
| 644 | +
|
| 645 | + # 5. run layer |
| 646 | + inputs = paddle.randn([10, 10], 'float32') |
| 647 | + outputs = dp_layer(inputs) |
| 648 | + labels = paddle.randn([10, 1], 'float32') |
| 649 | + loss = loss_fn(outputs, labels) |
| 650 | +
|
| 651 | + print("loss:", loss.numpy()) |
| 652 | +
|
| 653 | + loss = dp_layer.scale_loss(loss) |
| 654 | + loss.backward() |
| 655 | + dp_layer.apply_collective_grads() |
| 656 | +
|
| 657 | + adam.step() |
| 658 | + adam.clear_grad() |
| 659 | +
|
| 660 | + if __name__ == '__main__': |
| 661 | + paddle.distributed.spawn(train) |
| 662 | + """ |
| 663 | + assert model is not None |
| 664 | + self.model = paddle.DataParallel(model) |
| 665 | + return self.model |
| 666 | + |
| 667 | + @dygraph_only |
| 668 | + def state_dict(self): |
| 669 | + """ |
| 670 | + Get state dict information from optimizer. |
| 671 | + Only work in dygraph mode |
| 672 | +
|
| 673 | + Returns: |
| 674 | + state_dict(dict) : dict contains all the Tensor used by optimizer |
| 675 | +
|
| 676 | + Examples: |
| 677 | + .. code-block:: python |
| 678 | + import numpy as np |
| 679 | + import paddle |
| 680 | + from paddle.distributed import fleet |
| 681 | +
|
| 682 | + paddle.disable_static() |
| 683 | + fleet.init(is_collective=True) |
| 684 | +
|
| 685 | + value = np.arange(26).reshape(2, 13).astype("float32") |
| 686 | + a = paddle.fluid.dygraph.to_variable(value) |
| 687 | +
|
| 688 | + layer = paddle.nn.Linear(13, 5) |
| 689 | + adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters()) |
| 690 | +
|
| 691 | + adam = fleet.distributed_optimizer(adam) |
| 692 | + dp_layer = fleet.distributed_model(layer) |
| 693 | + state_dict = adam.state_dict() |
| 694 | + """ |
| 695 | + # imitate target optimizer retrieval |
| 696 | + return self.user_defined_optimizer.state_dict() |
| 697 | + |
| 698 | + @dygraph_only |
| 699 | + def set_state_dict(self, state_dict): |
| 700 | + """ |
| 701 | + Load optimizer state dict. |
| 702 | + Only work in dygraph mode |
| 703 | +
|
| 704 | + Args: |
| 705 | + state_dict(dict) : Dict contains all the Tensor needed by optimizer |
| 706 | +
|
| 707 | + Returns: None |
| 708 | +
|
| 709 | + Examples: |
| 710 | + .. code-block:: python |
| 711 | + import numpy as np |
| 712 | + import paddle |
| 713 | + from paddle.distributed import fleet |
| 714 | +
|
| 715 | + paddle.disable_static() |
| 716 | + fleet.init(is_collective=True) |
| 717 | +
|
| 718 | + value = np.arange(26).reshape(2, 13).astype("float32") |
| 719 | + a = paddle.fluid.dygraph.to_variable(value) |
| 720 | +
|
| 721 | + layer = paddle.nn.Linear(13, 5) |
| 722 | + adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters()) |
| 723 | +
|
| 724 | + adam = fleet.distributed_optimizer(adam) |
| 725 | + dp_layer = fleet.distributed_model(layer) |
| 726 | + state_dict = adam.state_dict() |
| 727 | + paddle.framework.save(state_dict, "paddle_dy") |
| 728 | + para_state_dict, opti_state_dict = paddle.framework.load( "paddle_dy") |
| 729 | + adam.set_state_dict(opti_state_dict) |
| 730 | + """ |
| 731 | + # imitate target optimizer retrieval |
| 732 | + return self.user_defined_optimizer.set_state_dict(state_dict) |
| 733 | + |
| 734 | + @dygraph_only |
| 735 | + def set_lr(self, value): |
| 736 | + """ |
| 737 | + Set the value of the learning rate manually in the optimizer. |
| 738 | + Only work in dygraph mode |
| 739 | + |
| 740 | + Args: |
| 741 | + value (float|Tensor): the value of learning rate |
| 742 | +
|
| 743 | + Returns: None |
| 744 | +
|
| 745 | + Examples: |
| 746 | + .. code-block:: python |
| 747 | + import numpy as np |
| 748 | + import paddle |
| 749 | + from paddle.distributed import fleet |
| 750 | +
|
| 751 | + paddle.disable_static() |
| 752 | + fleet.init(is_collective=True) |
| 753 | +
|
| 754 | + value = np.arange(26).reshape(2, 13).astype("float32") |
| 755 | + a = paddle.fluid.dygraph.to_variable(value) |
| 756 | +
|
| 757 | + layer = paddle.nn.Linear(13, 5) |
| 758 | + adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters()) |
| 759 | +
|
| 760 | + adam = fleet.distributed_optimizer(adam) |
| 761 | + dp_layer = fleet.distributed_model(layer) |
| 762 | +
|
| 763 | + lr_list = [0.2, 0.3, 0.4, 0.5, 0.6] |
| 764 | + for i in range(5): |
| 765 | + adam.set_lr(lr_list[i]) |
| 766 | + lr = adam.get_lr() |
| 767 | + print("current lr is {}".format(lr)) |
| 768 | + # Print: |
| 769 | + # current lr is 0.2 |
| 770 | + # current lr is 0.3 |
| 771 | + # current lr is 0.4 |
| 772 | + # current lr is 0.5 |
| 773 | + # current lr is 0.6 |
| 774 | + """ |
| 775 | + # imitate target optimizer retrieval |
| 776 | + return self.user_defined_optimizer.set_lr(value) |
| 777 | + |
| 778 | + @dygraph_only |
| 779 | + def get_lr(self): |
| 780 | + """ |
| 781 | + Get current step learning rate. |
| 782 | + Only work in dygraph mode |
| 783 | +
|
| 784 | + Returns: |
| 785 | + float: The learning rate of the current step. |
| 786 | +
|
| 787 | + Examples: |
| 788 | + .. code-block:: python |
| 789 | + import numpy as np |
| 790 | + import paddle |
| 791 | + from paddle.distributed import fleet |
| 792 | +
|
| 793 | + paddle.disable_static() |
| 794 | + fleet.init(is_collective=True) |
| 795 | +
|
| 796 | + value = np.arange(26).reshape(2, 13).astype("float32") |
| 797 | + a = paddle.fluid.dygraph.to_variable(value) |
| 798 | +
|
| 799 | + layer = paddle.nn.Linear(13, 5) |
| 800 | + adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters()) |
| 801 | +
|
| 802 | + adam = fleet.distributed_optimizer(adam) |
| 803 | + dp_layer = fleet.distributed_model(layer) |
| 804 | +
|
| 805 | + lr = adam.get_lr() |
| 806 | + print(lr) # 0.01 |
| 807 | + """ |
| 808 | + # imitate target optimizer retrieval |
| 809 | + return self.user_defined_optimizer.get_lr() |
| 810 | + |
| 811 | + @dygraph_only |
| 812 | + def step(self): |
| 813 | + """ |
| 814 | + Execute the optimizer once. |
| 815 | + Only work in dygraph mode |
| 816 | +
|
| 817 | + Returns: None |
| 818 | +
|
| 819 | + Examples: |
| 820 | + .. code-block:: python |
| 821 | +
|
| 822 | + import paddle |
| 823 | + import paddle.nn as nn |
| 824 | + from paddle.distributed import fleet |
| 825 | +
|
| 826 | + class LinearNet(nn.Layer): |
| 827 | + def __init__(self): |
| 828 | + super(LinearNet, self).__init__() |
| 829 | + self._linear1 = nn.Linear(10, 10) |
| 830 | + self._linear2 = nn.Linear(10, 1) |
| 831 | +
|
| 832 | + def forward(self, x): |
| 833 | + return self._linear2(self._linear1(x)) |
| 834 | +
|
| 835 | + def train(): |
| 836 | + # 1. enable dynamic mode |
| 837 | + paddle.disable_static() |
| 838 | +
|
| 839 | + # 2. initialize fleet environment |
| 840 | + fleet.init(is_collective=True) |
| 841 | +
|
| 842 | + # 3. create layer & optimizer |
| 843 | + layer = LinearNet() |
| 844 | + loss_fn = nn.MSELoss() |
| 845 | + adam = paddle.optimizer.Adam( |
| 846 | + learning_rate=0.001, parameters=layer.parameters()) |
| 847 | +
|
| 848 | + # 4. get data_parallel model using fleet |
| 849 | + adam = fleet.distributed_optimizer(adam) |
| 850 | + dp_layer = fleet.distributed_model(layer) |
| 851 | +
|
| 852 | + # 5. run layer |
| 853 | + inputs = paddle.randn([10, 10], 'float32') |
| 854 | + outputs = dp_layer(inputs) |
| 855 | + labels = paddle.randn([10, 1], 'float32') |
| 856 | + loss = loss_fn(outputs, labels) |
| 857 | +
|
| 858 | + print("loss:", loss.numpy()) |
| 859 | +
|
| 860 | + loss = dp_layer.scale_loss(loss) |
| 861 | + loss.backward() |
| 862 | + dp_layer.apply_collective_grads() |
| 863 | +
|
| 864 | + adam.step() |
| 865 | + adam.clear_grad() |
| 866 | +
|
| 867 | + if __name__ == '__main__': |
| 868 | + paddle.distributed.spawn(train) |
| 869 | +
|
| 870 | + """ |
| 871 | + # imitate target optimizer retrieval |
| 872 | + return self.user_defined_optimizer.step() |
| 873 | + |
| 874 | + @dygraph_only |
| 875 | + def clear_grad(self): |
| 876 | + """ |
| 877 | + Execute the optimizer once. |
| 878 | + Only work in dygraph mode |
| 879 | + |
| 880 | + Returns: None |
| 881 | +
|
| 882 | + Examples: |
| 883 | + .. code-block:: python |
| 884 | +
|
| 885 | + import paddle |
| 886 | + import paddle.nn as nn |
| 887 | + from paddle.distributed import fleet |
| 888 | +
|
| 889 | + class LinearNet(nn.Layer): |
| 890 | + def __init__(self): |
| 891 | + super(LinearNet, self).__init__() |
| 892 | + self._linear1 = nn.Linear(10, 10) |
| 893 | + self._linear2 = nn.Linear(10, 1) |
| 894 | +
|
| 895 | + def forward(self, x): |
| 896 | + return self._linear2(self._linear1(x)) |
| 897 | +
|
| 898 | + def train(): |
| 899 | + # 1. enable dynamic mode |
| 900 | + paddle.disable_static() |
| 901 | +
|
| 902 | + # 2. initialize fleet environment |
| 903 | + fleet.init(is_collective=True) |
| 904 | +
|
| 905 | + # 3. create layer & optimizer |
| 906 | + layer = LinearNet() |
| 907 | + loss_fn = nn.MSELoss() |
| 908 | + adam = paddle.optimizer.Adam( |
| 909 | + learning_rate=0.001, parameters=layer.parameters()) |
| 910 | +
|
| 911 | + # 4. get data_parallel model using fleet |
| 912 | + adam = fleet.distributed_optimizer(adam) |
| 913 | + dp_layer = fleet.distributed_model(layer) |
| 914 | +
|
| 915 | + # 5. run layer |
| 916 | + inputs = paddle.randn([10, 10], 'float32') |
| 917 | + outputs = dp_layer(inputs) |
| 918 | + labels = paddle.randn([10, 1], 'float32') |
| 919 | + loss = loss_fn(outputs, labels) |
| 920 | +
|
| 921 | + print("loss:", loss.numpy()) |
| 922 | +
|
| 923 | + loss = dp_layer.scale_loss(loss) |
| 924 | + loss.backward() |
| 925 | + dp_layer.apply_collective_grads() |
| 926 | +
|
| 927 | + adam.step() |
| 928 | + adam.clear_grad() |
| 929 | +
|
| 930 | + if __name__ == '__main__': |
| 931 | + paddle.distributed.spawn(train) |
| 932 | + """ |
| 933 | + # imitate target optimizer retrieval |
| 934 | + return self.user_defined_optimizer.clear_grad() |
| 935 | + |
596 | 936 | def minimize(self,
|
597 | 937 | loss,
|
598 | 938 | startup_program=None,
|
@@ -642,6 +982,11 @@ def minimize(self,
|
642 | 982 | # for more examples, please reference https://github.com/PaddlePaddle/FleetX
|
643 | 983 |
|
644 | 984 | """
|
| 985 | + if paddle.fluid.framework.in_dygraph_mode(): |
| 986 | + # imitate target optimizer retrieval |
| 987 | + target_opt = self.user_defined_optimizer |
| 988 | + return target_opt.minimize(loss) |
| 989 | + |
645 | 990 | context = {}
|
646 | 991 | # cache original feed forward program
|
647 | 992 | self.origin_main_program = loss.block.program
|
|
0 commit comments