Add env.py from Openverse builder
This commit is contained in:
249
env.py
Normal file
249
env.py
Normal file
@@ -0,0 +1,249 @@
|
||||
```python
|
||||
import re
|
||||
import random
|
||||
from typing import Any, Dict, Optional, Tuple, List
|
||||
import textarena as ta
|
||||
|
||||
|
||||
class HoneyHeistBattleEnv(ta.Env):
|
||||
"""
|
||||
Environment for "Honey Heist: Battle of the Bears".
|
||||
Two rival bears compete for honey in a turn-based deterministic environment.
|
||||
"""
|
||||
|
||||
def __init__(self, max_turns: Optional[int] = 20):
|
||||
self.max_turns = max_turns
|
||||
# regex patterns for validation
|
||||
self.patterns = {
|
||||
"forage": re.compile(r"^\[Forage:(1|2|3)\]$"),
|
||||
"steal": re.compile(r"^\[Steal:(1|2|3)\]$"),
|
||||
"defend": re.compile(r"^\[Defend\]$"),
|
||||
}
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Helper for extracting boxed content
|
||||
# --------------------------------------------------------------------- #
|
||||
def _extract_answer_content(self, action: str) -> str:
|
||||
"""Extract content inside \boxed{} or return full stripped string."""
|
||||
match = re.search(r'\\boxed\{([^}]*)\}', action, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return action.strip()
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Reset / Initialization
|
||||
# --------------------------------------------------------------------- #
|
||||
def reset(self, num_players: int, seed: Optional[int] = None):
|
||||
"""
|
||||
Resets the environment to an initial state.
|
||||
"""
|
||||
if num_players != 2:
|
||||
raise ValueError("Honey Heist: Battle of the Bears requires exactly 2 players.")
|
||||
|
||||
# Initialize TwoPlayerState
|
||||
self.state = ta.TwoPlayerState(num_players=num_players, seed=seed, max_turns=self.max_turns)
|
||||
|
||||
rng = random.Random(seed)
|
||||
hive_honey = rng.randint(15, 20)
|
||||
|
||||
game_state = {
|
||||
"turn_number": 1,
|
||||
"current_player": "BearA",
|
||||
"hive_honey": hive_honey,
|
||||
"max_turns": self.max_turns,
|
||||
"players": {
|
||||
"BearA": {"stored_honey": 0, "last_action": None, "defending": False, "score": 0},
|
||||
"BearB": {"stored_honey": 0, "last_action": None, "defending": False, "score": 0},
|
||||
},
|
||||
"history": [],
|
||||
"winner": None,
|
||||
"draw": False,
|
||||
"seed": seed if seed is not None else 0,
|
||||
}
|
||||
|
||||
role_mapping = {0: "BearA", 1: "BearB"}
|
||||
self.state.reset(game_state=game_state, player_prompt_function=self._generate_player_prompt, role_mapping=role_mapping)
|
||||
|
||||
self.state.add_observation("Welcome to Honey Heist: Battle of the Bears!", ta.ObservationType.GAME_MESSAGE)
|
||||
self.state.add_observation(
|
||||
f"The hive contains {hive_honey} units of honey. BearA goes first.",
|
||||
ta.ObservationType.GAME_MESSAGE
|
||||
)
|
||||
|
||||
return self.state
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Player Prompt Generator
|
||||
# --------------------------------------------------------------------- #
|
||||
def _generate_player_prompt(self, player_id: int, game_state: Dict[str, Any]) -> str:
|
||||
"""Produce role-appropriate prompt for each bear."""
|
||||
role = "BearA" if player_id == 0 else "BearB"
|
||||
rival = "BearB" if role == "BearA" else "BearA"
|
||||
hive_honey = game_state["hive_honey"]
|
||||
player_honey = game_state["players"][role]["stored_honey"]
|
||||
rival_honey = game_state["players"][rival]["stored_honey"]
|
||||
turn_number = game_state["turn_number"]
|
||||
max_turns = game_state["max_turns"]
|
||||
|
||||
prompt = f"""
|
||||
You are a hungry bear competing for the last honey in the forest.
|
||||
|
||||
- Your goal: End the game with more honey than your rival.
|
||||
- Each turn, choose ONE of the following actions:
|
||||
[Forage:X] Gather X units (1–3) from the hive.
|
||||
[Defend] Protect your honey from theft this turn.
|
||||
[Steal:X] Steal X units (1–3) from your rival if they do not defend.
|
||||
|
||||
Game facts:
|
||||
- Hive honey remaining: {hive_honey}
|
||||
- Your stored honey: {player_honey}
|
||||
- Rival stored honey: {rival_honey}
|
||||
- Turn {turn_number} / {max_turns}
|
||||
|
||||
Format rule:
|
||||
State your reasoning briefly, then put your final action in the following format:
|
||||
|
||||
"Put your final answer within \\boxed{{}} at the end of your response."
|
||||
|
||||
Example valid response:
|
||||
I think foraging is safe early on.
|
||||
\\boxed{{[Forage:3]}}
|
||||
|
||||
Example invalid response:
|
||||
\\boxed{{Forage 3}} <-- Must include brackets and colon.
|
||||
"""
|
||||
return prompt.strip()
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Step - main game logic
|
||||
# --------------------------------------------------------------------- #
|
||||
def step(self, action: str) -> Tuple[bool, ta.Info]:
|
||||
"""
|
||||
Perform a single environment step for the current player.
|
||||
"""
|
||||
player_id = self.state.current_player_id
|
||||
player_role = "BearA" if player_id == 0 else "BearB"
|
||||
opponent_role = "BearB" if player_role == "BearA" else "BearA"
|
||||
|
||||
game_state = self.state.game_state
|
||||
|
||||
# record action to transcript
|
||||
self.state.add_observation(action, ta.ObservationType.PLAYER_ACTION, from_id=player_id, to_id=-1)
|
||||
|
||||
# extract and validate
|
||||
token = self._extract_answer_content(action)
|
||||
if game_state["winner"] or game_state["draw"]:
|
||||
self.state.set_invalid_move("Game is already over.")
|
||||
return self.state.step()
|
||||
|
||||
valid_action = False
|
||||
reason = None
|
||||
if self.patterns["forage"].match(token):
|
||||
valid_action = True
|
||||
act_type = "forage"
|
||||
qty = int(re.findall(r"\d+", token)[0])
|
||||
elif self.patterns["steal"].match(token):
|
||||
valid_action = True
|
||||
act_type = "steal"
|
||||
qty = int(re.findall(r"\d+", token)[0])
|
||||
elif self.patterns["defend"].match(token):
|
||||
valid_action = True
|
||||
act_type = "defend"
|
||||
qty = 0
|
||||
else:
|
||||
reason = "Invalid format, must use [Forage:X], [Steal:X], or [Defend]."
|
||||
|
||||
if not valid_action:
|
||||
self.state.set_invalid_move(reason)
|
||||
return self.state.step()
|
||||
|
||||
# handle action
|
||||
player_data = game_state["players"][player_role]
|
||||
opp_data = game_state["players"][opponent_role]
|
||||
|
||||
if act_type == "forage":
|
||||
if game_state["hive_honey"] < qty:
|
||||
self.state.set_invalid_move("Not enough honey in hive.")
|
||||
return self.state.step()
|
||||
# update hive and player
|
||||
game_state["hive_honey"] -= qty
|
||||
player_data["stored_honey"] += qty
|
||||
player_data["score"] = player_data["stored_honey"]
|
||||
|
||||
elif act_type == "defend":
|
||||
player_data["defending"] = True
|
||||
|
||||
elif act_type == "steal":
|
||||
if opp_data["stored_honey"] < qty:
|
||||
self.state.set_invalid_move("Opponent has insufficient honey.")
|
||||
return self.state.step()
|
||||
if opp_data["defending"]:
|
||||
transfer = 0 # blocked
|
||||
else:
|
||||
transfer = qty
|
||||
opp_data["stored_honey"] -= transfer
|
||||
player_data["stored_honey"] += transfer
|
||||
player_data["score"] = player_data["stored_honey"]
|
||||
opp_data["score"] = opp_data["stored_honey"]
|
||||
|
||||
# update metadata
|
||||
player_data["last_action"] = token
|
||||
entry = {"turn": game_state["turn_number"], "actor": player_role, "action": token}
|
||||
game_state["history"].append(entry)
|
||||
game_state["current_player"] = opponent_role
|
||||
|
||||
# Next player's defending status reset check
|
||||
# Every full round (both bears act), clear defending flags
|
||||
if player_role == "BearB":
|
||||
game_state["players"]["BearA"]["defending"] = False
|
||||
game_state["players"]["BearB"]["defending"] = False
|
||||
|
||||
# increment turn number
|
||||
game_state["turn_number"] += 1
|
||||
|
||||
# Check terminal conditions
|
||||
done = False
|
||||
reason_end = ""
|
||||
if game_state["hive_honey"] <= 0:
|
||||
done = True
|
||||
reason_end = "Hive honey depleted."
|
||||
elif game_state["turn_number"] > game_state["max_turns"]:
|
||||
done = True
|
||||
reason_end = "Maximum turns reached."
|
||||
else:
|
||||
total_honey = (
|
||||
game_state["hive_honey"]
|
||||
+ game_state["players"]["BearA"]["stored_honey"]
|
||||
+ game_state["players"]["BearB"]["stored_honey"]
|
||||
)
|
||||
if total_honey <= 0:
|
||||
done = True
|
||||
reason_end = "All honey depleted."
|
||||
|
||||
if done:
|
||||
a_honey = game_state["players"]["BearA"]["stored_honey"]
|
||||
b_honey = game_state["players"]["BearB"]["stored_honey"]
|
||||
|
||||
if a_honey > b_honey:
|
||||
game_state["winner"] = "BearA"
|
||||
self.state.set_winner(player_id=0, reason=reason_end + " BearA has more honey.")
|
||||
elif b_honey > a_honey:
|
||||
game_state["winner"] = "BearB"
|
||||
self.state.set_winner(player_id=1, reason=reason_end + " BearB has more honey.")
|
||||
else:
|
||||
game_state["draw"] = True
|
||||
self.state.set_draw(reason=reason_end + " Equal honey scores.")
|
||||
|
||||
return self.state.step()
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Standard required Env accessors
|
||||
# --------------------------------------------------------------------- #
|
||||
def get_observation(self) -> Tuple[int, List]:
|
||||
"""Return current player's observations."""
|
||||
return self.state.current_player_id, self.state.observations
|
||||
|
||||
def close(self) -> Tuple[Dict, Dict]:
|
||||
"""Return rewards and game_info at close."""
|
||||
return self.state.rewards, self.state.game_info
|
||||
```
|
||||
Reference in New Issue
Block a user