|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 |
| -import os |
15 |
| -import unittest |
16 |
| - |
17 |
| -import paddle |
18 |
| -from paddle.distributed.fleet.base import role_maker |
19 |
| - |
20 |
| -paddle.enable_static() |
21 |
| - |
22 |
| - |
23 |
| -class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): |
24 |
| - def setUp(self): |
25 |
| - os.environ["PADDLE_PSERVER_NUMS"] = "2" |
26 |
| - os.environ["PADDLE_TRAINERS_NUM"] = "2" |
27 |
| - os.environ["POD_IP"] = "127.0.0.1" |
28 |
| - os.environ["PADDLE_PORT"] = "36001" |
29 |
| - os.environ["PADDLE_TRAINER_ID"] = "0" |
30 |
| - os.environ["PADDLE_TRAINERS_NUM"] = "2" |
31 |
| - os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = ( |
32 |
| - "127.0.0.1:36001,127.0.0.2:36001" |
33 |
| - ) |
34 |
| - |
35 |
| - def test_a_sync_optimizer_trainer(self): |
36 |
| - os.environ["TRAINING_ROLE"] = "TRAINER" |
37 |
| - from paddle.distributed import fleet |
38 |
| - |
39 |
| - main_program = paddle.base.Program() |
40 |
| - startup_program = paddle.base.Program() |
41 |
| - |
42 |
| - paddle.base.framework.switch_main_program(main_program) |
43 |
| - paddle.base.framework.switch_startup_program(startup_program) |
44 |
| - |
45 |
| - fleet.init(role_maker.PaddleCloudRoleMaker()) |
46 |
| - input_x = paddle.static.data(name="x", shape=[-1, 32], dtype='float32') |
47 |
| - input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64') |
48 |
| - |
49 |
| - fc_1 = paddle.static.nn.fc(x=input_x, size=64, activation='tanh') |
50 |
| - fc_2 = paddle.static.nn.fc(x=fc_1, size=64, activation='tanh') |
51 |
| - prediction = paddle.static.nn.fc(x=[fc_2], size=2, activation='softmax') |
52 |
| - cost = paddle.nn.functional.cross_entropy( |
53 |
| - input=prediction, label=input_y, reduction='none', use_softmax=False |
54 |
| - ) |
55 |
| - avg_cost = paddle.mean(x=cost) |
56 |
| - |
57 |
| - strategy = paddle.distributed.fleet.DistributedStrategy() |
58 |
| - strategy.a_sync = True |
59 |
| - strategy.a_sync_configs = {"k_steps": 100, "launch_barrier": False} |
60 |
| - |
61 |
| - optimizer = paddle.optimizer.SGD(learning_rate=0.01) |
62 |
| - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) |
63 |
| - |
64 |
| - optimizer.minimize(avg_cost) |
65 |
| - |
66 |
| - def test_a_sync_optimizer_pserver(self): |
67 |
| - os.environ["TRAINING_ROLE"] = "PSERVER" |
68 |
| - from paddle.distributed import fleet |
69 |
| - |
70 |
| - main_program = paddle.base.Program() |
71 |
| - startup_program = paddle.base.Program() |
72 |
| - |
73 |
| - paddle.base.framework.switch_main_program(main_program) |
74 |
| - paddle.base.framework.switch_startup_program(startup_program) |
75 |
| - |
76 |
| - fleet.init(role_maker.PaddleCloudRoleMaker()) |
77 |
| - input_x = paddle.static.data(name="x", shape=[-1, 32], dtype='float32') |
78 |
| - input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64') |
79 |
| - |
80 |
| - fc_1 = paddle.static.nn.fc(x=input_x, size=64, activation='tanh') |
81 |
| - fc_2 = paddle.static.nn.fc(x=fc_1, size=64, activation='tanh') |
82 |
| - prediction = paddle.static.nn.fc(x=[fc_2], size=2, activation='softmax') |
83 |
| - cost = paddle.nn.functional.cross_entropy( |
84 |
| - input=prediction, label=input_y, reduction='none', use_softmax=False |
85 |
| - ) |
86 |
| - avg_cost = paddle.mean(x=cost) |
87 |
| - |
88 |
| - strategy = paddle.distributed.fleet.DistributedStrategy() |
89 |
| - strategy.a_sync = True |
90 |
| - strategy.a_sync_configs = {"k_steps": 100, "launch_barrier": False} |
91 |
| - |
92 |
| - optimizer = paddle.optimizer.SGD(learning_rate=0.01) |
93 |
| - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) |
94 |
| - optimizer.minimize(avg_cost) |
95 |
| - |
96 |
| - prog = paddle.base.default_main_program() |
97 |
| - self.assertEqual(prog.global_block().ops[0].type, "listen_and_serv") |
98 |
| - |
99 |
| - |
100 |
| -if __name__ == "__main__": |
101 |
| - unittest.main() |
0 commit comments