Skip to content

Commit 5984144

Browse files
fix dtype checking for softmax (#51929)
1 parent 2b98993 commit 5984144

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

python/paddle/nn/functional/activation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,15 +1110,15 @@ def softmax(x, axis=-1, dtype=None, name=None):
11101110
use_cudnn = True
11111111
if dtype is None:
11121112
check_variable_and_dtype(
1113-
x, 'x', ['float16', 'float32', 'float64'], 'softmax'
1113+
x, 'x', ['float16', 'bfloat16', 'float32', 'float64'], 'softmax'
11141114
)
11151115
else:
11161116
check_dtype(
11171117
dtype,
11181118
'dtype',
1119-
['float32', 'float64'],
1119+
['float16', 'bfloat16', 'float32', 'float64'],
11201120
'softmax',
1121-
'If dtype is not None, it only support float32 or float64.',
1121+
'If dtype is not None, it only support float16, bfloat16, float32 or float64.',
11221122
)
11231123

11241124
helper = LayerHelper("softmax", **locals())

python/paddle/nn/layer/activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1324,7 +1324,7 @@ def __init__(self, axis=-1, name=None):
13241324
self._name = name
13251325

13261326
def forward(self, x):
1327-
return F.softmax(x, self._axis, self._dtype, self._name)
1327+
return F.softmax(x, self._axis, name=self._name)
13281328

13291329
def extra_repr(self):
13301330
name_str = ', name={}'.format(self._name) if self._name else ''

0 commit comments

Comments
 (0)