Skip to content

Commit f29c0ca

Browse files
authored
[AMP OP&Test] Fix the rtol setting rules in bfloat16 forward (#51875)
1 parent 75fb2ed commit f29c0ca

File tree

2 files changed

+8
-18
lines changed

2 files changed

+8
-18
lines changed

python/paddle/fluid/tests/unittests/eager_op_test.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,11 +1727,9 @@ def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
17271727
judge whether convert current output and expect to uint16.
17281728
return True | False
17291729
"""
1730-
if actual_np.dtype == np.uint16 and expect_np.dtype in [
1731-
np.float32,
1732-
np.float64,
1733-
]:
1734-
actual_np = convert_uint16_to_float(actual_np)
1730+
if actual_np.dtype == np.uint16:
1731+
if expect_np.dtype in [np.float32, np.float64]:
1732+
actual_np = convert_uint16_to_float(actual_np)
17351733
self.rtol = 1.0e-2
17361734
elif actual_np.dtype == np.float16:
17371735
self.rtol = 1.0e-3
@@ -1828,10 +1826,7 @@ def _compare_numpy(self, name, actual_np, expect_np):
18281826
)
18291827

18301828
def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
1831-
if actual_np.dtype == np.uint16 and expect_np.dtype in [
1832-
np.float32,
1833-
np.float64,
1834-
]:
1829+
if actual_np.dtype == np.uint16:
18351830
self.rtol = 1.0e-2
18361831
elif actual_np.dtype == np.float16:
18371832
self.rtol = 1.0e-3

python/paddle/fluid/tests/unittests/op_test.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,11 +1723,9 @@ def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
17231723
judge whether convert current output and expect to uint16.
17241724
return True | False
17251725
"""
1726-
if actual_np.dtype == np.uint16 and expect_np.dtype in [
1727-
np.float32,
1728-
np.float64,
1729-
]:
1730-
actual_np = convert_uint16_to_float(actual_np)
1726+
if actual_np.dtype == np.uint16:
1727+
if expect_np.dtype in [np.float32, np.float64]:
1728+
actual_np = convert_uint16_to_float(actual_np)
17311729
self.rtol = 1.0e-2
17321730
elif actual_np.dtype == np.float16:
17331731
self.rtol = 1.0e-3
@@ -1787,10 +1785,7 @@ def find_expect_value(self, name):
17871785
return imperative_expect, imperative_expect_t
17881786

17891787
def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
1790-
if actual_np.dtype == np.uint16 and expect_np.dtype in [
1791-
np.float32,
1792-
np.float64,
1793-
]:
1788+
if actual_np.dtype == np.uint16:
17941789
self.rtol = 1.0e-2
17951790
elif actual_np.dtype == np.float16:
17961791
self.rtol = 1.0e-3

0 commit comments

Comments
 (0)