Skip to content

Commit 31707b7

Browse files
authored
Bug fixes and complex support for in-place factorization (#12)
Fix dimensions problem and improve matrix shift for memory efficiency.
1 parent a995738 commit 31707b7

File tree

4 files changed

+271
-140
lines changed

4 files changed

+271
-140
lines changed

Project.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
name = "QRupdate"
22
uuid = "29d2ae51-7246-454b-9a65-19ff7b2849a5"
3-
authors = ["Michael P. Friedlander <michael@friedlander.io>",
4-
"Michael A. Saunders <saunders@stanford.edu>"]
3+
authors = ["Michael P. Friedlander <michael@friedlander.io>", "Michael A. Saunders <saunders@stanford.edu>"]
54
version = "1.0.0"
65

76
[deps]
@@ -14,4 +13,4 @@ julia = "1.7"
1413
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1514

1615
[targets]
17-
test = ["Test"]
16+
test = ["Test"]

src/QRupdate.jl

Lines changed: 176 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,39 @@ module QRupdate
22

33
using LinearAlgebra
44

5-
export qraddcol, qraddcol!, qraddrow, qrdelcol, qrdelcol!, csne
5+
export qraddcol, qraddcol!, qraddrow, qrdelcol, qrdelcol!, csne, csne!
6+
7+
function swapcols!(M::Matrix{T},i::Int,j::Int) where {T}
8+
Base.permutecols!!(M, replace(axes(M,2), i=>j, j=>i))
9+
end
10+
11+
ORTHO_TOL = 1e-6 # err = |unew|^2 / |uold|^2 < ORTHO_TOL
12+
ORTHO_MAX_IT = 1
13+
verbose = true
614

715
"""
816
Triangular solve `Rx = b`, where `R` is upper triangular of size `realSize x realSize`. The storage of `R` is described in the documentation for `qraddcol!`.
917
"""
10-
function solveR!(R::AbstractMatrix{T}, b::Vector{T}, sol::Vector{T}, realSize::Integer) where {T}
11-
@inbounds sol[realSize] = b[realSize] / R[realSize, realSize]
12-
for i in (realSize-1):-1:1
13-
@inbounds sol[i] = b[i]
14-
for j in realSize:-1:(i+1)
15-
@inbounds sol[i] = sol[i] - R[i,j] * sol[j]
16-
end
17-
@inbounds sol[i] = sol[i] / R[i,i]
18-
end
18+
function solveR!(R::AbstractMatrix{T}, b::AbstractVector{T}, sol::AbstractVector{T}) where {T}
19+
# Note : R is upper triangular
20+
# Note 2: I tried implementing a forward substitution algorithm by
21+
# hand but in the end in was a bit slower and less accurate for smaller
22+
# matrix sizes, so I left the abstract implementation. The Upper/Lower functions
23+
# provide a view only, so they do not incur in further memory allocations. I
24+
# verified this with BenchmarkTooks.
25+
# N Barnafi 06/11/24
26+
sol .= b
27+
ldiv!(UpperTriangular(R), sol)
1928
end
2029

2130
"""
2231
Triangular solve `R'x = b`, where `R` is upper triangular of size `realSize x realSize`. The storage of `R` is described in the documentation for `qraddcol!`.
2332
"""
24-
function solveRT!(R::AbstractMatrix{T}, b::Vector{T}, sol::Vector{T}, realSize::Integer) where {T}
25-
@inbounds sol[1] = b[1] / R[1, 1]
26-
for i in 2:realSize
27-
@inbounds sol[i] = b[i]
28-
for j in 1:(i-1)
29-
@inbounds sol[i] = sol[i] - R[j,i] * sol[j]
30-
end
31-
@inbounds sol[i] = sol[i] / R[i,i]
32-
end
33+
function solveRT!(R::AbstractMatrix{T}, b::AbstractVector{T}, sol::AbstractVector{T}) where {T}
34+
# Note: R is upper triangular.
35+
# Note 2: We solve for the conjugate transpose.
36+
sol .= b
37+
ldiv!(LowerTriangular(R'), sol)
3338
end
3439

3540

@@ -70,7 +75,7 @@ function qraddcol(A::AbstractMatrix{T}, Rin::AbstractMatrix{T}, a::Vector{T}, β
7075
end
7176

7277
if n == 0
73-
return reshape([anorm], 1, 1)
78+
return reshape([convert(T, anorm)], 1, 1)
7479
end
7580

7681
R = UpperTriangular(Rin)
@@ -127,94 +132,97 @@ R = [0 0 0 R = [r11 0 0 R = [r11 r12 0
127132
0 0 0 0 0 0 0 r22 0
128133
0 0 0] 0 0 0] 0 0 0]
129134
"""
130-
function qraddcol!(A::AbstractMatrix{T}, R::AbstractMatrix{T}, a::Vector{T}, N::Int64, β::T = zero(T)) where {T}
135+
#function qraddcol!(A::AT, R::RT, a::aT, N::Int64, work::wT, work2::w2T, u::uT, z::zT, r::rT) where {AT,RT,aT,wT,w2T,uT,zT,rT,T}
136+
function qraddcol!(A::AbstractMatrix{T}, R::AbstractMatrix{T}, a::AbstractVector{T}, N::Int64, work::AbstractVector{T}, work2::AbstractVector{T}, u::AbstractVector{T}, z::AbstractVector{T}, r::AbstractVector{T}) where {T}
137+
#c,u,z,du,dz are R^n. Only r is R^m
138+
#c -> work; du -> work2. dz is redundant
131139

132-
m = size(A, 1)
133-
134-
# First add vector to A
135-
for i in 1:m
136-
@inbounds A[i,N+1] = a[i]
140+
#@timeit "get views" begin
141+
m, n = size(A)
142+
@assert size(work,1) == n "Expected "*string(n)*", actual size: " * string(size(work))
143+
@assert size(work2,1) == n
144+
@assert size(u,1) == n
145+
@assert size(z,1) == n
146+
@assert size(r,1) == m
147+
148+
if N < n
149+
cols = 1:N
150+
Atr = view(A, :, cols) #truncated
151+
Rtr = view(R, cols, cols)
152+
work_tr = view(work, cols)
153+
work2_tr = view(work2, cols)
154+
u_tr = view(u, cols)
155+
z_tr = view(z, cols)
156+
else
157+
Atr = A
158+
Rtr = R
159+
work_tr = work
160+
work2_tr = work2
161+
u_tr = u
162+
z_tr = z
137163
end
164+
#end #timeit get views
138165

139-
anorm = norm(a)
166+
#@timeit "norms" begin
167+
anorm = norm(a)
140168
anorm2 = anorm^2
141-
β2 = β^2
142-
if β != 0
143-
anorm2 = anorm2 + β2
144-
anorm = sqrt(anorm2)
145-
end
146169

147170
if N == 0
148-
#return reshape([anorm], 1, 1)
171+
anorm = sqrt(anorm2)
149172
R[1,1] = anorm
173+
view(A,:,N+1) .= a
150174
return
151175
end
176+
#end #timeit norms
177+
178+
# work := c = A'a
179+
mul!(work_tr, Atr', a)
180+
solveRT!(Rtr, work_tr, u_tr) #u = R'\c = R'\work
181+
solveR!(Rtr, u_tr, z_tr) #z = R\u
182+
copy!(r, a)
183+
mul!(r, Atr, z_tr, -1, 1) #r = a - A*z
184+
γ = norm(r)
185+
mul!(work_tr, Atr', r) # r := c = A'r
186+
err = norm(work_tr) / sqrt(anorm2)
187+
188+
# Iterative refinement
189+
if err < ORTHO_TOL
190+
view(R,1:N,N+1) .= view(u, 1:N)
191+
R[N+1,N+1] = γ
192+
view(A,:,N+1) .= a
193+
return
194+
else
152195

153-
c = zeros(T, N)
154-
u = zeros(T, N)
155-
du = zeros(T, N)
156-
157-
for i in 1:N #c = A'a
158-
for j in 1:m
159-
@inbounds c[i] += A[j,i] * a[j]
160-
end
161-
end
162-
solveRT!(R, c, u, N) #u = R'\c
163-
unorm2 = norm(u)^2
164-
d2 = anorm2 - unorm2
196+
i = 0
197+
while err > ORTHO_TOL && i < ORTHO_MAX_IT
165198

166-
z = zeros(T, N)
167-
dz = zeros(T, N)
168-
r = zeros(T, m)
199+
solveRT!(Rtr, work_tr, work2_tr) # work2 := du = R'\c
200+
solveR!(Rtr, work2_tr, work_tr) # work := dz = R\du
201+
axpy!(1.0, work_tr, z_tr) #z += dz # Refine z
202+
#@timeit "residual 2" begin
169203

170-
if d2 > anorm2
171-
γ = sqrt(d2)
172-
else
173-
solveR!(R, u, z, N) #z = R\u
174-
#mul!(r, A, z, -1, 1) #r = a - A*z
175-
for i in 1:m
176-
@inbounds r[i] = a[i]
177-
for j in 1:N
178-
@inbounds r[i] -= A[i,j] * z[j]
179-
end
180-
end
181-
#mul!(c, A', r) #c = A'r
182-
c[:] .= zero(T)
183-
for i in 1:N
184-
for j in 1:m
185-
@inbounds c[i] += A[j,i] * r[j]
186-
end
187-
end
204+
copy!(r, a)
205+
mul!(r, Atr, z_tr, -1.0, 1.0) #r = a - A*z
206+
γ = norm(r)
207+
work .= 0.0
208+
mul!(work_tr, Atr', r) # work := c = A'r
188209

189-
if !iszero(β)
190-
axpy!(-β2, z, c) #c = c - β2*z
191-
end
192-
solveRT!(R, c, du, N) #du = R'\c
193-
solveR!(R, du, dz, N) #dz = R\du
194-
axpy!(1, dz, z) #z += dz # Refine z
195-
# u = R*z # Original: Bjork's version.
196-
axpy!(1, du, u) #u += du # Modification: Refine u
197-
#r = a - A*z
198-
for i in 1:m
199-
@inbounds r[i] = a[i]
200-
for j in 1:N
201-
@inbounds r[i] -= A[i,j] * z[j]
202-
end
203-
end
210+
err = norm(work_tr) / sqrt(anorm2)
211+
#verbose && println(" *** Reorthogonalize ",string(i)," . Error:", err)
212+
verbose && print("*")
213+
i += 1
204214

205-
γ = norm(r) # Safe computation (we know gamma >= 0).
206-
if !iszero(β)
207-
γ = sqrt^2 + β2*norm(z)^2 + β2)
208-
end
209-
end
210215

211-
# Concatenate new row and column to R:
212-
# [ R u
213-
# zeros(1,n) γ ]
214-
for i in 1:N
215-
@inbounds R[i, N+1] = u[i]
216-
end
216+
#if !iszero(β)
217+
#γ = sqrt(γ^2 + β2*norm(z)^2 + β2)
218+
#end
219+
end # while
220+
end # if
221+
222+
axpy!(1, work2_tr, u_tr)
223+
view(R,1:N,N+1) .= view(u, 1:N)
217224
R[N+1,N+1] = γ
225+
view(A,:,N+1) .= a
218226
end
219227

220228
"""
@@ -283,37 +291,40 @@ with a column of zeros. This is useful to avoid copying the matrix.
283291
"""
284292
function qrdelcol!(A::AbstractMatrix{T}, R::AbstractMatrix{T}, k::Integer) where {T}
285293

294+
# Note that R is n x n
286295
m, n = size(A)
296+
mR,nR = size(R)
287297

288298
# Shift columns. This is apparently faster than copying views.
289-
for j in (k+1):n, i in 1:m
290-
@inbounds R[i,j-1] = R[i, j]
291-
@inbounds A[i,j-1] = A[i, j]
292-
end
293-
for i in 1:m
294-
@inbounds R[i,n] = zero(T)
295-
@inbounds A[i,n] = zero(T)
299+
@inbounds for j in (k+1):n
300+
R[:,j-1] .= @view R[:, j]
301+
A[:,j-1] .= @view A[:, j]
296302
end
303+
A[:,n] .= zero(T)
304+
R[:,n] .= zero(T)
297305

298-
for j in k:(n-1) # Forward sweep to reduce k-th row to zeros
299-
@inbounds G, y = givens(R[j+1,j], R[k,j], 1, 2)
300-
@inbounds R[j+1,j] = y
306+
@inbounds for j in k:(nR-1) # Forward sweep to reduce k-th row to zeros
307+
G, y = givens(R[j+1,j], R[k,j], 1, 2)
308+
R[j+1,j] = y
301309
if j<n && !iszero(G.s)
302310
for i in j+1:n
303-
@inbounds tmp = G.c*R[j+1,i] + G.s*R[k,i]
304-
@inbounds R[k,i] = G.c*R[k,i] - conj(G.s)*R[j+1,i]
305-
@inbounds R[j+1,i] = tmp
311+
tmp = G.c*R[j+1,i] + G.s*R[k,i]
312+
R[k,i] = G.c*R[k,i] - conj(G.s)*R[j+1,i]
313+
R[j+1,i] = tmp
306314
end
307315
end
308316
end
317+
#end # timeit givens downdate
309318

310319
# Shift k-th row. We skipped the removed column.
311-
for j in k:(n-1)
320+
@inbounds for j in k:(n-1)
312321
for i in k:j
313-
@inbounds R[i,j] = R[i+1, j]
322+
R[i,j] = R[i+1, j]
314323
end
315-
@inbounds R[j+1,j] = zero(T)
324+
R[j+1,j] = zero(T)
316325
end
326+
#end # timeit shift row
327+
#end #timeit all
317328
end
318329

319330
"""
@@ -347,4 +358,61 @@ function csne(Rin::AbstractMatrix{T}, A::AbstractMatrix{T}, b::Vector{T}) where
347358
return (x, r)
348359
end
349360

361+
function csne!(R::RT, A::AT, b::bT, sol::solT, work::wT, work2::w2T, u::uT, r::rT, N::Int) where {RT,AT,bT,solT,wT,w2T,uT,rT}
362+
#c,u,sol,du are R^n. Only r is R^m
363+
#c -> work; du -> work2. dsol is redundant.
364+
365+
m, n = size(A)
366+
@assert size(sol,1) == n
367+
@assert size(work,1) == n
368+
@assert size(work2,1) == n
369+
@assert size(u,1) == n
370+
@assert size(r,1) == m
371+
if N < n
372+
cols = 1:N
373+
Atr = view(A, :, cols) #truncated
374+
Rtr = view(R, cols, cols)
375+
work_tr = view(work, cols)
376+
work2_tr = view(work2, cols)
377+
u_tr = view(u, cols)
378+
sol_tr = view(sol, cols)
379+
else
380+
Atr = A
381+
Rtr = R
382+
work_tr = work
383+
work2_tr = work2
384+
u_tr = u
385+
sol_tr = sol
386+
end
387+
bnorm = norm(b)
388+
bnorm2 = bnorm^2
389+
# work := c = A'b
390+
mul!(work_tr, Atr', b)
391+
solveRT!(Rtr, work_tr, u_tr)
392+
393+
solveR!(Rtr, u_tr, sol_tr) #z = R\u
394+
copy!(r, b)
395+
mul!(r, Atr, sol_tr, -1, 1) #r = b - A*z
396+
mul!(work_tr, Atr', r) # r := c = A'r
397+
err = norm(work_tr) / bnorm
398+
399+
i = 0
400+
while err > ORTHO_TOL && i < ORTHO_MAX_IT
401+
402+
solveRT!(Rtr, work_tr, work2_tr) # work2 := du = R'\c
403+
solveR!(Rtr, work2_tr, work_tr) # work := dz = R\du
404+
axpy!(1.0, work_tr, sol_tr) #z += dz # Refine z
405+
406+
copy!(r, b)
407+
mul!(r, Atr, sol_tr, -1.0, 1.0) #r = b - A*z
408+
mul!(work_tr, Atr', r) # work := c = A'r
409+
410+
err = norm(work_tr) / bnorm
411+
#verbose && println(" *** Reorthogonalize ",string(i), " CSNE. Error:", err)
412+
verbose && print("*")
413+
i += 1
414+
415+
end
416+
end
417+
350418
end # module

test/execution-time-test.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using QRupdate
2+
using LinearAlgebra
3+
using BenchmarkTools
4+
5+
for mm in [1000,2000,4000,10000, 20000,50000,100000]
6+
#reset_timer()
7+
m, n = mm, 100
8+
A = randn(m, n)
9+
R = qr(A).R
10+
Rin = deepcopy(R)
11+
Ain = deepcopy(A)
12+
13+
actual_size = n
14+
i = 20
15+
println("====== R ", i)
16+
@btime $R = qrdelcol($R, $i)
17+
println("====== Rin ", i)
18+
@btime qrdelcol!($Ain, $Rin, $i)
19+
end

0 commit comments

Comments
 (0)