|
18 | 18 | import paddle.nn as nn
|
19 | 19 | import paddle.nn.functional as F
|
20 | 20 | from paddleseg.models import layers
|
| 21 | +import paddle.distributed as dist |
| 22 | +import math |
| 23 | + |
| 24 | + |
| 25 | +class _AllReduce(paddle.autograd.PyLayer): |
| 26 | + |
| 27 | + @staticmethod |
| 28 | + def forward(ctx, input): |
| 29 | + input_list = [ |
| 30 | + paddle.zeros_like(input) for k in range(dist.get_world_size()) |
| 31 | + ] |
| 32 | + # Use allgather instead of allreduce since I don't trust in-place operations .. |
| 33 | + dist.all_gather(input_list, input, sync_op=True) |
| 34 | + inputs = paddle.stack(input_list, axis=0) |
| 35 | + return paddle.sum(inputs, axis=0) |
| 36 | + |
| 37 | + @staticmethod |
| 38 | + def backward(ctx, grad_output): |
| 39 | + dist.all_reduce(grad_output, sync_op=True) |
| 40 | + return grad_output |
| 41 | + |
| 42 | + |
| 43 | +def differentiable_all_reduce(input): |
| 44 | + """ |
| 45 | + Differentiable counterpart of `dist.all_reduce`. |
| 46 | + """ |
| 47 | + if (not dist.is_available() or not dist.is_initialized() |
| 48 | + or dist.get_world_size() == 1): |
| 49 | + return input |
| 50 | + return _AllReduce.apply(input) |
| 51 | + |
| 52 | + |
| 53 | +class NaiveSyncBatchNorm(nn.BatchNorm2D): |
| 54 | + |
| 55 | + def __init__(self, *args, stats_mode="", **kwargs): |
| 56 | + super().__init__(*args, **kwargs) |
| 57 | + assert stats_mode in ["", "N"] |
| 58 | + self._stats_mode = stats_mode |
| 59 | + |
| 60 | + def forward(self, input): |
| 61 | + if dist.get_world_size() == 1 or not self.training: |
| 62 | + return super().forward(input) |
| 63 | + |
| 64 | + B, C = input.shape[0], input.shape[1] |
| 65 | + |
| 66 | + mean = paddle.mean(input, axis=[0, 2, 3]) |
| 67 | + meansqr = paddle.mean(input * input, axis=[0, 2, 3]) |
| 68 | + |
| 69 | + if self._stats_mode == "": |
| 70 | + assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.' |
| 71 | + vec = paddle.concat([mean, meansqr], axis=0) |
| 72 | + vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size()) |
| 73 | + mean, meansqr = paddle.split(vec, [C, C]) |
| 74 | + momentum = 1 - self._momentum # NOTE: paddle has reverse momentum defination |
| 75 | + else: |
| 76 | + if B == 0: |
| 77 | + vec = paddle.zeros([2 * C + 1], dtype=mean.dtype) |
| 78 | + vec = vec + input.sum( |
| 79 | + ) # make sure there is gradient w.r.t input |
| 80 | + else: |
| 81 | + vec = paddle.concat( |
| 82 | + [ |
| 83 | + mean, |
| 84 | + meansqr, |
| 85 | + paddle.ones([1], dtype=mean.dtype), |
| 86 | + ], |
| 87 | + axis=0, |
| 88 | + ) |
| 89 | + vec = differentiable_all_reduce(vec * B) |
| 90 | + |
| 91 | + total_batch = vec[-1].detach() |
| 92 | + momentum = total_batch.clip(max=1) * ( |
| 93 | + 1 - self._momentum) # no update if total_batch is 0 |
| 94 | + mean, meansqr, _ = paddle.split( |
| 95 | + vec / total_batch.clip(min=1), |
| 96 | + [C, C, int(vec.shape[0] - 2 * C)]) # avoid div-by-zero |
| 97 | + |
| 98 | + var = meansqr - mean * mean |
| 99 | + invstd = paddle.rsqrt(var + self._epsilon) |
| 100 | + scale = self.weight * invstd |
| 101 | + bias = self.bias - mean * scale |
| 102 | + scale = scale.reshape([1, -1, 1, 1]) |
| 103 | + bias = bias.reshape([1, -1, 1, 1]) |
| 104 | + |
| 105 | + tmp_mean = self._mean + momentum * (mean.detach() - self._mean) |
| 106 | + self._mean.set_value(tmp_mean) |
| 107 | + tmp_variance = self._variance + (momentum * |
| 108 | + (var.detach() - self._variance)) |
| 109 | + self._variance.set_value(tmp_variance) |
| 110 | + ret = input * scale + bias |
| 111 | + return ret |
| 112 | + |
| 113 | + @classmethod |
| 114 | + def convert_sync_batchnorm(cls, layer): |
| 115 | + layer_output = layer |
| 116 | + if isinstance(layer, nn.BatchNorm2D): |
| 117 | + |
| 118 | + layer_output = NaiveSyncBatchNorm(layer._num_features, |
| 119 | + layer._momentum, layer._epsilon, |
| 120 | + layer._weight_attr, |
| 121 | + layer._bias_attr, |
| 122 | + layer._data_format, layer._name) |
| 123 | + |
| 124 | + if (layer._weight_attr is not False |
| 125 | + and layer._bias_attr is not False): |
| 126 | + with paddle.no_grad(): |
| 127 | + layer_output.weight = layer.weight |
| 128 | + layer_output.bias = layer.bias |
| 129 | + layer_output._mean = layer._mean |
| 130 | + layer_output._variance = layer._variance |
| 131 | + |
| 132 | + for name, sublayer in layer.named_children(): |
| 133 | + layer_output.add_sublayer(name, |
| 134 | + cls.convert_sync_batchnorm(sublayer)) |
| 135 | + del layer |
| 136 | + return layer_output |
21 | 137 |
|
22 | 138 |
|
23 | 139 | def SyncBatchNorm(*args, **kwargs):
|
24 | 140 | """In cpu environment nn.SyncBatchNorm does not have kernel so use nn.BatchNorm2D instead"""
|
25 | 141 | if paddle.get_device() == 'cpu' or os.environ.get(
|
26 |
| - 'PADDLESEG_EXPORT_STAGE') or 'xpu' in paddle.get_device( |
27 |
| - ) or 'npu' in paddle.get_device(): |
| 142 | + 'PADDLESEG_EXPORT_STAGE') or 'xpu' in paddle.get_device(): |
28 | 143 | return nn.BatchNorm2D(*args, **kwargs)
|
29 | 144 | elif paddle.distributed.ParallelEnv().nranks == 1:
|
30 | 145 | return nn.BatchNorm2D(*args, **kwargs)
|
| 146 | + elif 'npu' in paddle.get_device(): |
| 147 | + return NaiveSyncBatchNorm(*args, **kwargs) |
31 | 148 | else:
|
32 | 149 | return nn.SyncBatchNorm(*args, **kwargs)
|
33 | 150 |
|
|
0 commit comments