Skip to content

Commit 067a1ae

Browse files
puyuan1996puyuanxiongjyuPaParaZz1KJLdefeated
authored
feature(xjy): Enhance text-based games like Jericho with text decoding and configurable reconstruction loss mode (#355)
* v0.2.0 * polish(pu): add final_norm_option_in_encoder * polish(pu): polish jericho configs * tmp * fix(pu): fix world model init bug when use pretrained_model * tmp * feature(xjy): add text regularization function * feature(xjy): add decode text regularization option and related logs (#348) * fix(xjy): fixed some bug and add a function to output the decoder's text * fix(pu): fix _shift_right in decode loss * fix(xjy): add decode text function and decode_loss_mode option of reconstruction loss for jericho (#363) * Standardized the format and fixed existing bugs * resolved game_buffer bug and polished formatting * polish(xjy): standardize decode text related code for jericho (#366) * polish(xjy): delete unnecessary comments and translate CN comments into EN * fix(xjy): merged latest main branch (#368) * v0.2.0 * style(pu): use actions/upload-artifact@v3 * fix(pu): fix Union import in game_segment * style(pu): use actions/upload-artifact@v4 * test(nyz): only upload cov in macos * fix(pu): fix reanalyze_ratio compatibility with rope embed (#342) * fix(pu): fix release.yml * fix(pu): fix release.yml (#343) * fix(pu): fix release.yml * fix(pu): fix release.yml * fix(pu): fix release.yml * fix(pu): fix release.yml * fix(pu): fix release.yml * fix(pu): use actions/download-artifact@v2 * fix(pu): use actions/download-artifact@v4 * release v0.2.0 * fix(lkj): fix typo in customize_envs.md * fix(pu): adapt atari and dmc2gym env to support shared_memory (#345) * fix(pu): fix atari and dmc2gym env to support shared_memory * tmp * fix(pu): fix frame_stack_num default cfg in atari env --------- Co-authored-by: puyuan <puyuan1996@qq.com> * delete unnecessary comments and translate CN comments into EN * delete unnecessary comment --------- Co-authored-by: 蒲源 <2402552459@qq.com> Co-authored-by: PaParaZz1 <niuyazhe314@outlook.com> Co-authored-by: 蒲源 <48008469+puyuan1996@users.noreply.github.com> Co-authored-by: 林楷傑 <46377141+KJLdefeated@users.noreply.github.com> Co-authored-by: puyuan <puyuan1996@qq.com> * latest remove unnucessary comments * fix(pu): fix compatibility * polish(pu): polish readme and requirements --------- Co-authored-by: puyuan <puyuan1996@qq.com> Co-authored-by: xiongjyu <xiongjyu@gmail.com> Co-authored-by: PaParaZz1 <niuyazhe314@outlook.com> Co-authored-by: 林楷傑 <46377141+KJLdefeated@users.noreply.github.com>
1 parent 2e98102 commit 067a1ae

23 files changed

+532
-177
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
[![GitHub license](https://img.shields.io/github/license/opendilab/LightZero)](https://github.com/opendilab/LightZero/blob/master/LICENSE)
2929
[![discord badge](https://dcbadge.vercel.app/api/server/dkZS2JF56X?style=flat)](https://discord.gg/dkZS2JF56X)
3030

31-
Updated on 2025.04.09 LightZero-v0.2.0
31+
Updated on 2025.06.03 LightZero-v0.2.0
3232

3333
English | [简体中文(Simplified Chinese)](https://github.com/opendilab/LightZero/blob/main/README.zh.md) | [Documentation](https://opendilab.github.io/LightZero) | [LightZero Paper](https://arxiv.org/abs/2310.08348) | [🔥UniZero Paper](https://arxiv.org/abs/2406.10667) | [🔥ReZero Paper](https://arxiv.org/abs/2404.16364)
3434

README.zh.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
[![Contributors](https://img.shields.io/github/contributors/opendilab/LightZero)](https://github.com/opendilab/LightZero/graphs/contributors)
2828
[![GitHub license](https://img.shields.io/github/license/opendilab/LightZero)](https://github.com/opendilab/LightZero/blob/master/LICENSE)
2929

30-
最近更新于 2025.04.09 LightZero-v0.2.0
30+
最近更新于 2025.06.03 LightZero-v0.2.0
3131

3232
[English](https://github.com/opendilab/LightZero/blob/main/README.md) | 简体中文 | [文档](https://opendilab.github.io/LightZero) | [LightZero 论文](https://arxiv.org/abs/2310.08348) | [🔥UniZero 论文](https://arxiv.org/abs/2406.10667) | [🔥ReZero 论文](https://arxiv.org/abs/2404.16364)
3333

lzero/entry/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .eval_alphazero import eval_alphazero
22
from .eval_muzero import eval_muzero
3+
34
from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
45
from .train_alphazero import train_alphazero
56
from .train_muzero import train_muzero

lzero/entry/eval_muzero.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from functools import partial
33
from typing import Optional, Tuple
4+
import logging
45

56
import numpy as np
67
import torch
@@ -51,7 +52,7 @@ def eval_muzero(
5152
# Create main components: env, policy
5253
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
5354
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
54-
55+
# print(f"cfg.seed:{cfg.seed}")
5556
evaluator_env.seed(cfg.seed, dynamic_seed=False)
5657
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
5758

lzero/mcts/buffer/game_buffer.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,18 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
151151
# Indices exceeding `game_segment_length` are padded with the next segment and are not updated
152152
# in the current implementation. Therefore, we need to sample `pos_in_game_segment` within
153153
# [0, game_segment_length - num_unroll_steps] to avoid padded data.
154-
155-
# TODO: Consider increasing `self._cfg.game_segment_length` to ensure sampling efficiency.
156-
# if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps:
157-
# pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item()
158-
159-
# NOTE: Sample the init position from the whole segment, but not from the padded part
160-
if pos_in_game_segment >= self._cfg.game_segment_length:
161-
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item()
154+
155+
if self._cfg.action_type == 'varied_action_space':
156+
# For some environments (e.g., Jericho), the action space size may be different.
157+
# To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length),
158+
# we avoid sampling from the last `num_unroll_steps` steps of the game segment.
159+
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps:
160+
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item()
161+
else:
162+
# For environments with a fixed action space (e.g., Atari),
163+
# we can safely sample from the entire game segment range.
164+
if pos_in_game_segment >= self._cfg.game_segment_length:
165+
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item()
162166

163167
pos_in_game_segment_list.append(pos_in_game_segment)
164168

lzero/mcts/tree_search/mcts_ctree.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m
7575
def search(
7676
self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int,
7777
List[Any]], timestep: Union[int, List[Any]]
78-
) -> None:
78+
) -> dict:
7979
"""
8080
Overview:
8181
Perform Monte Carlo Tree Search (MCTS) for a batch of root nodes in parallel.
@@ -93,6 +93,10 @@ def search(
9393

9494
# preparation some constant
9595
batch_size = roots.num
96+
97+
# Store the latent state of each possible action at the MCTS root for each environment.
98+
first_action_latent_map = {env_id: {} for env_id in range(batch_size)} # {env_id: {action: latent_state}}
99+
96100
pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor
97101
# the data storage of latent states: storing the latent state of all the nodes in the search.
98102
latent_state_batch_in_search_path = [latent_state_roots]
@@ -156,8 +160,15 @@ def search(
156160
network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value))
157161
network_output.reward = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.reward))
158162

163+
for env_id in range(batch_size):
164+
depth = search_depth[env_id]
165+
action = last_actions[env_id].item()
166+
if depth == 1 and action not in first_action_latent_map[env_id]:
167+
first_action_latent_map[env_id][action] = network_output.latent_state[env_id]
168+
else:
169+
continue
170+
159171
latent_state_batch_in_search_path.append(network_output.latent_state)
160-
161172
# tolist() is to be compatible with cpp datatype.
162173
reward_batch = network_output.reward.reshape(-1).tolist()
163174
value_batch = network_output.value.reshape(-1).tolist()
@@ -173,6 +184,8 @@ def search(
173184
current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch,
174185
min_max_stats_lst, results, virtual_to_play_batch
175186
)
187+
188+
return first_action_latent_map
176189

177190

178191
class MuZeroMCTSCtree(object):

lzero/model/common.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -364,12 +364,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
364364

365365
class HFLanguageRepresentationNetwork(nn.Module):
366366
def __init__(self,
367-
model_path: str = 'google-bert/bert-base-uncased',
368-
embedding_size: int = 768,
369-
group_size: int = 8,
370-
norm_type: str = "simnorm",
371-
# norm_type: str = "layernorm", # TODO: Why does nan appear in the first step of training?
372-
tokenizer=None):
367+
model_path: str = 'google-bert/bert-base-uncased',
368+
embedding_size: int = 768,
369+
group_size: int = 8,
370+
final_norm_option_in_encoder: str = "layernorm",
371+
tokenizer=None):
373372
"""
374373
Overview:
375374
This class defines a language representation network that utilizes a pretrained Hugging Face model.
@@ -379,7 +378,7 @@ def __init__(self,
379378
- model_path (str): The path to the pretrained Hugging Face model. Default is 'google-bert/bert-base-uncased'.
380379
- embedding_size (int): The dimension of the output embeddings. Default is 768.
381380
- group_size (int): The group size for SimNorm when using normalization.
382-
- norm_type (str): The type of normalization to use ("simnorm" or "layernorm"). Default is "layernorm".
381+
- final_norm_option_in_encoder (str): The type of normalization to use ("simnorm" or "layernorm"). Default is "layernorm".
383382
- tokenizer (Optional): An instance of a tokenizer. If None, the tokenizer will be loaded from the pretrained model.
384383
"""
385384
super().__init__()
@@ -389,12 +388,13 @@ def __init__(self,
389388

390389
# In distributed training, only the rank 0 process downloads the model, and other processes load from cache to speed up startup.
391390
if get_rank() == 0:
392-
self.model = AutoModel.from_pretrained(model_path)
391+
self.pretrained_model = AutoModel.from_pretrained(model_path)
392+
393393
if get_world_size() > 1:
394394
# Wait for rank 0 to finish loading the model.
395395
torch.distributed.barrier()
396396
if get_rank() != 0:
397-
self.model = AutoModel.from_pretrained(model_path)
397+
self.pretrained_model = AutoModel.from_pretrained(model_path)
398398

399399
if tokenizer is None:
400400
# Only rank 0 downloads the tokenizer, and then other processes load it from cache.
@@ -409,15 +409,15 @@ def __init__(self,
409409

410410
# Set the embedding dimension. A linear projection is added (the dimension remains unchanged here but can be extended for other mappings).
411411
self.embedding_size = embedding_size
412-
self.embed_proj_head = nn.Linear(self.model.config.hidden_size, self.embedding_size)
412+
self.embed_proj_head = nn.Linear(self.pretrained_model.config.hidden_size, self.embedding_size)
413413

414-
# Select the normalization method based on the norm_type parameter.
415-
if norm_type.lower() == "simnorm":
414+
# # Select the normalization method based on the final_norm_option_in_encoder parameter.
415+
if final_norm_option_in_encoder.lower() == "simnorm":
416416
self.norm = SimNorm(simnorm_dim=group_size)
417-
elif norm_type.lower() == "layernorm":
417+
elif final_norm_option_in_encoder.lower() == "layernorm":
418418
self.norm = nn.LayerNorm(embedding_size)
419419
else:
420-
raise NotImplementedError(f"Normalization type '{norm_type}' is not implemented. "
420+
raise NotImplementedError(f"Normalization type '{final_norm_option_in_encoder}' is not implemented. "
421421
f"Choose 'simnorm' or 'layernorm'.")
422422

423423
def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
@@ -433,26 +433,27 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
433433
Returns:
434434
- torch.Tensor: The processed language embedding with shape [batch_size, embedding_size].
435435
"""
436+
436437
# Construct the attention mask to exclude padding tokens.
437438
attention_mask = x != self.tokenizer.pad_token_id
438439

439440
# Use no_grad context if specified to disable gradient computation.
440441
if no_grad:
441442
with torch.no_grad():
442443
x = x.long() # Ensure the input tensor is of type long.
443-
outputs = self.model(x, attention_mask=attention_mask)
444+
outputs = self.pretrained_model(x, attention_mask=attention_mask)
444445
# Get the hidden state from the last layer and select the output corresponding to the [CLS] token.
445446
cls_embedding = outputs.last_hidden_state[:, 0, :]
446447
else:
447448
x = x.long()
448-
outputs = self.model(x, attention_mask=attention_mask)
449+
outputs = self.pretrained_model(x, attention_mask=attention_mask)
449450
cls_embedding = outputs.last_hidden_state[:, 0, :]
450451

451452
# Apply linear projection to obtain the desired output dimension.
452453
cls_embedding = self.embed_proj_head(cls_embedding)
453454
# Normalize the embeddings using the selected normalization layer (SimNorm or LayerNorm) to ensure training stability.
454455
cls_embedding = self.norm(cls_embedding)
455-
456+
456457
return cls_embedding
457458

458459

@@ -468,6 +469,7 @@ def __init__(
468469
norm_type: str = 'BN',
469470
embedding_dim: int = 256,
470471
group_size: int = 8,
472+
final_norm_option_in_encoder: str = 'LayerNorm', # TODO
471473
) -> None:
472474
"""
473475
Overview:
@@ -486,6 +488,8 @@ def __init__(
486488
- norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
487489
- embedding_dim (:obj:`int`): The dimension of the latent state.
488490
- group_size (:obj:`int`): The dimension for simplicial normalization.
491+
- final_norm_option_in_encoder (:obj:`str`): The normalization option for the final layer, defaults to 'SimNorm'. \
492+
Options are 'SimNorm' and 'LayerNorm'.
489493
"""
490494
super().__init__()
491495
assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
@@ -530,7 +534,14 @@ def __init__(
530534
elif self.observation_shape[1] in [84, 96]:
531535
self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False)
532536

533-
self.sim_norm = SimNorm(simnorm_dim=group_size)
537+
self.final_norm_option_in_encoder = final_norm_option_in_encoder
538+
if self.final_norm_option_in_encoder == 'LayerNorm':
539+
self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5)
540+
elif self.final_norm_option_in_encoder == 'SimNorm':
541+
self.final_norm = SimNorm(simnorm_dim=group_size)
542+
else:
543+
raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}")
544+
534545

535546
def forward(self, x: torch.Tensor) -> torch.Tensor:
536547
"""
@@ -557,7 +568,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
557568
x = x.view(-1, self.embedding_dim)
558569

559570
# NOTE: very important for training stability.
560-
x = self.sim_norm(x)
571+
x = self.final_norm(x)
561572

562573
return x
563574

@@ -670,6 +681,7 @@ def __init__(
670681
activation: nn.Module = nn.GELU(approximate='tanh'),
671682
norm_type: Optional[str] = 'BN',
672683
group_size: int = 8,
684+
final_norm_option_in_encoder: str = 'LayerNorm', # TODO
673685
) -> torch.Tensor:
674686
"""
675687
Overview:
@@ -700,7 +712,15 @@ def __init__(
700712
# last_linear_layer_init_zero=True is beneficial for convergence speed.
701713
last_linear_layer_init_zero=True,
702714
)
703-
self.sim_norm = SimNorm(simnorm_dim=group_size)
715+
716+
# # Select the normalization method based on the final_norm_option_in_encoder parameter.
717+
if final_norm_option_in_encoder.lower() == "simnorm":
718+
self.norm = SimNorm(simnorm_dim=group_size)
719+
elif final_norm_option_in_encoder.lower() == "layernorm":
720+
self.norm = nn.LayerNorm(hidden_channels)
721+
else:
722+
raise NotImplementedError(f"Normalization type '{final_norm_option_in_encoder}' is not implemented. "
723+
f"Choose 'simnorm' or 'layernorm'.")
704724

705725
def forward(self, x: torch.Tensor) -> torch.Tensor:
706726
"""
@@ -709,8 +729,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
709729
- output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size.
710730
"""
711731
x = self.fc_representation(x)
712-
# TODO
713-
x = self.sim_norm(x)
732+
x = self.norm(x)
733+
714734
return x
715735

716736

lzero/model/unizero_model.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import torch.nn as nn
55
from ding.utils import MODEL_REGISTRY, SequenceType
66
from easydict import EasyDict
7+
from transformers import T5ForConditionalGeneration, T5Tokenizer
78

89
from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \
910
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \
1011
HFLanguageRepresentationNetwork
1112
from .unizero_world_models.tokenizer import Tokenizer
1213
from .unizero_world_models.world_model import WorldModel
14+
from ding.utils import ENV_REGISTRY, set_pkg_seed, get_rank, get_world_size
1315

1416

1517
# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document.
@@ -64,6 +66,10 @@ def __init__(
6466
- analysis_sim_norm (:obj:`bool`): Whether to analyze the similarity of the norm.
6567
"""
6668
super(UniZeroModel, self).__init__()
69+
# Get current world size and rank for distributed setups.
70+
self.world_size: int = get_world_size()
71+
self.rank: int = get_rank()
72+
6773
self.action_space_size = action_space_size
6874
self.activation = activation
6975
self.downsample = downsample
@@ -77,6 +83,7 @@ def __init__(
7783
layer_num=2,
7884
activation=self.activation,
7985
group_size=world_model_cfg.group_size,
86+
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder
8087
)
8188
# TODO: only for MemoryEnv now
8289
self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25)
@@ -89,8 +96,21 @@ def __init__(
8996
print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder')
9097
print('==' * 20)
9198
elif world_model_cfg.obs_type == 'text':
92-
self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim)
93-
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False,)
99+
self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder)
100+
# print(self.representation_network.model.encoder.layer[0].attention.output.LayerNorm.weight)
101+
102+
if self.rank == 0:
103+
self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small")
104+
self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small")
105+
if self.world_size > 1:
106+
# Wait until rank 0 finishes loading the tokenizer
107+
torch.distributed.barrier()
108+
if self.rank != 0:
109+
self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small")
110+
self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small")
111+
112+
projection = [self.representation_network.pretrained_model.config.hidden_size, self.decoder_network.config.d_model]
113+
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, with_lpips=False, projection=projection)
94114
self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer)
95115
print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
96116
print('==' * 20)
@@ -107,6 +127,7 @@ def __init__(
107127
norm_type=norm_type,
108128
embedding_dim=world_model_cfg.embed_dim,
109129
group_size=world_model_cfg.group_size,
130+
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder
110131
)
111132

112133
# ====== for analysis ======

0 commit comments

Comments
 (0)