Skip to content

Commit 7701372

Browse files
authored
Add npu support (#3717)
1 parent f0abdb3 commit 7701372

File tree

3 files changed

+149
-13
lines changed

3 files changed

+149
-13
lines changed

paddleseg/models/layers/layer_libs.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,133 @@
1818
import paddle.nn as nn
1919
import paddle.nn.functional as F
2020
from paddleseg.models import layers
21+
import paddle.distributed as dist
22+
import math
23+
24+
25+
class _AllReduce(paddle.autograd.PyLayer):
26+
27+
@staticmethod
28+
def forward(ctx, input):
29+
input_list = [
30+
paddle.zeros_like(input) for k in range(dist.get_world_size())
31+
]
32+
# Use allgather instead of allreduce since I don't trust in-place operations ..
33+
dist.all_gather(input_list, input, sync_op=True)
34+
inputs = paddle.stack(input_list, axis=0)
35+
return paddle.sum(inputs, axis=0)
36+
37+
@staticmethod
38+
def backward(ctx, grad_output):
39+
dist.all_reduce(grad_output, sync_op=True)
40+
return grad_output
41+
42+
43+
def differentiable_all_reduce(input):
44+
"""
45+
Differentiable counterpart of `dist.all_reduce`.
46+
"""
47+
if (not dist.is_available() or not dist.is_initialized()
48+
or dist.get_world_size() == 1):
49+
return input
50+
return _AllReduce.apply(input)
51+
52+
53+
class NaiveSyncBatchNorm(nn.BatchNorm2D):
54+
55+
def __init__(self, *args, stats_mode="", **kwargs):
56+
super().__init__(*args, **kwargs)
57+
assert stats_mode in ["", "N"]
58+
self._stats_mode = stats_mode
59+
60+
def forward(self, input):
61+
if dist.get_world_size() == 1 or not self.training:
62+
return super().forward(input)
63+
64+
B, C = input.shape[0], input.shape[1]
65+
66+
mean = paddle.mean(input, axis=[0, 2, 3])
67+
meansqr = paddle.mean(input * input, axis=[0, 2, 3])
68+
69+
if self._stats_mode == "":
70+
assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
71+
vec = paddle.concat([mean, meansqr], axis=0)
72+
vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size())
73+
mean, meansqr = paddle.split(vec, [C, C])
74+
momentum = 1 - self._momentum # NOTE: paddle has reverse momentum defination
75+
else:
76+
if B == 0:
77+
vec = paddle.zeros([2 * C + 1], dtype=mean.dtype)
78+
vec = vec + input.sum(
79+
) # make sure there is gradient w.r.t input
80+
else:
81+
vec = paddle.concat(
82+
[
83+
mean,
84+
meansqr,
85+
paddle.ones([1], dtype=mean.dtype),
86+
],
87+
axis=0,
88+
)
89+
vec = differentiable_all_reduce(vec * B)
90+
91+
total_batch = vec[-1].detach()
92+
momentum = total_batch.clip(max=1) * (
93+
1 - self._momentum) # no update if total_batch is 0
94+
mean, meansqr, _ = paddle.split(
95+
vec / total_batch.clip(min=1),
96+
[C, C, int(vec.shape[0] - 2 * C)]) # avoid div-by-zero
97+
98+
var = meansqr - mean * mean
99+
invstd = paddle.rsqrt(var + self._epsilon)
100+
scale = self.weight * invstd
101+
bias = self.bias - mean * scale
102+
scale = scale.reshape([1, -1, 1, 1])
103+
bias = bias.reshape([1, -1, 1, 1])
104+
105+
tmp_mean = self._mean + momentum * (mean.detach() - self._mean)
106+
self._mean.set_value(tmp_mean)
107+
tmp_variance = self._variance + (momentum *
108+
(var.detach() - self._variance))
109+
self._variance.set_value(tmp_variance)
110+
ret = input * scale + bias
111+
return ret
112+
113+
@classmethod
114+
def convert_sync_batchnorm(cls, layer):
115+
layer_output = layer
116+
if isinstance(layer, nn.BatchNorm2D):
117+
118+
layer_output = NaiveSyncBatchNorm(layer._num_features,
119+
layer._momentum, layer._epsilon,
120+
layer._weight_attr,
121+
layer._bias_attr,
122+
layer._data_format, layer._name)
123+
124+
if (layer._weight_attr is not False
125+
and layer._bias_attr is not False):
126+
with paddle.no_grad():
127+
layer_output.weight = layer.weight
128+
layer_output.bias = layer.bias
129+
layer_output._mean = layer._mean
130+
layer_output._variance = layer._variance
131+
132+
for name, sublayer in layer.named_children():
133+
layer_output.add_sublayer(name,
134+
cls.convert_sync_batchnorm(sublayer))
135+
del layer
136+
return layer_output
21137

22138

23139
def SyncBatchNorm(*args, **kwargs):
24140
"""In cpu environment nn.SyncBatchNorm does not have kernel so use nn.BatchNorm2D instead"""
25141
if paddle.get_device() == 'cpu' or os.environ.get(
26-
'PADDLESEG_EXPORT_STAGE') or 'xpu' in paddle.get_device(
27-
) or 'npu' in paddle.get_device():
142+
'PADDLESEG_EXPORT_STAGE') or 'xpu' in paddle.get_device():
28143
return nn.BatchNorm2D(*args, **kwargs)
29144
elif paddle.distributed.ParallelEnv().nranks == 1:
30145
return nn.BatchNorm2D(*args, **kwargs)
146+
elif 'npu' in paddle.get_device():
147+
return NaiveSyncBatchNorm(*args, **kwargs)
31148
else:
32149
return nn.SyncBatchNorm(*args, **kwargs)
33150

paddleseg/models/layers/pyramid_pool.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
from paddleseg.models import layers
2020

2121

22+
class CustomAvgPool2D(nn.Layer):
23+
24+
def __init__(self):
25+
super().__init__()
26+
27+
def forward(self, x):
28+
return paddle.mean(x, axis=[2, 3], keepdim=True)
29+
30+
2231
class ASPPModule(nn.Layer):
2332
"""
2433
Atrous Spatial Pyramid Pooling.
@@ -64,9 +73,13 @@ def __init__(self,
6473
out_size = len(self.aspp_blocks)
6574

6675
if image_pooling:
76+
# avgpool with outsize=(1,1) is extreamly slow when backward,
77+
# so we replace it with equivalent mean operation.
78+
pool_layer = CustomAvgPool2D() if "npu" in paddle.get_device(
79+
) else nn.AdaptiveAvgPool2D(output_size=(1, 1),
80+
data_format=data_format)
6781
self.global_avg_pool = nn.Sequential(
68-
nn.AdaptiveAvgPool2D(output_size=(1, 1),
69-
data_format=data_format),
82+
pool_layer,
7083
layers.ConvBNReLU(in_channels,
7184
out_channels,
7285
kernel_size=1,

paddleseg/utils/utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from paddleseg.utils import logger, seg_env, get_sys_env
2727
from paddleseg.utils.download import download_file_and_uncompress
28+
from paddleseg.models.layers.layer_libs import NaiveSyncBatchNorm
2829

2930

3031
def set_seed(seed=None):
@@ -83,6 +84,9 @@ def convert_sync_batchnorm(model, device):
8384
and env_info['GPUs used'] and paddle.distributed.ParallelEnv().nranks > 1:
8485
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
8586
logger.info("Convert bn to sync_bn")
87+
elif device == "npu" and paddle.distributed.ParallelEnv().nranks > 1:
88+
model = NaiveSyncBatchNorm.convert_sync_batchnorm(model)
89+
logger.info("Convert bn to sync_bn in NPU Device")
8690
return model
8791

8892

@@ -97,7 +101,7 @@ def set_cv2_num_threads(num_workers):
97101

98102

99103
@contextlib.contextmanager
100-
def generate_tempdir(directory: str=None, **kwargs):
104+
def generate_tempdir(directory: str = None, **kwargs):
101105
'''Generate a temporary directory'''
102106
directory = seg_env.TMP_HOME if not directory else directory
103107
with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
@@ -160,23 +164,24 @@ def load_pretrained_model(model, pretrained_model):
160164
for k in keys:
161165
if k not in para_state_dict:
162166
logger.warning("{} is not in pretrained model".format(k))
163-
elif list(para_state_dict[k].shape) != list(model_state_dict[k]
164-
.shape):
167+
elif list(para_state_dict[k].shape) != list(
168+
model_state_dict[k].shape):
165169
logger.warning(
166170
"[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
167-
.format(k, para_state_dict[k].shape, model_state_dict[k]
168-
.shape))
171+
.format(k, para_state_dict[k].shape,
172+
model_state_dict[k].shape))
169173
else:
170174
model_state_dict[k] = para_state_dict[k]
171175
num_params_loaded += 1
172176
model.set_dict(model_state_dict)
173177
logger.info("There are {}/{} variables loaded into {}.".format(
174-
num_params_loaded,
175-
len(model_state_dict), model.__class__.__name__))
178+
num_params_loaded, len(model_state_dict),
179+
model.__class__.__name__))
176180

177181
else:
178-
raise ValueError('The pretrained model directory is not Found: {}'.
179-
format(pretrained_model))
182+
raise ValueError(
183+
'The pretrained model directory is not Found: {}'.format(
184+
pretrained_model))
180185
else:
181186
logger.info(
182187
'No pretrained model to load, {} will be trained from scratch.'.
@@ -251,6 +256,7 @@ def get_image_list(image_path):
251256

252257

253258
class NoAliasDumper(yaml.SafeDumper):
259+
254260
def ignore_aliases(self, data):
255261
return True
256262

0 commit comments

Comments
 (0)