Skip to content

Commit 1d8cd18

Browse files
authored
fix fom error (#319)
* 1.fix error with 4 channels of image of fom predictor. 2.fix error of fom evaluate 3.fix lapstyle vgg network
1 parent 6094e44 commit 1d8cd18

File tree

5 files changed

+76
-62
lines changed

5 files changed

+76
-62
lines changed

configs/firstorder_vox_256.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ log_config:
115115
visiual_interval: 10
116116

117117
validate:
118-
interval: 10
118+
interval: 3000
119119
save_img: false
120120

121121
snapshot_config:

ppgan/apps/first_order_predictor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,16 @@ def __init__(self,
103103
self.cfg, self.weight_path)
104104
self.multi_person = multi_person
105105

106+
def read_img(self, path):
107+
img = imageio.imread(path)
108+
img = img.astype(np.float32)
109+
if img.ndim == 2:
110+
img = np.expand_dims(img, axis=2)
111+
# som images have 4 channels
112+
if img.shape[2] > 3:
113+
img = img[:,:,:3]
114+
return img
115+
106116
def run(self, source_image, driving_video):
107117
def get_prediction(face_image):
108118
if self.find_best_frame or self.best_frame is not None:
@@ -138,7 +148,7 @@ def get_prediction(face_image):
138148
adapt_movement_scale=self.adapt_scale)
139149
return predictions
140150

141-
source_image = imageio.imread(source_image)
151+
source_image = self.read_img(source_image)
142152
reader = imageio.get_reader(driving_video)
143153
fps = reader.get_meta_data()['fps']
144154
driving_video = []

ppgan/datasets/firstorder_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,8 @@ def __getitem__(self, idx):
251251
out['driving'] = out['source']
252252
out['source'] = buf
253253
else:
254-
video = np.stack(video_array, axis=0) / 255.0
254+
video = np.stack(video_array, axis=0).astype(
255+
np.float32) / 255.0
255256
out['video'] = video.transpose(3, 0, 1, 2)
256257
out['name'] = video_name
257258
return out

ppgan/models/firstorder_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,20 @@ def setup_lr_schedulers(self, lr_cfg):
8686
"gen_lr": self.gen_lr,
8787
"dis_lr": self.dis_lr
8888
}
89-
90-
def setup_optimizers(self, lr_cfg, optimizer):
89+
90+
def setup_net_parallel(self):
9191
if isinstance(self.nets['Gen_Full'], paddle.DataParallel):
9292
self.nets['kp_detector'] = self.nets[
9393
'Gen_Full']._layers.kp_extractor
9494
self.nets['generator'] = self.nets['Gen_Full']._layers.generator
9595
self.nets['discriminator'] = self.nets['Dis']._layers.discriminator
9696
else:
97-
9897
self.nets['kp_detector'] = self.nets['Gen_Full'].kp_extractor
9998
self.nets['generator'] = self.nets['Gen_Full'].generator
10099
self.nets['discriminator'] = self.nets['Dis'].discriminator
100+
101+
def setup_optimizers(self, lr_cfg, optimizer):
102+
self.setup_net_parallel()
101103
# init params
102104
init_weight(self.nets['kp_detector'])
103105
init_weight(self.nets['generator'])
@@ -163,6 +165,7 @@ def train_iter(self, optimizers=None):
163165
self.optimizers['optimizer_Dis'].step()
164166

165167
def test_iter(self, metrics=None):
168+
self.setup_net_parallel()
166169
self.nets['kp_detector'].eval()
167170
self.nets['generator'].eval()
168171
loss_list = []

ppgan/models/generators/generater_lapstyle.py

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -167,61 +167,6 @@ def forward(self, cF, sF):
167167
return out
168168

169169

170-
vgg = nn.Sequential(
171-
nn.Conv2D(3, 3, (1, 1)),
172-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
173-
nn.Conv2D(3, 64, (3, 3)),
174-
nn.ReLU(), # relu1-1
175-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
176-
nn.Conv2D(64, 64, (3, 3)),
177-
nn.ReLU(), # relu1-2
178-
nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True),
179-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
180-
nn.Conv2D(64, 128, (3, 3)),
181-
nn.ReLU(), # relu2-1
182-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
183-
nn.Conv2D(128, 128, (3, 3)),
184-
nn.ReLU(), # relu2-2
185-
nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True),
186-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
187-
nn.Conv2D(128, 256, (3, 3)),
188-
nn.ReLU(), # relu3-1
189-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
190-
nn.Conv2D(256, 256, (3, 3)),
191-
nn.ReLU(), # relu3-2
192-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
193-
nn.Conv2D(256, 256, (3, 3)),
194-
nn.ReLU(), # relu3-3
195-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
196-
nn.Conv2D(256, 256, (3, 3)),
197-
nn.ReLU(), # relu3-4
198-
nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True),
199-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
200-
nn.Conv2D(256, 512, (3, 3)),
201-
nn.ReLU(), # relu4-1, this is the last layer used
202-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
203-
nn.Conv2D(512, 512, (3, 3)),
204-
nn.ReLU(), # relu4-2
205-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
206-
nn.Conv2D(512, 512, (3, 3)),
207-
nn.ReLU(), # relu4-3
208-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
209-
nn.Conv2D(512, 512, (3, 3)),
210-
nn.ReLU(), # relu4-4
211-
nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True),
212-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
213-
nn.Conv2D(512, 512, (3, 3)),
214-
nn.ReLU(), # relu5-1
215-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
216-
nn.Conv2D(512, 512, (3, 3)),
217-
nn.ReLU(), # relu5-2
218-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
219-
nn.Conv2D(512, 512, (3, 3)),
220-
nn.ReLU(), # relu5-3
221-
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
222-
nn.Conv2D(512, 512, (3, 3)),
223-
nn.ReLU() # relu5-4
224-
)
225170

226171

227172
@GENERATORS.register()
@@ -233,7 +178,62 @@ class Encoder(nn.Layer):
233178
"""
234179
def __init__(self):
235180
super(Encoder, self).__init__()
236-
vgg_net = vgg
181+
vgg_net = nn.Sequential(
182+
nn.Conv2D(3, 3, (1, 1)),
183+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
184+
nn.Conv2D(3, 64, (3, 3)),
185+
nn.ReLU(), # relu1-1
186+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
187+
nn.Conv2D(64, 64, (3, 3)),
188+
nn.ReLU(), # relu1-2
189+
nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True),
190+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
191+
nn.Conv2D(64, 128, (3, 3)),
192+
nn.ReLU(), # relu2-1
193+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
194+
nn.Conv2D(128, 128, (3, 3)),
195+
nn.ReLU(), # relu2-2
196+
nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True),
197+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
198+
nn.Conv2D(128, 256, (3, 3)),
199+
nn.ReLU(), # relu3-1
200+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
201+
nn.Conv2D(256, 256, (3, 3)),
202+
nn.ReLU(), # relu3-2
203+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
204+
nn.Conv2D(256, 256, (3, 3)),
205+
nn.ReLU(), # relu3-3
206+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
207+
nn.Conv2D(256, 256, (3, 3)),
208+
nn.ReLU(), # relu3-4
209+
nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True),
210+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
211+
nn.Conv2D(256, 512, (3, 3)),
212+
nn.ReLU(), # relu4-1, this is the last layer used
213+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
214+
nn.Conv2D(512, 512, (3, 3)),
215+
nn.ReLU(), # relu4-2
216+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
217+
nn.Conv2D(512, 512, (3, 3)),
218+
nn.ReLU(), # relu4-3
219+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
220+
nn.Conv2D(512, 512, (3, 3)),
221+
nn.ReLU(), # relu4-4
222+
nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True),
223+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
224+
nn.Conv2D(512, 512, (3, 3)),
225+
nn.ReLU(), # relu5-1
226+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
227+
nn.Conv2D(512, 512, (3, 3)),
228+
nn.ReLU(), # relu5-2
229+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
230+
nn.Conv2D(512, 512, (3, 3)),
231+
nn.ReLU(), # relu5-3
232+
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
233+
nn.Conv2D(512, 512, (3, 3)),
234+
nn.ReLU() # relu5-4
235+
)
236+
237237
weight_path = get_path_from_url(
238238
'https://paddlegan.bj.bcebos.com/models/vgg_normalised.pdparams')
239239
vgg_net.set_dict(paddle.load(weight_path))

0 commit comments

Comments
 (0)