Skip to content

【Hackathon 8th No.12】在 PaddleScience 中实现 SOAP 优化器 #1102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 14, 2025

Conversation

BeingGod
Copy link
Contributor

@BeingGod BeingGod commented Mar 11, 2025

PR types

New features

PR changes

APIs

Describe

在 PaddleScience 中实现 SOAP 优化器

RFC:PaddlePaddle/community#1090

已知问题

  1. 不支持 param_group

实验数据

实验环境:

  1. Paddle: 3.0.0rc1
  2. GPU: 1080ti

复现命令

python allen_cahn_piratenet.py TRAIN.optim=soap TRAIN.lr_schedler.warmup_epoch=5

best metric指标对比
result

实验结论
L2Rel.u 从1.2e-5 提升至约 7e-6

预测结果
ac

原始训练日志
baseline_train.log
soap_train.log

Copy link

paddle-bot bot commented Mar 11, 2025

Thanks for your contribution!

@BeingGod BeingGod changed the title 【Hackathon 8th No.12】在 PaddleScience 中实现 SOAP 优化器 【Hackathon 8th No.12】在 PaddleScience 中实现 SOAP 优化器 [WIP] Mar 11, 2025
@BeingGod
Copy link
Contributor Author

为验证 SOAP 实现正确性,基于 MLP demo + SOAP 对torch与paddle 的 loss 进行对比

前 20iter loss对比如下:

λ beinggod-workstation /workspace/hackathon/SOAP/SOAP python soap_paddle.py 
grep: warning: GREP_OPTIONS is deprecated; please use an alias or script
W0313 05:29:07.938906 1188667 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 12.6, Runtime API Version: 11.8
W0313 05:29:07.939640 1188667 gpu_resources.cc:164] device: 0, cuDNN Version: 8.9.
pddle param name: weight, mean: 0.006312752142548561, sum: 103.42813110351562
torch param name: weight, mean: 0.006312752142548561, sum: 103.42813110351562
[ITERATION] 1/20 loss paddle: 12.870917, loss torch: 12.870915
[ITERATION] 2/20 loss paddle: 15.275839, loss torch: 15.275839
[ITERATION] 3/20 loss paddle: 9.223783, loss torch: 9.223782
[ITERATION] 4/20 loss paddle: 9.283824, loss torch: 9.283824
[ITERATION] 5/20 loss paddle: 7.028357, loss torch: 7.028356
[ITERATION] 6/20 loss paddle: 7.973010, loss torch: 7.973009
[ITERATION] 7/20 loss paddle: 11.567774, loss torch: 11.567776
[ITERATION] 8/20 loss paddle: 5.823763, loss torch: 5.823766
[ITERATION] 9/20 loss paddle: 12.174599, loss torch: 12.174601
[ITERATION] 10/20 loss paddle: 8.206469, loss torch: 8.206469
[ITERATION] 11/20 loss paddle: 7.991440, loss torch: 7.991440
[ITERATION] 12/20 loss paddle: 7.984601, loss torch: 7.984600
[ITERATION] 13/20 loss paddle: 9.944571, loss torch: 9.944571
[ITERATION] 14/20 loss paddle: 10.886424, loss torch: 10.886503
[ITERATION] 15/20 loss paddle: 8.474144, loss torch: 8.474236
[ITERATION] 16/20 loss paddle: 7.847643, loss torch: 7.847647
[ITERATION] 17/20 loss paddle: 10.080597, loss torch: 10.080593
[ITERATION] 18/20 loss paddle: 9.120567, loss torch: 9.120558
[ITERATION] 19/20 loss paddle: 7.173148, loss torch: 7.173133
[ITERATION] 20/20 loss paddle: 7.039429, loss torch: 7.039426
λ beinggod-workstation /workspace/hackathon/SOAP/SOAP 

loss误差在1e-5,可以认为二者实现是对齐的

测试代码:

import paddle
import paddle.device

from itertools import chain
from collections import defaultdict

import paddle.optimizer
import paddle.utils

import torch
import numpy as np

from itertools import chain

import torch.utils
import torch.utils.data


# Parts of the code are modifications of Pypaddle's AdamW optimizer
# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_paddle/galore_projector.py

seed = 1234
paddle.device.set_device('gpu')
paddle.seed(seed)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

class SOAP_paddle(paddle.optimizer.Optimizer):
    """
    Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).

    Parameters:
        params (`list|tuple`):
            Iterable of parameters to optimize or dictionaries defining parameter groups.
        lr (`float`, *optional*, defaults to 0.003):
            The learning rate to use.
        betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
            Adam's betas parameters (b1, b2).
        shampoo_beta (`float`, *optional*, defaults to -1):
            If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
        eps (`float`, *optional*, defaults to 1e-08):
            Adam's epsilon for numerical stability.
        weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
        precondition_frequency (`int`, *optional*, defaults to 10):
            How often to update the preconditioner.
        max_precond_dim (`int`, *optional*, defaults to 10000):
            Maximum dimension of the preconditioner.
            Set to 10000, so that we exclude most common vocab sizes while including layers.
        merge_dims (`bool`, *optional*, defaults to `False`):
            Whether or not to merge dimensions of the preconditioner.
        precondition_1d (`bool`, *optional*, defaults to `False`):
            Whether or not to precondition 1D gradients.
        normalize_grads (`bool`, *optional*, defaults to `False`):
            Whether or not to normalize gradients per layer. 
            Helps at large precondition_frequency (~100 in our experiments), 
            but hurts performance at small precondition_frequency (~10 in our experiments).
        data_format (`str`, *optional*, defaults to `channels_first`):
            Data format of the input for convolutional layers.
            Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
        correct_bias (`bool`, *optional*, defaults to `True`):
            Whether or not to use bias correction in Adam.
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.
    """

    def __init__(
        self,
        params,
        lr: float = 3e-3,
        betas=(0.95, 0.95),
        shampoo_beta: float= -1,
        eps: float = 1e-8,
        weight_decay: float = 0.01,
        precondition_frequency: int=10,
        max_precond_dim: int=10000, # 
        merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
        precondition_1d: bool = False,
        normalize_grads: bool = False,
        data_format: str = "channels_first",
        correct_bias: bool = True,
        name: str = None,
    ):
        self._betas = betas
        self._shampoo_beta = shampoo_beta
        self._eps = eps
        self._precondition_frequency = precondition_frequency
        self._max_precond_dim = max_precond_dim
        self._merge_dims = merge_dims
        self._precondition_1d = precondition_1d
        self._normalize_grads = normalize_grads
        self._correct_bias = correct_bias
        self._weight_decay = weight_decay
        
        self.state = defaultdict(dict)

        super().__init__(learning_rate=lr, 
                         parameters=params,
                         weight_decay=weight_decay,
                         name=name)

        if isinstance(self._parameter_list[0],dict):
            raise TypeError(
                "The parameter groups is not supported on SOAP optimizer."
            )

        self._data_format = data_format

        
    def merge_dims(self, grad, max_precond_dim):
        """
        Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
        """
        assert self._data_format in ["channels_first", "channels_last"]
        if self._data_format == "channels_last" and grad.dim() == 4:
            grad = grad.transpose(0, 3, 1, 2)
        shape = grad.shape
        new_shape = []
        
        curr_shape = 1
        for sh in shape:
            temp_shape = curr_shape * sh
            if temp_shape > max_precond_dim:
                if curr_shape > 1:
                    new_shape.append(curr_shape)
                    curr_shape = sh
                else:
                    new_shape.append(sh)
                    curr_shape = 1
            else:
                curr_shape = temp_shape
        
        if curr_shape > 1 or len(new_shape)==0:
            new_shape.append(curr_shape)
        
        new_grad = grad.reshape(new_shape)
        return new_grad               

    @paddle.base.framework.non_static_only
    def step(self, closure = None):
        """
        Performs a single optimization step.

        Arguments:
            closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
        """
        with paddle.no_grad():
            if closure is None:
                loss = None
            else:
                closure = paddle.enable_grad()(closure)
                loss = closure()

            for p in self._parameter_list:
                if p.grad is None:
                    continue
                grad = p.grad

                state = self.state[p]
                
                if "step" not in state:
                    state["step"] = 0 
                    
                # State initialization
                if "exp_avg" not in state:
                    # Exponential moving average of gradient values
                    state["exp_avg"] = paddle.zeros_like(grad)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = paddle.zeros_like(grad)
                
                if 'Q' not in state:
                    self.init_preconditioner(
                        grad,
                        state,
                        precondition_frequency=self._precondition_frequency,
                        precondition_1d=self._precondition_1d,
                        shampoo_beta=(self._shampoo_beta if self._shampoo_beta >= 0 else self._betas[1]),
                        max_precond_dim=self._max_precond_dim,
                        merge_dims=self._merge_dims,
                    )
                    self.update_preconditioner(grad, state,
                                                max_precond_dim=self._max_precond_dim,
                                                merge_dims=self._merge_dims,
                                                precondition_1d=self._precondition_1d)
                    continue # first step is skipped so that we never use the current gradients in the projection.
                
                # Projecting gradients to the eigenbases of Shampoo's preconditioner 
                # i.e. projecting to the eigenbases of matrices in state['GG']
                grad_projected = self.project(grad, state, merge_dims=self._merge_dims, 
                                                max_precond_dim=self._max_precond_dim)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = paddle.to_tensor(self._betas)

                state["step"] += 1

                # Decay the first and second moment running average coefficient
                # In-place operations to update the averages at the same time
                exp_avg.multiply_(beta1).add_((1.0 - beta1)*grad_projected)
                exp_avg_sq.multiply_(beta2).add_((1.0-beta2)*grad_projected.square())

                denom = exp_avg_sq.sqrt().add_(paddle.to_tensor(self._eps))
                
                # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner 
                # i.e. projecting to the eigenbases of matrices in state['GG']
                # exp_avg_projected = self.project(exp_avg, state, merge_dims=self._merge_dims"],
                #                                  max_precond_dim=self._max_precond_dim'])
                exp_avg_projected = exp_avg
                
                lr = self._learning_rate
                step_size = lr
                if self._correct_bias:
                    bias_correction1 = 1.0 - beta1 ** (state["step"])
                    bias_correction2 = 1.0 - beta2 ** (state["step"])
                    step_size = step_size * (bias_correction2 ** .5) / bias_correction1

                # Projecting back the preconditioned (by Adam) exponential moving average of gradients
                # to the original space
                norm_grad = self.project_back(exp_avg_projected / denom, state, merge_dims=self._merge_dims,
                                                    max_precond_dim=self._max_precond_dim)

                if self._normalize_grads:
                    norm_grad = norm_grad / (1e-30+paddle.mean(norm_grad**2)**0.5)
                
                p.add_(-step_size * norm_grad)
                

                # From AdamW code: Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want to decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                # Add weight decay at the end (fixed version)
                if self._weight_decay > 0.0:
                    p.add_((-lr * self._weight_decay) * p)
                    
                # Update is done after the gradient step to avoid using current gradients in the projection.
                self.update_preconditioner(grad, state, 
                                                max_precond_dim=self._max_precond_dim,
                                                merge_dims=self._merge_dims,
                                                precondition_1d=self._precondition_1d)
            
        return loss
    
    def init_preconditioner(self, grad, state, precondition_frequency=10, 
                            shampoo_beta=0.95, max_precond_dim=10000, precondition_1d=False,
                            merge_dims=False):
        """
        Initializes the preconditioner matrices (L and R in the paper).
        """
        state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
        if grad.dim() == 1:
            if not precondition_1d or grad.shape[0] > max_precond_dim:
                state['GG'].append([])
            else:
                state['GG'].append(paddle.zeros([grad.shape[0], grad.shape[0]]))
        else:
            if merge_dims:
                grad = self.merge_dims(grad, max_precond_dim)

            for sh in grad.shape:
                if sh > max_precond_dim:
                    state['GG'].append([])
                else:
                    state['GG'].append(paddle.zeros([sh, sh]))
                    
        state['Q'] = None # Will hold all the eigenbases of the preconditioner.
        state['precondition_frequency'] = precondition_frequency
        state['shampoo_beta'] = shampoo_beta          
        
    def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
        """
        Projects the gradient to the eigenbases of the preconditioner.
        """
        original_shape = grad.shape
        if merge_dims:
            if grad.dim() == 4 and self._data_format == 'channels_last':
                transposed_shape = grad.transpose(0, 3, 1, 2).shape
            grad = self.merge_dims(grad, max_precond_dim)
        
        for mat in state['Q']:
            if len(mat) > 0:
                grad = paddle.tensordot(
                        grad,
                        mat,
                        axes=[[0], [0]],
                    )
            else:
                transpose_order = list(range(1, len(grad.shape))) + [0]
                grad = grad.transpose(transpose_order)
        
        if merge_dims:
            if self._data_format == 'channels_last' and len(original_shape) == 4:
                grad = grad.reshape(transposed_shape).transpose(0, 2, 3, 1)
            else:
                grad = grad.reshape(original_shape)
        return grad
        
    def update_preconditioner(self, grad, state, 
                              max_precond_dim=10000, merge_dims=False, precondition_1d=False):
        """
        Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
        """
        if state["Q"] is not None:
            state["exp_avg"] = self.project_back(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim)
        if grad.dim() == 1:
            if precondition_1d and grad.shape[0] <= max_precond_dim:
                state['GG'][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1-state['shampoo_beta'])
        else:
            if merge_dims:
                new_grad = self.merge_dims(grad, max_precond_dim)
                for idx, sh in enumerate(new_grad.shape):
                    if sh <= max_precond_dim:
                        outer_product = paddle.tensordot(
                                new_grad,
                                new_grad,
                                axes=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2,
                            )
                        state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
            else:
                for idx, sh in enumerate(grad.shape):
                    if sh <= max_precond_dim:
                        outer_product = paddle.tensordot(
                                grad,
                                grad,
                                # Contracts across all dimensions except for k.
                                axes=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2,
                            )
                        state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
        
        if state['Q'] is None:
            state['Q'] = self.get_orthogonal_matrix(state['GG'])
        if state['step'] > 0 and state['step'] % state['precondition_frequency'] == 0:
            state['Q'] = self.get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims)
            # state['Q'] = self.get_fast_QR(state, max_precond_dim, merge_dims)             

        if state["step"] > 0:
            state["exp_avg"] = self.project(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim) 

    def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
        """
        Projects the gradient back to the original space.
        """
        original_shape = grad.shape
        if merge_dims:
            if self._data_format == 'channels_last' and grad.dim() == 4:
                transposed_shape = grad.transpose(0, 3, 1, 2).shape
            grad = self.merge_dims(grad, max_precond_dim)
        for mat in state['Q']:
            if len(mat) > 0:
                grad = paddle.tensordot(
                        grad,
                        mat,
                        axes=[[0], [1]],
                    )
            else:
                transpose_order = list(range(1, len(grad.shape))) + [0]
                grad = grad.transpose(transpose_order)
                
        if merge_dims:
            if self._data_format == 'channels_last' and len(original_shape) == 4:
                grad = grad.reshape(transposed_shape).transpose(0, 2, 3, 1)
            else:
                grad = grad.reshape(original_shape)
        return grad
        

    def get_orthogonal_matrix(self, mat):
        """
        Computes the eigenbases of the preconditioner using paddle.linalg.eigh decomposition.
        """
        matrix = []
        for m in mat:
            if len(m) == 0:
                matrix.append([])
                continue
            if m.data.dtype != paddle.float32:
                float_data = False
                original_type = m.data.dtype
                original_device = m.data.place
                matrix.append(m.data.to(paddle.float32))
            else:
                float_data = True
                matrix.append(m.data)
        
        final = []
        for m in matrix:
            if len(m) == 0:
                final.append([])
                continue
            # try:
            #     _, Q = paddle.linalg.eigh(m+1e-30*paddle.eye(m.shape[0]))
            # except:
            #     _, Q = paddle.linalg.eigh(m.to(paddle.float64)+1e-30*paddle.eye(m.shape[0]))
            #     Q = Q.to(m.dtype)
            _, Q = paddle.linalg.eigh(m+1e-30*paddle.eye(m.shape[0]))
            Q = paddle.flip(Q, [1])

            if not float_data:
                Q = Q.to(original_device, dtype=original_type)
            final.append(Q)
        return final
        

    def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
        """
        Computes the eigenbases of the preconditioner using one round of power iteration 
        followed by paddle.linalg.qr decomposition.
        """
        precond_list = state['GG']
        orth_list = state['Q']

        matrix = []
        orth_matrix = []
        for m,o in zip(precond_list, orth_list):
            if len(m) == 0:
                matrix.append([])
                orth_matrix.append([])
                continue
            if m.data.dtype != paddle.float32:
                float_data = False
                original_type = m.data.dtype
                original_device = m.data.place
                matrix.append(m.data.to(paddle.float32))
                orth_matrix.append(o.data.to(paddle.float32))
            else:
                float_data = True
                matrix.append(m.data.to(paddle.float32))
                orth_matrix.append(o.data.to(paddle.float32))
        
        orig_shape = state['exp_avg_sq'].shape
        if self._data_format == 'channels_last' and len(orig_shape) == 4:
            transposed_shape = state['exp_avg_sq'].transpose(0, 3, 1, 2).shape
        if merge_dims:
            exp_avg_sq = self.merge_dims(state['exp_avg_sq'], max_precond_dim)
        else:
            exp_avg_sq = state['exp_avg_sq']
            
        final = []
        for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
            if len(m)==0:
                final.append([])
                continue
            est_eig = paddle.diag(o.T @ m @ o)
            sort_idx = paddle.argsort(est_eig, descending=True)
            exp_avg_sq = exp_avg_sq.index_select(sort_idx, ind)
            o = o[:,sort_idx]
            power_iter = m @ o
            Q, _ = paddle.linalg.qr(power_iter)

            if not float_data:
                Q = Q.to(original_device, dtype=original_type)
            final.append(Q)
        
        if merge_dims:
            if self._data_format == 'channels_last' and len(orig_shape) == 4:
                exp_avg_sq = exp_avg_sq.reshape(transposed_shape).transpose(0, 2, 3, 1)
            else:
                exp_avg_sq = exp_avg_sq.reshape(orig_shape)
                
        state['exp_avg_sq'] = exp_avg_sq
        return final


class MLP_paddle(paddle.nn.Layer):


    def __init__(self, in_features, out_features):
        super().__init__()
        self._internal_weight = paddle.randn([out_features, in_features])
        self.weight = paddle.create_parameter(shape=self._internal_weight.shape,
                        dtype=self._internal_weight.dtype,
                        default_initializer=paddle.nn.initializer.Assign(self._internal_weight))
        self.weight.stop_gradient = False
    
    def forward(self, inp):
        return paddle.matmul(inp, self.weight.T)


# Parts of the code are modifications of Pytorch's AdamW optimizer
# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py


class SOAP(torch.optim.Optimizer):
    """
    Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).

    Parameters:
        params (`Iterable[nn.parameter.Parameter]`):
            Iterable of parameters to optimize or dictionaries defining parameter groups.
        lr (`float`, *optional*, defaults to 0.003):
            The learning rate to use.
        betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
            Adam's betas parameters (b1, b2).
        shampoo_beta (`float`, *optional*, defaults to -1):
            If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
        eps (`float`, *optional*, defaults to 1e-08):
            Adam's epsilon for numerical stability.
        weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
        precondition_frequency (`int`, *optional*, defaults to 10):
            How often to update the preconditioner.
        max_precond_dim (`int`, *optional*, defaults to 10000):
            Maximum dimension of the preconditioner.
            Set to 10000, so that we exclude most common vocab sizes while including layers.
        merge_dims (`bool`, *optional*, defaults to `False`):
            Whether or not to merge dimensions of the preconditioner.
        precondition_1d (`bool`, *optional*, defaults to `False`):
            Whether or not to precondition 1D gradients.
        normalize_grads (`bool`, *optional*, defaults to `False`):
            Whether or not to normalize gradients per layer. 
            Helps at large precondition_frequency (~100 in our experiments), 
            but hurts performance at small precondition_frequency (~10 in our experiments).
        data_format (`str`, *optional*, defaults to `channels_first`):
            Data format of the input for convolutional layers.
            Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
        correct_bias (`bool`, *optional*, defaults to `True`):
            Whether or not to use bias correction in Adam.
    """

    def __init__(
        self,
        params,
        lr: float = 3e-3,
        betas=(0.95, 0.95),
        shampoo_beta: float= -1,
        eps: float = 1e-8,
        weight_decay: float = 0.01,
        precondition_frequency: int=10,
        max_precond_dim: int=10000, # 
        merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
        precondition_1d: bool = False,
        normalize_grads: bool = False,
        data_format: str = "channels_first",
        correct_bias: bool = True,
    ):
        defaults = {
            "lr": lr,
            "betas": betas,
            "shampoo_beta": shampoo_beta,
            "eps": eps,
            "weight_decay": weight_decay,
            "precondition_frequency": precondition_frequency,
            "max_precond_dim": max_precond_dim,
            "merge_dims": merge_dims,
            "precondition_1d": precondition_1d,
            "normalize_grads": normalize_grads,
            "correct_bias": correct_bias,
        }
        super().__init__(params, defaults)
        self._data_format = data_format
        
    def merge_dims(self, grad, max_precond_dim):
        """
        Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
        """
        assert self._data_format in ["channels_first", "channels_last"]
        if self._data_format == "channels_last" and grad.dim() == 4:
            grad = grad.permute(0, 3, 1, 2)
        shape = grad.shape
        new_shape = []
        
        curr_shape = 1
        for sh in shape:
            temp_shape = curr_shape * sh
            if temp_shape > max_precond_dim:
                if curr_shape > 1:
                    new_shape.append(curr_shape)
                    curr_shape = sh
                else:
                    new_shape.append(sh)
                    curr_shape = 1
            else:
                curr_shape = temp_shape
        
        if curr_shape > 1 or len(new_shape)==0:
            new_shape.append(curr_shape)
        
        new_grad = grad.reshape(new_shape)
        return new_grad               

    @torch.no_grad()
    def step(self, closure = None):
        """
        Performs a single optimization step.

        Arguments:
            closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
        """
        if closure is None:
            loss = None
        else:
            loss = closure()
            
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad

                state = self.state[p]
                
                if "step" not in state:
                    state["step"] = 0 
                    
                # State initialization
                if "exp_avg" not in state:
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(grad)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(grad)
                
                if 'Q' not in state:
                    self.init_preconditioner(
                        grad,
                        state,
                        precondition_frequency=group['precondition_frequency'],
                        precondition_1d=group['precondition_1d'],
                        shampoo_beta=(group['shampoo_beta'] if group['shampoo_beta'] >= 0 else group["betas"][1]),
                        max_precond_dim=group['max_precond_dim'],
                        merge_dims=group["merge_dims"],
                    )
                    self.update_preconditioner(grad, state,
                                               max_precond_dim=group['max_precond_dim'],
                                               merge_dims=group["merge_dims"],
                                               precondition_1d=group["precondition_1d"])
                    continue # first step is skipped so that we never use the current gradients in the projection.
                
                # Projecting gradients to the eigenbases of Shampoo's preconditioner 
                # i.e. projecting to the eigenbases of matrices in state['GG']
                grad_projected = self.project(grad, state, merge_dims=group["merge_dims"], 
                                              max_precond_dim=group['max_precond_dim'])

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1

                # Decay the first and second moment running average coefficient
                # In-place operations to update the averages at the same time
                exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))
                exp_avg_sq.mul_(beta2).add_(grad_projected.square(), alpha=(1.0 - beta2))

                denom = exp_avg_sq.sqrt().add_(group["eps"])
                
                # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner 
                # i.e. projecting to the eigenbases of matrices in state['GG']
                # exp_avg_projected = self.project(exp_avg, state, merge_dims=group["merge_dims"],
                #                                  max_precond_dim=group['max_precond_dim'])
                exp_avg_projected = exp_avg
                
                step_size = group["lr"]
                if group["correct_bias"]:
                    bias_correction1 = 1.0 - beta1 ** (state["step"])
                    bias_correction2 = 1.0 - beta2 ** (state["step"])
                    step_size = step_size * (bias_correction2 ** .5) / bias_correction1

                # Projecting back the preconditioned (by Adam) exponential moving average of gradients
                # to the original space
                norm_grad = self.project_back(exp_avg_projected / denom, state, merge_dims=group["merge_dims"],
                                                 max_precond_dim=group['max_precond_dim'])

                if group["normalize_grads"]:
                    norm_grad = norm_grad / (1e-30+torch.mean(norm_grad**2)**0.5)
                
                p.add_(norm_grad, alpha=-step_size)
                

                # From AdamW code: Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want to decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                # Add weight decay at the end (fixed version)
                if group["weight_decay"] > 0.0:
                    p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
                    
                # Update is done after the gradient step to avoid using current gradients in the projection.
                self.update_preconditioner(grad, state, 
                                               max_precond_dim=group['max_precond_dim'],
                                               merge_dims=group["merge_dims"],
                                               precondition_1d=group["precondition_1d"])
        
        return loss
    
    def init_preconditioner(self, grad, state, precondition_frequency=10, 
                            shampoo_beta=0.95, max_precond_dim=10000, precondition_1d=False,
                            merge_dims=False):
        """
        Initializes the preconditioner matrices (L and R in the paper).
        """
        state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
        if grad.dim() == 1:
            if not precondition_1d or grad.shape[0] > max_precond_dim:
                state['GG'].append([])
            else:
                state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device))
        else:
            if merge_dims:
                grad = self.merge_dims(grad, max_precond_dim)

            for sh in grad.shape:
                if sh > max_precond_dim:
                    state['GG'].append([])
                else:
                    state['GG'].append(torch.zeros(sh, sh, device=grad.device))

        state['Q'] = None # Will hold all the eigenbases of the preconditioner.
        state['precondition_frequency'] = precondition_frequency
        state['shampoo_beta'] = shampoo_beta          
        
    def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
        """
        Projects the gradient to the eigenbases of the preconditioner.
        """
        original_shape = grad.shape
        if merge_dims:
            if grad.dim() == 4 and self._data_format == 'channels_last':
                permuted_shape = grad.permute(0, 3, 1, 2).shape
            grad = self.merge_dims(grad, max_precond_dim)

        for mat in state['Q']:
            if len(mat) > 0:
                grad = torch.tensordot(
                        grad,
                        mat,
                        dims=[[0], [0]],
                    )
            else:
                permute_order = list(range(1, len(grad.shape))) + [0]
                grad = grad.permute(permute_order)
        
        if merge_dims:
            if self._data_format == 'channels_last' and len(original_shape) == 4:
                grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
            else:
                grad = grad.reshape(original_shape)
        return grad
        
    def update_preconditioner(self, grad, state, 
                              max_precond_dim=10000, merge_dims=False, precondition_1d=False):
        """
        Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
        """
        if state["Q"] is not None:
            state["exp_avg"] = self.project_back(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim)
        if grad.dim() == 1:
            if precondition_1d and grad.shape[0] <= max_precond_dim:
                state['GG'][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1-state['shampoo_beta'])
        else:
            if merge_dims:
                new_grad = self.merge_dims(grad, max_precond_dim)
                for idx, sh in enumerate(new_grad.shape):
                    if sh <= max_precond_dim:
                        outer_product = torch.tensordot(
                                new_grad,
                                new_grad,
                                dims=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2,
                            )
                        state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
            else:
                for idx, sh in enumerate(grad.shape):
                    if sh <= max_precond_dim:
                        outer_product = torch.tensordot(
                                grad,
                                grad,
                                # Contracts across all dimensions except for k.
                                dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2,
                            )
                        state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
                     
        if state['Q'] is None:
            state['Q'] = self.get_orthogonal_matrix(state['GG'])
        if state['step'] > 0 and state['step'] % state['precondition_frequency'] == 0:
            state['Q'] = self.get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims)
            # state['Q'] = self.get_fast_QR(state, max_precond_dim, merge_dims)             

        if state["step"] > 0:
            state["exp_avg"] = self.project(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim) 

    def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
        """
        Projects the gradient back to the original space.
        """
        original_shape = grad.shape
        if merge_dims:
            if self._data_format == 'channels_last' and grad.dim() == 4:
                permuted_shape = grad.permute(0, 3, 1, 2).shape
            grad = self.merge_dims(grad, max_precond_dim)
        for mat in state['Q']:
            if len(mat) > 0:
                grad = torch.tensordot(
                        grad,
                        mat,
                        dims=[[0], [1]],
                    )
            else:
                permute_order = list(range(1, len(grad.shape))) + [0]
                grad = grad.permute(permute_order)
                
        if merge_dims:
            if self._data_format == 'channels_last' and len(original_shape) == 4:
                grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
            else:
                grad = grad.reshape(original_shape)
        return grad
        

    def get_orthogonal_matrix(self, mat):
        """
        Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
        """
        matrix = []
        for m in mat:
            if len(m) == 0:
                matrix.append([])
                continue
            if m.data.dtype != torch.float:
                float_data = False
                original_type = m.data.dtype
                original_device = m.data.device
                matrix.append(m.data.float())
            else:
                float_data = True
                matrix.append(m.data)
        
        final = []
        for m in matrix:
            if len(m) == 0:
                final.append([])
                continue
            try:
                _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device))
            except:
                _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device))
                Q = Q.to(m.dtype)
            Q = torch.flip(Q, [1])

            if not float_data:
                Q = Q.to(original_device).type(original_type)
            final.append(Q)
        return final
        

    def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
        """
        Computes the eigenbases of the preconditioner using one round of power iteration 
        followed by torch.linalg.qr decomposition.
        """
        precond_list = state['GG']
        orth_list = state['Q']

        matrix = []
        orth_matrix = []
        for m,o in zip(precond_list, orth_list):
            if len(m) == 0:
                matrix.append([])
                orth_matrix.append([])
                continue
            if m.data.dtype != torch.float:
                float_data = False
                original_type = m.data.dtype
                original_device = m.data.device
                matrix.append(m.data.float())
                orth_matrix.append(o.data.float())
            else:
                float_data = True
                matrix.append(m.data.float())
                orth_matrix.append(o.data.float())
        
        orig_shape = state['exp_avg_sq'].shape
        if self._data_format == 'channels_last' and len(orig_shape) == 4:
            permuted_shape = state['exp_avg_sq'].permute(0, 3, 1, 2).shape
        if merge_dims:
            exp_avg_sq = self.merge_dims(state['exp_avg_sq'], max_precond_dim)
        else:
            exp_avg_sq = state['exp_avg_sq']
            
        final = []
        for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
            if len(m)==0:
                final.append([])
                continue
            est_eig = torch.diag(o.T @ m @ o)
            sort_idx = torch.argsort(est_eig, descending=True)
            exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
            o = o[:,sort_idx]
            power_iter = m @ o
            Q, _ = torch.linalg.qr(power_iter)

            if not float_data:
                Q = Q.to(original_device).type(original_type)
            final.append(Q)
        
        if merge_dims:
            if self._data_format == 'channels_last' and len(orig_shape) == 4:
                exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
            else:
                exp_avg_sq = exp_avg_sq.reshape(orig_shape)
                
        state['exp_avg_sq'] = exp_avg_sq
        return final


class MLP_torch(torch.nn.Module):


    def __init__(self, in_features, out_features, device=torch.cuda.current_device()):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.randn([out_features, in_features],device=device))
    
    def forward(self, inp):
        # return torch.matmul(inp, self.weight.t())
        return torch.matmul(inp, self.weight.t())

class SampleDataSet(torch.utils.data.Dataset):

    def __init__(self, samples, hidden_state):
        super().__init__()

        assert hidden_state % 64 == 0 and hidden_state >= 64

        self._data = torch.rand([samples, hidden_state])
        self._label = torch.rand([samples, hidden_state//64])

    def __len__(self):
        return self._data.size(0)

    def __getitem__(self, index):
        return self._data[index], self._label[index]
        

if __name__ == "__main__":
    samples = 20
    hidden_state = 1024

    batch_size = 1

    sample_dataset = SampleDataSet(samples, hidden_state)
    sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size)

    model_torch = MLP_torch(hidden_state, hidden_state//64)
    model_paddle = MLP_paddle(hidden_state, hidden_state//64)

    for name, param in model_torch.named_parameters():
        print(f"pddle param name: {name}, mean: {param.mean().item()}, sum: {param.sum().item()}")

    for name, param in model_torch.named_parameters():
        print(f"torch param name: {name}, mean: {param.mean().item()}, sum: {param.sum().item()}")
    
    lr = 0.03
    weight_decay=0
    optimizer_paddle = SOAP_paddle(model_paddle.parameters(), lr, weight_decay=weight_decay)
    criterion_paddle = paddle.nn.L1Loss(reduction='mean')

    optimizer_torch = SOAP(model_torch.parameters(), lr,weight_decay=weight_decay)
    criterion_torch = torch.nn.L1Loss(reduction='mean')
    
    params_paddle = model_paddle.parameters()
    params_torch = [param for param in model_torch.parameters()]

    for param_paddle, param_torch in zip(params_paddle, params_torch):
        param_paddle.set_value(param_torch.detach().cpu().numpy())

    for param_paddle, param_torch in zip(params_paddle, params_torch):
        # check param
        np.testing.assert_allclose(param_torch.detach().cpu().numpy(), param_paddle.numpy(), atol=0)
    
    stop = 20
    for iter,(inp,label) in enumerate(sample_dataloader):
        inp_numpy,label_numpy = inp.numpy(),label.numpy()
        
        out_paddle = model_paddle(paddle.to_tensor(inp_numpy))
        loss_paddle = criterion_paddle(out_paddle, paddle.to_tensor(label_numpy))
        loss_paddle.backward()

        out_torch = model_torch(torch.from_numpy(inp_numpy).to(torch.cuda.current_device()))
        loss_torch = criterion_torch(out_torch, torch.from_numpy(label_numpy).to(torch.cuda.current_device()))
        loss_torch.backward()


        optimizer_paddle.step()
        optimizer_torch.step()
        
        state_paddle = optimizer_paddle.state 
        state_torch = optimizer_torch.state 


        optimizer_paddle.clear_grad()
        optimizer_torch.zero_grad()

        print(f"[ITERATION] {(iter+1)}/{len(sample_dataloader)} loss paddle: {loss_paddle.item():.6f}, loss torch: {loss_torch.item():.6f}")

        if iter >= stop:
            break

@BeingGod BeingGod changed the title 【Hackathon 8th No.12】在 PaddleScience 中实现 SOAP 优化器 [WIP] 【Hackathon 8th No.12】在 PaddleScience 中实现 SOAP 优化器 Mar 13, 2025
Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里修改一下,to_tensor自带blocking,避免多次调用

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在allen_cahn.md文档开头,更新一下soap优化器的模型指标 L2Rel.u: 6.8e-6,以及预训练模型url:https://paddle-org.bj.bcebos.com/paddlescience%2Fmodels%2FAllenCahn%2Fallen_cahn_piratenet_soap_pretrained.pdparams.pdparams

@BeingGod
Copy link
Contributor Author

在allen_cahn.md文档开头,更新一下soap优化器的模型指标 L2Rel.u: 6.8e-6,以及预训练模型url:https://paddle-org.bj.bcebos.com/paddlescience%2Fmodels%2FAllenCahn%2Fallen_cahn_piratenet_soap_pretrained.pdparams.pdparams

Done.

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HydrogenSulfate HydrogenSulfate merged commit 44cd7a0 into PaddlePaddle:develop Mar 14, 2025
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants