|
| 1 | +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import logging |
| 16 | +import re |
| 17 | +from collections import OrderedDict |
| 18 | + |
| 19 | +import numpy as np |
| 20 | + |
| 21 | +import paddle |
| 22 | +from paddle.distributed.fleet.utils import mix_precision_utils |
| 23 | +from paddle.nn import Linear, ReLU |
| 24 | + |
| 25 | +logging.basicConfig(level="INFO", format="%(message)s") |
| 26 | + |
| 27 | + |
| 28 | +class MLP(paddle.nn.Layer): |
| 29 | + def __init__(self, linear_size=1000): |
| 30 | + super().__init__() |
| 31 | + |
| 32 | + self._linear1 = Linear(linear_size, linear_size) |
| 33 | + self._linear2 = Linear(linear_size, linear_size) |
| 34 | + self._linear3 = Linear(linear_size, 10) |
| 35 | + self._relu = ReLU() |
| 36 | + |
| 37 | + def forward(self, inputs): |
| 38 | + y = self._linear1(inputs) |
| 39 | + y = self._linear2(y) |
| 40 | + y = self._linear3(y) |
| 41 | + y = self._relu(y) |
| 42 | + return y |
| 43 | + |
| 44 | + |
| 45 | +class RandomDataset(paddle.io.Dataset): |
| 46 | + def __init__(self, num_samples=200, linear_size=1000): |
| 47 | + self.num_samples = num_samples |
| 48 | + self.linear_size = linear_size |
| 49 | + self.samples = [] |
| 50 | + for i in range(num_samples): |
| 51 | + img = np.random.rand(self.linear_size).astype('float32') |
| 52 | + self.samples.append(img) |
| 53 | + |
| 54 | + def __getitem__(self, idx): |
| 55 | + return self.samples[idx] |
| 56 | + |
| 57 | + def __len__(self): |
| 58 | + return self.num_samples |
| 59 | + |
| 60 | + |
| 61 | +def create_optimizer(model, use_pure_bf16, use_main_grad): |
| 62 | + if use_main_grad: |
| 63 | + assert use_pure_bf16 |
| 64 | + model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") |
| 65 | + optimizer = paddle.optimizer.AdamW( |
| 66 | + parameters=model.parameters(), |
| 67 | + learning_rate=0.00001, |
| 68 | + weight_decay=0.00001, |
| 69 | + grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0), |
| 70 | + multi_precision=use_pure_bf16, |
| 71 | + ) |
| 72 | + if use_main_grad: |
| 73 | + optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer) |
| 74 | + |
| 75 | + return optimizer |
| 76 | + |
| 77 | + |
| 78 | +def save_model_parameters(model): |
| 79 | + param_dict = OrderedDict() |
| 80 | + for param in model.parameters(): |
| 81 | + param_dict[param.name] = param |
| 82 | + return param_dict |
| 83 | + |
| 84 | + |
| 85 | +def _extract_linear_order(param_names): |
| 86 | + # for param_names from model.state_dict, they are as like: ["_linear1.weight", "_linear1.bias"] |
| 87 | + # for master weight names from optimizer.state_dict, they are as like: ["linear_6.w_0", "linear_6.b_0"] |
| 88 | + param_order = [] |
| 89 | + for name in param_names: |
| 90 | + param_id = re.findall(r"\d+", name) |
| 91 | + assert len(param_id) >= 1 |
| 92 | + param_order.append(int(param_id[0])) |
| 93 | + return list(set(param_order)) |
| 94 | + |
| 95 | + |
| 96 | +def _extract_param_order_dict(model_param_dict_o1, model_param_dict_o2): |
| 97 | + param_names_o1 = list(model_param_dict_o1.keys()) |
| 98 | + param_order_o1 = _extract_linear_order(param_names_o1) |
| 99 | + param_order_o1.sort() |
| 100 | + |
| 101 | + param_names_o2 = list(model_param_dict_o2.keys()) |
| 102 | + param_order_o2 = _extract_linear_order(param_names_o2) |
| 103 | + param_order_o2.sort() |
| 104 | + |
| 105 | + assert len(param_order_o1) == len(param_order_o2) |
| 106 | + |
| 107 | + param_order_dict = {} |
| 108 | + for i in range(len(param_order_o1)): |
| 109 | + param_order_dict[param_order_o2[i]] = param_order_o1[i] |
| 110 | + |
| 111 | + logging.info(f"-- param_names_o1: {param_names_o1}") |
| 112 | + logging.info(f"-- param_names_o2: {param_names_o2}") |
| 113 | + logging.info(f"param_order_dict: {param_order_dict}") |
| 114 | + return param_order_dict |
| 115 | + |
| 116 | + |
| 117 | +def compare_state_dict( |
| 118 | + model_param_dict_o1, model_param_dict_o2, optimizer_state_dict_o2 |
| 119 | +): |
| 120 | + master_weights = None |
| 121 | + if optimizer_state_dict_o2.get("master_weights", None) is not None: |
| 122 | + master_weights = optimizer_state_dict_o2["master_weights"] |
| 123 | + assert master_weights is not None |
| 124 | + master_weights_names = list(master_weights.keys()) |
| 125 | + |
| 126 | + param_names = list(model_param_dict_o1.keys()) |
| 127 | + param_order_dict = _extract_param_order_dict( |
| 128 | + model_param_dict_o1, model_param_dict_o2 |
| 129 | + ) |
| 130 | + param_master_pair = [] |
| 131 | + |
| 132 | + # We assume the order of params in param_names and master_weights_names is the same. |
| 133 | + param_id = 0 |
| 134 | + for master_weight_name in master_weights_names: |
| 135 | + master_weight_id = re.findall(r"\d+", master_weight_name)[0] |
| 136 | + param_id = param_order_dict[int(master_weight_id)] |
| 137 | + for param_name in param_names: |
| 138 | + if ( |
| 139 | + master_weight_name.endswith("w_0") |
| 140 | + and param_name.endswith("weight") |
| 141 | + ) or ( |
| 142 | + master_weight_name.endswith("b_0") |
| 143 | + and param_name.endswith("bias") |
| 144 | + ): |
| 145 | + name_prefix = "linear" + param_id |
| 146 | + if name_prefix in param_name: |
| 147 | + param_master_pair.append([param_name, master_weight_name]) |
| 148 | + |
| 149 | + logging.info(f"-- master_weights_names: {master_weights_names}") |
| 150 | + for pair in param_master_pair: |
| 151 | + param_name = pair[0] |
| 152 | + master_weight_name = pair[1] |
| 153 | + logging.info(f"-- compare {param_name} with {master_weight_name}") |
| 154 | + param_o1 = model_param_dict_o1[param_name] |
| 155 | + master_param_o2 = master_weights[master_weight_name] |
| 156 | + np.testing.assert_array_equal(param_o1.numpy(), master_param_o2.numpy()) |
0 commit comments