|
67 | 67 | import subprocess
|
68 | 68 | import sys
|
69 | 69 | from pathlib import Path
|
| 70 | +from typing import List |
70 | 71 |
|
71 | 72 | import av
|
72 | 73 | import cv2
|
@@ -614,3 +615,150 @@ def get_face_region(image_path: str, detector):
|
614 | 615 | except Exception as e:
|
615 | 616 | print(f"Error processing image {image_path}: {e}")
|
616 | 617 | return None, None
|
| 618 | + |
| 619 | + |
| 620 | +def save_checkpoint(model: torch.nn.Module, save_dir: str, prefix: str, ckpt_num: int, total_limit: int = -1) -> None: |
| 621 | + """ |
| 622 | + Save the model's state_dict to a checkpoint file. |
| 623 | +
|
| 624 | + If `total_limit` is provided, this function will remove the oldest checkpoints |
| 625 | + until the total number of checkpoints is less than the specified limit. |
| 626 | +
|
| 627 | + Args: |
| 628 | + model (nn.Module): The model whose state_dict is to be saved. |
| 629 | + save_dir (str): The directory where the checkpoint will be saved. |
| 630 | + prefix (str): The prefix for the checkpoint file name. |
| 631 | + ckpt_num (int): The checkpoint number to be saved. |
| 632 | + total_limit (int, optional): The maximum number of checkpoints to keep. |
| 633 | + Defaults to None, in which case no checkpoints will be removed. |
| 634 | +
|
| 635 | + Raises: |
| 636 | + FileNotFoundError: If the save directory does not exist. |
| 637 | + ValueError: If the checkpoint number is negative. |
| 638 | + OSError: If there is an error saving the checkpoint. |
| 639 | + """ |
| 640 | + |
| 641 | + if not osp.exists(save_dir): |
| 642 | + raise FileNotFoundError( |
| 643 | + f"The save directory {save_dir} does not exist.") |
| 644 | + |
| 645 | + if ckpt_num < 0: |
| 646 | + raise ValueError(f"Checkpoint number {ckpt_num} must be non-negative.") |
| 647 | + |
| 648 | + save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth") |
| 649 | + |
| 650 | + if total_limit > 0: |
| 651 | + checkpoints = os.listdir(save_dir) |
| 652 | + checkpoints = [d for d in checkpoints if d.startswith(prefix)] |
| 653 | + checkpoints = sorted( |
| 654 | + checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]) |
| 655 | + ) |
| 656 | + |
| 657 | + if len(checkpoints) >= total_limit: |
| 658 | + num_to_remove = len(checkpoints) - total_limit + 1 |
| 659 | + removing_checkpoints = checkpoints[0:num_to_remove] |
| 660 | + print( |
| 661 | + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" |
| 662 | + ) |
| 663 | + print( |
| 664 | + f"Removing checkpoints: {', '.join(removing_checkpoints)}" |
| 665 | + ) |
| 666 | + |
| 667 | + for removing_checkpoint in removing_checkpoints: |
| 668 | + removing_checkpoint_path = osp.join( |
| 669 | + save_dir, removing_checkpoint) |
| 670 | + try: |
| 671 | + os.remove(removing_checkpoint_path) |
| 672 | + except OSError as e: |
| 673 | + print( |
| 674 | + f"Error removing checkpoint {removing_checkpoint_path}: {e}") |
| 675 | + |
| 676 | + state_dict = model.state_dict() |
| 677 | + try: |
| 678 | + torch.save(state_dict, save_path) |
| 679 | + print(f"Checkpoint saved at {save_path}") |
| 680 | + except OSError as e: |
| 681 | + raise OSError(f"Error saving checkpoint at {save_path}: {e}") from e |
| 682 | + |
| 683 | + |
| 684 | +def init_output_dir(dir_list: List[str]): |
| 685 | + """ |
| 686 | + Initialize the output directories. |
| 687 | +
|
| 688 | + This function creates the directories specified in the `dir_list`. If a directory already exists, it does nothing. |
| 689 | +
|
| 690 | + Args: |
| 691 | + dir_list (List[str]): List of directory paths to create. |
| 692 | + """ |
| 693 | + for path in dir_list: |
| 694 | + os.makedirs(path, exist_ok=True) |
| 695 | + |
| 696 | + |
| 697 | +def load_checkpoint(cfg, save_dir, accelerator): |
| 698 | + """ |
| 699 | + Load the most recent checkpoint from the specified directory. |
| 700 | +
|
| 701 | + This function loads the latest checkpoint from the `save_dir` if the `resume_from_checkpoint` parameter is set to "latest". |
| 702 | + If a specific checkpoint is provided in `resume_from_checkpoint`, it loads that checkpoint. If no checkpoint is found, |
| 703 | + it starts training from scratch. |
| 704 | +
|
| 705 | + Args: |
| 706 | + cfg: The configuration object containing training parameters. |
| 707 | + save_dir (str): The directory where checkpoints are saved. |
| 708 | + accelerator: The accelerator object for distributed training. |
| 709 | +
|
| 710 | + Returns: |
| 711 | + int: The global step at which to resume training. |
| 712 | + """ |
| 713 | + if cfg.resume_from_checkpoint != "latest": |
| 714 | + resume_dir = cfg.resume_from_checkpoint |
| 715 | + else: |
| 716 | + resume_dir = save_dir |
| 717 | + # Get the most recent checkpoint |
| 718 | + dirs = os.listdir(resume_dir) |
| 719 | + |
| 720 | + dirs = [d for d in dirs if d.startswith("checkpoint")] |
| 721 | + if len(dirs) > 0: |
| 722 | + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) |
| 723 | + path = dirs[-1] |
| 724 | + accelerator.load_state(os.path.join(resume_dir, path)) |
| 725 | + accelerator.print(f"Resuming from checkpoint {path}") |
| 726 | + global_step = int(path.split("-")[1]) |
| 727 | + else: |
| 728 | + accelerator.print( |
| 729 | + f"Could not find checkpoint under {resume_dir}, start training from scratch") |
| 730 | + global_step = 0 |
| 731 | + |
| 732 | + return global_step |
| 733 | + |
| 734 | + |
| 735 | +def compute_snr(noise_scheduler, timesteps): |
| 736 | + """ |
| 737 | + Computes SNR as per |
| 738 | + https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/ |
| 739 | + 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 |
| 740 | + """ |
| 741 | + alphas_cumprod = noise_scheduler.alphas_cumprod |
| 742 | + sqrt_alphas_cumprod = alphas_cumprod**0.5 |
| 743 | + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
| 744 | + |
| 745 | + # Expand the tensors. |
| 746 | + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/ |
| 747 | + # 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 |
| 748 | + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ |
| 749 | + timesteps |
| 750 | + ].float() |
| 751 | + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): |
| 752 | + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
| 753 | + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) |
| 754 | + |
| 755 | + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( |
| 756 | + device=timesteps.device |
| 757 | + )[timesteps].float() |
| 758 | + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): |
| 759 | + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] |
| 760 | + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) |
| 761 | + |
| 762 | + # Compute SNR. |
| 763 | + snr = (alpha / sigma) ** 2 |
| 764 | + return snr |
0 commit comments