Initial commit from Openverse builder

This commit is contained in:
2001-01-01 00:00:00 +00:00
parent 7b0c307a91
commit 5a43c21d4b

86
env.py Normal file
View File

@@ -0,0 +1,86 @@
import re
from typing import Any, Dict, Optional, Tuple
import textarena as ta
class ConnectFourEnv(ta.Env):
def __init__(self, is_open: bool=True, num_rows: int=6, num_cols: int=7):
"""
Args:
is_open (bool): If True, the game state is visible to the players.
num_rows (int): Number of rows in the game board.
num_cols (int): Number of columns in the game board.
"""
self.is_open = is_open
self.num_rows = num_rows
self.num_cols = num_cols
def reset(self, num_players: int, seed: Optional[int] = None):
self.state = ta.TwoPlayerState(num_players=num_players, seed=seed)
game_state = {"board": [["." for _ in range(self.num_cols)] for _ in range(self.num_rows)]}
self.state.reset(game_state=game_state, player_prompt_function=self._generate_player_prompt)
self.state.add_observation(message=(f"Board state:\n{self._render_board()}" if self.is_open else "The game board is not visible to players."), observation_type=ta.ObservationType.GAME_BOARD)
def _generate_player_prompt(self, player_id: int, game_state: Dict[int, Any]) -> str:
return (
f"You are Player {player_id} in Connect Four.\nYour disc symbol: {'X' if player_id == 0 else 'O'}.\n"
f"The game board has {self.num_rows} rows and {self.num_cols} columns.\n"
f"Players take turns dropping their disc into one of the columns (0 to {self.num_cols - 1}).\n"
"The first to connect (their own) four discs vertically, horizontally, or diagonally wins.\n"
"On your turn, enter the column number in squared brackets to make your move.\nFor example: '[col 4]' or '[col 1]'."
)
def _render_board(self) -> str:
column_numbers = " ".join([str(c) for c in range(self.num_cols)])
separator = "-" * (self.num_cols * 2 - 1)
board_rows = "\n".join([" ".join(row) for row in self.state.game_state["board"]])
return f"{column_numbers}\n{separator}\n{board_rows}"
def step(self, action: str) -> Tuple[bool, ta.Info]:
self.state.add_observation(from_id=self.state.current_player_id, message=action, observation_type=ta.ObservationType.PLAYER_ACTION)
is_valid, col, reason = self._validate_action(action=action) # check if the actions is valid
if not is_valid: self.state.set_invalid_move(reason=reason)
else:
row = self._get_available_row(col) # place the disc
player_symbol = "X" if self.state.current_player_id == 0 else "O"
self.state.add_observation(message=f"Player {self.state.current_player_id} dropped their disk ({player_symbol}) into column {col}.", observation_type=ta.ObservationType.GAME_ACTION_DESCRIPTION)
self.state.game_state["board"][row][col] = player_symbol # insert disc
if self._check_win(row, col): self.state.set_winner(player_id=self.state.current_player_id, reason=f"Player {self.state.current_player_id} wins by connecting four!")
elif self._check_draw(): self.state.set_draw(reason="Game ended in a draw.")
else: # update board state
if self.is_open: self.state.add_observation(message=f"Board state:\n{self._render_board()}", observation_type=ta.ObservationType.GAME_BOARD)
return self.state.step()
def _validate_action(self, action: str) -> Tuple[bool, Optional[int], Optional[str]]:
match = re.compile(r'.*\[(?:col\s*)?(\d+)\].*', re.IGNORECASE).search(action)
if not match: return False, None, f"Player {self.state.current_player_id}, Invalid action format. Expected format: '[col x]'."
col = int(match.group(1))
if not (0 <= col < self.num_cols): return False, None, f"Player {self.state.current_player_id}, Invalid action. Column {col} is out of bounds."
if self.state.game_state["board"][0][col] != ".": return False, None, f"Player {self.state.current_player_id}, Invalid action. Column {col} is full."
return True, col, None
def _get_available_row(self, col: int) -> int:
for r in range(self.num_rows - 1, -1, -1):
if self.state.game_state["board"][r][col] == ".":
return r
raise Exception("The column should be validated before calling the _get_available_row function.")
def _check_win(self, row: int, col:int) -> bool:
for direction in [((0, 1), (0, -1)), ((1, 0), (-1, 0)), ((1, 1), (-1, -1)), ((1, -1), (-1, 1)),]:
total = 1 # Count the disc just placed
for delta_row, delta_col in direction:
total += self._check_direction(self.state.game_state["board"], row, col, delta_row, delta_col, self.state.game_state["board"][row][col])
if total >= 4: return True
return False
def _check_direction(self, board, row, col, delta_row, delta_col, disc) -> int:
count = 0
r, c = row + delta_row, col + delta_col
while 0 <= r < self.num_rows and 0 <= c < self.num_cols and board[r][c] == disc:
count += 1
r += delta_row
c += delta_col
return count
def _check_draw(self) -> bool:
return all(self.state.game_state["board"][0][c] != "." for c in range(self.num_cols))