|
| 1 | +from pathlib import Path |
| 2 | +import argparse |
| 3 | +from trajnetplusplustools.reader import Reader |
| 4 | +from trajnetplusplustools import show |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +import matplotlib.pyplot as plt |
| 8 | +import mpl_toolkits.mplot3d.axes3d as p3 |
| 9 | +import matplotlib.animation as animation |
| 10 | + |
| 11 | + |
| 12 | +def add_gt_observation_to_prediction(gt_observation, model_prediction): |
| 13 | + obs_length = len(gt_observation[0]) - len(model_prediction[0]) |
| 14 | + full_predicted_paths = [gt_observation[ped_id][obs_length-3:obs_length] + pred for ped_id, pred in enumerate(model_prediction)] |
| 15 | + return full_predicted_paths |
| 16 | + |
| 17 | + |
| 18 | +def update_lines(num, dataLines, lines): |
| 19 | + for line, data in zip(lines, dataLines): |
| 20 | + # NOTE: there is no .set_data() for 3 dim data... |
| 21 | + line.set_data(data[0:2, :num]) |
| 22 | + line.set_3d_properties(data[2, :num]) |
| 23 | + return lines |
| 24 | + |
| 25 | +def animate_and_save_scene(scene, dataset_file, scene_id): |
| 26 | + |
| 27 | + # Attaching 3D axis to the figure |
| 28 | + fig = plt.figure() |
| 29 | + ax = p3.Axes3D(fig) |
| 30 | + |
| 31 | + z_zeros = np.zeros((scene.shape[0], scene.shape[1], 1)) |
| 32 | + data = np.concatenate([scene, z_zeros], axis=-1) |
| 33 | + data = data.transpose(0, 2, 1) |
| 34 | + |
| 35 | + # NOTE: Can't pass empty arrays into 3d version of plot() |
| 36 | + lines = [ax.plot(dat[0, 0:1], dat[1, 0:1], dat[2, 0:1])[0] for dat in data] |
| 37 | + |
| 38 | + # Setting the axes properties |
| 39 | + ax.set_xlim3d([-5.0, 5.0]) |
| 40 | + ax.set_xlabel('X') |
| 41 | + |
| 42 | + ax.set_ylim3d([-5.0, 5.0]) |
| 43 | + ax.set_ylabel('Y') |
| 44 | + |
| 45 | + ax.set_zlim3d([0.0, 0.3]) |
| 46 | + ax.set_zlabel('Z') |
| 47 | + |
| 48 | + ax.set_title('3D Test') |
| 49 | + |
| 50 | + # Creating the Animation object |
| 51 | + line_ani = animation.FuncAnimation(fig, update_lines, 25, fargs=(data, lines), |
| 52 | + interval=50, blit=False) |
| 53 | + line_ani.save(f'{dataset_file}/scene{scene_id}_animation.gif', writer='imagemagick', fps=5) |
| 54 | + plt.close() |
| 55 | + |
| 56 | +def main(): |
| 57 | + parser = argparse.ArgumentParser() |
| 58 | + parser.add_argument('dataset_files', nargs='+', |
| 59 | + help='Provide the ground-truth file followed by model prediction file') |
| 60 | + parser.add_argument('--viz_folder', default='./visualizations', |
| 61 | + help='base folder to store visualizations') |
| 62 | + parser.add_argument('--n', type=int, default=1, |
| 63 | + help='sample n trajectories') |
| 64 | + parser.add_argument('--id', type=int, nargs='*', |
| 65 | + help='plot a particular scene') |
| 66 | + parser.add_argument('--random', default=True, action='store_true', |
| 67 | + help='randomize scenes') |
| 68 | + args = parser.parse_args() |
| 69 | + |
| 70 | + assert len(args.dataset_files) > 2, "Please provide only one prediction file" |
| 71 | + # Determine and construct appropriate folders to save visualization |
| 72 | + dataset_name = args.dataset_files[1].split('/')[1] |
| 73 | + model_name = args.dataset_files[1].split('/')[-2] |
| 74 | + folder_name = f"{args.viz_folder}/{dataset_name}/{model_name}" |
| 75 | + Path(folder_name).mkdir(parents=True, exist_ok=True) |
| 76 | + |
| 77 | + |
| 78 | + ## Read Scenes |
| 79 | + reader = Reader(args.dataset_files[1], scene_type='paths') |
| 80 | + if args.id: |
| 81 | + scenes = reader.scenes(ids=args.id, randomize=args.random) |
| 82 | + elif args.n: |
| 83 | + scenes = reader.scenes(limit=args.n, randomize=args.random) |
| 84 | + else: |
| 85 | + scenes = reader.scenes(randomize=args.random) |
| 86 | + |
| 87 | + reader_gt = Reader(args.dataset_files[0], scene_type='paths') |
| 88 | + ## Visualize different scenes as GIF |
| 89 | + for scene_id, paths in scenes: |
| 90 | + print("Scene ID: ", scene_id) |
| 91 | + _, paths_gt = reader_gt.scene(scene_id) |
| 92 | + full_predicted_paths = add_gt_observation_to_prediction(paths_gt, paths) |
| 93 | + scene = Reader.paths_to_xy(full_predicted_paths) |
| 94 | + scene = scene.transpose(1, 0, 2) |
| 95 | + animate_and_save_scene(scene, folder_name, scene_id) |
| 96 | + |
| 97 | + |
| 98 | +if __name__ == '__main__': |
| 99 | + main() |
0 commit comments