Skip to content

Commit aeacfc2

Browse files
authored
Merge pull request #103 from termi-official/do/linsolvejl
LinearSolve.jl coarse solve integration
2 parents e567b00 + 2a585dd commit aeacfc2

File tree

5 files changed

+51
-9
lines changed

5 files changed

+51
-9
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.5.1"
55
[deps]
66
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
89
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
910
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1011
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -19,8 +20,9 @@ DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
1920
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
2021
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
2122
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
23+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2224
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2325
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2426

2527
[targets]
26-
test = ["DelimitedFiles", "FileIO", "IterativeSolvers", "JLD2", "Random", "Test"]
28+
test = ["DelimitedFiles", "FileIO", "IterativeSolvers", "JLD2", "LinearSolve", "Random", "Test"]

src/AlgebraicMultigrid.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,13 @@ module AlgebraicMultigrid
22

33
using Reexport
44
using LinearAlgebra
5+
using LinearSolve
56
using SparseArrays, Printf
6-
using Base.Threads
77
@reexport import CommonSolve: solve, solve!, init
88
using Reexport
99

1010
using LinearAlgebra: rmul!
1111

12-
# const mul! = A_mul_B!
13-
14-
const MT = false
15-
const AMG = AlgebraicMultigrid
16-
1712
include("utils.jl")
1813
export approximate_spectral_radius
1914

src/multilevel.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ function coarse_b!(m::MultiLevelWorkspace{TX, bs}, n) where {TX, bs}
5252
end
5353

5454
abstract type CoarseSolver end
55+
56+
"""
57+
Pinv{T} <: CoarseSolver
58+
59+
Moore-Penrose pseudo inverse coarse solver. Calls `pinv`
60+
"""
5561
struct Pinv{T} <: CoarseSolver
5662
pinvA::Matrix{T}
5763
Pinv{T}(A) where T = new{T}(pinv(Matrix(A)))
@@ -61,6 +67,43 @@ Base.show(io::IO, p::Pinv) = print(io, "Pinv")
6167

6268
(p::Pinv)(x, b) = mul!(x, p.pinvA, b)
6369

70+
# This one is used internally.
71+
"""
72+
LinearSolveWrapperInternal <: CoarseSolver
73+
74+
Helper to allow the usage of LinearSolve.jl solvers for the coarse-level solve. Constructed via `LinearSolveWrapper`.
75+
"""
76+
struct LinearSolveWrapperInternal{LC <: LinearSolve.LinearCache} <: CoarseSolver
77+
linsolve::LC
78+
function LinearSolveWrapperInternal(A, alg::LinearSolve.SciMLLinearSolveAlgorithm)
79+
rhs_tmp = zeros(eltype(A), size(A,1))
80+
u_tmp = zeros(eltype(A), size(A,2))
81+
linprob = LinearProblem(A, rhs_tmp; u0 = u_tmp, alias_A = false, alias_b = false)
82+
linsolve = init(linprob, alg)
83+
new{typeof(linsolve)}(linsolve)
84+
end
85+
end
86+
87+
function (p::LinearSolveWrapperInternal{LC})(x, b) where {LC <: LinearSolve.LinearCache}
88+
for i 1:size(b, 2)
89+
# Update right hand side
90+
p.linsolve.b = b[:, i]
91+
# Solve for x and update
92+
x[:, i] = solve!(p.linsolve).u
93+
end
94+
end
95+
96+
# This one simplifies passing of LinearSolve.jl algorithms into AlgebraicMultigrid.jl as coarse solvers.
97+
"""
98+
LinearSolveWrapper <: CoarseSolver
99+
100+
Helper to allow the usage of LinearSolve.jl solvers for the coarse-level solve.
101+
"""
102+
struct LinearSolveWrapper <: CoarseSolver
103+
alg::LinearSolve.SciMLLinearSolveAlgorithm
104+
end
105+
(p::LinearSolveWrapper)(A::AbstractMatrix) = LinearSolveWrapperInternal(A, p.alg)
106+
64107
Base.length(ml::MultiLevel) = length(ml.levels) + 1
65108

66109
function Base.show(io::IO, ml::MultiLevel)

src/utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ end
9999
find_breakdown(::Type{Float64}) = eps(Float64) * 10^6
100100
find_breakdown(::Type{Float32}) = eps(Float64) * 10^3
101101

102-
using Base.Threads
103102
#=function mul!(α::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat)
104103
A.n == size(B, 1) || throw(DimensionMismatch())
105104
A.m == size(C, 1) || throw(DimensionMismatch())

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using SparseArrays, DelimitedFiles, Random
22
using Test, LinearAlgebra
3-
using IterativeSolvers, AlgebraicMultigrid
3+
using IterativeSolvers, LinearSolve, AlgebraicMultigrid
44
import AlgebraicMultigrid: Pinv, Classical
55
using JLD2
66
using FileIO
@@ -128,6 +128,9 @@ ml = ruge_stuben(A, presmoother = fsmoother,
128128
x = AlgebraicMultigrid._solve(ml, A * ones(1000))
129129
@test sum(abs2, x - ones(1000)) < 1e-8
130130

131+
ml = ruge_stuben(A, coarse_solver=AlgebraicMultigrid.LinearSolveWrapper(UMFPACKFactorization()))
132+
x = AlgebraicMultigrid._solve(ml, A * ones(1000))
133+
@test sum(abs2, x - ones(1000)) < 1e-7
131134

132135
A = include("randlap.jl")
133136

0 commit comments

Comments
 (0)