Skip to content

Commit 7924078

Browse files
committed
some fixes in in-place sparse mul!'s, removed comments
1 parent 774f07c commit 7924078

File tree

7 files changed

+52
-175
lines changed

7 files changed

+52
-175
lines changed

src/PartitionedArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ using LinearAlgebra
66
using Printf
77
using CircularArrays
88
using StaticArrays
9-
import Base: +,-,*,/
9+
using Base
10+
import Base: +,-,*,/,copy
1011
import MPI
1112
import IterativeSolvers
1213
import Distances

src/sequential_implementations.jl

Lines changed: 28 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function *(A::SparseMatrixCSR{Bi,Tv,Ti},B::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,Tv,Ti}
1+
function Base.:*(A::SparseMatrixCSR{Bi,Tv,Ti},B::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,Tv,Ti}
22
p,q = size(A)
33
r,s = size(B)
44
if q != r && throw(DimensionMismatch("A has dimensions ($(p),$(q)) but B has dimensions ($(p),$(q))"));end
@@ -8,7 +8,7 @@ function *(A::SparseMatrixCSR{Bi,Tv,Ti},B::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,
88
SparseMatrixCSR{Bi}(p, s, Ccsc.colptr, rowvals(Ccsc), nonzeros(Ccsc))
99
end
1010

11-
function *(At::Transpose{Tv, SparseMatrixCSR{Bi,Tv,Ti}},B::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,Tv,Ti}
11+
function Base.:*(At::Transpose{Tv, SparseMatrixCSR{Bi,Tv,Ti}},B::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,Tv,Ti}
1212
p,q = size(At)
1313
r,s = size(B)
1414
if q != r && throw(DimensionMismatch("At has dimensions ($(p),$(q)) but B has dimensions ($(p),$(q))"));end
@@ -19,7 +19,7 @@ function *(At::Transpose{Tv, SparseMatrixCSR{Bi,Tv,Ti}},B::SparseMatrixCSR{Bi,Tv
1919
SparseMatrixCSR{Bi}(p, s, Ccsc.colptr, rowvals(Ccsc), nonzeros(Ccsc))
2020
end
2121

22-
function *(A::SparseMatrixCSR{Bi,Tv,Ti},Bt::Transpose{Tv, SparseMatrixCSR{Bi,Tv,Ti}}) where {Bi,Tv,Ti}
22+
function Base.:*(A::SparseMatrixCSR{Bi,Tv,Ti},Bt::Transpose{Tv, SparseMatrixCSR{Bi,Tv,Ti}}) where {Bi,Tv,Ti}
2323
p,q = size(A)
2424
r,s = size(Bt)
2525
if q != r && throw(DimensionMismatch("A has dimensions ($(p),$(q)) but B has dimensions ($(p),$(q))"));end
@@ -30,7 +30,7 @@ function *(A::SparseMatrixCSR{Bi,Tv,Ti},Bt::Transpose{Tv, SparseMatrixCSR{Bi,Tv,
3030
SparseMatrixCSR{Bi}(p, s, Ccsc.colptr, rowvals(Ccsc), nonzeros(Ccsc))
3131
end
3232

33-
function *(At::Transpose{Tv,SparseMatrixCSR{Bi,Tv,Ti}},Bt::Transpose{Tv, SparseMatrixCSR{Bi,Tv,Ti}}) where {Bi,Tv,Ti}
33+
function Base.:*(At::Transpose{Tv,SparseMatrixCSR{Bi,Tv,Ti}},Bt::Transpose{Tv, SparseMatrixCSR{Bi,Tv,Ti}}) where {Bi,Tv,Ti}
3434
p,q = size(At)
3535
r,s = size(Bt)
3636
if q != r && throw(DimensionMismatch("A has dimensions ($(p),$(q)) but B has dimensions ($(p),$(q))"));end
@@ -42,12 +42,12 @@ function *(At::Transpose{Tv,SparseMatrixCSR{Bi,Tv,Ti}},Bt::Transpose{Tv, SparseM
4242
SparseMatrixCSR{Bi}(p, s, Ccsc.colptr, rowvals(Ccsc), nonzeros(Ccsc))
4343
end
4444

45-
function *(x::Number,A::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,Tv,Ti}
45+
function Base.:*(x::Number,A::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,Tv,Ti}
4646
SparseMatrixCSR{Bi}(size(A)..., copy(A.rowptr), copy(A.colval), map(a -> x*a, A.nzval))
4747
end
48-
function *(A::SparseMatrixCSR,x::Number) *(x,A) end
48+
function Base.:*(A::SparseMatrixCSR,x::Number) *(x,A) end
4949

50-
function /(A::SparseMatrixCSR{Bi,Tv,Ti},x::Number) where {Bi,Tv,Ti}
50+
function Base.:/(A::SparseMatrixCSR{Bi,Tv,Ti},x::Number) where {Bi,Tv,Ti}
5151
SparseMatrixCSR{Bi}(size(A)..., copy(A.rowptr), copy(A.colval), map(a -> a/x, A.nzval))
5252
end
5353

@@ -87,7 +87,9 @@ function LinearAlgebra.mul!(C::SparseMatrixCSC{Tv,Ti},
8787
end
8888
for ip in nzrange(C,j)
8989
i = IC[ip]
90-
VC[ip] = x[i]
90+
if xb[i] == j
91+
VC[ip] = x[i]
92+
end
9193
end
9294
end
9395
C
@@ -132,7 +134,9 @@ function LinearAlgebra.mul!(C::SparseMatrixCSC{Tv,Ti},
132134
end
133135
for ip in nzrange(C,j)
134136
i = IC[ip]
135-
VC[ip] += α*x[i]
137+
if xb[i] == j
138+
VC[ip] += α*x[i]
139+
end
136140
end
137141
end
138142
C
@@ -482,12 +486,6 @@ function LinearAlgebra.mul!(C::SparseMatrixCSC{Tv,Ti},
482486
C
483487
end
484488

485-
# function LinearAlgebra.mul!(C::SparseMatrixCSC{Tv,Ti},At::Transpose{Tv,SparseMatrixCSC{Tv,Ti}},Bt::Transpose{Tv,SparseMatrixCSC{Tv,Ti}},α::Number,β::Number) where {Tv,Ti}
486-
# mul!(C,Bt.parent,At.parent,α,β)
487-
# C
488-
# end
489-
490-
491489
function LinearAlgebra.mul!(C::SparseMatrixCSR{Bi,Tv,Ti},
492490
A::SparseMatrixCSR{Bi,Tv,Ti},
493491
B::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,Tv,Ti}
@@ -553,7 +551,7 @@ function LinearAlgebra.mul!(C::SparseMatrixCSR{Bi,Tv,Ti},
553551
end
554552

555553
# Alternative to lazy csr to csc for matrix addition that does not drop structural zeros.
556-
function +(A::SparseMatrixCSR{Bi,Tv,Ti},B::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,Tv,Ti}
554+
function Base.:+(A::SparseMatrixCSR{Bi,Tv,Ti},B::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,Tv,Ti}
557555
if size(A) == size(B) || throw(DimensionMismatch("Size of B $(size(B)) must match size of A $(size(A))"));end
558556
p,q = size(A)
559557
nnz_C_upperbound = nnz(A) + nnz(B)
@@ -611,7 +609,7 @@ function +(A::SparseMatrixCSR{Bi,Tv,Ti},B::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,
611609
end
612610

613611
# Alternative to lazy csr to csc for matrix subtraction that does not drop structural zeros. Subtracts B from A, i.e. A - B.
614-
function -(A::SparseMatrixCSR{Bi,Tv,Ti},B::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,Tv,Ti}
612+
function Base.:-(A::SparseMatrixCSR{Bi,Tv,Ti},B::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,Tv,Ti}
615613
if size(A) == size(B) || throw(DimensionMismatch("Size of B $(size(B)) must match size of A $(size(A))"));end
616614
nnz_C_upperbound = nnz(A) + nnz(B)
617615
p,r = size(A)
@@ -668,7 +666,7 @@ function -(A::SparseMatrixCSR{Bi,Tv,Ti},B::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,
668666
SparseMatrixCSR{Bi}(p,r,IC,JC,VC) # A += B
669667
end
670668

671-
function -(A::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,Tv,Ti}
669+
function Base.:-(A::SparseMatrixCSR{Bi,Tv,Ti}) where {Bi,Tv,Ti}
672670
SparseMatrixCSR{Bi}(size(A)..., copy(A.rowptr), copy(A.colval), map(a->-a, A.nzval))
673671
end
674672

@@ -731,7 +729,7 @@ function +(A::SparseMatrixCSC{Tv,Ti},B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
731729
end
732730

733731
# Alternative to lazy csr to csc for matrix subtraction that does not drop structural zeros. Subtracts B from A, i.e. A - B.
734-
function -(A::SparseMatrixCSC{Tv,Ti},B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
732+
function Base.:-(A::SparseMatrixCSC{Tv,Ti},B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
735733
if size(A) == size(B) || throw(DimensionMismatch("Size of B $(size(B)) must match size of A $(size(A))"));end
736734
p,q = size(A)
737735
nnz_C_upperbound = nnz(A) + nnz(B)
@@ -788,7 +786,7 @@ function -(A::SparseMatrixCSC{Tv,Ti},B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
788786
SparseMatrixCSC{Tv,Ti}(p,q,JC,IC,VC)
789787
end
790788

791-
function -(A::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
789+
function Base.:-(A::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
792790
SparseMatrixCSC{Tv,Ti}(size(A)..., copy(A.colptr), copy(A.rowval), map(a->-a, A.nzval))
793791
end
794792

@@ -1083,8 +1081,7 @@ function RAP!(C::SparseMatrixCSR{Bi,Tv,Ti}, Plt::Transpose{Tv,SparseMatrixCSR{Bi
10831081
v = nonzeros(Pl)[kp]
10841082
for jp in nzrange(C,k)
10851083
j = JC[jp]
1086-
if xb[j] == i
1087-
# println("HUHHH")
1084+
if xb[j] == i
10881085
VC[jp] += v*x[j]
10891086
end
10901087
end
@@ -1183,15 +1180,12 @@ function RAP!(C::SparseMatrixCSR{Bi,Tv,Ti}, Plt::Transpose{Tv,SparseMatrixCSR{Bi
11831180
end
11841181
end
11851182
end
1186-
# @show((i,xb))
1187-
# @show((x))
11881183
for kp in nzrange(Pl, i)
11891184
k = colvals(Pl)[kp] # rowvals when transposed conceptually
11901185
v = nonzeros(Pl)[kp]
11911186
for jp in nzrange(C,k)
11921187
j = JC[jp]
11931188
if xb[j] == i
1194-
# println("EEEEEEEEEEEEEEEEEEHHH")
11951189
VC[jp] += α*v*x[j]
11961190
end
11971191
end
@@ -1312,7 +1306,9 @@ function RAP(Pl::SparseMatrixCSR{Bi,Tv,Ti}, A::SparseMatrixCSR{Bi,Tv,Ti}, Pr::Sp
13121306
end
13131307
for ind in nzrange(C,i)
13141308
j = JC[ind]
1315-
VC[ind] = xC[j]
1309+
if xbC[j] == i
1310+
VC[ind] = xC[j]
1311+
end
13161312
end
13171313
end
13181314
end
@@ -1381,7 +1377,9 @@ function RAP!(C::SparseMatrixCSR{Bi,Tv,Ti},Pl::SparseMatrixCSR{Bi,Tv,Ti}, A::Spa
13811377
end
13821378
for ind in nzrange(C,i)
13831379
j = JC[ind]
1384-
VC[ind] = xC[j]
1380+
if xbC[j] == i
1381+
VC[ind] = xC[j]
1382+
end
13851383
end
13861384
end
13871385
C
@@ -1441,7 +1439,9 @@ function RAP!(C::SparseMatrixCSR{Bi,Tv,Ti},Pl::SparseMatrixCSR{Bi,Tv,Ti}, A::Spa
14411439
end
14421440
for ind in nzrange(C,i)
14431441
j = JC[ind]
1444-
VC[ind] += α*xC[j]
1442+
if xbC[j] == i
1443+
VC[ind] += α*xC[j]
1444+
end
14451445
end
14461446
end
14471447
C
@@ -1457,142 +1457,6 @@ function RAP(Pl::SparseMatrixCSR{Bi,Tv,Ti}, A::SparseMatrixCSR{Bi,Tv,Ti}, Prt::T
14571457
RAP(Pl,A,halfperm(Prt.parent))
14581458
end
14591459

1460-
# # Not worth it, complexity of N^2, very slow for small problems
1461-
# function RAP(Pl::SparseMatrixCSR{Bi,Tv,Ti}, A::SparseMatrixCSR{Bi,Tv,Ti}, Prt::Transpose{Tv,SparseMatrixCSR{Bi,Tv,Ti}}) where {Bi,Tv,Ti}
1462-
# p,q = size(Pl)
1463-
# m,r = size(A)
1464-
# n,s = size(Prt)
1465-
# if q == m || throw(DimensionMismatch("Invalid dimensions for R*A: ($p,$q)*($m,$r),"));end
1466-
# if r == n || throw(DimensionMismatch("Invalid dimensions: RA*P: ($p,$r)*($n,$s)"));end
1467-
# # find max row length of transposed matrix
1468-
# function find_row_lengths!(At,xb)
1469-
# foreach(colvals(At.parent)) do i
1470-
# xb[i] += 1
1471-
# end
1472-
# xb
1473-
# end
1474-
# function RAP_symbolic!(Pl,A,Prt)
1475-
# Pr = Prt.parent
1476-
# JPl = colvals(Pl)
1477-
# JA = colvals(A)
1478-
# IPr = colvals(Pr) # colvals can be interpreted as rowvals when Pr is virtually transposed.
1479-
# xb = zeros(Ti, q)
1480-
# x = similar(xb, Tv) # sparse accumulator
1481-
# max_rPl = find_max_row_length(Pl)
1482-
# max_rA = find_max_row_length(A)
1483-
# find_row_lengths!(Prt, xb)
1484-
1485-
# max_rRA = max_rPl*max_rA
1486-
# JRA = Vector{Ti}(undef,max_rRA)
1487-
# JRA_permutation = collect(Ti, 1:max_rRA)
1488-
# nnz_C_upperbound = sum(l->max_rRA*l,xb)#p*max_rPl*max_rA*max_rPl
1489-
# # @show(nnz_C_upperbound)
1490-
# xb .= 0
1491-
# IC = Vector{Ti}(undef,p+1)
1492-
# JC = Vector{Ti}(undef, nnz_C_upperbound)
1493-
# nnz_C = 1
1494-
# IC[1] = nnz_C
1495-
# lp = Ref{Ti}()
1496-
# for i in 1:p
1497-
# lp[] = 0
1498-
# # local column pointer, refresh every row, start at 0 to allow empty rows
1499-
# # loop over columns "j" in row i of A
1500-
# for jp in nzrange(Pl, i)
1501-
# j = JPl[jp]
1502-
# # loop over columns "k" in row j of B
1503-
# for kp in nzrange(A, j)
1504-
# k = JA[kp]
1505-
# # since C is constructed rowwise, xb tracks if a column index is present in a new row in C.
1506-
# if xb[k] != i
1507-
# lp[] += 1
1508-
# JRA[lp[]] = k
1509-
# xb[k] = i
1510-
# end
1511-
# end
1512-
# end
1513-
# sort!(JRA_permutation,alg=QuickSort,by=i -> i <= lp[] ? JRA[i] : typemax(i))
1514-
# j_min = JRA[JRA_permutation[1]]
1515-
# j_max = JRA[JRA_permutation[lp[]]]
1516-
# for j in 1:size(Prt,2)
1517-
# ip_range = nzrange(Pr,j)
1518-
# ip = ip_range.start
1519-
# ip_stop = ip_range.stop
1520-
# i_min = IPr[ip]
1521-
# i_max = IPr[ip_stop]
1522-
# if i_min > j_max || j_min > i_max # no intersection
1523-
# continue
1524-
# end
1525-
# while ip <= ip_stop
1526-
# iPr = IPr[ip]
1527-
# if xb[iPr] == i
1528-
# JC[nnz_C] = j
1529-
# nnz_C += 1
1530-
# break
1531-
# end
1532-
# ip +=1
1533-
# end
1534-
# end
1535-
# IC[i+1] = nnz_C
1536-
# end
1537-
# nnz_C -= 1
1538-
# resize!(JC,nnz_C)
1539-
# VC = zeros(Tv,nnz_C)
1540-
# cache = (xb,x,JRA)
1541-
# SparseMatrixCSR{Bi}(p,s,IC,JC,VC), cache # values not yet initialized
1542-
# end
1543-
1544-
# function RAP_numeric!(C,Pl,A,Prt,cache)
1545-
# Pr = Prt.parent
1546-
# JPl = colvals(Pl)
1547-
# VPl = nonzeros(Pl)
1548-
# JA = colvals(A)
1549-
# VA = nonzeros(A)
1550-
# IPr = colvals(Pr) # colvals can be interpreted as rowvals when Pr is virtually transposed.
1551-
# VPr = nonzeros(Pr)
1552-
# JC = colvals(C)
1553-
# VC = nonzeros(C)
1554-
# (xb,x,_) = cache
1555-
# for i in 1:p
1556-
# # loop over columns "j" in row i of A
1557-
# for jp in nzrange(Pl, i)
1558-
# j = JPl[jp]
1559-
# # loop over columns "k" in row j of B
1560-
# for kp in nzrange(A, j)
1561-
# k = JA[kp]
1562-
# # since C is constructed rowwise, xb tracks if a column index is present in a new row in C.
1563-
# if xb[k] != i
1564-
# xb[k] = i
1565-
# x[k] = VPl[jp] * VA[kp]
1566-
# else
1567-
# x[k] += VPl[jp] * VA[kp]
1568-
# end
1569-
# end
1570-
# end
1571-
# for col_ptr in nzrange(C,i)
1572-
# Pr_col = JC[col_ptr]
1573-
# Pr_rows_range = nzrange(Pr,Pr_col) # column l of Pr^T
1574-
# for ip in Pr_rows_range
1575-
# Pr_row = IPr[ip]
1576-
# if xb[Pr_row] == i
1577-
# VC[col_ptr] += x[Pr_row]*VPr[ip]
1578-
# end
1579-
# end
1580-
# end
1581-
# end
1582-
# end
1583-
# function _RAP(Pl,A,Prt)
1584-
# C,cache = RAP_symbolic!(Pl,A,Prt)
1585-
# # @code_warntype RAP_symbolic!(Pl,A,Prt)
1586-
# xb = cache[1]
1587-
# xb.=0
1588-
# RAP_numeric!(C,Pl,A,Prt,cache)
1589-
# # @code_warntype RAP_numeric!(C,Pl,A,Prt,cache)
1590-
1591-
# C,cache
1592-
# end
1593-
# _RAP(Pl,A,Prt)
1594-
# end
1595-
15961460
function RAP!(C::SparseMatrixCSR{Bi,Tv,Ti}, Pl::SparseMatrixCSR{Bi,Tv,Ti}, A::SparseMatrixCSR{Bi,Tv,Ti}, Prt::Transpose{Tv,SparseMatrixCSR{Bi,Tv,Ti}},cache) where {Bi,Tv,Ti}
15971461
p,q = size(Pl)
15981462
m,r = size(A)
@@ -1659,8 +1523,6 @@ function RAP!(C::SparseMatrixCSR{Bi,Tv,Ti}, Pl::SparseMatrixCSR{Bi,Tv,Ti}, A::Sp
16591523
VC .*= β
16601524
# some cache items are present with the regular RAP product in mind, which is how the allocating verison is performed
16611525
(xb,_,x,_,_) = cache
1662-
# yes = 0
1663-
# no = 0
16641526
for i in 1:p
16651527
# loop over columns "j" in row i of A
16661528
for jp in nzrange(Pl, i)
@@ -1685,16 +1547,12 @@ function RAP!(C::SparseMatrixCSR{Bi,Tv,Ti}, Pl::SparseMatrixCSR{Bi,Tv,Ti}, A::Sp
16851547
iPr = IPr[ip]
16861548
if xb[iPr] == i
16871549
v += x[iPr]*VPr[ip]
1688-
# yes += 1
1689-
# else
1690-
# no += 1
16911550
end
16921551
end
16931552

16941553
VC[jpPr] += α*v
16951554
end
16961555
end
1697-
# @show((yes,no,yes/(yes+no),length(nonzeros(C))/p))
16981556
C
16991557
end
17001558

src/sparse_helpers.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,22 @@ end
177177

178178
function approx_equivalent(A::SparseMatrixCSR,B::SparseMatrixCSC) approx_equivalent_equivalent(B,A) end
179179

180-
function copy(At::Transpose{Tv,SparseMatrixCSR{Bi,Tv,Ti}}) where {Bi,Tv,Ti}
180+
function Base.copy(At::Transpose{Tv,SparseMatrixCSR{Bi,Tv,Ti}}) where {Bi,Tv,Ti}
181181
A = At.parent
182182
p,q = size(A)
183183
Acsc = SparseMatrixCSC{Tv, Ti}(q, p, A.rowptr, A.colval, A.nzval)
184184
Acsc_T = copy(transpose(Acsc)) # materialize SparseMAtrixCSC transpose
185185
SparseMatrixCSR{Bi}(q, p, Acsc_T.colptr, rowvals(Acsc_T), nonzeros(Acsc_T))
186186
end
187187

188+
function Base.similar(A::SparseMatrixCSR{Bi}, m::Integer, n::Integer) where Bi
189+
return SparseMatrixCSR{1}(m, n, ones(eltype(A.rowptr), m+1), eltype(A.colval)[], eltype(A.nzval)[])
190+
end
191+
192+
function Base.similar(A::SparseMatrixCSR{Bi}) where Bi
193+
return SparseMatrixCSR{Bi}(size(A)..., copy(A.rowptr), copy(colvals(A)), similar(nonzeros(A)))
194+
end
195+
188196
function pointer_array(A::SparseMatrixCSR)
189197
A.rowptr
190198
end

test/p_sparse_matrix_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ function p_sparse_matrix_tests(distribute)
377377
end
378378
A_seq = centralize(A)
379379
spmm!(B,Z,A,cacheB)
380+
380381
@test centralize(B) Z_seq*(A_seq)
381382
B = transpose(Z)*A
382383
@test centralize(B) transpose(Z_seq)*A_seq

0 commit comments

Comments
 (0)