Skip to content

Commit 53bc81a

Browse files
authored
feat: data preprocessing code of hallo (#103)
* feat: data preprocessing code of hallo * add data preprocessing * add utils functions of data preprocessing * add image processor and audio processor of data preprocessing * fix: train config and data processing param adjustment * add model weight postprocess after stage1 * make data processing param easier to understand
1 parent 0152cd9 commit 53bc81a

File tree

8 files changed

+860
-17
lines changed

8 files changed

+860
-17
lines changed

configs/train/stage2.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ start_ratio: 0.05
9898
noise_offset: 0.05
9999
snr_gamma: 5.0
100100
enable_zero_snr: True
101-
stage1_ckpt_dir: "./pretrained_models/hallo/stage1"
101+
stage1_ckpt_dir: "./exp_output/stage1/"
102102

103103
single_inference_times: 10
104104
inference_steps: 40
@@ -107,7 +107,7 @@ cfg_scale: 3.5
107107
seed: 42
108108
resume_from_checkpoint: "latest"
109109
checkpointing_steps: 500
110-
exp_name: "stage2_test"
110+
exp_name: "stage2"
111111
output_dir: "./exp_output"
112112

113113
ref_img_path:

hallo/datasets/audio_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
7474

7575

76-
def preprocess(self, wav_file: str, clip_length: int):
76+
def preprocess(self, wav_file: str, clip_length: int=-1):
7777
"""
7878
Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
7979
The separated vocal track is then converted into wav2vec2 for further processing or analysis.
@@ -109,7 +109,8 @@ def preprocess(self, wav_file: str, clip_length: int):
109109
audio_length = seq_len
110110

111111
audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device)
112-
if seq_len % clip_length != 0:
112+
113+
if clip_length>0 and seq_len % clip_length != 0:
113114
audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0)
114115
seq_len += clip_length - seq_len % clip_length
115116
audio_feature = audio_feature.unsqueeze(0)

hallo/datasets/image_processor.py

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pylint: disable=W0718
12
"""
23
This module is responsible for processing images, particularly for face-related tasks.
34
It uses various libraries such as OpenCV, NumPy, and InsightFace to perform tasks like
@@ -8,13 +9,15 @@
89
from typing import List
910

1011
import cv2
12+
import mediapipe as mp
1113
import numpy as np
1214
import torch
1315
from insightface.app import FaceAnalysis
1416
from PIL import Image
1517
from torchvision import transforms
1618

17-
from ..utils.util import get_mask
19+
from ..utils.util import (blur_mask, get_landmark_overframes, get_mask,
20+
get_union_face_mask, get_union_lip_mask)
1821

1922
MEAN = 0.5
2023
STD = 0.5
@@ -207,3 +210,137 @@ def __enter__(self):
207210

208211
def __exit__(self, _exc_type, _exc_val, _exc_tb):
209212
self.close()
213+
214+
215+
class ImageProcessorForDataProcessing():
216+
"""
217+
ImageProcessor is a class responsible for processing images, particularly for face-related tasks.
218+
It takes in an image and performs various operations such as augmentation, face detection,
219+
face embedding extraction, and rendering a face mask. The processed images are then used for
220+
further analysis or recognition purposes.
221+
222+
Attributes:
223+
img_size (int): The size of the image to be processed.
224+
face_analysis_model_path (str): The path to the face analysis model.
225+
226+
Methods:
227+
preprocess(source_image_path, cache_dir):
228+
Preprocesses the input image by performing augmentation, face detection,
229+
face embedding extraction, and rendering a face mask.
230+
231+
close():
232+
Closes the ImageProcessor and releases any resources being used.
233+
234+
_augmentation(images, transform, state=None):
235+
Applies image augmentation to the input images using the given transform and state.
236+
237+
__enter__():
238+
Enters a runtime context and returns the ImageProcessor object.
239+
240+
__exit__(_exc_type, _exc_val, _exc_tb):
241+
Exits a runtime context and handles any exceptions that occurred during the processing.
242+
"""
243+
def __init__(self, face_analysis_model_path, landmark_model_path, step) -> None:
244+
if step == 2:
245+
self.face_analysis = FaceAnalysis(
246+
name="",
247+
root=face_analysis_model_path,
248+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
249+
)
250+
self.face_analysis.prepare(ctx_id=0, det_size=(640, 640))
251+
self.landmarker = None
252+
else:
253+
BaseOptions = mp.tasks.BaseOptions
254+
FaceLandmarker = mp.tasks.vision.FaceLandmarker
255+
FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
256+
VisionRunningMode = mp.tasks.vision.RunningMode
257+
# Create a face landmarker instance with the video mode:
258+
options = FaceLandmarkerOptions(
259+
base_options=BaseOptions(model_asset_path=landmark_model_path),
260+
running_mode=VisionRunningMode.IMAGE,
261+
)
262+
self.landmarker = FaceLandmarker.create_from_options(options)
263+
self.face_analysis = None
264+
265+
def preprocess(self, source_image_path: str):
266+
"""
267+
Apply preprocessing to the source image to prepare for face analysis.
268+
269+
Parameters:
270+
source_image_path (str): The path to the source image.
271+
cache_dir (str): The directory to cache intermediate results.
272+
273+
Returns:
274+
None
275+
"""
276+
# 1. get face embdeding
277+
face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask = None, None, None, None, None
278+
if self.face_analysis:
279+
for frame in sorted(os.listdir(source_image_path)):
280+
try:
281+
source_image = Image.open(
282+
os.path.join(source_image_path, frame))
283+
ref_image_pil = source_image.convert("RGB")
284+
# 2.1 detect face
285+
faces = self.face_analysis.get(cv2.cvtColor(
286+
np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR))
287+
# use max size face
288+
face = sorted(faces, key=lambda x: (
289+
x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[-1]
290+
# 2.2 face embedding
291+
face_emb = face["embedding"]
292+
if face_emb is not None:
293+
break
294+
except Exception as _:
295+
continue
296+
297+
if self.landmarker:
298+
# 3.1 get landmark
299+
landmarks, height, width = get_landmark_overframes(
300+
self.landmarker, source_image_path)
301+
assert len(landmarks) == len(os.listdir(source_image_path))
302+
303+
# 3 render face and lip mask
304+
face_mask = get_union_face_mask(landmarks, height, width)
305+
lip_mask = get_union_lip_mask(landmarks, height, width)
306+
307+
# 4 gaussian blur
308+
blur_face_mask = blur_mask(face_mask, (64, 64), (51, 51))
309+
blur_lip_mask = blur_mask(lip_mask, (64, 64), (31, 31))
310+
311+
# 5 seperate mask
312+
sep_face_mask = cv2.subtract(blur_face_mask, blur_lip_mask)
313+
sep_pose_mask = 255.0 - blur_face_mask
314+
sep_lip_mask = blur_lip_mask
315+
316+
return face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask
317+
318+
def close(self):
319+
"""
320+
Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance.
321+
322+
Args:
323+
self: The ImageProcessor instance.
324+
325+
Returns:
326+
None.
327+
"""
328+
for _, model in self.face_analysis.models.items():
329+
if hasattr(model, "Dispose"):
330+
model.Dispose()
331+
332+
def _augmentation(self, images, transform, state=None):
333+
if state is not None:
334+
torch.set_rng_state(state)
335+
if isinstance(images, List):
336+
transformed_images = [transform(img) for img in images]
337+
ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
338+
else:
339+
ret_tensor = transform(images) # (c, h, w)
340+
return ret_tensor
341+
342+
def __enter__(self):
343+
return self
344+
345+
def __exit__(self, _exc_type, _exc_val, _exc_tb):
346+
self.close()

0 commit comments

Comments
 (0)