@@ -284,16 +284,16 @@ def __call__(self, att_image: np.ndarray) -> np.ndarray:
284284 )
285285 align_att_img_batch = align_att_img_batch .to (self .device , non_blocking = True )
286286
287- n , c , h , w = align_att_img_batch .shape
288- img_white = torch .zeros ((n , 1 , h , w ), dtype = align_att_img_batch .dtype , device = self .device ) + 255.0
289-
290- inv_att_transforms : torch .Tensor = inverse_transform_batch (att_transforms )
291-
292287 # Get face masks for the attribute image
293288 face_mask , ignore_mask_ids = self .bise_net .get_mask (
294289 align_att_img_batch_for_parsing_model , self .crop_size
295290 )
296291
292+ n , c , h , w = align_att_img_batch .shape
293+ img_white = torch .zeros ((n , 1 , h , w ), dtype = align_att_img_batch .dtype , device = self .device ) + 255.0
294+
295+ inv_att_transforms : torch .Tensor = inverse_transform_batch (att_transforms )
296+
297297 soft_face_mask , _ = self .smooth_mask (face_mask )
298298
299299 # Only take face area from the swapped image
@@ -306,10 +306,6 @@ def __call__(self, att_image: np.ndarray) -> np.ndarray:
306306
307307 att_image = self .to_tensor (att_image ).to (self .device , non_blocking = True )
308308
309- if torch .sum (ignore_mask_ids .int ()) > 0 :
310- img_white = img_white [ignore_mask_ids , ...]
311- inv_att_transforms = inv_att_transforms [ignore_mask_ids , ...]
312-
313309 # to avoid OOM apply erosion on low res masks
314310 img_white = F .pad (img_white , (self .erode_mask_value , self .erode_mask_value , self .erode_mask_value , self .erode_mask_value ))
315311
0 commit comments