1
1
# 在 PaddleScience 中复现 WGAN-GP 模型
2
2
3
- | 任务名称 | 在 PaddleScience 中复现 WGAN-GP 模型 |
4
- | --- | --- |
5
- | 提交作者 | robinbg |
6
- | 提交时间 | 2025-04-04 |
7
- | 版本号 | V1.0 |
8
- | 依赖飞桨版本 | develop |
3
+ | 任务名称 | 在 PaddleScience 中复现 WGAN-GP 模型 |
4
+ | --- | ------------------------------------------- |
5
+ | 提交作者 | robinbg、XvLingWYY |
6
+ | 提交时间 | 2025-04-04 |
7
+ | 版本号 | V1.0 |
8
+ | 依赖飞桨版本 | develop |
9
9
| 文件名 | 20250404_add_wgan_gp_for_paddlescience.md |
10
10
11
11
# 一、概述
@@ -75,44 +75,106 @@ WGAN-GP 的核心在于其损失函数和梯度惩罚项的计算。以下是主
75
75
76
76
### 1.1 损失函数
77
77
``` python
78
- # 生成器损失
79
- def generator_loss (fake_output ):
80
- return - paddle.mean(fake_output)
81
-
82
- # 判别器损失(包含梯度惩罚)
83
- def discriminator_loss (real_output , fake_output , gradient_penalty ):
84
- return paddle.mean(fake_output) - paddle.mean(real_output) + LAMBDA * gradient_penalty
78
+ # CIFAR10实验中生成器损失
79
+ class Cifar10GenFuncs :
80
+ """
81
+ Loss function for cifar10 generator
82
+ Args
83
+ discriminator_model: discriminator model
84
+ acgan_scale_g: scale of acgan loss for generator
85
+
86
+ """
87
+
88
+ def __init__ (
89
+ self ,
90
+ discriminator_model ,
91
+ acgan_scale_g = 0.1 ,
92
+ ):
93
+ self .crossEntropyLoss = paddle.nn.CrossEntropyLoss()
94
+ self .acgan_scale_g = acgan_scale_g
95
+ self .discriminator_model = discriminator_model
96
+
97
+ def loss (self , output_dict : Dict, * args ):
98
+ fake_image = output_dict[" fake_data" ]
99
+ labels = output_dict[" labels" ]
100
+ outputs = self .discriminator_model({" data" : fake_image, " labels" : labels})
101
+ disc_fake, disc_fake_acgan = outputs[" disc_fake" ], outputs[" disc_acgan" ]
102
+ gen_cost = - paddle.mean(disc_fake)
103
+ if disc_fake_acgan is not None :
104
+ gen_acgan_cost = self .crossEntropyLoss(disc_fake_acgan, labels)
105
+ gen_cost += self .acgan_scale_g * gen_acgan_cost
106
+ return {" loss_g" : gen_cost}
107
+
108
+ # CIFAR10实验中判别器损失
109
+ class Cifar10DisFuncs :
110
+ """
111
+ Loss function for cifar10 discriminator
112
+ Args
113
+ discriminator_model: discriminator model
114
+ acgan_scale: scale of acgan loss for discriminator
115
+
116
+ """
117
+
118
+ def __init__ (self , discriminator_model , acgan_scale ):
119
+ self .crossEntropyLoss = paddle.nn.CrossEntropyLoss()
120
+ self .acgan_scale = acgan_scale
121
+ self .discriminator_model = discriminator_model
122
+
123
+ def loss (self , output_dict : Dict, label_dict : Dict, * args ):
124
+ fake_image = output_dict[" fake_data" ]
125
+ real_image = label_dict[" real_data" ]
126
+ labels = output_dict[" labels" ]
127
+ disc_fake = self .discriminator_model({" data" : fake_image, " labels" : labels})[
128
+ " disc_fake"
129
+ ]
130
+ out = self .discriminator_model({" data" : real_image, " labels" : labels})
131
+ disc_real, disc_real_acgan = out[" disc_fake" ], out[" disc_acgan" ]
132
+ gradient_penalty = self .compute_gradient_penalty(real_image, fake_image, labels)
133
+ disc_cost = paddle.mean(disc_fake) - paddle.mean(disc_real)
134
+ disc_wgan = disc_cost + gradient_penalty
135
+ if disc_real_acgan is not None :
136
+ disc_acgan_cost = self .crossEntropyLoss(disc_real_acgan, labels)
137
+ disc_acgan = disc_acgan_cost.sum()
138
+ disc_cost = disc_wgan + (self .acgan_scale * disc_acgan)
139
+ else :
140
+ disc_cost = disc_wgan
141
+ return {" loss_d" : disc_cost}
142
+
143
+ def compute_gradient_penalty (self , real_data , fake_data , labels ):
144
+ differences = fake_data - real_data
145
+ alpha = paddle.rand([fake_data.shape[0 ], 1 ])
146
+ interpolates = real_data + (alpha * differences)
147
+ gradients = paddle.grad(
148
+ outputs = self .discriminator_model({" data" : interpolates, " labels" : labels})[
149
+ " disc_fake"
150
+ ],
151
+ inputs = interpolates,
152
+ create_graph = True ,
153
+ retain_graph = False ,
154
+ )[0 ]
155
+ slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis = 1 ))
156
+ gradient_penalty = 10 * paddle.mean((slopes - 1.0 ) ** 2 )
157
+ return gradient_penalty
85
158
```
86
159
87
160
### 1.2 梯度惩罚计算
88
161
``` python
89
- def gradient_penalty (discriminator , real_samples , fake_samples ):
90
- # 生成随机插值系数
91
- alpha = paddle.rand(shape = [real_samples.shape[0 ], 1 , 1 , 1 ])
92
-
93
- # 创建真实样本和生成样本之间的插值
94
- interpolates = real_samples + alpha * (fake_samples - real_samples)
95
- interpolates.stop_gradient = False
96
-
97
- # 计算判别器对插值样本的输出
98
- disc_interpolates = discriminator(interpolates)
99
-
100
- # 计算梯度
101
- gradients = paddle.grad(
102
- outputs = disc_interpolates,
103
- inputs = interpolates,
104
- grad_outputs = paddle.ones_like(disc_interpolates),
105
- create_graph = True ,
106
- retain_graph = True
107
- )[0 ]
108
-
109
- # 计算梯度范数
110
- gradients_norm = paddle.sqrt(paddle.sum(paddle.square(gradients), axis = [1 , 2 , 3 ]))
111
-
112
- # 计算梯度惩罚
113
- gradient_penalty = paddle.mean(paddle.square(gradients_norm - 1.0 ))
114
-
115
- return gradient_penalty
162
+ # CIFAR-10 判别器中的梯度惩罚计算
163
+ def compute_gradient_penalty (self , real_data , fake_data , labels ):
164
+ differences = fake_data - real_data
165
+ alpha = paddle.rand([fake_data.shape[0 ], 1 ])
166
+ interpolates = real_data + (alpha * differences)
167
+ gradients = paddle.grad(
168
+ outputs = self .discriminator_model({" data" : interpolates, " labels" : labels})[
169
+ " disc_fake"
170
+ ],
171
+ inputs = interpolates,
172
+ create_graph = True ,
173
+ retain_graph = False ,
174
+ )[0 ]
175
+ slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis = 1 ))
176
+ gradient_penalty = 10 * paddle.mean((slopes - 1.0 ) ** 2 )
177
+ return gradient_penalty
116
178
```
117
179
118
180
## 2. 网络架构
@@ -135,34 +197,12 @@ WGAN-GP 的训练流程与标准 GAN 有所不同,主要区别在于:
135
197
136
198
``` python
137
199
# 训练循环示例
138
- for iteration in range (ITERATIONS ):
139
- # 训练判别器
140
- for _ in range (CRITIC_ITERS ):
141
- real_data = next (data_iterator)
142
- noise = paddle.randn([BATCH_SIZE , NOISE_DIM ])
143
-
144
- # 计算判别器损失
145
- fake_data = generator(noise)
146
- real_output = discriminator(real_data)
147
- fake_output = discriminator(fake_data)
148
- gp = gradient_penalty(discriminator, real_data, fake_data)
149
- d_loss = discriminator_loss(real_output, fake_output, gp)
150
-
151
- # 更新判别器参数
152
- d_optimizer.clear_grad()
153
- d_loss.backward()
154
- d_optimizer.step()
155
-
156
- # 训练生成器
157
- noise = paddle.randn([BATCH_SIZE , NOISE_DIM ])
158
- fake_data = generator(noise)
159
- fake_output = discriminator(fake_data)
160
- g_loss = generator_loss(fake_output)
161
-
162
- # 更新生成器参数
163
- g_optimizer.clear_grad()
164
- g_loss.backward()
165
- g_optimizer.step()
200
+ for i in range (cfg.TRAIN .epochs):
201
+ logger.message(f " \n Epoch: { i + 1 } \n " )
202
+ optimizer_discriminator.clear_grad()
203
+ solver_discriminator.train()
204
+ optimizer_generator.clear_grad()
205
+ solver_generator.train()
166
206
```
167
207
168
208
## 4. 评估指标
@@ -171,11 +211,8 @@ for iteration in range(ITERATIONS):
171
211
### 4.1 Inception Score (IS)
172
212
用于评估生成图像的质量和多样性。
173
213
174
- ### 4.2 Fréchet Inception Distance (FID)
175
- 测量生成图像分布与真实图像分布之间的距离。
176
-
177
- ### 4.3 生成样本可视化
178
- 定期保存生成的样本,用于直观评估模型性能。
214
+ ### 4.2 生成样本可视化
215
+ 保存生成的样本,用于直观评估模型性能。
179
216
180
217
## 5. 与 PaddleScience 集成
181
218
我们将设计一个模块化的实现,便于与 PaddleScience 集成:
@@ -184,50 +221,142 @@ for iteration in range(ITERATIONS):
184
221
```
185
222
PaddleScience/
186
223
└── examples/
187
- └── wgan_gp/
188
- ├── __init__.py
189
- ├── utils/
190
- │ ├── __init__.py
191
- │ ├── losses.py # 损失函数
192
- │ ├── metrics.py # 评估指标
193
- │ └── visualization.py # 可视化工具
194
- ├── models/
195
- │ ├── __init__.py
196
- │ ├── base_gan.py # GAN 基类
197
- │ ├── wgan.py # WGAN 实现
198
- │ └── wgan_gp.py # WGAN-GP 实现
199
- └── cases/
200
- ├── wgan_gp_toy.py # 玩具数据集示例
201
- ├── wgan_gp_mnist.py # MNIST 示例
202
- └── wgan_gp_cifar.py # CIFAR-10 示例
224
+ └── wgangp/
225
+ ├── conf
226
+ │ ├── wgangp_cifar10.yaml # CIFAR-10 配置文件
227
+ │ ├── wgangp_mnist.yaml # MNIST 配置文件
228
+ │ └── wgangp_toy.yaml # 玩具数据集配置文件
229
+ ├── function.py # 损失函数、评估指标、可视化工具
230
+ ├── wgangp_cifr10.py # CIFAR-10 示例
231
+ ├── wgangp_cifar10_model.py # CIFAR-10实验模型
232
+ ├── wgangp_mnist.py # MNIST 示例
233
+ ├── wgangp_mnist_model.py # MNIST实验模型
234
+ └── wgangp_toy.py # 玩具数据集示例
235
+ └── wgangp_toy_model.py # 玩具数据集实验模型
203
236
```
204
237
205
238
### 5.2 接口设计
206
239
提供简洁统一的接口,方便用户使用:
207
240
208
241
``` python
209
242
# 示例用法
210
- from models.wgan_gp import WGAN_GP
211
-
212
- # 创建模型
213
- model = WGAN_GP(
214
- generator = generator_network,
215
- discriminator = discriminator_network,
216
- lambda_gp = 10.0 ,
217
- critic_iters = 5
218
- )
219
-
220
- # 训练模型
221
- model.train(
222
- train_data = dataset,
223
- batch_size = 64 ,
224
- iterations = 100000 ,
225
- g_learning_rate = 1e-4 ,
226
- d_learning_rate = 1e-4
227
- )
228
-
229
- # 生成样本
230
- samples = model.generate(num_samples = 100 )
243
+ import os
244
+ import paddle
245
+ from functions import Cifar10DisFuncs
246
+ from functions import Cifar10GenFuncs
247
+ from functions import load_cifar10
248
+ from omegaconf import DictConfig
249
+ from wgangp_cifar10_model import WganGpCifar10Discriminator
250
+ from wgangp_cifar10_model import WganGpCifar10Generator
251
+
252
+ def train (cfg : DictConfig):
253
+ # set model
254
+ generator_model = WganGpCifar10Generator(** cfg[" MODEL" ][" gen_net" ])
255
+ discriminator_model = WganGpCifar10Discriminator(** cfg[" MODEL" ][" dis_net" ])
256
+ if cfg.TRAIN .pretrained_dis_model_path and os.path.exists(
257
+ cfg.TRAIN .pretrained_dis_model_path
258
+ ):
259
+ discriminator_model.load_dict(paddle.load(cfg.TRAIN .pretrained_dis_model_path))
260
+
261
+ # set Loss
262
+ generator_funcs = Cifar10GenFuncs(
263
+ ** cfg[" LOSS" ][" gen" ], discriminator_model = discriminator_model
264
+ )
265
+ discriminator_funcs = Cifar10DisFuncs(
266
+ ** cfg[" LOSS" ][" dis" ], discriminator_model = discriminator_model
267
+ )
268
+
269
+ # set dataloader
270
+ inputs, labels = load_cifar10(** cfg[" DATA" ])
271
+ dataloader_cfg = {
272
+ " dataset" : {
273
+ " name" : cfg[" EVAL" ][" dataset" ][" name" ],
274
+ " input" : inputs,
275
+ " label" : labels,
276
+ },
277
+ " sampler" : {
278
+ ** cfg[" TRAIN" ][" sampler" ],
279
+ },
280
+ " batch_size" : cfg[" TRAIN" ][" batch_size" ],
281
+ " use_shared_memory" : cfg[" TRAIN" ][" use_shared_memory" ],
282
+ " num_workers" : cfg[" TRAIN" ][" num_workers" ],
283
+ " drop_last" : cfg[" TRAIN" ][" drop_last" ],
284
+ }
285
+
286
+ # set constraint
287
+ constraint_generator = ppsci.constraint.SupervisedConstraint(
288
+ dataloader_cfg = dataloader_cfg,
289
+ loss = ppsci.loss.FunctionalLoss(generator_funcs.loss),
290
+ output_expr = {" labels" : lambda out : out[" labels" ]},
291
+ name = " constraint_generator" ,
292
+ )
293
+ constraint_generator_dict = {constraint_generator.name: constraint_generator}
294
+
295
+ constraint_discriminator = ppsci.constraint.SupervisedConstraint(
296
+ dataloader_cfg = dataloader_cfg,
297
+ loss = ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
298
+ output_expr = {" labels" : lambda out : out[" labels" ]},
299
+ name = " constraint_discriminator" ,
300
+ )
301
+ constraint_discriminator_dict = {
302
+ constraint_discriminator.name: constraint_discriminator
303
+ }
304
+
305
+ # set optimizer
306
+ lr_scheduler_generator = Linear(** cfg[" TRAIN" ][" lr_scheduler_gen" ])()
307
+ lr_scheduler_discriminator = Linear(** cfg[" TRAIN" ][" lr_scheduler_dis" ])()
308
+
309
+ optimizer_generator = ppsci.optimizer.Adam(
310
+ learning_rate = lr_scheduler_generator,
311
+ beta1 = cfg[" TRAIN" ][" optimizer" ][" beta1" ],
312
+ beta2 = cfg[" TRAIN" ][" optimizer" ][" beta2" ],
313
+ )
314
+ optimizer_discriminator = ppsci.optimizer.Adam(
315
+ learning_rate = lr_scheduler_discriminator,
316
+ beta1 = cfg[" TRAIN" ][" optimizer" ][" beta1" ],
317
+ beta2 = cfg[" TRAIN" ][" optimizer" ][" beta2" ],
318
+ )
319
+ optimizer_generator = optimizer_generator(generator_model)
320
+ optimizer_discriminator = optimizer_discriminator(discriminator_model)
321
+
322
+ # initialize solver
323
+ solver_generator = ppsci.solver.Solver(
324
+ model = generator_model,
325
+ output_dir = os.path.join(cfg.output_dir, " generator" ),
326
+ constraint = constraint_generator_dict,
327
+ optimizer = optimizer_generator,
328
+ epochs = cfg.TRAIN .epochs_gen,
329
+ iters_per_epoch = cfg.TRAIN .iters_per_epoch_gen,
330
+ pretrained_model_path = cfg.TRAIN .pretrained_gen_model_path,
331
+ )
332
+ solver_discriminator = ppsci.solver.Solver(
333
+ model = generator_model,
334
+ output_dir = os.path.join(cfg.output_dir, " discriminator" ),
335
+ constraint = constraint_discriminator_dict,
336
+ optimizer = optimizer_discriminator,
337
+ epochs = cfg.TRAIN .epochs_dis,
338
+ iters_per_epoch = cfg.TRAIN .iters_per_epoch_dis,
339
+ pretrained_model_path = cfg.TRAIN .pretrained_gen_model_path,
340
+ )
341
+
342
+ # train
343
+ for i in range (cfg.TRAIN .epochs):
344
+ logger.message(f " \n Epoch: { i + 1 } \n " )
345
+ optimizer_discriminator.clear_grad()
346
+ solver_discriminator.train()
347
+ optimizer_generator.clear_grad()
348
+ solver_generator.train()
349
+
350
+ # save model weight
351
+ paddle.save(
352
+ generator_model.state_dict(),
353
+ os.path.join(cfg.output_dir, " model_generator.pdparams" ),
354
+ )
355
+ paddle.save(
356
+ discriminator_model.state_dict(),
357
+ os.path.join(cfg.output_dir, " model_discriminator.pdparams" ),
358
+ )
359
+
231
360
```
232
361
233
362
# 六、测试验收的考量
0 commit comments