Skip to content

Commit 76611cf

Browse files
authored
polish(xjy): standardize decode text related code for jericho (#366)
1 parent 2204abc commit 76611cf

File tree

12 files changed

+88
-333
lines changed

12 files changed

+88
-333
lines changed

lzero/entry/eval_muzero.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,8 @@ def eval_muzero(
6060

6161
# load pretrained model
6262
if model_path is not None:
63-
# print(policy._learn_model.representation_network.pretrained_model.encoder.layer[0].attention.output.LayerNorm.weight)
6463
logging.info(f"Loading pretrained model from {model_path}...")
6564
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
66-
# policy.eval_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
6765
logging.info("Pretrained model loaded successfully!")
6866
else:
6967
logging.warning("model_path is None!!!")

lzero/mcts/buffer/game_buffer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,16 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
151151
# Indices exceeding `game_segment_length` are padded with the next segment and are not updated
152152
# in the current implementation. Therefore, we need to sample `pos_in_game_segment` within
153153
# [0, game_segment_length - num_unroll_steps] to avoid padded data.
154-
154+
155155
if self._cfg.action_type == 'varied_action_space':
156-
# TODO: Consider increasing `self._cfg.game_segment_length` to ensure sampling efficiency.
156+
# For multi-environment training (e.g., Jericho), each environment may have a different discrete action space size.
157+
# To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length),
158+
# we avoid sampling from the last `num_unroll_steps` steps of the game segment.
157159
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps:
158160
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item()
159161
else:
160-
# NOTE: Sample the init position from the whole segment, but not from the padded part
162+
# For environments with a fixed action space (e.g., Atari),
163+
# we can safely sample from the entire game segment range.
161164
if pos_in_game_segment >= self._cfg.game_segment_length:
162165
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item()
163166

lzero/mcts/tree_search/mcts_ctree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def search(
9393

9494
# preparation some constant
9595
batch_size = roots.num
96+
97+
# Store the latent state of each possible action at the MCTS root for each environment.
9698
first_action_latent_map = {env_id: {} for env_id in range(batch_size)} # {env_id: {action: latent_state}}
9799

98100
pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor

lzero/model/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,6 @@ def __init__(
729729
# last_linear_layer_init_zero=True is beneficial for convergence speed.
730730
last_linear_layer_init_zero=True,
731731
)
732-
# self.sim_norm = SimNorm(simnorm_dim=group_size)
733732

734733
# # Select the normalization method based on the final_norm_option_in_encoder parameter.
735734
if final_norm_option_in_encoder.lower() == "simnorm":

lzero/model/unizero_world_models/tokenizer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,25 +188,23 @@ def decode_to_plain_text_for_decoder(
188188
List[List[int]]: List of decoded strings, one per input in batch.
189189
"""
190190

191-
# 设置 decoder_network projection_layer 为评估模式,关闭 dropout 等训练行为
191+
# Set decoder_network and projection_layer to evaluation mode to disable dropout and other training-specific behaviors.
192192
self.decoder_network.eval()
193193
self.projection_layer.eval()
194194

195-
# 如果 embeddings 不是 Tensor,则转换为 torch.Tensor
195+
# If embeddings is not a Tensor, convert it to a torch.Tensor.
196196
if not isinstance(embeddings, torch.Tensor):
197197
embeddings = torch.tensor(embeddings, dtype=torch.float32)
198198

199-
# 尝试从 decoder_network 获取设备信息,如果没有则从模型参数中获取
199+
# Attempt to retrieve the device information from decoder_network; if unavailable, fall back to the model’s parameters.
200200
try:
201201
device = self.decoder_network.device
202202
except AttributeError:
203203
device = next(self.decoder_network.parameters()).device
204204

205-
# 将 embeddings 移动到正确的设备上
206205
embeddings = embeddings.to(device)
207206

208-
with torch.no_grad(): # 在推理过程中关闭梯度计算,节约显存和计算
209-
207+
with torch.no_grad():
210208
if embeddings.dim() == 2:
211209
embeddings = embeddings.unsqueeze(1)
212210

lzero/model/unizero_world_models/tokenizer_bkp20250428.py

Lines changed: 0 additions & 244 deletions
This file was deleted.

lzero/model/unizero_world_models/world_model.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,19 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
9797

9898
# print(self.tokenizer.encoder.pretrained_model.encoder.layer[0].attention.output.LayerNorm.weight)
9999

100-
# 首先,构建需要跳过初始化的模块集合
100+
# First, build the set of modules to skip during re-initialization
101101
skip_modules = set(self.tokenizer.encoder.pretrained_model.modules())
102102
skip_modules.update(self.tokenizer.decoder_network.modules())
103103

104104
def custom_init(module):
105-
# 如果当前 module 属于跳过初始化的模块,则直接返回
105+
# If the current module is part of the skip list, return without reinitializing
106106
if module in skip_modules:
107107
return
108-
# 否则使用指定的初始化方法
108+
# Otherwise, apply the specified initialization method
109109
init_weights(module, norm_type=self.config.norm_type)
110-
# 递归地对模型中所有子模块应用 custom_init 函数
110+
111+
# Recursively apply `custom_init` to all submodules of the model
112+
# NOTE: This step is crucial — without skipping, pretrained modules (e.g., encoder/decoder) would be unintentionally re-initialized
111113
self.apply(custom_init)
112114

113115
# Apply weight initialization, the order is important
@@ -1414,6 +1416,19 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
14141416
else:
14151417
dormant_ratio_world_model = torch.tensor(0.)
14161418

1419+
# ========== for visualization ==========
1420+
# Uncomment the lines below for visualization
1421+
# predict_policy = outputs.logits_policy
1422+
# predict_policy = F.softmax(outputs.logits_policy, dim=-1)
1423+
# predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1)
1424+
# predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1)
1425+
# import pdb; pdb.set_trace()
1426+
# visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=[], suffix='pong_H10_H4_0613')
1427+
1428+
# visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_success_episode')
1429+
# visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode')
1430+
# ========== for visualization ==========
1431+
14171432
# For training stability, use target_tokenizer to compute the true next latent state representations
14181433
with torch.no_grad():
14191434
target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'])

0 commit comments

Comments
 (0)