@@ -2,34 +2,39 @@ module QRupdate
2
2
3
3
using LinearAlgebra
4
4
5
- export qraddcol, qraddcol!, qraddrow, qrdelcol, qrdelcol!, csne
5
+ export qraddcol, qraddcol!, qraddrow, qrdelcol, qrdelcol!, csne, csne!
6
+
7
+ function swapcols! (M:: Matrix{T} ,i:: Int ,j:: Int ) where {T}
8
+ Base. permutecols!! (M, replace (axes (M,2 ), i=> j, j=> i))
9
+ end
10
+
11
+ ORTHO_TOL = 1e-6 # err = |unew|^2 / |uold|^2 < ORTHO_TOL
12
+ ORTHO_MAX_IT = 1
13
+ verbose = true
6
14
7
15
"""
8
16
Triangular solve `Rx = b`, where `R` is upper triangular of size `realSize x realSize`. The storage of `R` is described in the documentation for `qraddcol!`.
9
17
"""
10
- function solveR! (R:: AbstractMatrix{T} , b:: Vector{T} , sol:: Vector{T} , realSize:: Integer ) where {T}
11
- @inbounds sol[realSize] = b[realSize] / R[realSize, realSize]
12
- for i in (realSize- 1 ): - 1 : 1
13
- @inbounds sol[i] = b[i]
14
- for j in realSize: - 1 : (i+ 1 )
15
- @inbounds sol[i] = sol[i] - R[i,j] * sol[j]
16
- end
17
- @inbounds sol[i] = sol[i] / R[i,i]
18
- end
18
+ function solveR! (R:: AbstractMatrix{T} , b:: AbstractVector{T} , sol:: AbstractVector{T} ) where {T}
19
+ # Note : R is upper triangular
20
+ # Note 2: I tried implementing a forward substitution algorithm by
21
+ # hand but in the end in was a bit slower and less accurate for smaller
22
+ # matrix sizes, so I left the abstract implementation. The Upper/Lower functions
23
+ # provide a view only, so they do not incur in further memory allocations. I
24
+ # verified this with BenchmarkTooks.
25
+ # N Barnafi 06/11/24
26
+ sol .= b
27
+ ldiv! (UpperTriangular (R), sol)
19
28
end
20
29
21
30
"""
22
31
Triangular solve `R'x = b`, where `R` is upper triangular of size `realSize x realSize`. The storage of `R` is described in the documentation for `qraddcol!`.
23
32
"""
24
- function solveRT! (R:: AbstractMatrix{T} , b:: Vector{T} , sol:: Vector{T} , realSize:: Integer ) where {T}
25
- @inbounds sol[1 ] = b[1 ] / R[1 , 1 ]
26
- for i in 2 : realSize
27
- @inbounds sol[i] = b[i]
28
- for j in 1 : (i- 1 )
29
- @inbounds sol[i] = sol[i] - R[j,i] * sol[j]
30
- end
31
- @inbounds sol[i] = sol[i] / R[i,i]
32
- end
33
+ function solveRT! (R:: AbstractMatrix{T} , b:: AbstractVector{T} , sol:: AbstractVector{T} ) where {T}
34
+ # Note: R is upper triangular.
35
+ # Note 2: We solve for the conjugate transpose.
36
+ sol .= b
37
+ ldiv! (LowerTriangular (R' ), sol)
33
38
end
34
39
35
40
@@ -70,7 +75,7 @@ function qraddcol(A::AbstractMatrix{T}, Rin::AbstractMatrix{T}, a::Vector{T}, β
70
75
end
71
76
72
77
if n == 0
73
- return reshape ([anorm], 1 , 1 )
78
+ return reshape ([convert (T, anorm) ], 1 , 1 )
74
79
end
75
80
76
81
R = UpperTriangular (Rin)
@@ -127,94 +132,97 @@ R = [0 0 0 R = [r11 0 0 R = [r11 r12 0
127
132
0 0 0 0 0 0 0 r22 0
128
133
0 0 0] 0 0 0] 0 0 0]
129
134
"""
130
- function qraddcol! (A:: AbstractMatrix{T} , R:: AbstractMatrix{T} , a:: Vector{T} , N:: Int64 , β:: T = zero (T)) where {T}
135
+ # function qraddcol!(A::AT, R::RT, a::aT, N::Int64, work::wT, work2::w2T, u::uT, z::zT, r::rT) where {AT,RT,aT,wT,w2T,uT,zT,rT,T}
136
+ function qraddcol! (A:: AbstractMatrix{T} , R:: AbstractMatrix{T} , a:: AbstractVector{T} , N:: Int64 , work:: AbstractVector{T} , work2:: AbstractVector{T} , u:: AbstractVector{T} , z:: AbstractVector{T} , r:: AbstractVector{T} ) where {T}
137
+ # c,u,z,du,dz are R^n. Only r is R^m
138
+ # c -> work; du -> work2. dz is redundant
131
139
132
- m = size (A, 1 )
133
-
134
- # First add vector to A
135
- for i in 1 : m
136
- @inbounds A[i,N+ 1 ] = a[i]
140
+ # @timeit "get views" begin
141
+ m, n = size (A)
142
+ @assert size (work,1 ) == n " Expected " * string (n)* " , actual size: " * string (size (work))
143
+ @assert size (work2,1 ) == n
144
+ @assert size (u,1 ) == n
145
+ @assert size (z,1 ) == n
146
+ @assert size (r,1 ) == m
147
+
148
+ if N < n
149
+ cols = 1 : N
150
+ Atr = view (A, :, cols) # truncated
151
+ Rtr = view (R, cols, cols)
152
+ work_tr = view (work, cols)
153
+ work2_tr = view (work2, cols)
154
+ u_tr = view (u, cols)
155
+ z_tr = view (z, cols)
156
+ else
157
+ Atr = A
158
+ Rtr = R
159
+ work_tr = work
160
+ work2_tr = work2
161
+ u_tr = u
162
+ z_tr = z
137
163
end
164
+ # end #timeit get views
138
165
139
- anorm = norm (a)
166
+ # @timeit "norms" begin
167
+ anorm = norm (a)
140
168
anorm2 = anorm^ 2
141
- β2 = β^ 2
142
- if β != 0
143
- anorm2 = anorm2 + β2
144
- anorm = sqrt (anorm2)
145
- end
146
169
147
170
if N == 0
148
- # return reshape([ anorm], 1, 1 )
171
+ anorm = sqrt (anorm2 )
149
172
R[1 ,1 ] = anorm
173
+ view (A,:,N+ 1 ) .= a
150
174
return
151
175
end
176
+ # end #timeit norms
177
+
178
+ # work := c = A'a
179
+ mul! (work_tr, Atr' , a)
180
+ solveRT! (Rtr, work_tr, u_tr) # u = R'\c = R'\work
181
+ solveR! (Rtr, u_tr, z_tr) # z = R\u
182
+ copy! (r, a)
183
+ mul! (r, Atr, z_tr, - 1 , 1 ) # r = a - A*z
184
+ γ = norm (r)
185
+ mul! (work_tr, Atr' , r) # r := c = A'r
186
+ err = norm (work_tr) / sqrt (anorm2)
187
+
188
+ # Iterative refinement
189
+ if err < ORTHO_TOL
190
+ view (R,1 : N,N+ 1 ) .= view (u, 1 : N)
191
+ R[N+ 1 ,N+ 1 ] = γ
192
+ view (A,:,N+ 1 ) .= a
193
+ return
194
+ else
152
195
153
- c = zeros (T, N)
154
- u = zeros (T, N)
155
- du = zeros (T, N)
156
-
157
- for i in 1 : N # c = A'a
158
- for j in 1 : m
159
- @inbounds c[i] += A[j,i] * a[j]
160
- end
161
- end
162
- solveRT! (R, c, u, N) # u = R'\c
163
- unorm2 = norm (u)^ 2
164
- d2 = anorm2 - unorm2
196
+ i = 0
197
+ while err > ORTHO_TOL && i < ORTHO_MAX_IT
165
198
166
- z = zeros (T, N)
167
- dz = zeros (T, N)
168
- r = zeros (T, m)
199
+ solveRT! (Rtr, work_tr, work2_tr) # work2 := du = R'\c
200
+ solveR! (Rtr, work2_tr, work_tr) # work := dz = R\du
201
+ axpy! (1.0 , work_tr, z_tr) # z += dz # Refine z
202
+ # @timeit "residual 2" begin
169
203
170
- if d2 > anorm2
171
- γ = sqrt (d2)
172
- else
173
- solveR! (R, u, z, N) # z = R\u
174
- # mul!(r, A, z, -1, 1) #r = a - A*z
175
- for i in 1 : m
176
- @inbounds r[i] = a[i]
177
- for j in 1 : N
178
- @inbounds r[i] -= A[i,j] * z[j]
179
- end
180
- end
181
- # mul!(c, A', r) #c = A'r
182
- c[:] .= zero (T)
183
- for i in 1 : N
184
- for j in 1 : m
185
- @inbounds c[i] += A[j,i] * r[j]
186
- end
187
- end
204
+ copy! (r, a)
205
+ mul! (r, Atr, z_tr, - 1.0 , 1.0 ) # r = a - A*z
206
+ γ = norm (r)
207
+ work .= 0.0
208
+ mul! (work_tr, Atr' , r) # work := c = A'r
188
209
189
- if ! iszero (β)
190
- axpy! (- β2, z, c) # c = c - β2*z
191
- end
192
- solveRT! (R, c, du, N) # du = R'\c
193
- solveR! (R, du, dz, N) # dz = R\du
194
- axpy! (1 , dz, z) # z += dz # Refine z
195
- # u = R*z # Original: Bjork's version.
196
- axpy! (1 , du, u) # u += du # Modification: Refine u
197
- # r = a - A*z
198
- for i in 1 : m
199
- @inbounds r[i] = a[i]
200
- for j in 1 : N
201
- @inbounds r[i] -= A[i,j] * z[j]
202
- end
203
- end
210
+ err = norm (work_tr) / sqrt (anorm2)
211
+ # verbose && println(" *** Reorthogonalize ",string(i)," . Error:", err)
212
+ verbose && print (" *" )
213
+ i += 1
204
214
205
- γ = norm (r) # Safe computation (we know gamma >= 0).
206
- if ! iszero (β)
207
- γ = sqrt (γ^ 2 + β2* norm (z)^ 2 + β2)
208
- end
209
- end
210
215
211
- # Concatenate new row and column to R:
212
- # [ R u
213
- # zeros(1,n) γ ]
214
- for i in 1 : N
215
- @inbounds R[i, N+ 1 ] = u[i]
216
- end
216
+ # if !iszero(β)
217
+ # γ = sqrt(γ^2 + β2*norm(z)^2 + β2)
218
+ # end
219
+ end # while
220
+ end # if
221
+
222
+ axpy! (1 , work2_tr, u_tr)
223
+ view (R,1 : N,N+ 1 ) .= view (u, 1 : N)
217
224
R[N+ 1 ,N+ 1 ] = γ
225
+ view (A,:,N+ 1 ) .= a
218
226
end
219
227
220
228
"""
@@ -283,37 +291,40 @@ with a column of zeros. This is useful to avoid copying the matrix.
283
291
"""
284
292
function qrdelcol! (A:: AbstractMatrix{T} , R:: AbstractMatrix{T} , k:: Integer ) where {T}
285
293
294
+ # Note that R is n x n
286
295
m, n = size (A)
296
+ mR,nR = size (R)
287
297
288
298
# Shift columns. This is apparently faster than copying views.
289
- for j in (k+ 1 ): n, i in 1 : m
290
- @inbounds R[i,j- 1 ] = R[i, j]
291
- @inbounds A[i,j- 1 ] = A[i, j]
292
- end
293
- for i in 1 : m
294
- @inbounds R[i,n] = zero (T)
295
- @inbounds A[i,n] = zero (T)
299
+ @inbounds for j in (k+ 1 ): n
300
+ R[:,j- 1 ] .= @view R[:, j]
301
+ A[:,j- 1 ] .= @view A[:, j]
296
302
end
303
+ A[:,n] .= zero (T)
304
+ R[:,n] .= zero (T)
297
305
298
- for j in k: (n - 1 ) # Forward sweep to reduce k-th row to zeros
299
- @inbounds G, y = givens (R[j+ 1 ,j], R[k,j], 1 , 2 )
300
- @inbounds R[j+ 1 ,j] = y
306
+ @inbounds for j in k: (nR - 1 ) # Forward sweep to reduce k-th row to zeros
307
+ G, y = givens (R[j+ 1 ,j], R[k,j], 1 , 2 )
308
+ R[j+ 1 ,j] = y
301
309
if j< n && ! iszero (G. s)
302
310
for i in j+ 1 : n
303
- @inbounds tmp = G. c* R[j+ 1 ,i] + G. s* R[k,i]
304
- @inbounds R[k,i] = G. c* R[k,i] - conj (G. s)* R[j+ 1 ,i]
305
- @inbounds R[j+ 1 ,i] = tmp
311
+ tmp = G. c* R[j+ 1 ,i] + G. s* R[k,i]
312
+ R[k,i] = G. c* R[k,i] - conj (G. s)* R[j+ 1 ,i]
313
+ R[j+ 1 ,i] = tmp
306
314
end
307
315
end
308
316
end
317
+ # end # timeit givens downdate
309
318
310
319
# Shift k-th row. We skipped the removed column.
311
- for j in k: (n- 1 )
320
+ @inbounds for j in k: (n- 1 )
312
321
for i in k: j
313
- @inbounds R[i,j] = R[i+ 1 , j]
322
+ R[i,j] = R[i+ 1 , j]
314
323
end
315
- @inbounds R[j+ 1 ,j] = zero (T)
324
+ R[j+ 1 ,j] = zero (T)
316
325
end
326
+ # end # timeit shift row
327
+ # end #timeit all
317
328
end
318
329
319
330
"""
@@ -347,4 +358,61 @@ function csne(Rin::AbstractMatrix{T}, A::AbstractMatrix{T}, b::Vector{T}) where
347
358
return (x, r)
348
359
end
349
360
361
+ function csne! (R:: RT , A:: AT , b:: bT , sol:: solT , work:: wT , work2:: w2T , u:: uT , r:: rT , N:: Int ) where {RT,AT,bT,solT,wT,w2T,uT,rT}
362
+ # c,u,sol,du are R^n. Only r is R^m
363
+ # c -> work; du -> work2. dsol is redundant.
364
+
365
+ m, n = size (A)
366
+ @assert size (sol,1 ) == n
367
+ @assert size (work,1 ) == n
368
+ @assert size (work2,1 ) == n
369
+ @assert size (u,1 ) == n
370
+ @assert size (r,1 ) == m
371
+ if N < n
372
+ cols = 1 : N
373
+ Atr = view (A, :, cols) # truncated
374
+ Rtr = view (R, cols, cols)
375
+ work_tr = view (work, cols)
376
+ work2_tr = view (work2, cols)
377
+ u_tr = view (u, cols)
378
+ sol_tr = view (sol, cols)
379
+ else
380
+ Atr = A
381
+ Rtr = R
382
+ work_tr = work
383
+ work2_tr = work2
384
+ u_tr = u
385
+ sol_tr = sol
386
+ end
387
+ bnorm = norm (b)
388
+ bnorm2 = bnorm^ 2
389
+ # work := c = A'b
390
+ mul! (work_tr, Atr' , b)
391
+ solveRT! (Rtr, work_tr, u_tr)
392
+
393
+ solveR! (Rtr, u_tr, sol_tr) # z = R\u
394
+ copy! (r, b)
395
+ mul! (r, Atr, sol_tr, - 1 , 1 ) # r = b - A*z
396
+ mul! (work_tr, Atr' , r) # r := c = A'r
397
+ err = norm (work_tr) / bnorm
398
+
399
+ i = 0
400
+ while err > ORTHO_TOL && i < ORTHO_MAX_IT
401
+
402
+ solveRT! (Rtr, work_tr, work2_tr) # work2 := du = R'\c
403
+ solveR! (Rtr, work2_tr, work_tr) # work := dz = R\du
404
+ axpy! (1.0 , work_tr, sol_tr) # z += dz # Refine z
405
+
406
+ copy! (r, b)
407
+ mul! (r, Atr, sol_tr, - 1.0 , 1.0 ) # r = b - A*z
408
+ mul! (work_tr, Atr' , r) # work := c = A'r
409
+
410
+ err = norm (work_tr) / bnorm
411
+ # verbose && println(" *** Reorthogonalize ",string(i), " CSNE. Error:", err)
412
+ verbose && print (" *" )
413
+ i += 1
414
+
415
+ end
416
+ end
417
+
350
418
end # module
0 commit comments