Skip to content

Commit a45c2de

Browse files
committed
Bugfix in consistent!
1 parent bec638d commit a45c2de

File tree

2 files changed

+53
-27
lines changed

2 files changed

+53
-27
lines changed

src/p_sparse_matrix.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1668,7 +1668,7 @@ function psparse_consistent_impl(
16681668
n_ghost_cols = ghost_length(cols_co)
16691669
TA = typeof(A.blocks.ghost_own)
16701670
own_own = A.blocks.own_own
1671-
own_ghost = compresscoo(TA,I2,J2,V2,n_own_rows,n_ghost_cols)
1671+
own_ghost = compresscoo(TA,I2,J2,V2,n_own_rows,n_ghost_cols) # TODO this can be improved
16721672
ghost_own = compresscoo(TA,I_rcv_own,J_rcv_own,V_rcv_own,n_ghost_rows,n_own_cols)
16731673
ghost_ghost = compresscoo(TA,I_rcv_ghost,J_rcv_ghost,V_rcv_ghost,n_ghost_rows,n_ghost_cols)
16741674
K_own = precompute_nzindex(ghost_own,I_rcv_own,J_rcv_own)
@@ -1743,13 +1743,22 @@ function psparse_consistent_impl!(B,A,::Type{<:AbstractSplitMatrix},cache)
17431743
setcoofast!(B.blocks.ghost_ghost,V_rcv_ghost,K_ghost)
17441744
B
17451745
end
1746+
map(own_own_values(B),own_own_values(A)) do b,a
1747+
msg = "consistent!(B,A,cache) can only be called if B was obtained as B,cache = consistent(A)|>fetch"
1748+
@assert a === b msg
1749+
end
17461750
map(setup_snd,partition(A),cache)
17471751
parts_snd = map(i->i.parts_snd,cache)
17481752
parts_rcv = map(i->i.parts_rcv,cache)
17491753
V_snd = map(i->i.V_snd,cache)
17501754
V_rcv = map(i->i.V_rcv,cache)
17511755
graph = ExchangeGraph(parts_snd,parts_rcv)
17521756
t = exchange!(V_rcv,V_snd,graph)
1757+
map(own_ghost_values(B),own_ghost_values(A)) do b,a
1758+
if nonzeros(b) !== nonzeros(a)
1759+
copy!(nonzeros(b),nonzeros(a))
1760+
end
1761+
end
17531762
@async begin
17541763
wait(t)
17551764
map(setup_rcv,partition(B),cache)

test/p_sparse_matrix_tests.jl

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -327,35 +327,52 @@ function p_sparse_matrix_tests(distribute)
327327
d = dense_diag(A)
328328
dense_diag!(d,A)
329329

330-
nodes_per_dir = (5,5)
330+
nodes_per_dir = (2,2)
331331
parts_per_dir = (2,2)
332332
A = PartitionedArrays.laplace_matrix(nodes_per_dir,parts_per_dir,parts)
333-
334-
B = A*A
335333
A_seq = centralize(A)
336-
@test centralize(B) A_seq*A_seq
337-
338-
B = spmm(A,A)
339-
@test centralize(B) A_seq*A_seq
340-
B,cacheB = spmm(A,A;reuse=true)
341-
spmm!(B,A,A,cacheB)
342-
@test centralize(B) A_seq*A_seq
343-
344-
B = transpose(A)*A
345-
@test centralize(B) transpose(A_seq)*A_seq
346-
347-
B = spmtm(A,A)
348-
B,cacheB = spmtm(A,A;reuse=true)
349-
@test centralize(B) transpose(A_seq)*A_seq
350-
spmtm!(B,A,A,cacheB)
351-
@test centralize(B) transpose(A_seq)*A_seq
352-
353-
C = rap(transpose(A),A,A)
354-
@test centralize(C) transpose(A_seq)*A_seq*A_seq
355-
C,cacheC = rap(transpose(A),A,A;reuse=true)
356-
@test centralize(C) transpose(A_seq)*A_seq*A_seq
357-
rap!(C,transpose(A),A,A,cacheC)
358-
@test centralize(C) transpose(A_seq)*A_seq*A_seq
334+
Z = 2*A
335+
Z_seq = centralize(Z)
336+
337+
B = Z*A
338+
@test centralize(B) Z_seq*A_seq
339+
340+
B = spmm(Z,A)
341+
@test centralize(B) Z_seq*A_seq
342+
B,cacheB = spmm(Z,A;reuse=true)
343+
map(partition(A)) do A
344+
nonzeros(A.blocks.own_own) .*= 4
345+
nonzeros(A.blocks.own_ghost) .*= 4
346+
end
347+
A_seq = centralize(A)
348+
spmm!(B,Z,A,cacheB)
349+
@test centralize(B) Z_seq*(A_seq)
350+
351+
B = transpose(Z)*A
352+
@test centralize(B) transpose(Z_seq)*A_seq
353+
354+
B = spmtm(Z,A)
355+
B,cacheB = spmtm(Z,A;reuse=true)
356+
@test centralize(B) transpose(Z_seq)*A_seq
357+
map(partition(A)) do A
358+
nonzeros(A.blocks.own_own) .*= 4
359+
nonzeros(A.blocks.own_ghost) .*= 4
360+
end
361+
A_seq = centralize(A)
362+
spmtm!(B,Z,A,cacheB)
363+
@test centralize(B) transpose(Z_seq)*A_seq
364+
365+
C = rap(transpose(A),Z,A)
366+
@test centralize(C) transpose(A_seq)*Z_seq*A_seq
367+
C,cacheC = rap(transpose(A),Z,A;reuse=true)
368+
@test centralize(C) transpose(A_seq)*Z_seq*A_seq
369+
map(partition(A)) do A
370+
nonzeros(A.blocks.own_own) .*= 4
371+
nonzeros(A.blocks.own_ghost) .*= 4
372+
end
373+
A_seq = centralize(A)
374+
rap!(C,transpose(A),Z,A,cacheC)
375+
@test centralize(C) transpose(A_seq)*Z_seq*A_seq
359376

360377
r = pzeros(partition(axes(A,2)))
361378
x = pones(partition(axes(A,1)))

0 commit comments

Comments
 (0)