@@ -586,10 +586,10 @@ end
586586# # CSR to COO and vice-versa
587587
588588function CuSparseMatrixCSR {Tv} (coo:: CuSparseMatrixCOO{Tv} ; index:: SparseChar = ' O' ) where {Tv}
589- m,n = size (coo)
590- nnz (coo) == 0 && return CuSparseMatrixCSR {Tv} (CUDA. ones (Cint, m+ 1 ), coo. colInd, nonzeros (coo), size (coo))
589+ m, n = size (coo)
590+ csrRowPtr = (index == ' O' ) ? CUDA. ones (Cint, m + 1 ) : CUDA. zeros (Cint, m + 1 )
591+ nnz (coo) == 0 && return CuSparseMatrixCSR {Tv} (csrRowPtr, coo. colInd, nonzeros (coo), size (coo))
591592 coo = sort_coo (coo, ' R' )
592- csrRowPtr = CuVector {Cint} (undef, m+ 1 )
593593 cusparseXcoo2csr (handle (), coo. rowInd, nnz (coo), m, csrRowPtr, index)
594594 CuSparseMatrixCSR {Tv} (csrRowPtr, coo. colInd, nonzeros (coo), size (coo))
595595end
@@ -605,10 +605,10 @@ end
605605# ## CSC to COO and viceversa
606606
607607function CuSparseMatrixCSC {Tv} (coo:: CuSparseMatrixCOO{Tv} ; index:: SparseChar = ' O' ) where {Tv}
608- m,n = size (coo)
609- nnz (coo) == 0 && return CuSparseMatrixCSC {Tv} (CUDA. ones (Cint, n+ 1 ), coo. rowInd, nonzeros (coo), size (coo))
608+ m, n = size (coo)
609+ cscColPtr = (index == ' O' ) ? CUDA. ones (Cint, n + 1 ) : CUDA. zeros (Cint, n + 1 )
610+ nnz (coo) == 0 && return CuSparseMatrixCSC {Tv} (cscColPtr, coo. rowInd, nonzeros (coo), size (coo))
610611 coo = sort_coo (coo, ' C' )
611- cscColPtr = CuVector {Cint} (undef, n+ 1 )
612612 cusparseXcoo2csr (handle (), coo. colInd, nnz (coo), n, cscColPtr, index)
613613 CuSparseMatrixCSC {Tv} (cscColPtr, coo. rowInd, nonzeros (coo), size (coo))
614614end
0 commit comments