Skip to content

Commit 25473a8

Browse files
authored
add ema for sd (#3755)
1 parent 22ae267 commit 25473a8

File tree

1 file changed

+104
-0
lines changed
  • ppdiffusers/ppdiffusers/models

1 file changed

+104
-0
lines changed

ppdiffusers/ppdiffusers/models/ema.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from paddle import nn
17+
18+
19+
class LitEma(nn.Layer):
20+
"""
21+
Exponential Moving Average (EMA) of model updates
22+
23+
Parameters:
24+
model: The model architecture for apply EMA.
25+
decay: The exponential decay. Default 0.9999.
26+
use_num_updates: Whether to use number of updates when computing
27+
averages.
28+
"""
29+
30+
def __init__(self, model, decay=0.9999, use_num_upates=True):
31+
super().__init__()
32+
if decay < 0.0 or decay > 1.0:
33+
raise ValueError('Decay must be between 0 and 1')
34+
35+
self.m_name2s_name = {}
36+
self.register_buffer('decay',
37+
paddle.to_tensor(decay, dtype=paddle.float32))
38+
self.register_buffer(
39+
'num_updates',
40+
paddle.to_tensor(0, dtype=paddle.int64)
41+
if use_num_upates else paddle.to_tensor(-1, dtype=paddle.int64))
42+
43+
for name, p in model.named_parameters():
44+
if not p.stop_gradient:
45+
#remove as '.'-character is not allowed in buffers
46+
s_name = name.replace('.', '')
47+
self.m_name2s_name.update({name: s_name})
48+
self.register_buffer(s_name, p.clone().detach())
49+
50+
self.collected_params = []
51+
52+
def forward(self, model):
53+
decay = self.decay
54+
55+
if self.num_updates >= 0:
56+
self.num_updates += 1
57+
decay = min(self.decay,
58+
(1 + self.num_updates) / (10 + self.num_updates))
59+
60+
one_minus_decay = 1.0 - decay
61+
62+
with paddle.no_grad():
63+
m_param = dict(model.named_parameters())
64+
shadow_params = dict(self.named_buffers())
65+
66+
for key in m_param:
67+
if not m_param[key].stop_gradient:
68+
sname = self.m_name2s_name[key]
69+
shadow_params[sname].scale_(decay)
70+
shadow_params[sname].add_(m_param[key] * one_minus_decay)
71+
else:
72+
assert not key in self.m_name2s_name
73+
74+
def copy_to(self, model):
75+
m_param = dict(model.named_parameters())
76+
shadow_params = dict(self.named_buffers())
77+
for key in m_param:
78+
if not m_param[key].stop_gradient:
79+
m_param[key].copy_(shadow_params[self.m_name2s_name[key]], True)
80+
else:
81+
assert not key in self.m_name2s_name
82+
83+
def store(self, parameters):
84+
"""
85+
Save the current parameters for restoring later.
86+
Args:
87+
parameters: Iterable of `paddle.nn.Parameter`; the parameters to be
88+
temporarily stored.
89+
"""
90+
self.collected_params = [param.clone() for param in parameters]
91+
92+
def restore(self, parameters):
93+
"""
94+
Restore the parameters stored with the `store` method.
95+
Useful to validate the model with EMA parameters without affecting the
96+
original optimization process. Store the parameters before the
97+
`copy_to` method. After validation (or model saving), use this to
98+
restore the former parameters.
99+
Args:
100+
parameters: Iterable of `paddle.nn.Parameter`; the parameters to be
101+
updated with the stored parameters.
102+
"""
103+
for c_param, param in zip(self.collected_params, parameters):
104+
param.copy_(c_param, True)

0 commit comments

Comments
 (0)