Files
Nim-v0/env.py

53 lines
3.5 KiB
Python
Raw Normal View History

2001-01-01 00:00:00 +00:00
import re
from typing import Any, Dict, Optional, Tuple, List
import textarena as ta
class NimEnv(ta.Env):
def __init__(self, piles: List[int] = None):
"""
Args:
piles (List[int]): Initial sizes of the piles (e.g. [3, 5, 7]). If None, defaults to [3,4,5].
"""
super().__init__()
self.initial_piles = piles if piles is not None else [3, 4, 5]
def reset(self, num_players: int, seed: Optional[int] = None):
self.state = ta.TwoPlayerState(num_players=num_players, seed=seed)
self.state.reset(game_state={"piles": self.initial_piles.copy()}, player_prompt_function=self._prompt)
self.state.add_observation(message="Current Pile:\n" + self._render_piles(), observation_type=ta.ObservationType.GAME_BOARD)
def _prompt(self, player_id: int, game_state: Dict[str, Any]) -> str:
return (
f"Welcome to Nim, Player {player_id}!\nRules:\n- On your turn, remove at least one object from exactly one pile.\n"
"- Remove objects with the format '[pile quantity]', e.g. '[0 3]'.\n- Whoever takes the last object(s) wins!"
)
def step(self, action: str) -> Tuple[bool, ta.Info]:
self.state.add_observation(from_id=self.state.current_player_id, message=action, observation_type=ta.ObservationType.PLAYER_ACTION)
self._execute_move(action) # Execute the move (or mark invalid if the format is incorrect/illegal)
self.state.add_observation(message="Current Pile:\n" + self._render_piles(), observation_type=ta.ObservationType.GAME_BOARD) # After the current player moves, send the updated board to the opponent.
self._check_game_over() # Check if the game is over
return self.state.step() # Proceed to the next turn (or finalize if done)
def _execute_move(self, action: str) -> None:
match = re.compile(r"\[\s*(\d+)\s+(\d+)\s*\]").search(action.strip()) # We'll look for actions in the format [pile_index quantity_to_remove], e.g. [1 3].
if not match: self.state.set_invalid_move(reason="No valid move format found. Use '[pile quantity]'."); return
try: pile_index, quantity = map(int, match.groups()) # Extract pile index and quantity to remove
except ValueError: self.state.set_invalid_move(reason="Action must be two integers: '[pile quantity]'."); return
# Validate the move
if not (0 <= pile_index < len(self.state.game_state["piles"])): self.state.set_invalid_move(reason=f"Pile index {pile_index} is out of range."); return
if quantity <= 0: self.state.set_invalid_move(reason="Must remove at least 1 object."); return
if self.state.game_state["piles"][pile_index] < quantity: self.state.set_invalid_move(reason=f"Cannot remove {quantity} from pile {pile_index} (only {self.state.game_state['piles'][pile_index]} left)."); return
self.state.game_state["piles"][pile_index] -= quantity # Perform the removal
self.state.add_observation(message=f"Player {self.state.current_player_id} removes {quantity} from pile {pile_index}.", observation_type=ta.ObservationType.GAME_ACTION_DESCRIPTION) # Announce the move
def _check_game_over(self) -> None:
if all(pile == 0 for pile in self.state.game_state["piles"]):
self.state.set_winner(player_id=self.state.current_player_id, reason=f"Player {self.state.current_player_id} took the last object(s)!")
def _render_piles(self) -> str:
lines = []
for i, amt in enumerate(self.state.game_state["piles"]):
lines.append(f" pile {i}: {amt}")
return "\n".join(lines)