Skip to content

Commit 9507969

Browse files
【Hackathon No.6】implement nan_to_num (#42469)
1 parent 13a5f18 commit 9507969

File tree

4 files changed

+255
-0
lines changed

4 files changed

+255
-0
lines changed

python/paddle/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@
230230
from .tensor.math import square # noqa: F401
231231
from .tensor.math import stanh # noqa: F401
232232
from .tensor.math import sum # noqa: F401
233+
from .tensor.math import nan_to_num # noqa: F401
233234
from .tensor.math import nansum # noqa: F401
234235
from .tensor.math import nanmean # noqa: F401
235236
from .tensor.math import count_nonzero # noqa: F401
@@ -666,6 +667,7 @@
666667
'renorm',
667668
'take_along_axis',
668669
'put_along_axis',
670+
'nan_to_num',
669671
'heaviside',
670672
'tril_indices',
671673
'index_add',
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from typing import Optional
17+
import numpy as np
18+
import paddle
19+
import paddle.fluid.core as core
20+
21+
# from op_test import OpTest
22+
23+
24+
def np_nan_to_num(
25+
x: np.ndarray,
26+
nan: float = 0.0,
27+
posinf: Optional[float] = None,
28+
neginf: Optional[float] = None,
29+
) -> np.ndarray:
30+
return np.nan_to_num(x, True, nan=nan, posinf=posinf, neginf=neginf)
31+
32+
33+
def np_nan_to_num_op(
34+
x: np.ndarray,
35+
nan: float,
36+
replace_posinf_with_max: bool,
37+
posinf: float,
38+
replace_neginf_with_min: bool,
39+
neginf: float,
40+
) -> np.ndarray:
41+
if replace_posinf_with_max:
42+
posinf = None
43+
if replace_neginf_with_min:
44+
neginf = None
45+
return np.nan_to_num(x, True, nan=nan, posinf=posinf, neginf=neginf)
46+
47+
48+
def np_nan_to_num_grad(x: np.ndarray, dout: np.ndarray) -> np.ndarray:
49+
dx = np.copy(dout)
50+
dx[np.isnan(x) | (x == np.inf) | (x == -np.inf)] = 0
51+
return dx
52+
53+
54+
class TestNanToNum(unittest.TestCase):
55+
def setUp(self):
56+
self.place = (
57+
paddle.CUDAPlace(0)
58+
if core.is_compiled_with_cuda()
59+
else paddle.CPUPlace()
60+
)
61+
62+
def test_static(self):
63+
x_np = np.array([[1, np.nan, -2], [np.inf, 0, -np.inf]]).astype(
64+
np.float32
65+
)
66+
out1_np = np_nan_to_num(x_np)
67+
out2_np = np_nan_to_num(x_np, 1.0)
68+
out3_np = np_nan_to_num(x_np, 1.0, 9.0)
69+
out4_np = np_nan_to_num(x_np, 1.0, 9.0, -12.0)
70+
paddle.enable_static()
71+
with paddle.static.program_guard(paddle.static.Program()):
72+
x = paddle.fluid.data('X', x_np.shape)
73+
out1 = paddle.nan_to_num(x)
74+
out2 = paddle.nan_to_num(x, 1.0)
75+
out3 = paddle.nan_to_num(x, 1.0, 9.0)
76+
out4 = paddle.nan_to_num(x, 1.0, 9.0, -12.0)
77+
exe = paddle.static.Executor(self.place)
78+
res = exe.run(feed={'X': x_np}, fetch_list=[out1, out2, out3, out4])
79+
80+
self.assertTrue(np.allclose(out1_np, res[0]))
81+
self.assertTrue(np.allclose(out2_np, res[1]))
82+
self.assertTrue(np.allclose(out3_np, res[2]))
83+
self.assertTrue(np.allclose(out4_np, res[3]))
84+
85+
def test_dygraph(self):
86+
87+
paddle.disable_static(place=self.place)
88+
89+
with paddle.fluid.dygraph.guard():
90+
# NOTE(tiancaishaonvjituizi): float64 input fails the test
91+
x_np = np.array([[1, np.nan, -2], [np.inf, 0, -np.inf]]).astype(
92+
np.float32
93+
# np.float64
94+
)
95+
x_tensor = paddle.to_tensor(x_np, stop_gradient=False)
96+
97+
out_tensor = paddle.nan_to_num(x_tensor)
98+
out_np = np_nan_to_num(x_np)
99+
self.assertTrue(np.allclose(out_tensor.numpy(), out_np))
100+
101+
out_tensor = paddle.nan_to_num(x_tensor, 1.0, None, None)
102+
out_np = np_nan_to_num(x_np, 1, None, None)
103+
self.assertTrue(np.allclose(out_tensor.numpy(), out_np))
104+
105+
out_tensor = paddle.nan_to_num(x_tensor, 1.0, 2.0, None)
106+
out_np = np_nan_to_num(x_np, 1, 2, None)
107+
self.assertTrue(np.allclose(out_tensor.numpy(), out_np))
108+
109+
out_tensor = paddle.nan_to_num(x_tensor, 1.0, None, -10.0)
110+
out_np = np_nan_to_num(x_np, 1, None, -10)
111+
self.assertTrue(np.allclose(out_tensor.numpy(), out_np))
112+
113+
out_tensor = paddle.nan_to_num(x_tensor, 1.0, 100.0, -10.0)
114+
out_np = np_nan_to_num(x_np, 1, 100, -10)
115+
self.assertTrue(np.allclose(out_tensor.numpy(), out_np))
116+
117+
paddle.enable_static()
118+
119+
def test_check_grad(self):
120+
paddle.disable_static(place=self.place)
121+
x_np = np.array([[1, np.nan, -2], [np.inf, 0, -np.inf]]).astype(
122+
np.float32
123+
)
124+
x_tensor = paddle.to_tensor(x_np, stop_gradient=False)
125+
126+
y = paddle.nan_to_num(x_tensor)
127+
dx = paddle.grad(y, x_tensor)[0].numpy()
128+
129+
np_grad = np_nan_to_num_grad(x_np, np.ones_like(x_np))
130+
self.assertTrue(np.allclose(np_grad, dx))
131+
132+
paddle.enable_static()
133+
134+
135+
# class BaseTestCases:
136+
#
137+
# class BaseOpTest(OpTest):
138+
#
139+
# def setUp(self):
140+
# self.op_type = "nan_to_num"
141+
# input = np.arange(100, dtype=np.float64)
142+
# input[5] = np.nan
143+
# input[29] = np.inf
144+
# input[97] = -np.inf
145+
# self.inputs = {'X': input}
146+
# self.attrs = self._attrs()
147+
# self.outputs = {
148+
# 'Out': np_nan_to_num_op(self.inputs['X'], **self.attrs)
149+
# }
150+
# paddle.enable_static()
151+
#
152+
# def test_check_output(self):
153+
# self.check_output()
154+
#
155+
# def test_check_grad(self):
156+
# input = self.inputs['X']
157+
# dout = np.ones_like(input) / input.size
158+
# self.check_grad(
159+
# ['X'],
160+
# 'Out',
161+
# user_defined_grads=[np_nan_to_num_grad(self.inputs['X'], dout)])
162+
#
163+
# def _attrs(self):
164+
# raise NotImplementedError()
165+
#
166+
#
167+
# class TestNanToNumOp1(BaseTestCases.BaseOpTest):
168+
#
169+
# def _attrs(self):
170+
# return {
171+
# 'nan': 0.0,
172+
# 'replace_posinf_with_max': True,
173+
# 'posinf': -1,
174+
# 'replace_neginf_with_min': True,
175+
# 'neginf': -10
176+
# }
177+
#
178+
#
179+
# class TestNanToNumOp2(BaseTestCases.BaseOpTest):
180+
#
181+
# def _attrs(self):
182+
# return {
183+
# 'nan': 2.0,
184+
# 'replace_posinf_with_max': False,
185+
# 'posinf': -1,
186+
# 'replace_neginf_with_min': True,
187+
# 'neginf': -10
188+
# }
189+
#
190+
#
191+
# class TestNanToNumOp3(BaseTestCases.BaseOpTest):
192+
#
193+
# def _attrs(self):
194+
# return {
195+
# 'nan': 0.0,
196+
# 'replace_posinf_with_max': False,
197+
# 'posinf': -1,
198+
# 'replace_neginf_with_min': False,
199+
# 'neginf': -10
200+
# }
201+
202+
if __name__ == "__main__":
203+
unittest.main()

python/paddle/tensor/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@
169169
from .math import square # noqa: F401
170170
from .math import stanh # noqa: F401
171171
from .math import sum # noqa: F401
172+
from .math import nan_to_num # noqa: F401
172173
from .math import nansum # noqa: F401
173174
from .math import nanmean # noqa: F401
174175
from .math import count_nonzero # noqa: F401
@@ -350,6 +351,7 @@
350351
'square',
351352
'stanh',
352353
'sum',
354+
'nan_to_num',
353355
'nansum',
354356
'nanmean',
355357
'count_nonzero',

python/paddle/tensor/math.py

+48
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,54 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
13641364
return out
13651365

13661366

1367+
def nan_to_num(x, nan=0.0, posinf=None, neginf=None, name=None):
1368+
"""
1369+
Replaces NaN, positive infinity, and negative infinity values in input tensor.
1370+
1371+
Args:
1372+
x (Tensor): An N-D Tensor, the data type is float32, float64.
1373+
nan (float, optional): the value to replace NaNs with. Default is 0.
1374+
posinf (float, optional): if a Number, the value to replace positive infinity values with. If None, positive infinity values are replaced with the greatest finite value representable by input’s dtype. Default is None.
1375+
neginf (float, optional): if a Number, the value to replace negative infinity values with. If None, negative infinity values are replaced with the lowest finite value representable by input’s dtype. Default is None.
1376+
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
1377+
1378+
Returns:
1379+
Tensor: Results of nan_to_num operation input Tensor ``x``.
1380+
1381+
Examples:
1382+
.. code-block:: python
1383+
1384+
import paddle
1385+
1386+
x = paddle.to_tensor([float('nan'), 0.3, float('+inf'), float('-inf')], dtype='float32')
1387+
out1 = paddle.nan_to_num(x) # [0, 0.3, 3.4028235e+38, -3.4028235e+38]
1388+
out2 = paddle.nan_to_num(x, nan=1) # [1, 0.3, 3.4028235e+38, -3.4028235e+38]
1389+
out3 = paddle.nan_to_num(x, posinf=5) # [0, 0.3, 5, -3.4028235e+38]
1390+
out4 = paddle.nan_to_num(x, nan=10, neginf=-99) # [10, 0.3, 3.4028235e+38, -99]
1391+
"""
1392+
# NOTE(tiancaishaonvjituizi): it seems that paddle handles the dtype of python float number
1393+
# incorrectly, so we have to explicitly contruct tensors here
1394+
posinf_value = paddle.full_like(x, float("+inf"))
1395+
neginf_value = paddle.full_like(x, float("-inf"))
1396+
nan = paddle.full_like(x, nan)
1397+
assert x.dtype in [paddle.float32, paddle.float64]
1398+
is_float32 = x.dtype == paddle.float32
1399+
if posinf is None:
1400+
posinf = (
1401+
np.finfo(np.float32).max if is_float32 else np.finfo(np.float64).max
1402+
)
1403+
posinf = paddle.full_like(x, posinf)
1404+
if neginf is None:
1405+
neginf = (
1406+
np.finfo(np.float32).min if is_float32 else np.finfo(np.float64).min
1407+
)
1408+
neginf = paddle.full_like(x, neginf)
1409+
x = paddle.where(paddle.isnan(x), nan, x)
1410+
x = paddle.where(x == posinf_value, posinf, x)
1411+
x = paddle.where(x == neginf_value, neginf, x)
1412+
return x
1413+
1414+
13671415
def nansum(x, axis=None, dtype=None, keepdim=False, name=None):
13681416
"""
13691417
Computes the sum of tensor elements over the given axis, treating Not a Numbers (NaNs) as zero.

0 commit comments

Comments
 (0)