Skip to content

Commit 8eb890a

Browse files
Visualization updated, GIF added
1 parent 45e6bdb commit 8eb890a

File tree

2 files changed

+129
-9
lines changed

2 files changed

+129
-9
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from pathlib import Path
2+
import argparse
3+
from trajnetplusplustools.reader import Reader
4+
from trajnetplusplustools import show
5+
import numpy as np
6+
7+
import matplotlib.pyplot as plt
8+
import mpl_toolkits.mplot3d.axes3d as p3
9+
import matplotlib.animation as animation
10+
11+
12+
def add_gt_observation_to_prediction(gt_observation, model_prediction):
13+
obs_length = len(gt_observation[0]) - len(model_prediction[0])
14+
full_predicted_paths = [gt_observation[ped_id][obs_length-3:obs_length] + pred for ped_id, pred in enumerate(model_prediction)]
15+
return full_predicted_paths
16+
17+
18+
def update_lines(num, dataLines, lines):
19+
for line, data in zip(lines, dataLines):
20+
# NOTE: there is no .set_data() for 3 dim data...
21+
line.set_data(data[0:2, :num])
22+
line.set_3d_properties(data[2, :num])
23+
return lines
24+
25+
def animate_and_save_scene(scene, dataset_file, scene_id):
26+
27+
# Attaching 3D axis to the figure
28+
fig = plt.figure()
29+
ax = p3.Axes3D(fig)
30+
31+
z_zeros = np.zeros((scene.shape[0], scene.shape[1], 1))
32+
data = np.concatenate([scene, z_zeros], axis=-1)
33+
data = data.transpose(0, 2, 1)
34+
35+
# NOTE: Can't pass empty arrays into 3d version of plot()
36+
lines = [ax.plot(dat[0, 0:1], dat[1, 0:1], dat[2, 0:1])[0] for dat in data]
37+
38+
# Setting the axes properties
39+
ax.set_xlim3d([-5.0, 5.0])
40+
ax.set_xlabel('X')
41+
42+
ax.set_ylim3d([-5.0, 5.0])
43+
ax.set_ylabel('Y')
44+
45+
ax.set_zlim3d([0.0, 0.3])
46+
ax.set_zlabel('Z')
47+
48+
ax.set_title('3D Test')
49+
50+
# Creating the Animation object
51+
line_ani = animation.FuncAnimation(fig, update_lines, 25, fargs=(data, lines),
52+
interval=50, blit=False)
53+
line_ani.save(f'{dataset_file}/scene{scene_id}_animation.gif', writer='imagemagick', fps=5)
54+
plt.close()
55+
56+
def main():
57+
parser = argparse.ArgumentParser()
58+
parser.add_argument('dataset_files', nargs='+',
59+
help='Provide the ground-truth file followed by model prediction file')
60+
parser.add_argument('--viz_folder', default='./visualizations',
61+
help='base folder to store visualizations')
62+
parser.add_argument('--n', type=int, default=1,
63+
help='sample n trajectories')
64+
parser.add_argument('--id', type=int, nargs='*',
65+
help='plot a particular scene')
66+
parser.add_argument('--random', default=True, action='store_true',
67+
help='randomize scenes')
68+
args = parser.parse_args()
69+
70+
assert len(args.dataset_files) > 2, "Please provide only one prediction file"
71+
# Determine and construct appropriate folders to save visualization
72+
dataset_name = args.dataset_files[1].split('/')[1]
73+
model_name = args.dataset_files[1].split('/')[-2]
74+
folder_name = f"{args.viz_folder}/{dataset_name}/{model_name}"
75+
Path(folder_name).mkdir(parents=True, exist_ok=True)
76+
77+
78+
## Read Scenes
79+
reader = Reader(args.dataset_files[1], scene_type='paths')
80+
if args.id:
81+
scenes = reader.scenes(ids=args.id, randomize=args.random)
82+
elif args.n:
83+
scenes = reader.scenes(limit=args.n, randomize=args.random)
84+
else:
85+
scenes = reader.scenes(randomize=args.random)
86+
87+
reader_gt = Reader(args.dataset_files[0], scene_type='paths')
88+
## Visualize different scenes as GIF
89+
for scene_id, paths in scenes:
90+
print("Scene ID: ", scene_id)
91+
_, paths_gt = reader_gt.scene(scene_id)
92+
full_predicted_paths = add_gt_observation_to_prediction(paths_gt, paths)
93+
scene = Reader.paths_to_xy(full_predicted_paths)
94+
scene = scene.transpose(1, 0, 2)
95+
animate_and_save_scene(scene, folder_name, scene_id)
96+
97+
98+
if __name__ == '__main__':
99+
main()

evaluator/visualize_predictions.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
1+
from pathlib import Path
12
import argparse
23
from trajnetplusplustools.reader import Reader
34
from trajnetplusplustools import show
45

56

7+
def add_gt_observation_to_prediction(gt_observation, model_prediction):
8+
obs_length = len(gt_observation[0]) - len(model_prediction[0])
9+
full_predicted_paths = [gt_observation[ped_id][:obs_length] + pred for ped_id, pred in enumerate(model_prediction)]
10+
return full_predicted_paths
11+
612
def main():
713
parser = argparse.ArgumentParser()
814
parser.add_argument('dataset_files', nargs='+',
9-
help='Trajnet dataset file(s).')
15+
help='Provide the ground-truth file followed by model prediction file(s)')
1016
parser.add_argument('--n', type=int, default=15,
1117
help='sample n trajectories')
1218
parser.add_argument('--id', type=int, nargs='*',
1319
help='plot a particular scene')
20+
parser.add_argument('--viz_folder', default='./visualizations',
21+
help='base folder to store visualizations')
1422
parser.add_argument('-o', '--output', default=None,
1523
help='specify output prefix')
1624
parser.add_argument('--random', default=True, action='store_true',
@@ -19,9 +27,12 @@ def main():
1927
help='labels of models')
2028
args = parser.parse_args()
2129

22-
## TODO Configure Writing images
23-
# if args.output is None:
24-
# args.output = args.dataset_file
30+
# Determine and construct appropriate folders to save visualization
31+
dataset_name = args.dataset_files[0].split('/')[1]
32+
model_name = args.dataset_files[1].split('/')[-2]
33+
folder_name = f"{args.viz_folder}/{dataset_name}/{model_name}"
34+
Path(folder_name).mkdir(parents=True, exist_ok=True)
35+
single_model = len(args.dataset_files) == 2
2536

2637
## Read GT Scenes
2738
reader = Reader(args.dataset_files[0], scene_type='paths')
@@ -53,13 +64,23 @@ def main():
5364
pred_paths[label_dict[name]] = predicted_paths[0]
5465
pred_neigh_paths[label_dict[name]] = predicted_paths[1:]
5566

56-
output_filename = None
57-
if args.output is not None:
58-
output_filename = '{}.scene{}.png'.format(args.output, scene_id)
67+
# Visualize prediction(s) overlayed on GT scene
68+
output_filename = f"{folder_name}/single_scene{scene_id}.png" if single_model else \
69+
f"{folder_name}/multiple_scene{scene_id}.png"
5970
with show.predicted_paths(paths, pred_paths, output_file=output_filename):
6071
pass
61-
# with show.predicted_paths(paths, pred_paths, pred_neigh_paths):
62-
# pass
72+
73+
# Used when visualizing only a single model
74+
if single_model:
75+
# Visualize GT scene
76+
gt_filename = f"{folder_name}/gt_scene{scene_id}.png"
77+
with show.paths(paths, output_file=gt_filename):
78+
pass
79+
# Visualize Model Prediction scene
80+
pred_filename = f"{folder_name}/pred_scene{scene_id}.png"
81+
full_predicted_paths = add_gt_observation_to_prediction(paths, predicted_paths)
82+
with show.paths(full_predicted_paths, output_file=pred_filename):
83+
pass
6384

6485
if __name__ == '__main__':
6586
main()

0 commit comments

Comments
 (0)