Skip to content

Commit aa7623e

Browse files
authored
【Hackathon 8th No.23】RFC:Improved Training of Wasserstein GANs 论文复现 (#1114)
* Update 20250404_add_wgan_gp_for_paddlescience.md * Add files via upload * Add files via upload
1 parent b706297 commit aa7623e

File tree

1 file changed

+239
-110
lines changed

1 file changed

+239
-110
lines changed

rfcs/Science/20250404_add_wgan_gp_for_paddlescience.md

Lines changed: 239 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# 在 PaddleScience 中复现 WGAN-GP 模型
22

3-
| 任务名称 | 在 PaddleScience 中复现 WGAN-GP 模型 |
4-
| --- | --- |
5-
| 提交作者 | robinbg |
6-
| 提交时间 | 2025-04-04 |
7-
| 版本号 | V1.0 |
8-
| 依赖飞桨版本 | develop |
3+
| 任务名称 | 在 PaddleScience 中复现 WGAN-GP 模型 |
4+
| --- |-------------------------------------------|
5+
| 提交作者 | robinbg、XvLingWYY |
6+
| 提交时间 | 2025-04-04 |
7+
| 版本号 | V1.0 |
8+
| 依赖飞桨版本 | develop |
99
| 文件名 | 20250404_add_wgan_gp_for_paddlescience.md |
1010

1111
# 一、概述
@@ -75,44 +75,106 @@ WGAN-GP 的核心在于其损失函数和梯度惩罚项的计算。以下是主
7575

7676
### 1.1 损失函数
7777
```python
78-
# 生成器损失
79-
def generator_loss(fake_output):
80-
return -paddle.mean(fake_output)
81-
82-
# 判别器损失(包含梯度惩罚)
83-
def discriminator_loss(real_output, fake_output, gradient_penalty):
84-
return paddle.mean(fake_output) - paddle.mean(real_output) + LAMBDA * gradient_penalty
78+
# CIFAR10实验中生成器损失
79+
class Cifar10GenFuncs:
80+
"""
81+
Loss function for cifar10 generator
82+
Args
83+
discriminator_model: discriminator model
84+
acgan_scale_g: scale of acgan loss for generator
85+
86+
"""
87+
88+
def __init__(
89+
self,
90+
discriminator_model,
91+
acgan_scale_g=0.1,
92+
):
93+
self.crossEntropyLoss = paddle.nn.CrossEntropyLoss()
94+
self.acgan_scale_g = acgan_scale_g
95+
self.discriminator_model = discriminator_model
96+
97+
def loss(self, output_dict: Dict, *args):
98+
fake_image = output_dict["fake_data"]
99+
labels = output_dict["labels"]
100+
outputs = self.discriminator_model({"data": fake_image, "labels": labels})
101+
disc_fake, disc_fake_acgan = outputs["disc_fake"], outputs["disc_acgan"]
102+
gen_cost = -paddle.mean(disc_fake)
103+
if disc_fake_acgan is not None:
104+
gen_acgan_cost = self.crossEntropyLoss(disc_fake_acgan, labels)
105+
gen_cost += self.acgan_scale_g * gen_acgan_cost
106+
return {"loss_g": gen_cost}
107+
108+
# CIFAR10实验中判别器损失
109+
class Cifar10DisFuncs:
110+
"""
111+
Loss function for cifar10 discriminator
112+
Args
113+
discriminator_model: discriminator model
114+
acgan_scale: scale of acgan loss for discriminator
115+
116+
"""
117+
118+
def __init__(self, discriminator_model, acgan_scale):
119+
self.crossEntropyLoss = paddle.nn.CrossEntropyLoss()
120+
self.acgan_scale = acgan_scale
121+
self.discriminator_model = discriminator_model
122+
123+
def loss(self, output_dict: Dict, label_dict: Dict, *args):
124+
fake_image = output_dict["fake_data"]
125+
real_image = label_dict["real_data"]
126+
labels = output_dict["labels"]
127+
disc_fake = self.discriminator_model({"data": fake_image, "labels": labels})[
128+
"disc_fake"
129+
]
130+
out = self.discriminator_model({"data": real_image, "labels": labels})
131+
disc_real, disc_real_acgan = out["disc_fake"], out["disc_acgan"]
132+
gradient_penalty = self.compute_gradient_penalty(real_image, fake_image, labels)
133+
disc_cost = paddle.mean(disc_fake) - paddle.mean(disc_real)
134+
disc_wgan = disc_cost + gradient_penalty
135+
if disc_real_acgan is not None:
136+
disc_acgan_cost = self.crossEntropyLoss(disc_real_acgan, labels)
137+
disc_acgan = disc_acgan_cost.sum()
138+
disc_cost = disc_wgan + (self.acgan_scale * disc_acgan)
139+
else:
140+
disc_cost = disc_wgan
141+
return {"loss_d": disc_cost}
142+
143+
def compute_gradient_penalty(self, real_data, fake_data, labels):
144+
differences = fake_data - real_data
145+
alpha = paddle.rand([fake_data.shape[0], 1])
146+
interpolates = real_data + (alpha * differences)
147+
gradients = paddle.grad(
148+
outputs=self.discriminator_model({"data": interpolates, "labels": labels})[
149+
"disc_fake"
150+
],
151+
inputs=interpolates,
152+
create_graph=True,
153+
retain_graph=False,
154+
)[0]
155+
slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=1))
156+
gradient_penalty = 10 * paddle.mean((slopes - 1.0) ** 2)
157+
return gradient_penalty
85158
```
86159

87160
### 1.2 梯度惩罚计算
88161
```python
89-
def gradient_penalty(discriminator, real_samples, fake_samples):
90-
# 生成随机插值系数
91-
alpha = paddle.rand(shape=[real_samples.shape[0], 1, 1, 1])
92-
93-
# 创建真实样本和生成样本之间的插值
94-
interpolates = real_samples + alpha * (fake_samples - real_samples)
95-
interpolates.stop_gradient = False
96-
97-
# 计算判别器对插值样本的输出
98-
disc_interpolates = discriminator(interpolates)
99-
100-
# 计算梯度
101-
gradients = paddle.grad(
102-
outputs=disc_interpolates,
103-
inputs=interpolates,
104-
grad_outputs=paddle.ones_like(disc_interpolates),
105-
create_graph=True,
106-
retain_graph=True
107-
)[0]
108-
109-
# 计算梯度范数
110-
gradients_norm = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=[1, 2, 3]))
111-
112-
# 计算梯度惩罚
113-
gradient_penalty = paddle.mean(paddle.square(gradients_norm - 1.0))
114-
115-
return gradient_penalty
162+
# CIFAR-10 判别器中的梯度惩罚计算
163+
def compute_gradient_penalty(self, real_data, fake_data, labels):
164+
differences = fake_data - real_data
165+
alpha = paddle.rand([fake_data.shape[0], 1])
166+
interpolates = real_data + (alpha * differences)
167+
gradients = paddle.grad(
168+
outputs=self.discriminator_model({"data": interpolates, "labels": labels})[
169+
"disc_fake"
170+
],
171+
inputs=interpolates,
172+
create_graph=True,
173+
retain_graph=False,
174+
)[0]
175+
slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=1))
176+
gradient_penalty = 10 * paddle.mean((slopes - 1.0) ** 2)
177+
return gradient_penalty
116178
```
117179

118180
## 2. 网络架构
@@ -135,34 +197,12 @@ WGAN-GP 的训练流程与标准 GAN 有所不同,主要区别在于:
135197

136198
```python
137199
# 训练循环示例
138-
for iteration in range(ITERATIONS):
139-
# 训练判别器
140-
for _ in range(CRITIC_ITERS):
141-
real_data = next(data_iterator)
142-
noise = paddle.randn([BATCH_SIZE, NOISE_DIM])
143-
144-
# 计算判别器损失
145-
fake_data = generator(noise)
146-
real_output = discriminator(real_data)
147-
fake_output = discriminator(fake_data)
148-
gp = gradient_penalty(discriminator, real_data, fake_data)
149-
d_loss = discriminator_loss(real_output, fake_output, gp)
150-
151-
# 更新判别器参数
152-
d_optimizer.clear_grad()
153-
d_loss.backward()
154-
d_optimizer.step()
155-
156-
# 训练生成器
157-
noise = paddle.randn([BATCH_SIZE, NOISE_DIM])
158-
fake_data = generator(noise)
159-
fake_output = discriminator(fake_data)
160-
g_loss = generator_loss(fake_output)
161-
162-
# 更新生成器参数
163-
g_optimizer.clear_grad()
164-
g_loss.backward()
165-
g_optimizer.step()
200+
for i in range(cfg.TRAIN.epochs):
201+
logger.message(f"\nEpoch: {i + 1}\n")
202+
optimizer_discriminator.clear_grad()
203+
solver_discriminator.train()
204+
optimizer_generator.clear_grad()
205+
solver_generator.train()
166206
```
167207

168208
## 4. 评估指标
@@ -171,11 +211,8 @@ for iteration in range(ITERATIONS):
171211
### 4.1 Inception Score (IS)
172212
用于评估生成图像的质量和多样性。
173213

174-
### 4.2 Fréchet Inception Distance (FID)
175-
测量生成图像分布与真实图像分布之间的距离。
176-
177-
### 4.3 生成样本可视化
178-
定期保存生成的样本,用于直观评估模型性能。
214+
### 4.2 生成样本可视化
215+
保存生成的样本,用于直观评估模型性能。
179216

180217
## 5. 与 PaddleScience 集成
181218
我们将设计一个模块化的实现,便于与 PaddleScience 集成:
@@ -184,50 +221,142 @@ for iteration in range(ITERATIONS):
184221
```
185222
PaddleScience/
186223
└── examples/
187-
└── wgan_gp/
188-
├── __init__.py
189-
├── utils/
190-
│ ├── __init__.py
191-
│ ├── losses.py # 损失函数
192-
│ ├── metrics.py # 评估指标
193-
│ └── visualization.py # 可视化工具
194-
├── models/
195-
│ ├── __init__.py
196-
│ ├── base_gan.py # GAN 基类
197-
│ ├── wgan.py # WGAN 实现
198-
│ └── wgan_gp.py # WGAN-GP 实现
199-
└── cases/
200-
├── wgan_gp_toy.py # 玩具数据集示例
201-
├── wgan_gp_mnist.py # MNIST 示例
202-
└── wgan_gp_cifar.py # CIFAR-10 示例
224+
└── wgangp/
225+
├── conf
226+
│ ├── wgangp_cifar10.yaml # CIFAR-10 配置文件
227+
│ ├── wgangp_mnist.yaml # MNIST 配置文件
228+
│ └── wgangp_toy.yaml # 玩具数据集配置文件
229+
├── function.py # 损失函数、评估指标、可视化工具
230+
├── wgangp_cifr10.py # CIFAR-10 示例
231+
├── wgangp_cifar10_model.py # CIFAR-10实验模型
232+
├── wgangp_mnist.py # MNIST 示例
233+
├── wgangp_mnist_model.py # MNIST实验模型
234+
└── wgangp_toy.py # 玩具数据集示例
235+
└── wgangp_toy_model.py # 玩具数据集实验模型
203236
```
204237

205238
### 5.2 接口设计
206239
提供简洁统一的接口,方便用户使用:
207240

208241
```python
209242
# 示例用法
210-
from models.wgan_gp import WGAN_GP
211-
212-
# 创建模型
213-
model = WGAN_GP(
214-
generator=generator_network,
215-
discriminator=discriminator_network,
216-
lambda_gp=10.0,
217-
critic_iters=5
218-
)
219-
220-
# 训练模型
221-
model.train(
222-
train_data=dataset,
223-
batch_size=64,
224-
iterations=100000,
225-
g_learning_rate=1e-4,
226-
d_learning_rate=1e-4
227-
)
228-
229-
# 生成样本
230-
samples = model.generate(num_samples=100)
243+
import os
244+
import paddle
245+
from functions import Cifar10DisFuncs
246+
from functions import Cifar10GenFuncs
247+
from functions import load_cifar10
248+
from omegaconf import DictConfig
249+
from wgangp_cifar10_model import WganGpCifar10Discriminator
250+
from wgangp_cifar10_model import WganGpCifar10Generator
251+
252+
def train(cfg: DictConfig):
253+
# set model
254+
generator_model = WganGpCifar10Generator(**cfg["MODEL"]["gen_net"])
255+
discriminator_model = WganGpCifar10Discriminator(**cfg["MODEL"]["dis_net"])
256+
if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
257+
cfg.TRAIN.pretrained_dis_model_path
258+
):
259+
discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))
260+
261+
# set Loss
262+
generator_funcs = Cifar10GenFuncs(
263+
**cfg["LOSS"]["gen"], discriminator_model=discriminator_model
264+
)
265+
discriminator_funcs = Cifar10DisFuncs(
266+
**cfg["LOSS"]["dis"], discriminator_model=discriminator_model
267+
)
268+
269+
# set dataloader
270+
inputs, labels = load_cifar10(**cfg["DATA"])
271+
dataloader_cfg = {
272+
"dataset": {
273+
"name": cfg["EVAL"]["dataset"]["name"],
274+
"input": inputs,
275+
"label": labels,
276+
},
277+
"sampler": {
278+
**cfg["TRAIN"]["sampler"],
279+
},
280+
"batch_size": cfg["TRAIN"]["batch_size"],
281+
"use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
282+
"num_workers": cfg["TRAIN"]["num_workers"],
283+
"drop_last": cfg["TRAIN"]["drop_last"],
284+
}
285+
286+
# set constraint
287+
constraint_generator = ppsci.constraint.SupervisedConstraint(
288+
dataloader_cfg=dataloader_cfg,
289+
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
290+
output_expr={"labels": lambda out: out["labels"]},
291+
name="constraint_generator",
292+
)
293+
constraint_generator_dict = {constraint_generator.name: constraint_generator}
294+
295+
constraint_discriminator = ppsci.constraint.SupervisedConstraint(
296+
dataloader_cfg=dataloader_cfg,
297+
loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
298+
output_expr={"labels": lambda out: out["labels"]},
299+
name="constraint_discriminator",
300+
)
301+
constraint_discriminator_dict = {
302+
constraint_discriminator.name: constraint_discriminator
303+
}
304+
305+
# set optimizer
306+
lr_scheduler_generator = Linear(**cfg["TRAIN"]["lr_scheduler_gen"])()
307+
lr_scheduler_discriminator = Linear(**cfg["TRAIN"]["lr_scheduler_dis"])()
308+
309+
optimizer_generator = ppsci.optimizer.Adam(
310+
learning_rate=lr_scheduler_generator,
311+
beta1=cfg["TRAIN"]["optimizer"]["beta1"],
312+
beta2=cfg["TRAIN"]["optimizer"]["beta2"],
313+
)
314+
optimizer_discriminator = ppsci.optimizer.Adam(
315+
learning_rate=lr_scheduler_discriminator,
316+
beta1=cfg["TRAIN"]["optimizer"]["beta1"],
317+
beta2=cfg["TRAIN"]["optimizer"]["beta2"],
318+
)
319+
optimizer_generator = optimizer_generator(generator_model)
320+
optimizer_discriminator = optimizer_discriminator(discriminator_model)
321+
322+
# initialize solver
323+
solver_generator = ppsci.solver.Solver(
324+
model=generator_model,
325+
output_dir=os.path.join(cfg.output_dir, "generator"),
326+
constraint=constraint_generator_dict,
327+
optimizer=optimizer_generator,
328+
epochs=cfg.TRAIN.epochs_gen,
329+
iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
330+
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
331+
)
332+
solver_discriminator = ppsci.solver.Solver(
333+
model=generator_model,
334+
output_dir=os.path.join(cfg.output_dir, "discriminator"),
335+
constraint=constraint_discriminator_dict,
336+
optimizer=optimizer_discriminator,
337+
epochs=cfg.TRAIN.epochs_dis,
338+
iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
339+
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
340+
)
341+
342+
# train
343+
for i in range(cfg.TRAIN.epochs):
344+
logger.message(f"\nEpoch: {i + 1}\n")
345+
optimizer_discriminator.clear_grad()
346+
solver_discriminator.train()
347+
optimizer_generator.clear_grad()
348+
solver_generator.train()
349+
350+
# save model weight
351+
paddle.save(
352+
generator_model.state_dict(),
353+
os.path.join(cfg.output_dir, "model_generator.pdparams"),
354+
)
355+
paddle.save(
356+
discriminator_model.state_dict(),
357+
os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
358+
)
359+
231360
```
232361

233362
# 六、测试验收的考量

0 commit comments

Comments
 (0)