Skip to content

Commit e85c449

Browse files
author
puyuan
committed
fix(pu): fix task_id bug in balance pipeline, and polish benchmark_name option
1 parent 39ee55e commit e85c449

5 files changed

+98
-93
lines changed

lzero/entry/train_unizero_multitask_balance_segment_ddp.py

Lines changed: 74 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -27,72 +27,6 @@
2727
import numpy as np # 计算均值
2828
from collections import defaultdict # 保存所有任务最近一次评估分数
2929

30-
# ====== UniZero-MT 需要用到的基准分数(与 26 个 Atari100k 任务 id 一一对应)======
31-
# 原始的 RANDOM_SCORES 和 HUMAN_SCORES
32-
33-
34-
global BENCHMARK_NAME
35-
BENCHMARK_NAME = "atari"
36-
# BENCHMARK_NAME = "dmc" # TODO
37-
if BENCHMARK_NAME == "atari":
38-
RANDOM_SCORES = np.array([
39-
227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5,
40-
152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3,
41-
-20.7, 24.9, 163.9, 11.5, 68.4, 533.4
42-
])
43-
HUMAN_SCORES = np.array([
44-
7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4,
45-
1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6,
46-
14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2
47-
])
48-
elif BENCHMARK_NAME == "dmc":
49-
RANDOM_SCORES = np.array([0]*26)
50-
HUMAN_SCORES = np.array([1000]*26)
51-
52-
# 新顺序对应的原始索引列表
53-
# 新顺序: [Pong, MsPacman, Seaquest, Boxing, Alien, ChopperCommand, Hero, RoadRunner,
54-
# Amidar, Assault, Asterix, BankHeist, BattleZone, CrazyClimber, DemonAttack,
55-
# Freeway, Frostbite, Gopher, Jamesbond, Kangaroo, Krull, KungFuMaster,
56-
# PrivateEye, UpNDown, Qbert, Breakout]
57-
# 映射为原始数组中的索引(注意:索引均从0开始)
58-
new_order = [
59-
20, # Pong
60-
19, # MsPacman
61-
24, # Seaquest
62-
6, # Boxing
63-
0, # Alien
64-
8, # ChopperCommand
65-
14, # Hero
66-
23, # RoadRunner
67-
1, # Amidar
68-
2, # Assault
69-
3, # Asterix
70-
4, # BankHeist
71-
5, # BattleZone
72-
9, # CrazyClimber
73-
10, # DemonAttack
74-
11, # Freeway
75-
12, # Frostbite
76-
13, # Gopher
77-
15, # Jamesbond
78-
16, # Kangaroo
79-
17, # Krull
80-
18, # KungFuMaster
81-
21, # PrivateEye
82-
25, # UpNDown
83-
22, # Qbert
84-
7 # Breakout
85-
]
86-
87-
# 根据 new_order 生成新的数组
88-
new_RANDOM_SCORES = RANDOM_SCORES[new_order]
89-
new_HUMAN_SCORES = HUMAN_SCORES[new_order]
90-
91-
# 查看重排后的结果
92-
print("重排后的 RANDOM_SCORES:")
93-
print(new_RANDOM_SCORES)
94-
print("\n重排后的 HUMAN_SCORES:")
95-
print(new_HUMAN_SCORES)
9630

9731
# 保存最近一次评估回报:{task_id: eval_episode_return_mean}
9832
from collections import defaultdict
@@ -354,6 +288,7 @@ def train_unizero_multitask_balance_segment_ddp(
354288
model_path: Optional[str] = None,
355289
max_train_iter: Optional[int] = int(1e10),
356290
max_env_step: Optional[int] = int(1e10),
291+
benchmark_name: str = "atari"
357292
) -> 'Policy':
358293
"""
359294
Overview:
@@ -378,6 +313,73 @@ def train_unizero_multitask_balance_segment_ddp(
378313
Returns:
379314
- policy (:obj:`Policy`): 收敛的策略。
380315
"""
316+
317+
# ---------------------------------------------------------------
318+
# ====== UniZero-MT 需要用到的基准分数(与 26 个 Atari100k 任务 id 一一对应)======
319+
# 原始的 RANDOM_SCORES 和 HUMAN_SCORES
320+
if benchmark_name == "atari":
321+
RANDOM_SCORES = np.array([
322+
227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5,
323+
152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3,
324+
-20.7, 24.9, 163.9, 11.5, 68.4, 533.4
325+
])
326+
HUMAN_SCORES = np.array([
327+
7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4,
328+
1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6,
329+
14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2
330+
])
331+
elif benchmark_name == "dmc":
332+
# RANDOM_SCORES = np.array([0]*26)
333+
# HUMAN_SCORES = np.array([1000]*26)
334+
RANDOM_SCORES = np.zeros(26)
335+
HUMAN_SCORES = np.ones(26) * 1000
336+
else:
337+
raise ValueError(f"Unsupported BENCHMARK_NAME: {BENCHMARK_NAME}")
338+
339+
# 新顺序对应的原始索引列表
340+
# 新顺序: [Pong, MsPacman, Seaquest, Boxing, Alien, ChopperCommand, Hero, RoadRunner,
341+
# Amidar, Assault, Asterix, BankHeist, BattleZone, CrazyClimber, DemonAttack,
342+
# Freeway, Frostbite, Gopher, Jamesbond, Kangaroo, Krull, KungFuMaster,
343+
# PrivateEye, UpNDown, Qbert, Breakout]
344+
# 映射为原始数组中的索引(注意:索引均从0开始)
345+
new_order = [
346+
20, # Pong
347+
19, # MsPacman
348+
24, # Seaquest
349+
6, # Boxing
350+
0, # Alien
351+
8, # ChopperCommand
352+
14, # Hero
353+
23, # RoadRunner
354+
1, # Amidar
355+
2, # Assault
356+
3, # Asterix
357+
4, # BankHeist
358+
5, # BattleZone
359+
9, # CrazyClimber
360+
10, # DemonAttack
361+
11, # Freeway
362+
12, # Frostbite
363+
13, # Gopher
364+
15, # Jamesbond
365+
16, # Kangaroo
366+
17, # Krull
367+
18, # KungFuMaster
368+
21, # PrivateEye
369+
25, # UpNDown
370+
22, # Qbert
371+
7 # Breakout
372+
]
373+
# 根据 new_order 生成新的数组
374+
new_RANDOM_SCORES = RANDOM_SCORES[new_order]
375+
new_HUMAN_SCORES = HUMAN_SCORES[new_order]
376+
# 查看重排后的结果
377+
print("重排后的 RANDOM_SCORES:")
378+
print(new_RANDOM_SCORES)
379+
print("\n重排后的 HUMAN_SCORES:")
380+
print(new_HUMAN_SCORES)
381+
# ---------------------------------------------------------------
382+
381383
# 初始化温度调度器
382384
initial_temperature = 10.0
383385
final_temperature = 1.0
@@ -552,7 +554,8 @@ def train_unizero_multitask_balance_segment_ddp(
552554
# TODO: ============
553555
# cfg.policy.target_return = 10
554556
# ==================== 如果任务已解决,则不参与后续评估和采集 TODO: ddp ====================
555-
if task_id in solved_task_pool:
557+
# if task_id in solved_task_pool:
558+
if cfg.policy.task_id in solved_task_pool:
556559
continue
557560

558561
# 记录缓冲区内存使用情况
@@ -601,8 +604,10 @@ def train_unizero_multitask_balance_segment_ddp(
601604

602605
# 如果达到目标奖励,将任务移入 solved_task_pool
603606
if eval_mean_reward >= cfg.policy.target_return:
604-
print(f"任务 {task_id} 达到了目标奖励 {cfg.policy.target_return}, 移入 solved_task_pool.")
605-
solved_task_pool.add(task_id)
607+
cur_task_id = cfg.policy.task_id
608+
print(f"任务 {cur_task_id} 达到了目标奖励 {cfg.policy.target_return}, 移入 solved_task_pool.")
609+
solved_task_pool.add(cur_task_id)
610+
606611

607612
except Exception as e:
608613
print(f"提取评估奖励时发生错误: {e}")

lzero/entry/train_unizero_multitask_segment_ddp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828

2929

3030
global BENCHMARK_NAME
31-
BENCHMARK_NAME = "atari"
32-
# BENCHMARK_NAME = "dmc" # TODO
31+
# BENCHMARK_NAME = "atari"
32+
BENCHMARK_NAME = "dmc" # TODO
3333
if BENCHMARK_NAME == "atari":
3434
RANDOM_SCORES = np.array([
3535
227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5,

zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
190190
n_episode=n_episode,
191191
replay_buffer_size=int(5e5),
192192
# eval_freq=int(1e4),
193-
eval_freq=int(1.5e4),
193+
eval_freq=int(1e4),
194194
# eval_freq=int(2),
195195
collector_env_num=collector_env_num,
196196
evaluator_env_num=evaluator_env_num,
@@ -208,7 +208,7 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod
208208
# ===== only for debug =====
209209
# exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-encoder-ps8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
210210
# exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_no-encoder-scale_cnn-encoder_moe8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
211-
exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-ln_moe8_trans-nlayer4_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
211+
exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250514/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-ln_moe8_trans-nlayer4_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
212212

213213
for task_id, env_id in enumerate(env_id_list):
214214
config = create_config(
@@ -404,11 +404,10 @@ def get_atari_target_return_dict(ratio=1.0):
404404
]
405405

406406
global curriculum_stage_num
407-
407+
# TODO ==============
408408
curriculum_stage_num=3
409-
# curriculum_stage_num=5
410-
curriculum_stage_num=9
411-
409+
curriculum_stage_num=5
410+
# curriculum_stage_num=9
412411

413412
action_space_size = 18
414413
collector_env_num = 8
@@ -461,6 +460,6 @@ def get_atari_target_return_dict(ratio=1.0):
461460
num_segments, total_batch_size)
462461

463462
with DDPContext():
464-
train_unizero_multitask_balance_segment_ddp(configs, seed=seed, max_env_step=max_env_step)
463+
train_unizero_multitask_balance_segment_ddp(configs, seed=seed, max_env_step=max_env_step, benchmark_name="atari")
465464
# ======== TODO: only for debug ========
466465
# train_unizero_multitask_segment_ddp(configs[:2], seed=seed, max_env_step=max_env_step) # train on the first four tasks

zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
6464
policy=dict(
6565
multi_gpu=True, # Very important for ddp
6666
only_use_moco_stats=False,
67-
# use_moco=False, # ==============TODO==============
68-
use_moco=True, # ==============TODO==============
67+
use_moco=False, # ==============TODO==============
68+
# use_moco=True, # ==============TODO==============
6969
learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))),
7070
grad_correct_params=dict(
7171
MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0,
@@ -99,8 +99,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
9999

100100
analysis_dormant_ratio_weight_rank=True,
101101
# analysis_dormant_ratio_weight_rank=False, # TODO
102-
analysis_dormant_ratio_interval=100,
103-
# analysis_dormant_ratio_interval=1000,
102+
# analysis_dormant_ratio_interval=100,
103+
analysis_dormant_ratio_interval=1000,
104104
# analysis_dormant_ratio_interval=20,
105105

106106
continuous_action_space=False,
@@ -123,6 +123,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
123123
# num_heads=24,
124124

125125
num_layers=8,
126+
126127
# num_layers=12, # todo
127128
num_heads=24,
128129

@@ -134,8 +135,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
134135
obs_type='image',
135136
env_num=8,
136137
task_num=len(env_id_list),
137-
# encoder_type='vit',
138-
encoder_type='resnet',
138+
encoder_type='vit',
139+
# encoder_type='resnet',
139140

140141
use_normal_head=True,
141142
use_softmoe_head=False,
@@ -197,7 +198,9 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod
197198
configs = []
198199
# ===== only for debug =====
199200
# ========= TODO: global BENCHMARK_NAME =========
200-
exp_name_prefix = f'data_lz/data_unizero_atari_mt_20250508/atari_{len(env_id_list)}games_orig-ln_moco_tran-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
201+
exp_name_prefix = f'data_lz/data_unizero_atari_mt_20250508/atari_{len(env_id_list)}games_orig_vit-ln_tran-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
202+
203+
# exp_name_prefix = f'data_lz/data_unizero_atari_mt_20250508/atari_{len(env_id_list)}games_orig-ln_moco_tran-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
201204

202205
# exp_name_prefix = f'data_lz/data_unizero_atari_mt_20250508/atari_{len(env_id_list)}games_orig_simnorm_tran-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
203206

@@ -251,7 +254,7 @@ def create_env_manager():
251254
Overview:
252255
This script should be executed with <nproc_per_node> GPUs.
253256
Run the following command to launch the script:
254-
python -m torch.distributed.launch --nproc_per_node=8 --master_port=29502 ./zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee ./log/20250509/uz_mt_atari8_orig-ln_moco.log
257+
python -m torch.distributed.launch --nproc_per_node=8 --master_port=29502 ./zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee ./log/20250509/uz_mt_atari8_orig_vit-ln.log
255258
python -m torch.distributed.launch --nproc_per_node=8 --master_port=29502 ./zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee ./log/uz_mt_atari8_orig-simnorm.log
256259
257260

zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def generate_configs(env_id_list: List[str],
205205
configs = []
206206
# ========= TODO: global BENCHMARK_NAME =========
207207

208-
exp_name_prefix = f'data_lz/data_suz_dmc_mt_balance_20250509/dmc_{len(env_id_list)}tasks_frameskip8_balance-stage-total-{curriculum_stage_num}_moe8_nlayer4_not-share-head_brf{buffer_reanalyze_freq}_seed{seed}/'
208+
exp_name_prefix = f'data_lz/data_suz_dmc_mt_balance_20250514/dmc_{len(env_id_list)}tasks_frameskip8_balance-stage-total-{curriculum_stage_num}_moe8_nlayer4_not-share-head_brf{buffer_reanalyze_freq}_seed{seed}/'
209209

210210
# exp_name_prefix = f'data_lz/data_suz_dmc_mt_20250409_moco/dmc_{len(env_id_list)}tasks_notaskembed_nlayer8_not-share-head_final-ln_bs64_brf{buffer_reanalyze_freq}_seed{seed}/'
211211

@@ -269,8 +269,6 @@ def create_env_manager():
269269
import os
270270
from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map
271271

272-
global BENCHMARK_NAME
273-
BENCHMARK_NAME='dmc'
274272

275273
global curriculum_stage_num
276274

@@ -303,9 +301,9 @@ def create_env_manager():
303301
]
304302

305303
target_return_dict = {
306-
'acrobot-swingup': 500,
307-
'cartpole-balance':950,
308-
'cartpole-balance_sparse':950,
304+
'acrobot-swingup': 500, # 0
305+
'cartpole-balance':950, # 1
306+
'cartpole-balance_sparse':950, # 2
309307
'cartpole-swingup': 800, # 3
310308
'cartpole-swingup_sparse': 750, # 4
311309
'cheetah-run': 650, # 5
@@ -405,6 +403,6 @@ def create_env_manager():
405403
)
406404

407405
with DDPContext():
408-
train_unizero_multitask_balance_segment_ddp(configs, seed=seed, max_env_step=max_env_step)
406+
train_unizero_multitask_balance_segment_ddp(configs, seed=seed, max_env_step=max_env_step, benchmark_name="dmc")
409407
# 如果只想训练部分任务,可以修改 configs,例如:
410408
# train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step)

0 commit comments

Comments
 (0)