Skip to content

Commit 0152cd9

Browse files
authored
feat: training code of hallo (#101)
* add training code and corresponding config yaml files * add some auxiliary funciton in utils * fix a parameter bug in motion module * fix mask size issue in stage2 dataset module "talk_video.py"
1 parent cfd1815 commit 0152cd9

File tree

8 files changed

+2114
-4
lines changed

8 files changed

+2114
-4
lines changed

configs/train/stage1.yaml

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
data:
2+
train_bs: 8
3+
train_width: 512
4+
train_height: 512
5+
meta_paths:
6+
- "./data/HDTF_meta.json"
7+
# Margin of frame indexes between ref and tgt images
8+
sample_margin: 30
9+
10+
solver:
11+
gradient_accumulation_steps: 1
12+
mixed_precision: "no"
13+
enable_xformers_memory_efficient_attention: True
14+
gradient_checkpointing: False
15+
max_train_steps: 30000
16+
max_grad_norm: 1.0
17+
# lr
18+
learning_rate: 1.0e-5
19+
scale_lr: False
20+
lr_warmup_steps: 1
21+
lr_scheduler: "constant"
22+
23+
# optimizer
24+
use_8bit_adam: False
25+
adam_beta1: 0.9
26+
adam_beta2: 0.999
27+
adam_weight_decay: 1.0e-2
28+
adam_epsilon: 1.0e-8
29+
30+
val:
31+
validation_steps: 500
32+
33+
noise_scheduler_kwargs:
34+
num_train_timesteps: 1000
35+
beta_start: 0.00085
36+
beta_end: 0.012
37+
beta_schedule: "scaled_linear"
38+
steps_offset: 1
39+
clip_sample: false
40+
41+
base_model_path: "./pretrained_models/stable-diffusion-v1-5/"
42+
vae_model_path: "./pretrained_models/sd-vae-ft-mse"
43+
face_analysis_model_path: "./pretrained_models/face_analysis"
44+
45+
weight_dtype: "fp16" # [fp16, fp32]
46+
uncond_ratio: 0.1
47+
noise_offset: 0.05
48+
snr_gamma: 5.0
49+
enable_zero_snr: True
50+
face_locator_pretrained: False
51+
52+
seed: 42
53+
resume_from_checkpoint: "latest"
54+
checkpointing_steps: 500
55+
exp_name: "stage1"
56+
output_dir: "./exp_output"
57+
58+
ref_image_paths:
59+
- "examples/reference_images/1.jpg"
60+
61+
mask_image_paths:
62+
- "examples/masks/1.png"
63+

configs/train/stage2.yaml

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
data:
2+
train_bs: 4
3+
val_bs: 1
4+
train_width: 512
5+
train_height: 512
6+
fps: 25
7+
sample_rate: 16000
8+
n_motion_frames: 2
9+
n_sample_frames: 14
10+
audio_margin: 2
11+
train_meta_paths:
12+
- "./data/hdtf_split_stage2.json"
13+
14+
wav2vec_config:
15+
audio_type: "vocals" # audio vocals
16+
model_scale: "base" # base large
17+
features: "all" # last avg all
18+
model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h
19+
audio_separator:
20+
model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx
21+
face_expand_ratio: 1.2
22+
23+
solver:
24+
gradient_accumulation_steps: 1
25+
mixed_precision: "no"
26+
enable_xformers_memory_efficient_attention: True
27+
gradient_checkpointing: True
28+
max_train_steps: 30000
29+
max_grad_norm: 1.0
30+
# lr
31+
learning_rate: 1e-5
32+
scale_lr: False
33+
lr_warmup_steps: 1
34+
lr_scheduler: "constant"
35+
36+
# optimizer
37+
use_8bit_adam: True
38+
adam_beta1: 0.9
39+
adam_beta2: 0.999
40+
adam_weight_decay: 1.0e-2
41+
adam_epsilon: 1.0e-8
42+
43+
val:
44+
validation_steps: 1000
45+
46+
noise_scheduler_kwargs:
47+
num_train_timesteps: 1000
48+
beta_start: 0.00085
49+
beta_end: 0.012
50+
beta_schedule: "linear"
51+
steps_offset: 1
52+
clip_sample: false
53+
54+
unet_additional_kwargs:
55+
use_inflated_groupnorm: true
56+
unet_use_cross_frame_attention: false
57+
unet_use_temporal_attention: false
58+
use_motion_module: true
59+
use_audio_module: true
60+
motion_module_resolutions:
61+
- 1
62+
- 2
63+
- 4
64+
- 8
65+
motion_module_mid_block: true
66+
motion_module_decoder_only: false
67+
motion_module_type: Vanilla
68+
motion_module_kwargs:
69+
num_attention_heads: 8
70+
num_transformer_block: 1
71+
attention_block_types:
72+
- Temporal_Self
73+
- Temporal_Self
74+
temporal_position_encoding: true
75+
temporal_position_encoding_max_len: 32
76+
temporal_attention_dim_div: 1
77+
audio_attention_dim: 768
78+
stack_enable_blocks_name:
79+
- "up"
80+
- "down"
81+
- "mid"
82+
stack_enable_blocks_depth: [0,1,2,3]
83+
84+
trainable_para:
85+
- audio_modules
86+
- motion_modules
87+
88+
base_model_path: "./pretrained_models/stable-diffusion-v1-5/"
89+
vae_model_path: "./pretrained_models/sd-vae-ft-mse"
90+
face_analysis_model_path: "./pretrained_models/face_analysis"
91+
mm_path: "./pretrained_models/motion_module/mm_sd_v15_v2.ckpt"
92+
93+
weight_dtype: "fp16" # [fp16, fp32]
94+
uncond_img_ratio: 0.05
95+
uncond_audio_ratio: 0.05
96+
uncond_ia_ratio: 0.05
97+
start_ratio: 0.05
98+
noise_offset: 0.05
99+
snr_gamma: 5.0
100+
enable_zero_snr: True
101+
stage1_ckpt_dir: "./pretrained_models/hallo/stage1"
102+
103+
single_inference_times: 10
104+
inference_steps: 40
105+
cfg_scale: 3.5
106+
107+
seed: 42
108+
resume_from_checkpoint: "latest"
109+
checkpointing_steps: 500
110+
exp_name: "stage2_test"
111+
output_dir: "./exp_output"
112+
113+
ref_img_path:
114+
- "examples/reference_images/1.jpg"
115+
116+
audio_path:
117+
- "examples/driving_audios/1.wav"
118+
119+

examples/masks/1.png

7.32 KB
Loading

hallo/datasets/talk_video.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,25 +145,29 @@ def __init__(
145145
)
146146
self.attn_transform_64 = transforms.Compose(
147147
[
148-
transforms.Resize((64,64)),
148+
transforms.Resize(
149+
(self.img_size[0] // 8, self.img_size[0] // 8)),
149150
transforms.ToTensor(),
150151
]
151152
)
152153
self.attn_transform_32 = transforms.Compose(
153154
[
154-
transforms.Resize((32, 32)),
155+
transforms.Resize(
156+
(self.img_size[0] // 16, self.img_size[0] // 16)),
155157
transforms.ToTensor(),
156158
]
157159
)
158160
self.attn_transform_16 = transforms.Compose(
159161
[
160-
transforms.Resize((16, 16)),
162+
transforms.Resize(
163+
(self.img_size[0] // 32, self.img_size[0] // 32)),
161164
transforms.ToTensor(),
162165
]
163166
)
164167
self.attn_transform_8 = transforms.Compose(
165168
[
166-
transforms.Resize((8, 8)),
169+
transforms.Resize(
170+
(self.img_size[0] // 64, self.img_size[0] // 64)),
167171
transforms.ToTensor(),
168172
]
169173
)

hallo/models/motion_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ def extra_repr(self):
507507
def set_use_memory_efficient_attention_xformers(
508508
self,
509509
use_memory_efficient_attention_xformers: bool,
510+
attention_op = None,
510511
):
511512
"""
512513
Sets the use of memory-efficient attention xformers for the VersatileAttention class.

hallo/utils/util.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
import subprocess
6868
import sys
6969
from pathlib import Path
70+
from typing import List
7071

7172
import av
7273
import cv2
@@ -614,3 +615,150 @@ def get_face_region(image_path: str, detector):
614615
except Exception as e:
615616
print(f"Error processing image {image_path}: {e}")
616617
return None, None
618+
619+
620+
def save_checkpoint(model: torch.nn.Module, save_dir: str, prefix: str, ckpt_num: int, total_limit: int = -1) -> None:
621+
"""
622+
Save the model's state_dict to a checkpoint file.
623+
624+
If `total_limit` is provided, this function will remove the oldest checkpoints
625+
until the total number of checkpoints is less than the specified limit.
626+
627+
Args:
628+
model (nn.Module): The model whose state_dict is to be saved.
629+
save_dir (str): The directory where the checkpoint will be saved.
630+
prefix (str): The prefix for the checkpoint file name.
631+
ckpt_num (int): The checkpoint number to be saved.
632+
total_limit (int, optional): The maximum number of checkpoints to keep.
633+
Defaults to None, in which case no checkpoints will be removed.
634+
635+
Raises:
636+
FileNotFoundError: If the save directory does not exist.
637+
ValueError: If the checkpoint number is negative.
638+
OSError: If there is an error saving the checkpoint.
639+
"""
640+
641+
if not osp.exists(save_dir):
642+
raise FileNotFoundError(
643+
f"The save directory {save_dir} does not exist.")
644+
645+
if ckpt_num < 0:
646+
raise ValueError(f"Checkpoint number {ckpt_num} must be non-negative.")
647+
648+
save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")
649+
650+
if total_limit > 0:
651+
checkpoints = os.listdir(save_dir)
652+
checkpoints = [d for d in checkpoints if d.startswith(prefix)]
653+
checkpoints = sorted(
654+
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
655+
)
656+
657+
if len(checkpoints) >= total_limit:
658+
num_to_remove = len(checkpoints) - total_limit + 1
659+
removing_checkpoints = checkpoints[0:num_to_remove]
660+
print(
661+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
662+
)
663+
print(
664+
f"Removing checkpoints: {', '.join(removing_checkpoints)}"
665+
)
666+
667+
for removing_checkpoint in removing_checkpoints:
668+
removing_checkpoint_path = osp.join(
669+
save_dir, removing_checkpoint)
670+
try:
671+
os.remove(removing_checkpoint_path)
672+
except OSError as e:
673+
print(
674+
f"Error removing checkpoint {removing_checkpoint_path}: {e}")
675+
676+
state_dict = model.state_dict()
677+
try:
678+
torch.save(state_dict, save_path)
679+
print(f"Checkpoint saved at {save_path}")
680+
except OSError as e:
681+
raise OSError(f"Error saving checkpoint at {save_path}: {e}") from e
682+
683+
684+
def init_output_dir(dir_list: List[str]):
685+
"""
686+
Initialize the output directories.
687+
688+
This function creates the directories specified in the `dir_list`. If a directory already exists, it does nothing.
689+
690+
Args:
691+
dir_list (List[str]): List of directory paths to create.
692+
"""
693+
for path in dir_list:
694+
os.makedirs(path, exist_ok=True)
695+
696+
697+
def load_checkpoint(cfg, save_dir, accelerator):
698+
"""
699+
Load the most recent checkpoint from the specified directory.
700+
701+
This function loads the latest checkpoint from the `save_dir` if the `resume_from_checkpoint` parameter is set to "latest".
702+
If a specific checkpoint is provided in `resume_from_checkpoint`, it loads that checkpoint. If no checkpoint is found,
703+
it starts training from scratch.
704+
705+
Args:
706+
cfg: The configuration object containing training parameters.
707+
save_dir (str): The directory where checkpoints are saved.
708+
accelerator: The accelerator object for distributed training.
709+
710+
Returns:
711+
int: The global step at which to resume training.
712+
"""
713+
if cfg.resume_from_checkpoint != "latest":
714+
resume_dir = cfg.resume_from_checkpoint
715+
else:
716+
resume_dir = save_dir
717+
# Get the most recent checkpoint
718+
dirs = os.listdir(resume_dir)
719+
720+
dirs = [d for d in dirs if d.startswith("checkpoint")]
721+
if len(dirs) > 0:
722+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
723+
path = dirs[-1]
724+
accelerator.load_state(os.path.join(resume_dir, path))
725+
accelerator.print(f"Resuming from checkpoint {path}")
726+
global_step = int(path.split("-")[1])
727+
else:
728+
accelerator.print(
729+
f"Could not find checkpoint under {resume_dir}, start training from scratch")
730+
global_step = 0
731+
732+
return global_step
733+
734+
735+
def compute_snr(noise_scheduler, timesteps):
736+
"""
737+
Computes SNR as per
738+
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/
739+
521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
740+
"""
741+
alphas_cumprod = noise_scheduler.alphas_cumprod
742+
sqrt_alphas_cumprod = alphas_cumprod**0.5
743+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
744+
745+
# Expand the tensors.
746+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/
747+
# 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
748+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
749+
timesteps
750+
].float()
751+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
752+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
753+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
754+
755+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
756+
device=timesteps.device
757+
)[timesteps].float()
758+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
759+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
760+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
761+
762+
# Compute SNR.
763+
snr = (alpha / sigma) ** 2
764+
return snr

0 commit comments

Comments
 (0)