Skip to content

Commit 6649e85

Browse files
author
puyuan
committed
polish(pu): polish chess config
1 parent 44c0287 commit 6649e85

File tree

3 files changed

+51
-52
lines changed

3 files changed

+51
-52
lines changed

zoo/board_games/chess/config/chess_alphazero_bot_mode_config.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,24 @@
33
# ==============================================================
44
# begin of the most frequently changed config specified by the user
55
# ==============================================================
6-
# collector_env_num = 8
7-
# n_episode = 8
8-
# evaluator_env_num = 5
9-
# num_simulations = 400
10-
# update_per_collect = 200
11-
# batch_size = 512
12-
# max_env_step = int(1e6)
13-
# mcts_ctree = False
6+
collector_env_num = 8
7+
n_episode = 8
8+
evaluator_env_num = 5
9+
num_simulations = 400
10+
update_per_collect = 200
11+
batch_size = 512
12+
max_env_step = int(1e6)
13+
mcts_ctree = False
1414

1515
# TODO: for debug
16-
collector_env_num = 2
17-
n_episode = 2
18-
evaluator_env_num = 2
19-
num_simulations = 4
20-
update_per_collect = 2
21-
batch_size = 2
22-
max_env_step = int(1e4)
23-
mcts_ctree = False
16+
# collector_env_num = 2
17+
# n_episode = 2
18+
# evaluator_env_num = 2
19+
# num_simulations = 4
20+
# update_per_collect = 2
21+
# batch_size = 2
22+
# max_env_step = int(1e4)
23+
# mcts_ctree = False
2424
# ==============================================================
2525
# end of the most frequently changed config specified by the user
2626
# ==============================================================
@@ -56,7 +56,7 @@
5656
model=dict(
5757
observation_shape=(8, 8, 20),
5858
action_space_size=int(8 * 8 * 73),
59-
# TODO: for debug
59+
# TODO: only for for debug
6060
num_res_blocks=1,
6161
num_channels=1,
6262
value_head_hidden_channels=[16],

zoo/board_games/chess/config/chess_alphazero_sp_mode_config.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,22 @@
1010
update_per_collect = 200
1111
batch_size = 512
1212
max_env_step = int(1e6)
13-
mcts_ctree = True
14-
# mcts_ctree = False
15-
13+
mcts_ctree = False
1614

1715
# TODO: for debug
18-
collector_env_num = 2
19-
n_episode = 2
20-
evaluator_env_num = 2
21-
num_simulations = 4
22-
update_per_collect = 2
23-
batch_size = 2
24-
max_env_step = int(1e4)
16+
# collector_env_num = 2
17+
# n_episode = 2
18+
# evaluator_env_num = 2
19+
# num_simulations = 2
20+
# update_per_collect = 1
21+
# batch_size = 2
22+
# max_env_step = int(1e4)
2523
# mcts_ctree = False
2624
# ==============================================================
2725
# end of the most frequently changed config specified by the user
2826
# ==============================================================
2927
chess_alphazero_config = dict(
30-
exp_name='data_az_ctree/chess_sp-mode_alphazero_seed0',
28+
exp_name='data_az_ptree/chess_sp-mode_alphazero_seed0',
3129
env=dict(
3230
board_size=8,
3331
battle_mode='self_play_mode',
@@ -58,14 +56,14 @@
5856
observation_shape=(8, 8, 20),
5957
action_space_size=int(8 * 8 * 73),
6058
# TODO: for debug
61-
num_res_blocks=1,
62-
num_channels=1,
63-
value_head_hidden_channels=[16],
64-
policy_head_hidden_channels=[16],
65-
# num_res_blocks=8,
66-
# num_channels=256,
67-
# value_head_hidden_channels=[256, 256],
68-
# policy_head_hidden_channels=[256, 256],
59+
# num_res_blocks=1,
60+
# num_channels=1,
61+
# value_head_hidden_channels=[16],
62+
# policy_head_hidden_channels=[16],
63+
num_res_blocks=8,
64+
num_channels=256,
65+
value_head_hidden_channels=[256, 256],
66+
policy_head_hidden_channels=[256, 256],
6967
),
7068
cuda=True,
7169
board_size=8,

zoo/board_games/chess/envs/chess_lightzero_env.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
from ding.envs.env.base_env import BaseEnvTimestep
1111
from ding.utils.registry_factory import ENV_REGISTRY
1212
from gymnasium import spaces
13-
from pettingzoo.classic.chess import chess_utils
14-
1513
from zoo.board_games.chess.envs.chess_env import ChessEnv
14+
from pettingzoo.classic.chess import chess_utils as pz_cu
1615

1716

1817
@ENV_REGISTRY.register('chess_lightzero')
@@ -50,16 +49,15 @@ def __init__(self, cfg=None):
5049

5150
@property
5251
def legal_actions(self):
53-
return chess_utils.legal_moves(self.board)
52+
return pz_cu.legal_moves(self.board)
5453

5554
def observe(self, agent_index):
5655
try:
57-
observation = chess_utils.get_observation(self.board, agent_index).astype(float) # TODO
56+
observation = pz_cu.get_observation(self.board, agent_index).astype(float) # TODO
5857
except Exception as e:
59-
print('debug')
58+
print(f'debug: {e}')
6059
print(f"self.board:{self.board}")
6160

62-
6361
# TODO:
6462
# observation = np.dstack((observation[:, :, :7], self.board_history))
6563
# We need to swap the white 6 channels with black 6 channels
@@ -75,9 +73,12 @@ def observe(self, agent_index):
7573
# observation[..., 13 * i : 13 * i + 6] = tmp
7674

7775
action_mask = np.zeros(4672, dtype=np.int8)
78-
action_mask[chess_utils.legal_moves(self.board)] = 1
76+
action_mask[pz_cu.legal_moves(self.board)] = 1
7977
return {'observation': observation, 'action_mask': action_mask}
8078

79+
80+
81+
8182
def current_state(self):
8283
"""
8384
Overview:
@@ -103,7 +104,7 @@ def get_done_winner(self):
103104
if result == "*":
104105
winner = -1
105106
else:
106-
winner = chess_utils.result_to_int(result)
107+
winner = pz_cu.result_to_int(result)
107108

108109
if not done:
109110
winner = -1
@@ -143,7 +144,7 @@ def reset(self, start_player_index=0, init_state=None, katago_policy_init=False,
143144
self.board = chess.Board()
144145

145146
action_mask = np.zeros(4672, dtype=np.int8)
146-
action_mask[chess_utils.legal_moves(self.board)] = 1
147+
action_mask[pz_cu.legal_moves(self.board)] = 1
147148
# self.board_history = np.zeros((8, 8, 104), dtype=bool)
148149

149150
if self.battle_mode == 'play_with_bot_mode' or self.battle_mode == 'eval_mode':
@@ -265,10 +266,10 @@ def _player_step(self, action):
265266
current_agent = self.current_player_index
266267

267268
# TODO: Update board history
268-
# next_board = chess_utils.get_observation(self.board, current_agent)
269+
# next_board = pz_cu.get_observation(self.board, current_agent)
269270
# self.board_history = np.dstack((next_board[:, :, 7:], self.board_history[:, :, :-13]))
270271

271-
chosen_move = chess_utils.action_to_move(self.board, action, current_agent)
272+
chosen_move = pz_cu.action_to_move(self.board, action, current_agent)
272273
assert chosen_move in self.board.legal_moves
273274
self.board.push(chosen_move)
274275

@@ -277,7 +278,7 @@ def _player_step(self, action):
277278
if result == "*":
278279
reward = 0.
279280
else:
280-
reward = chess_utils.result_to_int(result)
281+
reward = pz_cu.result_to_int(result)
281282

282283
if self.current_player == 1:
283284
reward = -reward
@@ -287,7 +288,7 @@ def _player_step(self, action):
287288
info['eval_episode_return'] = reward
288289

289290
action_mask = np.zeros(4672, dtype=np.int8)
290-
action_mask[chess_utils.legal_moves(self.board)] = 1
291+
action_mask[pz_cu.legal_moves(self.board)] = 1
291292

292293
obs = {
293294
'observation': self.observe(self.current_player_index)['observation'],
@@ -318,14 +319,14 @@ def current_player(self, value):
318319
self._current_player = value
319320

320321
def random_action(self):
321-
action_list = chess_utils.legal_moves(self.board)
322+
action_list = pz_cu.legal_moves(self.board)
322323
return np.random.choice(action_list)
323324

324325
def simulate_action(self, action):
325-
if action not in chess_utils.legal_moves(self.board):
326+
if action not in pz_cu.legal_moves(self.board):
326327
raise ValueError("action {0} on board {1} is not legal".format(action, self.board.fen()))
327328
new_board = copy.deepcopy(self.board)
328-
new_board.push(chess_utils.action_to_move(self.board, action, self.current_player_index))
329+
new_board.push(pz_cu.action_to_move(self.board, action, self.current_player_index))
329330
if self.start_player_index == 0:
330331
start_player_index = 1
331332
else:

0 commit comments

Comments
 (0)