diff --git a/env.py b/env.py new file mode 100644 index 0000000..0b279ef --- /dev/null +++ b/env.py @@ -0,0 +1,86 @@ +import re +from typing import Any, Dict, Optional, Tuple + +import textarena as ta + +class ConnectFourEnv(ta.Env): + def __init__(self, is_open: bool=True, num_rows: int=6, num_cols: int=7): + """ + Args: + is_open (bool): If True, the game state is visible to the players. + num_rows (int): Number of rows in the game board. + num_cols (int): Number of columns in the game board. + """ + self.is_open = is_open + self.num_rows = num_rows + self.num_cols = num_cols + + def reset(self, num_players: int, seed: Optional[int] = None): + self.state = ta.TwoPlayerState(num_players=num_players, seed=seed) + game_state = {"board": [["." for _ in range(self.num_cols)] for _ in range(self.num_rows)]} + self.state.reset(game_state=game_state, player_prompt_function=self._generate_player_prompt) + self.state.add_observation(message=(f"Board state:\n{self._render_board()}" if self.is_open else "The game board is not visible to players."), observation_type=ta.ObservationType.GAME_BOARD) + + def _generate_player_prompt(self, player_id: int, game_state: Dict[int, Any]) -> str: + return ( + f"You are Player {player_id} in Connect Four.\nYour disc symbol: {'X' if player_id == 0 else 'O'}.\n" + f"The game board has {self.num_rows} rows and {self.num_cols} columns.\n" + f"Players take turns dropping their disc into one of the columns (0 to {self.num_cols - 1}).\n" + "The first to connect (their own) four discs vertically, horizontally, or diagonally wins.\n" + "On your turn, enter the column number in squared brackets to make your move.\nFor example: '[col 4]' or '[col 1]'." + ) + + def _render_board(self) -> str: + column_numbers = " ".join([str(c) for c in range(self.num_cols)]) + separator = "-" * (self.num_cols * 2 - 1) + board_rows = "\n".join([" ".join(row) for row in self.state.game_state["board"]]) + return f"{column_numbers}\n{separator}\n{board_rows}" + + def step(self, action: str) -> Tuple[bool, ta.Info]: + self.state.add_observation(from_id=self.state.current_player_id, message=action, observation_type=ta.ObservationType.PLAYER_ACTION) + is_valid, col, reason = self._validate_action(action=action) # check if the actions is valid + if not is_valid: self.state.set_invalid_move(reason=reason) + else: + row = self._get_available_row(col) # place the disc + player_symbol = "X" if self.state.current_player_id == 0 else "O" + self.state.add_observation(message=f"Player {self.state.current_player_id} dropped their disk ({player_symbol}) into column {col}.", observation_type=ta.ObservationType.GAME_ACTION_DESCRIPTION) + self.state.game_state["board"][row][col] = player_symbol # insert disc + if self._check_win(row, col): self.state.set_winner(player_id=self.state.current_player_id, reason=f"Player {self.state.current_player_id} wins by connecting four!") + elif self._check_draw(): self.state.set_draw(reason="Game ended in a draw.") + else: # update board state + if self.is_open: self.state.add_observation(message=f"Board state:\n{self._render_board()}", observation_type=ta.ObservationType.GAME_BOARD) + return self.state.step() + + def _validate_action(self, action: str) -> Tuple[bool, Optional[int], Optional[str]]: + match = re.compile(r'.*\[(?:col\s*)?(\d+)\].*', re.IGNORECASE).search(action) + if not match: return False, None, f"Player {self.state.current_player_id}, Invalid action format. Expected format: '[col x]'." + col = int(match.group(1)) + if not (0 <= col < self.num_cols): return False, None, f"Player {self.state.current_player_id}, Invalid action. Column {col} is out of bounds." + if self.state.game_state["board"][0][col] != ".": return False, None, f"Player {self.state.current_player_id}, Invalid action. Column {col} is full." + return True, col, None + + def _get_available_row(self, col: int) -> int: + for r in range(self.num_rows - 1, -1, -1): + if self.state.game_state["board"][r][col] == ".": + return r + raise Exception("The column should be validated before calling the _get_available_row function.") + + def _check_win(self, row: int, col:int) -> bool: + for direction in [((0, 1), (0, -1)), ((1, 0), (-1, 0)), ((1, 1), (-1, -1)), ((1, -1), (-1, 1)),]: + total = 1 # Count the disc just placed + for delta_row, delta_col in direction: + total += self._check_direction(self.state.game_state["board"], row, col, delta_row, delta_col, self.state.game_state["board"][row][col]) + if total >= 4: return True + return False + + def _check_direction(self, board, row, col, delta_row, delta_col, disc) -> int: + count = 0 + r, c = row + delta_row, col + delta_col + while 0 <= r < self.num_rows and 0 <= c < self.num_cols and board[r][c] == disc: + count += 1 + r += delta_row + c += delta_col + return count + + def _check_draw(self) -> bool: + return all(self.state.game_state["board"][0][c] != "." for c in range(self.num_cols))