Skip to content

Commit f89ad41

Browse files
authored
Merge pull request #1878 from JuliaRobotics/25Q3/enh/keepsol
fast-forward keep solve changes (first use is extracting DERelative odes)
2 parents a1b0fdb + 2fd5f6f commit f89ad41

File tree

10 files changed

+110
-30
lines changed

10 files changed

+110
-30
lines changed

IncrementalInference/ext/IncrInfrDiffEqFactorExt.jl

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ function DERelative(
7575
state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)),
7676
tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi),
7777
problemType = ODEProblem, # DiscreteProblem,
78+
keepSolution::Bool = false
7879
)
7980
#
8081
datatuple = if 2 < length(Xi)
@@ -87,8 +88,21 @@ function DERelative(
8788
fproblem = problemType(f, state0, tspan, datatuple; dt)
8889
# backward time problem
8990
bproblem = problemType(f, state1, (tspan[2], tspan[1]), datatuple; dt = -dt)
91+
92+
_keepSolution = if keepSolution
93+
Vector{typeof((;t=Vector{Float64}(),u=Vector{Vector{Float64}}()))}()
94+
else
95+
nothing
96+
end
97+
9098
# build the IIF recognizable object
91-
return DERelative(domain, fproblem, bproblem, datatuple) #, getSample)
99+
return DERelative(
100+
domain,
101+
fproblem,
102+
bproblem,
103+
datatuple,
104+
_keepSolution
105+
)
92106
end
93107

94108
function DERelative(
@@ -103,6 +117,7 @@ function DERelative(
103117
state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), #zeros(getDimension(domain)),
104118
tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi),
105119
problemType = DiscreteProblem,
120+
keepSolution::Bool = false
106121
)
107122
return DERelative(
108123
Xi,
@@ -114,6 +129,7 @@ function DERelative(
114129
state1,
115130
tspan,
116131
problemType,
132+
keepSolution,
117133
)
118134
end
119135
#
@@ -301,12 +317,28 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
301317
cf._legacyParams[1], M_
302318
end
303319

320+
# solve ODE for N many particles transiting this factor
321+
kS = cf.factor.keepSolution
322+
if !isnothing( kS )
323+
resize!(kS, N)
324+
# @info "kS" length(kS)
325+
for i in 1:N
326+
kS[i] = (;t=Vector{Float64}(),u=Vector{Vector{Float64}}())
327+
end
328+
end
304329
# solve likely elements
305330
for i = 1:N
306331
# TODO, does this respect hyporecipe ???
307332
idxArr = (k -> cf._legacyParams[k][i]).(1:length(cf._legacyParams))
308-
_solveFactorODE!(meas[i], prob, u0pts[i], _maketuplebeyond2args(idxArr...)...)
309-
# _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(cf._legacyParams...)...)
333+
oderes = _solveFactorODE!(meas[i], prob, u0pts[i], _maketuplebeyond2args(idxArr...)...)
334+
if !isnothing( kS )
335+
resize!(kS[i].t, length(oderes.t))
336+
resize!(kS[i].u, length(oderes.u))
337+
for j = 1:length(oderes.t)
338+
kS[i].t[j] = oderes.t[j]
339+
kS[i].u[j] = oderes.u[j]
340+
end
341+
end
310342
end
311343

312344
# return meas, M

IncrementalInference/src/entities/ExtFactors.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ DevNotes
1919
- FIXME Lots of consolidation and standardization to do, see RoME.jl #244 regarding Manifolds.jl.
2020
- TODO does not yet handle case where a factor spans across two timezones.
2121
"""
22-
struct DERelative{T <: StateType, P, D} <: RelativeObservation
22+
struct DERelative{T <: StateType, P, D, O} <: RelativeObservation
2323
domain::Type{T}
2424
forwardProblem::P
2525
backwardProblem::P
2626
""" second element of this data tuple is additional variables that will be passed down as a parameter """
2727
data::D
28-
# specialSampler::Function
28+
keepSolution::O
2929
end

IncrementalInference/src/entities/FactorOperationalMemory.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ Base.@kwdef struct CommonConvWrapper{
2626
AM <: AbstractManifold,
2727
HR <: HypoRecipeCompute,
2828
MT,
29-
G
29+
G,
30+
KCF <: Union{Nothing, <: Channel{<: CalcFactor}}
3031
} <: FactorCache
3132
# Basic factor topological info
3233
""" Values consistent across all threads during approx convolution """
@@ -67,6 +68,8 @@ Base.@kwdef struct CommonConvWrapper{
6768
res::Vector{Float64} = zeros(manifold_dimension(manifold))
6869
""" experimental feature to embed gradient calcs with ccw """
6970
_gradients::G = nothing
71+
""" working memory to store residual from optimization routines """
72+
keepCalcFactor::KCF
7073
end
7174

7275

IncrementalInference/src/services/ApproxConv.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ function approxConvBelief(
1010
N::Int = length(measurement),
1111
nullSurplus::Real = 0,
1212
skipSolve::Bool = false,
13+
keepCalcFactor::Union{Nothing, <:Channel} = nothing,
1314
)
1415
#
1516
v_trg = getVariable(dfg, target)
@@ -25,7 +26,8 @@ function approxConvBelief(
2526
solveKey,
2627
N,
2728
skipSolve,
28-
nullSurplus
29+
nullSurplus,
30+
keepCalcFactor
2931
)
3032

3133
len = length(ipc)
@@ -85,6 +87,7 @@ function approxConvBelief(
8587
path::AbstractVector{Symbol} = Symbol[],
8688
skipSolve::Bool = false,
8789
nullSurplus::Real = 0,
90+
keepCalcFactor::Union{Nothing, <:Channel} = nothing,
8891
)
8992
#
9093
# @assert isVariable(dfg, target) "approxConv(dfg, from, target,...) where `target`=$target must be a variable in `dfg`"
@@ -137,6 +140,7 @@ function approxConvBelief(
137140
N,
138141
skipSolve,
139142
nullSurplus,
143+
keepCalcFactor,
140144
)
141145
if length(path) == 2
142146
return pts1Bel
@@ -155,7 +159,7 @@ function approxConvBelief(
155159
# this is a factor path[idx]
156160
fct = getFactor(dfg, path[idx])
157161
addFactor!(tfg, fct)
158-
ptsBel = approxConvBelief(tfg, fct, path[idx + 1]; solveKey, N, skipSolve)
162+
ptsBel = approxConvBelief(tfg, fct, path[idx + 1]; solveKey, N, skipSolve, keepCalcFactor)
159163
initVariable!(tfg, path[idx + 1], ptsBel)
160164
!setPPE ? nothing : setPPE!(tfg, path[idx + 1], solveKey, ppemethod)
161165
end
@@ -184,10 +188,11 @@ function calcProposalBelief(
184188
solveKey::Symbol = :default,
185189
nullSurplus::Real = 0,
186190
dbg::Bool = false,
191+
keepCalcFactor::Union{Nothing, <:Channel} = nothing,
187192
)
188193
#
189194
# assuming it is properly initialized TODO
190-
proposal = approxConvBelief(dfg, fct, target, measurement; solveKey, N, nullSurplus)
195+
proposal = approxConvBelief(dfg, fct, target, measurement; solveKey, N, nullSurplus, keepCalcFactor)
191196

192197
# return the proposal belief and inferdim, NOTE likely to be changed
193198
return proposal

IncrementalInference/src/services/CalcFactor.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ function _createCCW(
376376
_blockRecursion::Bool = false,
377377
attemptGradients::Bool = true,
378378
userCache::CT = nothing,
379+
keepCalcFactor::Bool = false,
379380
) where {T <: AbstractFactor, CT}
380381
#
381382
if length(Xi) !== 0
@@ -394,7 +395,7 @@ function _createCCW(
394395
fullvariables = tuple(Xi...) # convert(Vector{VariableCompute}, Xi)
395396
# create a temporary CalcFactor object for extracting the first sample
396397

397-
_cf = CalcFactorNormSq(
398+
_cf_nt = CalcFactorNormSq(
398399
usrfnc,
399400
1,
400401
_varValsAll,
@@ -408,12 +409,12 @@ function _createCCW(
408409
)
409410

410411
# get a measurement sample
411-
meas_single = sampleFactor(_cf, 1)[1]
412+
meas_single = sampleFactor(_cf_nt, 1)[1]
412413
elT = typeof(meas_single)
413414
#TODO preallocate measurement?
414415
measurement = Vector{elT}()
415416

416-
#FIXME chicken and egg problem for getting measurement type, so creating twice.
417+
# NOTE chicken and egg problem for getting measurement type, so creating twice.
417418
_cf = CalcFactorNormSq(
418419
usrfnc,
419420
1,
@@ -427,6 +428,11 @@ function _createCCW(
427428
nothing,
428429
)
429430

431+
keepCalcFactor_ = if keepCalcFactor
432+
Channel{CalcFactor}(1024)
433+
else
434+
nothing
435+
end
430436

431437
# partialDims are sensitive to both which solvefor variable index and whether the factor is partial
432438
partial = hasfield(T, :partial) # FIXME, use isPartial function instead
@@ -441,6 +447,7 @@ function _createCCW(
441447

442448
# as per struct CommonConvWrapper
443449
_gradients = if attemptGradients
450+
# FIXME update to proper AD tools
444451
attemptGradientPrep(
445452
varTypes,
446453
usrfnc,
@@ -477,6 +484,7 @@ function _createCCW(
477484
),
478485
measurement,
479486
_gradients,
487+
keepCalcFactor = keepCalcFactor_
480488
)
481489
end
482490

@@ -485,14 +493,15 @@ function updateMeasurement!(
485493
N::Int=1;
486494
measurement::AbstractVector = Vector{Tuple{}}(),
487495
needFreshMeasurements::Bool=true,
488-
_allowThreads::Bool = true
496+
_allowThreads::Bool = true,
497+
keepCalcFactor::Union{Nothing, <:Channel} = nothing,
489498
)
490499
# FIXME do not divert Mixture for sampling
491500

492501
# option to disable fresh samples or user provided
493502
if needFreshMeasurements
494503
# TODO this is only one thread, make this a for loop for multithreaded sampling
495-
sampleFactor!(ccwl, N; _allowThreads)
504+
sampleFactor!(ccwl, N; _allowThreads, keepCalcFactor)
496505
elseif 0 < length(measurement)
497506
resize!(ccwl.measurement, length(measurement))
498507
ccwl.measurement[:] = measurement
@@ -520,6 +529,7 @@ function _beforeSolveCCW!(
520529
measurement = Vector{Tuple{}}(),
521530
needFreshMeasurements::Bool = true,
522531
solveKey::Symbol = :default,
532+
keepCalcFactor::Union{Nothing, <:Channel} = nothing,
523533
) where {F <: AbstractFactor} # F might be Mixture
524534
#
525535
if length(variables) !== 0
@@ -570,7 +580,7 @@ function _beforeSolveCCW!(
570580
_setCCWDecisionDimsConv!(ccwl, xDim)
571581

572582
# FIXME do not divert Mixture for sampling
573-
updateMeasurement!(ccwl, maxlen; needFreshMeasurements, measurement, _allowThreads=true)
583+
updateMeasurement!(ccwl, maxlen; needFreshMeasurements, measurement, _allowThreads=true, keepCalcFactor)
574584

575585
# used in ccw functor for AbstractRelativeMinimize
576586
resize!(ccwl.res, _getZDim(ccwl))
@@ -591,6 +601,7 @@ function _beforeSolveCCW!(
591601
measurement = Vector{Tuple{}}(),
592602
needFreshMeasurements::Bool = true,
593603
solveKey::Symbol = :default,
604+
keepCalcFactor::Union{Nothing, <:Channel} = nothing,
594605
) where {F <: AbstractFactor} # F might be Mixture
595606
# FIXME, NEEDS TO BE CLEANED UP AND WORK ON MANIFOLDS PROPER
596607

@@ -606,7 +617,7 @@ function _beforeSolveCCW!(
606617

607618
# FIXME do not divert Mixture for sampling
608619
# update ccwl.measurement values
609-
updateMeasurement!(ccwl, maxlen; needFreshMeasurements, measurement, _allowThreads=true)
620+
updateMeasurement!(ccwl, maxlen; needFreshMeasurements, measurement, _allowThreads=true, keepCalcFactor)
610621

611622
return maxlen
612623
end

IncrementalInference/src/services/EvalFactor.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,14 +335,15 @@ function evalPotentialSpecific(
335335
dbg::Bool = false,
336336
skipSolve::Bool = false,
337337
_slack = nothing,
338+
keepCalcFactor::Union{Nothing, <:Channel} = nothing,
338339
) where {T <: AbstractFactor}
339340
#
340341

341342
# Prep computation variables
342343
# add user desired measurement values if 0 < length
343344
# 2023Q2, ccwl.varValsAll always points at the variable.VND.val memory locations
344345
# remember when doing approxConv to make a deepcopy of the destination memory first.
345-
maxlen = _beforeSolveCCW!(ccwl, variables, sfidx, N; solveKey, needFreshMeasurements, measurement)
346+
maxlen = _beforeSolveCCW!(ccwl, variables, sfidx, N; solveKey, needFreshMeasurements, measurement, keepCalcFactor)
346347

347348
# Check which variables have been initialized
348349
isinit = map(x -> isInitialized(x, solveKey), variables)
@@ -414,11 +415,12 @@ function evalPotentialSpecific(
414415
dbg::Bool = false,
415416
skipSolve::Bool = false,
416417
_slack = nothing,
418+
keepCalcFactor::Union{Nothing, <:Channel} = nothing,
417419
) where {T <: AbstractFactor}
418420
#
419421

420422
# Prep computation variables
421-
maxlen = _beforeSolveCCW!(ccwl, variables, sfidx, N; solveKey, needFreshMeasurements, measurement)
423+
maxlen = _beforeSolveCCW!(ccwl, variables, sfidx, N; solveKey, needFreshMeasurements, measurement, keepCalcFactor)
422424

423425
# # FIXME, NEEDS TO BE CLEANED UP AND WORK ON MANIFOLDS PROPER
424426
fnc = ccwl.usrfnc!
@@ -582,6 +584,7 @@ function evalFactor(
582584
dbg::Bool = false,
583585
skipSolve::Bool = false,
584586
_slack = nothing,
587+
keepCalcFactor::Union{Nothing, <:Channel} = nothing,
585588
)
586589
#
587590
return evalPotentialSpecific(
@@ -598,6 +601,7 @@ function evalFactor(
598601
nullSurplus,
599602
skipSolve,
600603
_slack,
604+
keepCalcFactor,
601605
)
602606
#
603607
end

IncrementalInference/src/services/FactorGraph.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ function getDefaultFactorData(
723723
solveInProgress = 0,
724724
inflation::Real = getSolverParams(dfg).inflation,
725725
_blockRecursion::Bool = false,
726+
keepCalcFactor::Bool = false,
726727
) where {T <: AbstractFactor}
727728
#
728729

@@ -741,6 +742,7 @@ function getDefaultFactorData(
741742
attemptGradients = getSolverParams(dfg).attemptGradients,
742743
_blockRecursion,
743744
userCache,
745+
keepCalcFactor,
744746
)
745747

746748
state = DFG.FactorState(
@@ -829,6 +831,7 @@ function DFG.addFactor!(
829831
inflation::Real = getSolverParams(dfg).inflation,
830832
namestring::Symbol = assembleFactorName(dfg, Xi),
831833
_blockRecursion::Bool = !getSolverParams(dfg).attemptGradients,
834+
keepCalcFactor::Bool = false,
832835
)
833836
#
834837

@@ -847,6 +850,7 @@ function DFG.addFactor!(
847850
# threadmodel,
848851
inflation,
849852
_blockRecursion,
853+
keepCalcFactor,
850854
)
851855
#
852856
newFactor = FactorCompute(

IncrementalInference/src/services/NumericalCalculations.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ function _solveLambdaNumeric(
157157
# TODO find good way for a solve to store diagnostics about number of failed converges etc.
158158
@warn "Optim did not converge (maxlog=10):" r maxlog=10
159159
end
160+
161+
# FIXME, how to use this exp when either Manifolds or LieGroups is used?
160162
return exp(M, ϵ, hat(M, ϵ, r.minimizer))
161163
end
162164

0 commit comments

Comments
 (0)