Skip to content

Commit f237649

Browse files
authored
Reapply "Online lateral lag learning" (#34975)
* Online lateral lag learning (#34974) This reverts commit b4cc9e6. * pad to the best size for fft * Fix static analysis * Add typing * Fix typing * MAX_LAG * Calculate cross correlation regardless if the points are valid * Back to lagd * Add lagd to process_config * Lagd in test onroad * Move lag estimator for lagd * Remove duplicate entry from test_onroad * Update process replay * pre-fill the data * Update cpu usage * 25sec window * Change the meaning of lateralDelayEstimate * No newline * Fix typing * Prefill * Update ref commit * Add a unit test * Fix static issues * Time limit * Or timeout * Use mocker * Update estimate every time * empty test * DT const * enable RIVIAN again * Update ref commit * Update that again * Improve the tests * Fix static * Add masking test * Increase timeout * Add liveDelay to selfdrived * Add liveDelay to selfdrived in process_replay * Fix block_avg restore after num_blocks * regen most * Update bolt * Update ref commit * Change the key name * Add assert * True weighted average
1 parent 5d1816e commit f237649

File tree

11 files changed

+542
-21
lines changed

11 files changed

+542
-21
lines changed

cereal/services.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(self, should_log: bool, frequency: float, decimation: Optional[int]
3636
"errorLogMessage": (True, 0., 1),
3737
"liveCalibration": (True, 4., 4),
3838
"liveTorqueParameters": (True, 4., 1),
39+
"liveDelay": (True, 4., 1),
3940
"androidLog": (True, 0.),
4041
"carState": (True, 100., 10),
4142
"carControl": (True, 100., 10),

common/params_keys.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ inline static std::unordered_map<std::string, uint32_t> keys = {
7171
{"LastPowerDropDetected", CLEAR_ON_MANAGER_START},
7272
{"LastUpdateException", CLEAR_ON_MANAGER_START},
7373
{"LastUpdateTime", PERSISTENT},
74+
{"LiveDelay", PERSISTENT},
7475
{"LiveParameters", PERSISTENT},
7576
{"LiveParametersV2", PERSISTENT},
7677
{"LiveTorqueParameters", PERSISTENT | DONT_LOG},

selfdrive/locationd/helpers.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,48 @@
11
import numpy as np
22
from typing import Any
3+
from functools import cache
34

45
from cereal import log
56
from openpilot.common.transformations.orientation import rot_from_euler, euler_from_rot
67

78

9+
@cache
10+
def fft_next_good_size(n: int) -> int:
11+
"""
12+
smallest composite of 2, 3, 5, 7, 11 that is >= n
13+
inspired by pocketfft
14+
"""
15+
if n <= 6:
16+
return n
17+
best, f2 = 2 * n, 1
18+
while f2 < best:
19+
f23 = f2
20+
while f23 < best:
21+
f235 = f23
22+
while f235 < best:
23+
f2357 = f235
24+
while f2357 < best:
25+
f235711 = f2357
26+
while f235711 < best:
27+
best = f235711 if f235711 >= n else best
28+
f235711 *= 11
29+
f2357 *= 7
30+
f235 *= 5
31+
f23 *= 3
32+
f2 *= 2
33+
return best
34+
35+
36+
def parabolic_peak_interp(R, max_index):
37+
if max_index == 0 or max_index == len(R) - 1:
38+
return max_index
39+
40+
y_m1, y_0, y_p1 = R[max_index - 1], R[max_index], R[max_index + 1]
41+
offset = 0.5 * (y_p1 - y_m1) / (2 * y_0 - y_p1 - y_m1)
42+
43+
return max_index + offset
44+
45+
846
def rotate_cov(rot_matrix, cov_in):
947
return rot_matrix @ cov_in @ rot_matrix.T
1048

selfdrive/locationd/lagd.py

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
#!/usr/bin/env python3
2+
import os
3+
import numpy as np
4+
import capnp
5+
from collections import deque
6+
from functools import partial
7+
8+
import cereal.messaging as messaging
9+
from cereal import car, log
10+
from cereal.services import SERVICE_LIST
11+
from openpilot.common.params import Params
12+
from openpilot.common.realtime import config_realtime_process
13+
from openpilot.common.swaglog import cloudlog
14+
from openpilot.selfdrive.locationd.helpers import PoseCalibrator, Pose, fft_next_good_size, parabolic_peak_interp
15+
16+
BLOCK_SIZE = 100
17+
BLOCK_NUM = 50
18+
BLOCK_NUM_NEEDED = 5
19+
MOVING_WINDOW_SEC = 300.0
20+
MIN_OKAY_WINDOW_SEC = 25.0
21+
MIN_RECOVERY_BUFFER_SEC = 2.0
22+
MIN_VEGO = 15.0
23+
MIN_ABS_YAW_RATE = np.radians(1.0)
24+
MIN_NCC = 0.95
25+
MAX_LAG = 1.0
26+
27+
28+
def masked_normalized_cross_correlation(expected_sig: np.ndarray, actual_sig: np.ndarray, mask: np.ndarray, n: int):
29+
"""
30+
References:
31+
D. Padfield. "Masked FFT registration". In Proc. Computer Vision and
32+
Pattern Recognition, pp. 2918-2925 (2010).
33+
:DOI:`10.1109/CVPR.2010.5540032`
34+
"""
35+
36+
eps = np.finfo(np.float64).eps
37+
expected_sig = np.asarray(expected_sig, dtype=np.float64)
38+
actual_sig = np.asarray(actual_sig, dtype=np.float64)
39+
40+
expected_sig[~mask] = 0.0
41+
actual_sig[~mask] = 0.0
42+
43+
rotated_expected_sig = expected_sig[::-1]
44+
rotated_mask = mask[::-1]
45+
46+
fft = partial(np.fft.fft, n=n)
47+
48+
actual_sig_fft = fft(actual_sig)
49+
rotated_expected_sig_fft = fft(rotated_expected_sig)
50+
actual_mask_fft = fft(mask.astype(np.float64))
51+
rotated_mask_fft = fft(rotated_mask.astype(np.float64))
52+
53+
number_overlap_masked_samples = np.fft.ifft(rotated_mask_fft * actual_mask_fft).real
54+
number_overlap_masked_samples[:] = np.round(number_overlap_masked_samples)
55+
number_overlap_masked_samples[:] = np.fmax(number_overlap_masked_samples, eps)
56+
masked_correlated_actual_fft = np.fft.ifft(rotated_mask_fft * actual_sig_fft).real
57+
masked_correlated_expected_fft = np.fft.ifft(actual_mask_fft * rotated_expected_sig_fft).real
58+
59+
numerator = np.fft.ifft(rotated_expected_sig_fft * actual_sig_fft).real
60+
numerator -= masked_correlated_actual_fft * masked_correlated_expected_fft / number_overlap_masked_samples
61+
62+
actual_squared_fft = fft(actual_sig ** 2)
63+
actual_sig_denom = np.fft.ifft(rotated_mask_fft * actual_squared_fft).real
64+
actual_sig_denom -= masked_correlated_actual_fft ** 2 / number_overlap_masked_samples
65+
actual_sig_denom[:] = np.fmax(actual_sig_denom, 0.0)
66+
67+
rotated_expected_squared_fft = fft(rotated_expected_sig ** 2)
68+
expected_sig_denom = np.fft.ifft(actual_mask_fft * rotated_expected_squared_fft).real
69+
expected_sig_denom -= masked_correlated_expected_fft ** 2 / number_overlap_masked_samples
70+
expected_sig_denom[:] = np.fmax(expected_sig_denom, 0.0)
71+
72+
denom = np.sqrt(actual_sig_denom * expected_sig_denom)
73+
74+
# zero-out samples with very small denominators
75+
tol = 1e3 * eps * np.max(np.abs(denom), keepdims=True)
76+
nonzero_indices = denom > tol
77+
78+
ncc = np.zeros_like(denom, dtype=np.float64)
79+
ncc[nonzero_indices] = numerator[nonzero_indices] / denom[nonzero_indices]
80+
np.clip(ncc, -1, 1, out=ncc)
81+
82+
return ncc
83+
84+
85+
class Points:
86+
def __init__(self, num_points: int):
87+
self.times = deque[float]([0.0] * num_points, maxlen=num_points)
88+
self.okay = deque[bool]([False] * num_points, maxlen=num_points)
89+
self.desired = deque[float]([0.0] * num_points, maxlen=num_points)
90+
self.actual = deque[float]([0.0] * num_points, maxlen=num_points)
91+
92+
@property
93+
def num_points(self):
94+
return len(self.desired)
95+
96+
@property
97+
def num_okay(self):
98+
return np.count_nonzero(self.okay)
99+
100+
def update(self, t: float, desired: float, actual: float, okay: bool):
101+
self.times.append(t)
102+
self.okay.append(okay)
103+
self.desired.append(desired)
104+
self.actual.append(actual)
105+
106+
def get(self) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
107+
return np.array(self.times), np.array(self.desired), np.array(self.actual), np.array(self.okay)
108+
109+
110+
class BlockAverage:
111+
def __init__(self, num_blocks: int, block_size: int, valid_blocks: int, initial_value: float):
112+
self.num_blocks = num_blocks
113+
self.block_size = block_size
114+
self.block_idx = valid_blocks % num_blocks
115+
self.idx = 0
116+
117+
self.values = np.tile(initial_value, (num_blocks, 1))
118+
self.valid_blocks = valid_blocks
119+
120+
def update(self, value: float):
121+
self.values[self.block_idx] = (self.idx * self.values[self.block_idx] + value) / (self.idx + 1)
122+
self.idx = (self.idx + 1) % self.block_size
123+
if self.idx == 0:
124+
self.block_idx = (self.block_idx + 1) % self.num_blocks
125+
self.valid_blocks = min(self.valid_blocks + 1, self.num_blocks)
126+
127+
def get(self) -> tuple[float, float]:
128+
valid_block_idx = [i for i in range(self.valid_blocks) if i != self.block_idx]
129+
valid_and_current_idx = valid_block_idx + ([self.block_idx] if self.idx > 0 else [])
130+
131+
valid_mean = float(np.mean(self.values[valid_block_idx], axis=0).item()) if len(valid_block_idx) > 0 else float('nan')
132+
current_mean = float(np.mean(self.values[valid_and_current_idx], axis=0).item()) if len(valid_and_current_idx) > 0 else float('nan')
133+
return valid_mean, current_mean
134+
135+
136+
class LateralLagEstimator:
137+
inputs = {"carControl", "carState", "controlsState", "liveCalibration", "livePose"}
138+
139+
def __init__(self, CP: car.CarParams, dt: float,
140+
block_count: int = BLOCK_NUM, min_valid_block_count: int = BLOCK_NUM_NEEDED, block_size: int = BLOCK_SIZE,
141+
window_sec: float = MOVING_WINDOW_SEC, okay_window_sec: float = MIN_OKAY_WINDOW_SEC, min_recovery_buffer_sec: float = MIN_RECOVERY_BUFFER_SEC,
142+
min_vego: float = MIN_VEGO, min_yr: float = MIN_ABS_YAW_RATE, min_ncc: float = MIN_NCC):
143+
self.dt = dt
144+
self.window_sec = window_sec
145+
self.okay_window_sec = okay_window_sec
146+
self.min_recovery_buffer_sec = min_recovery_buffer_sec
147+
self.initial_lag = CP.steerActuatorDelay + 0.2
148+
self.block_size = block_size
149+
self.block_count = block_count
150+
self.min_valid_block_count = min_valid_block_count
151+
self.min_vego = min_vego
152+
self.min_yr = min_yr
153+
self.min_ncc = min_ncc
154+
155+
self.t = 0.0
156+
self.lat_active = False
157+
self.steering_pressed = False
158+
self.steering_saturated = False
159+
self.desired_curvature = 0.0
160+
self.v_ego = 0.0
161+
self.yaw_rate = 0.0
162+
163+
self.last_lat_inactive_t = 0.0
164+
self.last_steering_pressed_t = 0.0
165+
self.last_steering_saturated_t = 0.0
166+
self.last_estimate_t = 0.0
167+
168+
self.calibrator = PoseCalibrator()
169+
170+
self.reset(self.initial_lag, 0)
171+
172+
def reset(self, initial_lag: float, valid_blocks: int):
173+
window_len = int(self.window_sec / self.dt)
174+
self.points = Points(window_len)
175+
self.block_avg = BlockAverage(self.block_count, self.block_size, valid_blocks, initial_lag)
176+
177+
def get_msg(self, valid: bool, debug: bool = False) -> capnp._DynamicStructBuilder:
178+
msg = messaging.new_message('liveDelay')
179+
180+
msg.valid = valid
181+
182+
liveDelay = msg.liveDelay
183+
184+
valid_mean_lag, current_mean_lag = self.block_avg.get()
185+
if self.block_avg.valid_blocks >= self.min_valid_block_count and not np.isnan(valid_mean_lag):
186+
liveDelay.status = log.LiveDelayData.Status.estimated
187+
liveDelay.lateralDelay = valid_mean_lag
188+
else:
189+
liveDelay.status = log.LiveDelayData.Status.unestimated
190+
liveDelay.lateralDelay = self.initial_lag
191+
if not np.isnan(current_mean_lag):
192+
liveDelay.lateralDelayEstimate = current_mean_lag
193+
else:
194+
liveDelay.lateralDelayEstimate = self.initial_lag
195+
liveDelay.validBlocks = self.block_avg.valid_blocks
196+
if debug:
197+
liveDelay.points = self.block_avg.values.flatten().tolist()
198+
199+
return msg
200+
201+
def handle_log(self, t: float, which: str, msg: capnp._DynamicStructReader):
202+
if which == "carControl":
203+
self.lat_active = msg.latActive
204+
elif which == "carState":
205+
self.steering_pressed = msg.steeringPressed
206+
self.v_ego = msg.vEgo
207+
elif which == "controlsState":
208+
self.steering_saturated = getattr(msg.lateralControlState, msg.lateralControlState.which()).saturated
209+
self.desired_curvature = msg.desiredCurvature
210+
elif which == "liveCalibration":
211+
self.calibrator.feed_live_calib(msg)
212+
elif which == "livePose":
213+
device_pose = Pose.from_live_pose(msg)
214+
calibrated_pose = self.calibrator.build_calibrated_pose(device_pose)
215+
self.yaw_rate = calibrated_pose.angular_velocity.z
216+
self.t = t
217+
218+
def points_enough(self):
219+
return self.points.num_points >= int(self.okay_window_sec / self.dt)
220+
221+
def points_valid(self):
222+
return self.points.num_okay >= int(self.okay_window_sec / self.dt)
223+
224+
def update_points(self):
225+
if not self.lat_active:
226+
self.last_lat_inactive_t = self.t
227+
if self.steering_pressed:
228+
self.last_steering_pressed_t = self.t
229+
if self.steering_saturated:
230+
self.last_steering_saturated_t = self.t
231+
232+
la_desired = self.desired_curvature * self.v_ego * self.v_ego
233+
la_actual_pose = self.yaw_rate * self.v_ego
234+
235+
fast = self.v_ego > self.min_vego
236+
turning = np.abs(self.yaw_rate) >= self.min_yr
237+
has_recovered = all( # wait for recovery after !lat_active, steering_pressed, steering_saturated
238+
self.t - last_t >= self.min_recovery_buffer_sec
239+
for last_t in [self.last_lat_inactive_t, self.last_steering_pressed_t, self.last_steering_saturated_t]
240+
)
241+
okay = self.lat_active and not self.steering_pressed and not self.steering_saturated and fast and turning and has_recovered
242+
243+
self.points.update(self.t, la_desired, la_actual_pose, okay)
244+
245+
def update_estimate(self):
246+
if not self.points_enough():
247+
return
248+
249+
times, desired, actual, okay = self.points.get()
250+
# check if there are any new valid data points since the last update
251+
is_valid = self.points_valid()
252+
if self.last_estimate_t != 0 and times[0] <= self.last_estimate_t:
253+
new_values_start_idx = next(-i for i, t in enumerate(reversed(times)) if t <= self.last_estimate_t)
254+
is_valid = is_valid and not (new_values_start_idx == 0 or not np.any(okay[new_values_start_idx:]))
255+
256+
delay, corr = self.actuator_delay(desired, actual, okay, self.dt, MAX_LAG)
257+
if corr < self.min_ncc or not is_valid:
258+
return
259+
260+
self.block_avg.update(delay)
261+
self.last_estimate_t = self.t
262+
263+
def actuator_delay(self, expected_sig: np.ndarray, actual_sig: np.ndarray, mask: np.ndarray, dt: float, max_lag: float) -> tuple[float, float]:
264+
assert len(expected_sig) == len(actual_sig)
265+
max_lag_samples = int(max_lag / dt)
266+
padded_size = fft_next_good_size(len(expected_sig) + max_lag_samples)
267+
268+
ncc = masked_normalized_cross_correlation(expected_sig, actual_sig, mask, padded_size)
269+
270+
# only consider lags from 0 to max_lag
271+
roi_ncc = ncc[len(expected_sig) - 1: len(expected_sig) - 1 + max_lag_samples]
272+
273+
max_corr_index = np.argmax(roi_ncc)
274+
corr = roi_ncc[max_corr_index]
275+
lag = parabolic_peak_interp(roi_ncc, max_corr_index) * dt
276+
277+
return lag, corr
278+
279+
280+
def retrieve_initial_lag(params_reader: Params, CP: car.CarParams):
281+
last_lag_data = params_reader.get("LiveDelay")
282+
last_carparams_data = params_reader.get("CarParamsPrevRoute")
283+
284+
if last_lag_data is not None:
285+
try:
286+
with log.Event.from_bytes(last_lag_data) as last_lag_msg, car.CarParams.from_bytes(last_carparams_data) as last_CP:
287+
ld = last_lag_msg.liveDelay
288+
if last_CP.carFingerprint != CP.carFingerprint:
289+
raise Exception("Car model mismatch")
290+
291+
lag, valid_blocks = ld.lateralDelayEstimate, ld.validBlocks
292+
assert valid_blocks <= BLOCK_NUM, "Invalid number of valid blocks"
293+
return lag, valid_blocks
294+
except Exception as e:
295+
cloudlog.error(f"Failed to retrieve initial lag: {e}")
296+
297+
return None
298+
299+
300+
def main():
301+
config_realtime_process([0, 1, 2, 3], 5)
302+
303+
DEBUG = bool(int(os.getenv("DEBUG", "0")))
304+
305+
pm = messaging.PubMaster(['liveDelay'])
306+
sm = messaging.SubMaster(['livePose', 'liveCalibration', 'carState', 'controlsState', 'carControl'], poll='livePose')
307+
308+
params_reader = Params()
309+
CP = messaging.log_from_bytes(params_reader.get("CarParams", block=True), car.CarParams)
310+
311+
lag_learner = LateralLagEstimator(CP, 1. / SERVICE_LIST['livePose'].frequency)
312+
if (initial_lag_params := retrieve_initial_lag(params_reader, CP)) is not None:
313+
lag, valid_blocks = initial_lag_params
314+
lag_learner.reset(lag, valid_blocks)
315+
316+
while True:
317+
sm.update()
318+
if sm.all_checks():
319+
for which in sorted(sm.updated.keys(), key=lambda x: sm.logMonoTime[x]):
320+
if sm.updated[which]:
321+
t = sm.logMonoTime[which] * 1e-9
322+
lag_learner.handle_log(t, which, sm[which])
323+
lag_learner.update_points()
324+
325+
# 4Hz driven by livePose
326+
if sm.frame % 5 == 0:
327+
lag_learner.update_estimate()
328+
lag_msg = lag_learner.get_msg(sm.all_checks(), DEBUG)
329+
lag_msg_dat = lag_msg.to_bytes()
330+
pm.send('liveDelay', lag_msg_dat)
331+
332+
if sm.frame % 1200 == 0: # cache every 60 seconds
333+
params_reader.put_nonblocking("LiveDelay", lag_msg_dat)

0 commit comments

Comments
 (0)