-
Notifications
You must be signed in to change notification settings - Fork 215
【Hackathon 8th No.12】在 PaddleScience 中实现 SOAP 优化器 #1102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for your contribution! |
为验证 SOAP 实现正确性,基于 MLP demo + SOAP 对torch与paddle 的 loss 进行对比 前 20iter loss对比如下: λ beinggod-workstation /workspace/hackathon/SOAP/SOAP python soap_paddle.py
grep: warning: GREP_OPTIONS is deprecated; please use an alias or script
W0313 05:29:07.938906 1188667 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 12.6, Runtime API Version: 11.8
W0313 05:29:07.939640 1188667 gpu_resources.cc:164] device: 0, cuDNN Version: 8.9.
pddle param name: weight, mean: 0.006312752142548561, sum: 103.42813110351562
torch param name: weight, mean: 0.006312752142548561, sum: 103.42813110351562
[ITERATION] 1/20 loss paddle: 12.870917, loss torch: 12.870915
[ITERATION] 2/20 loss paddle: 15.275839, loss torch: 15.275839
[ITERATION] 3/20 loss paddle: 9.223783, loss torch: 9.223782
[ITERATION] 4/20 loss paddle: 9.283824, loss torch: 9.283824
[ITERATION] 5/20 loss paddle: 7.028357, loss torch: 7.028356
[ITERATION] 6/20 loss paddle: 7.973010, loss torch: 7.973009
[ITERATION] 7/20 loss paddle: 11.567774, loss torch: 11.567776
[ITERATION] 8/20 loss paddle: 5.823763, loss torch: 5.823766
[ITERATION] 9/20 loss paddle: 12.174599, loss torch: 12.174601
[ITERATION] 10/20 loss paddle: 8.206469, loss torch: 8.206469
[ITERATION] 11/20 loss paddle: 7.991440, loss torch: 7.991440
[ITERATION] 12/20 loss paddle: 7.984601, loss torch: 7.984600
[ITERATION] 13/20 loss paddle: 9.944571, loss torch: 9.944571
[ITERATION] 14/20 loss paddle: 10.886424, loss torch: 10.886503
[ITERATION] 15/20 loss paddle: 8.474144, loss torch: 8.474236
[ITERATION] 16/20 loss paddle: 7.847643, loss torch: 7.847647
[ITERATION] 17/20 loss paddle: 10.080597, loss torch: 10.080593
[ITERATION] 18/20 loss paddle: 9.120567, loss torch: 9.120558
[ITERATION] 19/20 loss paddle: 7.173148, loss torch: 7.173133
[ITERATION] 20/20 loss paddle: 7.039429, loss torch: 7.039426
λ beinggod-workstation /workspace/hackathon/SOAP/SOAP loss误差在1e-5,可以认为二者实现是对齐的 测试代码: import paddle
import paddle.device
from itertools import chain
from collections import defaultdict
import paddle.optimizer
import paddle.utils
import torch
import numpy as np
from itertools import chain
import torch.utils
import torch.utils.data
# Parts of the code are modifications of Pypaddle's AdamW optimizer
# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_paddle/galore_projector.py
seed = 1234
paddle.device.set_device('gpu')
paddle.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
class SOAP_paddle(paddle.optimizer.Optimizer):
"""
Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).
Parameters:
params (`list|tuple`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.003):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
Adam's betas parameters (b1, b2).
shampoo_beta (`float`, *optional*, defaults to -1):
If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
eps (`float`, *optional*, defaults to 1e-08):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
precondition_frequency (`int`, *optional*, defaults to 10):
How often to update the preconditioner.
max_precond_dim (`int`, *optional*, defaults to 10000):
Maximum dimension of the preconditioner.
Set to 10000, so that we exclude most common vocab sizes while including layers.
merge_dims (`bool`, *optional*, defaults to `False`):
Whether or not to merge dimensions of the preconditioner.
precondition_1d (`bool`, *optional*, defaults to `False`):
Whether or not to precondition 1D gradients.
normalize_grads (`bool`, *optional*, defaults to `False`):
Whether or not to normalize gradients per layer.
Helps at large precondition_frequency (~100 in our experiments),
but hurts performance at small precondition_frequency (~10 in our experiments).
data_format (`str`, *optional*, defaults to `channels_first`):
Data format of the input for convolutional layers.
Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias correction in Adam.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
"""
def __init__(
self,
params,
lr: float = 3e-3,
betas=(0.95, 0.95),
shampoo_beta: float= -1,
eps: float = 1e-8,
weight_decay: float = 0.01,
precondition_frequency: int=10,
max_precond_dim: int=10000, #
merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
precondition_1d: bool = False,
normalize_grads: bool = False,
data_format: str = "channels_first",
correct_bias: bool = True,
name: str = None,
):
self._betas = betas
self._shampoo_beta = shampoo_beta
self._eps = eps
self._precondition_frequency = precondition_frequency
self._max_precond_dim = max_precond_dim
self._merge_dims = merge_dims
self._precondition_1d = precondition_1d
self._normalize_grads = normalize_grads
self._correct_bias = correct_bias
self._weight_decay = weight_decay
self.state = defaultdict(dict)
super().__init__(learning_rate=lr,
parameters=params,
weight_decay=weight_decay,
name=name)
if isinstance(self._parameter_list[0],dict):
raise TypeError(
"The parameter groups is not supported on SOAP optimizer."
)
self._data_format = data_format
def merge_dims(self, grad, max_precond_dim):
"""
Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
"""
assert self._data_format in ["channels_first", "channels_last"]
if self._data_format == "channels_last" and grad.dim() == 4:
grad = grad.transpose(0, 3, 1, 2)
shape = grad.shape
new_shape = []
curr_shape = 1
for sh in shape:
temp_shape = curr_shape * sh
if temp_shape > max_precond_dim:
if curr_shape > 1:
new_shape.append(curr_shape)
curr_shape = sh
else:
new_shape.append(sh)
curr_shape = 1
else:
curr_shape = temp_shape
if curr_shape > 1 or len(new_shape)==0:
new_shape.append(curr_shape)
new_grad = grad.reshape(new_shape)
return new_grad
@paddle.base.framework.non_static_only
def step(self, closure = None):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
with paddle.no_grad():
if closure is None:
loss = None
else:
closure = paddle.enable_grad()(closure)
loss = closure()
for p in self._parameter_list:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if "step" not in state:
state["step"] = 0
# State initialization
if "exp_avg" not in state:
# Exponential moving average of gradient values
state["exp_avg"] = paddle.zeros_like(grad)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = paddle.zeros_like(grad)
if 'Q' not in state:
self.init_preconditioner(
grad,
state,
precondition_frequency=self._precondition_frequency,
precondition_1d=self._precondition_1d,
shampoo_beta=(self._shampoo_beta if self._shampoo_beta >= 0 else self._betas[1]),
max_precond_dim=self._max_precond_dim,
merge_dims=self._merge_dims,
)
self.update_preconditioner(grad, state,
max_precond_dim=self._max_precond_dim,
merge_dims=self._merge_dims,
precondition_1d=self._precondition_1d)
continue # first step is skipped so that we never use the current gradients in the projection.
# Projecting gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state['GG']
grad_projected = self.project(grad, state, merge_dims=self._merge_dims,
max_precond_dim=self._max_precond_dim)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = paddle.to_tensor(self._betas)
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.multiply_(beta1).add_((1.0 - beta1)*grad_projected)
exp_avg_sq.multiply_(beta2).add_((1.0-beta2)*grad_projected.square())
denom = exp_avg_sq.sqrt().add_(paddle.to_tensor(self._eps))
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state['GG']
# exp_avg_projected = self.project(exp_avg, state, merge_dims=self._merge_dims"],
# max_precond_dim=self._max_precond_dim'])
exp_avg_projected = exp_avg
lr = self._learning_rate
step_size = lr
if self._correct_bias:
bias_correction1 = 1.0 - beta1 ** (state["step"])
bias_correction2 = 1.0 - beta2 ** (state["step"])
step_size = step_size * (bias_correction2 ** .5) / bias_correction1
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
# to the original space
norm_grad = self.project_back(exp_avg_projected / denom, state, merge_dims=self._merge_dims,
max_precond_dim=self._max_precond_dim)
if self._normalize_grads:
norm_grad = norm_grad / (1e-30+paddle.mean(norm_grad**2)**0.5)
p.add_(-step_size * norm_grad)
# From AdamW code: Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if self._weight_decay > 0.0:
p.add_((-lr * self._weight_decay) * p)
# Update is done after the gradient step to avoid using current gradients in the projection.
self.update_preconditioner(grad, state,
max_precond_dim=self._max_precond_dim,
merge_dims=self._merge_dims,
precondition_1d=self._precondition_1d)
return loss
def init_preconditioner(self, grad, state, precondition_frequency=10,
shampoo_beta=0.95, max_precond_dim=10000, precondition_1d=False,
merge_dims=False):
"""
Initializes the preconditioner matrices (L and R in the paper).
"""
state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
if grad.dim() == 1:
if not precondition_1d or grad.shape[0] > max_precond_dim:
state['GG'].append([])
else:
state['GG'].append(paddle.zeros([grad.shape[0], grad.shape[0]]))
else:
if merge_dims:
grad = self.merge_dims(grad, max_precond_dim)
for sh in grad.shape:
if sh > max_precond_dim:
state['GG'].append([])
else:
state['GG'].append(paddle.zeros([sh, sh]))
state['Q'] = None # Will hold all the eigenbases of the preconditioner.
state['precondition_frequency'] = precondition_frequency
state['shampoo_beta'] = shampoo_beta
def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
"""
Projects the gradient to the eigenbases of the preconditioner.
"""
original_shape = grad.shape
if merge_dims:
if grad.dim() == 4 and self._data_format == 'channels_last':
transposed_shape = grad.transpose(0, 3, 1, 2).shape
grad = self.merge_dims(grad, max_precond_dim)
for mat in state['Q']:
if len(mat) > 0:
grad = paddle.tensordot(
grad,
mat,
axes=[[0], [0]],
)
else:
transpose_order = list(range(1, len(grad.shape))) + [0]
grad = grad.transpose(transpose_order)
if merge_dims:
if self._data_format == 'channels_last' and len(original_shape) == 4:
grad = grad.reshape(transposed_shape).transpose(0, 2, 3, 1)
else:
grad = grad.reshape(original_shape)
return grad
def update_preconditioner(self, grad, state,
max_precond_dim=10000, merge_dims=False, precondition_1d=False):
"""
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
"""
if state["Q"] is not None:
state["exp_avg"] = self.project_back(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim)
if grad.dim() == 1:
if precondition_1d and grad.shape[0] <= max_precond_dim:
state['GG'][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1-state['shampoo_beta'])
else:
if merge_dims:
new_grad = self.merge_dims(grad, max_precond_dim)
for idx, sh in enumerate(new_grad.shape):
if sh <= max_precond_dim:
outer_product = paddle.tensordot(
new_grad,
new_grad,
axes=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2,
)
state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
else:
for idx, sh in enumerate(grad.shape):
if sh <= max_precond_dim:
outer_product = paddle.tensordot(
grad,
grad,
# Contracts across all dimensions except for k.
axes=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2,
)
state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
if state['Q'] is None:
state['Q'] = self.get_orthogonal_matrix(state['GG'])
if state['step'] > 0 and state['step'] % state['precondition_frequency'] == 0:
state['Q'] = self.get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims)
# state['Q'] = self.get_fast_QR(state, max_precond_dim, merge_dims)
if state["step"] > 0:
state["exp_avg"] = self.project(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim)
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
"""
Projects the gradient back to the original space.
"""
original_shape = grad.shape
if merge_dims:
if self._data_format == 'channels_last' and grad.dim() == 4:
transposed_shape = grad.transpose(0, 3, 1, 2).shape
grad = self.merge_dims(grad, max_precond_dim)
for mat in state['Q']:
if len(mat) > 0:
grad = paddle.tensordot(
grad,
mat,
axes=[[0], [1]],
)
else:
transpose_order = list(range(1, len(grad.shape))) + [0]
grad = grad.transpose(transpose_order)
if merge_dims:
if self._data_format == 'channels_last' and len(original_shape) == 4:
grad = grad.reshape(transposed_shape).transpose(0, 2, 3, 1)
else:
grad = grad.reshape(original_shape)
return grad
def get_orthogonal_matrix(self, mat):
"""
Computes the eigenbases of the preconditioner using paddle.linalg.eigh decomposition.
"""
matrix = []
for m in mat:
if len(m) == 0:
matrix.append([])
continue
if m.data.dtype != paddle.float32:
float_data = False
original_type = m.data.dtype
original_device = m.data.place
matrix.append(m.data.to(paddle.float32))
else:
float_data = True
matrix.append(m.data)
final = []
for m in matrix:
if len(m) == 0:
final.append([])
continue
# try:
# _, Q = paddle.linalg.eigh(m+1e-30*paddle.eye(m.shape[0]))
# except:
# _, Q = paddle.linalg.eigh(m.to(paddle.float64)+1e-30*paddle.eye(m.shape[0]))
# Q = Q.to(m.dtype)
_, Q = paddle.linalg.eigh(m+1e-30*paddle.eye(m.shape[0]))
Q = paddle.flip(Q, [1])
if not float_data:
Q = Q.to(original_device, dtype=original_type)
final.append(Q)
return final
def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
"""
Computes the eigenbases of the preconditioner using one round of power iteration
followed by paddle.linalg.qr decomposition.
"""
precond_list = state['GG']
orth_list = state['Q']
matrix = []
orth_matrix = []
for m,o in zip(precond_list, orth_list):
if len(m) == 0:
matrix.append([])
orth_matrix.append([])
continue
if m.data.dtype != paddle.float32:
float_data = False
original_type = m.data.dtype
original_device = m.data.place
matrix.append(m.data.to(paddle.float32))
orth_matrix.append(o.data.to(paddle.float32))
else:
float_data = True
matrix.append(m.data.to(paddle.float32))
orth_matrix.append(o.data.to(paddle.float32))
orig_shape = state['exp_avg_sq'].shape
if self._data_format == 'channels_last' and len(orig_shape) == 4:
transposed_shape = state['exp_avg_sq'].transpose(0, 3, 1, 2).shape
if merge_dims:
exp_avg_sq = self.merge_dims(state['exp_avg_sq'], max_precond_dim)
else:
exp_avg_sq = state['exp_avg_sq']
final = []
for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
if len(m)==0:
final.append([])
continue
est_eig = paddle.diag(o.T @ m @ o)
sort_idx = paddle.argsort(est_eig, descending=True)
exp_avg_sq = exp_avg_sq.index_select(sort_idx, ind)
o = o[:,sort_idx]
power_iter = m @ o
Q, _ = paddle.linalg.qr(power_iter)
if not float_data:
Q = Q.to(original_device, dtype=original_type)
final.append(Q)
if merge_dims:
if self._data_format == 'channels_last' and len(orig_shape) == 4:
exp_avg_sq = exp_avg_sq.reshape(transposed_shape).transpose(0, 2, 3, 1)
else:
exp_avg_sq = exp_avg_sq.reshape(orig_shape)
state['exp_avg_sq'] = exp_avg_sq
return final
class MLP_paddle(paddle.nn.Layer):
def __init__(self, in_features, out_features):
super().__init__()
self._internal_weight = paddle.randn([out_features, in_features])
self.weight = paddle.create_parameter(shape=self._internal_weight.shape,
dtype=self._internal_weight.dtype,
default_initializer=paddle.nn.initializer.Assign(self._internal_weight))
self.weight.stop_gradient = False
def forward(self, inp):
return paddle.matmul(inp, self.weight.T)
# Parts of the code are modifications of Pytorch's AdamW optimizer
# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py
class SOAP(torch.optim.Optimizer):
"""
Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.003):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
Adam's betas parameters (b1, b2).
shampoo_beta (`float`, *optional*, defaults to -1):
If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
eps (`float`, *optional*, defaults to 1e-08):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
precondition_frequency (`int`, *optional*, defaults to 10):
How often to update the preconditioner.
max_precond_dim (`int`, *optional*, defaults to 10000):
Maximum dimension of the preconditioner.
Set to 10000, so that we exclude most common vocab sizes while including layers.
merge_dims (`bool`, *optional*, defaults to `False`):
Whether or not to merge dimensions of the preconditioner.
precondition_1d (`bool`, *optional*, defaults to `False`):
Whether or not to precondition 1D gradients.
normalize_grads (`bool`, *optional*, defaults to `False`):
Whether or not to normalize gradients per layer.
Helps at large precondition_frequency (~100 in our experiments),
but hurts performance at small precondition_frequency (~10 in our experiments).
data_format (`str`, *optional*, defaults to `channels_first`):
Data format of the input for convolutional layers.
Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias correction in Adam.
"""
def __init__(
self,
params,
lr: float = 3e-3,
betas=(0.95, 0.95),
shampoo_beta: float= -1,
eps: float = 1e-8,
weight_decay: float = 0.01,
precondition_frequency: int=10,
max_precond_dim: int=10000, #
merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
precondition_1d: bool = False,
normalize_grads: bool = False,
data_format: str = "channels_first",
correct_bias: bool = True,
):
defaults = {
"lr": lr,
"betas": betas,
"shampoo_beta": shampoo_beta,
"eps": eps,
"weight_decay": weight_decay,
"precondition_frequency": precondition_frequency,
"max_precond_dim": max_precond_dim,
"merge_dims": merge_dims,
"precondition_1d": precondition_1d,
"normalize_grads": normalize_grads,
"correct_bias": correct_bias,
}
super().__init__(params, defaults)
self._data_format = data_format
def merge_dims(self, grad, max_precond_dim):
"""
Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
"""
assert self._data_format in ["channels_first", "channels_last"]
if self._data_format == "channels_last" and grad.dim() == 4:
grad = grad.permute(0, 3, 1, 2)
shape = grad.shape
new_shape = []
curr_shape = 1
for sh in shape:
temp_shape = curr_shape * sh
if temp_shape > max_precond_dim:
if curr_shape > 1:
new_shape.append(curr_shape)
curr_shape = sh
else:
new_shape.append(sh)
curr_shape = 1
else:
curr_shape = temp_shape
if curr_shape > 1 or len(new_shape)==0:
new_shape.append(curr_shape)
new_grad = grad.reshape(new_shape)
return new_grad
@torch.no_grad()
def step(self, closure = None):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
if closure is None:
loss = None
else:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if "step" not in state:
state["step"] = 0
# State initialization
if "exp_avg" not in state:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad)
if 'Q' not in state:
self.init_preconditioner(
grad,
state,
precondition_frequency=group['precondition_frequency'],
precondition_1d=group['precondition_1d'],
shampoo_beta=(group['shampoo_beta'] if group['shampoo_beta'] >= 0 else group["betas"][1]),
max_precond_dim=group['max_precond_dim'],
merge_dims=group["merge_dims"],
)
self.update_preconditioner(grad, state,
max_precond_dim=group['max_precond_dim'],
merge_dims=group["merge_dims"],
precondition_1d=group["precondition_1d"])
continue # first step is skipped so that we never use the current gradients in the projection.
# Projecting gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state['GG']
grad_projected = self.project(grad, state, merge_dims=group["merge_dims"],
max_precond_dim=group['max_precond_dim'])
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).add_(grad_projected.square(), alpha=(1.0 - beta2))
denom = exp_avg_sq.sqrt().add_(group["eps"])
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state['GG']
# exp_avg_projected = self.project(exp_avg, state, merge_dims=group["merge_dims"],
# max_precond_dim=group['max_precond_dim'])
exp_avg_projected = exp_avg
step_size = group["lr"]
if group["correct_bias"]:
bias_correction1 = 1.0 - beta1 ** (state["step"])
bias_correction2 = 1.0 - beta2 ** (state["step"])
step_size = step_size * (bias_correction2 ** .5) / bias_correction1
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
# to the original space
norm_grad = self.project_back(exp_avg_projected / denom, state, merge_dims=group["merge_dims"],
max_precond_dim=group['max_precond_dim'])
if group["normalize_grads"]:
norm_grad = norm_grad / (1e-30+torch.mean(norm_grad**2)**0.5)
p.add_(norm_grad, alpha=-step_size)
# From AdamW code: Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
# Update is done after the gradient step to avoid using current gradients in the projection.
self.update_preconditioner(grad, state,
max_precond_dim=group['max_precond_dim'],
merge_dims=group["merge_dims"],
precondition_1d=group["precondition_1d"])
return loss
def init_preconditioner(self, grad, state, precondition_frequency=10,
shampoo_beta=0.95, max_precond_dim=10000, precondition_1d=False,
merge_dims=False):
"""
Initializes the preconditioner matrices (L and R in the paper).
"""
state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
if grad.dim() == 1:
if not precondition_1d or grad.shape[0] > max_precond_dim:
state['GG'].append([])
else:
state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device))
else:
if merge_dims:
grad = self.merge_dims(grad, max_precond_dim)
for sh in grad.shape:
if sh > max_precond_dim:
state['GG'].append([])
else:
state['GG'].append(torch.zeros(sh, sh, device=grad.device))
state['Q'] = None # Will hold all the eigenbases of the preconditioner.
state['precondition_frequency'] = precondition_frequency
state['shampoo_beta'] = shampoo_beta
def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
"""
Projects the gradient to the eigenbases of the preconditioner.
"""
original_shape = grad.shape
if merge_dims:
if grad.dim() == 4 and self._data_format == 'channels_last':
permuted_shape = grad.permute(0, 3, 1, 2).shape
grad = self.merge_dims(grad, max_precond_dim)
for mat in state['Q']:
if len(mat) > 0:
grad = torch.tensordot(
grad,
mat,
dims=[[0], [0]],
)
else:
permute_order = list(range(1, len(grad.shape))) + [0]
grad = grad.permute(permute_order)
if merge_dims:
if self._data_format == 'channels_last' and len(original_shape) == 4:
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
else:
grad = grad.reshape(original_shape)
return grad
def update_preconditioner(self, grad, state,
max_precond_dim=10000, merge_dims=False, precondition_1d=False):
"""
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
"""
if state["Q"] is not None:
state["exp_avg"] = self.project_back(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim)
if grad.dim() == 1:
if precondition_1d and grad.shape[0] <= max_precond_dim:
state['GG'][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1-state['shampoo_beta'])
else:
if merge_dims:
new_grad = self.merge_dims(grad, max_precond_dim)
for idx, sh in enumerate(new_grad.shape):
if sh <= max_precond_dim:
outer_product = torch.tensordot(
new_grad,
new_grad,
dims=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2,
)
state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
else:
for idx, sh in enumerate(grad.shape):
if sh <= max_precond_dim:
outer_product = torch.tensordot(
grad,
grad,
# Contracts across all dimensions except for k.
dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2,
)
state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
if state['Q'] is None:
state['Q'] = self.get_orthogonal_matrix(state['GG'])
if state['step'] > 0 and state['step'] % state['precondition_frequency'] == 0:
state['Q'] = self.get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims)
# state['Q'] = self.get_fast_QR(state, max_precond_dim, merge_dims)
if state["step"] > 0:
state["exp_avg"] = self.project(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim)
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
"""
Projects the gradient back to the original space.
"""
original_shape = grad.shape
if merge_dims:
if self._data_format == 'channels_last' and grad.dim() == 4:
permuted_shape = grad.permute(0, 3, 1, 2).shape
grad = self.merge_dims(grad, max_precond_dim)
for mat in state['Q']:
if len(mat) > 0:
grad = torch.tensordot(
grad,
mat,
dims=[[0], [1]],
)
else:
permute_order = list(range(1, len(grad.shape))) + [0]
grad = grad.permute(permute_order)
if merge_dims:
if self._data_format == 'channels_last' and len(original_shape) == 4:
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
else:
grad = grad.reshape(original_shape)
return grad
def get_orthogonal_matrix(self, mat):
"""
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
"""
matrix = []
for m in mat:
if len(m) == 0:
matrix.append([])
continue
if m.data.dtype != torch.float:
float_data = False
original_type = m.data.dtype
original_device = m.data.device
matrix.append(m.data.float())
else:
float_data = True
matrix.append(m.data)
final = []
for m in matrix:
if len(m) == 0:
final.append([])
continue
try:
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device))
except:
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device))
Q = Q.to(m.dtype)
Q = torch.flip(Q, [1])
if not float_data:
Q = Q.to(original_device).type(original_type)
final.append(Q)
return final
def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
"""
Computes the eigenbases of the preconditioner using one round of power iteration
followed by torch.linalg.qr decomposition.
"""
precond_list = state['GG']
orth_list = state['Q']
matrix = []
orth_matrix = []
for m,o in zip(precond_list, orth_list):
if len(m) == 0:
matrix.append([])
orth_matrix.append([])
continue
if m.data.dtype != torch.float:
float_data = False
original_type = m.data.dtype
original_device = m.data.device
matrix.append(m.data.float())
orth_matrix.append(o.data.float())
else:
float_data = True
matrix.append(m.data.float())
orth_matrix.append(o.data.float())
orig_shape = state['exp_avg_sq'].shape
if self._data_format == 'channels_last' and len(orig_shape) == 4:
permuted_shape = state['exp_avg_sq'].permute(0, 3, 1, 2).shape
if merge_dims:
exp_avg_sq = self.merge_dims(state['exp_avg_sq'], max_precond_dim)
else:
exp_avg_sq = state['exp_avg_sq']
final = []
for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
if len(m)==0:
final.append([])
continue
est_eig = torch.diag(o.T @ m @ o)
sort_idx = torch.argsort(est_eig, descending=True)
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
o = o[:,sort_idx]
power_iter = m @ o
Q, _ = torch.linalg.qr(power_iter)
if not float_data:
Q = Q.to(original_device).type(original_type)
final.append(Q)
if merge_dims:
if self._data_format == 'channels_last' and len(orig_shape) == 4:
exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
else:
exp_avg_sq = exp_avg_sq.reshape(orig_shape)
state['exp_avg_sq'] = exp_avg_sq
return final
class MLP_torch(torch.nn.Module):
def __init__(self, in_features, out_features, device=torch.cuda.current_device()):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features],device=device))
def forward(self, inp):
# return torch.matmul(inp, self.weight.t())
return torch.matmul(inp, self.weight.t())
class SampleDataSet(torch.utils.data.Dataset):
def __init__(self, samples, hidden_state):
super().__init__()
assert hidden_state % 64 == 0 and hidden_state >= 64
self._data = torch.rand([samples, hidden_state])
self._label = torch.rand([samples, hidden_state//64])
def __len__(self):
return self._data.size(0)
def __getitem__(self, index):
return self._data[index], self._label[index]
if __name__ == "__main__":
samples = 20
hidden_state = 1024
batch_size = 1
sample_dataset = SampleDataSet(samples, hidden_state)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size)
model_torch = MLP_torch(hidden_state, hidden_state//64)
model_paddle = MLP_paddle(hidden_state, hidden_state//64)
for name, param in model_torch.named_parameters():
print(f"pddle param name: {name}, mean: {param.mean().item()}, sum: {param.sum().item()}")
for name, param in model_torch.named_parameters():
print(f"torch param name: {name}, mean: {param.mean().item()}, sum: {param.sum().item()}")
lr = 0.03
weight_decay=0
optimizer_paddle = SOAP_paddle(model_paddle.parameters(), lr, weight_decay=weight_decay)
criterion_paddle = paddle.nn.L1Loss(reduction='mean')
optimizer_torch = SOAP(model_torch.parameters(), lr,weight_decay=weight_decay)
criterion_torch = torch.nn.L1Loss(reduction='mean')
params_paddle = model_paddle.parameters()
params_torch = [param for param in model_torch.parameters()]
for param_paddle, param_torch in zip(params_paddle, params_torch):
param_paddle.set_value(param_torch.detach().cpu().numpy())
for param_paddle, param_torch in zip(params_paddle, params_torch):
# check param
np.testing.assert_allclose(param_torch.detach().cpu().numpy(), param_paddle.numpy(), atol=0)
stop = 20
for iter,(inp,label) in enumerate(sample_dataloader):
inp_numpy,label_numpy = inp.numpy(),label.numpy()
out_paddle = model_paddle(paddle.to_tensor(inp_numpy))
loss_paddle = criterion_paddle(out_paddle, paddle.to_tensor(label_numpy))
loss_paddle.backward()
out_torch = model_torch(torch.from_numpy(inp_numpy).to(torch.cuda.current_device()))
loss_torch = criterion_torch(out_torch, torch.from_numpy(label_numpy).to(torch.cuda.current_device()))
loss_torch.backward()
optimizer_paddle.step()
optimizer_torch.step()
state_paddle = optimizer_paddle.state
state_torch = optimizer_torch.state
optimizer_paddle.clear_grad()
optimizer_torch.zero_grad()
print(f"[ITERATION] {(iter+1)}/{len(sample_dataloader)} loss paddle: {loss_paddle.item():.6f}, loss torch: {loss_torch.item():.6f}")
if iter >= stop:
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里修改一下,to_tensor自带blocking,避免多次调用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在allen_cahn.md文档开头,更新一下soap优化器的模型指标 L2Rel.u: 6.8e-6
,以及预训练模型url:https://paddle-org.bj.bcebos.com/paddlescience%2Fmodels%2FAllenCahn%2Fallen_cahn_piratenet_soap_pretrained.pdparams.pdparams
Done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Describe
在 PaddleScience 中实现 SOAP 优化器
RFC:PaddlePaddle/community#1090
已知问题
param_group
实验数据
实验环境:
复现命令
best metric指标对比

实验结论
L2Rel.u 从1.2e-5 提升至约 7e-6
预测结果

附
原始训练日志
baseline_train.log
soap_train.log