Skip to content

Commit 88b3ce1

Browse files
authored
Merge pull request #201 from JustinPrent/add-nonblocking-reduction
Non-blocking allReduce implementation for dot products
2 parents 95604fd + fef4cc3 commit 88b3ce1

File tree

5 files changed

+75
-0
lines changed

5 files changed

+75
-0
lines changed

src/PartitionedArrays.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ export exchange
5555
export exchange!
5656
export allocate_exchange
5757
export find_rcv_ids_gather_scatter
58+
export setup_non_blocking_reduction
59+
export non_blocking_reduction
5860
include("primitives.jl")
5961

6062
export DebugArray
@@ -144,6 +146,8 @@ export SplitVector
144146
export split_vector
145147
export split_vector_blocks
146148
export pvector_from_split_blocks
149+
export setup_non_blocking_dot
150+
export non_blocking_dot
147151
include("p_vector.jl")
148152

149153
export SplitMatrix

src/mpi_array.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,38 @@ function reduction_impl(op,a::MPIArray,destination;init=nothing)
500500
MPIArray(b_item,comm,size(a))
501501
end
502502

503+
function setup_non_blocking_reduction_impl(a::MPIArray, ::Type{T}) where T
504+
request = MPI.UnsafeRequest() # Single reduction request
505+
buffer = Ref{T}()
506+
return (request = request, recvbuf = buffer)
507+
end
508+
509+
function non_blocking_reduction_impl(op, a::MPIArray, setup, destination=:all; init=nothing)
510+
@assert destination === :all
511+
T = eltype(a)
512+
comm = a.comm
513+
opr = MPI.Op(op, T)
514+
515+
sendbuf = Ref(a.item)
516+
recvbuf = setup.recvbuf
517+
request = setup.request
518+
rbuf = MPI.RBuffer(sendbuf, recvbuf)
519+
520+
state = (sendbuf, recvbuf, request)
521+
522+
GC.@preserve state MPI.API.MPI_Iallreduce(rbuf.senddata, rbuf.recvdata, rbuf.count, rbuf.datatype, opr, comm, request)
523+
524+
525+
@fake_async begin
526+
GC.@preserve state MPI.Wait(request)
527+
b_item = recvbuf[]
528+
if init !== nothing
529+
b_item = op(b_item,init)
530+
end
531+
MPIArray(b_item,comm,size(a))
532+
end
533+
end
534+
503535
function Base.reduce(op,a::MPIArray;kwargs...)
504536
r = reduction(op,a;destination=:all,kwargs...)
505537
r.item

src/p_vector.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,22 @@ function LinearAlgebra.dot(a::PVector,b::PVector)
11911191
sum(c)
11921192
end
11931193

1194+
function setup_non_blocking_dot(a::PVector, b::PVector)
1195+
partials = map(own_values(a), own_values(b)) do mya, myb
1196+
zero(eltype(mya)) + zero(eltype(myb))
1197+
end
1198+
setup_non_blocking_reduction(partials)
1199+
end
1200+
1201+
function non_blocking_dot(a::PVector, b::PVector, setup)
1202+
partials = map(dot, own_values(a), own_values(b))
1203+
t = non_blocking_reduction(+, partials, setup, destination=:all, init=zero(eltype(a)) + zero(eltype(b)))
1204+
@fake_async begin
1205+
getany(fetch(t))
1206+
end
1207+
end
1208+
1209+
11941210
function LinearAlgebra.rmul!(a::PVector,v::Number)
11951211
map(partition(a)) do l
11961212
rmul!(l,v)

src/primitives.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,24 @@ end
708708
# b
709709
#end
710710

711+
function non_blocking_reduction(op,a,setup= setup_non_blocking_reduction(a);destination=MAIN,kwargs...)
712+
non_blocking_reduction_impl(op,a,setup,destination;kwargs...)
713+
end
714+
715+
function setup_non_blocking_reduction(a)
716+
setup_non_blocking_reduction_impl(a,eltype(a))
717+
end
718+
719+
function setup_non_blocking_reduction_impl(a::AbstractArray, ::Type{T}) where T
720+
return nothing
721+
end
722+
723+
function non_blocking_reduction_impl(op, a::AbstractArray, setup, destination=:all; init=nothing)
724+
@fake_async begin
725+
reduction_impl(op, a, destination; init=init)
726+
end
727+
end
728+
711729
"""
712730
struct ExchangeGraph{A}
713731

test/p_vector_tests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ function p_vector_tests(distribute)
9090
@test sqrt(aa) norm(a)
9191
@test euclidean(a,a) + 1 1
9292

93+
# Quick Test non_blocking_dot
94+
setup = setup_non_blocking_dot(a,b)
95+
t = non_blocking_dot(a,b,setup)
96+
@test fetch(t) ab
97+
9398
n = 10
9499
parts = rank
95100
row_partition = map(parts) do part

0 commit comments

Comments
 (0)