Skip to content

Commit 846a539

Browse files
committed
fix
1 parent c36eb75 commit 846a539

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

python/paddle/nn/layer/norm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,10 @@ def __init__(self,
563563
self._bias_attr = bias_attr
564564
self._use_global_stats = use_global_stats
565565

566-
self._dtype = 'float32'
566+
if get_default_dtype() == 'float16':
567+
self._dtype = 'float32'
568+
else:
569+
self._dtype = get_default_dtype()
567570

568571
param_shape = [num_features]
569572

0 commit comments

Comments
 (0)