@@ -332,9 +332,13 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
332332 @eval begin
333333 function CuSparseMatrixCSC {$elty} (csr:: CuSparseMatrixCSR{$elty} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1)
334334 m,n = size (csr)
335- colPtr = (index == ' O' ) ? CUDA. ones (Cint, n+ 1 ) : CUDA. zeros (Cint, n+ 1 )
336- rowVal = CUDA. zeros (Cint, nnz (csr))
337- nzVal = CUDA. zeros ($ elty, nnz (csr))
335+ colPtr = CuArray {Cint} (undef, n+ 1 )
336+ rowVal = CuArray {Cint} (undef, nnz (csr))
337+ nzVal = CuArray {$elty} (undef, nnz (csr))
338+ if version () <= v " 12.6-"
339+ # JuliaGPU/CUDA.jl#2806 (NVBUG 5384319)
340+ colPtr .= (index == ' O' ? 1 : 0 )
341+ end
338342 function bufferSize ()
339343 out = Ref {Csize_t} (1 )
340344 cusparseCsr2cscEx2_bufferSize (handle (), m, n, nnz (csr), nonzeros (csr),
@@ -352,9 +356,13 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
352356
353357 function CuSparseMatrixCSR {$elty} (csc:: CuSparseMatrixCSC{$elty} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1)
354358 m,n = size (csc)
355- rowPtr = (index == ' O' ) ? CUDA. ones (Cint, m+ 1 ) : CUDA. zeros (Cint, m+ 1 )
356- colVal = CUDA. zeros (Cint,nnz (csc))
357- nzVal = CUDA. zeros ($ elty,nnz (csc))
359+ rowPtr = CuArray {Cint} (undef, m+ 1 )
360+ colVal = CuArray {Cint} (undef, nnz (csc))
361+ nzVal = CuArray {$elty} (undef, nnz (csc))
362+ if version () <= v " 12.6-"
363+ # JuliaGPU/CUDA.jl#2806 (NVBUG 5384319)
364+ rowPtr .= (index == ' O' ? 1 : 0 )
365+ end
358366 function bufferSize ()
359367 out = Ref {Csize_t} (1 )
360368 cusparseCsr2cscEx2_bufferSize (handle (), n, m, nnz (csc), nonzeros (csc),
@@ -379,9 +387,13 @@ for (elty, welty) in ((:Float16, :Float32),
379387 @eval begin
380388 function CuSparseMatrixCSC {$elty} (csr:: CuSparseMatrixCSR{$elty} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1)
381389 m,n = size (csr)
382- colPtr = (index == ' O' ) ? CUDA. ones (Cint, n+ 1 ) : CUDA. zeros (Cint, n+ 1 )
383- rowVal = CUDA. zeros (Cint, nnz (csr))
384- nzVal = CUDA. zeros ($ elty, nnz (csr))
390+ colPtr = CuArray {Cint} (undef, n+ 1 )
391+ rowVal = CuArray {Cint} (undef, nnz (csr))
392+ nzVal = CuArray {$elty} (undef, nnz (csr))
393+ if version () <= v " 12.6-"
394+ # JuliaGPU/CUDA.jl#2806 (NVBUG 5384319)
395+ colPtr .= (index == ' O' ? 1 : 0 )
396+ end
385397 if $ elty == Float16 # broken for ComplexF16?
386398 function bufferSize ()
387399 out = Ref {Csize_t} (1 )
@@ -405,9 +417,13 @@ for (elty, welty) in ((:Float16, :Float32),
405417
406418 function CuSparseMatrixCSR {$elty} (csc:: CuSparseMatrixCSC{$elty} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1)
407419 m,n = size (csc)
408- rowPtr = (index == ' O' ) ? CUDA. ones (Cint, m+ 1 ) : CUDA. zeros (Cint, m+ 1 )
409- colVal = CUDA. zeros (Cint,nnz (csc))
410- nzVal = CUDA. zeros ($ elty,nnz (csc))
420+ rowPtr = CuArray {Cint} (undef, m+ 1 )
421+ colVal = CuArray {Cint} (undef, nnz (csc))
422+ nzVal = CuArray {$elty} (undef, nnz (csc))
423+ if version () <= v " 12.6-"
424+ # JuliaGPU/CUDA.jl#2806 (NVBUG 5384319)
425+ rowPtr .= (index == ' O' ? 1 : 0 )
426+ end
411427 if $ elty == Float16 # broken for ComplexF16?
412428 function bufferSize ()
413429 out = Ref {Csize_t} (1 )
@@ -523,9 +539,9 @@ for (fname,elty) in ((:cusparseSbsr2csr, :Float32),
523539 nb = cld (n, bsr. blockDim)
524540 cudesca = CuMatrixDescriptor (' G' , ' L' , ' N' , index)
525541 cudescc = CuMatrixDescriptor (' G' , ' L' , ' N' , indc)
526- csrRowPtr = CUDA . zeros ( Cint, m + 1 )
527- csrColInd = CUDA . zeros ( Cint, nnz (bsr))
528- csrNzVal = CUDA . zeros ( $ elty, nnz (bsr))
542+ csrRowPtr = CuArray { Cint} (undef , m + 1 )
543+ csrColInd = CuArray { Cint} (undef , nnz (bsr))
544+ csrNzVal = CuArray { $elty} (undef , nnz (bsr))
529545 $ fname (handle (), bsr. dir, mb, nb,
530546 cudesca, nonzeros (bsr), bsr. rowPtr, bsr. colVal,
531547 bsr. blockDim, cudescc, csrNzVal, csrRowPtr,
0 commit comments