Initial commit from Openverse builder
This commit is contained in:
53
env.py
Normal file
53
env.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import re
|
||||||
|
from typing import Any, Dict, Optional, Tuple, List
|
||||||
|
|
||||||
|
import textarena as ta
|
||||||
|
|
||||||
|
class NimEnv(ta.Env):
|
||||||
|
def __init__(self, piles: List[int] = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
piles (List[int]): Initial sizes of the piles (e.g. [3, 5, 7]). If None, defaults to [3,4,5].
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.initial_piles = piles if piles is not None else [3, 4, 5]
|
||||||
|
|
||||||
|
def reset(self, num_players: int, seed: Optional[int] = None):
|
||||||
|
self.state = ta.TwoPlayerState(num_players=num_players, seed=seed)
|
||||||
|
self.state.reset(game_state={"piles": self.initial_piles.copy()}, player_prompt_function=self._prompt)
|
||||||
|
self.state.add_observation(message="Current Pile:\n" + self._render_piles(), observation_type=ta.ObservationType.GAME_BOARD)
|
||||||
|
|
||||||
|
def _prompt(self, player_id: int, game_state: Dict[str, Any]) -> str:
|
||||||
|
return (
|
||||||
|
f"Welcome to Nim, Player {player_id}!\nRules:\n- On your turn, remove at least one object from exactly one pile.\n"
|
||||||
|
"- Remove objects with the format '[pile quantity]', e.g. '[0 3]'.\n- Whoever takes the last object(s) wins!"
|
||||||
|
)
|
||||||
|
|
||||||
|
def step(self, action: str) -> Tuple[bool, ta.Info]:
|
||||||
|
self.state.add_observation(from_id=self.state.current_player_id, message=action, observation_type=ta.ObservationType.PLAYER_ACTION)
|
||||||
|
self._execute_move(action) # Execute the move (or mark invalid if the format is incorrect/illegal)
|
||||||
|
self.state.add_observation(message="Current Pile:\n" + self._render_piles(), observation_type=ta.ObservationType.GAME_BOARD) # After the current player moves, send the updated board to the opponent.
|
||||||
|
self._check_game_over() # Check if the game is over
|
||||||
|
return self.state.step() # Proceed to the next turn (or finalize if done)
|
||||||
|
|
||||||
|
def _execute_move(self, action: str) -> None:
|
||||||
|
match = re.compile(r"\[\s*(\d+)\s+(\d+)\s*\]").search(action.strip()) # We'll look for actions in the format [pile_index quantity_to_remove], e.g. [1 3].
|
||||||
|
if not match: self.state.set_invalid_move(reason="No valid move format found. Use '[pile quantity]'."); return
|
||||||
|
try: pile_index, quantity = map(int, match.groups()) # Extract pile index and quantity to remove
|
||||||
|
except ValueError: self.state.set_invalid_move(reason="Action must be two integers: '[pile quantity]'."); return
|
||||||
|
# Validate the move
|
||||||
|
if not (0 <= pile_index < len(self.state.game_state["piles"])): self.state.set_invalid_move(reason=f"Pile index {pile_index} is out of range."); return
|
||||||
|
if quantity <= 0: self.state.set_invalid_move(reason="Must remove at least 1 object."); return
|
||||||
|
if self.state.game_state["piles"][pile_index] < quantity: self.state.set_invalid_move(reason=f"Cannot remove {quantity} from pile {pile_index} (only {self.state.game_state['piles'][pile_index]} left)."); return
|
||||||
|
self.state.game_state["piles"][pile_index] -= quantity # Perform the removal
|
||||||
|
self.state.add_observation(message=f"Player {self.state.current_player_id} removes {quantity} from pile {pile_index}.", observation_type=ta.ObservationType.GAME_ACTION_DESCRIPTION) # Announce the move
|
||||||
|
|
||||||
|
def _check_game_over(self) -> None:
|
||||||
|
if all(pile == 0 for pile in self.state.game_state["piles"]):
|
||||||
|
self.state.set_winner(player_id=self.state.current_player_id, reason=f"Player {self.state.current_player_id} took the last object(s)!")
|
||||||
|
|
||||||
|
def _render_piles(self) -> str:
|
||||||
|
lines = []
|
||||||
|
for i, amt in enumerate(self.state.game_state["piles"]):
|
||||||
|
lines.append(f" pile {i}: {amt}")
|
||||||
|
return "\n".join(lines)
|
||||||
Reference in New Issue
Block a user