Skip to content

Commit 158e4a0

Browse files
author
puyuan
committed
tmp
1 parent bc5003a commit 158e4a0

File tree

3 files changed

+26
-18
lines changed

3 files changed

+26
-18
lines changed

lzero/policy/unizero_multitask.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from .utils import configure_optimizers_nanogpt
1818
import sys
1919

20-
sys.path.append('/cpfs04/user/puyuan/code/LibMTL')
20+
# sys.path.append('/cpfs04/user/puyuan/code/LibMTL')
21+
sys.path.append('/fs-computility/niuyazhe/puyuan/code/LibMTL')
22+
2123
from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect
2224
from LibMTL.weighting.moco_generic import GenericMoCo, MoCoCfg
2325
from LibMTL.weighting.moco_fast import FastMoCo, MoCoCfg
@@ -634,11 +636,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr
634636
# rank = get_rank()
635637
# print(f'Rank {rank}: cfg.policy.task_id : {self._cfg.task_id}, self._cfg.batch_size {self._cfg.batch_size}')
636638

637-
target_reward = target_reward.view(self._cfg.batch_size[task_id], -1)
638-
target_value = target_value.view(self._cfg.batch_size[task_id], -1)
639+
cur_batch_size = target_reward.size(0) # run-time batch
640+
641+
target_reward = target_reward.view(cur_batch_size, -1)
642+
target_value = target_value.view(cur_batch_size, -1)
639643

640-
target_reward = target_reward.view(self._cfg.batch_size[task_id], -1)
641-
target_value = target_value.view(self._cfg.batch_size[task_id], -1)
644+
# target_reward = target_reward.view(self._cfg.batch_size[task_id], -1)
645+
# target_value = target_value.view(self._cfg.batch_size[task_id], -1)
642646

643647
# assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0)
644648

@@ -654,10 +658,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr
654658
batch_for_gpt = {}
655659
if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1:
656660
batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape(
657-
self._cfg.batch_size[task_id], -1, self._cfg.model.observation_shape)
661+
cur_batch_size, -1, self._cfg.model.observation_shape)
658662
elif len(self._cfg.model.observation_shape) == 3:
659663
batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape(
660-
self._cfg.batch_size[task_id], -1, *self._cfg.model.observation_shape)
664+
cur_batch_size, -1, *self._cfg.model.observation_shape)
661665

662666
batch_for_gpt['actions'] = action_batch.squeeze(-1)
663667
batch_for_gpt['rewards'] = target_reward_categorical[:, :-1]

zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
184184
reanalyze_ratio=reanalyze_ratio,
185185
n_episode=n_episode,
186186
replay_buffer_size=int(5e5),
187-
eval_freq=int(1e4), # TODO:
188-
# eval_freq=int(2e4),
187+
# eval_freq=int(1e4), # TODO: 8games
188+
eval_freq=int(2e4), # TODO: 26games
189189
collector_env_num=collector_env_num,
190190
evaluator_env_num=evaluator_env_num,
191191
buffer_reanalyze_freq=buffer_reanalyze_freq,
@@ -204,12 +204,12 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod
204204

205205

206206
# ========= TODO: global BENCHMARK_NAME =========
207-
exp_name_prefix = f'data_unizero_atari_mt_20250527/atari_{len(env_id_list)}games_orig_simnorm-kl_vit_moe8_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
207+
# exp_name_prefix = f'data_unizero_atari_mt_20250527/atari_{len(env_id_list)}games_orig_simnorm-kl_vit_moe8_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
208208

209209
# exp_name_prefix = f'data_unizero_atari_mt_20250522/atari_{len(env_id_list)}games_orig_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
210210
# exp_name_prefix = f'data_unizero_atari_mt_20250527/atari_{len(env_id_list)}games_orig_simnorm-kl_vit_moco-v2_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
211211

212-
exp_name_prefix = f'data_unizero_atari_mt_20250530/atari_{len(env_id_list)}games_orig_vit_ln-mse_moe8_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
212+
exp_name_prefix = f'data_unizero_atari_mt_20250601/atari_{len(env_id_list)}games_orig_vit_ln-mse_moe8_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
213213

214214
# exp_name_prefix = f'data_unizero_atari_mt_20250521/atari_{len(env_id_list)}games_orig_simnorm-kl_vit_moe8_taskembed128_tran-nlayer{num_layers}_rr1_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/'
215215

@@ -254,7 +254,7 @@ def create_env_manager():
254254
255255
=========== volce atari8 =========================
256256
cd /fs-computility/niuyazhe/puyuan/code/LightZero/
257-
python -m torch.distributed.launch --nproc_per_node=4 --master_port=29502 /fs-computility/niuyazhe/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee /fs-computility/niuyazhe/puyuan/code/LightZero/log/20250509/uz_mt_atari8_orig_vit_ln-mse_moe8_nlayer8_brf002_seed12.log
257+
python -m torch.distributed.launch --nproc_per_node=8 --master_port=29502 /fs-computility/niuyazhe/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee /fs-computility/niuyazhe/puyuan/code/LightZero/log/20250509/uz_mt_atari26_orig_vit_ln-mse_moe8_nlayer8_brf002_seed12.log
258258
259259
260260
=========== cpfs atari8 =========================
@@ -306,7 +306,7 @@ def create_env_manager():
306306
import os
307307

308308

309-
num_games = 8 # 26 # 8
309+
num_games = 26 # 26 # 8
310310
num_layers = 8 # ==============TODO==============
311311
action_space_size = 18
312312
collector_env_num = 8
@@ -383,7 +383,7 @@ def create_env_manager():
383383

384384

385385
import torch.distributed as dist
386-
for seed in [1]:
386+
for seed in [1,2]:
387387
configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num,
388388
num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length,
389389
norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition,

zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,11 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
139139
lora_alpha=1,
140140
lora_dropout=0.0,
141141
lora_scale_init=1,
142-
# min_stage0_iters=10000,
143-
# max_stage_iters=20000,
142+
144143
min_stage0_iters=10000,
145144
max_stage_iters=5000,
145+
# min_stage0_iters=10,
146+
# max_stage_iters=20,
146147
),
147148
),
148149
use_task_exploitation_weight=False, # TODO
@@ -211,7 +212,7 @@ def generate_configs(env_id_list: List[str],
211212
configs = []
212213
# ========= TODO: global BENCHMARK_NAME =========
213214

214-
exp_name_prefix = f'data_suz_dmc_mt_balance_20250526/dmc_{len(env_id_list)}tasks_frameskip4_balance-stage-total-{curriculum_stage_num}_stage0-10k-5k_moe8_nlayer8_not-share-head_brf{buffer_reanalyze_freq}_seed{seed}/'
215+
exp_name_prefix = f'data_suz_dmc_mt_balance_20250601/dmc_{len(env_id_list)}tasks_frameskip4_balance-stage-total-{curriculum_stage_num}_stage0-10k-5k_moe8_nlayer8_not-share-head_brf{buffer_reanalyze_freq}_seed{seed}/'
215216

216217
# exp_name_prefix = f'data_lz/data_suz_dmc_mt_20250409_moco/dmc_{len(env_id_list)}tasks_notaskembed_nlayer8_not-share-head_final-ln_bs64_brf{buffer_reanalyze_freq}_seed{seed}/'
217218

@@ -266,6 +267,9 @@ def create_env_manager():
266267
Overview:
267268
This script should be executed with <nproc_per_node> GPUs.
268269
Run the following command to launch the script:
270+
cd /fs-computility/niuyazhe/puyuan/code/LightZero/
271+
python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 /fs-computility/niuyazhe/puyuan/code/LightZero/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py 2>&1 | tee /fs-computility/niuyazhe/puyuan/code/LightZero/log/20250509/uz_mt_dmc18_ln_balance_moe8_stage5_stage0-10k-5k_nlayer8.log
272+
269273
cd /cpfs04/user/puyuan/code/LightZero/
270274
python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 /cpfs04/user/puyuan/code/LightZero/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py 2>&1 | tee /cpfs04/user/puyuan/code/LightZero/log/20250522_cpfs/uz_mt_dmc18_ln_balance_moe8_stage5_stage0-5k-10k_nlayer8.log
271275
torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py
@@ -385,7 +389,7 @@ def create_env_manager():
385389
# evaluator_env_num = 2
386390
# num_simulations = 1
387391
# total_batch_size = 8
388-
# batch_size = [2 for _ in range(len(env_id_list))]
392+
# batch_size = [3 for _ in range(len(env_id_list))]
389393
# =======================================
390394

391395
seed = 0 # You can iterate over multiple seeds if needed

0 commit comments

Comments
 (0)