Skip to content

Commit c4def24

Browse files
committed
Removed incorrect handling of ignored faces
1 parent c01a7e0 commit c4def24

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

src/simswap.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -284,16 +284,16 @@ def __call__(self, att_image: np.ndarray) -> np.ndarray:
284284
)
285285
align_att_img_batch = align_att_img_batch.to(self.device, non_blocking=True)
286286

287-
n, c, h, w = align_att_img_batch.shape
288-
img_white = torch.zeros((n, 1, h, w), dtype=align_att_img_batch.dtype, device=self.device) + 255.0
289-
290-
inv_att_transforms: torch.Tensor = inverse_transform_batch(att_transforms)
291-
292287
# Get face masks for the attribute image
293288
face_mask, ignore_mask_ids = self.bise_net.get_mask(
294289
align_att_img_batch_for_parsing_model, self.crop_size
295290
)
296291

292+
n, c, h, w = align_att_img_batch.shape
293+
img_white = torch.zeros((n, 1, h, w), dtype=align_att_img_batch.dtype, device=self.device) + 255.0
294+
295+
inv_att_transforms: torch.Tensor = inverse_transform_batch(att_transforms)
296+
297297
soft_face_mask, _ = self.smooth_mask(face_mask)
298298

299299
# Only take face area from the swapped image
@@ -306,10 +306,6 @@ def __call__(self, att_image: np.ndarray) -> np.ndarray:
306306

307307
att_image = self.to_tensor(att_image).to(self.device, non_blocking=True)
308308

309-
if torch.sum(ignore_mask_ids.int()) > 0:
310-
img_white = img_white[ignore_mask_ids, ...]
311-
inv_att_transforms = inv_att_transforms[ignore_mask_ids, ...]
312-
313309
# to avoid OOM apply erosion on low res masks
314310
img_white = F.pad(img_white, (self.erode_mask_value, self.erode_mask_value, self.erode_mask_value, self.erode_mask_value))
315311

0 commit comments

Comments
 (0)