23
23
from lzero .model .common import SimNorm
24
24
import logging
25
25
26
+ class LearnableScale (nn .Module ):
27
+ """
28
+ 可学习且有界的标量参数:
29
+ s = s_max * sigmoid(ŝ) (0, s_max)
30
+ """
31
+ def __init__ (self , init = 1.0 , s_max = 1.5 ):
32
+ super ().__init__ ()
33
+ # 反推初始值
34
+ inv_sig = math .log (init / (s_max - init + 1e-9 ))
35
+ self .logit = nn .Parameter (torch .tensor (inv_sig ))
36
+ self .logit .requires_grad = True # TODO
37
+ self .s_max = s_max
38
+
39
+ def forward (self ):
40
+ return self .s_max * torch .sigmoid (self .logit )
26
41
##############################################
27
42
# CurriculumLoRALinear 实现
28
43
##############################################
@@ -74,7 +89,9 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
74
89
75
90
# 初始化 LoRA adapter,只有在 r > 0 且 curriculum_stage_num > 1 时才存在
76
91
self .adapters = nn .ModuleList ()
77
- self .adapter_scales = nn .ParameterList ()
92
+ # self.adapter_scales = nn.ParameterList()
93
+ self .adapter_scales = nn .ModuleList ()
94
+
78
95
if r > 0 and (curriculum_stage_num - 1 ) > 0 :
79
96
for i in range (curriculum_stage_num - 1 ):
80
97
adapter = nn .ParameterDict ({
@@ -83,9 +100,15 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
83
100
})
84
101
self .adapters .append (adapter )
85
102
86
- self .adapter_scales .append ( # ← 新增
87
- nn .Parameter (torch .tensor (lora_scale_init , dtype = torch .float32 ))
88
- )
103
+ self .adapter_scales .append (LearnableScale (lora_scale_init , s_max = 1.5 ))
104
+
105
+ # self.adapter_scales.append( # ← 新增
106
+ # nn.Parameter(torch.tensor(lora_scale_init, dtype=torch.float32))
107
+ # )
108
+
109
+ # --- CurriculumLoRALinear.__init__() ------------
110
+ # for p in self.adapter_scales:
111
+ # p.requires_grad = True # 统一设 True,避免遗漏
89
112
else :
90
113
self .adapters = None
91
114
@@ -121,17 +144,21 @@ def set_curriculum_stage(self, stage: int):
121
144
for idx , adapter in enumerate (self .adapters ):
122
145
adapter ['lora_A' ].requires_grad = False
123
146
adapter ['lora_B' ].requires_grad = False
124
- self .adapter_scales [idx ].requires_grad = True # ← 新增
147
+ # self.adapter_scales[idx].requires_grad = True # ← 新增
125
148
logging .info (f"[CurriculumLoRALinear { module_id } ] Stage 0: 基础层可训练,所有 adapter 均冻结。" )
126
149
logging .info (f"[self.adapter_scales:] { self .adapter_scales } " )
150
+ logging .info (f"self.adapter_scales[0].item(): { self .adapter_scales [0 ]().item ()} " )
151
+
127
152
else :
128
153
# 阶段大于 0,冻结基础层
129
154
self .weight .requires_grad = False
130
155
if self .bias is not None :
131
156
self .bias .requires_grad = False
132
157
for idx , adapter in enumerate (self .adapters ):
133
- self .adapter_scales [idx ].requires_grad = True # ← 新增
158
+ # self.adapter_scales[idx].requires_grad = True # ← 新增
134
159
logging .info (f"[self.adapter_scales:] { self .adapter_scales } " )
160
+ logging .info (f"self.adapter_scales[0].item(): { self .adapter_scales [0 ]().item ()} " )
161
+
135
162
if idx == stage - 1 :
136
163
adapter ['lora_A' ].requires_grad = True
137
164
adapter ['lora_B' ].requires_grad = True
@@ -154,9 +181,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
154
181
adapter = self .adapters [idx ]
155
182
out = F .linear (self .lora_dropout (x ), adapter ['lora_A' ])
156
183
out = F .linear (out , adapter ['lora_B' ])
157
- scale = self .adapter_scales [idx ] # TODO
184
+ scale = self .adapter_scales [idx ]() # TODO: 所有adapter 对应的scale都参与训练
158
185
if idx == self .curriculum_stage - 1 :
159
- adapter_out = adapter_out + self .scaling * out * scale # 当前 adapter参与更新
186
+ adapter_out = adapter_out + self .scaling * out * scale # 仅当前 adapter 参与更新
160
187
else :
161
188
adapter_out = adapter_out + self .scaling * out .detach () * scale
162
189
return baseline_out + adapter_out
0 commit comments