Skip to content

Commit e3ac35e

Browse files
committed
Fix transition articacts
* Clip final mask to 0...1 * Added parameter smooth_mask_value * Added parameter validation
1 parent f6dce04 commit e3ac35e

File tree

9 files changed

+100
-59
lines changed

9 files changed

+100
-59
lines changed

app.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,6 @@ def __init__(self, config: DictConfig):
2222
att_video_path = Path(config.data.att_video)
2323
output_dir = Path(config.data.output_dir)
2424

25-
device = config.pipeline.device
26-
27-
crop_size = config.pipeline.crop_size
28-
use_mask = True
29-
3025
assert id_image_path.exists(), f"Can't find {id_image_path} file!"
3126

3227
self.id_image: Optional[np.ndarray] = imread_rgb(id_image_path)
@@ -49,10 +44,7 @@ def __init__(self, config: DictConfig):
4944

5045
self.model = SimSwap(config=config.pipeline,
5146
id_image=self.id_image,
52-
specific_image=self.specific_id_image,
53-
use_mask=use_mask,
54-
crop_size=crop_size,
55-
device=device)
47+
specific_image=self.specific_id_image)
5648

5749
def run(self):
5850
for _ in tqdm(range(len(self.data_manager))):

configs/run_image.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ pipeline:
1313
device: "cuda"
1414
crop_size: 224
1515
# it seems that the official 224 checkpoint works better with 'none' face alignment type
16-
checkpoint_type: "official_224" #"none
16+
checkpoint_type: "official_224" #"none"
1717
face_alignment_type: "none" #"ffhq"
18-
erosion_kernel_size: 4
18+
erode_mask_value: 40
19+
smooth_mask_value: 41
1920
face_detector_threshold: 0.6
2021
specific_latent_match_threshold: 0.05
2122

configs/run_image_specific.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ pipeline:
1313
device: "cuda"
1414
crop_size: 224
1515
# it seems that the official 224 checkpoint works better with 'none' face alignment type
16-
checkpoint_type: "official_224" #"none
16+
checkpoint_type: "official_224" #"none"
1717
face_alignment_type: "none" #"ffhq"
18-
erosion_kernel_size: 4
18+
erode_mask_value: 40
19+
smooth_mask_value: 41
1920
face_detector_threshold: 0.6
2021
specific_latent_match_threshold: 0.05
2122

configs/run_video.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ pipeline:
1313
device: "cuda"
1414
crop_size: 224
1515
# it seems that the official 224 checkpoint works better with 'none' face alignment type
16-
checkpoint_type: "official_224" #"none
16+
checkpoint_type: "official_224" #"none"
1717
face_alignment_type: "none" #"ffhq"
18-
erosion_kernel_size: 4
18+
erode_mask_value: 40
19+
smooth_mask_value: 41
1920
face_detector_threshold: 0.6
2021
specific_latent_match_threshold: 0.05
2122

configs/run_video_specific.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ pipeline:
1313
device: "cuda"
1414
crop_size: 224
1515
# it seems that the official 224 checkpoint works better with 'none' face alignment type
16-
checkpoint_type: "official_224" #"none
16+
checkpoint_type: "official_224" #"none"
1717
face_alignment_type: "none" #"ffhq"
18-
erosion_kernel_size: 4
18+
erode_mask_value: 40
19+
smooth_mask_value: 41
1920
face_detector_threshold: 0.6
2021
specific_latent_match_threshold: 0.05
2122

src/Generator/fs_networks_fix.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(self, input_nc: int,
9494
latent_size: int,
9595
n_blocks: int=6,
9696
deep: bool = False,
97-
checkpoint_type: str = 'none',
97+
use_last_act: bool = True,
9898
norm_layer: torch.nn.Module = nn.BatchNorm2d,
9999
padding_type: str = 'reflect'):
100100
assert (n_blocks >= 0)
@@ -103,7 +103,7 @@ def __init__(self, input_nc: int,
103103
activation = nn.ReLU(True)
104104

105105
self.deep = deep
106-
self.checkpoint_type = checkpoint_type
106+
self.use_last_act = use_last_act
107107

108108
self.to_tensor_normalize = transforms.Compose([
109109
transforms.ToTensor(),
@@ -159,7 +159,7 @@ def __init__(self, input_nc: int,
159159
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
160160
nn.BatchNorm2d(64), activation
161161
)
162-
if self.checkpoint_type == "official_224":
162+
if self.use_last_act:
163163
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0),
164164
torch.nn.Tanh())
165165
else:
@@ -173,7 +173,7 @@ def to(self, device):
173173
return self
174174

175175
def forward(self, x: Iterable[np.ndarray], dlatents: torch.Tensor):
176-
if self.checkpoint_type == "official_224":
176+
if self.use_last_act:
177177
x = [self.to_tensor(_) for _ in x]
178178
else:
179179
x = [self.to_tensor_normalize(_) for _ in x]
@@ -202,7 +202,7 @@ def forward(self, x: Iterable[np.ndarray], dlatents: torch.Tensor):
202202
x = self.up1(x)
203203
x = self.last_layer(x)
204204

205-
if self.checkpoint_type == "official_224":
205+
if self.use_last_act:
206206
x = (x + 1) / 2
207207
else:
208208
x = x * self.imagenet_std + self.imagenet_mean

src/Misc/types.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from enum import Enum
2+
3+
4+
class CheckpointType(Enum):
5+
OFFICIAL_224 = "official_224"
6+
UNOFFICIAL = "none"
7+
8+
9+
class FaceAlignmentType(Enum):
10+
FFHQ = "ffhq"
11+
DEFAULT = "none"

src/Misc/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
import numpy as np
3+
4+
5+
def tensor2img_denorm(tensor):
6+
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
7+
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
8+
tensor = std * tensor.detach().cpu() + mean
9+
img = tensor.numpy()
10+
img = img.transpose(0, 2, 3, 1)[0]
11+
img = np.clip(img * 255, 0.0, 255.0).astype(np.uint8)
12+
return img
13+
14+
15+
def tensor2img(tensor):
16+
tensor = tensor.detach().cpu().numpy()
17+
img = tensor.transpose(0, 2, 3, 1)[0]
18+
img = np.clip(img * 255, 0.0, 255.0).astype(np.uint8)
19+
return img

src/simswap.py

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import torch
44
import torch.nn.functional as F
5+
from enum import Enum
56
from typing import Optional, Iterable, Tuple, Union
67
from pathlib import Path
78
from torchvision import transforms
@@ -16,47 +17,62 @@
1617
from src.PostProcess.utils import SoftErosion
1718
from src.Generator.fs_networks_fix import Generator_Adain_Upsample as Generator_Adain_Upsample_224
1819
from src.Generator.fs_networks_512 import Generator_Adain_Upsample as Generator_Adain_Upsample_512
19-
20-
21-
def tensor2img_denorm(tensor):
22-
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
23-
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
24-
tensor = std * tensor.detach().cpu() + mean
25-
img = tensor.numpy()
26-
img = img.transpose(0, 2, 3, 1)[0]
27-
img = np.clip(img * 255, 0.0, 255.0).astype(np.uint8)
28-
return img
29-
30-
31-
def tensor2img(tensor):
32-
tensor = tensor.detach().cpu().numpy()
33-
img = tensor.transpose(0, 2, 3, 1)[0]
34-
img = np.clip(img * 255, 0.0, 255.0).astype(np.uint8)
35-
return img
20+
from src.Misc.types import CheckpointType, FaceAlignmentType
21+
from src.Misc.utils import tensor2img, tensor2img_denorm
3622

3723

3824
class SimSwap:
3925
def __init__(self,
4026
config: DictConfig,
4127
id_image: np.ndarray,
42-
specific_image: Optional[np.ndarray] = None,
43-
use_mask: bool = True,
44-
crop_size: int = 224,
45-
device: str = 'cpu'):
28+
specific_image: Optional[np.ndarray] = None):
4629

4730
self.id_image: np.ndarray = id_image
4831
self.id_latent = None
4932
self.specific_id_image: Optional[np.ndarray] = specific_image
5033
self.specific_latent = None
5134

52-
self.use_mask: bool = use_mask
53-
self.crop_size: int = crop_size
54-
self.checkpoint_type: str = config.checkpoint_type
55-
self.face_alignment_type: str = config.face_alignment_type
56-
self.erosion_kernel_size: int = config.erosion_kernel_size
35+
self.use_mask: bool = True
36+
self.crop_size: int = config.crop_size
37+
self.checkpoint_type: CheckpointType = CheckpointType(config.checkpoint_type)
38+
self.face_alignment_type: FaceAlignmentType = FaceAlignmentType(config.face_alignment_type)
39+
self.erode_mask_value: int = config.erode_mask_value
40+
self.smooth_mask_value: int = config.smooth_mask_value
5741
self.face_detector_threshold: float = config.face_detector_threshold
5842
self.specific_latent_match_th: float = config.specific_latent_match_threshold
59-
self.device = torch.device(device)
43+
self.device = torch.device(config.device)
44+
45+
if self.crop_size < 0:
46+
raise f'Invalid crop_size! Must be a positive value.'
47+
48+
if self.checkpoint_type not in (CheckpointType.OFFICIAL_224, CheckpointType.UNOFFICIAL):
49+
raise f'Invalid checkpoint_type! Must be one of the predefined values.'
50+
51+
if self.face_alignment_type not in (FaceAlignmentType.FFHQ, FaceAlignmentType.DEFAULT):
52+
raise f'Invalid face_alignment_type! Must be one of the predefined values.'
53+
54+
self.use_erosion = True
55+
if self.erode_mask_value == 0:
56+
self.use_erosion = False
57+
58+
if self.erode_mask_value < 0:
59+
raise f'Invalid erode_mask_value! Must be a positive value.'
60+
61+
self.use_blur = True
62+
if self.smooth_mask_value == 0:
63+
self.use_erosion = False
64+
elif self.smooth_mask_value > 0:
65+
# Make sure it's odd
66+
self.smooth_mask_value += 1 if self.smooth_mask_value % 2 == 0 else 0
67+
68+
if self.smooth_mask_value < 0:
69+
raise f"Invalid smooth_mask_value! Must be a positive value."
70+
71+
if self.face_detector_threshold < 0.0 or self.face_detector_threshold > 1.0:
72+
raise f"Invalid face_detector_threshold! Must be a positive value in range [0.0...1.0]."
73+
74+
if self.specific_latent_match_th < 0.0:
75+
raise f"Invalid specific_latent_match_th! Must be a positive value."
6076

6177
# For BiSeNet and for official_224 SimSwap
6278
self.to_tensor_normalize = transforms.Compose([
@@ -69,7 +85,7 @@ def __init__(self,
6985

7086
self.face_detector = FaceDetector(
7187
Path(config.face_detector_weights),
72-
det_thresh=self.face_detector_threshold, det_size=(640, 640), mode="ffhq", device=device)
88+
det_thresh=self.face_detector_threshold, det_size=(640, 640), mode="ffhq", device=self.device.__str__())
7389

7490
self.face_id_net = FaceId(Path(config.face_id_weights)).to(self.device)
7591

@@ -81,7 +97,7 @@ def __init__(self,
8197

8298
self.simswap_net = Generator_Adain_Upsample_224(input_nc=3, output_nc=3, latent_size=512, n_blocks=9,
8399
deep=True if self.crop_size == 512 else False,
84-
checkpoint_type=self.checkpoint_type)
100+
use_last_act=True if self.checkpoint_type == CheckpointType.OFFICIAL_224 else False)
85101

86102
# if crop_size == 224:
87103
# self.simswap_net = Generator_Adain_Upsample_224(input_nc=3, output_nc=3, latent_size=512, n_blocks=9,
@@ -207,21 +223,20 @@ def __call__(self, att_image: np.ndarray) -> np.ndarray:
207223
# Get np.ndarray with range [0...255]
208224
img_mask = tensor2img(img_mask / 255.0)
209225

210-
kernel = np.ones((self.erosion_kernel_size, self.erosion_kernel_size), dtype=np.uint8)
211-
img_mask = cv2.erode(img_mask, kernel, iterations=1)
212-
213-
delta = 1 if self.erosion_kernel_size % 2 == 0 else 0
214-
kernel_size = (self.erosion_kernel_size + delta, self.erosion_kernel_size + delta)
226+
if self.use_erosion:
227+
kernel = np.ones((self.erode_mask_value, self.erode_mask_value), dtype=np.uint8)
228+
img_mask = cv2.erode(img_mask, kernel, iterations=1)
215229

216-
img_mask = cv2.GaussianBlur(img_mask, kernel_size, 0)
230+
if self.use_blur:
231+
img_mask = cv2.GaussianBlur(img_mask, (self.smooth_mask_value, self.smooth_mask_value), 0)
217232

218233
# Collect all swapped crops
219234
target_image = torch.sum(target_image, dim=0, keepdim=True)
220235
target_image = tensor2img(target_image)
221236

222-
img_mask = img_mask // 255
237+
img_mask = np.clip(img_mask / 255, 0.0, 1.0)
223238

224-
result = img_mask * target_image + (1 - img_mask) * att_image
239+
result = (img_mask * target_image + (1 - img_mask) * att_image).astype(np.uint8)
225240

226241
# # torch postprocessing
227242
# # faster but Erosion with 40x40 kernel requires too much memory and causes OOM.

0 commit comments

Comments
 (0)