10
10
from ding .envs .env .base_env import BaseEnvTimestep
11
11
from ding .utils .registry_factory import ENV_REGISTRY
12
12
from gymnasium import spaces
13
- from pettingzoo .classic .chess import chess_utils
14
-
15
13
from zoo .board_games .chess .envs .chess_env import ChessEnv
14
+ from pettingzoo .classic .chess import chess_utils as pz_cu
16
15
17
16
18
17
@ENV_REGISTRY .register ('chess_lightzero' )
@@ -50,16 +49,15 @@ def __init__(self, cfg=None):
50
49
51
50
@property
52
51
def legal_actions (self ):
53
- return chess_utils .legal_moves (self .board )
52
+ return pz_cu .legal_moves (self .board )
54
53
55
54
def observe (self , agent_index ):
56
55
try :
57
- observation = chess_utils .get_observation (self .board , agent_index ).astype (float ) # TODO
56
+ observation = pz_cu .get_observation (self .board , agent_index ).astype (float ) # TODO
58
57
except Exception as e :
59
- print ('debug' )
58
+ print (f 'debug: { e } ' )
60
59
print (f"self.board:{ self .board } " )
61
60
62
-
63
61
# TODO:
64
62
# observation = np.dstack((observation[:, :, :7], self.board_history))
65
63
# We need to swap the white 6 channels with black 6 channels
@@ -75,9 +73,12 @@ def observe(self, agent_index):
75
73
# observation[..., 13 * i : 13 * i + 6] = tmp
76
74
77
75
action_mask = np .zeros (4672 , dtype = np .int8 )
78
- action_mask [chess_utils .legal_moves (self .board )] = 1
76
+ action_mask [pz_cu .legal_moves (self .board )] = 1
79
77
return {'observation' : observation , 'action_mask' : action_mask }
80
78
79
+
80
+
81
+
81
82
def current_state (self ):
82
83
"""
83
84
Overview:
@@ -103,7 +104,7 @@ def get_done_winner(self):
103
104
if result == "*" :
104
105
winner = - 1
105
106
else :
106
- winner = chess_utils .result_to_int (result )
107
+ winner = pz_cu .result_to_int (result )
107
108
108
109
if not done :
109
110
winner = - 1
@@ -143,7 +144,7 @@ def reset(self, start_player_index=0, init_state=None, katago_policy_init=False,
143
144
self .board = chess .Board ()
144
145
145
146
action_mask = np .zeros (4672 , dtype = np .int8 )
146
- action_mask [chess_utils .legal_moves (self .board )] = 1
147
+ action_mask [pz_cu .legal_moves (self .board )] = 1
147
148
# self.board_history = np.zeros((8, 8, 104), dtype=bool)
148
149
149
150
if self .battle_mode == 'play_with_bot_mode' or self .battle_mode == 'eval_mode' :
@@ -265,10 +266,10 @@ def _player_step(self, action):
265
266
current_agent = self .current_player_index
266
267
267
268
# TODO: Update board history
268
- # next_board = chess_utils .get_observation(self.board, current_agent)
269
+ # next_board = pz_cu .get_observation(self.board, current_agent)
269
270
# self.board_history = np.dstack((next_board[:, :, 7:], self.board_history[:, :, :-13]))
270
271
271
- chosen_move = chess_utils .action_to_move (self .board , action , current_agent )
272
+ chosen_move = pz_cu .action_to_move (self .board , action , current_agent )
272
273
assert chosen_move in self .board .legal_moves
273
274
self .board .push (chosen_move )
274
275
@@ -277,7 +278,7 @@ def _player_step(self, action):
277
278
if result == "*" :
278
279
reward = 0.
279
280
else :
280
- reward = chess_utils .result_to_int (result )
281
+ reward = pz_cu .result_to_int (result )
281
282
282
283
if self .current_player == 1 :
283
284
reward = - reward
@@ -287,7 +288,7 @@ def _player_step(self, action):
287
288
info ['eval_episode_return' ] = reward
288
289
289
290
action_mask = np .zeros (4672 , dtype = np .int8 )
290
- action_mask [chess_utils .legal_moves (self .board )] = 1
291
+ action_mask [pz_cu .legal_moves (self .board )] = 1
291
292
292
293
obs = {
293
294
'observation' : self .observe (self .current_player_index )['observation' ],
@@ -318,14 +319,14 @@ def current_player(self, value):
318
319
self ._current_player = value
319
320
320
321
def random_action (self ):
321
- action_list = chess_utils .legal_moves (self .board )
322
+ action_list = pz_cu .legal_moves (self .board )
322
323
return np .random .choice (action_list )
323
324
324
325
def simulate_action (self , action ):
325
- if action not in chess_utils .legal_moves (self .board ):
326
+ if action not in pz_cu .legal_moves (self .board ):
326
327
raise ValueError ("action {0} on board {1} is not legal" .format (action , self .board .fen ()))
327
328
new_board = copy .deepcopy (self .board )
328
- new_board .push (chess_utils .action_to_move (self .board , action , self .current_player_index ))
329
+ new_board .push (pz_cu .action_to_move (self .board , action , self .current_player_index ))
329
330
if self .start_player_index == 0 :
330
331
start_player_index = 1
331
332
else :
0 commit comments