22import numpy as np
33import torch
44import torch .nn .functional as F
5+ from enum import Enum
56from typing import Optional , Iterable , Tuple , Union
67from pathlib import Path
78from torchvision import transforms
1617from src .PostProcess .utils import SoftErosion
1718from src .Generator .fs_networks_fix import Generator_Adain_Upsample as Generator_Adain_Upsample_224
1819from src .Generator .fs_networks_512 import Generator_Adain_Upsample as Generator_Adain_Upsample_512
19-
20-
21- def tensor2img_denorm (tensor ):
22- std = torch .tensor ([0.229 , 0.224 , 0.225 ]).view (1 , 3 , 1 , 1 )
23- mean = torch .tensor ([0.485 , 0.456 , 0.406 ]).view (1 , 3 , 1 , 1 )
24- tensor = std * tensor .detach ().cpu () + mean
25- img = tensor .numpy ()
26- img = img .transpose (0 , 2 , 3 , 1 )[0 ]
27- img = np .clip (img * 255 , 0.0 , 255.0 ).astype (np .uint8 )
28- return img
29-
30-
31- def tensor2img (tensor ):
32- tensor = tensor .detach ().cpu ().numpy ()
33- img = tensor .transpose (0 , 2 , 3 , 1 )[0 ]
34- img = np .clip (img * 255 , 0.0 , 255.0 ).astype (np .uint8 )
35- return img
20+ from src .Misc .types import CheckpointType , FaceAlignmentType
21+ from src .Misc .utils import tensor2img , tensor2img_denorm
3622
3723
3824class SimSwap :
3925 def __init__ (self ,
4026 config : DictConfig ,
4127 id_image : np .ndarray ,
42- specific_image : Optional [np .ndarray ] = None ,
43- use_mask : bool = True ,
44- crop_size : int = 224 ,
45- device : str = 'cpu' ):
28+ specific_image : Optional [np .ndarray ] = None ):
4629
4730 self .id_image : np .ndarray = id_image
4831 self .id_latent = None
4932 self .specific_id_image : Optional [np .ndarray ] = specific_image
5033 self .specific_latent = None
5134
52- self .use_mask : bool = use_mask
53- self .crop_size : int = crop_size
54- self .checkpoint_type : str = config .checkpoint_type
55- self .face_alignment_type : str = config .face_alignment_type
56- self .erosion_kernel_size : int = config .erosion_kernel_size
35+ self .use_mask : bool = True
36+ self .crop_size : int = config .crop_size
37+ self .checkpoint_type : CheckpointType = CheckpointType (config .checkpoint_type )
38+ self .face_alignment_type : FaceAlignmentType = FaceAlignmentType (config .face_alignment_type )
39+ self .erode_mask_value : int = config .erode_mask_value
40+ self .smooth_mask_value : int = config .smooth_mask_value
5741 self .face_detector_threshold : float = config .face_detector_threshold
5842 self .specific_latent_match_th : float = config .specific_latent_match_threshold
59- self .device = torch .device (device )
43+ self .device = torch .device (config .device )
44+
45+ if self .crop_size < 0 :
46+ raise f'Invalid crop_size! Must be a positive value.'
47+
48+ if self .checkpoint_type not in (CheckpointType .OFFICIAL_224 , CheckpointType .UNOFFICIAL ):
49+ raise f'Invalid checkpoint_type! Must be one of the predefined values.'
50+
51+ if self .face_alignment_type not in (FaceAlignmentType .FFHQ , FaceAlignmentType .DEFAULT ):
52+ raise f'Invalid face_alignment_type! Must be one of the predefined values.'
53+
54+ self .use_erosion = True
55+ if self .erode_mask_value == 0 :
56+ self .use_erosion = False
57+
58+ if self .erode_mask_value < 0 :
59+ raise f'Invalid erode_mask_value! Must be a positive value.'
60+
61+ self .use_blur = True
62+ if self .smooth_mask_value == 0 :
63+ self .use_erosion = False
64+ elif self .smooth_mask_value > 0 :
65+ # Make sure it's odd
66+ self .smooth_mask_value += 1 if self .smooth_mask_value % 2 == 0 else 0
67+
68+ if self .smooth_mask_value < 0 :
69+ raise f"Invalid smooth_mask_value! Must be a positive value."
70+
71+ if self .face_detector_threshold < 0.0 or self .face_detector_threshold > 1.0 :
72+ raise f"Invalid face_detector_threshold! Must be a positive value in range [0.0...1.0]."
73+
74+ if self .specific_latent_match_th < 0.0 :
75+ raise f"Invalid specific_latent_match_th! Must be a positive value."
6076
6177 # For BiSeNet and for official_224 SimSwap
6278 self .to_tensor_normalize = transforms .Compose ([
@@ -69,7 +85,7 @@ def __init__(self,
6985
7086 self .face_detector = FaceDetector (
7187 Path (config .face_detector_weights ),
72- det_thresh = self .face_detector_threshold , det_size = (640 , 640 ), mode = "ffhq" , device = device )
88+ det_thresh = self .face_detector_threshold , det_size = (640 , 640 ), mode = "ffhq" , device = self . device . __str__ () )
7389
7490 self .face_id_net = FaceId (Path (config .face_id_weights )).to (self .device )
7591
@@ -81,7 +97,7 @@ def __init__(self,
8197
8298 self .simswap_net = Generator_Adain_Upsample_224 (input_nc = 3 , output_nc = 3 , latent_size = 512 , n_blocks = 9 ,
8399 deep = True if self .crop_size == 512 else False ,
84- checkpoint_type = self .checkpoint_type )
100+ use_last_act = True if self .checkpoint_type == CheckpointType . OFFICIAL_224 else False )
85101
86102 # if crop_size == 224:
87103 # self.simswap_net = Generator_Adain_Upsample_224(input_nc=3, output_nc=3, latent_size=512, n_blocks=9,
@@ -207,21 +223,20 @@ def __call__(self, att_image: np.ndarray) -> np.ndarray:
207223 # Get np.ndarray with range [0...255]
208224 img_mask = tensor2img (img_mask / 255.0 )
209225
210- kernel = np .ones ((self .erosion_kernel_size , self .erosion_kernel_size ), dtype = np .uint8 )
211- img_mask = cv2 .erode (img_mask , kernel , iterations = 1 )
212-
213- delta = 1 if self .erosion_kernel_size % 2 == 0 else 0
214- kernel_size = (self .erosion_kernel_size + delta , self .erosion_kernel_size + delta )
226+ if self .use_erosion :
227+ kernel = np .ones ((self .erode_mask_value , self .erode_mask_value ), dtype = np .uint8 )
228+ img_mask = cv2 .erode (img_mask , kernel , iterations = 1 )
215229
216- img_mask = cv2 .GaussianBlur (img_mask , kernel_size , 0 )
230+ if self .use_blur :
231+ img_mask = cv2 .GaussianBlur (img_mask , (self .smooth_mask_value , self .smooth_mask_value ), 0 )
217232
218233 # Collect all swapped crops
219234 target_image = torch .sum (target_image , dim = 0 , keepdim = True )
220235 target_image = tensor2img (target_image )
221236
222- img_mask = img_mask // 255
237+ img_mask = np . clip ( img_mask / 255 , 0.0 , 1.0 )
223238
224- result = img_mask * target_image + (1 - img_mask ) * att_image
239+ result = ( img_mask * target_image + (1 - img_mask ) * att_image ). astype ( np . uint8 )
225240
226241 # # torch postprocessing
227242 # # faster but Erosion with 40x40 kernel requires too much memory and causes OOM.
0 commit comments