Skip to content

Commit 4de84dd

Browse files
committed
Bug fix
1 parent 2d35396 commit 4de84dd

File tree

7 files changed

+123
-77
lines changed

7 files changed

+123
-77
lines changed

DataLoader.py

Lines changed: 11 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
import numpy as np
99
from PIL import Image, ImageDraw
1010
import cv2
11-
import matplotlib.pyplot as plt
1211
import time
1312
import utils
14-
import matplotlib.pyplot as plt
1513

1614

1715
class myJAAD(torch.utils.data.Dataset):
1816
def __init__(self, args):
17+
print('Loading', args.dtype, 'data ...')
1918

2019
if(args.from_file):
2120
sequence_centric = pd.read_csv(args.file)
@@ -30,16 +29,12 @@ def __init__(self, args):
3029

3130
else:
3231
#read data
32+
print('Reading data files ...')
3333
df = pd.DataFrame()
3434
new_index=0
3535
for file in glob.glob(os.path.join(args.jaad_dataset,args.dtype,"*")):
3636
temp = pd.read_csv(file)
3737
if not temp.empty:
38-
#drop unnecessary columns
39-
temp = temp.drop(columns=['type', 'occlusion', 'nod', 'slow_down', 'speed_up', 'WALKING', 'walking',
40-
'standing', 'looking', 'handwave', 'clear_path', 'CLEAR_PATH','STANDING',
41-
'standing_pred', 'looking_pred', 'walking_pred','keypoints', 'crossing_pred'])
42-
4338
temp['file'] = [file for t in range(temp.shape[0])]
4439

4540
#assign unique ID to each
@@ -51,8 +46,8 @@ def __init__(self, args):
5146
temp = temp.sort_values(['ID', 'frame'], axis=0)
5247

5348
df = df.append(temp, ignore_index=True)
54-
print('reading files complete')
5549

50+
print('Processing data ...')
5651
#create sequence column
5752
df.insert(0, 'sequence', df.ID)
5853

@@ -119,18 +114,12 @@ def __init__(self, args):
119114

120115
sequence_centric = data.copy()
121116

122-
if args.sample:
123-
if args.trainOrVal == 'train':
124-
self.data = sequence_centric.loc[:args.n_train_sequences].copy().reset_index(drop=True)
125-
elif args.trainOrVal == 'val':
126-
self.data = sequence_centric.loc[args.n_train_sequences:].copy().reset_index(drop=True)
127-
128-
else:
129-
self.data = sequence_centric.copy().reset_index(drop=True)
117+
118+
self.data = sequence_centric.copy().reset_index(drop=True)
130119

131120
self.args = args
132121
self.dtype = args.dtype
133-
print(self.dtype, " set loaded")
122+
print(args.dtype, "set loaded")
134123
print('*'*30)
135124

136125

@@ -195,32 +184,9 @@ def scene_transforms(self, scene):
195184

196185

197186
def data_loader(args):
198-
if args.dtype == 'train':
199-
train_set = myJAAD(args)
200-
train_loader = torch.utils.data.DataLoader(
201-
train_set, batch_size=args.batch_size, shuffle=args.loader_shuffle,
202-
pin_memory=args.pin_memory, num_workers=args.loader_workers, drop_last=True)
203-
204-
args.trainOrVal = 'val'
205-
206-
val_set = myJAAD(args)
207-
val_loader = torch.utils.data.DataLoader(
208-
val_set, batch_size=args.batch_size, shuffle=args.loader_shuffle,
209-
pin_memory=args.pin_memory, num_workers=args.loader_workers, drop_last=True)
210-
211-
return train_loader, val_loader
212-
213-
elif args.dtype == 'val':
214-
215-
#rgs.file = args.val_file
216-
#rgs.dtype = 'val'
217-
#rgs.trainOrVal = 'test'
218-
#rgs.sample = False
219-
220-
test_set = myJAAD(args)
221-
222-
test_loader = torch.utils.data.DataLoader(
223-
test_set, batch_size=args.batch_size, shuffle=args.loader_shuffle,
224-
pin_memory=args.pin_memory, num_workers=args.loader_workers, drop_last=True)
187+
dataset = myJAAD(args)
188+
dataloader = torch.utils.data.DataLoader(
189+
dataset, batch_size=args.batch_size, shuffle=args.loader_shuffle,
190+
pin_memory=args.pin_memory, num_workers=args.loader_workers, drop_last=True)
225191

226-
return test_loader
192+
return dataloader

__pycache__/DataLoader.cpython-36.pyc

6.52 KB
Binary file not shown.

__pycache__/network.cpython-36.pyc

2.26 KB
Binary file not shown.

__pycache__/utils.cpython-36.pyc

3.8 KB
Binary file not shown.

prepare_data.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import os
2+
import sys
3+
import argparse
4+
import numpy as np
5+
import pandas as pd
6+
7+
parser = argparse.ArgumentParser()
8+
parser.add_argument('jaad_path', type=str, help='Path to zhr cloned JAAD repository')
9+
parser.add_argument('train_ratio', type=float, help='Ratio of train video')
10+
parser.add_argument('val_ratio', type=float, help='Ratio of val video')
11+
parser.add_argument('test_ratio', type=float, help='Ratio of test video')
12+
13+
args = parser.parse_args()
14+
15+
data_path = args.jaad_path
16+
sys.path.insert(1, data_path+'/')
17+
18+
import jaad_data
19+
20+
if not os.path.isdir(os.path.join(data_path, 'processed_annotations')):
21+
os.mkdir(os.path.join(data_path, 'processed_annotations'))
22+
23+
if not os.path.isdir(os.path.join(data_path, 'processed_annotations', 'train')):
24+
os.mkdir(os.path.join(data_path, 'processed_annotations', 'train'))
25+
26+
if not os.path.isdir(os.path.join(data_path, 'processed_annotations', 'val')):
27+
os.mkdir(os.path.join(data_path, 'processed_annotations', 'val'))
28+
29+
if not os.path.isdir(os.path.join(data_path, 'processed_annotations', 'test')):
30+
os.mkdir(os.path.join(data_path, 'processed_annotations', 'test'))
31+
32+
jaad = jaad_data.JAAD(data_path=data_path)
33+
dataset = jaad.generate_database()
34+
35+
n_train_video = int(args.train_ratio * 346)
36+
n_val_video = int(args.val_ratio * 346)
37+
n_test_video = int(args.test_ratio * 346)
38+
39+
videos = list(dataset.keys())
40+
train_videos = videos[:n_train_video]
41+
val_videos = videos[n_train_video:n_train_video+n_val_video]
42+
test_videos = videos[n_train_video+n_val_video:]
43+
44+
45+
for video in dataset:
46+
print('Processing', video, '...')
47+
vid = dataset[video]
48+
data = np.empty((0,8))
49+
for ped in vid['ped_annotations']:
50+
if vid['ped_annotations'][ped]['behavior']:
51+
frames = np.array(vid['ped_annotations'][ped]['frames']).reshape(-1,1)
52+
ids = np.repeat(vid['ped_annotations'][ped]['old_id'], frames.shape[0]).reshape(-1,1)
53+
bbox = np.array(vid['ped_annotations'][ped]['bbox'])
54+
x = bbox[:,0].reshape(-1,1)
55+
y = bbox[:,1].reshape(-1,1)
56+
w = np.abs(bbox[:,0] - bbox[:,2]).reshape(-1,1)
57+
h = np.abs(bbox[:,1] - bbox[:,3]).reshape(-1,1)
58+
scenefolderpath = np.repeat(os.path.join(data_path, 'scene', video.replace('video_', '')), frames.shape[0]).reshape(-1,1)
59+
60+
cross = np.array(vid['ped_annotations'][ped]['behavior']['cross']).reshape(-1,1)
61+
62+
ped_data = np.hstack((frames, ids, x, y, w, h, scenefolderpath, cross))
63+
data = np.vstack((data, ped_data))
64+
data_to_write = pd.DataFrame({'frame': data[:,0].reshape(-1),
65+
'ID': data[:,1].reshape(-1),
66+
'x': data[:,2].reshape(-1),
67+
'y': data[:,3].reshape(-1),
68+
'w': data[:,4].reshape(-1),
69+
'h': data[:,5].reshape(-1),
70+
'scenefolderpath': data[:,6].reshape(-1),
71+
'crossing_true': data[:,7].reshape(-1)})
72+
data_to_write['filename'] = data_to_write.frame
73+
data_to_write.filename = data_to_write.filename.apply(lambda x: '%04d'%int(x)+'.png')
74+
75+
if video in train_videos:
76+
data_to_write.to_csv(os.path.join(data_path, 'processed_annotations', 'train', video+'.csv'), index=False)
77+
elif video in val_videos:
78+
data_to_write.to_csv(os.path.join(data_path, 'processed_annotations', 'val', video+'.csv'), index=False)
79+
elif video in test_videos:
80+
data_to_write.to_csv(os.path.join(data_path, 'processed_annotations', 'test', video+'.csv'), index=False)
81+
82+
83+
84+
85+
86+
87+
88+
89+
90+
91+
92+

test.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torchvision
88
import torchvision.transforms as transforms
99

10-
import matplotlib.pyplot as plt
1110
import numpy as np
1211
from sklearn.metrics import recall_score, accuracy_score, average_precision_score, precision_score
1312

@@ -18,12 +17,13 @@
1817

1918
class args():
2019
def __init__(self):
21-
self.jaad_dataset = '/data/haziq-data/jaad/annotations' #folder containing parsed jaad annotations (used when first time loading data)
22-
self.dtype = 'val'
20+
self.jaad_dataset = '/data/smailait-data/JAAD/processed_annotations' #folder containing parsed jaad annotations (used when first time loading data)
21+
self.dtype = 'test'
2322
self.from_file = False #read dataset from csv file or reprocess data
24-
self.file = '/data/smail-data/jaad_val_16_16.csv'
25-
self.save_path = '/data/smail-data/jaad_val_16_16.csv'
26-
self.model_path = '/data/smail-data/multitask_pv_lstm_trained.pkl'
23+
self.save = True
24+
self.file = '/data/smailait-data/jaad_test_16_16.csv'
25+
self.save_path = '/data/smailait-data/jaad_test_16_16.csv'
26+
self.model_path = '/data/smailait-data/models/multitask_pv_lstm_trained.pkl'
2727
self.loader_workers = 10
2828
self.loader_shuffle = True
2929
self.pin_memory = False
@@ -33,12 +33,9 @@ def __init__(self):
3333
self.n_epochs = 100
3434
self.hidden_size = 512
3535
self.hardtanh_limit = 100
36-
self.sample = False
37-
self.n_train_sequences = 40000
38-
self.trainOrVal = 'test'
39-
self.citywalks = False
4036
self.input = 16
4137
self.output = 16
38+
self.stride = 16
4239
self.skip = 1
4340
self.task = 'bounding_box-intention'
4441
self.use_scenes = False

train.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
import torchvision
88
import torchvision.transforms as transforms
9-
10-
import matplotlib.pyplot as plt
9+
1110
import numpy as np
1211
from sklearn.metrics import recall_score, accuracy_score, average_precision_score, precision_score
1312

@@ -17,12 +16,13 @@
1716

1817
class args():
1918
def __init__(self):
20-
self.jaad_dataset = '../../../../data/haziq-data/jaad/annotations' #folder containing parsed jaad annotations (used when first time loading data)
19+
self.jaad_dataset = '/data/smailait-data/JAAD/processed_annotations' #folder containing parsed jaad annotations (used when first time loading data)
2120
self.dtype = 'train'
2221
self.from_file = False #read dataset from csv file or reprocess data
23-
self.file = '/data/smail-data/jaad_train_16_16.csv'
24-
self.save_path = '/data/smail-data/jaad_train_16_16.csv'
25-
self.model_path = '/data/smail-data/multitask_pv_lstm_trained.pkl'
22+
self.save = True
23+
self.file = '/data/smailait-data/jaad_train_16_16.csv'
24+
self.save_path = '/data/smailait-data/jaad_train_16_16.csv'
25+
self.model_path = '/data/smailait-data/models/multitask_pv_lstm_trained.pkl'
2626
self.loader_workers = 10
2727
self.loader_shuffle = True
2828
self.pin_memory = False
@@ -32,12 +32,9 @@ def __init__(self):
3232
self.n_epochs = 100
3333
self.hidden_size = 512
3434
self.hardtanh_limit = 100
35-
self.sample = False
36-
self.n_train_sequences = 40000
37-
self.trainOrVal = 'train'
38-
self.citywalks = False
3935
self.input = 16
4036
self.output = 16
37+
self.stride = 16
4138
self.skip = 1
4239
self.task = 'bounding_box-intention'
4340
self.use_scenes = False
@@ -46,7 +43,11 @@ def __init__(self):
4643
args = args()
4744

4845
net = network.PV_LSTM(args).to(args.device)
49-
train, val = DataLoader.data_loader(args)
46+
train = DataLoader.data_loader(args)
47+
args.dtype = 'val'
48+
args.save_path = args.save_path.replace('train', 'val')
49+
args.file = args.file.replace('train', 'val')
50+
val = DataLoader.data_loader(args)
5051

5152
optimizer = optim.Adam(net.parameters(), lr=args.lr)
5253
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=15,
@@ -178,17 +179,7 @@ def __init__(self):
178179
'| fde: %.4f'% fde, '| aiou: %.4f'% aiou, '| fiou: %.4f'% fiou, '| state_acc: %.4f'% avg_acc,
179180
'| rec: %.4f'% avg_rec, '| pre: %.4f'% avg_pre, '| intention_acc: %.4f'% intent_acc,
180181
'| t:%.4f'%(time.time()-start))
181-
182-
print('='*100)
183-
plt.figure(figsize=(10,8))
184-
plt.plot(list(range(len(train_s_scores))), train_s_scores, label = 'BB Training loss')
185-
plt.plot(list(range(len(val_s_scores))), val_s_scores, label = 'BB Validation loss')
186-
plt.plot(list(range(len(train_c_scores))), train_c_scores, label = 'Intention Training loss')
187-
plt.plot(list(range(len(val_c_scores))), val_c_scores, label = 'Intention Validation loss')
188-
plt.xlabel('epoch')
189-
plt.ylabel('Mean square error loss')
190-
plt.legend()
191-
plt.show()
182+
192183
print('='*100)
193184
print('Saving ...')
194185
torch.save(net.state_dict(), args.model_path)

0 commit comments

Comments
 (0)