From a114450d5aa6596aab326ac9dcae34210365ee2c Mon Sep 17 00:00:00 2001 From: hannw Date: Wed, 1 Oct 2025 23:17:16 +0000 Subject: [PATCH 1/2] Port just game logic --- .../envs/werewolf/GAME_RULE.md | 75 +++ kaggle_environments/envs/werewolf/__init__.py | 0 .../envs/werewolf/game/__init__.py | 0 .../envs/werewolf/game/actions.py | 268 ++++++++ .../envs/werewolf/game/base.py | 115 ++++ .../envs/werewolf/game/consts.py | 156 +++++ .../envs/werewolf/game/engine.py | 582 ++++++++++++++++++ .../game/night_elimination_manager.py | 101 +++ .../envs/werewolf/game/protocols/__init__.py | 4 + .../envs/werewolf/game/protocols/base.py | 242 ++++++++ .../envs/werewolf/game/protocols/bid.py | 248 ++++++++ .../envs/werewolf/game/protocols/chat.py | 465 ++++++++++++++ .../envs/werewolf/game/protocols/factory.py | 59 ++ .../envs/werewolf/game/protocols/vote.py | 471 ++++++++++++++ .../envs/werewolf/game/records.py | 334 ++++++++++ .../envs/werewolf/game/roles.py | 326 ++++++++++ .../envs/werewolf/game/states.py | 214 +++++++ .../envs/werewolf/game/test_actions.py | 45 ++ pyproject.toml | 10 +- 19 files changed, 3714 insertions(+), 1 deletion(-) create mode 100644 kaggle_environments/envs/werewolf/GAME_RULE.md create mode 100644 kaggle_environments/envs/werewolf/__init__.py create mode 100644 kaggle_environments/envs/werewolf/game/__init__.py create mode 100644 kaggle_environments/envs/werewolf/game/actions.py create mode 100644 kaggle_environments/envs/werewolf/game/base.py create mode 100644 kaggle_environments/envs/werewolf/game/consts.py create mode 100644 kaggle_environments/envs/werewolf/game/engine.py create mode 100644 kaggle_environments/envs/werewolf/game/night_elimination_manager.py create mode 100644 kaggle_environments/envs/werewolf/game/protocols/__init__.py create mode 100644 kaggle_environments/envs/werewolf/game/protocols/base.py create mode 100644 kaggle_environments/envs/werewolf/game/protocols/bid.py create mode 100644 kaggle_environments/envs/werewolf/game/protocols/chat.py create mode 100644 kaggle_environments/envs/werewolf/game/protocols/factory.py create mode 100644 kaggle_environments/envs/werewolf/game/protocols/vote.py create mode 100644 kaggle_environments/envs/werewolf/game/records.py create mode 100644 kaggle_environments/envs/werewolf/game/roles.py create mode 100644 kaggle_environments/envs/werewolf/game/states.py create mode 100644 kaggle_environments/envs/werewolf/game/test_actions.py diff --git a/kaggle_environments/envs/werewolf/GAME_RULE.md b/kaggle_environments/envs/werewolf/GAME_RULE.md new file mode 100644 index 00000000..d5de09ed --- /dev/null +++ b/kaggle_environments/envs/werewolf/GAME_RULE.md @@ -0,0 +1,75 @@ +# Werewolf: Game Rules + +Welcome to Werewolf, a game of social deduction, team collaboration, deception, and survival. Players are secretly assigned roles on one of two teams: the Village or the Werewolves. + +## Roles + +Each player is assigned one of the following roles: + +### Village Team + +The goal of the Village team is to exile all the werewolves. + +* **Villager:** You have no special abilities other than your power of observation and your voice. Use the discussion phase to identify suspicious behavior and vote to exile suspected werewolves. +* **Seer:** Each night, you may choose one player to investigate. You will learn if that player is a Werewolf or not. Your goal is to share this information strategically to help the village without revealing your identity too soon. +* **Doctor:** Each night, you may choose one player to protect. The player you protect cannot be eliminated by the werewolves that night. + +### Werewolf Team + +* **Werewolf:** Your goal is to eliminate villagers until the number of werewolves equals the number of remaining village members. Each night, you and your fellow werewolves will secretly agree on one player to eliminate. + +## Game Phases + +The game alternates between a Night phase and a Day phase. + +### Night Phase 🐺 + +During the night, all players close their eyes. The moderator will ask players with special roles to wake up and perform their actions in this order: + +1. **Doctor:** Chooses one player to protect. +2. **Seer:** Chooses one player to investigate their alignment. +3. **Werewolves:** Silently vote on one player to eliminate. + +### Day Phase ☀️ + +1. **Announcement:** The moderator announces which player, if any, was eliminated during the night. That player is removed from the game and may not speak or participate further. +2. **Discussion:** The surviving players discuss who they think the werewolves are. +3. **Exile Vote:** Players vote on who to exile from the village. The player who receives the most votes is exiled, removed from the game, and their role is revealed. + +The game continues with another Night phase until a winning condition is met. + +## Customizable Rules + +Before the game begins, the following options must be decided. + +### 1. Doctor's Self-Save + +* **Option A (Self-Save Allowed):** The Doctor is allowed to choose themselves as the target of their protection. +* **Option B (No Self-Save):** The Doctor must choose another player to protect. + +### 2. Discussion Protocol + +* **Option A (Parallel Discussion):** All players may speak simultaneously for a number of rounds. +* **Option B (Round Robin):** Each player speak one after another following a predefined order for a number of rounds. + +### 3. Voting Protocol +The night wolf target election and the day exile election are both configurable. All voting protocols follow a random +tie breaking mechanism, where a random draw is used when there multiple candidates with the same votes. + +* **Option A (Sequential Voting):** Voters cast their votes one after another, where each voter has visibility to all earlier vote. +* **Option B (Parallel Voting):** All voters cast their votes simultaneously. + +## Winning the Game + +A team wins as soon as their winning condition is met. + +* **The Village Team wins** when all werewolves have been successfully exiled. +* **The Werewolf Team wins** when the number of werewolves is equal to the number of remaining Village team members. + +### Rewards + +All members of the winning team will receive **1 reward**. This includes players who were eliminated before the end of the game. + +### Tie Game (Forfeit) + +If any back-end inference fails during the game, the match will immediately end. The game will be declared a **tie**, and no players will receive a reward. \ No newline at end of file diff --git a/kaggle_environments/envs/werewolf/__init__.py b/kaggle_environments/envs/werewolf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kaggle_environments/envs/werewolf/game/__init__.py b/kaggle_environments/envs/werewolf/game/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kaggle_environments/envs/werewolf/game/actions.py b/kaggle_environments/envs/werewolf/game/actions.py new file mode 100644 index 00000000..9b0a80da --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/actions.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import re +from functools import lru_cache +from typing import Optional, Tuple + +from pydantic import Field, create_model, field_validator + +from .base import BaseAction, BaseState, PlayerID +from .consts import EventName, PerceivedThreatLevel, Phase +from .records import DoctorHealActionDataEntry, SeerInspectActionDataEntry + +ACTION_EVENT_MAP = {} + + +def register_event(event_name: EventName): + """A class decorator to register an EventName for an Action class.""" + + def decorator(cls): + ACTION_EVENT_MAP[cls.__name__] = event_name + setattr(cls, "event_name", event_name) + return cls + + return decorator + + +_REPLACEMENT_MAP = { + # 'kill' variations + "kill": "eliminate", + "kills": "eliminates", + "killed": "eliminated", + "killing": "eliminating", + "killer": "eliminator", + # 'lynch' variations + "lynch": "exile", + "lynches": "exiles", + "lynched": "exiled", + "lynching": "exiling", + # 'mislynch' variations + "mislynch": "mis-exile", + "mislynches": "mis-exiles", + "mislynched": "mis-exiled", + "mislynching": "mis-exiling", + # 'murder' variations + "murder": "remove", + "murders": "removes", + "murdered": "removed", + "murdering": "removing", + "murderer": "remover", +} + +_CENSOR_PATTERN = re.compile(r"\b(" + "|".join(_REPLACEMENT_MAP.keys()) + r")\b", re.IGNORECASE) + + +# Create a single, case-insensitive regex pattern from all map keys. +def replacer(match): + """ + Finds the correct replacement and applies case based on a specific heuristic. + """ + original_word = match.group(0) + replacement = _REPLACEMENT_MAP[original_word.lower()] + + # Rule 1: Preserve ALL CAPS. + if original_word.isupper(): + return replacement.upper() + + # Rule 2: Handle title-cased words with a more specific heuristic. + if original_word.istitle(): + # Preserve title case if it's the first word of the string OR + # if it's a form like "-ing" which can start a new clause. + return replacement.title() + + # Rule 3: For all other cases (e.g., "Kill" mid-sentence), default to lowercase. + return replacement.lower() + + +def filter_language(text): + """Remove inappropriate/violent language.""" + return _CENSOR_PATTERN.sub(replacer, text) + + +# ------------------------------------------------------------------ # +class Action(BaseAction): + """Root of the discriminated-union tree.""" + + day: int + phase: Phase + actor_id: PlayerID + reasoning: Optional[str] = Field( + default=None, + max_length=4096, + description="The self monologue that illustrate how you arrived at the action. " + "It will be invisible to other players.", + ) + + perceived_threat_level: PerceivedThreatLevel = Field( + default=PerceivedThreatLevel.SAFE, + description="The self perceived threat level you are currently experiencing from other players. " + "The assessment will be invisible to other players.", + ) + error: Optional[str] = None + raw_prompt: Optional[str] = None + raw_completion: Optional[str] = None + + @field_validator("reasoning", mode="before") + @classmethod + def filter_reasoning(cls, v): + if v is None: + return v + return filter_language(v) + + def serialize(self): + return {"action_type": self.__class__.__name__, "kwargs": self.model_dump()} + + @classmethod + def schema_for_player(cls, fields: Tuple = None, new_cls_name=None): + """Many of the fields are for internal game record. This method is used to convert the response schema + to a format friendly for players. + """ + fields = fields or [] + if not new_cls_name: + new_cls_name = cls.__name__ + "Data" + field_definitions = { + field: ( + cls.model_fields[field].annotation, + # Pass the entire FieldInfo object, not just the default value + cls.model_fields[field], + ) + for field in fields + if field in cls.model_fields + } + sub_cls = create_model(new_cls_name, **field_definitions) + subset_schema = sub_cls.model_json_schema() + return subset_schema + + @property + def action_field(self) -> Optional[str]: + return None + + def push_event(self, state: BaseState): + # The following is just for internal record keeping. + data = self.model_dump() + state.push_event( + description=f"Player {self.actor_id}, you submitted {data}", + event_name=ACTION_EVENT_MAP[self.__class__.__name__], + public=False, + visible_to=[], + data=data, + ) + + +# ——— Mix-in for actions that need a target ------------------------ # +class TargetedAction(Action): + target_id: PlayerID = Field(description="The target player's id.") + + @classmethod + @lru_cache(maxsize=10) + def schema_for_player(cls, fields=None, new_cls_name=None): + fields = fields or ["perceived_threat_level", "reasoning", "target_id"] + return super(TargetedAction, cls).schema_for_player(fields, new_cls_name) + + @property + def action_field(self): + return "target_id" + + +# ——— Concrete leaf classes --------------------------------------- # +@register_event(EventName.HEAL_ACTION) +class HealAction(TargetedAction): + def push_event(self, state: BaseState): + action_data = DoctorHealActionDataEntry( + actor_id=self.actor_id, + target_id=self.target_id, + reasoning=self.reasoning, + perceived_threat_level=self.perceived_threat_level, + action=self, + ) + state.push_event( + description=f"Player {self.actor_id}, you chose to heal player {self.target_id}.", + event_name=EventName.HEAL_ACTION, + public=False, + visible_to=[self.actor_id], + data=action_data, + ) + + +@register_event(EventName.INSPECT_ACTION) +class InspectAction(TargetedAction): + def push_event(self, state: BaseState): + action_data = SeerInspectActionDataEntry( + actor_id=self.actor_id, + target_id=self.target_id, + reasoning=self.reasoning, + perceived_threat_level=self.perceived_threat_level, + action=self, + ) + state.push_event( + description=f"Player {self.actor_id}, you chose to inspect player {self.target_id}.", + event_name=EventName.INSPECT_ACTION, + public=False, + visible_to=[self.actor_id], + data=action_data, + ) + + +@register_event(EventName.VOTE_ACTION) +class VoteAction(TargetedAction): + pass + + +@register_event(EventName.ELIMINATE_PROPOSAL_ACTION) +class EliminateProposalAction(VoteAction): + pass + + +@register_event(EventName.DISCUSSION) +class ChatAction(Action): + message: str = Field(default="", max_length=4096) + + @field_validator("message", mode="before") + @classmethod + def filter_message(cls, v): + return filter_language(v) + + @classmethod + @lru_cache(maxsize=10) + def schema_for_player(cls, fields=None, new_cls_name=None): + fields = fields or ["perceived_threat_level", "reasoning", "message"] + return super(ChatAction, cls).schema_for_player(fields, new_cls_name) + + @property + def action_field(self): + return "message" + + +@register_event(EventName.NOOP_ACTION) +class NoOpAction(Action): + pass + + +# ------------------------------------------------------------ # +@register_event(EventName.BID_ACTION) +class BidAction(Action): + """ + An amount the actor is willing to pay this round. + Currency unit can be generic 'chips' or role-specific. + """ + + amount: int = Field(ge=0) + + @classmethod + @lru_cache(maxsize=10) + def schema_for_player(cls, fields=None, new_cls_name=None): + fields = fields or ["perceived_threat_level", "reasoning", "amount"] + return super(BidAction, cls).schema_for_player(fields, new_cls_name) + + @property + def action_field(self): + return "amount" + + +ACTIONS = [EliminateProposalAction, HealAction, InspectAction, VoteAction, ChatAction, BidAction, NoOpAction] + +ACTION_REGISTRY = {action.__name__: action for action in ACTIONS} + + +def create_action(serialized): + return ACTION_REGISTRY[serialized["action_type"]](**serialized.get("kwargs", {})) diff --git a/kaggle_environments/envs/werewolf/game/base.py b/kaggle_environments/envs/werewolf/game/base.py new file mode 100644 index 00000000..f36fe99b --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/base.py @@ -0,0 +1,115 @@ +from abc import ABC, abstractmethod +from typing import Annotated, Any, Dict, List, Optional, Protocol, Type + +from pydantic import BaseModel, StringConstraints + +from .consts import EVENT_HANDLER_FOR_ATTR_NAME, MODERATOR_ID, EventName + +# The ID regex supports Unicode letters (\p{L}), numbers (\p{N}) and common symbol for ID. +ROBUST_ID_REGEX = r"^[\p{L}\p{N} _.-]+$" + +PlayerID = Annotated[str, StringConstraints(pattern=ROBUST_ID_REGEX, min_length=1, max_length=128)] + + +class BasePlayer(BaseModel, ABC): + id: PlayerID + """The unique id of the player. Also, how the player is referred to in the game.""" + + alive: bool = True + + @abstractmethod + def set_role_state(self, key, value): + """Set role related state, which is a dict.""" + + @abstractmethod + def get_role_state(self, key, default=None): + """Get role related state.""" + + +class BaseAction(BaseModel): + pass + + +class BaseState(BaseModel): + @abstractmethod + def push_event( + self, + description: str, + event_name: EventName, + public: bool, + visible_to: Optional[List[PlayerID]] = None, + data: Any = None, + source=MODERATOR_ID, + ): + """Publish an event.""" + + +class BaseEvent(BaseModel): + event_name: EventName + + +class BaseModerator(ABC): + @abstractmethod + def advance(self, player_actions: Dict[PlayerID, BaseAction]): + """Move one Kaggle environment step further. This is to be used within Kaggle 'interpreter'.""" + + @abstractmethod + def request_action( + self, + action_cls: Type[BaseAction], + player_id: PlayerID, + prompt: str, + data=None, + event_name=EventName.MODERATOR_ANNOUNCEMENT, + ): + """This can be used by event handler to request action from a player.""" + + @abstractmethod + def record_night_save(self, doctor_id: str, target_id: str): + """To be used by a special Role to perform night save. This is implemented in moderator level, since + coordinating between safe and night elimination is cross role activity. + """ + + @property + @abstractmethod + def state(self) -> BaseState: + """Providing current state of the game, including player info, event messaging and caching.""" + + +def on_event(event_type: EventName): + def decorator(func): + setattr(func, EVENT_HANDLER_FOR_ATTR_NAME, event_type) + return func + + return decorator + + +class EventHandler(Protocol): + """A callable triggered by an event.""" + + def __call__(self, event: BaseEvent) -> Any: + pass + + +class RoleEventHandler(Protocol): + """A role specific event handler.""" + + def __call__(self, me: BasePlayer, moderator: BaseModerator, event: BaseEvent) -> Any: + pass + + +class BaseRole(BaseModel, ABC): + """Special abilities should be implemented as RoleEventHandler in each subclass of BaseRole, so that Moderator + doesn't need to be overwhelmed by role specific logic. + """ + + def get_event_handlers(self) -> Dict[EventName, RoleEventHandler]: + """Inspects the role instance and collects all methods decorated with @on_event""" + handlers = {} + for attr_name in dir(self): + if not attr_name.startswith("__"): + attr = getattr(self, attr_name) + if callable(attr) and hasattr(attr, EVENT_HANDLER_FOR_ATTR_NAME): + event_type = getattr(attr, EVENT_HANDLER_FOR_ATTR_NAME) + handlers[event_type] = attr + return handlers diff --git a/kaggle_environments/envs/werewolf/game/consts.py b/kaggle_environments/envs/werewolf/game/consts.py new file mode 100644 index 00000000..ee2ff9c7 --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/consts.py @@ -0,0 +1,156 @@ +from enum import Enum + +MODERATOR_ID = "MODERATOR" + + +class StrEnum(str, Enum): + def __str__(self): + return str(self.value) + + def __repr__(self): + return str(self.value) + + +class Phase(StrEnum): + DAY = "Day" + NIGHT = "Night" + GAME_OVER = "Game Over" + + +DAY, NIGHT, GAME_OVER = Phase + + +class PhaseDivider(StrEnum): + NIGHT_START = "NIGHT START" + NIGHT_END = "NIGHT END" + DAY_START = "DAY START" + DAY_END = "DAY END" + NIGHT_VOTE_START = "NIGHT VOTE START" + NIGHT_VOTE_END = "NIGHT VOTE END" + DAY_CHAT_START = "DAY CHAT START" + DAY_CHAT_END = "DAY CHAT END" + DAY_VOTE_START = "DAY VOTE START" + DAY_VOTE_END = "DAY VOTE END" + + +class Team(StrEnum): + VILLAGERS = "Villagers" + WEREWOLVES = "Werewolves" + + +class RoleConst(StrEnum): + VILLAGER = "Villager" + WEREWOLF = "Werewolf" + DOCTOR = "Doctor" + SEER = "Seer" + + +class ActionType(StrEnum): + NO_OP = "NO_OP" + NIGHT_KILL_VOTE = "NIGHT_KILL_VOTE" + NIGHT_SAVE_TARGET = "NIGHT_SAVE_TARGET" + NIGHT_INSPECT_TARGET = "NIGHT_INSPECT_TARGET" + DAY_DISCUSS = "DAY_DISCUSS" + DAY_LYNCH_VOTE = "DAY_LYNCH_VOTE" + + +class PerceivedThreatLevel(StrEnum): + SAFE = "SAFE" + UNEASY = "UNEASY" + DANGER = "DANGER" + + +class EnvInfoKeys: + MODERATOR_OBS = "MODERATOR_OBSERVATION" + GAME_END = "GAME_END" + + +class ObsKeys: + RAW_OBSERVATION = "raw_observation" + + +class DetailedPhase(StrEnum): + def __new__(cls, value, category: Phase): + # This creates the string object from the value + obj = str.__new__(cls, value) + # This sets the _value_ attribute, which is what Enum uses internally + obj._value_ = value + # Now, attach your custom category attribute + obj.category = category + return obj + + # Night Phases + NIGHT_START = "NIGHT_START", NIGHT + NIGHT_AWAIT_ACTIONS = "NIGHT_AWAIT_ACTIONS", NIGHT + NIGHT_CONCLUDE = "NIGHT_CONCLUDE", NIGHT + + # Day Phases + DAY_START = "DAY_START", DAY + + DAY_BIDDING_AWAIT = "DAY_BIDDING_AWAIT", DAY + DAY_BIDDING_CONCLUDE = "DAY_BIDDING_CONCLUDE", DAY + + DAY_CHAT_AWAIT = "DAY_CHAT_AWAIT", DAY + DAY_CHAT_CONCLUDE = "DAY_CHAT_CONCLUDE", DAY + + DAY_VOTING_START = "DAY_VOTING_START", DAY + DAY_VOTING_AWAIT = "DAY_VOTING_AWAIT", DAY + DAY_VOTING_CONCLUDE = "DAY_VOTING_CONCLUDE", DAY + + # Game Over + GAME_OVER = "GAME_OVER", GAME_OVER + + +EVENT_HANDLER_FOR_ATTR_NAME = "_event_handler_for" + + +class EventName(str, Enum): + GAME_START = "game_start" + PHASE_CHANGE = "phase_change" + PHASE_DIVIDER = "phase_divider" + ELIMINATION = "elimination" + + VOTE_REQUEST = "vote_request" + VOTE_ACTION = "vote_action" + VOTE_RESULT = "vote_result" + VOTE_ORDER = "vote_order" + + HEAL_REQUEST = "heal_request" + HEAL_ACTION = "heal_action" + HEAL_RESULT = "heal_result" + + INSPECT_REQUEST = "inspect_request" + INSPECT_ACTION = "inspect_action" + INSPECT_RESULT = "inspect_result" + + CHAT_REQUEST = "chat_request" + DISCUSSION = "discussion" + DISCUSSION_ORDER = "discussion_order" + + BID_REQEUST = "bid_request" + BID_RESULT = "bid_result" + BID_ACTION = "bid_action" + BIDDING_INFO = "bidding_info" + + ELIMINATE_PROPOSAL_ACTION = "eliminate_proposal_action" + NOOP_ACTION = "no_op_action" + + GAME_END = "game_end" + MODERATOR_ANNOUNCEMENT = "moderator_announcement" + ACTION_CONFIRMATION = "action_confirmation" + ERROR = "error" + NIGHT_START = "night_start" + DAY_START = "day_start" + NIGHT_END = "night_end" + DAY_END = "day_end" + + +class RevealLevel(StrEnum): + NO_REVEAL = "no_reveal" + """No reveal during elimination.""" + + TEAM = "team" + """Only reveal team during elimination.""" + + ROLE = "role" + """Reveal detailed role information during elimination.""" diff --git a/kaggle_environments/envs/werewolf/game/engine.py b/kaggle_environments/envs/werewolf/game/engine.py new file mode 100644 index 00000000..be5c5091 --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/engine.py @@ -0,0 +1,582 @@ +import json +from typing import Dict, List, Protocol, Sequence, Type + +from .actions import Action, BidAction, ChatAction, VoteAction +from .base import BaseModerator, PlayerID +from .consts import DetailedPhase, PhaseDivider, RevealLevel, RoleConst, Team +from .night_elimination_manager import NightEliminationManager +from .protocols.base import DiscussionProtocol, VotingProtocol +from .protocols.chat import BiddingDiscussion +from .records import ( + DayExileElectedDataEntry, + EventName, + GameEndResultsDataEntry, + GameStartDataEntry, + GameStartRoleDataEntry, + RequestWerewolfVotingDataEntry, + WerewolfNightEliminationElectedDataEntry, +) +from .roles import Player +from .states import GameState + + +class ActionQueue: + """A data structure for managing player ids in action specific queues.""" + + def __init__(self): + self._action_queue: Dict[str, List[PlayerID]] = {} + + def clear(self): + self._action_queue = {} + + def append(self, action_cls: Type[Action], player_id: PlayerID): + action_type = action_cls.__name__ + self._action_queue.setdefault(action_type, []) + if player_id in self._action_queue[action_type]: + raise ValueError(f"player {player_id} is already in the action queue. ") + self._action_queue[action_type].append(player_id) + + def extend(self, action_cls: Type[Action], player_ids: Sequence[PlayerID]): + for player_id in player_ids: + self.append(action_cls, player_id) + + def get(self, action_cls: Type[Action]) -> List[str]: + """return a list of player_id for the selected action.""" + return self._action_queue.get(action_cls.__name__, []) + + def get_active_player_ids(self) -> List[PlayerID]: + all_players = set() + for players in self._action_queue.values(): + all_players.update(players) + return list(all_players) + + +def phase_handler(phase: DetailedPhase): + """Decorator to register a method as a handler for a specific game phase.""" + + def decorator(func): + setattr(func, "_phase_handler_for", phase) + return func + + return decorator + + +class PhaseHandler(Protocol): + def __call__(self, player_actions: Dict[PlayerID, Action]) -> DetailedPhase: + pass + + +class Moderator(BaseModerator): + """Drives the finite-state machine for the game.""" + + def __init__( + self, + state: GameState, + discussion: DiscussionProtocol, + day_voting: VotingProtocol, # Renamed for clarity + night_voting: VotingProtocol, + night_elimination_reveal_level: RevealLevel = RevealLevel.ROLE, + day_exile_reveal_level: RevealLevel = RevealLevel.ROLE, + ): + self._state = state + self.discussion = discussion + self.day_voting = day_voting + self.night_voting = night_voting + + self._night_elimination_reveal_level = night_elimination_reveal_level + self._day_exile_reveal_level = day_exile_reveal_level + + self._active_night_roles_queue: List[Player] = [] + self._night_elimination_manager = NightEliminationManager( + self._state, reveal_level=self._night_elimination_reveal_level + ) + self._action_queue = ActionQueue() + + # This is for registering role specific event handling + self._register_player_handlers() + + # below is the state transition function table + # each transition function has the signature tr_func(actions: List[Action]) where the input is a list of actions + # with the length the same as the number of agents + self.detailed_phase = DetailedPhase.NIGHT_START + self._phase_handlers: Dict[DetailedPhase, PhaseHandler] = {} + self._register_phase_handlers() + + self._make_initial_announcements() + + @property + def state(self) -> GameState: + return self._state + + def _make_initial_announcements(self): + data = GameStartDataEntry( + player_ids=[p.id for p in self.state.alive_players()], + number_of_players=len(self.state.alive_players()), + role_counts=self.state.alive_player_counts_per_role(), + team_member_counts=self.state.alive_player_counts_per_team(), + day_discussion_protocol_name=self.discussion.__class__.__name__, + day_discussion_display_name=self.discussion.display_name, + day_discussion_protocol_rule=self.discussion.rule, + night_werewolf_discussion_protocol_name=self.night_voting.__class__.__name__, + night_werewolf_discussion_display_name=self.night_voting.display_name, + night_werewolf_discussion_protocol_rule=self.night_voting.rule, + day_voting_protocol_name=self.day_voting.__class__.__name__, + day_voting_display_name=self.day_voting.display_name, + day_voting_protocol_rule=self.day_voting.rule, + ) + + role_msg = "\n".join( + ["The following explain the function of each role."] + + [ + f" * Role name {role.name.value} - team {role.team.value} - {role.descriptions}" + for role in self.state.all_unique_roles + ] + ) + + if self._day_exile_reveal_level == RevealLevel.ROLE: + day_exile_reveal_msg = "If a player is exiled in the day, their role will be revealed." + elif self._day_exile_reveal_level == RevealLevel.TEAM: + day_exile_reveal_msg = "If a player is exiled in the day, their team will be revealed." + elif self._day_exile_reveal_level == RevealLevel.NO_REVEAL: + day_exile_reveal_msg = "If a player is exiled in the day, their team and role will NOT be revealed." + else: + raise ValueError(f"Unsupported day_exile_reveal_level = {self._day_exile_reveal_level}.") + + if self._night_elimination_reveal_level == RevealLevel.ROLE: + night_elimination_reveal_msg = "If a player is eliminated at night, their role will be revealed." + elif self._night_elimination_reveal_level == RevealLevel.TEAM: + night_elimination_reveal_msg = "If a player is eliminated at night, their team will be revealed." + elif self._night_elimination_reveal_level == RevealLevel.NO_REVEAL: + night_elimination_reveal_msg = ( + "If a player is eliminated at night, their team and role will NOT be revealed." + ) + else: + raise ValueError(f"Unsupported night_elimination_reveal_level = {self._night_elimination_reveal_level}.") + + description = "\n - ".join( + [ + "Werewolf game begins.", + f"**Player Roster:** {data.player_ids}", + f"**Alive Players:** {data.number_of_players}.", + f"**Role Counts:** {data.role_counts}.", + f"**Alive Team Member:** {data.team_member_counts}", + f"**Day Discussion:** {data.day_discussion_display_name}. {data.day_discussion_protocol_rule}", + f"**Day Exile Vote:** {data.day_voting_display_name}. {data.day_voting_protocol_rule}", + f"**Night Werewolf Vote:** {data.night_werewolf_discussion_display_name}. {data.night_werewolf_discussion_protocol_rule}", + role_msg, + day_exile_reveal_msg, + night_elimination_reveal_msg, + ] + ) + self.state.push_event( + description=description, event_name=EventName.MODERATOR_ANNOUNCEMENT, public=True, data=data + ) + # add role specific announcements + for player in self.state.alive_players(): + data = GameStartRoleDataEntry( + player_id=player.id, team=player.role.team, role=player.role.name, rule_of_role=player.role.descriptions + ) + self.state.push_event( + description=f'Your player id is "{data.player_id}". Your team is "{data.team}". Your role is "{data.role}".\n' + f"The rule of your role: {data.rule_of_role}", + event_name=EventName.GAME_START, + public=False, + visible_to=[player.id], + data=data, + ) + + def _register_phase_handlers(self): + """Collects all methods decorated with @phase_handler.""" + for attr_name in dir(self): + attr = getattr(self, attr_name) + if callable(attr) and hasattr(attr, "_phase_handler_for"): + phase = getattr(attr, "_phase_handler_for") + self._phase_handlers[phase] = attr + + def _register_player_handlers(self): + for player in self.state.players: + for event_name, handlers in player.get_event_handlers(self).items(): + for handler in handlers: + self.state.register_event_handler(event_name, handler) + + def request_action( + self, + action_cls: Type[Action], + player_id: PlayerID, + prompt: str, + data=None, + event_name=EventName.MODERATOR_ANNOUNCEMENT, + ): + """A public method for listeners to add a player to the action queue.""" + self._action_queue.append(action_cls, player_id) + # Create the corresponding data entry to prompt the player + self.state.push_event( + description=prompt, event_name=event_name, public=False, visible_to=[player_id], data=data + ) + + def confirm_action(self, player_actions: Dict[PlayerID, Action]): + for action in player_actions.values(): + # moderator confirming the action with players + action.push_event(state=self.state) + + def set_next_phase(self, new_detailed_phase: DetailedPhase, add_one_day: bool = False): + """Note: phase change is not the same as phase start, still need phase start at each block""" + old_detailed_phase = self.detailed_phase + self.detailed_phase = new_detailed_phase + self.state.detailed_phase = new_detailed_phase + self.state.phase = new_detailed_phase.category + + if add_one_day: + self.state.day_count += 1 + + self.state.push_event( + description=f"Transitioning from {old_detailed_phase} to {new_detailed_phase}.", + event_name=EventName.PHASE_CHANGE, + public=False, + ) + + def get_active_player_ids(self) -> List[PlayerID]: + return self._action_queue.get_active_player_ids() + + def record_night_save(self, doctor_id: PlayerID, target_id: PlayerID): + self._night_elimination_manager.record_save(doctor_id, target_id) + + def _call_handler(self, player_actions: Dict[PlayerID, Action]): + current_handler = self._phase_handlers.get(self.detailed_phase) + if current_handler: + next_detailed_phase = current_handler(player_actions) + else: + raise ValueError(f"Unhandled detailed_phase: {self.detailed_phase}") + add_one_day = True if next_detailed_phase == DetailedPhase.DAY_START else False + self.set_next_phase(next_detailed_phase, add_one_day=add_one_day) + + def advance(self, player_actions: Dict[PlayerID, Action]): + self.confirm_action(player_actions) + # Process the incoming actions for the current phase. + self._call_handler(player_actions) + + # Loop through automatic state transitions (those that don't need agent actions) + # This continues until the game is over or requires new agent input. + # this logic is required since Environments in core.py requires that there are some players being ACTIVE to + # continue. Otherwise, if all INACTIVE the game is marked done. + while not self.get_active_player_ids() and not self.is_game_over(): + self._call_handler({}) + + # After all transitions, check for game over. + if self.is_game_over() and self.detailed_phase != DetailedPhase.GAME_OVER: + # clear action queue + self._action_queue.clear() + self.set_next_phase(DetailedPhase.GAME_OVER) + self._determine_and_log_winner() + + @phase_handler(DetailedPhase.NIGHT_START) + def _handle_night_start(self, player_actions: Dict[PlayerID, Action]) -> DetailedPhase: + self._action_queue.clear() + self.state.add_phase_divider(PhaseDivider.NIGHT_START) + self.state.push_event( + description=f"Night {self.state.day_count} begins!", event_name=EventName.NIGHT_START, public=True + ) + + # initialize werewolves voting + self.state.add_phase_divider(PhaseDivider.NIGHT_VOTE_START) + alive_werewolves = self.state.alive_players_by_role(RoleConst.WEREWOLF) + alive_werewolf_ids = list({p.id for p in alive_werewolves}) + potential_targets = self.state.alive_players_by_team(Team.VILLAGERS) # Target non-werewolves + + data = RequestWerewolfVotingDataEntry( + valid_targets=[f"{p.id}" for p in potential_targets], + alive_werewolve_player_ids=[f"{p.id}" for p in alive_werewolves], + voting_protocol_name=self.night_voting.__class__.__name__, + voting_protocol_rule=self.night_voting.rule, + action_json_schema=json.dumps(VoteAction.schema_for_player()), + ) + self.state.push_event( + description=f"Wake up Werewolves. Your fellow alive werewolves are: {data.alive_werewolve_player_ids}. " + f"Choose one target player to eliminate tonight. " + f"The voting rule ({data.voting_protocol_name}): {data.voting_protocol_rule} " + f"Who would you like to eliminate tonight? Options: {data.valid_targets}.", + event_name=EventName.VOTE_REQUEST, + public=False, + visible_to=alive_werewolf_ids, + data=data, + ) + self.night_voting.begin_voting( + state=self.state, alive_voters=alive_werewolves, potential_targets=potential_targets + ) + return DetailedPhase.NIGHT_AWAIT_ACTIONS + + @phase_handler(DetailedPhase.NIGHT_AWAIT_ACTIONS) + def _handle_night_await_actions(self, player_actions: Dict[PlayerID, Action]) -> DetailedPhase: + # Process werewolf votes + werewolf_voters_expected = self._action_queue.get(VoteAction) + if werewolf_voters_expected: + self.night_voting.collect_votes(player_actions, self.state, werewolf_voters_expected) + + self._action_queue.clear() + + if not self.night_voting.done(): + next_ww_voters = self.night_voting.get_next_voters() + self._action_queue.extend(VoteAction, next_ww_voters) + vote_action_queue = self._action_queue.get(VoteAction) + alive_werewolves_still_to_vote = [ + p for p in self.state.alive_players_by_role(RoleConst.WEREWOLF) if p.id in vote_action_queue + ] + if alive_werewolves_still_to_vote: + for ww_voter in alive_werewolves_still_to_vote: + prompt = self.night_voting.get_voting_prompt(self.state, ww_voter.id) + self.state.push_event( + description=prompt, + event_name=EventName.VOTE_REQUEST, + public=False, + visible_to=[ww_voter.id], + visible_in_ui=False, + ) + return DetailedPhase.NIGHT_AWAIT_ACTIONS + else: + return DetailedPhase.NIGHT_CONCLUDE + + @phase_handler(DetailedPhase.NIGHT_CONCLUDE) + def _handle_night_conclude(self, player_actions: Dict[PlayerID, Action]) -> DetailedPhase: + werewolf_target_id = self.night_voting.get_elected() + + data = WerewolfNightEliminationElectedDataEntry(elected_target_player_id=werewolf_target_id) + self.state.push_event( + description=f'Werewolves elected to eliminate player "{data.elected_target_player_id}".', + event_name=EventName.VOTE_RESULT, + public=False, + visible_to=[p.id for p in self.state.alive_players_by_team(Team.WEREWOLVES)], + data=data, + ) + + self._night_elimination_manager.resolve_elimination(werewolf_target_id) + + self.night_voting.reset() + self._night_elimination_manager.reset() + + self.state.add_phase_divider(PhaseDivider.NIGHT_VOTE_END) + self.state.add_phase_divider(PhaseDivider.NIGHT_END) + return DetailedPhase.DAY_START + + @phase_handler(DetailedPhase.DAY_START) + def _handle_day_start(self, player_actions: Dict[PlayerID, Action]) -> DetailedPhase: + self.state.add_phase_divider(PhaseDivider.DAY_START) + self._action_queue.clear() + self.night_step = 0 # Reset night step counter + + self.state.push_event( + description=f"Day {self.state.day_count} begins.", event_name=EventName.DAY_START, public=True + ) + + self.state.push_event( + description=f"Villagers, let's decide who to exile. The discussion rule is: {self.discussion.rule}", + event_name=EventName.MODERATOR_ANNOUNCEMENT, + public=True, + data={"discussion_rule": self.discussion.rule}, + ) + + self.state.add_phase_divider(PhaseDivider.DAY_CHAT_START) + self.discussion.begin(self.state) + + # Check if the protocol starts with bidding + if isinstance(self.discussion, BiddingDiscussion): + return DetailedPhase.DAY_BIDDING_AWAIT + else: + return DetailedPhase.DAY_CHAT_AWAIT + + @phase_handler(DetailedPhase.DAY_BIDDING_AWAIT) + def _handle_day_bidding_await(self, player_actions: Dict[PlayerID, Action]) -> DetailedPhase: + current_bidders = self._action_queue.get(BidAction) + self._action_queue.clear() + + # The protocol processes bid actions + self.discussion.process_actions(list(player_actions.values()), current_bidders, self.state) + + # We need to explicitly check if the bidding sub-phase is over + # This requires a reference to the bidding protocol within BiddingDiscussion + assert isinstance(self.discussion, BiddingDiscussion) + bidding_protocol = self.discussion.bidding + if bidding_protocol.is_finished(self.state): + return DetailedPhase.DAY_BIDDING_CONCLUDE + else: + # Bidding is not over (e.g., sequential auction), get next bidders + next_bidders = self.discussion.speakers_for_tick(self.state) + self._action_queue.extend(BidAction, next_bidders) + self.discussion.prompt_speakers_for_tick(self.state, next_bidders) + return DetailedPhase.DAY_BIDDING_AWAIT + + @phase_handler(DetailedPhase.DAY_BIDDING_CONCLUDE) + def _handle_day_bidding_conclude(self, player_actions: Dict[PlayerID, Action]) -> DetailedPhase: + self.state.push_event( + description="Bidding has concluded. The discussion will now begin.", + event_name=EventName.PHASE_CHANGE, + public=True, + ) + self.discussion.bidding.reset() + return DetailedPhase.DAY_CHAT_AWAIT + + @phase_handler(DetailedPhase.DAY_CHAT_AWAIT) + def _handle_day_chat_await(self, player_actions: Dict[PlayerID, Action]) -> DetailedPhase: + speaker_ids = self._action_queue.get(ChatAction) + self._action_queue.clear() + self.discussion.process_actions(list(player_actions.values()), speaker_ids, self.state) + + if self.discussion.is_discussion_over(self.state): + return DetailedPhase.DAY_CHAT_CONCLUDE + else: + # Discussion is not over. Check if we need to go back to bidding action and phase. + if isinstance(self.discussion, BiddingDiscussion) and self.discussion.is_bidding_phase(): + return DetailedPhase.DAY_BIDDING_AWAIT + # Get the next active players (either bidders or the next speaker) + next_actors = self.discussion.speakers_for_tick(self.state) + self._action_queue.extend(ChatAction, next_actors) + self.discussion.prompt_speakers_for_tick(self.state, next_actors) + return DetailedPhase.DAY_CHAT_AWAIT + + @phase_handler(DetailedPhase.DAY_CHAT_CONCLUDE) + def _handle_day_chat_conclude(self, player_actions: Dict[PlayerID, Action]) -> DetailedPhase: + self.state.push_event( + description="Daytime discussion has concluded. Moving to day vote.", + event_name=EventName.PHASE_CHANGE, + public=True, + ) + self.discussion.reset() + self.state.add_phase_divider(PhaseDivider.DAY_CHAT_END) + return DetailedPhase.DAY_VOTING_START + + @phase_handler(DetailedPhase.DAY_VOTING_START) + def _handle_day_voting_start(self, player_actions: Dict[PlayerID, Action]) -> DetailedPhase: + self.state.add_phase_divider(PhaseDivider.DAY_VOTE_START) + alive_players = self.state.alive_players() + self.day_voting.begin_voting(self.state, alive_players, alive_players) + self.state.push_event( + description="Voting phase begins. We will decide who to exile today." + f"\nDay voting Rule: {self.day_voting.rule}" + f"\nCurrent alive players are: {[player.id for player in alive_players]}", + event_name=EventName.MODERATOR_ANNOUNCEMENT, + public=True, + data={"voting_rule": self.day_voting.rule}, + ) + return DetailedPhase.DAY_VOTING_AWAIT + + @phase_handler(DetailedPhase.DAY_VOTING_AWAIT) + def _handle_day_voting_await(self, player_actions: Dict[PlayerID, Action]) -> DetailedPhase: + vote_queue = self._action_queue.get(VoteAction) + self.day_voting.collect_votes(player_actions, self.state, vote_queue) + self._action_queue.clear() # Clear previous voters + + if self.day_voting.done(): + return DetailedPhase.DAY_VOTING_CONCLUDE + else: + next_voters_ids = self.day_voting.get_next_voters() + self._action_queue.extend(VoteAction, next_voters_ids) + if next_voters_ids: + for voter_id in next_voters_ids: + player = self.state.get_player_by_id(voter_id) + if player and player.alive: + prompt = self.day_voting.get_voting_prompt(self.state, voter_id) + self.state.push_event( + description=prompt, + event_name=EventName.VOTE_REQUEST, + public=False, + visible_to=[voter_id], + visible_in_ui=False, + ) + return DetailedPhase.DAY_VOTING_AWAIT + + @phase_handler(DetailedPhase.DAY_VOTING_CONCLUDE) + def _handle_day_voting_conclude(self, player_actions: Dict[PlayerID, Action]) -> DetailedPhase: + exiled_player_id = self.day_voting.get_elected() + if exiled_player_id: + exiled_player = self.state.get_player_by_id(exiled_player_id) + if exiled_player: + self.state.eliminate_player(exiled_player_id) + + role = None + team = None + description = f'Player "{exiled_player_id}" is exiled by vote.' + if self._day_exile_reveal_level == RevealLevel.ROLE: + role = exiled_player.role.name + team = exiled_player.role.team + description = ( + f'Player "{exiled_player_id}" in team {team} is exiled by vote. The player is a {role}.' + ) + elif self._day_exile_reveal_level == RevealLevel.TEAM: + team = exiled_player.role.team + description = f'Player "{exiled_player_id}" in team {team} is exiled by vote.' + + data = DayExileElectedDataEntry( + elected_player_id=exiled_player_id, elected_player_role_name=role, elected_player_team_name=team + ) + self.state.push_event(description=description, event_name=EventName.ELIMINATION, public=True, data=data) + else: + self.state.push_event( + description="The vote resulted in no exile (e.g., a tie, no majority, or all abstained).", + event_name=EventName.VOTE_RESULT, + public=True, + data={"vote_type": "day_exile", "outcome": "no_exile", "reason": "tie_or_no_majority"}, + ) + + self.day_voting.reset() + self.state.add_phase_divider(PhaseDivider.DAY_VOTE_END) + self.state.add_phase_divider(PhaseDivider.DAY_END) + return DetailedPhase.NIGHT_START + + def _determine_and_log_winner(self): + # Check if a GAME_END entry already exists + game_end_event = self.state.get_event_by_name(EventName.GAME_END) + if game_end_event: + return # Winner already logged for this day count + + wolves = [p for p in self.state.alive_players() if p.role.team == Team.WEREWOLVES] + villagers = [p for p in self.state.alive_players() if p.role.team == Team.VILLAGERS] + + if not wolves: + winner_team = Team.VILLAGERS.value + winner_message = "Game Over: Villagers Win!" + reason = "Reason: All werewolves exiled." + scores = {p.id: 1 for p in self.state.get_players_by_team(team=Team.VILLAGERS)} + scores.update({p.id: 0 for p in self.state.get_players_by_team(team=Team.WEREWOLVES)}) + winner_ids = [p.id for p in self.state.get_players_by_team(Team.VILLAGERS)] + loser_ids = [p.id for p in self.state.get_players_by_team(Team.WEREWOLVES)] + else: + winner_team = Team.WEREWOLVES.value + winner_message = "Game Over: Werewolves Win!" + reason = f"Reason: len(werewolves) >= len(villagers). Final counts: len(werewolves)={len(wolves)}, len(villagers)={len(villagers)})." + scores = {p.id: 1 for p in self.state.get_players_by_team(team=Team.WEREWOLVES)} + scores.update({p.id: 0 for p in self.state.get_players_by_team(team=Team.VILLAGERS)}) + loser_ids = [p.id for p in self.state.get_players_by_team(Team.VILLAGERS)] + winner_ids = [p.id for p in self.state.get_players_by_team(Team.WEREWOLVES)] + + data = GameEndResultsDataEntry( + winner_team=winner_team, + winner_ids=winner_ids, + loser_ids=loser_ids, + scores=scores, + reason=reason, + last_day=self.state.day_count, + last_phase=self.state.phase.value, + survivors_until_last_round_and_role={p.id: p.role.name.value for p in self.state.alive_players()}, + all_players_and_role={p.id: p.role.name.value for p in self.state.players}, + elimination_info=self.state.get_elimination_info(), + all_players=[p.model_dump() for p in self.state.players], + ) + + self.state.push_event( + description=f"{winner_message}\n{reason}\nScores: {scores}\n" + f"Survivors: {data.survivors_until_last_round_and_role}\n" + f"All player roles: {data.all_players_and_role}", + event_name=EventName.GAME_END, + public=True, + data=data, + ) + + def is_game_over(self) -> bool: + if self.detailed_phase == DetailedPhase.GAME_OVER: + return True + wolves = self.state.alive_players_by_team(Team.WEREWOLVES) + villagers = self.state.alive_players_by_team(Team.VILLAGERS) + if not wolves and villagers: + return True + if wolves and len(wolves) >= len(villagers): + return True + return False diff --git a/kaggle_environments/envs/werewolf/game/night_elimination_manager.py b/kaggle_environments/envs/werewolf/game/night_elimination_manager.py new file mode 100644 index 00000000..9a2a4dbd --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/night_elimination_manager.py @@ -0,0 +1,101 @@ +from typing import Dict, List, Optional + +from .base import PlayerID +from .consts import RevealLevel, Team +from .records import DoctorSaveDataEntry, EventName, WerewolfNightEliminationDataEntry +from .states import GameState + + +class NightEliminationManager: + """ + Manages the state and resolution of nighttime eliminations. + """ + + def __init__(self, state: GameState, reveal_level: RevealLevel = RevealLevel.ROLE): + self._state = state + self._reveal_level = reveal_level + self._saves: Dict[PlayerID, List[PlayerID]] = {} # Key: target_id, Value: [doctor_id] + + def reset(self): + """Clears all recorded actions for the start of a new night.""" + self._saves.clear() + + def record_save(self, doctor_id: PlayerID, target_id: PlayerID): + """Records a save action from a Doctor.""" + self._saves.setdefault(target_id, []).append(doctor_id) + + def resolve_elimination(self, werewolf_target_id: Optional[PlayerID]): + """ + Resolves the werewolf attack against any saves, eliminates a player + if necessary, and pushes the resulting events to the game state. + """ + if not werewolf_target_id: + self._state.push_event( + description="Last night, the werewolves did not reach a consensus (or no valid target was chosen)." + " No one was eliminated by werewolves.", + event_name=EventName.MODERATOR_ANNOUNCEMENT, + public=False, + visible_to=self._state.get_players_by_team(Team.WEREWOLVES), + ) + self._state.push_event( + description="Last night, No one was eliminated.", + event_name=EventName.MODERATOR_ANNOUNCEMENT, + public=True, + ) + return + + target_player = self._state.get_player_by_id(werewolf_target_id) + if not target_player: + self._state.push_event( + description=f'Last night, werewolves targeted player "{werewolf_target_id}", but this player ' + f"could not be found. No one was eliminated by werewolves.", + event_name=EventName.ERROR, + public=False, + visible_to=self._state.get_players_by_team(Team.WEREWOLVES), + ) + self._state.push_event( + description="Last night, no one was eliminated.", + event_name=EventName.MODERATOR_ANNOUNCEMENT, + public=True, + ) + return + + if werewolf_target_id in self._saves: + # The player was saved. + saving_doctor_ids = self._saves[werewolf_target_id] + save_data = DoctorSaveDataEntry(saved_player_id=werewolf_target_id) + self._state.push_event( + description=f'Your heal on player "{werewolf_target_id}" was successful!', + event_name=EventName.HEAL_RESULT, + public=False, + data=save_data, + visible_to=saving_doctor_ids, + ) + self._state.push_event( + description="Last night, no one was eliminated.", + event_name=EventName.MODERATOR_ANNOUNCEMENT, + public=True, + ) + else: + # The player is eliminated. + original_role_name = target_player.role.name + self._state.eliminate_player(werewolf_target_id) + + team = None + role = None + descriptions = [f'Last night, player "{werewolf_target_id}" was eliminated by werewolves.'] + if self._reveal_level == RevealLevel.ROLE: + team = target_player.role.team + role = target_player.role.name + descriptions.append(f'Their role was a "{original_role_name}".') + elif self._reveal_level == RevealLevel.TEAM: + team = target_player.role.team + descriptions.append(f'Their team was "{team}".') + + data = WerewolfNightEliminationDataEntry( + eliminated_player_id=werewolf_target_id, + eliminated_player_role_name=role, + eliminated_player_team_name=team, + ) + description = " ".join(descriptions) + self._state.push_event(description=description, event_name=EventName.ELIMINATION, public=True, data=data) diff --git a/kaggle_environments/envs/werewolf/game/protocols/__init__.py b/kaggle_environments/envs/werewolf/game/protocols/__init__.py new file mode 100644 index 00000000..29babd6b --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/protocols/__init__.py @@ -0,0 +1,4 @@ +# The line below register the protocols +from . import bid, chat, factory, vote + +__all__ = ['bid', 'chat', 'factory', 'vote'] diff --git a/kaggle_environments/envs/werewolf/game/protocols/base.py b/kaggle_environments/envs/werewolf/game/protocols/base.py new file mode 100644 index 00000000..44aa5c69 --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/protocols/base.py @@ -0,0 +1,242 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Sequence, Tuple + +from kaggle_environments.envs.werewolf.game.actions import Action, BidAction, ChatAction +from kaggle_environments.envs.werewolf.game.base import PlayerID +from kaggle_environments.envs.werewolf.game.consts import EventName +from kaggle_environments.envs.werewolf.game.records import ChatDataEntry, RequestVillagerToSpeakDataEntry +from kaggle_environments.envs.werewolf.game.roles import Player +from kaggle_environments.envs.werewolf.game.states import GameState + + +def _extract_player_ids_from_string(text: str, all_player_ids: List[PlayerID]) -> List[PlayerID]: + """Extracts player IDs mentioned in a string.""" + if not all_player_ids: + return [] + # Create a regex pattern to find any of the player IDs as whole words + # Using a set for faster lookups and to handle duplicates from the regex + pattern = r"\b(" + "|".join(re.escape(pid) for pid in all_player_ids) + r")\b" + # Use a set to automatically handle duplicates found by the regex + found_ids = set(re.findall(pattern, text)) + return sorted(list(found_ids)) # sorted for deterministic order + + +def _find_mentioned_players(text: str, all_player_ids: List[PlayerID]) -> List[PlayerID]: + """ + Finds player IDs mentioned in a string of text, ordered by their first appearance. + Player IDs are treated as whole words. + Example: "I think gpt-4 is suspicious, what do you think John?" -> ["gpt-4", "John"] + """ + if not text or not all_player_ids: + return [] + + # Sort by length descending to handle substrings correctly. + sorted_player_ids = sorted(all_player_ids, key=len, reverse=True) + pattern = r"\b(" + "|".join(re.escape(pid) for pid in sorted_player_ids) + r")\b" + + matches = re.finditer(pattern, text) + + # Deduplicate while preserving order of first appearance + ordered_mentioned_ids = [] + seen = set() + for match in matches: + player_id = match.group(1) + if player_id not in seen: + ordered_mentioned_ids.append(player_id) + seen.add(player_id) + + return ordered_mentioned_ids + + +class GameProtocol(ABC): + @property + def display_name(self) -> str: + return self.__class__.__name__ + + @property + @abstractmethod + def rule(self) -> str: + """Human-readable format of rule.""" + + +class VotingProtocol(GameProtocol): + """Collects, validates, and tallies votes.""" + + @abstractmethod + def begin_voting(self, state: GameState, alive_voters: Sequence[Player], potential_targets: Sequence[Player]): + """Initialize for a new voting round.""" + + @abstractmethod + def get_voting_prompt(self, state: GameState, player_id: PlayerID) -> str: + """ + Returns a string prompt for the specified player, potentially including current tally. + """ + + @abstractmethod + def collect_vote(self, vote_action: Action, state: GameState): # Changed to Action, will check type + """Collect an individual vote.""" + + @abstractmethod + def collect_votes(self, player_actions: Dict[str, Action], state: GameState, expected_voters: List[PlayerID]): + """Collect a batch of votes.""" + + @abstractmethod + def get_current_tally_info(self, state: GameState) -> Dict[PlayerID, PlayerID]: + """ + Return the current tally by a map, where key is player, value is target. + """ + + @abstractmethod + def get_next_voters(self) -> List[PlayerID]: + """get the next batch of voters""" + + @abstractmethod + def done(self): + """Check if voting is done.""" + + @abstractmethod + def get_valid_targets(self) -> List[PlayerID]: + """get a list of targets""" + + @abstractmethod + def get_elected(self) -> Optional[PlayerID]: + """get the final elected individual, or None if no one was elected.""" + + @abstractmethod + def reset(self) -> None: + """Resets the protocol to its initial state.""" + pass + + +class BiddingProtocol(GameProtocol): + """Drives one auction round and returns the winner(s).""" + + @property + @abstractmethod + def bids(self) -> Dict[PlayerID, int]: + """return a snapshot of the current bids""" + + @staticmethod + def get_last_mentioned(state: GameState) -> Tuple[List[PlayerID], str]: + """get the players that were mentioned in last player message.""" + last_chat_message = "" + sorted_days = sorted(state.history.keys(), reverse=True) + for day in sorted_days: + for entry in reversed(state.history[day]): + if entry.event_name == EventName.DISCUSSION and isinstance(entry.data, ChatDataEntry): + last_chat_message = entry.data.message + break + if last_chat_message: + break + players = _find_mentioned_players(last_chat_message, state.all_player_ids) + return players, last_chat_message + + @abstractmethod + def begin(self, state: GameState) -> None: ... + + @abstractmethod + def accept(self, bid: BidAction, state: GameState) -> None: ... + + @abstractmethod + def process_incoming_bids(self, actions: List[Action], state: GameState) -> None: + """Processes a batch of actions, handling BidActions by calling self.accept().""" + + @abstractmethod + def is_finished(self, state: GameState) -> bool: ... + + @abstractmethod + def outcome(self, state: GameState) -> list[PlayerID]: + """ + Return list of player-ids, ordered by bid strength. + Could be 1 winner (sealed-bid) or a full ranking (Dutch auction). + """ + + @abstractmethod + def reset(self) -> None: + """Resets the protocol to its initial state.""" + + +class DiscussionProtocol(GameProtocol): + """Drives the order/shape of daytime conversation.""" + + @abstractmethod + def begin(self, state: GameState) -> None: + """Optional hook – initialise timers, round counters…""" + + @abstractmethod + def speakers_for_tick(self, state: GameState) -> Sequence[PlayerID]: + """ + Return the IDs that are *allowed to send a chat action* this tick. + Return an empty sequence when the discussion phase is over. + """ + + @abstractmethod + def is_discussion_over(self, state: GameState) -> bool: + """Returns True if the entire discussion (including any preliminary phases like bidding) is complete.""" + pass + + @abstractmethod + def reset(self) -> None: + """Resets the protocol to its initial state.""" + pass + + def process_actions(self, actions: List[Action], expected_speakers: Sequence[PlayerID], state: GameState) -> None: + """ + Processes a batch of actions. Depending on the protocol's state (e.g., bidding or chatting), + it will handle relevant actions (like BidAction or ChatAction) from expected_speakers. + """ + for act in actions: + if isinstance(act, ChatAction): + all_player_ids = [p.id for p in state.players] + mentioned_ids = _extract_player_ids_from_string(act.message, all_player_ids) + if expected_speakers and act.actor_id in expected_speakers: + data = ChatDataEntry( + actor_id=act.actor_id, + message=act.message, + reasoning=act.reasoning, + mentioned_player_ids=mentioned_ids, + perceived_threat_level=act.perceived_threat_level, + action=act, + ) + state.push_event( + description=f'Player "{act.actor_id}" (chat): {act.message}', + # Make public for general discussion + event_name=EventName.DISCUSSION, + public=True, + source=act.actor_id, + data=data, + ) + else: + state.push_event( + description=f'Player "{act.actor_id}" (chat, out of turn): {act.message}', + event_name=EventName.DISCUSSION, # Or a specific "INVALID_CHAT" type + visible_to=[act.actor_id], + public=False, + source=act.actor_id, + ) + + def call_for_actions(self, speakers: Sequence[PlayerID]) -> List[str]: + """prepare moderator call for action for each player.""" + return [f'Player "{speaker_id}", it is your turn to speak.' for speaker_id in speakers] + + def prompt_speakers_for_tick(self, state: GameState, speakers: Sequence[PlayerID]) -> None: + """ + Allows the protocol to make specific announcements or prompts to the current speakers for this tick. + This method is called by the Moderator after speakers_for_tick() returns a non-empty list of speakers, + and before process_actions(). + Implementations should use state.push_event() to make announcements. + These announcements are typically visible only to the speakers, unless they are general status updates. + """ + call_for_actions = self.call_for_actions(speakers) + for speaker_id, call_for_action in zip(speakers, call_for_actions): + data = RequestVillagerToSpeakDataEntry(action_json_schema=json.dumps(ChatAction.schema_for_player())) + state.push_event( + description=call_for_action, + event_name=EventName.CHAT_REQUEST, + public=False, + visible_to=[speaker_id], + data=data, + visible_in_ui=False, + ) diff --git a/kaggle_environments/envs/werewolf/game/protocols/bid.py b/kaggle_environments/envs/werewolf/game/protocols/bid.py new file mode 100644 index 00000000..608af4a5 --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/protocols/bid.py @@ -0,0 +1,248 @@ +from collections import Counter +from typing import Dict, List + +from kaggle_environments.envs.werewolf.game.actions import Action, BidAction +from kaggle_environments.envs.werewolf.game.base import PlayerID +from kaggle_environments.envs.werewolf.game.consts import EventName +from kaggle_environments.envs.werewolf.game.protocols.base import BiddingProtocol +from kaggle_environments.envs.werewolf.game.records import BidDataEntry, ChatDataEntry +from kaggle_environments.envs.werewolf.game.states import GameState + +from .factory import register_protocol + + +@register_protocol() +class SimpleBiddingProtocol(BiddingProtocol): + """ + A straightforward bidding protocol where speaking priority is determined + solely by the bid amount. + - Agents bid with a numerical amount. + - Higher bids result in earlier speaking turns. + - Ties are broken deterministically by player ID (ascending). + """ + + def __init__(self): + self._bids: Dict[PlayerID, int] = {} + self._max_bid = 4 + self.reset() + + def reset(self) -> None: + """Resets the bids for a new round.""" + self._bids = {} + + @property + def display_name(self): + return "Simple Bidding" + + @property + def rule(self) -> str: + """Provides a description of the bidding rules.""" + return "\n".join( + ( + "Players bid with an urgency level (0-4) to determine speaking order.", + "0: I would like to observe and listen for now.", + "1: I have some general thoughts to share with the group.", + "2: I have something critical and specific to contribute to this discussion.", + "3: It is absolutely urgent for me to speak next.", + "4: I must respond.", + "Higher bids speak earlier. Ties are broken by player name (A-Z).", + ) + ) + + @property + def bids(self) -> Dict[PlayerID, int]: + """Returns a copy of the current bids.""" + return dict(**self._bids) + + def begin(self, state: GameState) -> None: + """Initializes a new bidding round.""" + self.reset() + + def accept(self, bid: BidAction, state: GameState) -> None: + """Accepts and records a single bid from a player.""" + bid_amount = min(max(0, bid.amount), self._max_bid) + self._bids[bid.actor_id] = bid_amount + + data = BidDataEntry( + actor_id=bid.actor_id, + reasoning=bid.reasoning, + perceived_threat_level=bid.perceived_threat_level, + bid_amount=bid_amount, + action=bid, + ) + state.push_event( + description=f"Player {bid.actor_id} submitted a bid of {bid_amount}.", + event_name=EventName.BID_ACTION, + public=False, # Bids are private until the outcome is announced + visible_to=[bid.actor_id], + data=data, + source=bid.actor_id, + ) + + def process_incoming_bids(self, actions: List[Action], state: GameState) -> None: + """Processes a list of actions, handling any BidActions.""" + for act in actions: + if isinstance(act, BidAction): + self.accept(act, state) + + def is_finished(self, state: GameState) -> bool: + """ + Checks if the bidding phase is complete (i.e., all alive players have bid). + """ + return len(self._bids) >= len(state.alive_players()) + + def outcome(self, state: GameState) -> list[str]: + """ + Determines the final speaking order based on bids. + + Returns: + A list of player IDs sorted by bid (descending) and then player ID (ascending). + """ + if not self._bids: + # If no bids were made, return alive players in their default order. + return sorted([p.id for p in state.alive_players()]) + + # Sort by bid amount (descending) and then by player ID (ascending) for tie-breaking. + sorted_bidders = sorted(self._bids.items(), key=lambda item: (-item[1], item[0])) + return [player_id for player_id, bid_amount in sorted_bidders] + + +@register_protocol() +class UrgencyBiddingProtocol(BiddingProtocol): + """ + A bidding protocol based on the Werewolf Arena paper. + - Agents bid with an urgency level (0-4). + - Highest bidder wins. + - Ties are broken by prioritizing players mentioned in the previous turn. + """ + + @property + def display_name(self) -> str: + return "Urgency Bidding" + + @property + def rule(self) -> str: + return "\n".join( + [ + "Urgency-based bidding. Players bid with an urgency level (0-4).", + "0: I would like to observe and listen for now.", + "1: I have some general thoughts to share with the group.", + "2: I have something critical and specific to contribute to this discussion.", + "3: It is absolutely urgent for me to speak next.", + "4: Someone has addressed me directly and I must respond.", + "Highest bidder wins." + "Ties are broken by the following priority: (1) players mentioned in the previous turn's chat, " + "(2) the least spoken player, (3) round robin order of the player list.", + ] + ) + + @property + def bids(self) -> Dict[PlayerID, int]: + return dict(**self._bids) + + def __init__(self): + self._bids: Dict[PlayerID, int] = {} + self._mentioned_last_turn: List[PlayerID] = [] + + def reset(self) -> None: + self._bids = {} + self._mentioned_last_turn = [] + + def begin(self, state: GameState) -> None: + """Called at the start of a bidding round to identify recently mentioned players.""" + self.reset() + # Find the very last chat entry in the history to check for mentions + self._mentioned_last_turn, last_chat_message = self.get_last_mentioned(state) + + if last_chat_message: + if self._mentioned_last_turn: + state.push_event( + description=f"Players mentioned last turn (priority in ties): {self._mentioned_last_turn}", + event_name=EventName.BIDDING_INFO, + public=True, # So everyone knows who has priority + ) + + def accept(self, bid: BidAction, state: GameState) -> None: + if 0 <= bid.amount <= 4: + self._bids[bid.actor_id] = bid.amount + data = BidDataEntry( + actor_id=bid.actor_id, + reasoning=bid.reasoning, + perceived_threat_level=bid.perceived_threat_level, + bid_amount=bid.amount, + action=bid, + ) + state.push_event( + description=f"Player {bid.actor_id} submitted bid=({bid.amount}).", + event_name=EventName.BID_ACTION, + public=False, + visible_to=[bid.actor_id], + data=data, + source=bid.actor_id, + ) + else: + # Invalid bid amount is treated as a bid of 0 + self._bids[bid.actor_id] = 0 + state.push_event( + description=f"Player {bid.actor_id} submitted an invalid bid amount ({bid.amount}). Treated as 0.", + event_name=EventName.ERROR, + public=False, + visible_to=[bid.actor_id], + ) + + def process_incoming_bids(self, actions: List[Action], state: GameState) -> None: + for act in actions: + if isinstance(act, BidAction): + self.accept(act, state) + + def is_finished(self, state: GameState) -> bool: + # This bidding round is considered "finished" when all alive players have bid. + return len(self._bids) >= len(state.alive_players()) + + def outcome(self, state: GameState) -> list[str]: + if not self._bids: + # If no one bids, deterministically pick the first alive player to speak. + alive_players = state.alive_players() + return [alive_players[0].id] if alive_players else [] + + max_bid = max(self._bids.values()) + highest_bidders = sorted([pid for pid, amt in self._bids.items() if amt == max_bid]) + + if len(highest_bidders) == 1: + return highest_bidders + + # Tie-breaking logic + candidates = highest_bidders + + # Rule 1: Players mentioned in the last turn + mentioned_in_tie = [pid for pid in candidates if pid in self._mentioned_last_turn] + if mentioned_in_tie: + candidates = mentioned_in_tie + + if len(candidates) == 1: + return candidates + + # Rule 2: The least spoken individual + speech_counts = Counter( + entry.data.actor_id + for day_events in state.history.values() + for entry in day_events + if entry.event_name == EventName.DISCUSSION and isinstance(entry.data, ChatDataEntry) + ) + + candidate_speech_counts = {pid: speech_counts.get(pid, 0) for pid in candidates} + min_spoken = min(candidate_speech_counts.values()) + least_spoken_candidates = sorted([pid for pid, count in candidate_speech_counts.items() if count == min_spoken]) + + if len(least_spoken_candidates) == 1: + return least_spoken_candidates + + candidates = least_spoken_candidates + + # Rule 3: Round robin order of the player list in state + for pid in state.all_player_ids: + if pid in candidates: + return [pid] + + # This part should be unreachable if candidates is a subset of all_player_ids + return [candidates[0]] if candidates else [] diff --git a/kaggle_environments/envs/werewolf/game/protocols/chat.py b/kaggle_environments/envs/werewolf/game/protocols/chat.py new file mode 100644 index 00000000..84dd3cf9 --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/protocols/chat.py @@ -0,0 +1,465 @@ +import itertools +import json +import random +from abc import ABC +from collections import deque +from typing import List, Optional, Sequence + +from kaggle_environments.envs.werewolf.game.actions import Action, BidAction +from kaggle_environments.envs.werewolf.game.base import PlayerID +from kaggle_environments.envs.werewolf.game.consts import EventName, StrEnum +from kaggle_environments.envs.werewolf.game.protocols.base import BiddingProtocol, DiscussionProtocol +from kaggle_environments.envs.werewolf.game.records import BidResultDataEntry, DiscussionOrderDataEntry +from kaggle_environments.envs.werewolf.game.states import GameState + +from .bid import SimpleBiddingProtocol +from .factory import register_protocol + + +@register_protocol(default_params={"max_rounds": 2, "assign_random_first_speaker": True}) +class RoundRobinDiscussion(DiscussionProtocol): + def __init__(self, max_rounds: int = 1, assign_random_first_speaker: bool = True): + """ + + Args: + max_rounds: rounds of discussion + assign_random_first_speaker: If true, the first speaker will be determined at the beginning of + the game randomly, while the order follow that of the player list. Otherwise, will start from the + 0th player from player list. + """ + self.max_rounds = max_rounds + self._queue: deque[str] = deque() + self._assign_random_first_speaker = assign_random_first_speaker + self._player_ids = None + self._first_player_idx = None + + def reset(self) -> None: + self._queue = deque() + + @property + def display_name(self) -> str: + return "Roundrobin" + + @property + def rule(self) -> str: + return f"Players speak in round-robin order for {self.max_rounds} round(s)." + + def begin(self, state): + if self._player_ids is None: + # initialize player_ids once. + self._player_ids = deque(state.all_player_ids) + if self._assign_random_first_speaker: + self._player_ids.rotate(random.randrange(len(self._player_ids))) + + # Reset queue + player_order = [pid for pid in self._player_ids if state.is_alive(pid)] + self._queue = deque(player_order * self.max_rounds) + if self.max_rounds > 0 and self._queue: + data = DiscussionOrderDataEntry(chat_order_of_player_ids=player_order) + state.push_event( + description="Discussion phase begins. Players will speak in round-robin order. " + f"Starting from player {player_order[0]} with the following order: {player_order} " + f"for {self.max_rounds} round(s).", + event_name=EventName.DISCUSSION_ORDER, + public=True, + data=data, + ) + + def speakers_for_tick(self, state): + return [self._queue.popleft()] if self._queue else [] + + def is_discussion_over(self, state: GameState) -> bool: + return not self._queue # Over if queue is empty + + +@register_protocol() +class RandomOrderDiscussion(DiscussionProtocol): + def __init__(self): + self._iters = None + self._steps = 0 + + def reset(self) -> None: + self._iters = None + self._steps = 0 + + @property + def display_name(self) -> str: + return "Random Order Discussion" + + @property + def rule(self) -> str: + return "Players speak in a random order for one full round." + + def begin(self, state): + self._iters = itertools.cycle( + random.sample([p.id for p in state.alive_players()], k=len(state.alive_players())) + ) + self._steps = len(state.alive_players()) # one full round + if self._steps > 0: + state.push_event( + description="Discussion phase begins. Players will speak in random order.", + event_name=EventName.PHASE_CHANGE, + public=True, + ) + + def speakers_for_tick(self, state): + if self._steps == 0: + return [] + self._steps -= 1 + return [next(self._iters)] + + def is_discussion_over(self, state: GameState) -> bool: + return self._steps == 0 + + +@register_protocol() +class ParallelDiscussion(DiscussionProtocol): + """ + Everyone may talk for `ticks` chat turns. + Useful when you want simultaneous / overlapping chat. + """ + + def __init__(self, ticks: int = 3): + self.ticks = ticks + self._remaining = 0 + + def reset(self) -> None: + self._remaining = 0 + + @property + def display_name(self) -> str: + return "Parallel Discussion" + + @property + def rule(self) -> str: + return f"All players may speak simultaneously for {self.ticks} tick(s)." + + def begin(self, state): + self._remaining = self.ticks + if self.ticks > 0: + state.push_event( + description="Parallel discussion phase begins. All players may speak.", + event_name=EventName.PHASE_CHANGE, + public=True, + ) + + def speakers_for_tick(self, state): + if self._remaining == 0: + return [] + self._remaining -= 1 + return [p.id for p in state.alive_players()] + + def call_for_actions(self, speakers: Sequence[str]) -> List[str]: + return [ + f"Parallel discussion: All designated players may speak now or remain silent. " + f"({self._remaining + 1} speaking opportunities remaining, including this one)." + ] * len(speakers) + + def is_discussion_over(self, state: GameState) -> bool: + return self._remaining == 0 + + +class BiddingDiscussionPhase(StrEnum): + BIDDING_PHASE = "bidding_phase" + SPEAKING_PHASE = "speaking_phase" + + +class BiddingDiscussion(DiscussionProtocol, ABC): + def __init__(self, bidding: Optional[BiddingProtocol] = None): + bidding = bidding or SimpleBiddingProtocol() + self._bidding = bidding + self._phase = BiddingDiscussionPhase.BIDDING_PHASE + + @property + def bidding(self): + return self._bidding + + @property + def phase(self): + return self._phase + + def is_bidding_phase(self): + return self._phase == BiddingDiscussionPhase.BIDDING_PHASE + + def is_speaking_phase(self): + return self._phase == BiddingDiscussionPhase.SPEAKING_PHASE + + def set_phase(self, phase: BiddingDiscussionPhase): + self._phase = phase + + +@register_protocol(default_params={"max_turns": 8, "bid_result_public": True}) +class TurnByTurnBiddingDiscussion(BiddingDiscussion): + """ + A discussion protocol where players bid for the right to speak each turn. + This protocol manages the entire bid-speak-bid-speak loop. + """ + + def __init__(self, bidding: Optional[BiddingProtocol] = None, max_turns: int = 8, bid_result_public: bool = True): + super().__init__(bidding=bidding) + self.max_turns = max_turns + self._turns_taken = 0 + self._speaker: Optional[str] = None + self._all_passed = False + self._bid_result_public = bid_result_public + + def reset(self) -> None: + self.bidding.reset() + self.set_phase(BiddingDiscussionPhase.BIDDING_PHASE) + self._turns_taken = 0 + self._speaker = None + self._all_passed = False + + @property + def display_name(self) -> str: + return "Turn-by-turn Bidding Driven Discussion" + + @property + def rule(self) -> str: + return "\n".join( + [ + f"Players bid for the right to speak each turn for up to {self.max_turns} turns.", + f"**Bidding Rule:** {self.bidding.display_name}. {self.bidding.rule}", + "If everyone bids 0, moderator will directly move on to day voting and no one speaks.", + ] + ) + + def begin(self, state: GameState) -> None: + self.reset() + self.bidding.begin(state) # Initial setup for the first bidding round + + def is_discussion_over(self, state: GameState) -> bool: + return self._turns_taken >= self.max_turns or self._all_passed + + def speakers_for_tick(self, state: GameState) -> Sequence[PlayerID]: + if self.is_discussion_over(state): + return [] + + if self.is_bidding_phase(): + return [p.id for p in state.alive_players()] + elif self.is_speaking_phase(): + return [self._speaker] if self._speaker else [] + return [] + + def process_actions(self, actions: List[Action], expected_speakers: Sequence[PlayerID], state: GameState) -> None: + if self.is_bidding_phase(): + self.bidding.process_incoming_bids(actions, state) + + # Handle players who didn't bid (timed out) by assuming a bid of 0 + all_alive_player_ids = [p.id for p in state.alive_players()] + if hasattr(self.bidding, "_bids"): + for player_id in all_alive_player_ids: + if player_id not in self.bidding._bids: + default_bid = BidAction( + actor_id=player_id, amount=0, day=state.day_count, phase=state.phase.value + ) + self.bidding.accept(default_bid, state) + + bids = getattr(self.bidding, "_bids", {}) + if len(bids) >= len(all_alive_player_ids) and all(amount == 0 for amount in bids.values()): + self._all_passed = True + state.push_event( + description="All players passed on speaking. Discussion ends.", + event_name=EventName.MODERATOR_ANNOUNCEMENT, + public=True, + ) + return + + # Once all bids are in (or a timeout, handled by moderator's single tick), determine the winner + winner_list = self.bidding.outcome(state) + self._speaker = winner_list[0] if winner_list else None + + if self._speaker: + data = BidResultDataEntry( + winner_player_ids=[self._speaker], + bid_overview=self.bidding.bids, + mentioned_players_in_previous_turn=self.bidding.get_last_mentioned(state)[0], + ) + overview_text = ", ".join([f"{k}: {v}" for k, v in self.bidding.bids.items()]) + state.push_event( + description=f"Player {self._speaker} won the bid and will speak next.\n" + f"Bid overview - {overview_text}.", + event_name=EventName.BID_RESULT, + public=self._bid_result_public, + data=data, + ) + self.set_phase(BiddingDiscussionPhase.SPEAKING_PHASE) + else: + # No one to speak, advance turn count and bid again + self._turns_taken += 1 + if not self.is_discussion_over(state): + self.bidding.begin(state) # Prepare for next bidding round + + elif self.is_speaking_phase(): + # Process the chat action from the designated speaker + super().process_actions(actions, expected_speakers, state) + self._turns_taken += 1 + + # After speaking, transition back to bidding for the next turn + if not self.is_discussion_over(state): + self.set_phase(BiddingDiscussionPhase.BIDDING_PHASE) + self._speaker = None + self.bidding.begin(state) # Reset bids and find new mentioned players + + def prompt_speakers_for_tick(self, state: GameState, speakers: Sequence[PlayerID]) -> None: + if self.is_bidding_phase(): + data = {"action_json_schema": json.dumps(BidAction.schema_for_player())} + state.push_event( + description=( + f"A new round of discussion begins. Place bid for a chance to speak. " + f"{self.max_turns - self._turns_taken} turns left to speak." + ), + event_name=EventName.BID_REQEUST, + public=True, + data=data, + visible_in_ui=False, + ) + elif self.is_speaking_phase() and self._speaker: + super().prompt_speakers_for_tick(state, speakers) + + +@register_protocol(default_params={"max_rounds": 2, "bid_result_public": True}) +class RoundByRoundBiddingDiscussion(BiddingDiscussion): + """ + A discussion protocol where players bid at the start of each round to + determine the speaking order for that round. + + In each of the N rounds: + 1. A bidding phase occurs where all alive players submit a bid (0-4). + 2. The speaking order is determined by sorting players by their bid amount + (descending) and then by player ID (ascending) as a tie-breaker. + 3. A speaking phase occurs where each player speaks once according to the + determined order. + """ + + def __init__(self, bidding: Optional[BiddingProtocol] = None, max_rounds: int = 2, bid_result_public: bool = True): + """ + Args: + bidding: The bidding protocol to use for determining speaking order. + max_rounds: The total number of discussion rounds. + bid_result_public: Whether to make the bidding results public. + """ + super().__init__(bidding=bidding) + self.max_rounds = max_rounds + self._bid_result_public = bid_result_public + self._current_round = 0 + self._speaking_queue: deque[str] = deque() + self.reset() + + def reset(self) -> None: + """Resets the protocol to its initial state.""" + self.bidding.reset() + self.set_phase(BiddingDiscussionPhase.BIDDING_PHASE) + self._current_round = 0 + self._speaking_queue = deque() + + @property + def display_name(self) -> str: + return "Round-by-round Bidding Driven Discussion" + + @property + def rule(self) -> str: + """A string describing the discussion rule in effect.""" + return "\n".join( + [ + "Players speak in an order determined by bidding at the beginning of each round. " + f"There will be {self.max_rounds} round(s) per day.", + "In each round, all players may speak once.", + f"**Bidding Rule:** {self.bidding.display_name}. {self.bidding.rule}", + ] + ) + + def begin(self, state: GameState) -> None: + """Initializes the protocol for the first round.""" + self.reset() + self.bidding.begin(state) + + def is_discussion_over(self, state: GameState) -> bool: + """Checks if all rounds have been completed.""" + return self._current_round >= self.max_rounds + + def speakers_for_tick(self, state: GameState) -> Sequence[PlayerID]: + """Returns the players who are allowed to act in the current tick.""" + if self.is_discussion_over(state): + return [] + + if self.is_bidding_phase(): + # In the bidding phase, all alive players can bid. + return [p.id for p in state.alive_players()] + elif self.is_speaking_phase(): + # In the speaking phase, the next player in the queue speaks. + return [self._speaking_queue.popleft()] if self._speaking_queue else [] + return [] + + def process_actions(self, actions: List[Action], expected_speakers: Sequence[PlayerID], state: GameState) -> None: + """Processes incoming actions from players.""" + if self.is_bidding_phase(): + self.bidding.process_incoming_bids(actions, state) + + # Assume a bid of 0 for any players who timed out. + all_alive_player_ids = [p.id for p in state.alive_players()] + if hasattr(self.bidding, "_bids"): + for player_id in all_alive_player_ids: + if player_id not in self.bidding.bids: + default_bid = BidAction( + actor_id=player_id, amount=0, day=state.day_count, phase=state.phase.value + ) + self.bidding.accept(default_bid, state) + + # Determine speaking order based on bids. + # Sort by bid amount (desc) and then player ID (asc). + bids = self.bidding.bids + sorted_bidders = sorted(bids.items(), key=lambda item: (-item[1], item[0])) + + self._speaking_queue = deque([player_id for player_id, bid_amount in sorted_bidders]) + + # Announce the speaking order for the round. + data = DiscussionOrderDataEntry(chat_order_of_player_ids=list(self._speaking_queue)) + speaking_order_text = ", ".join([f"{pid} ({amount})" for pid, amount in sorted_bidders]) + + state.push_event( + description=f"Bidding for round {self._current_round + 1} has concluded. The speaking order, " + f"with bid amounts in parentheses, is: {speaking_order_text}.", + event_name=EventName.BID_RESULT, + public=self._bid_result_public, + data=data, + ) + + # Transition to the speaking phase. + self.set_phase(BiddingDiscussionPhase.SPEAKING_PHASE) + + elif self.is_speaking_phase(): + # Process the chat action from the current speaker. + super().process_actions(actions, expected_speakers, state) + + # Check if the round is over (i.e., the speaking queue is empty). + if not self._speaking_queue: + self._current_round += 1 + state.push_event( + description=f"End of discussion round {self._current_round}.", + event_name=EventName.PHASE_CHANGE, + public=True, + ) + + # If the game isn't over, prepare for the next round's bidding. + if not self.is_discussion_over(state): + self.set_phase(BiddingDiscussionPhase.BIDDING_PHASE) + self.bidding.begin(state) + + def prompt_speakers_for_tick(self, state: GameState, speakers: Sequence[PlayerID]) -> None: + """Prompts the active players for their next action.""" + if self.is_bidding_phase(): + data = {"action_json_schema": json.dumps(BidAction.schema_for_player())} + state.push_event( + description=( + f"Round {self._current_round + 1} of {self.max_rounds} begins. " + "Place your bid to determine speaking order." + ), + event_name=EventName.BID_REQEUST, + public=True, + data=data, + visible_in_ui=False, + ) + elif self.is_speaking_phase(): + # The default prompt from the base class is sufficient for speaking. + super().prompt_speakers_for_tick(state, speakers) diff --git a/kaggle_environments/envs/werewolf/game/protocols/factory.py b/kaggle_environments/envs/werewolf/game/protocols/factory.py new file mode 100644 index 00000000..8d345fd6 --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/protocols/factory.py @@ -0,0 +1,59 @@ +from typing import Any, Callable, Dict, Type + +# The new unified, flat registry. Maps class names to class objects and default params. +PROTOCOL_REGISTRY: Dict[str, Dict[str, Any]] = {} + + +def register_protocol(default_params: Dict = None) -> Callable: + """ + A decorator to register a protocol class in the central unified registry. + The protocol is registered using its class name. + """ + if default_params is None: + default_params = {} + + def decorator(cls: Type) -> Type: + name = cls.__name__ + if name in PROTOCOL_REGISTRY: + raise TypeError(f"Protocol '{name}' is already registered.") + + PROTOCOL_REGISTRY[name] = {"class": cls, "default_params": default_params} + return cls + + return decorator + + +def create_protocol(config: Dict, default_name: str = None) -> Any: + """ + Factory function to recursively create protocol instances from a configuration dictionary. + """ + if not config and default_name: + config = {"name": default_name} + elif not config and not default_name: + # If no config and no default, we cannot proceed. + raise ValueError("Cannot create protocol from an empty configuration without a default name.") + + # Fallback to default_name if 'name' is not in the config + name = config.get("name", default_name) + if not name: + raise ValueError("Protocol name must be provided in config or as a default.") + + params = config.get("params", {}) + + protocol_info = PROTOCOL_REGISTRY.get(name) + if not protocol_info: + raise ValueError(f"Protocol '{name}' not found in the registry.") + + protocol_class = protocol_info["class"] + # Start with the protocol's defaults, then override with config params + final_params = {**protocol_info["default_params"], **params} + + # --- Recursive Instantiation for Nested Protocols --- + for param_name, param_value in final_params.items(): + # If a parameter's value is a dictionary that looks like a protocol config + # (i.e., it has a "name" key), we recursively create it. + if isinstance(param_value, dict) and "name" in param_value: + # The nested protocol's config is the param_value itself. + final_params[param_name] = create_protocol(param_value) + + return protocol_class(**final_params) diff --git a/kaggle_environments/envs/werewolf/game/protocols/vote.py b/kaggle_environments/envs/werewolf/game/protocols/vote.py new file mode 100644 index 00000000..7d29b6b8 --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/protocols/vote.py @@ -0,0 +1,471 @@ +import random +from collections import Counter, deque +from typing import Dict, List, Optional, Sequence + +from kaggle_environments.envs.werewolf.game.actions import Action, NoOpAction, VoteAction +from kaggle_environments.envs.werewolf.game.base import PlayerID +from kaggle_environments.envs.werewolf.game.consts import EventName, Phase, StrEnum +from kaggle_environments.envs.werewolf.game.protocols.base import VotingProtocol +from kaggle_environments.envs.werewolf.game.records import ( + DayExileVoteDataEntry, + VoteOrderDataEntry, + WerewolfNightVoteDataEntry, +) +from kaggle_environments.envs.werewolf.game.roles import Player +from kaggle_environments.envs.werewolf.game.states import GameState + +from .factory import register_protocol + + +class TieBreak(StrEnum): + RANDOM = "random" + """Randomly select from top ties.""" + + NO_EXILE = "no_elected" + """Tie result in no one elected.""" + + +ABSTAIN_VOTE = "-1" + + +class Ballot: + def __init__(self, tie_selection: TieBreak = TieBreak.RANDOM): + self._ballots: Dict[PlayerID, PlayerID] = {} + self._tie_selection = tie_selection + + def reset(self): + self._ballots = {} + + def add_vote(self, voter_id: PlayerID, target_id: PlayerID): + """Records a vote from a voter for a target.""" + self._ballots[voter_id] = target_id + + def get_tally(self) -> Counter: + """Returns a Counter of votes for each target, excluding abstained votes.""" + return Counter(v for v in self._ballots.values() if v is not None and v != ABSTAIN_VOTE) + + def get_elected(self, potential_targets: List[PlayerID]) -> Optional[PlayerID]: + """ + Tallies the votes and determines the elected player based on the tie-breaking rule. + """ + counts = self.get_tally().most_common() + elected: Optional[PlayerID] = None + + if not counts: + # No valid votes were cast. + if self._tie_selection == TieBreak.RANDOM and potential_targets: + elected = random.choice(potential_targets) + # If NO_EXILE, elected remains None. + else: + _, top_votes = counts[0] + top_candidates = [v for v, c in counts if c == top_votes] + + if len(top_candidates) == 1: + elected = top_candidates[0] + else: # It's a tie. + if self._tie_selection == TieBreak.RANDOM: + elected = random.choice(top_candidates) + # If NO_EXILE, elected remains None. + + return elected + + def get_all_votes(self) -> Dict[PlayerID, PlayerID]: + """Returns a copy of all recorded ballots.""" + return self._ballots.copy() + + +@register_protocol() +class SimultaneousMajority(VotingProtocol): + def __init__(self, tie_break=TieBreak.RANDOM): + self._expected_voters: List[PlayerID] = [] + self._potential_targets: List[PlayerID] = [] + self._current_game_state: Optional[GameState] = None # To store state from begin_voting + self._elected: Optional[PlayerID] = None + self._done_tallying = False + self._tie_break = tie_break + self._ballot = Ballot(tie_selection=self._tie_break) + + if tie_break not in TieBreak: + raise ValueError(f"Invalid tie_break value: {tie_break}. Must be one of {TieBreak}.") + + def reset(self) -> None: + self._ballot.reset() + self._expected_voters = [] + self._potential_targets = [] + self._current_game_state = None + self._elected = None + self._done_tallying = False + + @property + def display_name(self) -> str: + return "Simultaneous Majority Voting" + + @property + def rule(self) -> str: + rule = "Player with the most votes is exiled. " + if self._tie_break == TieBreak.RANDOM: + rule += ( + "Ties result in random selection amongst the top ties. " + "If no valid vote available (if all casted abstained votes), " + "will result in random elimination of one player." + ) + elif self._tie_break == TieBreak.NO_EXILE: + rule += "Ties result in no exile." + return rule + + def begin_voting(self, state: GameState, alive_voters: Sequence[Player], potential_targets: Sequence[Player]): + self._ballot.reset() + # Ensure voters and targets are alive at the start of voting + self._expected_voters = [p.id for p in alive_voters if p.alive] + self._potential_targets = [p.id for p in potential_targets if p.alive] + self._current_game_state = state # Store the game state reference + + def collect_votes(self, player_actions: Dict[PlayerID, Action], state: GameState, expected_voters: List[PlayerID]): + for actor_id, action in player_actions.items(): + if actor_id in expected_voters: + self.collect_vote(action, state) + + # For any expected voter who didn't act, record an abstain vote. + all_votes = self._ballot.get_all_votes() + for player_id in expected_voters: + if player_id not in all_votes: + self._ballot.add_vote(player_id, ABSTAIN_VOTE) + + def collect_vote(self, vote_action: Action, state: GameState): + actor_player = state.get_player_by_id(vote_action.actor_id) + if not isinstance(vote_action, VoteAction): + state.push_event( + description=f'Invalid vote attempt by player "{vote_action.actor_id}". ' + f"Not a VoteAction; submitted {vote_action.__class__.__name__} instead. " + f"Cast as abstained vote.", + event_name=EventName.ERROR, + public=False, + visible_to=self._expected_voters, + data={}, + ) + self._ballot.add_vote(vote_action.actor_id, ABSTAIN_VOTE) + return + + if state.phase == Phase.NIGHT: + data_entry_class = WerewolfNightVoteDataEntry + else: + data_entry_class = DayExileVoteDataEntry + + data = data_entry_class( + actor_id=vote_action.actor_id, + target_id=vote_action.target_id, + reasoning=vote_action.reasoning, + perceived_threat_level=vote_action.perceived_threat_level, + action=vote_action, + ) + + # Voter must be expected and alive at the moment of casting vote + if actor_player and actor_player.alive and vote_action.actor_id in self._expected_voters: + # Prevent re-voting + if vote_action.actor_id in self._ballot.get_all_votes(): + state.push_event( + description=f'Invalid vote attempt by "{vote_action.actor_id}", already voted.', + event_name=EventName.ERROR, + public=False, + visible_to=self._expected_voters, + data=data, + ) + return + + if vote_action.target_id in self._potential_targets: + self._ballot.add_vote(vote_action.actor_id, vote_action.target_id) + + # Determine DataEntry type based on game phase + state.push_event( + description=f'Player "{data.actor_id}" voted to eliminate "{data.target_id}". ', + event_name=EventName.VOTE_ACTION, + public=False, + visible_to=self._expected_voters, + data=data, + source=vote_action.actor_id, + ) + else: + self._ballot.add_vote(vote_action.actor_id, ABSTAIN_VOTE) + state.push_event( + description=f'Invalid vote attempt by "{vote_action.actor_id}".', + event_name=EventName.ERROR, + public=False, + visible_to=self._expected_voters, + data=data, + ) + return + else: + state.push_event( + description=f"Invalid vote attempt by {vote_action.actor_id}.", + event_name=EventName.ERROR, + public=False, + data=data, + ) + + def get_voting_prompt(self, state: GameState, player_id: PlayerID) -> str: + target_options = [ + p_id + for p_id in self._potential_targets + if state.get_player_by_id(p_id) and state.get_player_by_id(p_id).alive + ] + return f'Player "{player_id}", please cast your vote. Options: {target_options} or Abstain ("{ABSTAIN_VOTE}").' + + def get_current_tally_info(self, state: GameState) -> Dict[PlayerID, int]: + return self._ballot.get_tally() + + def get_next_voters(self) -> List[PlayerID]: + # For simultaneous, all expected voters vote at once, and only once. + return [voter for voter in self._expected_voters if voter not in self._ballot.get_all_votes()] + + def done(self) -> bool: + # The voting is considered "done" after one tick where voters were requested. + # The moderator will then call tally_votes. + return all(voter in self._ballot.get_all_votes() for voter in self._expected_voters) + + def get_valid_targets(self) -> List[PlayerID]: + # Return a copy of targets that were valid (alive) at the start of voting. + return list(self._potential_targets) + + def get_elected(self) -> PlayerID | None: # Return type matches tally_votes + if not self.done(): + raise Exception("Voting is not done yet.") + if self._elected is None and not self._done_tallying: + self._elected = self._ballot.get_elected(self._potential_targets) + self._done_tallying = True + return self._elected + + +@register_protocol() +class SequentialVoting(VotingProtocol): + """ + Players vote one by one in a sequence. Each player is shown the current + tally before casting their vote. All players in the initial list of + voters get a turn. + """ + + def __init__(self, assign_random_first_voter: bool = True, tie_break: TieBreak = TieBreak.RANDOM): + self._potential_targets: List[PlayerID] = [] + self._voter_queue: List[PlayerID] = [] # Order of players to vote + self._expected_voters: List[PlayerID] = [] + self._current_voter_index: int = 0 # Index for _voter_queue + self._current_game_state: Optional[GameState] = None # To store state from begin_voting + self._elected: Optional[PlayerID] = None + self._done_tallying = False + self._assign_random_first_voter = assign_random_first_voter + self._player_ids = None + self._ballot = Ballot(tie_selection=tie_break) + + def reset(self) -> None: + self._ballot.reset() + self._potential_targets = [] + self._expected_voters = [] + self._voter_queue = [] + self._current_voter_index = 0 + self._current_game_state = None + self._elected = None + self._done_tallying = False + + @property + def display_name(self) -> str: + return "Sequential Voting" + + @property + def rule(self) -> str: + return ( + "Players vote one by one. Player with the most votes after all have voted is exiled." + " Ties are broken randomly." + ) + + def begin_voting(self, state: GameState, alive_voters: Sequence[Player], potential_targets: Sequence[Player]): + if self._player_ids is None: + # initialize player_ids once. + self._player_ids = deque(state.all_player_ids) + if self._assign_random_first_voter: + self._player_ids.rotate(random.randrange(len(self._player_ids))) + alive_voter_ids = [p.id for p in alive_voters] + alive_voter_ids_set = set(alive_voter_ids) + self._ballot.reset() + self._expected_voters = [pid for pid in self._player_ids if pid in alive_voter_ids_set] + self._potential_targets = [p.id for p in potential_targets] + # The order of voting can be based on player ID, a random shuffle, or the order in alive_voters + # For simplicity, using the order from alive_voters. + self._voter_queue = list(self._expected_voters) + self._current_voter_index = 0 + self._current_game_state = state # Store the game state reference + + if self._expected_voters: + data = VoteOrderDataEntry(vote_order_of_player_ids=self._expected_voters) + state.push_event( + description=f"Voting starts from player {self._expected_voters[0]} " + f"with the following order: {self._expected_voters}", + event_name=EventName.VOTE_ORDER, + public=False, + visible_to=alive_voter_ids, + data=data, + ) + + def get_voting_prompt(self, state: GameState, player_id: PlayerID) -> str: + """ + Generates a prompt for the given player_id, assuming it's their turn. + """ + current_tally = self.get_current_tally_info(state) + + # Sort for consistent display + tally_str_parts = [] + for target_id, votes in sorted(current_tally.items(), key=lambda x: x[1], reverse=True): + tally_str_parts.append(f"{target_id}: {votes} vote(s)") + + tally_str = "; ".join(tally_str_parts) if tally_str_parts else "No votes cast yet." + + options_str_parts = [] + for p_target in state.alive_players(): # Iterate through all alive players for options + if p_target.id in self._potential_targets: + options_str_parts.append(f"{p_target.id}") + options_str = ", ".join(options_str_parts) + + return ( + f"{player_id}, it is your turn to vote. " + f"Current tally: {tally_str}. " + f"Options: {options_str} or Abstain (vote for {ABSTAIN_VOTE})." + ) + + def collect_votes(self, player_actions: Dict[PlayerID, Action], state: GameState, expected_voters: List[PlayerID]): + if self.done(): + return + + # In sequential voting, expected_voters should contain exactly one player. + if not expected_voters: + # This case should ideally not be reached if `done()` is false. + # If it is, advancing the turn might be a safe way to prevent a stall. + self._current_voter_index += 1 + return + + expected_voter_id = expected_voters[0] + action = player_actions.get(expected_voter_id) + + if action: + self.collect_vote(action, state) + else: + # This block handles timeout for the expected voter. + # The player did not submit an action. Treat as NoOp/Abstain. + self.collect_vote(NoOpAction(actor_id=expected_voter_id, day=state.day_count, phase=state.phase), state) + + def collect_vote(self, vote_action: Action, state: GameState): + if not isinstance(vote_action, (VoteAction, NoOpAction)): + # Silently ignore if not a VoteAction or NoOpAction. + # Consider logging an "unexpected action type" error if more verbosity is needed. + return + + if self.done(): + state.push_event( + description=f"Action ({vote_action.kind}) received from {vote_action.actor_id}, " + f"but voting is already complete.", + event_name=EventName.ERROR, + public=False, + visible_to=[vote_action.actor_id], + ) + return + + expected_voter_id = self._voter_queue[self._current_voter_index] + if vote_action.actor_id != expected_voter_id: + state.push_event( + description=f"Action ({vote_action.kind}) received from {vote_action.actor_id}, " + f"but it is {expected_voter_id}'s turn.", + event_name=EventName.ERROR, + public=False, # Or public if strict turn enforcement is announced + visible_to=[vote_action.actor_id, expected_voter_id], + ) + return + + actor_player = next((p for p in state.players if p.id == vote_action.actor_id), None) + if actor_player and actor_player.alive: + description_for_event = "" + involved_players_list = [vote_action.actor_id] # Actor is always involved + data = None + if isinstance(vote_action, NoOpAction): + self._ballot.add_vote(vote_action.actor_id, ABSTAIN_VOTE) # Treat NoOp as abstain + description_for_event = f"{vote_action.actor_id} chose to NoOp (treated as Abstain)." + + elif isinstance(vote_action, VoteAction): # This must be true if not NoOpAction + target_display: str + recorded_target_id = vote_action.target_id + if vote_action.target_id != ABSTAIN_VOTE and vote_action.target_id not in self._potential_targets: + # Invalid target chosen for VoteAction + state.push_event( + description=f"{vote_action.actor_id} attempted to vote for {vote_action.target_id} " + f"(invalid target). Vote recorded as Abstain.", + event_name=EventName.ERROR, + public=False, + visible_to=[vote_action.actor_id], + ) + recorded_target_id = ABSTAIN_VOTE # Treat invalid target as abstain + target_display = f"Invalid Target ({vote_action.target_id}), recorded as Abstain" + elif vote_action.target_id == ABSTAIN_VOTE: + # Explicit Abstain via VoteAction + target_display = "Abstain" + # recorded_target_id is already ABSTAIN_VOTE + else: + # Valid target chosen for VoteAction + target_display = f"{vote_action.target_id}" + involved_players_list.append(vote_action.target_id) # Add valid target to involved + + self._ballot.add_vote(vote_action.actor_id, recorded_target_id) + description_for_event = f"{vote_action.actor_id} has voted for {target_display}." + + # Add data entry for the vote + data_entry_class = DayExileVoteDataEntry if state.phase == Phase.DAY else WerewolfNightVoteDataEntry + data = data_entry_class( + actor_id=vote_action.actor_id, + target_id=recorded_target_id, + reasoning=vote_action.reasoning, + perceived_threat_level=vote_action.perceived_threat_level, + action=vote_action, + ) + + state.push_event( + description=description_for_event, + event_name=EventName.VOTE_ACTION, + public=False, + visible_to=self._expected_voters, + data=data, + source=vote_action.actor_id, + ) + self._current_voter_index += 1 + else: # Player not found, not alive, or (redundantly) not their turn + state.push_event( + description=f"Invalid action ({vote_action.kind}) attempt by {vote_action.actor_id} (player not found," + f" not alive, or not their turn). Action not counted.", + event_name=EventName.ERROR, + public=False, + visible_to=[vote_action.actor_id], + ) + # If voter was expected but found to be not alive, advance turn to prevent stall + if vote_action.actor_id == expected_voter_id: # Implies actor_player was found but not actor_player.alive + self._current_voter_index += 1 + + def get_current_tally_info(self, state: GameState) -> Dict[str, int]: + # Returns counts of non-abstain votes for valid targets + return self._ballot.get_tally() + + def get_next_voters(self) -> List[PlayerID]: + if not self.done(): + # Ensure _current_voter_index is within bounds before accessing + if self._current_voter_index < len(self._voter_queue): + return [self._voter_queue[self._current_voter_index]] + return [] + + def done(self) -> bool: + if not self._voter_queue: # No voters were ever in the queue + return True + return self._current_voter_index >= len(self._voter_queue) + + def get_valid_targets(self) -> List[PlayerID]: + return list(self._potential_targets) + + def get_elected(self) -> Optional[PlayerID]: + if not self.done(): + raise Exception("Voting is not done yet.") + if self._elected is None and not self._done_tallying: + self._elected = self._ballot.get_elected(self._potential_targets) + self._done_tallying = True + return self._elected diff --git a/kaggle_environments/envs/werewolf/game/records.py b/kaggle_environments/envs/werewolf/game/records.py new file mode 100644 index 00000000..c16edf29 --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/records.py @@ -0,0 +1,334 @@ +import json +from abc import ABC +from datetime import datetime +from enum import IntEnum +from typing import Dict, List, Optional +from zoneinfo import ZoneInfo + +from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_serializer + +from .base import BaseAction, BaseEvent, PlayerID +from .consts import DetailedPhase, EventName, ObsKeys, PerceivedThreatLevel, Phase, PhaseDivider, RoleConst, Team + + +def get_utc_now(): + return str(datetime.now(ZoneInfo("UTC"))) + + +class DataAccessLevel(IntEnum): + PUBLIC = 0 + PERSONAL = 1 + + +class DataEntry(BaseModel, ABC): + """Abstract base class for all data entry types.""" + + pass + + +class ActionDataMixin(BaseModel): + """ + A mixin for action-related DataEntry models. + Includes the actor performing the action and their private reasoning. + """ + + actor_id: PlayerID + reasoning: Optional[str] = Field( + default=None, description="Private reasoning for moderator analysis.", access=DataAccessLevel.PERSONAL + ) + perceived_threat_level: Optional[PerceivedThreatLevel] = Field( + default=PerceivedThreatLevel.SAFE, access=DataAccessLevel.PERSONAL + ) + action: Optional[BaseAction] = Field(default=None, access=DataAccessLevel.PERSONAL) + + +class VisibleRawData(BaseModel): + data_type: str + json_str: str + + +class PlayerEventView(BaseModel): + day: int + phase: Phase + detailed_phase: DetailedPhase + event_name: EventName + description: str + data: Optional[dict | DataEntry] = None + source: str + created_at: str + + @model_serializer + def serialize(self) -> dict: + if isinstance(self.data, DataEntry): + data = self.data.model_dump() + else: + data = self.data + return dict( + day=self.day, + phase=self.phase, + detailed_phase=self.detailed_phase, + event_name=self.event_name, + description=self.description, + data=data, + source=self.source, + created_at=self.created_at, + ) + + +class Event(BaseEvent): + day: int # Day number, 0 for initial night + phase: Phase + detailed_phase: DetailedPhase + event_name: EventName + description: str + public: bool = False + visible_to: List[str] = Field(default_factory=list) + data: Optional[dict | DataEntry] = None + source: str + created_at: str = Field(default_factory=get_utc_now) + visible_in_ui: bool = True + """Determine if visible to game viewer in UI. Has no effect to game engine flow.""" + + @field_serializer("data") + def serialize_data(self, data): + if data is None: + return None + if isinstance(data, dict): + return data + if isinstance(data, BaseModel): + return data.model_dump() + return None + + def serialize(self): + # TODO: this is purely constructed for compatibility with html renderer. Need to refactor werewolf.js to handle + # a direct model_dump of Event + data_dict = self.model_dump() + return VisibleRawData(data_type=self.data.__class__.__name__, json_str=json.dumps(data_dict)).model_dump() + + def view_by_access(self, user_level: DataAccessLevel) -> PlayerEventView: + if isinstance(self.data, ActionDataMixin): + fields_to_include = set() + fields_to_exclude = set() + for name, info in self.data.__class__.model_fields.items(): + if info.json_schema_extra: + if user_level >= info.json_schema_extra.get("access", DataAccessLevel.PUBLIC): + fields_to_include.add(name) + else: + fields_to_exclude.add(name) + else: + fields_to_include.add(name) + data = self.data.model_dump(include=fields_to_include, exclude=fields_to_exclude) + else: + data = self.data + out = PlayerEventView( + day=self.day, + phase=self.phase, + detailed_phase=self.detailed_phase, + event_name=self.event_name, + description=self.description, + data=data, + source=self.source, + created_at=self.created_at, + ) + return out + + +# --- Game State and Setup Data Entries --- +class GameStartDataEntry(DataEntry): + player_ids: List[PlayerID] + number_of_players: int + role_counts: Dict[RoleConst, int] + team_member_counts: Dict[Team, int] + day_discussion_protocol_name: str + day_discussion_display_name: str + day_discussion_protocol_rule: str + night_werewolf_discussion_protocol_name: str + night_werewolf_discussion_display_name: str + night_werewolf_discussion_protocol_rule: str + day_voting_protocol_name: str + day_voting_display_name: str + day_voting_protocol_rule: str + + +class GameStartRoleDataEntry(DataEntry): + player_id: PlayerID + team: Team + role: RoleConst + rule_of_role: str + + +class SetNewPhaseDataEntry(DataEntry): + new_detailed_phase: DetailedPhase + + +class PhaseDividerDataEntry(DataEntry): + divider_type: PhaseDivider + + +# --- Request for Action Data Entries --- +class RequestForActionDataEntry(DataEntry): + action_json_schema: str + + +class RequestDoctorSaveDataEntry(RequestForActionDataEntry): + valid_candidates: List[PlayerID] + + +class RequestSeerRevealDataEntry(RequestForActionDataEntry): + valid_candidates: List[PlayerID] + + +class RequestWerewolfVotingDataEntry(RequestForActionDataEntry): + valid_targets: List[PlayerID] + alive_werewolve_player_ids: List[PlayerID] + voting_protocol_name: str + voting_protocol_rule: str + + +class RequestVillagerToSpeakDataEntry(RequestForActionDataEntry): + pass + + +# --- Action and Result Data Entries --- +class SeerInspectResultDataEntry(DataEntry): + actor_id: PlayerID + target_id: PlayerID + role: Optional[RoleConst] + team: Optional[Team] + + +class TargetedActionDataEntry(ActionDataMixin, DataEntry): + target_id: PlayerID + + +class SeerInspectActionDataEntry(TargetedActionDataEntry): + """This records the Seer's choice of target to inspect.""" + + +class DoctorHealActionDataEntry(TargetedActionDataEntry): + """This records the Doctor's choice of target to heal.""" + + +class WerewolfNightVoteDataEntry(TargetedActionDataEntry): + """Records a werewolf's vote, including private reasoning.""" + + +class DayExileVoteDataEntry(TargetedActionDataEntry): + """Records a player's vote to exile, including private reasoning.""" + + +class DoctorSaveDataEntry(DataEntry): + """This records that a player was successfully saved by a doctor.""" + + saved_player_id: PlayerID + + +class VoteOrderDataEntry(DataEntry): + vote_order_of_player_ids: List[PlayerID] + + +class WerewolfNightEliminationElectedDataEntry(DataEntry): + """This record the elected elimination target by werewolves.""" + + elected_target_player_id: PlayerID + + +class WerewolfNightEliminationDataEntry(DataEntry): + """This record the one eventually got eliminated by werewolves without doctor safe.""" + + eliminated_player_id: PlayerID + eliminated_player_role_name: Optional[RoleConst] = None + eliminated_player_team_name: Optional[Team] = None + + +class DayExileElectedDataEntry(DataEntry): + elected_player_id: PlayerID + elected_player_role_name: Optional[RoleConst] = None + elected_player_team_name: Optional[Team] = None + + +class DiscussionOrderDataEntry(DataEntry): + chat_order_of_player_ids: List[PlayerID] + + +class ChatDataEntry(ActionDataMixin, DataEntry): + """Records a chat message from a player, including private reasoning.""" + + # actor_id and reasoning are inherited from ActionDataMixin + message: str + mentioned_player_ids: List[PlayerID] = Field(default_factory=list) + + +class BidDataEntry(ActionDataMixin, DataEntry): + bid_amount: int + + +class BidResultDataEntry(DataEntry): + winner_player_ids: List[PlayerID] + bid_overview: Dict[PlayerID, int] + mentioned_players_in_previous_turn: List[PlayerID] = [] + + +# --- Game End and Observation Models (Unchanged) --- +class GameEndResultsDataEntry(DataEntry): + model_config = ConfigDict(use_enum_values=True) + + winner_team: Team + winner_ids: List[PlayerID] + loser_ids: List[PlayerID] + scores: Dict[str, int | float] + reason: str + last_day: int + last_phase: Phase + survivors_until_last_round_and_role: Dict[PlayerID, RoleConst] + all_players_and_role: Dict[PlayerID, RoleConst] + elimination_info: List[Dict] + """list each player's elimination status, see GameState.get_elimination_info""" + + all_players: List[Dict] + """provide the info dump for each player""" + + +class WerewolfObservationModel(BaseModel): + player_id: PlayerID + role: RoleConst + team: Team + is_alive: bool + day: int + detailed_phase: DetailedPhase + all_player_ids: List[PlayerID] + player_thumbnails: Dict[PlayerID, str] = {} + alive_players: List[PlayerID] + revealed_players: Dict[PlayerID, RoleConst | Team | None] = {} + new_visible_announcements: List[str] + new_player_event_views: List[PlayerEventView] + game_state_phase: Phase + + def get_human_readable(self) -> str: + # This is a placeholder implementation. A real implementation would format this nicely. + return json.dumps(self.model_dump(), indent=2) + + +def set_raw_observation(kaggle_player_state, raw_obs: WerewolfObservationModel): + """Persist raw observations for players in kaggle's player state + + Args: + kaggle_player_state: Kaggle's interpreter state is a list of player state. This arg is one player state item. + raw_obs: the raw observation for a player extracted from game engine. + + Note: using raw_obs.model_dump_json() will greatly increase rendering speed (due to kaggle environment's use + of deepcopy for serialization) at the expense of harder to parse JSON rendering, since we are getting a json + string instead of human-readable dump. We choose raw_obs.model_dump() for clarity. + """ + kaggle_player_state.observation[ObsKeys.RAW_OBSERVATION] = raw_obs.model_dump() + + +def get_raw_observation(kaggle_observation) -> WerewolfObservationModel: + """ + + Args: + kaggle_observation: + + Returns: a dict of WerewolfObservationModel dump + """ + return WerewolfObservationModel(**kaggle_observation[ObsKeys.RAW_OBSERVATION]) diff --git a/kaggle_environments/envs/werewolf/game/roles.py b/kaggle_environments/envs/werewolf/game/roles.py new file mode 100644 index 00000000..e8d32b57 --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/roles.py @@ -0,0 +1,326 @@ +import json +import logging +from collections import Counter, defaultdict, deque +from functools import partial +from typing import Deque, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator, model_validator + +from .actions import HealAction, InspectAction +from .base import BaseModerator, BasePlayer, BaseRole, EventHandler, PlayerID, on_event +from .consts import EventName, Phase, RevealLevel, RoleConst, Team +from .records import ( + Event, + PlayerEventView, + RequestDoctorSaveDataEntry, + RequestSeerRevealDataEntry, + SeerInspectResultDataEntry, +) + +logger = logging.getLogger(__name__) + + +class Role(BaseRole): + model_config = ConfigDict(use_enum_values=True) + + name: RoleConst = Field(..., frozen=True) + team: Team + night_priority: int = 100 # lower number acts earlier + descriptions: str + + +class Werewolf(Role): + name: RoleConst = RoleConst.WEREWOLF + team: Team = Team.WEREWOLVES + night_priority: int = 2 + descriptions: str = "Each night, collaborates with fellow werewolves to vote on eliminating one player." + + +class Villager(Role): + name: RoleConst = RoleConst.VILLAGER + team: Team = Team.VILLAGERS + descriptions: str = "No special abilities. Participates in the daily vote to eliminate a suspected werewolf." + + +class DoctorDescription: + ALLOW_SELF_SAVE = "Each night, may protect one player from a werewolf attack. Doctor is allowed to save themselves during night time." + NO_SELF_SAVE = "Each night, may protect one player from a werewolf attack. Doctor is NOT allowed to save themselves during night time." + NO_CONSECUTIVE_SAVE = " Doctor is NOT allowed to save the same player on consecutive nights." + + +class DoctorStateKey: + LAST_SAVED_DAY = "last_saved_day" + LAST_SAVED_PLAYER_ID = "last_saved_player_id" + + +class Doctor(Role): + name: RoleConst = RoleConst.DOCTOR + team: Team = Team.VILLAGERS + allow_self_save: bool = False + allow_consecutive_saves: bool = True + descriptions: str = "" + + @model_validator(mode="after") + def set_descriptions_default(self) -> "Doctor": + if self.descriptions == "": + if self.allow_self_save: + self.descriptions = DoctorDescription.ALLOW_SELF_SAVE + else: + self.descriptions = DoctorDescription.NO_SELF_SAVE + if not self.allow_consecutive_saves: + self.descriptions += DoctorDescription.NO_CONSECUTIVE_SAVE + return self + + @on_event(EventName.NIGHT_START) + def on_night_starts(self, me: BasePlayer, moderator: BaseModerator, event: Event): + if me.alive: + current_day = moderator.state.day_count + last_saved_day = me.get_role_state(DoctorStateKey.LAST_SAVED_DAY, default=-1) + last_saved_player_id = me.get_role_state(DoctorStateKey.LAST_SAVED_PLAYER_ID) + + # Reset consecutive save memory if a night was skipped + if not self.allow_consecutive_saves and last_saved_day != -1 and current_day > last_saved_day + 1: + me.set_role_state(DoctorStateKey.LAST_SAVED_PLAYER_ID, None) + last_saved_player_id = None + + valid_candidates = [p.id for p in moderator.state.alive_players()] + + if not self.allow_self_save: + valid_candidates = [p_id for p_id in valid_candidates if p_id != me.id] + + prompt = "Wake up Doctor. Who would you like to save? " + if not self.allow_consecutive_saves and last_saved_player_id: + valid_candidates = [p_id for p_id in valid_candidates if p_id != last_saved_player_id] + prompt += f'You cannot save the same player on consecutive nights. Player "{last_saved_player_id}" is not a valid target this night. ' + + data_entry = RequestDoctorSaveDataEntry( + valid_candidates=valid_candidates, action_json_schema=json.dumps(HealAction.schema_for_player()) + ) + prompt += f"The options are {data_entry.valid_candidates}." + + moderator.request_action( + action_cls=HealAction, + player_id=me.id, + prompt=prompt, + data=data_entry, + event_name=EventName.HEAL_REQUEST, + ) + + @on_event(EventName.HEAL_ACTION) + def on_heal_action(self, me: BasePlayer, moderator: BaseModerator, event: Event): + if not me.alive or event.data.actor_id != me.id: + return + + action = event.data.action + if isinstance(action, HealAction): + if not self.allow_self_save and action.target_id == me.id: + moderator.state.push_event( + description=f'Player "{me.id}", doctor is not allowed to self save. ' + f"Your target is {action.target_id}, which is your own id.", + event_name=EventName.ERROR, + public=False, + visible_to=[me.id], + ) + return + + if not self.allow_consecutive_saves and action.target_id == me.get_role_state( + DoctorStateKey.LAST_SAVED_PLAYER_ID + ): + moderator.state.push_event( + description=f'Player "{me.id}", you cannot save the same player on consecutive nights. ' + f'Your target "{action.target_id}" was also saved last night.', + event_name=EventName.ERROR, + public=False, + visible_to=[me.id], + ) + return + + moderator.record_night_save(me.id, action.target_id) + me.set_role_state(DoctorStateKey.LAST_SAVED_PLAYER_ID, action.target_id) + me.set_role_state(DoctorStateKey.LAST_SAVED_DAY, moderator.state.day_count) + + +class SeerDescription: + REVEAL_ROLE = "Each night, may inspect one player to learn their true role." + REVEAL_TEAM = "Each night, may inspect one player's team but not their role." + + +class Seer(Role): + name: RoleConst = RoleConst.SEER + team: Team = Team.VILLAGERS + descriptions: str = "" + reveal_level: RevealLevel = RevealLevel.ROLE + + @field_validator("reveal_level") + @classmethod + def validate_reveal_level(cls, v): + if v == RevealLevel.NO_REVEAL: + raise ValueError(f"Setting reveal_level of Seer as {v}. Seer will become useless.") + return v + + @model_validator(mode="after") + def set_descriptions_default(self) -> "Seer": + if self.descriptions == "": + if self.reveal_level == RevealLevel.ROLE: + self.descriptions = SeerDescription.REVEAL_ROLE + elif self.reveal_level == RevealLevel.TEAM: + self.descriptions = SeerDescription.REVEAL_TEAM + else: + raise ValueError(f"reveal_level {self.reveal_level} not supported.") + return self + + @on_event(EventName.NIGHT_START) + def on_night_starts(self, me: BasePlayer, moderator: BaseModerator, event: Event): + if me.alive: + data_entry = RequestSeerRevealDataEntry( + valid_candidates=[p.id for p in moderator.state.alive_players() if p != me], + action_json_schema=json.dumps(InspectAction.schema_for_player()), + ) + moderator.request_action( + action_cls=InspectAction, + player_id=me.id, + prompt=f"Wake up Seer. Who would you like to see their true {self.reveal_level}? " + f"The options are {data_entry.valid_candidates}.", + data=data_entry, + event_name=EventName.INSPECT_REQUEST, + ) + + @on_event(EventName.INSPECT_ACTION) + def on_inspect_action(self, me: BasePlayer, moderator: BaseModerator, event: Event): + action = event.data.action + if not me.alive or action.actor_id != me.id: + return + actor_id = me.id + target_player = moderator.state.get_player_by_id(action.target_id) + if target_player: # Ensure target exists + role = None + team = None + reveal_text = "" + if self.reveal_level == RevealLevel.ROLE: + role = target_player.role.name + team = target_player.role.team + reveal_text = f'Their role is a "{target_player.role.name}" in team "{target_player.role.team.value}".' + elif self.reveal_level == RevealLevel.TEAM: + team = target_player.role.team + reveal_text = f"Their team is {team}." + + data = SeerInspectResultDataEntry(actor_id=actor_id, target_id=action.target_id, role=role, team=team) + moderator.state.push_event( + description=f'Player "{actor_id}", you inspected {target_player.id}. ' + reveal_text, + event_name=EventName.INSPECT_RESULT, + public=False, + visible_to=[actor_id], + data=data, + ) + else: + moderator.state.push_event( + description=f'Player "{actor_id}", you inspected player "{action.target_id}",' + f" but this player could not be found.", + event_name=EventName.ERROR, + public=False, + visible_to=[actor_id], + ) + + +class LLM(BaseModel): + model_name: str + properties: Dict = {} + + +class Agent(BaseModel): + id: PlayerID + """The unique name of the player.""" + + agent_id: str + """Id of the agent. Might not be unique (many players might be using the same underlying agent).""" + + display_name: str = "" + """Agent name shown in the UI and only visible to spectator but not the players. e.g. Pete (base_harness-gemini-2.5-pro) + base_harness-gemini-2.5-pro is the display_name while Pete is the id. It maybe different from agent_id, + e.g. base_harness_v2-gemini-2.5-pro-0506, to reduce the cognitive load of the spectators. + """ + + role: RoleConst + role_params: Dict = Field(default_factory=dict) + """Parameters to the Role constructor""" + + thumbnail: Optional[str] = "" + agent_harness_name: str = "basic_llm" + llms: List[LLM] = [] + + def get_agent_name(self): + return f"{self.agent_harness_name}({', '.join([llm.model_name for llm in self.llms])})" + + +class Player(BasePlayer): + model_config = ConfigDict(use_enum_values=True) + + id: PlayerID + """The unique name of the player.""" + + agent: Agent + role: BaseRole + alive: bool = True + eliminated_during_day: int = -1 + """game starts at night 0, then day 1, night 1, day 2, ...""" + + eliminated_during_phase: Optional[Phase] = None + + _message_queue: Deque[PlayerEventView] = PrivateAttr(default_factory=deque) + _role_state: Dict = PrivateAttr(default_factory=dict) + + def set_role_state(self, key, value): + self._role_state[key] = value + + def get_role_state(self, key, default=None): + return self._role_state.get(key, default) + + def get_event_handlers(self, moderator: BaseModerator) -> Dict[EventName, List[EventHandler]]: + handlers = defaultdict(list) + for event_type, handler in self.role.get_event_handlers().items(): + event_handler = partial(handler, self, moderator) + handlers[event_type].append(event_handler) + return handlers + + def update(self, entry: PlayerEventView): + self._message_queue.append(entry) + + def consume_messages(self) -> List[PlayerEventView]: + messages = list(self._message_queue) + self._message_queue.clear() + return messages + + def eliminate(self, day: int, phase: Phase): + self.alive = False + self.eliminated_during_day = day + self.eliminated_during_phase = phase.value + + def report_elimination(self): + return { + "player_id": self.id, + "eliminated_during_day": self.eliminated_during_day, + "eliminated_during_phase": self.eliminated_during_phase, + } + + +ROLE_CLASS_MAP = { + RoleConst.WEREWOLF.value: Werewolf, + RoleConst.DOCTOR.value: Doctor, + RoleConst.SEER.value: Seer, + RoleConst.VILLAGER.value: Villager, +} + + +def create_players_from_agents_config(agents_config: List[Dict]) -> List[Player]: + # check all agents have unique ids + agent_ids = [agent_config["id"] for agent_config in agents_config] + if len(agent_ids) != len(set(agent_ids)): + counts = Counter(agent_ids) + duplicates = [item for item, count in counts.items() if count > 1 and item is not None] + if duplicates: + raise ValueError(f"Duplicate agent ids found: {', '.join(duplicates)}") + agents = [Agent(**agent_config) for agent_config in agents_config] + players = [ + Player(id=agent.id, agent=agent, role=ROLE_CLASS_MAP[agent.role](**agent.role_params)) for agent in agents + ] + return players diff --git a/kaggle_environments/envs/werewolf/game/states.py b/kaggle_environments/envs/werewolf/game/states.py new file mode 100644 index 00000000..22023aea --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/states.py @@ -0,0 +1,214 @@ +import logging +from collections import defaultdict, deque +from functools import cached_property +from typing import Any, DefaultDict, Deque, Dict, List, Optional, Sequence, Union + +from pydantic import ConfigDict, Field, PrivateAttr, computed_field + +from .base import BaseRole, BaseState, EventHandler, PlayerID +from .consts import MODERATOR_ID, DetailedPhase, EventName, Phase, PhaseDivider, RevealLevel, RoleConst, Team +from .records import DataAccessLevel, DataEntry, Event, PhaseDividerDataEntry, PlayerEventView +from .roles import Player + +logger = logging.getLogger(__name__) + + +class EventBus: + def __init__(self): + self._subs: DefaultDict[EventName, List[EventHandler]] = defaultdict(list) + + def register(self, event_name: EventName, handler: EventHandler): + self._subs[event_name].append(handler) + + def dispatch(self, entry: Event): + for handler in self._subs[entry.event_name]: + handler(entry) + + +class GameState(BaseState): + model_config = ConfigDict(use_enum_values=True) + + players: List[Player] + phase: Phase = Phase.NIGHT + detailed_phase: DetailedPhase = DetailedPhase.NIGHT_START + day_count: int = 0 + history: Dict[int, List[Event]] = Field(default_factory=dict) + wallet: dict[PlayerID, int] = Field(default_factory=dict) + night_elimination_reveal_level: RevealLevel = RevealLevel.ROLE + day_exile_reveal_level: RevealLevel = RevealLevel.ROLE + _id_to_player: Dict[PlayerID, Player] = PrivateAttr(default_factory=dict) + _event_by_type: Dict[EventName, List[Event]] = PrivateAttr(default_factory=lambda: defaultdict(list)) + _event_queue: Deque[Event] = PrivateAttr(default_factory=deque) + _night_elimination_player_ids: List[PlayerID] = PrivateAttr(default_factory=list) + _day_exile_player_ids: List[PlayerID] = PrivateAttr(default_factory=list) + _event_bus: EventBus = PrivateAttr(default_factory=EventBus) + + @computed_field + @cached_property + def all_player_ids(self) -> List[str]: + return [player.id for player in self.players] + + @computed_field + @cached_property + def all_unique_roles(self) -> List[BaseRole]: + role_dict = {player.role.name: player.role for player in self.players} + return list(role_dict.values()) + + def model_post_init(self, context: Any, /) -> None: + self._id_to_player = {p.id: p for p in self.players} + + def get_player_by_id(self, pid: PlayerID): + return self._id_to_player.get(pid) + + def get_players_by_role(self, role: RoleConst): + return [p for p in self.players if p.role.name == role] + + def get_players_by_team(self, team: Team): + return [p for p in self.players if p.role.team == team] + + def alive_players(self): + return [p for p in self.players if p.alive] + + def eliminated_players(self): + return [p for p in self.players if not p.alive] + + def revealed_players(self) -> Dict[PlayerID, RoleConst | Team | None]: + revealed = {} + if self.night_elimination_reveal_level == RevealLevel.ROLE: + revealed.update({pid: self.get_player_by_id(pid).role.name for pid in self._night_elimination_player_ids}) + elif self.night_elimination_reveal_level == RevealLevel.TEAM: + revealed.update({pid: self.get_player_by_id(pid).role.team for pid in self._night_elimination_player_ids}) + elif self.night_elimination_reveal_level == RevealLevel.NO_REVEAL: + revealed.update({pid: None for pid in self._night_elimination_player_ids}) + + if self.day_exile_reveal_level == RevealLevel.ROLE: + revealed.update({pid: self.get_player_by_id(pid).role.name for pid in self._day_exile_player_ids}) + elif self.day_exile_reveal_level == RevealLevel.TEAM: + revealed.update({pid: self.get_player_by_id(pid).role.team for pid in self._day_exile_player_ids}) + elif self.day_exile_reveal_level == RevealLevel.NO_REVEAL: + revealed.update({pid: None for pid in self._day_exile_player_ids}) + return revealed + + def is_alive(self, player_id: PlayerID): + return self.get_player_by_id(player_id).alive + + def alive_players_by_role(self, role: RoleConst): + return [p for p in self.alive_players() if p.role.name == role] + + def alive_players_by_team(self, team: Team): + return [p for p in self.alive_players() if p.role.team == team] + + def alive_player_counts_per_role(self): + counts = {role: len(self.alive_players_by_role(role)) for role in RoleConst} + return counts + + def alive_player_counts_per_team(self): + return {team: len(self.alive_players_by_team(team)) for team in Team} + + _night_eliminate_queue: List[PlayerID] = PrivateAttr(default_factory=list) + + def queue_eliminate(self, target: Player): + self._night_eliminate_queue.append(target.id) + + def clear_eliminate_queue(self): + self._night_eliminate_queue.clear() + + _night_doctor_save_queue: List[PlayerID] = PrivateAttr(default_factory=list) + + def queue_doctor_save(self, target: Player): + self._night_doctor_save_queue.append(target.id) + + def get_event_by_name(self, event_name: EventName) -> List[Event]: + return self._event_by_type[event_name] + + def push_event( + self, + description: str, + event_name: EventName, + public: bool, + visible_to: Optional[List[PlayerID]] = None, + data: Optional[Union[DataEntry, Dict[str, Any]]] = None, + source=MODERATOR_ID, + visible_in_ui: bool = True, + ): + visible_to = visible_to or [] + # Night 0 will use day_count 0, Day 1 will use day_count 1, etc. + day_key = self.day_count + self.history.setdefault(day_key, []) + sys_entry = Event( + day=day_key, + phase=self.phase, + detailed_phase=self.detailed_phase, + event_name=event_name, + description=description, + public=public, + visible_to=visible_to or [], + data=data, + source=source, + visible_in_ui=visible_in_ui, + ) + + self.history[day_key].append(sys_entry) + self._event_by_type[event_name].append(sys_entry) + self._event_queue.append(sys_entry) + + public_view = sys_entry.view_by_access(user_level=DataAccessLevel.PUBLIC) + personal_view = sys_entry.view_by_access(user_level=DataAccessLevel.PERSONAL) + + # observers message pushing below + if public: + for player in self.players: + if player.id == source: + player.update(personal_view) + else: + player.update(public_view) + else: + for player_id in visible_to: + player = self.get_player_by_id(player_id) + if player: + if player.id == source: + player.update(personal_view) + else: + player.update(public_view) + + # publish events + self._event_bus.dispatch(sys_entry) + + def add_phase_divider(self, divider: PhaseDivider): + """The phase divider is used to clearly separate phase boundary. This is very useful + for visualizer updates, where some updates naturally takes a time slice of events as input. + """ + self.push_event( + description=divider.value, + event_name=EventName.PHASE_DIVIDER, + public=False, + data=PhaseDividerDataEntry(divider_type=divider.value), + ) + + def eliminate_player(self, pid: PlayerID): + if pid not in self.all_player_ids: + logger.warning(f"Tried to eliminate {pid} who is not within valid player ids {self.all_player_ids}.") + return + player = self.get_player_by_id(pid) + if self.phase == Phase.NIGHT: + self._night_elimination_player_ids.append(pid) + else: + self._day_exile_player_ids.append(pid) + if player: + player.eliminate(day=self.day_count, phase=self.phase) + + def consume_messages(self) -> List[Event]: + messages = list(self._event_queue) + self._event_queue.clear() + return messages + + def get_elimination_info(self): + return [player.report_elimination() for player in self.players] + + def register_event_handler(self, event_name: EventName, handler: EventHandler): + self._event_bus.register(event_name, handler) + + +def get_last_action_request(event_views: Sequence[PlayerEventView], event_name: EventName) -> None | PlayerEventView: + """Get the action request from the new player history entry view updates.""" + return next((entry for entry in event_views if entry.event_name == event_name), None) diff --git a/kaggle_environments/envs/werewolf/game/test_actions.py b/kaggle_environments/envs/werewolf/game/test_actions.py new file mode 100644 index 00000000..73835df4 --- /dev/null +++ b/kaggle_environments/envs/werewolf/game/test_actions.py @@ -0,0 +1,45 @@ +import pytest + +from kaggle_environments.envs.werewolf.game.actions import filter_language + +test_data = [ + # Test 1: Basic lowercase substitution + ("We must kill the monster.", "We must eliminate the monster."), + # Test 2: Title case substitution + ("Killing is wrong.", "Eliminating is wrong."), + # Test 3: Uppercase substitution + ("The town should not LYNCH anyone.", "The town should not EXILE anyone."), + # Test 4: Word boundary check (should not affect "skill") + ("His skill is unparalleled.", "His skill is unparalleled."), + # Test 5: Mixed case and multiple substitutions + ( + "The Mob will lynch the player they think will Kill them.", + "The Mob will exile the player they think will Eliminate them.", + ), + # Test 6: Handling different word endings ('-ed', '-s') + ("He killed the dragon, and she kills the goblin.", "He eliminated the dragon, and she eliminates the goblin."), + # Test 7: No inappropriate words, should return original string + ("This is a perfectly safe sentence.", "This is a perfectly safe sentence."), + # Test 8: A more complex sentence with a third rule ('murder') + ( + "The detective solved the Murder, preventing the killer from killing again.", + "The detective solved the Remove, preventing the eliminator from eliminating again.", + ), + # Test 9: A tricky title case that isn't at the start of a sentence + ("I think Killing is not the answer.", "I think Eliminating is not the answer."), +] + + +@pytest.mark.parametrize("input_text, expected_text", test_data) +def test_clean_script_scenarios(input_text, expected_text): + """ + Tests the clean_script_preserve_case function with various scenarios. + """ + assert filter_language(input_text) == expected_text + + +def test_empty_string(): + """ + Tests that an empty string input results in an empty string output. + """ + assert filter_language("") == "" diff --git a/pyproject.toml b/pyproject.toml index 5a7ff867..dbbf86ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,19 +18,27 @@ dependencies = [ "numpy >= 2.2.6", "open_spiel >= 1.6.2", "pettingzoo == 1.24.0", + "pydantic >= 2.11.4", + "pygame", + "pyjson5", + "termcolor", "requests >= 2.25.1", "scipy >= 1.15.3", "shimmy >= 1.2.1", "stable-baselines3 == 2.7.0", "transformers >= 4.33.1", + "tenacity", + "google-auth >= 2.35.0" ] requires-python = ">=3.10" [dependency-groups] dev = [ - "ipython", "flit", + "ipython", + "python-dotenv", "pytest", + "PyYAML", "ruff", "uv", ] From b718bcece18892b731ec7b78ffd1f7d883df0526 Mon Sep 17 00:00:00 2001 From: Hann Wang Date: Fri, 3 Oct 2025 17:55:42 +0000 Subject: [PATCH 2/2] Add deterministic tests --- .../envs/werewolf/game/engine.py | 4 +- .../envs/werewolf/game/protocols/chat.py | 74 +-- .../werewolf/test_werewolf_deterministic.py | 171 +++++ kaggle_environments/envs/werewolf/werewolf.py | 600 ++++++++++++++++++ 4 files changed, 810 insertions(+), 39 deletions(-) create mode 100644 kaggle_environments/envs/werewolf/test_werewolf_deterministic.py create mode 100644 kaggle_environments/envs/werewolf/werewolf.py diff --git a/kaggle_environments/envs/werewolf/game/engine.py b/kaggle_environments/envs/werewolf/game/engine.py index be5c5091..55059e5c 100644 --- a/kaggle_environments/envs/werewolf/game/engine.py +++ b/kaggle_environments/envs/werewolf/game/engine.py @@ -393,9 +393,7 @@ def _handle_day_bidding_await(self, player_actions: Dict[PlayerID, Action]) -> D # We need to explicitly check if the bidding sub-phase is over # This requires a reference to the bidding protocol within BiddingDiscussion - assert isinstance(self.discussion, BiddingDiscussion) - bidding_protocol = self.discussion.bidding - if bidding_protocol.is_finished(self.state): + if self.discussion.bidding.is_finished(self.state): return DetailedPhase.DAY_BIDDING_CONCLUDE else: # Bidding is not over (e.g., sequential auction), get next bidders diff --git a/kaggle_environments/envs/werewolf/game/protocols/chat.py b/kaggle_environments/envs/werewolf/game/protocols/chat.py index 84dd3cf9..3ef946cd 100644 --- a/kaggle_environments/envs/werewolf/game/protocols/chat.py +++ b/kaggle_environments/envs/werewolf/game/protocols/chat.py @@ -248,48 +248,50 @@ def process_actions(self, actions: List[Action], expected_speakers: Sequence[Pla # Handle players who didn't bid (timed out) by assuming a bid of 0 all_alive_player_ids = [p.id for p in state.alive_players()] if hasattr(self.bidding, "_bids"): - for player_id in all_alive_player_ids: - if player_id not in self.bidding._bids: + for action, player_id in zip(actions, expected_speakers): + if not isinstance(action, BidAction): default_bid = BidAction( actor_id=player_id, amount=0, day=state.day_count, phase=state.phase.value ) self.bidding.accept(default_bid, state) bids = getattr(self.bidding, "_bids", {}) - if len(bids) >= len(all_alive_player_ids) and all(amount == 0 for amount in bids.values()): - self._all_passed = True - state.push_event( - description="All players passed on speaking. Discussion ends.", - event_name=EventName.MODERATOR_ANNOUNCEMENT, - public=True, - ) - return - - # Once all bids are in (or a timeout, handled by moderator's single tick), determine the winner - winner_list = self.bidding.outcome(state) - self._speaker = winner_list[0] if winner_list else None - - if self._speaker: - data = BidResultDataEntry( - winner_player_ids=[self._speaker], - bid_overview=self.bidding.bids, - mentioned_players_in_previous_turn=self.bidding.get_last_mentioned(state)[0], - ) - overview_text = ", ".join([f"{k}: {v}" for k, v in self.bidding.bids.items()]) - state.push_event( - description=f"Player {self._speaker} won the bid and will speak next.\n" - f"Bid overview - {overview_text}.", - event_name=EventName.BID_RESULT, - public=self._bid_result_public, - data=data, - ) - self.set_phase(BiddingDiscussionPhase.SPEAKING_PHASE) - else: - # No one to speak, advance turn count and bid again - self._turns_taken += 1 - if not self.is_discussion_over(state): - self.bidding.begin(state) # Prepare for next bidding round + if len(bids) >= len(all_alive_player_ids): + # If all bids are in + if all(amount == 0 for amount in bids.values()): + # If everyone decided to pass + self._all_passed = True + state.push_event( + description="All players passed on speaking. Discussion ends.", + event_name=EventName.MODERATOR_ANNOUNCEMENT, + public=True, + ) + return + else: + winner_list = self.bidding.outcome(state) + self._speaker = winner_list[0] if winner_list else None + if self._speaker: + data = BidResultDataEntry( + winner_player_ids=[self._speaker], + bid_overview=self.bidding.bids, + mentioned_players_in_previous_turn=self.bidding.get_last_mentioned(state)[0], + ) + overview_text = ", ".join([f"{k}: {v}" for k, v in self.bidding.bids.items()]) + state.push_event( + description=f"Player {self._speaker} won the bid and will speak next.\n" + f"Bid overview - {overview_text}.", + event_name=EventName.BID_RESULT, + public=self._bid_result_public, + data=data, + ) + self.set_phase(BiddingDiscussionPhase.SPEAKING_PHASE) + return + else: + self._turns_taken += 1 + if not self.is_discussion_over(state): + self.bidding.begin(state) + # continue bidding elif self.is_speaking_phase(): # Process the chat action from the designated speaker super().process_actions(actions, expected_speakers, state) @@ -299,7 +301,7 @@ def process_actions(self, actions: List[Action], expected_speakers: Sequence[Pla if not self.is_discussion_over(state): self.set_phase(BiddingDiscussionPhase.BIDDING_PHASE) self._speaker = None - self.bidding.begin(state) # Reset bids and find new mentioned players + self.bidding.begin(state) def prompt_speakers_for_tick(self, state: GameState, speakers: Sequence[PlayerID]) -> None: if self.is_bidding_phase(): diff --git a/kaggle_environments/envs/werewolf/test_werewolf_deterministic.py b/kaggle_environments/envs/werewolf/test_werewolf_deterministic.py new file mode 100644 index 00000000..ed49e657 --- /dev/null +++ b/kaggle_environments/envs/werewolf/test_werewolf_deterministic.py @@ -0,0 +1,171 @@ +import pytest + +from kaggle_environments import make +from kaggle_environments.envs.werewolf.game.protocols.vote import TieBreak +from kaggle_environments.envs.werewolf.game.consts import EnvInfoKeys, Team +from kaggle_environments.envs.werewolf.game.records import GameEndResultsDataEntry + +URLS = { + "gemini": "https://logos-world.net/wp-content/uploads/2025/01/Google-Gemini-Symbol.png", + "openai": "https://images.seeklogo.com/logo-png/46/1/chatgpt-logo-png_seeklogo-465219.png", + "claude": "https://images.seeklogo.com/logo-png/55/1/claude-logo-png_seeklogo-554534.png", + "grok": "https://images.seeklogo.com/logo-png/61/1/grok-logo-png_seeklogo-613403.png", +} + + +@pytest.fixture +def deterministic_agents_config(): + roles = ["Werewolf", "Werewolf", "Doctor", "Seer", "Villager", "Villager", "Villager"] + names = [f"player_{i}" for i in range(len(roles))] + thumbnails = [ + URLS["gemini"], + URLS["gemini"], + URLS["openai"], + URLS["openai"], + URLS["openai"], + URLS["claude"], + URLS["grok"], + ] + agents_config = [ + {"role": role, "id": name, "agent_id": "deterministic", "thumbnail": url} + for role, name, url in zip(roles, names, thumbnails) + ] + return agents_config + + +@pytest.fixture +def deterministic_config_options(): + options = { + "discussion_protocol": {"name": "RoundRobinDiscussion", + "params": {"max_rounds": 1, "assign_random_first_speaker": False}}, + "day_voting_protocol": {"name": "SequentialVoting", + "params": {"assign_random_first_voter": True, "tie_break": TieBreak.NO_EXILE}}, + "werewolf_night_vote_protocol": {"name": "SequentialVoting", + "params": {"assign_random_first_voter": True, "tie_break": TieBreak.NO_EXILE}} + } + return options + + +def test_game_result(deterministic_agents_config, deterministic_config_options): + """ + Tests that the deterministic werewolves vote to eliminate the first valid target. + """ + env = make("werewolf", debug=True, + configuration={"agents": deterministic_agents_config, **deterministic_config_options}) + agents = ["deterministic"] * 7 + env.run(agents) + + result = GameEndResultsDataEntry(**env.info[EnvInfoKeys.GAME_END]) + + assert len(env.steps) == 24 + assert result.winner_team == Team.VILLAGERS + assert result.winner_ids == ['player_2', 'player_3', 'player_4', 'player_5', 'player_6'] + assert result.loser_ids == ['player_0', 'player_1'] + assert result.scores == {'player_2': 1, 'player_3': 1, 'player_4': 1, 'player_5': 1, 'player_6': 1, 'player_0': 0, + 'player_1': 0} + assert result.elimination_info == [ + {'player_id': 'player_0', 'eliminated_during_day': 1, 'eliminated_during_phase': 'Day'}, + {'player_id': 'player_1', 'eliminated_during_day': 2, 'eliminated_during_phase': 'Day'}, + {'player_id': 'player_2', 'eliminated_during_day': 0, 'eliminated_during_phase': 'Night'}, + {'player_id': 'player_3', 'eliminated_during_day': 1, 'eliminated_during_phase': 'Night'}, + {'player_id': 'player_4', 'eliminated_during_day': -1, 'eliminated_during_phase': None}, + {'player_id': 'player_5', 'eliminated_during_day': -1, 'eliminated_during_phase': None}, + {'player_id': 'player_6', 'eliminated_during_day': -1, 'eliminated_during_phase': None} + ] + + +def test_parallel_discussion_simultaneous_majority_vote(deterministic_agents_config, deterministic_config_options): + config = {'agents': deterministic_agents_config, **deterministic_config_options} + config.update({ + "discussion_protocol": {"name": "ParallelDiscussion", + "params": {"ticks": 2}}, + "day_voting_protocol": {"name": "SimultaneousMajority", + "params": {"tie_break": TieBreak.NO_EXILE}}, + "werewolf_night_vote_protocol": {"name": "SimultaneousMajority", + "params": {"tie_break": TieBreak.NO_EXILE}} + }) + + env = make("werewolf", debug=True, + configuration=config) + agents = ["deterministic"] * 7 + env.run(agents) + + result = GameEndResultsDataEntry(**env.info[EnvInfoKeys.GAME_END]) + + assert len(env.steps) == 11 + assert result.winner_team == Team.VILLAGERS + assert result.winner_ids == ['player_2', 'player_3', 'player_4', 'player_5', 'player_6'] + assert result.loser_ids == ['player_0', 'player_1'] + assert result.scores == {'player_2': 1, 'player_3': 1, 'player_4': 1, 'player_5': 1, 'player_6': 1, 'player_0': 0, + 'player_1': 0} + assert result.elimination_info == [ + {'player_id': 'player_0', 'eliminated_during_day': 1, 'eliminated_during_phase': 'Day'}, + {'player_id': 'player_1', 'eliminated_during_day': 2, 'eliminated_during_phase': 'Day'}, + {'player_id': 'player_2', 'eliminated_during_day': 0, 'eliminated_during_phase': 'Night'}, + {'player_id': 'player_3', 'eliminated_during_day': 1, 'eliminated_during_phase': 'Night'}, + {'player_id': 'player_4', 'eliminated_during_day': -1, 'eliminated_during_phase': None}, + {'player_id': 'player_5', 'eliminated_during_day': -1, 'eliminated_during_phase': None}, + {'player_id': 'player_6', 'eliminated_during_day': -1, 'eliminated_during_phase': None} + ] + + +def test_round_by_round_bidding_discussion_sequential_vote(deterministic_agents_config, deterministic_config_options): + config = {'agents': deterministic_agents_config, **deterministic_config_options} + config.update({ + "discussion_protocol": {"name": "RoundByRoundBiddingDiscussion", + "params": {"bidding": {"name": "UrgencyBiddingProtocol"}, "max_rounds": 2, + "bid_result_public": True}} + }) + env = make("werewolf", debug=True, + configuration=config) + agents = ["deterministic"] * 7 + env.run(agents) + + result = GameEndResultsDataEntry(**env.info[EnvInfoKeys.GAME_END]) + + assert len(env.steps) == 34 + assert result.winner_team == Team.VILLAGERS + assert result.winner_ids == ['player_2', 'player_3', 'player_4', 'player_5', 'player_6'] + assert result.loser_ids == ['player_0', 'player_1'] + assert result.scores == {'player_2': 1, 'player_3': 1, 'player_4': 1, 'player_5': 1, 'player_6': 1, 'player_0': 0, + 'player_1': 0} + assert result.elimination_info == [ + {'player_id': 'player_0', 'eliminated_during_day': 1, 'eliminated_during_phase': 'Day'}, + {'player_id': 'player_1', 'eliminated_during_day': 2, 'eliminated_during_phase': 'Day'}, + {'player_id': 'player_2', 'eliminated_during_day': 0, 'eliminated_during_phase': 'Night'}, + {'player_id': 'player_3', 'eliminated_during_day': 1, 'eliminated_during_phase': 'Night'}, + {'player_id': 'player_4', 'eliminated_during_day': -1, 'eliminated_during_phase': None}, + {'player_id': 'player_5', 'eliminated_during_day': -1, 'eliminated_during_phase': None}, + {'player_id': 'player_6', 'eliminated_during_day': -1, 'eliminated_during_phase': None} + ] + + +def test_turn_by_turn_bidding(deterministic_agents_config, deterministic_config_options): + config = {'agents': deterministic_agents_config, **deterministic_config_options} + config.update({ + "discussion_protocol": {"name": "TurnByTurnBiddingDiscussion", + "params": {"bidding": {"name": "UrgencyBiddingProtocol"}, "max_turns": 10, + "bid_result_public": False}} + }) + env = make("werewolf", debug=True, + configuration=config) + agents = ["deterministic"] * 7 + env.run(agents) + + result = GameEndResultsDataEntry(**env.info[EnvInfoKeys.GAME_END]) + + assert len(env.steps) == 34 + assert result.winner_team == Team.VILLAGERS + assert result.winner_ids == ['player_2', 'player_3', 'player_4', 'player_5', 'player_6'] + assert result.loser_ids == ['player_0', 'player_1'] + assert result.scores == {'player_2': 1, 'player_3': 1, 'player_4': 1, 'player_5': 1, 'player_6': 1, 'player_0': 0, + 'player_1': 0} + assert result.elimination_info == [ + {'player_id': 'player_0', 'eliminated_during_day': 1, 'eliminated_during_phase': 'Day'}, + {'player_id': 'player_1', 'eliminated_during_day': 2, 'eliminated_during_phase': 'Day'}, + {'player_id': 'player_2', 'eliminated_during_day': 0, 'eliminated_during_phase': 'Night'}, + {'player_id': 'player_3', 'eliminated_during_day': 1, 'eliminated_during_phase': 'Night'}, + {'player_id': 'player_4', 'eliminated_during_day': -1, 'eliminated_during_phase': None}, + {'player_id': 'player_5', 'eliminated_during_day': -1, 'eliminated_during_phase': None}, + {'player_id': 'player_6', 'eliminated_during_day': -1, 'eliminated_during_phase': None} + ] diff --git a/kaggle_environments/envs/werewolf/werewolf.py b/kaggle_environments/envs/werewolf/werewolf.py new file mode 100644 index 00000000..315810bd --- /dev/null +++ b/kaggle_environments/envs/werewolf/werewolf.py @@ -0,0 +1,600 @@ +import json +import logging +import random +from os import getenv, path +from typing import Callable, Dict, List, Optional + +from pydantic import BaseModel, Field + +from kaggle_environments.envs.werewolf.game.consts import DetailedPhase, EnvInfoKeys, PerceivedThreatLevel +from .game.actions import ( + Action, + BidAction, + ChatAction, + HealAction, + InspectAction, + NoOpAction, + VoteAction, + create_action, +) +from .game.base import PlayerID +from .game.consts import RoleConst +from .game.engine import Moderator +from .game.protocols.factory import create_protocol +from .game.records import WerewolfObservationModel, get_raw_observation, set_raw_observation +from .game.roles import create_players_from_agents_config +from .game.states import EventName, GameState, get_last_action_request +from .harness.base import LLMCostTracker, LLMWerewolfAgent + +logger = logging.getLogger(__name__) + +# --- Protocol Factory --- +DEFAULT_DISCUSSION_PROTOCOL_NAME = "RoundRobinDiscussion" +DEFAULT_VOTING_PROTOCOL_NAME = "SimultaneousMajority" +DEFAULT_BIDDING_PROTOCOL_NAME = "UrgencyBiddingProtocol" + + +class AgentCost(BaseModel): + total_cost: float = 0.0 + prompt_tokens: int = 0 + completion_tokens: int = 0 + + +class AgentCostSummary(BaseModel): + agent_config: Dict + costs: AgentCost = Field(default_factory=AgentCost) + data: Optional[LLMCostTracker] = None + + +class CostSummary(BaseModel): + cost_per_agent: List[AgentCostSummary] = Field(default_factory=list) + total_cost: float = 0.0 + total_prompt_tokens: int = 0 + total_completion_tokens: int = 0 + total_tokens: int = 0 + + +_PERCEIVED_THREAT_LEVELS = [item.value for item in PerceivedThreatLevel] + +def random_agent(obs): + raw_obs = get_raw_observation(obs) + + entries = raw_obs.new_player_event_views + current_phase = DetailedPhase(raw_obs.detailed_phase) + my_role = raw_obs.role + all_player_names = raw_obs.all_player_ids + my_id = raw_obs.player_id + alive_players = raw_obs.alive_players + day = raw_obs.day + phase = raw_obs.game_state_phase + common_args = {"day": day, "phase": phase, "actor_id": my_id} + + action = NoOpAction(**common_args, reasoning="There's nothing to be done.") # Default action + threat_level = random.choice(_PERCEIVED_THREAT_LEVELS) + + if current_phase == DetailedPhase.NIGHT_AWAIT_ACTIONS: + if my_role == RoleConst.WEREWOLF: + history_entry = get_last_action_request(entries, EventName.VOTE_REQUEST) + if history_entry: + valid_targets = history_entry.data.get("valid_targets") + if valid_targets: + target_id = random.choice(valid_targets) + action = VoteAction( + **common_args, + target_id=target_id, + reasoning="I randomly chose one.", + perceived_threat_level=threat_level, + ) + + elif my_role == RoleConst.DOCTOR: + history_entry = get_last_action_request(entries, EventName.HEAL_REQUEST) + if history_entry: + valid_targets = history_entry.data["valid_candidates"] + if valid_targets: + target_id = random.choice(valid_targets) + action = HealAction( + **common_args, + target_id=target_id, + reasoning="I randomly chose one to heal.", + perceived_threat_level=threat_level, + ) + + elif my_role == RoleConst.SEER: + history_entry = get_last_action_request(entries, EventName.INSPECT_REQUEST) + if history_entry: + valid_targets = history_entry.data["valid_candidates"] + if valid_targets: + target_id = random.choice(valid_targets) + action = InspectAction( + **common_args, + target_id=target_id, + reasoning="I randomly chose one to inspect.", + perceived_threat_level=threat_level, + ) + + elif current_phase in [DetailedPhase.DAY_BIDDING_AWAIT, DetailedPhase.DAY_CHAT_AWAIT]: + if current_phase == DetailedPhase.DAY_BIDDING_AWAIT: + if my_id in alive_players: + action = BidAction( + **common_args, + amount=random.randint(1, 4), + reasoning="I am bidding randomly.", + perceived_threat_level=threat_level, + ) + else: # It's a chat turn (DAY_CHAT_AWAIT) + if my_id in alive_players: + action = ChatAction( + **common_args, + message=random.choice( + [ + "Hello everyone!", + f"I suspect {random.choice(all_player_names)}.", + "Any information to share?", + "I am a simple Villager just trying to survive.", + "Let's think carefully before voting.", + ] + ), + reasoning="I randomly chose one message.", + perceived_threat_level=threat_level, + ) + + elif current_phase == DetailedPhase.DAY_VOTING_AWAIT: + if my_id in alive_players: + # A real agent would parse the prompt for valid targets + valid_targets = [p_id for p_id in alive_players if p_id != my_id] + if valid_targets: + action = VoteAction( + **common_args, + target_id=random.choice(valid_targets), + reasoning="I randomly chose one.", + perceived_threat_level=threat_level, + ) + + return action.serialize() + + +FIXED_MESSAGE = "I am a simple villager." +FIXED_REASONING = "I am going to do one fixed thing." + + +def deterministic_agent(obs): + raw_obs = get_raw_observation(obs) + + entries = raw_obs.new_player_event_views + current_phase = DetailedPhase(raw_obs.detailed_phase) + my_role = raw_obs.role + my_id = raw_obs.player_id + alive_players = raw_obs.alive_players + day = raw_obs.day + phase = raw_obs.game_state_phase + common_args = {"day": day, "phase": phase, "actor_id": my_id} + + action = NoOpAction(**common_args, reasoning="There's nothing to be done.") # Default action + threat_level = random.choice(_PERCEIVED_THREAT_LEVELS) + + if current_phase == DetailedPhase.NIGHT_AWAIT_ACTIONS: + if my_role == RoleConst.WEREWOLF: + history_entry = get_last_action_request(entries, EventName.VOTE_REQUEST) + if history_entry: + valid_targets = history_entry.data.get("valid_targets") + if valid_targets: + # always select first valid + target_id = valid_targets[0] + action = VoteAction( + **common_args, + target_id=target_id, + reasoning=FIXED_REASONING, + perceived_threat_level=threat_level, + ) + + elif my_role == RoleConst.DOCTOR: + history_entry = get_last_action_request(entries, EventName.HEAL_REQUEST) + if history_entry: + valid_targets = history_entry.data["valid_candidates"] + if valid_targets: + target_id = valid_targets[0] + action = HealAction( + **common_args, + target_id=target_id, + reasoning=FIXED_REASONING, + perceived_threat_level=threat_level, + ) + + elif my_role == RoleConst.SEER: + history_entry = get_last_action_request(entries, EventName.INSPECT_REQUEST) + if history_entry: + valid_targets = history_entry.data["valid_candidates"] + if valid_targets: + target_id = valid_targets[0] + action = InspectAction( + **common_args, + target_id=target_id, + reasoning=FIXED_REASONING, + perceived_threat_level=threat_level, + ) + + elif current_phase in [DetailedPhase.DAY_BIDDING_AWAIT, DetailedPhase.DAY_CHAT_AWAIT]: + if current_phase == DetailedPhase.DAY_BIDDING_AWAIT: + if my_id in alive_players: + action = BidAction( + **common_args, + amount=4, + reasoning=FIXED_REASONING, + perceived_threat_level=threat_level, + ) + else: # It's a chat turn (DAY_CHAT_AWAIT) + if my_id in alive_players: + action = ChatAction( + **common_args, + message=FIXED_MESSAGE, + reasoning=FIXED_REASONING, + perceived_threat_level=threat_level, + ) + + elif current_phase == DetailedPhase.DAY_VOTING_AWAIT: + if my_id in alive_players: + # A real agent would parse the prompt for valid targets + valid_targets = [p_id for p_id in alive_players if p_id != my_id] + if valid_targets: + action = VoteAction( + **common_args, + target_id=valid_targets[0], + reasoning=FIXED_REASONING, + perceived_threat_level=threat_level, + ) + + return action.serialize() + + +class AgentFactoryWrapper: + """ + A wrapper that creates and manages separate agent instances for each player. + This is necessary for stateful agents to be used in the agent registry, + preventing them from sharing state (like memory or history) across different players. + """ + + def __init__(self, agent_class, **kwargs): + self._agent_class = agent_class + self._shared_kwargs = kwargs + self._kwargs = {} # store configs of individual agents + self._instances = {} + self._agent_configs = None + + @property + def agent_class(self): + return self._agent_class + + def get_instance(self, player_id: PlayerID): + return self._instances.get(player_id) + + def __call__(self, obs, config): + """ + The main callable method for the agent. It routes the call to the correct + player-specific agent instance. + """ + raw_obs = get_raw_observation(obs) + player_id = raw_obs.player_id # get the current active player id + + if not player_id: + # This could happen on initial steps or for an inactive agent. + # Returning a NO_OP action is a safe fallback. + return NoOpAction( + day=raw_obs.day, + phase=raw_obs.game_state_phase, + actor_id="unknown_fallback", + reasoning="AgentFactoryWrapper: No player_id found in observation.", + ).serialize() + + if not self._agent_configs: + self._agent_configs = {agent_config.id: agent_config for agent_config in config.agents} + + if player_id not in self._instances: + # Create a new agent instance for this player + self._kwargs[player_id] = {"agent_config": self._agent_configs.get(player_id)} + self._instances[player_id] = self._agent_class(**self._shared_kwargs, **self._kwargs[player_id]) + return self._instances[player_id](obs) + + def reset(self): + self._instances.clear() + + +# --- Agent Registry --- +LLM_SYSTEM_PROMPT = "You are a master strategist playing the game of Werewolf. Your goal is to win. You win as a team and not as individuals." + + +# *Package variable required by Kaggle Environments framework* +# These are base agents that the calling framework can choose from +# Provides a random_agent for testing and a convenient default 'llm' agent. + +agents = { + "random": random_agent, + "deterministic": deterministic_agent, + "llm": AgentFactoryWrapper( + LLMWerewolfAgent, + model_name=getenv("WEREWOLF_LLM_MODEL", "gemini/gemini-2.5-pro"), + system_prompt=LLM_SYSTEM_PROMPT, + ), +} + + +def register_agents(agent_dict: Dict[str, Callable]): + agents.update(agent_dict) + + +def log_error(status_code, state, env): + invalid_action = any(player_state["status"] == status_code for player_state in state) + if invalid_action: + logger.error(f"{status_code} DETECTED") + for i, player_state in enumerate(state): + if player_state["status"] == status_code: + agent_config = env.configuration["agents"][i] + logger.error(f"agent_id={agent_config['id']} returns action with status code {status_code}.") + return invalid_action + + +def interpreter(state, env): + """ + * Required interface function for kaggle environments package * + + This is the primary interface for the kaggle environment (kEnv) to step game forward. + Briefly flow of logic is: + Initialization - kEnv creates werewolf object and chooses players. Schema definition for + this is in werewolf.json + 1) kEnv calls interpreter() with current game state recorded in env.game_state + 2) interpreter() reads game state and any new player actions and updates + the games state based on those actions and flow of the game to env.game_state. + 3) interpreter() writes events to history data and also writes events about + state change in the game to env.game_state and returns back to kEnv + 4) kEnv parses out the relevant game events via agent logic in harness/base.py, + constructs final prompt, and performs external API calls for models and records back + to env.game_state + Go back to 1 and continue + + For example - consider discussion and voting by villagers. werewolf.interpreter() + updates phase and writes history entry that solicits players for discussion. + kEnv calls agents to get their discussion and writes them to the history/game state. + kEnv then calls interpreter() that then updates game phase and writes history entry soliciting + votes for exile. kEnv then calls agents and associated models to get their votes and writes + responses to game state. env then calls interpreter() and moderator collects votes, determine + who was exiled, performs that action and advances game phase and game state. + And so on... + + Note - The UI is also updated after each call to interpreter() as that is the tick unit + for the game. + + Note - env framework assumes that there is an action to be done by player, but + for werewolf there are places where moderator is the one taking the action (e.g. + counting votes and performing exile) so some game 'ticks' are larger than others. + + state: list of dictionaries, one for each agent. + Each dict has: {observation, action, reward, status, info} + env: the kaggle_environments.Environment object itself including the env.game_state + """ + agent_error = False + for status_code in ["TIMEOUT", "ERROR", "INVALID"]: + if log_error(status_code, state, env): + agent_error = True + + # --- Initialize Moderator and GameState if it's the start of an episode --- + if not hasattr(env, "moderator") or env.done: # env.done is true after reset by Kaggle core + initialize_moderator(state, env) + + moderator: Moderator = env.moderator + game_state: GameState = env.game_state + + # 1. Collect and parse actions from Kaggle agents + parsed_player_actions = parse_player_actions(state, moderator, game_state) + + # 2. Advance the Moderator + moderator.advance(parsed_player_actions) + + # 3. Update Kaggle state (observations, rewards, statuses) + is_game_done = moderator.is_game_over() or agent_error + current_info = {} + if is_game_done: + record_game_end(state, env, game_state, current_info, agent_error) + + # 4. Moderator interprets player actions, updates game phase, and advance game player actions + active_player_ids_after_advance = set(moderator.get_active_player_ids()) + + # 4.1. Accumulate God mode observations from env for rendering + global_messages = env.game_state.consume_messages() + global_data = [rec.serialize() for rec in global_messages] + env.info[EnvInfoKeys.MODERATOR_OBS].append(global_data) + + # 4.2. Update observations for individual agents + update_agent_messages( + state, env, moderator, game_state, is_game_done, current_info, active_player_ids_after_advance, agent_error + ) + return state + + +def collect_cost_summary(env) -> CostSummary: + cost_summary = CostSummary() + + for agent_config in env.configuration.agents: + player_id = agent_config["id"] + agent_id = agent_config["agent_id"] + + agent_cost_summary = AgentCostSummary(agent_config=agent_config) + + if isinstance(agents.get(agent_id), AgentFactoryWrapper) and issubclass( + agents[agent_id].agent_class, LLMWerewolfAgent + ): + agent_instance = agents[agent_id].get_instance(player_id) + if agent_instance: + cost_tracker = agent_instance.cost_tracker + agent_cost = AgentCost( + total_cost=cost_tracker.query_token_cost.total_costs_usd, + prompt_tokens=cost_tracker.prompt_token_cost.total_tokens, + completion_tokens=cost_tracker.completion_token_cost.total_tokens, + ) + agent_cost_summary.costs = agent_cost + agent_cost_summary.data = cost_tracker + + cost_summary.total_cost += agent_cost.total_cost + cost_summary.total_prompt_tokens += agent_cost.prompt_tokens + cost_summary.total_completion_tokens += agent_cost.completion_tokens + + cost_summary.cost_per_agent.append(agent_cost_summary) + + cost_summary.total_tokens = cost_summary.total_prompt_tokens + cost_summary.total_completion_tokens + return cost_summary + + +def record_game_end(state, env, game_state, current_info, agent_error): + # log game end to env.info using GameEndResultsDataEntry + game_end_entry = next(iter(game_state.get_event_by_name(EventName.GAME_END)), None) + if game_end_entry and game_end_entry.data: + current_info.update(game_end_entry.data.model_dump()) + # Record if terminated with agent error. If so, the game record is invalid. + current_info["terminated_with_agent_error"] = agent_error + + # Record cost from endpoints if any. + current_info["cost_summary"] = collect_cost_summary(env).model_dump() + + env.info[EnvInfoKeys.GAME_END] = current_info + # Determine winner based on game_state.history's GAME_END entry + if game_end_entry: + scores = game_end_entry.data.scores + for i, player_id in enumerate(env.player_id_str_list): + state[i].reward = scores[player_id] + + +def update_agent_messages( + state, env, moderator, game_state, is_game_done, current_info, active_player_ids_after_advance, agent_error +): + for player_index, player_state in enumerate(state): + player_id_str = env.player_ids_map[player_index] + + # skip if player not active and game is not done + if player_id_str not in active_player_ids_after_advance and not is_game_done: + player_state.status = "INACTIVE" + continue + + # set the status of active player to ACTIVE + player_state.status = "ACTIVE" + player_obj = game_state.get_player_by_id(player_id_str) + + # Observation processing + new_history_entries = player_obj.consume_messages() + + obs = WerewolfObservationModel( + player_id=player_obj.id, + role=player_obj.role.name, + team=player_obj.role.team.value, + is_alive=player_obj.alive, + day=game_state.day_count, + detailed_phase=moderator.detailed_phase.value, + all_player_ids=game_state.all_player_ids, + player_thumbnails=env.player_thumbnails, + alive_players=[p.id for p in game_state.alive_players()], + revealed_players=game_state.revealed_players(), + new_visible_announcements=[entry.description for entry in new_history_entries], + new_player_event_views=new_history_entries, + game_state_phase=game_state.phase.value, + ) + + set_raw_observation(player_state, raw_obs=obs) + + # Status + if is_game_done or agent_error: + player_state.status = "DONE" + elif player_id_str in active_player_ids_after_advance: + player_state.status = "ACTIVE" + else: + player_state.status = "INACTIVE" + + # Info + player_state.info = current_info + + +def parse_player_actions(state, moderator, game_state): + parsed_player_actions: Dict[str, Action] = {} + active_player_ids_from_moderator = moderator.get_active_player_ids() + + for sub_state, player in zip(state, game_state.players): + player_id_str = player.id + if player_id_str in active_player_ids_from_moderator and sub_state.status == "ACTIVE": + serialized_action = sub_state.action + if serialized_action: + parsed_player_actions[player_id_str] = create_action(serialized_action) + return parsed_player_actions + + +def initialize_moderator(state, env): + num_players = len(state) + + agents_from_config = env.configuration.agents + + # below checks for configuration consistency with agent count. If inconsistent, it will cause down stream subtle error. + if len(agents_from_config) < num_players: + raise ValueError( + f"Configuration has {len(agents_from_config)} agents, but {num_players} kaggle agents are present." + ) + + players = create_players_from_agents_config(agents_from_config) + + env.game_state = GameState( + players=players, + history={}, + night_elimination_reveal_level=env.configuration.night_elimination_reveal_level, + day_exile_reveal_level=env.configuration.day_exile_reveal_level, + ) + + env.player_ids_map = {i: p.id for i, p in enumerate(players)} + env.player_id_str_list = [p.id for p in players] + + env.player_thumbnails = {p.id: p.agent.thumbnail for p in players} + # Initialize protocols from configuration or defaults + discussion_protocol = create_protocol( + env.configuration.get("discussion_protocol", {}), default_name=DEFAULT_DISCUSSION_PROTOCOL_NAME + ) + day_voting_protocol = create_protocol( + env.configuration.get("day_voting_protocol", {}), default_name=DEFAULT_VOTING_PROTOCOL_NAME + ) + night_voting_protocol = create_protocol( + env.configuration.get("werewolf_night_vote_protocol", {}), default_name=DEFAULT_VOTING_PROTOCOL_NAME + ) + + logger.info( + f"Interpreter: Using Discussion: {type(discussion_protocol).__name__}, " + f"Day Voting: {type(day_voting_protocol).__name__}, " + f"Night WW Voting: {type(night_voting_protocol).__name__}" + ) + + env.moderator = Moderator( + state=env.game_state, + discussion=discussion_protocol, + day_voting=day_voting_protocol, + night_voting=night_voting_protocol, + night_elimination_reveal_level=env.configuration.night_elimination_reveal_level, + day_exile_reveal_level=env.configuration.day_exile_reveal_level, + ) + + env.player_full_visible_history_cache = {p_id: [] for p_id in env.player_id_str_list} + env.info = {EnvInfoKeys.MODERATOR_OBS: []} + env.agents = agents + + +def renderer(state, env): + if not hasattr(env, "moderator") or not hasattr(env, "game_state"): + return "Game not initialized by interpreter yet." + + game_state: GameState = env.game_state + + lines = [] + for entry in game_state.consume_messages(): + lines.append(entry.description) + return "\n\n".join(lines) + + +def html_renderer(): + js_path = path.abspath(path.join(path.dirname(__file__), "werewolf.js")) + with open(js_path, encoding="utf-8") as buff: + return buff.read() + + +jsonpath = path.abspath(path.join(path.dirname(__file__), "werewolf.json")) +with open(jsonpath) as handle: + specification = json.load(handle)