Skip to content

Commit 8e0e6a9

Browse files
authored
【 Paddle Tensor 规范化第二期 】paddle.linalg.cholesky适配0-size Tensor (#70790)
* fix * update * fix * fix * fix
1 parent 968c8b8 commit 8e0e6a9

File tree

4 files changed

+31
-21
lines changed

4 files changed

+31
-21
lines changed

paddle/phi/kernels/cpu/cholesky_kernel.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ void CholeskyKernel(const Context& dev_ctx,
3333
using OutputMatrixMap = Eigen::Map<EigenMatrix>;
3434

3535
auto& dims = x.dims();
36+
if (x.numel() == 0) {
37+
out->Resize(dims);
38+
dev_ctx.template Alloc<T>(out);
39+
return;
40+
}
3641
int batch_count = 1;
3742
for (int i = 0; i < dims.size() - 2; i++) {
3843
batch_count *= static_cast<int>(dims[i]);

paddle/phi/kernels/gpu/cholesky_kernel.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ void CholeskyKernel(const Context& dev_ctx,
115115
bool upper,
116116
DenseTensor* out) {
117117
auto& dims = x.dims();
118+
if (x.numel() == 0) {
119+
out->Resize(dims);
120+
dev_ctx.template Alloc<T>(out);
121+
return;
122+
}
118123
int batch_count = 1;
119124
for (int i = 0; i < dims.size() - 2; i++) {
120125
batch_count *= dims[i];

paddle/phi/kernels/impl/cholesky_grad_kernel_impl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,11 @@ void CholeskyGradKernel(const Context& dev_ctx,
245245
auto* x_grad_data = dev_ctx.template Alloc<T>(x_grad);
246246

247247
auto& dims = out.dims();
248+
if (out.numel() == 0) {
249+
x_grad->Resize(dims);
250+
dev_ctx.template Alloc<T>(x_grad);
251+
return;
252+
}
248253
int batch_count = 1;
249254
for (int i = 0; i < dims.size() - 2; i++) {
250255
batch_count *= dims[i];

test/legacy_test/test_cholesky_op.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
1615
import unittest
1716

1817
import numpy as np
@@ -64,14 +63,7 @@ def test_check_output(self):
6463
self.check_output(check_pir=True)
6564

6665
def test_check_grad(self):
67-
places = []
68-
if (
69-
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
70-
in ['1', 'true', 'on']
71-
or not core.is_compiled_with_cuda()
72-
or core.is_compiled_with_rocm()
73-
):
74-
places.append(base.CPUPlace())
66+
places = [base.CPUPlace()]
7567
if core.is_compiled_with_cuda() and (not core.is_compiled_with_rocm()):
7668
places.append(base.CUDAPlace(0))
7769
for p in places:
@@ -161,6 +153,11 @@ def init_config(self):
161153
self._input_shape = (32, 32)
162154

163155

156+
class TestCholeskyOpZeroSize(TestCholeskyOp):
157+
def init_config(self):
158+
self._input_shape = (0, 0)
159+
160+
164161
class TestDygraph(unittest.TestCase):
165162
def test_dygraph(self):
166163
if core.is_compiled_with_rocm():
@@ -176,27 +173,20 @@ def test_dygraph(self):
176173

177174
class TestCholeskySingularAPI(unittest.TestCase):
178175
def setUp(self):
179-
self.places = []
180-
if (
181-
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
182-
in ['1', 'true', 'on']
183-
or not core.is_compiled_with_cuda()
184-
or core.is_compiled_with_rocm()
185-
):
186-
self.places.append(base.CPUPlace())
176+
self.places = [base.CPUPlace()]
187177
if core.is_compiled_with_cuda() and (not core.is_compiled_with_rocm()):
188178
self.places.append(base.CUDAPlace(0))
189179

190-
def check_static_result(self, place, with_out=False):
180+
def check_static_result(self, place, input_shape, with_out=False):
191181
with paddle.static.program_guard(
192182
paddle.static.Program(), paddle.static.Program()
193183
):
194184
input = paddle.static.data(
195-
name="input", shape=[4, 4], dtype="float64"
185+
name="input", shape=input_shape, dtype="float64"
196186
)
197187
result = paddle.cholesky(input)
198188

199-
input_np = np.zeros([4, 4]).astype("float64")
189+
input_np = np.zeros(input_shape).astype("float64")
200190

201191
exe = base.Executor(place)
202192
try:
@@ -211,7 +201,9 @@ def check_static_result(self, place, with_out=False):
211201

212202
def test_static(self):
213203
for place in self.places:
214-
self.check_static_result(place=place)
204+
self.check_static_result(place=place, input_shape=[4, 4])
205+
self.check_static_result(place=place, input_shape=[0, 0])
206+
self.check_static_result(place=place, input_shape=[5, 0, 0])
215207

216208
def test_dygraph(self):
217209
for place in self.places:
@@ -222,9 +214,12 @@ def test_dygraph(self):
222214
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
223215
]
224216
).astype("float64")
217+
input_np_zero = np.zeros((0, 3, 3), dtype="float64")
225218
input = paddle.to_tensor(input_np)
219+
input_zero = paddle.to_tensor(input_np_zero)
226220
try:
227221
result = paddle.cholesky(input)
222+
result_zero = paddle.cholesky(input_zero)
228223
except RuntimeError as ex:
229224
print("The mat is singular")
230225
except ValueError as ex:

0 commit comments

Comments
 (0)