Skip to content

Commit 44c69d6

Browse files
authored
Merge pull request #426 from Kaggle/wwolf_game
[Werewolf V0.2] - Game Logic
2 parents 00e0b83 + a93e5df commit 44c69d6

File tree

16 files changed

+4397
-0
lines changed

16 files changed

+4397
-0
lines changed
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
from __future__ import annotations
2+
3+
import re
4+
from functools import lru_cache
5+
from typing import Optional, Tuple
6+
7+
from pydantic import Field, create_model, field_validator
8+
9+
from .base import BaseAction, BaseState, PlayerID
10+
from .consts import EventName, PerceivedThreatLevel, Phase
11+
from .records import DoctorHealActionDataEntry, SeerInspectActionDataEntry
12+
13+
ACTION_EVENT_MAP = {}
14+
15+
16+
def register_event(event_name: EventName):
17+
"""A class decorator to register an EventName for an Action class."""
18+
19+
def decorator(cls):
20+
ACTION_EVENT_MAP[cls.__name__] = event_name
21+
setattr(cls, "event_name", event_name)
22+
return cls
23+
24+
return decorator
25+
26+
27+
_REPLACEMENT_MAP = {
28+
# 'kill' variations
29+
"kill": "eliminate",
30+
"kills": "eliminates",
31+
"killed": "eliminated",
32+
"killing": "eliminating",
33+
"killer": "eliminator",
34+
# 'lynch' variations
35+
"lynch": "exile",
36+
"lynches": "exiles",
37+
"lynched": "exiled",
38+
"lynching": "exiling",
39+
# 'mislynch' variations
40+
"mislynch": "mis-exile",
41+
"mislynches": "mis-exiles",
42+
"mislynched": "mis-exiled",
43+
"mislynching": "mis-exiling",
44+
# 'murder' variations
45+
"murder": "remove",
46+
"murders": "removes",
47+
"murdered": "removed",
48+
"murdering": "removing",
49+
"murderer": "remover",
50+
}
51+
52+
_CENSOR_PATTERN = re.compile(r"\b(" + "|".join(_REPLACEMENT_MAP.keys()) + r")\b", re.IGNORECASE)
53+
54+
55+
# Create a single, case-insensitive regex pattern from all map keys.
56+
def replacer(match):
57+
"""
58+
Finds the correct replacement and applies case based on a specific heuristic.
59+
"""
60+
original_word = match.group(0)
61+
replacement = _REPLACEMENT_MAP[original_word.lower()]
62+
63+
# Rule 1: Preserve ALL CAPS.
64+
if original_word.isupper():
65+
return replacement.upper()
66+
67+
# Rule 2: Handle title-cased words with a more specific heuristic.
68+
if original_word.istitle():
69+
# Preserve title case if it's the first word of the string OR
70+
# if it's a form like "-ing" which can start a new clause.
71+
return replacement.title()
72+
73+
# Rule 3: For all other cases (e.g., "Kill" mid-sentence), default to lowercase.
74+
return replacement.lower()
75+
76+
77+
def filter_language(text):
78+
"""Remove inappropriate/violent language."""
79+
return _CENSOR_PATTERN.sub(replacer, text)
80+
81+
82+
# ------------------------------------------------------------------ #
83+
class Action(BaseAction):
84+
"""Root of the discriminated-union tree."""
85+
86+
day: int
87+
phase: Phase
88+
actor_id: PlayerID
89+
reasoning: Optional[str] = Field(
90+
default=None,
91+
max_length=4096,
92+
description="The self monologue that illustrate how you arrived at the action. "
93+
"It will be invisible to other players.",
94+
)
95+
96+
perceived_threat_level: PerceivedThreatLevel = Field(
97+
default=PerceivedThreatLevel.SAFE,
98+
description="The self perceived threat level you are currently experiencing from other players. "
99+
"The assessment will be invisible to other players.",
100+
)
101+
error: Optional[str] = None
102+
raw_prompt: Optional[str] = None
103+
raw_completion: Optional[str] = None
104+
105+
@field_validator("reasoning", mode="before")
106+
@classmethod
107+
def filter_reasoning(cls, v):
108+
if v is None:
109+
return v
110+
return filter_language(v)
111+
112+
def serialize(self):
113+
return {"action_type": self.__class__.__name__, "kwargs": self.model_dump()}
114+
115+
@classmethod
116+
def schema_for_player(cls, fields: Tuple = None, new_cls_name=None):
117+
"""Many of the fields are for internal game record. This method is used to convert the response schema
118+
to a format friendly for players.
119+
"""
120+
fields = fields or []
121+
if not new_cls_name:
122+
new_cls_name = cls.__name__ + "Data"
123+
field_definitions = {
124+
field: (
125+
cls.model_fields[field].annotation,
126+
# Pass the entire FieldInfo object, not just the default value
127+
cls.model_fields[field],
128+
)
129+
for field in fields
130+
if field in cls.model_fields
131+
}
132+
sub_cls = create_model(new_cls_name, **field_definitions)
133+
subset_schema = sub_cls.model_json_schema()
134+
return subset_schema
135+
136+
@property
137+
def action_field(self) -> Optional[str]:
138+
return None
139+
140+
def push_event(self, state: BaseState):
141+
# The following is just for internal record keeping.
142+
data = self.model_dump()
143+
state.push_event(
144+
description=f"Player {self.actor_id}, you submitted {data}",
145+
event_name=ACTION_EVENT_MAP[self.__class__.__name__],
146+
public=False,
147+
visible_to=[],
148+
data=data,
149+
)
150+
151+
152+
# ——— Mix-in for actions that need a target ------------------------ #
153+
class TargetedAction(Action):
154+
target_id: PlayerID = Field(description="The target player's id.")
155+
156+
@classmethod
157+
@lru_cache(maxsize=10)
158+
def schema_for_player(cls, fields=None, new_cls_name=None):
159+
fields = fields or ["perceived_threat_level", "reasoning", "target_id"]
160+
return super(TargetedAction, cls).schema_for_player(fields, new_cls_name)
161+
162+
@property
163+
def action_field(self):
164+
return "target_id"
165+
166+
167+
# ——— Concrete leaf classes --------------------------------------- #
168+
@register_event(EventName.HEAL_ACTION)
169+
class HealAction(TargetedAction):
170+
def push_event(self, state: BaseState):
171+
action_data = DoctorHealActionDataEntry(
172+
actor_id=self.actor_id,
173+
target_id=self.target_id,
174+
reasoning=self.reasoning,
175+
perceived_threat_level=self.perceived_threat_level,
176+
action=self,
177+
)
178+
state.push_event(
179+
description=f"Player {self.actor_id}, you chose to heal player {self.target_id}.",
180+
event_name=EventName.HEAL_ACTION,
181+
public=False,
182+
visible_to=[self.actor_id],
183+
data=action_data,
184+
)
185+
186+
187+
@register_event(EventName.INSPECT_ACTION)
188+
class InspectAction(TargetedAction):
189+
def push_event(self, state: BaseState):
190+
action_data = SeerInspectActionDataEntry(
191+
actor_id=self.actor_id,
192+
target_id=self.target_id,
193+
reasoning=self.reasoning,
194+
perceived_threat_level=self.perceived_threat_level,
195+
action=self,
196+
)
197+
state.push_event(
198+
description=f"Player {self.actor_id}, you chose to inspect player {self.target_id}.",
199+
event_name=EventName.INSPECT_ACTION,
200+
public=False,
201+
visible_to=[self.actor_id],
202+
data=action_data,
203+
)
204+
205+
206+
@register_event(EventName.VOTE_ACTION)
207+
class VoteAction(TargetedAction):
208+
pass
209+
210+
211+
@register_event(EventName.ELIMINATE_PROPOSAL_ACTION)
212+
class EliminateProposalAction(VoteAction):
213+
pass
214+
215+
216+
@register_event(EventName.DISCUSSION)
217+
class ChatAction(Action):
218+
message: str = Field(default="", max_length=4096)
219+
220+
@field_validator("message", mode="before")
221+
@classmethod
222+
def filter_message(cls, v):
223+
return filter_language(v)
224+
225+
@classmethod
226+
@lru_cache(maxsize=10)
227+
def schema_for_player(cls, fields=None, new_cls_name=None):
228+
fields = fields or ["perceived_threat_level", "reasoning", "message"]
229+
return super(ChatAction, cls).schema_for_player(fields, new_cls_name)
230+
231+
@property
232+
def action_field(self):
233+
return "message"
234+
235+
236+
@register_event(EventName.NOOP_ACTION)
237+
class NoOpAction(Action):
238+
pass
239+
240+
241+
# ------------------------------------------------------------ #
242+
@register_event(EventName.BID_ACTION)
243+
class BidAction(Action):
244+
"""
245+
An amount the actor is willing to pay this round.
246+
Currency unit can be generic 'chips' or role-specific.
247+
"""
248+
249+
amount: int = Field(ge=0)
250+
251+
@classmethod
252+
@lru_cache(maxsize=10)
253+
def schema_for_player(cls, fields=None, new_cls_name=None):
254+
fields = fields or ["perceived_threat_level", "reasoning", "amount"]
255+
return super(BidAction, cls).schema_for_player(fields, new_cls_name)
256+
257+
@property
258+
def action_field(self):
259+
return "amount"
260+
261+
262+
ACTIONS = [EliminateProposalAction, HealAction, InspectAction, VoteAction, ChatAction, BidAction, NoOpAction]
263+
264+
ACTION_REGISTRY = {action.__name__: action for action in ACTIONS}
265+
266+
267+
def create_action(serialized):
268+
return ACTION_REGISTRY[serialized["action_type"]](**serialized.get("kwargs", {}))
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Annotated, Any, Dict, List, Optional, Protocol, Type
3+
4+
from pydantic import BaseModel, StringConstraints
5+
6+
from .consts import EVENT_HANDLER_FOR_ATTR_NAME, MODERATOR_ID, EventName
7+
8+
# The ID regex supports Unicode letters (\p{L}), numbers (\p{N}) and common symbol for ID.
9+
ROBUST_ID_REGEX = r"^[\p{L}\p{N} _.-]+$"
10+
11+
PlayerID = Annotated[str, StringConstraints(pattern=ROBUST_ID_REGEX, min_length=1, max_length=128)]
12+
13+
14+
class BasePlayer(BaseModel, ABC):
15+
id: PlayerID
16+
"""The unique id of the player. Also, how the player is referred to in the game."""
17+
18+
alive: bool = True
19+
20+
@abstractmethod
21+
def set_role_state(self, key, value):
22+
"""Set role related state, which is a dict."""
23+
24+
@abstractmethod
25+
def get_role_state(self, key, default=None):
26+
"""Get role related state."""
27+
28+
29+
class BaseAction(BaseModel):
30+
pass
31+
32+
33+
class BaseState(BaseModel):
34+
@abstractmethod
35+
def push_event(
36+
self,
37+
description: str,
38+
event_name: EventName,
39+
public: bool,
40+
visible_to: Optional[List[PlayerID]] = None,
41+
data: Any = None,
42+
source=MODERATOR_ID,
43+
):
44+
"""Publish an event."""
45+
46+
47+
class BaseEvent(BaseModel):
48+
event_name: EventName
49+
50+
51+
class BaseModerator(ABC):
52+
@abstractmethod
53+
def advance(self, player_actions: Dict[PlayerID, BaseAction]):
54+
"""Move one Kaggle environment step further. This is to be used within Kaggle 'interpreter'."""
55+
56+
@abstractmethod
57+
def request_action(
58+
self,
59+
action_cls: Type[BaseAction],
60+
player_id: PlayerID,
61+
prompt: str,
62+
data=None,
63+
event_name=EventName.MODERATOR_ANNOUNCEMENT,
64+
):
65+
"""This can be used by event handler to request action from a player."""
66+
67+
@abstractmethod
68+
def record_night_save(self, doctor_id: str, target_id: str):
69+
"""To be used by a special Role to perform night save. This is implemented in moderator level, since
70+
coordinating between safe and night elimination is cross role activity.
71+
"""
72+
73+
@property
74+
@abstractmethod
75+
def state(self) -> BaseState:
76+
"""Providing current state of the game, including player info, event messaging and caching."""
77+
78+
79+
def on_event(event_type: EventName):
80+
def decorator(func):
81+
setattr(func, EVENT_HANDLER_FOR_ATTR_NAME, event_type)
82+
return func
83+
84+
return decorator
85+
86+
87+
class EventHandler(Protocol):
88+
"""A callable triggered by an event."""
89+
90+
def __call__(self, event: BaseEvent) -> Any:
91+
pass
92+
93+
94+
class RoleEventHandler(Protocol):
95+
"""A role specific event handler."""
96+
97+
def __call__(self, me: BasePlayer, moderator: BaseModerator, event: BaseEvent) -> Any:
98+
pass
99+
100+
101+
class BaseRole(BaseModel, ABC):
102+
"""Special abilities should be implemented as RoleEventHandler in each subclass of BaseRole, so that Moderator
103+
doesn't need to be overwhelmed by role specific logic.
104+
"""
105+
106+
def get_event_handlers(self) -> Dict[EventName, RoleEventHandler]:
107+
"""Inspects the role instance and collects all methods decorated with @on_event"""
108+
handlers = {}
109+
for attr_name in dir(self):
110+
if not attr_name.startswith("__"):
111+
attr = getattr(self, attr_name)
112+
if callable(attr) and hasattr(attr, EVENT_HANDLER_FOR_ATTR_NAME):
113+
event_type = getattr(attr, EVENT_HANDLER_FOR_ATTR_NAME)
114+
handlers[event_type] = attr
115+
return handlers

0 commit comments

Comments
 (0)