27
27
import numpy as np # 计算均值
28
28
from collections import defaultdict # 保存所有任务最近一次评估分数
29
29
30
- # ====== UniZero-MT 需要用到的基准分数(与 26 个 Atari100k 任务 id 一一对应)======
31
- # 原始的 RANDOM_SCORES 和 HUMAN_SCORES
32
-
33
-
34
- global BENCHMARK_NAME
35
- BENCHMARK_NAME = "atari"
36
- # BENCHMARK_NAME = "dmc" # TODO
37
- if BENCHMARK_NAME == "atari" :
38
- RANDOM_SCORES = np .array ([
39
- 227.8 , 5.8 , 222.4 , 210.0 , 14.2 , 2360.0 , 0.1 , 1.7 , 811.0 , 10780.5 ,
40
- 152.1 , 0.0 , 65.2 , 257.6 , 1027.0 , 29.0 , 52.0 , 1598.0 , 258.5 , 307.3 ,
41
- - 20.7 , 24.9 , 163.9 , 11.5 , 68.4 , 533.4
42
- ])
43
- HUMAN_SCORES = np .array ([
44
- 7127.7 , 1719.5 , 742.0 , 8503.3 , 753.1 , 37187.5 , 12.1 , 30.5 , 7387.8 , 35829.4 ,
45
- 1971.0 , 29.6 , 4334.7 , 2412.5 , 30826.4 , 302.8 , 3035.0 , 2665.5 , 22736.3 , 6951.6 ,
46
- 14.6 , 69571.3 , 13455.0 , 7845.0 , 42054.7 , 11693.2
47
- ])
48
- elif BENCHMARK_NAME == "dmc" :
49
- RANDOM_SCORES = np .array ([0 ]* 26 )
50
- HUMAN_SCORES = np .array ([1000 ]* 26 )
51
-
52
- # 新顺序对应的原始索引列表
53
- # 新顺序: [Pong, MsPacman, Seaquest, Boxing, Alien, ChopperCommand, Hero, RoadRunner,
54
- # Amidar, Assault, Asterix, BankHeist, BattleZone, CrazyClimber, DemonAttack,
55
- # Freeway, Frostbite, Gopher, Jamesbond, Kangaroo, Krull, KungFuMaster,
56
- # PrivateEye, UpNDown, Qbert, Breakout]
57
- # 映射为原始数组中的索引(注意:索引均从0开始)
58
- new_order = [
59
- 20 , # Pong
60
- 19 , # MsPacman
61
- 24 , # Seaquest
62
- 6 , # Boxing
63
- 0 , # Alien
64
- 8 , # ChopperCommand
65
- 14 , # Hero
66
- 23 , # RoadRunner
67
- 1 , # Amidar
68
- 2 , # Assault
69
- 3 , # Asterix
70
- 4 , # BankHeist
71
- 5 , # BattleZone
72
- 9 , # CrazyClimber
73
- 10 , # DemonAttack
74
- 11 , # Freeway
75
- 12 , # Frostbite
76
- 13 , # Gopher
77
- 15 , # Jamesbond
78
- 16 , # Kangaroo
79
- 17 , # Krull
80
- 18 , # KungFuMaster
81
- 21 , # PrivateEye
82
- 25 , # UpNDown
83
- 22 , # Qbert
84
- 7 # Breakout
85
- ]
86
-
87
- # 根据 new_order 生成新的数组
88
- new_RANDOM_SCORES = RANDOM_SCORES [new_order ]
89
- new_HUMAN_SCORES = HUMAN_SCORES [new_order ]
90
-
91
- # 查看重排后的结果
92
- print ("重排后的 RANDOM_SCORES:" )
93
- print (new_RANDOM_SCORES )
94
- print ("\n 重排后的 HUMAN_SCORES:" )
95
- print (new_HUMAN_SCORES )
96
30
97
31
# 保存最近一次评估回报:{task_id: eval_episode_return_mean}
98
32
from collections import defaultdict
@@ -354,6 +288,7 @@ def train_unizero_multitask_balance_segment_ddp(
354
288
model_path : Optional [str ] = None ,
355
289
max_train_iter : Optional [int ] = int (1e10 ),
356
290
max_env_step : Optional [int ] = int (1e10 ),
291
+ benchmark_name : str = "atari"
357
292
) -> 'Policy' :
358
293
"""
359
294
Overview:
@@ -378,6 +313,73 @@ def train_unizero_multitask_balance_segment_ddp(
378
313
Returns:
379
314
- policy (:obj:`Policy`): 收敛的策略。
380
315
"""
316
+
317
+ # ---------------------------------------------------------------
318
+ # ====== UniZero-MT 需要用到的基准分数(与 26 个 Atari100k 任务 id 一一对应)======
319
+ # 原始的 RANDOM_SCORES 和 HUMAN_SCORES
320
+ if benchmark_name == "atari" :
321
+ RANDOM_SCORES = np .array ([
322
+ 227.8 , 5.8 , 222.4 , 210.0 , 14.2 , 2360.0 , 0.1 , 1.7 , 811.0 , 10780.5 ,
323
+ 152.1 , 0.0 , 65.2 , 257.6 , 1027.0 , 29.0 , 52.0 , 1598.0 , 258.5 , 307.3 ,
324
+ - 20.7 , 24.9 , 163.9 , 11.5 , 68.4 , 533.4
325
+ ])
326
+ HUMAN_SCORES = np .array ([
327
+ 7127.7 , 1719.5 , 742.0 , 8503.3 , 753.1 , 37187.5 , 12.1 , 30.5 , 7387.8 , 35829.4 ,
328
+ 1971.0 , 29.6 , 4334.7 , 2412.5 , 30826.4 , 302.8 , 3035.0 , 2665.5 , 22736.3 , 6951.6 ,
329
+ 14.6 , 69571.3 , 13455.0 , 7845.0 , 42054.7 , 11693.2
330
+ ])
331
+ elif benchmark_name == "dmc" :
332
+ # RANDOM_SCORES = np.array([0]*26)
333
+ # HUMAN_SCORES = np.array([1000]*26)
334
+ RANDOM_SCORES = np .zeros (26 )
335
+ HUMAN_SCORES = np .ones (26 ) * 1000
336
+ else :
337
+ raise ValueError (f"Unsupported BENCHMARK_NAME: { BENCHMARK_NAME } " )
338
+
339
+ # 新顺序对应的原始索引列表
340
+ # 新顺序: [Pong, MsPacman, Seaquest, Boxing, Alien, ChopperCommand, Hero, RoadRunner,
341
+ # Amidar, Assault, Asterix, BankHeist, BattleZone, CrazyClimber, DemonAttack,
342
+ # Freeway, Frostbite, Gopher, Jamesbond, Kangaroo, Krull, KungFuMaster,
343
+ # PrivateEye, UpNDown, Qbert, Breakout]
344
+ # 映射为原始数组中的索引(注意:索引均从0开始)
345
+ new_order = [
346
+ 20 , # Pong
347
+ 19 , # MsPacman
348
+ 24 , # Seaquest
349
+ 6 , # Boxing
350
+ 0 , # Alien
351
+ 8 , # ChopperCommand
352
+ 14 , # Hero
353
+ 23 , # RoadRunner
354
+ 1 , # Amidar
355
+ 2 , # Assault
356
+ 3 , # Asterix
357
+ 4 , # BankHeist
358
+ 5 , # BattleZone
359
+ 9 , # CrazyClimber
360
+ 10 , # DemonAttack
361
+ 11 , # Freeway
362
+ 12 , # Frostbite
363
+ 13 , # Gopher
364
+ 15 , # Jamesbond
365
+ 16 , # Kangaroo
366
+ 17 , # Krull
367
+ 18 , # KungFuMaster
368
+ 21 , # PrivateEye
369
+ 25 , # UpNDown
370
+ 22 , # Qbert
371
+ 7 # Breakout
372
+ ]
373
+ # 根据 new_order 生成新的数组
374
+ new_RANDOM_SCORES = RANDOM_SCORES [new_order ]
375
+ new_HUMAN_SCORES = HUMAN_SCORES [new_order ]
376
+ # 查看重排后的结果
377
+ print ("重排后的 RANDOM_SCORES:" )
378
+ print (new_RANDOM_SCORES )
379
+ print ("\n 重排后的 HUMAN_SCORES:" )
380
+ print (new_HUMAN_SCORES )
381
+ # ---------------------------------------------------------------
382
+
381
383
# 初始化温度调度器
382
384
initial_temperature = 10.0
383
385
final_temperature = 1.0
@@ -552,7 +554,8 @@ def train_unizero_multitask_balance_segment_ddp(
552
554
# TODO: ============
553
555
# cfg.policy.target_return = 10
554
556
# ==================== 如果任务已解决,则不参与后续评估和采集 TODO: ddp ====================
555
- if task_id in solved_task_pool :
557
+ # if task_id in solved_task_pool:
558
+ if cfg .policy .task_id in solved_task_pool :
556
559
continue
557
560
558
561
# 记录缓冲区内存使用情况
@@ -601,8 +604,10 @@ def train_unizero_multitask_balance_segment_ddp(
601
604
602
605
# 如果达到目标奖励,将任务移入 solved_task_pool
603
606
if eval_mean_reward >= cfg .policy .target_return :
604
- print (f"任务 { task_id } 达到了目标奖励 { cfg .policy .target_return } , 移入 solved_task_pool." )
605
- solved_task_pool .add (task_id )
607
+ cur_task_id = cfg .policy .task_id
608
+ print (f"任务 { cur_task_id } 达到了目标奖励 { cfg .policy .target_return } , 移入 solved_task_pool." )
609
+ solved_task_pool .add (cur_task_id )
610
+
606
611
607
612
except Exception as e :
608
613
print (f"提取评估奖励时发生错误: { e } " )
0 commit comments