Skip to content

Add GaussianNLLLoss API. #50843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
bec2c4a
Add GaussianNLLLoss API.
Atlantisming Feb 23, 2023
13d3880
Change `rotl` `atol`.Check `var` in dynamic graph
Atlantisming Feb 26, 2023
5079f74
Merge branch 'PaddlePaddle:develop' into GsNLLLoss_branch
Atlantisming Feb 27, 2023
75d858c
remove assertTrue
Atlantisming Feb 27, 2023
0110b07
update unittest
Atlantisming Feb 28, 2023
380faeb
update unittest for ci-covarage.add broadcast with same dim.
Atlantisming Mar 2, 2023
8c66074
Supply static err print.
Atlantisming Mar 3, 2023
335c7a8
Repair note and example.
Atlantisming Mar 3, 2023
7ac1556
Split unitest.
Atlantisming Mar 6, 2023
dd50740
empty commit.
Atlantisming Mar 9, 2023
5da754f
for standard commit.
Atlantisming Mar 10, 2023
0784a34
for standard commit.
Atlantisming Mar 10, 2023
86ba005
Add int dynamic graph test.
Atlantisming Mar 11, 2023
90b5616
Repair parameters name.
Atlantisming Mar 14, 2023
f11f97a
Repair unitest parameters name.
Atlantisming Mar 16, 2023
5fa70b8
Repair unitest parameters name
Atlantisming Mar 16, 2023
3601625
Repair unitest parameters name
Atlantisming Mar 17, 2023
8fd7e30
Repair unitest parameters name
Atlantisming Mar 17, 2023
16f21aa
Merge remote-tracking branch 'origin/GsNLLLoss_branch' into GsNLLLoss…
Atlantisming Mar 20, 2023
2cb2432
add square in code-block
Atlantisming Mar 24, 2023
e2d74a5
fit few notes.
Atlantisming Mar 24, 2023
2854f3d
fit few notes.
Atlantisming Mar 30, 2023
a547070
fit few notes.
Atlantisming Mar 31, 2023
2dc4a7b
fit few notes.
Atlantisming Apr 4, 2023
d8b7316
add few interpretations.
Atlantisming Apr 7, 2023
1d99e85
add few interpretations.
Atlantisming Apr 7, 2023
9c0e135
add few interpretations.
Atlantisming Apr 10, 2023
bb2b36e
fix import.
Atlantisming Apr 11, 2023
c36fd88
fix space.
Atlantisming Apr 11, 2023
70960a1
empty commit for ci.
Atlantisming Apr 11, 2023
1b8a851
Merge remote-tracking branch 'origin/GsNLLLoss_branch' into GsNLLLoss…
Atlantisming Apr 12, 2023
4c299ac
Merge branch 'PaddlePaddle:develop' into GsNLLLoss_branch
Atlantisming Apr 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions python/paddle/fluid/tests/unittests/test_gaussian_nll_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
import paddle.fluid.core as core
import paddle.nn.functional as F

np.random.seed(10)


def ref_gaussian_nll_loss(
input, target, var, full=False, eps=1e-6, reduction='none'
):
if var.shape != input.shape:
if input.shape[:-1] == var.shape:
var = np.expand_dims(var, -1)
elif input.shape[:-1] == var.shape[:-1] and var.shape[-1] == 1:
pass
else:
raise ValueError("var is of incorrect size")
if reduction != 'none' and reduction != 'mean' and reduction != 'sum':
raise ValueError(reduction + " is not valid")

if np.any(var < 0):
raise ValueError("var has negative entry/entries")

var = var.copy()
var = np.clip(var, a_min=eps, a_max=None)

loss = 0.5 * (np.log(var) + (input - target) ** 2 / var)
if full:
loss += 0.5 * np.log(2 * np.pi)

if reduction == 'none':
return loss
elif reduction == 'sum':
return [np.sum(loss)]
elif reduction == 'mean':
return [np.mean(loss)]


class TestGaussianNLLLossAPI(unittest.TestCase):
Copy link
Contributor

@GGBond8488 GGBond8488 Mar 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其他的没问题了,这个单测不同的场景分写成不同的test_case吧(把这些用例写到单独的class里面),方便后续直接定位是哪个case不通过。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# test paddle.nn.functional.gaussian_nll_loss, paddle.nn.gaussian_nll_loss

def setUp(self, type=None):
self.shape = [10, 2]
if type == 'float64':
self.input_np = np.random.random(self.shape).astype(np.float64)
self.target_np = np.random.random(self.shape).astype(np.float64)
self.var_np = np.ones(self.shape).astype(np.float64)
elif type == 'broadcast':
self.shape = [10, 2, 3]
self.broadcast_shape = [10, 2]
self.input_np = np.random.random(self.shape).astype(np.float32)
self.target_np = np.random.random(self.shape).astype(np.float32)
self.var_np = np.ones(self.broadcast_shape).astype(np.float32)
else:
self.input_np = np.random.random(self.shape).astype(np.float32)
self.target_np = np.random.random(self.shape).astype(np.float32)
self.var_np = np.ones(self.shape).astype(np.float32)

self.place = (
paddle.CUDAPlace(0)
if core.is_compiled_with_cuda()
else paddle.CPUPlace()
)

def test_dynamic_case(self, type=None, full=False, reduction='none'):
self.setUp(type)
out_ref = ref_gaussian_nll_loss(
self.input_np,
self.target_np,
self.var_np,
full=full,
reduction=reduction,
)
paddle.disable_static(self.place)

input_x = paddle.to_tensor(self.input_np)
target = paddle.to_tensor(self.target_np)
var = paddle.to_tensor(self.var_np)
out1 = F.gaussian_nll_loss(
input_x, target, var, full=full, reduction=reduction
)
gaussian_nll_loss = paddle.nn.GaussianNLLLoss(full, reduction=reduction)
out2 = gaussian_nll_loss(input_x, target, var)

for r in [out1, out2]:
self.assertEqual(
np.allclose(out_ref, r.numpy(), rtol=1e-8, atol=1e-7), True
)
paddle.enable_static()

def test_static_case(self, type=None, full=False, reduction='none'):
self.setUp(type)
out_ref = ref_gaussian_nll_loss(
self.input_np,
self.target_np,
self.var_np,
full=full,
reduction=reduction,
)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
if type == 'float64':
input_x = paddle.static.data('Input_x', self.shape, type)
target = paddle.static.data('Target', self.shape, type)
var = paddle.static.data('Var', self.shape, type)
elif type == 'broadcast':
input_x = paddle.static.data('Input_x', self.shape)
target = paddle.static.data('Target', self.shape)
var = paddle.static.data('Var', self.broadcast_shape)
else:
input_x = paddle.static.data('Input_x', self.shape, 'float32')
target = paddle.static.data('Target', self.shape, 'float32')
var = paddle.static.data('Var', self.shape, 'float32')
out1 = F.gaussian_nll_loss(
input_x, target, var, full=full, reduction=reduction
)
gaussian_nll_loss = paddle.nn.GaussianNLLLoss(
full, reduction=reduction
)
out2 = gaussian_nll_loss(input_x, target, var)

exe = paddle.static.Executor(self.place)
res = exe.run(
feed={
'Input_x': self.input_np,
'Target': self.target_np,
'Var': self.var_np,
},
fetch_list=[out1, out2],
)
for r in res:
self.assertEqual(
np.allclose(out_ref, r, rtol=1e-8, atol=1e-7), True
)

def test_api(self):
self.test_dynamic_case('float64')
self.test_dynamic_case('broadcast')
self.test_dynamic_case()
self.test_dynamic_case(full=True, reduction='mean')
self.test_static_case(full=True, reduction='mean')
self.test_static_case()
self.test_static_case('broadcast')
self.test_static_case('float64')


if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@
from .layer.loss import TripletMarginWithDistanceLoss
from .layer.loss import TripletMarginLoss
from .layer.loss import SoftMarginLoss
from .layer.loss import GaussianNLLLoss

from .layer.norm import BatchNorm # noqa: F401
from .layer.norm import SyncBatchNorm # noqa: F401
from .layer.norm import GroupNorm # noqa: F401
Expand Down Expand Up @@ -332,4 +334,5 @@ def weight_norm(*args):
'TripletMarginWithDistanceLoss',
'TripletMarginLoss',
'SoftMarginLoss',
'GaussianNLLLoss',
]
3 changes: 3 additions & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@
from .loss import triplet_margin_with_distance_loss
from .loss import triplet_margin_loss
from .loss import soft_margin_loss
from .loss import gaussian_nll_loss

from .norm import batch_norm # noqa: F401
from .norm import instance_norm # noqa: F401
from .norm import layer_norm # noqa: F401
Expand Down Expand Up @@ -246,4 +248,5 @@
'triplet_margin_loss',
'multi_margin_loss',
'soft_margin_loss',
'gaussian_nll_loss',
]
135 changes: 135 additions & 0 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math

# TODO: define loss functions of neural network
import paddle
import paddle.fluid as fluid
Expand Down Expand Up @@ -3884,3 +3886,136 @@ def soft_margin_loss(input, label, reduction='mean', name=None):
return paddle.mean(out, name=name)
else:
return out


def gaussian_nll_loss(
input, target, var, full=False, eps=1e-6, reduction='mean', name=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the new API development&submission process, there are API design and naming specifications. target, var and eps are not recommended. It is recommended to change to label, varible and epsilon to keep consistent with other APIs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this also needs to be modified in rfc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I realize.

):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstr 的描述与API参数对齐(var, eps)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也有点不太理解。。sorry

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#50843 (comment) 见这里

r"""Gaussian negative log likelihood loss.

The targets are treated as samples from Gaussian distributions with
expectations and variances predicted by the neural network. For a
``target`` tensor modelled as having Gaussian distribution with a tensor
of expectations ``input`` and a tensor of positive variances ``var`` the loss is:

.. math::
\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var},
\ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2}
{\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}

where :attr:`eps` is used for stability. By default, the constant term of
the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same
size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension
of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting.

Args:
input(Tensor): input tensor, expectation of the Gaussian distribution, available dtype is float32, float64.
target(Tensor): target tensor, sample from the Gaussian distribution, available dtype is float32, float64.
var(Tensor): tensor of positive variance(s), one for each of the expectations
in the input (heteroscedastic), or a single one (homoscedastic), available dtype is float32, float64.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该行和上一行进行左对齐,否则会解析错误
image

full (bool, optional): include the constant term in the loss
calculation. Default: ``False``.
eps (float, optional): value used to clamp ``var`` (see note below), for
stability. Default: 1e-6.
reduction (str, optional): specifies the reduction to apply to the
output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
will be applied, ``'mean'``: the output is the average of all batch
member losses, ``'sum'``: the output is the sum of all batch member
losses. Default: ``'mean'``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Shape:
- Input: :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional
dimensions
- Target: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input
but with one dimension equal to 1 (to allow for broadcasting)
- Var: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but
with one dimension equal to 1, or same shape as the input but with one fewer
dimension (to allow for broadcasting)
- Output: scalar if :attr:`reduction` is ``'mean'`` (default) or
``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same
shape as the input

Examples::
.. code-block:: python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code-block这行下空一行吧,否则解析会出错
image

import paddle
import paddle.nn.functional as F

input = paddle.randn([5, 2], dtype=paddle.float32)
target = paddle.randn([5, 2], dtype=paddle.float32)
var = paddle.ones([5, 2], dtype=paddle.float32)

loss = F.multi_label_soft_margin_loss(input, target, var, reduction='none')
print(loss)

loss = F.multi_label_soft_margin_loss(input, target, var, reduction='mean')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

示例代码不对

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

抱歉,git流程还是不熟悉,之前的例子丢失了,我现在来补充

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

print(loss)


Note:
The clamping of ``var`` is ignored with respect to autograd, and so the
gradients are unaffected by it.
"""

# Check var shape
# If var.shape == input.shape, the case is heteroscedastic and no further checks are needed.
# Otherwise:
if var.shape != input.shape:
# If var is one dimension short of input, but the shape match otherwise, then this is a homoscedastic case.
# e.g. input.shape = (10, 2, 3), var.shape = (10, 2)
# -> unsqueeze var so that var.shape = (10, 2, 1)
# this is done so that broadcasting can happen in the loss calculation
if input.shape[:-1] == var.shape:
var = paddle.unsqueeze(var, -1)
# This checks if the shape match up to the final dimension, and the final dimension of var is of shape 1.
# This is also a homoscedastic case.
# e.g. input.shape = (10, 2, 3), var.shape = (10, 2, 1)
elif (
input.shape[:-1] == var.shape[:-1] and var.shape[-1] == 1
): # Heteroscedastic case
pass
# If none of the above pass, then the shape of var is incorrect.
else:
raise ValueError("var is of incorrect shape")

# Check validity of reduction mode
if reduction != 'none' and reduction != 'mean' and reduction != 'sum':
raise ValueError(reduction + " is not valid")

# Entries of var must be non-negative
# print(paddle.any(var < 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释删掉

# if paddle.any(var < 0):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处静态图时判断var返回为LoDTensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是一个比较老的概念,但是应该不会影响这一段的检查

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能是我对静态图不了解,paddle.any(var < 0)是否在静态图时可能输出的是节点信息?我在测试静态图时这段检查会进入到判断语句内层返回Error。相同代码的动态图可以通过测试。

Copy link
Contributor Author

@Atlantisming Atlantisming Feb 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试返回的错误代码
if paddle.any(var < 0): \ print('var',var) \ print(paddle.any(var < 0)) \ raise ValueError("var has negative entry/entries")
输出的结果:
var var Var : LOD_TENSOR.shape(10, 2).dtype(float32).stop_gradient(True) E.E var any_1.tmp_0 : LOD_TENSOR.shape(1,).dtype(bool).stop_gradient(False) var var Var : LOD_TENSOR.shape(10, 2).dtype(float32).stop_gradient(True) var any_3.tmp_0 : LOD_TENSOR.shape(1,).dtype(bool).stop_gradient(False)

Copy link
Contributor Author

@Atlantisming Atlantisming Feb 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是一个比较老的概念,但是应该不会影响这一段的检查

尝试了cond()函数,发现在组网时都会调用cond()函数里设计的tru_func()false_func()并抛出函数内的错误。
然后我找了一下其他人是否有在lossfunc中使用到raise ValueError,可以发现在python\paddle\nn\functional\loss.py中的triplet_margin_with_distance_lossline 3526有关于节点内参数的值的判断。如您所说的不会影响到检查。
但是使用Print()进行检查发现节点内的数据无误。

print(paddle.any(var < 0))
var_res = paddle.static.Print(paddle.any(var < 0))
# if paddle.any(var < 0):
#     raise ValueError("var has negative entry/entries")
================================================
Variable: any_1.tmp_0
  - lod: {}
  - place: Place(cpu)
  - shape: [1]
  - layout: NCHW
  - dtype: bool
  - data: [0]
Variable: any_3.tmp_0
  - lod: {}
  - place: Place(cpu)
  - shape: [1]
  - layout: NCHW
  - dtype: bool
  - data: [0]

进程已结束,退出代码0
var any_0.tmp_0 : LOD_TENSOR.shape(1,).dtype(bool).stop_gradient(False)
var any_2.tmp_0 : LOD_TENSOR.shape(1,).dtype(bool).stop_gradient(False)
I0227 16:33:39.522938 20040 interpretercore.cc:273] New Executor is Running.


Ran 1 test in 0.204s

OK

如果不添加判断的代码,则可以正常通过测试
如果添加了判断代码,仍会进入到判断语句中返回错误。

Error
Traceback (most recent call last):
  File "D:\PyWorkspace\Paddle\python\paddle\fluid\tests\unittests\test_gaussian_nll_loss.py", line 130, in test_static_case
    out1,var_res = F.gaussian_nll_loss(
  File "D:\Anaconda\envs\paddle_devcpu\lib\site-packages\paddle\nn\functional\loss.py", line 4003, in gaussian_nll_loss
    raise ValueError("var has negative entry/entries")
ValueError: var has negative entry/entries

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果语句修改为

if not paddle.all(var > 0):
    raise ValueError("var has negative entry/entries")

也可以通过测试。。但是会不会还是判断的是节点

Copy link
Contributor

@GGBond8488 GGBond8488 Feb 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你的判断是正确的,在静态图里面,组网阶段是没有办法拿到var的数据的,所以这个检查在静态图下会报错,现在有两种解决方案:

  1. 增加c++ 层 的kernel,在kernel层实现计算,并实现对数据的检查,kernel运行在计算阶段,可以拿到对应的数据
  2. https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/static/nn/control_flow.py#L43,利用这里 的Assert OP进行数值判断和提示

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是一个比较老的概念,但是应该不会影响这一段的检查

尝试了cond()函数,发现在组网时都会调用cond()函数里设计的tru_func()false_func()并抛出函数内的错误。 然后我找了一下其他人是否有在lossfunc中使用到raise ValueError,可以发现在python\paddle\nn\functional\loss.py中的triplet_margin_with_distance_lossline 3526有关于节点内参数的值的判断。如您所说的不会影响到检查。 但是使用Print()进行检查发现节点内的数据无误。

print(paddle.any(var < 0))
var_res = paddle.static.Print(paddle.any(var < 0))
# if paddle.any(var < 0):
#     raise ValueError("var has negative entry/entries")
================================================
Variable: any_1.tmp_0
  - lod: {}
  - place: Place(cpu)
  - shape: [1]
  - layout: NCHW
  - dtype: bool
  - data: [0]
Variable: any_3.tmp_0
  - lod: {}
  - place: Place(cpu)
  - shape: [1]
  - layout: NCHW
  - dtype: bool
  - data: [0]

进程已结束,退出代码0
var any_0.tmp_0 : LOD_TENSOR.shape(1,).dtype(bool).stop_gradient(False)
var any_2.tmp_0 : LOD_TENSOR.shape(1,).dtype(bool).stop_gradient(False)
I0227 16:33:39.522938 20040 interpretercore.cc:273] New Executor is Running.


Ran 1 test in 0.204s

OK

如果不添加判断的代码,则可以正常通过测试 如果添加了判断代码,仍会进入到判断语句中返回错误。

Error
Traceback (most recent call last):
  File "D:\PyWorkspace\Paddle\python\paddle\fluid\tests\unittests\test_gaussian_nll_loss.py", line 130, in test_static_case
    out1,var_res = F.gaussian_nll_loss(
  File "D:\Anaconda\envs\paddle_devcpu\lib\site-packages\paddle\nn\functional\loss.py", line 4003, in gaussian_nll_loss
    raise ValueError("var has negative entry/entries")
ValueError: var has negative entry/entries

这里的cond,控制流会对控制流的分支都进行组网,所以会发现true_fn以及false_fn都会抛出异常
而Print()实际上也是一个op,也在进行组网,只是在计算阶段会执行和打印

# raise ValueError("var has negative entry/entries")

if not in_dygraph_mode():
check_variable_and_dtype(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类型检查静态图动态图都需要,以及确认这里不支持int32和int64吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是因为paddle.square函数只支持float32、float64. 参考了其他loss函数都是这样之后就只使用了这两种类型。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的check_variable_and_dtype 参考其他API的实现好像都是在静态图检测节点和数据类型。动态图分支已经添加了条件分支检查数据类型~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emmm, 尽量写动静统一的代码吧,专属于静态和动态图分支的地方才区分开,check_variable_and_dtype会自动跳过动态图,动态图的类型检查应该也不用,执行的时候会自动在这一行报错

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数名称已修改,check_variable_and_dtype已经移出静态图分支。
但是动态图分支的的检查没有删除,因为执行到loss时在log函数中检查参数数据类型时的报错过不了单测TAT

input, 'Input', ['float32', 'float64'], 'gaussian_nll_loss'
)
check_variable_and_dtype(
target,
'Target',
['float32', 'float64'],
'gaussian_nll_loss',
)
check_variable_and_dtype(
var,
'Var',
['float32', 'float64'],
'gaussian_nll_loss',
)

# Clamp for stability
var = var.clone()
with paddle.no_grad():
var = paddle.clip(var, min=eps)
# Calculate the loss
loss = 0.5 * (paddle.log(var) + paddle.square(input - target) / var)
if full:
loss += 0.5 * math.log(2 * math.pi)

if reduction == 'mean':
return loss.mean()
elif reduction == 'sum':
return loss.sum()
else:
return loss
2 changes: 2 additions & 0 deletions python/paddle/nn/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
from .loss import TripletMarginLoss
from .loss import SoftMarginLoss
from .loss import MultiMarginLoss
from .loss import GaussianNLLLoss

from .norm import BatchNorm1D # noqa: F401
from .norm import BatchNorm2D # noqa: F401
from .norm import BatchNorm3D # noqa: F401
Expand Down
Loading