From b2791a8e07fad139f7eb2d46a160f14ac08456a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20R=C4=85czy?= Date: Wed, 2 Oct 2024 20:47:16 -0700 Subject: [PATCH] process_replay: in-place modification for message migration (#33695) * Inplace modification of lr * Replace the original function * Add comment * Change the return type * Fix carParams retrieval * Remove the newline * Include carState migration * Remove TODO * Comment * List instead of gen * Fix deletion * Delete camera state if not valid * Update ref commit * Remove sorting at the end * Use migrate_all in ui report * Allow more control in what to migrate * Add type annotations * Static analysis * Improve type annot * Fix linter issues * Remove f-string * Migrate carState too in test_ui * Fix peripheralState migration * Sort at the end * Fix regen issue * Fix comments --- selfdrive/test/process_replay/migration.py | 406 +++++++++++---------- selfdrive/test/process_replay/ref_commit | 2 +- selfdrive/ui/tests/test_ui/run.py | 4 +- 3 files changed, 217 insertions(+), 195 deletions(-) diff --git a/selfdrive/test/process_replay/migration.py b/selfdrive/test/process_replay/migration.py index c0696f2181b3d2..9241a3d4934f42 100644 --- a/selfdrive/test/process_replay/migration.py +++ b/selfdrive/test/process_replay/migration.py @@ -1,4 +1,7 @@ from collections import defaultdict +from collections.abc import Callable +import functools +import capnp from cereal import messaging, car from opendbc.car.fingerprints import MIGRATION @@ -7,70 +10,115 @@ from openpilot.selfdrive.modeld.fill_model_msg import fill_xyz_poly, fill_lane_line_meta from openpilot.selfdrive.test.process_replay.vision_meta import meta_from_encode_index from openpilot.system.manager.process_config import managed_processes +from openpilot.tools.lib.logreader import LogIterable from panda import Panda - -# TODO: message migration should happen in-place -def migrate_all(lr, manager_states=False, panda_states=False, camera_states=False): - msgs = migrate_sensorEvents(lr) - msgs = migrate_carParams(msgs) - msgs = migrate_gpsLocation(msgs) - msgs = migrate_deviceState(msgs) - msgs = migrate_carOutput(msgs) - msgs = migrate_controlsState(msgs) - msgs = migrate_liveLocationKalman(msgs) - msgs = migrate_liveTracks(msgs) - msgs = migrate_driverAssistance(msgs) - msgs = migrate_drivingModelData(msgs) +MessageWithIndex = tuple[int, capnp.lib.capnp._DynamicStructReader] +MigrationOps = tuple[list[tuple[int, capnp.lib.capnp._DynamicStructReader]], list[capnp.lib.capnp._DynamicStructReader], list[int]] +MigrationFunc = Callable[[list[MessageWithIndex]], MigrationOps] + + +## rules for migration functions +## 1. must use the decorator @migration(inputs=[...], product="...") and MigrationFunc signature +## 2. it only gets the messages that are in the inputs list +## 3. product is the message type created by the migration function, and the function will be skipped if product type already exists in lr +## 4. it must return a list of operations to be applied to the logreader (replace, add, delete) +## 5. all migration functions must be independent of each other +def migrate_all(lr: LogIterable, manager_states: bool = False, panda_states: bool = False, camera_states: bool = False): + migrations = [ + migrate_sensorEvents, + migrate_carParams, + migrate_gpsLocation, + migrate_deviceState, + migrate_carOutput, + migrate_controlsState, + migrate_carState, + migrate_liveLocationKalman, + migrate_liveTracks, + migrate_driverAssistance, + migrate_drivingModelData, + ] if manager_states: - msgs = migrate_managerState(msgs) + migrations.append(migrate_managerState) if panda_states: - msgs = migrate_pandaStates(msgs) - msgs = migrate_peripheralState(msgs) + migrations.extend([migrate_pandaStates, migrate_peripheralState]) if camera_states: - msgs = migrate_cameraStates(msgs) - - return msgs - - -def migrate_driverAssistance(lr): - all_msgs = [] - for msg in lr: - all_msgs.append(msg) - if msg.which() == 'longitudinalPlan': - all_msgs.append(messaging.new_message('driverAssistance', valid=True, logMonoTime=msg.logMonoTime).as_reader()) - if msg.which() == 'driverAssistance': - return lr - return all_msgs - - -def migrate_drivingModelData(lr): - all_msgs = [] - for msg in lr: - all_msgs.append(msg) - if msg.which() == "modelV2": - dmd = messaging.new_message('drivingModelData', valid=msg.valid, logMonoTime=msg.logMonoTime) - for field in ["frameId", "frameIdExtra", "frameDropPerc", "modelExecutionTime", "action"]: - setattr(dmd.drivingModelData, field, getattr(msg.modelV2, field)) - for meta_field in ["laneChangeState", "laneChangeState"]: - setattr(dmd.drivingModelData.meta, meta_field, getattr(msg.modelV2.meta, meta_field)) - if len(msg.modelV2.laneLines) and len(msg.modelV2.laneLineProbs): - fill_lane_line_meta(dmd.drivingModelData.laneLineMeta, msg.modelV2.laneLines, msg.modelV2.laneLineProbs) - if all(len(a) for a in [msg.modelV2.position.x, msg.modelV2.position.y, msg.modelV2.position.z]): - fill_xyz_poly(dmd.drivingModelData.path, ModelConstants.POLY_PATH_DEGREE, msg.modelV2.position.x, msg.modelV2.position.y, msg.modelV2.position.z) - all_msgs.append(dmd.as_reader()) - elif msg.which() == "drivingModelData": - return lr - return all_msgs - - -def migrate_liveTracks(lr): - all_msgs = [] - for msg in lr: - if msg.which() != "liveTracksDEPRECATED": - all_msgs.append(msg) + migrations.append(migrate_cameraStates) + + return migrate(lr, migrations) + + +def migrate(lr: LogIterable, migration_funcs: list[MigrationFunc]): + lr = list(lr) + grouped = defaultdict(list) + for i, msg in enumerate(lr): + grouped[msg.which()].append(i) + + replace_ops, add_ops, del_ops = [], [], [] + for migration in migration_funcs: + assert hasattr(migration, "inputs") and hasattr(migration, "product"), "Migration functions must use @migration decorator" + if migration.product in grouped: # skip if product already exists continue + sorted_indices = sorted(ii for i in migration.inputs for ii in grouped[i]) + msg_gen = [(i, lr[i]) for i in sorted_indices] + r_ops, a_ops, d_ops = migration(msg_gen) + replace_ops.extend(r_ops) + add_ops.extend(a_ops) + del_ops.extend(d_ops) + + for index, msg in replace_ops: + lr[index] = msg + for index in sorted(del_ops, reverse=True): + del lr[index] + for msg in add_ops: + lr.append(msg) + lr = sorted(lr, key=lambda x: x.logMonoTime) + + return lr + + +def migration(inputs: list[str], product: str|None=None): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + wrapper.inputs = inputs + wrapper.product = product + return wrapper + return decorator + + +@migration(inputs=["longitudinalPlan"], product="driverAssistance") +def migrate_driverAssistance(msgs): + add_ops = [] + for _, msg in msgs: + new_msg = messaging.new_message('driverAssistance', valid=True, logMonoTime=msg.logMonoTime) + add_ops.append(new_msg.as_reader()) + return [], add_ops, [] + + +@migration(inputs=["modelV2"], product="drivingModelData") +def migrate_drivingModelData(msgs): + add_ops = [] + for _, msg in msgs: + dmd = messaging.new_message('drivingModelData', valid=msg.valid, logMonoTime=msg.logMonoTime) + for field in ["frameId", "frameIdExtra", "frameDropPerc", "modelExecutionTime", "action"]: + setattr(dmd.drivingModelData, field, getattr(msg.modelV2, field)) + for meta_field in ["laneChangeState", "laneChangeState"]: + setattr(dmd.drivingModelData.meta, meta_field, getattr(msg.modelV2.meta, meta_field)) + if len(msg.modelV2.laneLines) and len(msg.modelV2.laneLineProbs): + fill_lane_line_meta(dmd.drivingModelData.laneLineMeta, msg.modelV2.laneLines, msg.modelV2.laneLineProbs) + if all(len(a) for a in [msg.modelV2.position.x, msg.modelV2.position.y, msg.modelV2.position.z]): + fill_xyz_poly(dmd.drivingModelData.path, ModelConstants.POLY_PATH_DEGREE, msg.modelV2.position.x, msg.modelV2.position.y, msg.modelV2.position.z) + add_ops.append( dmd.as_reader()) + return [], add_ops, [] + + +@migration(inputs=["liveTracksDEPRECATED"], product="liveTracks") +def migrate_liveTracks(msgs): + ops = [] + for index, msg in msgs: new_msg = messaging.new_message('liveTracks') new_msg.valid = msg.valid new_msg.logMonoTime = msg.logMonoTime @@ -88,22 +136,14 @@ def migrate_liveTracks(lr): pts.append(pt) new_msg.liveTracks.points = pts - all_msgs.append(new_msg.as_reader()) + ops.append((index, new_msg.as_reader())) + return ops, [], [] - return all_msgs - - -def migrate_liveLocationKalman(lr): - # migration needed only for routes before livePose - if any(msg.which() == 'livePose' for msg in lr): - return lr - - all_msgs = [] - for msg in lr: - if msg.which() != 'liveLocationKalmanDEPRECATED': - all_msgs.append(msg) - continue +@migration(inputs=["liveLocationKalmanDEPRECATED"], product="livePose") +def migrate_liveLocationKalman(msgs): + ops = [] + for index, msg in msgs: m = messaging.new_message('livePose') m.valid = msg.valid m.logMonoTime = msg.logMonoTime @@ -114,102 +154,93 @@ def migrate_liveLocationKalman(lr): lp_field.valid = llk_field.valid for flag in ["inputsOK", "posenetOK", "sensorsOK"]: setattr(m.livePose, flag, getattr(msg.liveLocationKalmanDEPRECATED, flag)) + ops.append((index, m.as_reader())) + return ops, [], [] - all_msgs.append(m.as_reader()) - return all_msgs - - -def migrate_controlsState(lr): - ret = [] +@migration(inputs=["controlsState"], product="selfdriveState") +def migrate_controlsState(msgs): + add_ops = [] + for _, msg in msgs: + m = messaging.new_message('selfdriveState') + m.valid = msg.valid + m.logMonoTime = msg.logMonoTime + ss = m.selfdriveState + for field in ("enabled", "active", "state", "engageable", "alertText1", "alertText2", + "alertStatus", "alertSize", "alertType", "experimentalMode", + "personality"): + setattr(ss, field, getattr(msg.controlsState, field+"DEPRECATED")) + add_ops.append(m.as_reader()) + return [], add_ops, [] + + +@migration(inputs=["carState", "controlsState"]) +def migrate_carState(msgs): + ops = [] last_cs = None - for msg in lr: + for index, msg in msgs: if msg.which() == 'controlsState': last_cs = msg - - m = messaging.new_message('selfdriveState') - m.valid = msg.valid - m.logMonoTime = msg.logMonoTime - ss = m.selfdriveState - for field in ("enabled", "active", "state", "engageable", "alertText1", "alertText2", - "alertStatus", "alertSize", "alertType", "experimentalMode", - "personality"): - setattr(ss, field, getattr(msg.controlsState, field+"DEPRECATED")) - ret.append(m.as_reader()) elif msg.which() == 'carState' and last_cs is not None: if last_cs.controlsState.vCruiseDEPRECATED - msg.carState.vCruise > 0.1: msg = msg.as_builder() msg.carState.vCruise = last_cs.controlsState.vCruiseDEPRECATED msg.carState.vCruiseCluster = last_cs.controlsState.vCruiseClusterDEPRECATED - msg = msg.as_reader() + ops.append((index, msg.as_reader())) + return ops, [], [] - ret.append(msg) - return ret - - -def migrate_managerState(lr): - all_msgs = [] - for msg in lr: - if msg.which() != "managerState": - all_msgs.append(msg) - continue +@migration(inputs=["managerState"]) +def migrate_managerState(msgs): + ops = [] + for index, msg in msgs: new_msg = msg.as_builder() new_msg.managerState.processes = [{'name': name, 'running': True} for name in managed_processes] - all_msgs.append(new_msg.as_reader()) - - return all_msgs + ops.append((index, new_msg.as_reader())) + return ops, [], [] -def migrate_gpsLocation(lr): - all_msgs = [] - for msg in lr: - if msg.which() in ('gpsLocation', 'gpsLocationExternal'): - new_msg = msg.as_builder() - g = getattr(new_msg, new_msg.which()) - # hasFix is a newer field - if not g.hasFix and g.flags == 1: - g.hasFix = True - all_msgs.append(new_msg.as_reader()) - else: - all_msgs.append(msg) - return all_msgs +@migration(inputs=["gpsLocation", "gpsLocationExternal"]) +def migrate_gpsLocation(msgs): + ops = [] + for index, msg in msgs: + new_msg = msg.as_builder() + g = getattr(new_msg, new_msg.which()) + # hasFix is a newer field + if not g.hasFix and g.flags == 1: + g.hasFix = True + ops.append((index, new_msg.as_reader())) + return ops, [], [] -def migrate_deviceState(lr): - all_msgs = [] +@migration(inputs=["deviceState", "initData"]) +def migrate_deviceState(msgs): + ops = [] dt = None - for msg in lr: + for i, msg in msgs: if msg.which() == 'initData': dt = msg.initData.deviceType if msg.which() == 'deviceState': n = msg.as_builder() n.deviceState.deviceType = dt - all_msgs.append(n.as_reader()) - else: - all_msgs.append(msg) - return all_msgs + ops.append((i, n.as_reader())) + return ops, [], [] -def migrate_carOutput(lr): - # migration needed only for routes before carOutput - if any(msg.which() == 'carOutput' for msg in lr): - return lr +@migration(inputs=["carControl"], product="carOutput") +def migrate_carOutput(msgs): + add_ops = [] + for _, msg in msgs: + co = messaging.new_message('carOutput') + co.valid = msg.valid + co.logMonoTime = msg.logMonoTime + co.carOutput.actuatorsOutput = msg.carControl.actuatorsOutputDEPRECATED + add_ops.append(co.as_reader()) + return [], add_ops, [] - all_msgs = [] - for msg in lr: - if msg.which() == 'carControl': - co = messaging.new_message('carOutput') - co.valid = msg.valid - co.logMonoTime = msg.logMonoTime - co.carOutput.actuatorsOutput = msg.carControl.actuatorsOutputDEPRECATED - all_msgs.append(co.as_reader()) - all_msgs.append(msg) - return all_msgs - -def migrate_pandaStates(lr): - all_msgs = [] +@migration(inputs=["pandaStates", "pandaStateDEPRECATED", "carParams"]) +def migrate_pandaStates(msgs): # TODO: safety param migration should be handled automatically safety_param_migration = { "TOYOTA_PRIUS": EPS_SCALE["TOYOTA_PRIUS"] | Panda.FLAG_TOYOTA_STOCK_LONGITUDINAL, @@ -217,11 +248,12 @@ def migrate_pandaStates(lr): "KIA_EV6": Panda.FLAG_HYUNDAI_EV_GAS | Panda.FLAG_HYUNDAI_CANFD_HDA2, } - # Migrate safety param base on carState - CP = next((m.carParams for m in lr if m.which() == 'carParams'), None) + # Migrate safety param base on carParams + CP = next((m.carParams for _, m in msgs if m.which() == 'carParams'), None) assert CP is not None, "carParams message not found" - if CP.carFingerprint in safety_param_migration: - safety_param = safety_param_migration[CP.carFingerprint] + fingerprint = MIGRATION.get(CP.carFingerprint, CP.carFingerprint) + if fingerprint in safety_param_migration: + safety_param = safety_param_migration[fingerprint] elif len(CP.safetyConfigs): safety_param = CP.safetyConfigs[0].safetyParam if CP.safetyConfigs[0].safetyParamDEPRECATED != 0: @@ -229,49 +261,45 @@ def migrate_pandaStates(lr): else: safety_param = CP.safetyParamDEPRECATED - for msg in lr: + ops = [] + for index, msg in msgs: if msg.which() == 'pandaStateDEPRECATED': new_msg = messaging.new_message('pandaStates', 1) new_msg.valid = msg.valid new_msg.logMonoTime = msg.logMonoTime new_msg.pandaStates[0] = msg.pandaStateDEPRECATED new_msg.pandaStates[0].safetyParam = safety_param - all_msgs.append(new_msg.as_reader()) + ops.append((index, new_msg.as_reader())) elif msg.which() == 'pandaStates': new_msg = msg.as_builder() new_msg.pandaStates[-1].safetyParam = safety_param - all_msgs.append(new_msg.as_reader()) - else: - all_msgs.append(msg) + ops.append((index, new_msg.as_reader())) + return ops, [], [] - return all_msgs +@migration(inputs=["pandaStates", "pandaStateDEPRECATED"], product="peripheralState") +def migrate_peripheralState(msgs): + add_ops = [] -def migrate_peripheralState(lr): - if any(msg.which() == "peripheralState" for msg in lr): - return lr - - all_msg = [] - for msg in lr: - all_msg.append(msg) - if msg.which() not in ["pandaStates", "pandaStateDEPRECATED"]: + which = "pandaStates" if any(msg.which() == "pandaStates" for _, msg in msgs) else "pandaStateDEPRECATED" + for _, msg in msgs: + if msg.which() != which: continue - new_msg = messaging.new_message("peripheralState") new_msg.valid = msg.valid new_msg.logMonoTime = msg.logMonoTime - all_msg.append(new_msg.as_reader()) + add_ops.append(new_msg.as_reader()) + return [], add_ops, [] - return all_msg - -def migrate_cameraStates(lr): - all_msgs = [] +@migration(inputs=["roadEncodeIdx", "wideRoadEncodeIdx", "driverEncodeIdx", "roadCameraState", "wideRoadCameraState", "driverCameraState"]) +def migrate_cameraStates(msgs): + add_ops, del_ops = [], [] frame_to_encode_id = defaultdict(dict) # just for encodeId fallback mechanism min_frame_id = defaultdict(lambda: float('inf')) - for msg in lr: + for _, msg in msgs: if msg.which() not in ["roadEncodeIdx", "wideRoadEncodeIdx", "driverEncodeIdx"]: continue @@ -281,9 +309,8 @@ def migrate_cameraStates(lr): assert encode_index.segmentId < 1200, f"Encoder index segmentId greater that 1200: {msg.which()} {encode_index.segmentId}" frame_to_encode_id[meta.camera_state][encode_index.frameId] = encode_index.segmentId - for msg in lr: + for index, msg in msgs: if msg.which() not in ["roadCameraState", "wideRoadCameraState", "driverCameraState"]: - all_msgs.append(msg) continue camera_state = getattr(msg, msg.which()) @@ -293,6 +320,7 @@ def migrate_cameraStates(lr): if encode_id is None: print(f"Missing encoded frame for camera feed {msg.which()} with frameId: {camera_state.frameId}") if len(frame_to_encode_id[msg.which()]) != 0: + del_ops.append(index) continue # fallback mechanism for logs without encodeIdx (e.g. logs from before 2022 with dcamera recording disabled) @@ -313,33 +341,27 @@ def migrate_cameraStates(lr): new_msg.logMonoTime = msg.logMonoTime new_msg.valid = msg.valid - all_msgs.append(new_msg.as_reader()) - - return all_msgs + del_ops.append(index) + add_ops.append(new_msg.as_reader()) + return [], add_ops, del_ops -def migrate_carParams(lr): - all_msgs = [] - for msg in lr: - if msg.which() == 'carParams': - CP = msg.as_builder() - CP.carParams.carFingerprint = MIGRATION.get(CP.carParams.carFingerprint, CP.carParams.carFingerprint) - for car_fw in CP.carParams.carFw: - car_fw.brand = CP.carParams.carName - CP.logMonoTime = msg.logMonoTime - msg = CP.as_reader() - all_msgs.append(msg) +@migration(inputs=["carParams"]) +def migrate_carParams(msgs): + ops = [] + for index, msg in msgs: + CP = msg.as_builder() + CP.carParams.carFingerprint = MIGRATION.get(CP.carParams.carFingerprint, CP.carParams.carFingerprint) + for car_fw in CP.carParams.carFw: + car_fw.brand = CP.carParams.carName + ops.append((index, CP.as_reader())) + return ops, [], [] - return all_msgs - - -def migrate_sensorEvents(lr): - all_msgs = [] - for msg in lr: - if msg.which() != 'sensorEventsDEPRECATED': - all_msgs.append(msg) - continue +@migration(inputs=["sensorEventsDEPRECATED"], product="sensorEvents") +def migrate_sensorEvents(msgs): + add_ops, del_ops = [], [] + for index, msg in msgs: # migrate to split sensor events for evt in msg.sensorEventsDEPRECATED: # build new message for each sensor type @@ -367,6 +389,6 @@ def migrate_sensorEvents(lr): m_dat.timestamp = evt.timestamp setattr(m_dat, evt.which(), getattr(evt, evt.which())) - all_msgs.append(m.as_reader()) - - return all_msgs + add_ops.append(m.as_reader()) + del_ops.append(index) + return [], add_ops, del_ops diff --git a/selfdrive/test/process_replay/ref_commit b/selfdrive/test/process_replay/ref_commit index 30daa6e2652641..32609b0a8f6f7d 100644 --- a/selfdrive/test/process_replay/ref_commit +++ b/selfdrive/test/process_replay/ref_commit @@ -1 +1 @@ -1b4480f5ebf5a4003dde22f19bbcf989294bc724 \ No newline at end of file +f9f1fd736c6bbef3aa2d3aea8e4f8e1c892234de \ No newline at end of file diff --git a/selfdrive/ui/tests/test_ui/run.py b/selfdrive/ui/tests/test_ui/run.py index 8c9db5a3f3b948..a918a573649d4b 100644 --- a/selfdrive/ui/tests/test_ui/run.py +++ b/selfdrive/ui/tests/test_ui/run.py @@ -19,7 +19,7 @@ from openpilot.common.transformations.camera import CameraConfig, DEVICE_CAMERAS from openpilot.selfdrive.selfdrived.alertmanager import set_offroad_alert from openpilot.selfdrive.test.helpers import with_processes -from openpilot.selfdrive.test.process_replay.migration import migrate_controlsState +from openpilot.selfdrive.test.process_replay.migration import migrate, migrate_controlsState, migrate_carState from openpilot.tools.lib.logreader import LogReader from openpilot.tools.lib.framereader import FrameReader from openpilot.tools.lib.route import Route @@ -263,7 +263,7 @@ def create_screenshots(): segnum = 2 lr = LogReader(route.qlog_paths()[segnum]) DATA['carParams'] = next((event.as_builder() for event in lr if event.which() == 'carParams'), None) - for event in migrate_controlsState(lr): + for event in migrate(lr, [migrate_controlsState, migrate_carState]): if event.which() in DATA: DATA[event.which()] = event.as_builder()