Skip to content

Commit f295dda

Browse files
committed
Fix video manager and rectangle artifacts
1 parent 5f6e324 commit f295dda

File tree

9 files changed

+51
-13
lines changed

9 files changed

+51
-13
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ Config files contain two main parts:
103103
- *att_image* - target image, attributes of the person on this image will be mixed with the person's identity from the source image. Here you can also specify a folder with multiple images - identity translation will be applied to all images in the folder.
104104
- *specific_id_image* - a specific person on the *att_image* you would like to replace, leaving others untouched (if there're any other person).
105105
- *att_video* - the same as *att_image*
106+
- *clean_work_dir* - whether remove temp folder with images or not (for video configs only).
106107

107108

108109
- **pipeline**
@@ -117,6 +118,7 @@ Config files contain two main parts:
117118
- *face_alignment_type* - affects reference face key points coordinates. **Possible values are "ffhq" and "none". Try both of them to see which one works better for your data.**
118119
- *erode_mask_value* - a non-zero value. It's used for the post-processing mask size attenuation. You might want to play with this parameter.
119120
- *smooth_mask_value* - an odd non-zero value. It's used for smoothing edges of the post-processing mask. Usually is set to *erode_mask_value* + 1.
121+
- *sigma_scale_value* - controls the amount of blur added to the post-processing mask. Valid values are in range [0.01...1.0]. Tune it if yuo see artifacts around swapped faces (some rectangles).
120122
- *face_detector_threshold* - values in range [0.0...1.0]. Higher value reduces probability of FP detections but increases the probability of FN.
121123
- *specific_latent_match_threshold* - values in range [0.0...inf]. Usually takes small values around 0.05.
122124
- *enhance_output* - whether to apply GFPGAN model or not as a post-processing step.

app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, config: DictConfig):
3939
self.att_video: Optional[VideoDataManager] = None
4040
if att_video_path and att_video_path.is_file():
4141
self.att_video: Optional[VideoDataManager] = VideoDataManager(
42-
src_data=att_video_path, output_dir=output_dir
42+
src_data=att_video_path, output_dir=output_dir, clean_work_dir=config.data.clean_work_dir
4343
)
4444

4545
assert not (self.att_video and self.att_image), "Only one attribute source can be used!"

app_web.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def get_np_image(file):
4141
label="smooth_mask_value", min_value=1, max_value=61, step=2, value=41
4242
)
4343

44+
sigma_scale_value = st.slider(label="sigma_scale_value", min_value=0.01, max_value=1.0, step=0.01, value=1.0)
45+
4446
specific_latent_match_threshold = st.slider(
4547
label="specific_latent_match_threshold",
4648
min_value=0.0,
@@ -75,6 +77,7 @@ def get_np_image(file):
7577
model.set_face_alignment_type(face_alignment_type)
7678
model.set_erode_mask_value(erode_mask_value)
7779
model.set_smooth_mask_value(smooth_mask_value)
80+
model.set_sigma_scale_value(sigma_scale_value)
7881
model.set_specific_latent_match_threshold(specific_latent_match_threshold)
7982
model.enhance_output = True if enhance_output == "yes" else False
8083

@@ -126,6 +129,7 @@ def load_model(config):
126129
+ " face_alignment_type"
127130
+ " erode_mask_value"
128131
+ " smooth_mask_value"
132+
+ " sigma_scale_value"
129133
+ " face_detector_threshold"
130134
+ " specific_latent_match_threshold"
131135
+ " enhance_output",
@@ -144,6 +148,7 @@ def load_model(config):
144148
face_alignment_type="none",
145149
erode_mask_value=40,
146150
smooth_mask_value=41,
151+
sigma_scale_value=1.0,
147152
face_detector_threshold=0.6,
148153
specific_latent_match_threshold=0.05,
149154
enhance_output=True

configs/run_image.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pipeline:
1818
face_alignment_type: "none" #"ffhq"
1919
erode_mask_value: 40
2020
smooth_mask_value: 41
21+
sigma_scale_value: 1.0
2122
face_detector_threshold: 0.6
2223
specific_latent_match_threshold: 0.05
2324
enhance_output: True

configs/run_image_specific.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pipeline:
1818
face_alignment_type: "none" #"ffhq"
1919
erode_mask_value: 40
2020
smooth_mask_value: 41
21+
sigma_scale_value: 1.0
2122
face_detector_threshold: 0.6
2223
specific_latent_match_threshold: 0.05
2324
enhance_output: True

configs/run_video.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ data:
44
specific_id_image: "none"
55
att_video: "${hydra:runtime.cwd}/demo_file/multi_people_1080p.mp4"
66
output_dir: ${hydra:runtime.cwd}/output
7+
clean_work_dir: True
78

89
pipeline:
910
face_detector_weights: "${hydra:runtime.cwd}/weights/face_detector_scrfd_10g_bnkps.onnx"
@@ -18,6 +19,7 @@ pipeline:
1819
face_alignment_type: "none" #"ffhq"
1920
erode_mask_value: 40
2021
smooth_mask_value: 41
22+
sigma_scale_value: 1.0
2123
face_detector_threshold: 0.6
2224
specific_latent_match_threshold: 0.05
2325
enhance_output: True

configs/run_video_specific.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ data:
44
specific_id_image: "${hydra:runtime.cwd}/demo_file/specific1.png"
55
att_video: "${hydra:runtime.cwd}/demo_file/multi_people_1080p.mp4"
66
output_dir: ${hydra:runtime.cwd}/output
7+
clean_work_dir: True
78

89
pipeline:
910
face_detector_weights: "${hydra:runtime.cwd}/weights/face_detector_scrfd_10g_bnkps.onnx"
@@ -18,6 +19,7 @@ pipeline:
1819
face_alignment_type: "none" #"ffhq"
1920
erode_mask_value: 40
2021
smooth_mask_value: 41
22+
sigma_scale_value: 1.0
2123
face_detector_threshold: 0.6
2224
specific_latent_match_threshold: 0.05
2325
enhance_output: True
Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,38 @@
11
from src.DataManager.base import BaseDataManager
22
from src.DataManager.utils import imwrite_rgb
33

4+
import cv2
45
import numpy as np
56
from pathlib import Path
6-
from typing import Optional
7+
import shutil
8+
from typing import Optional, Union
79

810
from moviepy.editor import AudioFileClip, VideoFileClip
911
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
1012

1113

1214
class VideoDataManager(BaseDataManager):
13-
def __init__(self, src_data: Path, output_dir: Path):
14-
self.video_handle: Optional[VideoFileClip] = None
15+
def __init__(self, src_data: Path, output_dir: Path, clean_work_dir: bool = False):
16+
self.video_handle: Optional[cv2.VideoCapture] = None
1517
self.audio_handle: Optional[AudioFileClip] = None
1618

1719
self.output_dir = output_dir
1820
self.output_img_dir = output_dir / "img"
1921
self.output_dir.mkdir(exist_ok=True)
2022
self.output_img_dir.mkdir(exist_ok=True)
2123
self.video_name = None
24+
self.clean_work_dir = clean_work_dir
2225

2326
if src_data.is_file():
2427
self.video_name = "swap_" + src_data.name
2528

26-
self.audio_handle = AudioFileClip(str(src_data))
27-
self.video_handle = VideoFileClip(str(src_data))
28-
self.fps = self.video_handle.reader.fps
29-
self.frame_count = self.video_handle.reader.nframes
30-
self.data_iterator = zip(range(self.frame_count), self.video_handle.iter_frames())
29+
if VideoFileClip(str(src_data)).audio is not None:
30+
self.audio_handle = AudioFileClip(str(src_data))
31+
32+
self.video_handle = cv2.VideoCapture(str(src_data))
33+
34+
self.frame_count = int(self.video_handle.get(cv2.CAP_PROP_FRAME_COUNT))
35+
self.fps = self.video_handle.get(cv2.CAP_PROP_FPS)
3136

3237
self.last_idx = -1
3338

@@ -37,7 +42,14 @@ def __len__(self):
3742
return self.frame_count
3843

3944
def get(self) -> np.ndarray:
40-
self.last_idx, img = next(self.data_iterator)
45+
img: Union[None, np.ndarray] = None
46+
47+
while img is None and self.last_idx < self.frame_count:
48+
status, img = self.video_handle.read()
49+
self.last_idx += 1
50+
51+
if img is not None:
52+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
4153
return img
4254

4355
def save(self, img: np.ndarray):
@@ -51,6 +63,10 @@ def _close(self):
5163
image_filenames = [str(x) for x in sorted(self.output_img_dir.glob("*.jpg"))]
5264
clip = ImageSequenceClip(image_filenames, fps=self.fps)
5365

54-
clip = clip.set_audio(self.audio_handle)
66+
if self.audio_handle is not None:
67+
clip = clip.set_audio(self.audio_handle)
68+
69+
clip.write_videofile(str(self.output_dir / self.video_name))
5570

56-
clip.write_videofile(str(self.output_dir / self.video_name), audio_codec="aac")
71+
if self.clean_work_dir:
72+
shutil.rmtree(self.output_img_dir, ignore_errors=True)

src/simswap.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
self.face_alignment_type: Union[FaceAlignmentType, None] = None
3636
self.erode_mask_value: Union[int, None] = None
3737
self.smooth_mask_value: Union[int, None] = None
38+
self.sigma_scale_value: Union[float, None] = None
3839
self.face_detector_threshold: Union[float, None] = None
3940
self.specific_latent_match_threshold: Union[float, None] = None
4041
self.device = torch.device(config.device)
@@ -122,6 +123,7 @@ def set_parameters(self, config) -> None:
122123
self.set_specific_latent_match_threshold(config.specific_latent_match_threshold)
123124
self.set_erode_mask_value(config.erode_mask_value)
124125
self.set_smooth_mask_value(config.smooth_mask_value)
126+
self.set_sigma_scale_value(config.sigma_scale_value)
125127

126128
def set_crop_size(self, crop_size: int) -> None:
127129
if crop_size < 0:
@@ -174,6 +176,12 @@ def set_smooth_mask_value(self, smooth_mask_value: int) -> None:
174176

175177
self.smooth_mask_value = smooth_mask_value
176178

179+
def set_sigma_scale_value(self, sigma_scale_value: float) -> None:
180+
if sigma_scale_value < 0 or sigma_scale_value > 1.0:
181+
raise "Invalid sigma_scale_value! Must be within 0...1 range."
182+
183+
self.sigma_scale_value = sigma_scale_value
184+
177185
def run_detect_align(self, image: np.ndarray, for_id: bool = False) -> Tuple[Union[Iterable[np.ndarray], None],
178186
Union[Iterable[np.ndarray], None],
179187
np.ndarray]:
@@ -336,7 +344,8 @@ def __call__(self, att_image: np.ndarray) -> np.ndarray:
336344
kernel_size = (self.smooth_mask_value, self.smooth_mask_value)
337345
# https://docs.opencv.org/4.x/d4/d86/group__imgproc__filter.html#gaabe8c836e97159a9193fb0b11ac52cf1
338346
# https://docs.opencv.org/4.x/d4/d86/group__imgproc__filter.html#gac05a120c1ae92a6060dd0db190a61afa
339-
sigma = 2 * 0.3 * ((kernel_size[0] - 1) * 0.5 - 1) + 0.8
347+
sigma = 0.3 * ((kernel_size[0] - 1) * 0.5 - 1) + 0.8
348+
sigma *= self.sigma_scale_value
340349
img_mask = kornia.filters.gaussian_blur2d(img_mask, kernel_size, (sigma, sigma), border_type='constant',
341350
separable=True)
342351

0 commit comments

Comments
 (0)