diff --git a/env.py b/env.py new file mode 100644 index 0000000..6242a5c --- /dev/null +++ b/env.py @@ -0,0 +1,53 @@ +import re +from typing import Any, Dict, Optional, Tuple, List + +import textarena as ta + +class NimEnv(ta.Env): + def __init__(self, piles: List[int] = None): + """ + Args: + piles (List[int]): Initial sizes of the piles (e.g. [3, 5, 7]). If None, defaults to [3,4,5]. + """ + super().__init__() + self.initial_piles = piles if piles is not None else [3, 4, 5] + + def reset(self, num_players: int, seed: Optional[int] = None): + self.state = ta.TwoPlayerState(num_players=num_players, seed=seed) + self.state.reset(game_state={"piles": self.initial_piles.copy()}, player_prompt_function=self._prompt) + self.state.add_observation(message="Current Pile:\n" + self._render_piles(), observation_type=ta.ObservationType.GAME_BOARD) + + def _prompt(self, player_id: int, game_state: Dict[str, Any]) -> str: + return ( + f"Welcome to Nim, Player {player_id}!\nRules:\n- On your turn, remove at least one object from exactly one pile.\n" + "- Remove objects with the format '[pile quantity]', e.g. '[0 3]'.\n- Whoever takes the last object(s) wins!" + ) + + def step(self, action: str) -> Tuple[bool, ta.Info]: + self.state.add_observation(from_id=self.state.current_player_id, message=action, observation_type=ta.ObservationType.PLAYER_ACTION) + self._execute_move(action) # Execute the move (or mark invalid if the format is incorrect/illegal) + self.state.add_observation(message="Current Pile:\n" + self._render_piles(), observation_type=ta.ObservationType.GAME_BOARD) # After the current player moves, send the updated board to the opponent. + self._check_game_over() # Check if the game is over + return self.state.step() # Proceed to the next turn (or finalize if done) + + def _execute_move(self, action: str) -> None: + match = re.compile(r"\[\s*(\d+)\s+(\d+)\s*\]").search(action.strip()) # We'll look for actions in the format [pile_index quantity_to_remove], e.g. [1 3]. + if not match: self.state.set_invalid_move(reason="No valid move format found. Use '[pile quantity]'."); return + try: pile_index, quantity = map(int, match.groups()) # Extract pile index and quantity to remove + except ValueError: self.state.set_invalid_move(reason="Action must be two integers: '[pile quantity]'."); return + # Validate the move + if not (0 <= pile_index < len(self.state.game_state["piles"])): self.state.set_invalid_move(reason=f"Pile index {pile_index} is out of range."); return + if quantity <= 0: self.state.set_invalid_move(reason="Must remove at least 1 object."); return + if self.state.game_state["piles"][pile_index] < quantity: self.state.set_invalid_move(reason=f"Cannot remove {quantity} from pile {pile_index} (only {self.state.game_state['piles'][pile_index]} left)."); return + self.state.game_state["piles"][pile_index] -= quantity # Perform the removal + self.state.add_observation(message=f"Player {self.state.current_player_id} removes {quantity} from pile {pile_index}.", observation_type=ta.ObservationType.GAME_ACTION_DESCRIPTION) # Announce the move + + def _check_game_over(self) -> None: + if all(pile == 0 for pile in self.state.game_state["piles"]): + self.state.set_winner(player_id=self.state.current_player_id, reason=f"Player {self.state.current_player_id} took the last object(s)!") + + def _render_piles(self) -> str: + lines = [] + for i, amt in enumerate(self.state.game_state["piles"]): + lines.append(f" pile {i}: {amt}") + return "\n".join(lines) \ No newline at end of file