Skip to content

add lvdm #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions ppdiffusers/examples/text_to_video_lvdm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
## Latent Video Diffusion Model模型训练

本教程介绍 [**LVDM(Latent Video Diffusion Model)**](https://arxiv.org/abs/2211.13221) 的训练,这里的训练仅针对扩散模型(UNet)部分,而不涉及一阶段的模型的训练。


## 准备工作
### 安装依赖

在运行这个训练代码前,我们需要安装ppdiffusers以及相关依赖。


```bash
cd PaddleMIX/ppdiffusers
python setup.py install
pip install -r requirements.txt
```

### 数据准备
准备扩散模型训练的数据,格式需要适配`VideoFrameDataset`或`WebVidDataset`。数据集相关的配置请参考`lvdm/lvdm_args_short.py`或`lvdm/lvdm_args_text2video.py`中的`DatasetArguments`。相关数据下载链接为[Sky Timelapse](https://github.com/weixiong-ur/mdgan)、[Webvid](https://github.com/m-bain/webvid)。


### 预训练模型准备
由于一个完整的PPDiffusers Pipeline包含多个预训练模型,而我们这里仅针对扩散模型(UNet)部分进行训练,所以还需要准备好其他预训练模型参数才能够正常训练和推理,包括Text-Encoder、VAE。此外,开发者如果不想从头开始训练而是在现有模型上微调,也可准备好UNet模型参数并基于此进行微调。目前提供如下预训练模型权重供开发者使用:
- 基于Sky Timelapse数据集的无条件视频生成ema权重,使用3d的vae: ``westfish/lvdm_short_sky``
- 基于Sky Timelapse数据集的无条件视频生成非ema权重,使用3d的vae: ``westfish/lvdm_short_sky_no_ema``
- 基于Webvid数据集的文本条件视频生成非ema权重,使用2d的vae:``westfish/lvdm_text2video_orig_webvid_2m``

## 模型训练
模型训练时的参数配置及含义请参考`lvdm/lvdm_args_short.py`或`lvdm/lvdm_args_text2video.py`,分别对应无条件视频生成和文本条件视频生成,均包含、`ModelArguments`、`DatasetArguments`、`TrainerArguments`,分别表示预训练模型及对齐相关的参数,数据集相关的参数,Trainer相关的参数。开发者可以使用默认参数进行训练,也可以根据需要修改参数。


### 单机单卡训练
```bash
# unconditional generation
python -u train_lvdm_short.py
```
```bash
# text to video generation
python -u train_lvdm_text2video.py
```

### 单机多卡训练
```bash
# unconditional generation
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" train_lvdm_short.py
```
```bash
# text to video generation
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" train_lvdm_text2video.py
```

训练时可通过如下命令通过浏览器观察训练过程:
```
visualdl --logdir your_log_dir/runs --host 0.0.0.0 --port 8042
```
具体的训练范例可参考``scripts/train_lvdm_short_sky.sh``及``scripts/train_lvdm_text2video_webvid.sh``。

具体的推理范例可参考``scripts/inference_lvdm_short.sh``及``scripts/inference_lvdm_text2video.sh``。

## 参考
https://github.com/YingqingHe/LVDM
18 changes: 18 additions & 0 deletions ppdiffusers/examples/text_to_video_lvdm/lvdm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .lvdm_model import LatentVideoDiffusion
from .lvdm_trainer import LatentVideoDiffusionTrainer
from .frame_dataset import VideoFrameDataset
from .webvid_dataset import WebVidDataset
109 changes: 109 additions & 0 deletions ppdiffusers/examples/text_to_video_lvdm/lvdm/_functional_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import paddle
import warnings


def _is_tensor_video_clip(clip):
if not paddle.is_tensor(x=clip):
raise TypeError('clip should be Tensor. Got %s' % type(clip))
if not clip.ndimension() == 4:
raise ValueError('clip should be 4D. Got %dD' % clip.dim())
return True


def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
"""
if len(clip.shape) != 4:
raise ValueError('clip should be a 4D tensor')
return clip[(...), i:i + h, j:j + w]


def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(
f'target size should be tuple (height, width), instead got {target_size}'
)
return paddle.nn.functional.interpolate(
x=clip, size=target_size, mode=interpolation_mode, align_corners=False)


def resized_crop(clip, i, j, h, w, size, interpolation_mode='bilinear'):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError('clip should be a 4D torch.tensor')
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip


def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError('clip should be a 4D torch.tensor')
h, w = clip.shape[-2], clip.shape[-1]
th, tw = crop_size
if h < th or w < tw:
raise ValueError('height and width must be no smaller than crop_size')
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)


def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
Return:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == 'uint8':
raise TypeError('clip tensor should have data type uint8. Got %s' %
str(clip.dtype))
return clip.astype(dtype='float32').transpose(perm=[3, 0, 1, 2]) / 255.0


def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (C, T, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError('clip should be a 4D torch.tensor')
if not inplace:
clip = clip.clone()
mean = paddle.to_tensor(data=mean, place=clip.place).astype(clip.dtype)
std = paddle.to_tensor(data=std, place=clip.place).astype(clip.dtype)
clip = clip.substract(mean[:, (None), (None), (None)]).divide(std[:, (
None), (None), (None)])
return clip


def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
Returns:
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError('clip should be a 4D torch.tensor')
return clip.flip(axis=-1)
157 changes: 157 additions & 0 deletions ppdiffusers/examples/text_to_video_lvdm/lvdm/_transforms_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import paddle
import numbers
import random
import warnings
from . import _functional_video as F


class RandomCropVideo(paddle.vision.transforms.RandomCrop):
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = int(size), int(size)
else:
self.size = size

def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: randomly cropped/resized video clip.
size is (C, T, OH, OW)
"""
i, j, h, w = self.get_params(clip, self.size)
return F.crop(clip, i, j, h, w)

def __repr__(self) -> str:
return f'{self.__class__.__name__}(size={self.size})'


class RandomResizedCropVideo(paddle.vision.transforms.RandomResizedCrop):
def __init__(self,
size,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation_mode='bilinear'):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(
f'size should be tuple (height, width), instead got {size}')
self.size = size
else:
self.size = size, size
self.interpolation_mode = interpolation_mode
self.scale = scale
self.ratio = ratio

def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: randomly cropped/resized video clip.
size is (C, T, H, W)
"""
i, j, h, w = self.get_params(clip, self.scale, self.ratio)
return F.resized_crop(clip, i, j, h, w, self.size,
self.interpolation_mode)

def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})'
)


class CenterCropVideo:
def __init__(self, crop_size):
if isinstance(crop_size, numbers.Number):
self.crop_size = int(crop_size), int(crop_size)
else:
self.crop_size = crop_size

def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: central cropping of video clip. Size is
(C, T, crop_size, crop_size)
"""
return F.center_crop(clip, self.crop_size)

def __repr__(self) -> str:
return f'{self.__class__.__name__}(crop_size={self.crop_size})'


class NormalizeVideo:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""

def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace

def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W)
"""
return F.normalize(clip, self.mean, self.std, self.inplace)

def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})'
)


class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""

def __init__(self):
pass

def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
Return:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
"""
return F.to_tensor(clip)

def __repr__(self) -> str:
return self.__class__.__name__


class RandomHorizontalFlipVideo:
"""
Flip the video clip along the horizontal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""

def __init__(self, p=0.5):
self.p = p

def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (C, T, H, W)
Return:
clip (torch.tensor): Size is (C, T, H, W)
"""
if random.random() < self.p:
clip = F.hflip(clip)
return clip

def __repr__(self) -> str:
return f'{self.__class__.__name__}(p={self.p})'
Loading