Skip to content

Commit c8e7cb8

Browse files
author
puyuan
committed
polish(pu): add stable_adaptor_scale
1 parent 7dd6c04 commit c8e7cb8

8 files changed

+877
-68
lines changed

lzero/entry/train_unizero_multitask_balance_segment_ddp.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,19 @@
3232
from collections import defaultdict
3333
GLOBAL_EVAL_RETURNS: dict[int, float] = defaultdict(lambda: None)
3434

35+
def log_param_statistics(model, logger=logging):
36+
n_tensors_total = sum(1 for _ in model.parameters())
37+
n_tensors_train = sum(p.requires_grad for p in model.parameters())
38+
39+
n_elems_total = sum(p.numel() for p in model.parameters())
40+
n_elems_train = sum(p.numel() for p in model.parameters() if p.requires_grad)
41+
42+
logger.info(
43+
f'Trainable parameters: '
44+
f'{n_tensors_train}/{n_tensors_total} tensors | '
45+
f'{n_elems_train:,}/{n_elems_total:,} elements '
46+
f'(~{n_elems_train/1e6:.2f} M / {n_elems_total/1e6:.2f} M)'
47+
)
3548

3649
def tasks_per_stage(unsolved: int, remain_lora: int) -> int:
3750
"""
@@ -84,6 +97,10 @@ def step(self, solved_cnt: int, unsolved_cnt: int, train_iter: int):
8497
logging.info(f'[Curriculum] switch to stage {self.stage} '
8598
f'(solved={solved_cnt}, unsolved={unsolved_cnt}, '
8699
f'iter={train_iter})')
100+
101+
updated = sum(p.requires_grad for p in self.policy._learn_model.world_model.parameters())
102+
logging.info(f'{updated}/{sum(1 for _ in self.policy._learn_model.world_model.parameters())} params will be optimized')
103+
log_param_statistics(self.policy._learn_model.world_model) # 再打印一次,看看数值变化
87104
self.last_solved = solved_cnt
88105
self.last_switch_iter = train_iter
89106
return True
@@ -595,6 +612,9 @@ def train_unizero_multitask_balance_segment_ddp(
595612
# 初始化一次(rank0 或各 rank 均可)
596613
curr_ctrl = CurriculumController(cfg, policy)
597614

615+
updated = sum(p.requires_grad for p in policy._learn_model.world_model.parameters())
616+
logging.info(f'{updated}/{sum(1 for _ in policy._learn_model.world_model.parameters())} params will be optimized')
617+
598618
while True:
599619
last_curriculum_stage = cur_curriculum_stage
600620

@@ -814,9 +834,14 @@ def train_unizero_multitask_balance_segment_ddp(
814834
for module_name, module in transformer.named_modules():
815835
if isinstance(module, CurriculumLoRALinear) and module.adapters is not None:
816836
for adapter_idx, scale_param in enumerate(module.adapter_scales):
837+
# tb_logger.add_scalar(
838+
# f'UniZero-MT/adapter_scales/{module_name}/adapter_{adapter_idx}',
839+
# scale_param.item(),
840+
# global_step=learner.train_iter
841+
# )
817842
tb_logger.add_scalar(
818843
f'UniZero-MT/adapter_scales/{module_name}/adapter_{adapter_idx}',
819-
scale_param.item(),
844+
scale_param().item(),
820845
global_step=learner.train_iter
821846
)
822847

lzero/model/unizero_world_models/transformer.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,21 @@
2323
from lzero.model.common import SimNorm
2424
import logging
2525

26+
class LearnableScale(nn.Module):
27+
"""
28+
可学习且有界的标量参数:
29+
s = s_max * sigmoid(ŝ) (0, s_max)
30+
"""
31+
def __init__(self, init=1.0, s_max=1.5):
32+
super().__init__()
33+
# 反推初始值
34+
inv_sig = math.log(init / (s_max - init + 1e-9))
35+
self.logit = nn.Parameter(torch.tensor(inv_sig))
36+
self.logit.requires_grad = True # TODO
37+
self.s_max = s_max
38+
39+
def forward(self):
40+
return self.s_max * torch.sigmoid(self.logit)
2641
##############################################
2742
# CurriculumLoRALinear 实现
2843
##############################################
@@ -74,7 +89,9 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
7489

7590
# 初始化 LoRA adapter,只有在 r > 0 且 curriculum_stage_num > 1 时才存在
7691
self.adapters = nn.ModuleList()
77-
self.adapter_scales = nn.ParameterList()
92+
# self.adapter_scales = nn.ParameterList()
93+
self.adapter_scales = nn.ModuleList()
94+
7895
if r > 0 and (curriculum_stage_num - 1) > 0:
7996
for i in range(curriculum_stage_num - 1):
8097
adapter = nn.ParameterDict({
@@ -83,9 +100,15 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
83100
})
84101
self.adapters.append(adapter)
85102

86-
self.adapter_scales.append( # ← 新增
87-
nn.Parameter(torch.tensor(lora_scale_init, dtype=torch.float32))
88-
)
103+
self.adapter_scales.append(LearnableScale(lora_scale_init, s_max=1.5))
104+
105+
# self.adapter_scales.append( # ← 新增
106+
# nn.Parameter(torch.tensor(lora_scale_init, dtype=torch.float32))
107+
# )
108+
109+
# --- CurriculumLoRALinear.__init__() ------------
110+
# for p in self.adapter_scales:
111+
# p.requires_grad = True # 统一设 True,避免遗漏
89112
else:
90113
self.adapters = None
91114

@@ -121,17 +144,21 @@ def set_curriculum_stage(self, stage: int):
121144
for idx, adapter in enumerate(self.adapters):
122145
adapter['lora_A'].requires_grad = False
123146
adapter['lora_B'].requires_grad = False
124-
self.adapter_scales[idx].requires_grad = True # ← 新增
147+
# self.adapter_scales[idx].requires_grad = True # ← 新增
125148
logging.info(f"[CurriculumLoRALinear {module_id}] Stage 0: 基础层可训练,所有 adapter 均冻结。")
126149
logging.info(f"[self.adapter_scales:] {self.adapter_scales}")
150+
logging.info(f"self.adapter_scales[0].item(): {self.adapter_scales[0]().item()}")
151+
127152
else:
128153
# 阶段大于 0,冻结基础层
129154
self.weight.requires_grad = False
130155
if self.bias is not None:
131156
self.bias.requires_grad = False
132157
for idx, adapter in enumerate(self.adapters):
133-
self.adapter_scales[idx].requires_grad = True # ← 新增
158+
# self.adapter_scales[idx].requires_grad = True # ← 新增
134159
logging.info(f"[self.adapter_scales:] {self.adapter_scales}")
160+
logging.info(f"self.adapter_scales[0].item(): {self.adapter_scales[0]().item()}")
161+
135162
if idx == stage - 1:
136163
adapter['lora_A'].requires_grad = True
137164
adapter['lora_B'].requires_grad = True
@@ -154,9 +181,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
154181
adapter = self.adapters[idx]
155182
out = F.linear(self.lora_dropout(x), adapter['lora_A'])
156183
out = F.linear(out, adapter['lora_B'])
157-
scale = self.adapter_scales[idx] # TODO
184+
scale = self.adapter_scales[idx]() # TODO: 所有adapter 对应的scale都参与训练
158185
if idx == self.curriculum_stage - 1:
159-
adapter_out = adapter_out + self.scaling * out * scale # 当前 adapter参与更新
186+
adapter_out = adapter_out + self.scaling * out * scale # 仅当前 adapter 参与更新
160187
else:
161188
adapter_out = adapter_out + self.scaling * out.detach() * scale
162189
return baseline_out + adapter_out

0 commit comments

Comments
 (0)