Skip to content

Commit 9a07f12

Browse files
committed
Fix
1 parent f1b794d commit 9a07f12

File tree

3 files changed

+42
-0
lines changed

3 files changed

+42
-0
lines changed

paddle/phi/kernels/impl/renorm_grad_kernel_impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ void RenormGradKernel(const Context& dev_ctx,
3737
auto dimension_each = input_dims[dim];
3838
dx->Resize(x.dims());
3939
dev_ctx.template Alloc<T>(dx);
40+
if (dx && dx->numel() == 0) {
41+
return;
42+
}
4043
phi::funcs::RenormGradFunc(dev_ctx,
4144
x_data,
4245
dout_data,

paddle/phi/kernels/impl/renorm_kernel_impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14+
#pragma once
1415

1516
#include "paddle/phi/core/dense_tensor.h"
1617
#include "paddle/phi/kernels/funcs/eigen/common.h"
@@ -28,6 +29,9 @@ void RenormKernel(const Context& dev_ctx,
2829
DenseTensor* out) {
2930
out->Resize(x.dims());
3031
dev_ctx.template Alloc<T>(out);
32+
if (out && out->numel() == 0) {
33+
return;
34+
}
3135
auto x_ptr = x.template data<T>();
3236
auto numel = x.numel();
3337
int dim = axis;

test/legacy_test/test_renorm_op.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,41 @@ def test_dygraph_api(self):
122122
np.testing.assert_allclose(expected, np.array(y), rtol=1e-05)
123123

124124

125+
class TestRenormAPI_ZeroSize(unittest.TestCase):
126+
def input_data(self):
127+
self.shape = [2, 0, 3]
128+
self.data_x = np.random.random(self.shape).astype('float64')
129+
self.p = 1.0
130+
self.dim = 2
131+
self.max_norm = 2.05
132+
133+
def test_renorm_api(self):
134+
paddle.enable_static()
135+
self.input_data()
136+
137+
# case 1:
138+
with paddle.static.program_guard(
139+
paddle.static.Program(), paddle.static.Program()
140+
):
141+
x = paddle.static.data(name="x", shape=self.shape, dtype='float64')
142+
z = paddle.renorm(x, self.p, self.dim, self.max_norm)
143+
exe = base.Executor(base.CPUPlace())
144+
(res,) = exe.run(
145+
feed={"x": self.data_x}, fetch_list=[z], return_numpy=False
146+
)
147+
np.testing.assert_allclose(np.array(res).shape, self.shape)
148+
149+
def test_dygraph_api(self):
150+
self.input_data()
151+
with base.dygraph.guard(base.CPUPlace()):
152+
x = paddle.to_tensor(self.data_x, stop_gradient=False)
153+
y = paddle.renorm(x, 1.0, 2, 2.05)
154+
np.testing.assert_allclose(np.array(y).shape, self.shape)
155+
z = paddle.mean(y)
156+
z.backward(retain_graph=True)
157+
np.testing.assert_allclose(x.grad.shape, x.shape)
158+
159+
125160
if __name__ == '__main__':
126161
paddle.enable_static()
127162
unittest.main()

0 commit comments

Comments
 (0)