-
Notifications
You must be signed in to change notification settings - Fork 205
【Hackathon 8th No.23】Improved Training of Wasserstein GANs 论文复现 #1146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
Thanks for your contribution! |
Update wgangp_toy.py
Update functions.py
所有模型和训练日志 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
麻烦修改如下问题~
toy: 结果基本符合预期
mnist: 网盘里似乎只有预测结果的图片而没有target,实际运行发现,按照当前参数训练后再评估,mse loss约0.47,生成的图片与target相比都不相同
cifar10:网盘中无预测结果,实际运行时发现在eval的最后一个iter会报错
output_dim: 3072 | ||
label_num: 10 | ||
use_label: true | ||
batch_size: 64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果训练或者评估用的batch_size与这个不同会报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的谢谢老师,已完成此项修改。
if noise is None: | ||
noise = paddle.randn([self.batch_size, 128], dtype=paddle.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看了一下似乎只有生成noise这里使用到了batch_size,但实际上这个值可以通过labels.shape[0]得到,如果允许在config文件中定义,反而有可能出现二者不一致的情况,因此这里是否可以直接改为noise = paddle.randn([labels.shape[0], 128], dtype=paddle.float32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的谢谢老师,已完成此项修改。
inceptionscore: | ||
eps: 1e-16 | ||
splits: 5 | ||
batch_size: 64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当这个值改为1时会报错
另外,即使不更改这个值,运行到eval最后一个iter时也会报错:
File "/PaddleScience/examples/wgangp/wgamgp_cifar10_model.py", line 166, in forward
outputs = outputs * label_scale + label_offset
ValueError: (InvalidArgument) Broadcast dimension mismatch. Operands could not be broadcast together with the shape of X = [64, 128, 4, 4] and the shape of Y = [16, 128, 1, 1]. Received [64] in X is not equal to [16] in Y at i:0.
[Hint: Expected x_dims_array[i] == y_dims_array[i] || (x_dims_array[i] <= 1 && x_dims_array[i] != 0) || (y_dims_array[i] <= 1 && y_dims_array[i] != 0) == true, but received x_dims_array[i] == y_dims_array[i] || (x_dims_array[i] <= 1 && x_dims_array[i] != 0) || (y_dims_array[i] <= 1 && y_dims_array[i] != 0):0 != true:1.] (at ../paddle/phi/kernels/funcs/common_shape.h:71)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
老师修改后未复现出这个错误,应该是没有了。
DATA: | ||
input_keys: [ "labels" ] | ||
label_keys: [ "real_data" ] | ||
data_path: D:\Data\CIFAR-10\raw\cifar-10-python.tar.gz |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改为相对地址如:data_path: "./data/cifar-10-python.tar.gz"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的谢谢老师,已完成此项修改。
pretrained_gen_model_path: | ||
pretrained_dis_model_path: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改为:
pretrained_gen_model_path: null
pretrained_dis_model_path: null
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的谢谢老师,已完成此项修改。
for i in range( | ||
cfg["EVAL"]["batch_size"] | ||
if cfg["EVAL"]["batch_size"] < cfg.VIS.num | ||
else cfg.VIS.num | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同理可以把batch_size设置为1并删除这层循环吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不能把batch_size变为1,变为1后生成的图片质量大幅变差,我没有找到原因,所以我认为应该无法把batch_size变为1
def forward(self, x): | ||
y = paddle.randn([self.batch_size, 128]) | ||
y = self.linear1(y) | ||
y = self.relu1(y) | ||
y = paddle.reshape(y, [-1, 4 * self.dim, 4, 4]) | ||
y = self.conv2d_transpose1(y) | ||
y = self.relu2(y) | ||
y = y[:, :, :7, :7] | ||
y = self.conv2d_transpose2(y) | ||
y = self.relu3(y) | ||
y = self.conv2d_transpose3(y) | ||
y = self.sigmoid(y) | ||
return paddle.reshape(y, [-1, self.output_dim]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个forward中没有用到x吗?这里的self.batch_size是否也是x.shape[0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,谢谢老师,已完成此项修改。
examples/wgangp/wgangp_toy_model.py
Outdated
self.batch_size = batch_size | ||
|
||
def forward(self, x): | ||
y = self.generator(self.batch_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的self.batch_size同理
tensorflow由于通常是静态图模式,需要placeholder,因此可能需要在模型init阶段就确定batch_size,但是paddle和torch都常用动态图模式,就不需要这么做。而且事实上,一般在模型开发过程中会进行反复调优,batch_size作为一个超参数,也会对模型效果、显存占用等方面造成影响,因此也会进行调整,如果在init阶段就确定它,不仅调试麻烦,也可能会导致无法使用之前训练的模型做pretain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,谢谢老师,已完成此项修改。
@@ -54,6 +54,7 @@ nav: | |||
- NeuralOperator: zh/examples/neuraloperator.md | |||
- Brusselator3D: zh/examples/brusselator3d.md | |||
- Transformer4SR: zh/examples/transformer4sr.md | |||
- Wgan_gp: zh/examples/wgan_gp.md |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
放在地球科学分类吧
另外代码改动会影响文档内容,可能需要先合入代码再合入文档,可以麻烦把文档相关的文件另提一个pr吗
@@ -0,0 +1,619 @@ | |||
import math |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件名错了打成m了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,谢谢老师,已完成此项修改。
对于相同目的的反复修改,在git push之前在本地修改的时候,可以使用git commit --amend,而不是git commit -m "新的信息",这样可以把修改改在上一个commit上,然后再push就会更新commit而不是创建新的 |
PR types
New Features
PR changes
Others
Describe
Add wgangp