|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import re |
| 4 | +from functools import lru_cache |
| 5 | +from typing import Optional, Tuple |
| 6 | + |
| 7 | +from pydantic import Field, create_model, field_validator |
| 8 | + |
| 9 | +from .base import BaseAction, BaseState, PlayerID |
| 10 | +from .consts import EventName, PerceivedThreatLevel, Phase |
| 11 | +from .records import DoctorHealActionDataEntry, SeerInspectActionDataEntry |
| 12 | + |
| 13 | +ACTION_EVENT_MAP = {} |
| 14 | + |
| 15 | + |
| 16 | +def register_event(event_name: EventName): |
| 17 | + """A class decorator to register an EventName for an Action class.""" |
| 18 | + |
| 19 | + def decorator(cls): |
| 20 | + ACTION_EVENT_MAP[cls.__name__] = event_name |
| 21 | + setattr(cls, "event_name", event_name) |
| 22 | + return cls |
| 23 | + |
| 24 | + return decorator |
| 25 | + |
| 26 | + |
| 27 | +_REPLACEMENT_MAP = { |
| 28 | + # 'kill' variations |
| 29 | + "kill": "eliminate", |
| 30 | + "kills": "eliminates", |
| 31 | + "killed": "eliminated", |
| 32 | + "killing": "eliminating", |
| 33 | + "killer": "eliminator", |
| 34 | + # 'lynch' variations |
| 35 | + "lynch": "exile", |
| 36 | + "lynches": "exiles", |
| 37 | + "lynched": "exiled", |
| 38 | + "lynching": "exiling", |
| 39 | + # 'mislynch' variations |
| 40 | + "mislynch": "mis-exile", |
| 41 | + "mislynches": "mis-exiles", |
| 42 | + "mislynched": "mis-exiled", |
| 43 | + "mislynching": "mis-exiling", |
| 44 | + # 'murder' variations |
| 45 | + "murder": "remove", |
| 46 | + "murders": "removes", |
| 47 | + "murdered": "removed", |
| 48 | + "murdering": "removing", |
| 49 | + "murderer": "remover", |
| 50 | +} |
| 51 | + |
| 52 | +_CENSOR_PATTERN = re.compile(r"\b(" + "|".join(_REPLACEMENT_MAP.keys()) + r")\b", re.IGNORECASE) |
| 53 | + |
| 54 | + |
| 55 | +# Create a single, case-insensitive regex pattern from all map keys. |
| 56 | +def replacer(match): |
| 57 | + """ |
| 58 | + Finds the correct replacement and applies case based on a specific heuristic. |
| 59 | + """ |
| 60 | + original_word = match.group(0) |
| 61 | + replacement = _REPLACEMENT_MAP[original_word.lower()] |
| 62 | + |
| 63 | + # Rule 1: Preserve ALL CAPS. |
| 64 | + if original_word.isupper(): |
| 65 | + return replacement.upper() |
| 66 | + |
| 67 | + # Rule 2: Handle title-cased words with a more specific heuristic. |
| 68 | + if original_word.istitle(): |
| 69 | + # Preserve title case if it's the first word of the string OR |
| 70 | + # if it's a form like "-ing" which can start a new clause. |
| 71 | + return replacement.title() |
| 72 | + |
| 73 | + # Rule 3: For all other cases (e.g., "Kill" mid-sentence), default to lowercase. |
| 74 | + return replacement.lower() |
| 75 | + |
| 76 | + |
| 77 | +def filter_language(text): |
| 78 | + """Remove inappropriate/violent language.""" |
| 79 | + return _CENSOR_PATTERN.sub(replacer, text) |
| 80 | + |
| 81 | + |
| 82 | +# ------------------------------------------------------------------ # |
| 83 | +class Action(BaseAction): |
| 84 | + """Root of the discriminated-union tree.""" |
| 85 | + |
| 86 | + day: int |
| 87 | + phase: Phase |
| 88 | + actor_id: PlayerID |
| 89 | + reasoning: Optional[str] = Field( |
| 90 | + default=None, |
| 91 | + max_length=4096, |
| 92 | + description="The self monologue that illustrate how you arrived at the action. " |
| 93 | + "It will be invisible to other players.", |
| 94 | + ) |
| 95 | + |
| 96 | + perceived_threat_level: PerceivedThreatLevel = Field( |
| 97 | + default=PerceivedThreatLevel.SAFE, |
| 98 | + description="The self perceived threat level you are currently experiencing from other players. " |
| 99 | + "The assessment will be invisible to other players.", |
| 100 | + ) |
| 101 | + error: Optional[str] = None |
| 102 | + raw_prompt: Optional[str] = None |
| 103 | + raw_completion: Optional[str] = None |
| 104 | + |
| 105 | + @field_validator("reasoning", mode="before") |
| 106 | + @classmethod |
| 107 | + def filter_reasoning(cls, v): |
| 108 | + if v is None: |
| 109 | + return v |
| 110 | + return filter_language(v) |
| 111 | + |
| 112 | + def serialize(self): |
| 113 | + return {"action_type": self.__class__.__name__, "kwargs": self.model_dump()} |
| 114 | + |
| 115 | + @classmethod |
| 116 | + def schema_for_player(cls, fields: Tuple = None, new_cls_name=None): |
| 117 | + """Many of the fields are for internal game record. This method is used to convert the response schema |
| 118 | + to a format friendly for players. |
| 119 | + """ |
| 120 | + fields = fields or [] |
| 121 | + if not new_cls_name: |
| 122 | + new_cls_name = cls.__name__ + "Data" |
| 123 | + field_definitions = { |
| 124 | + field: ( |
| 125 | + cls.model_fields[field].annotation, |
| 126 | + # Pass the entire FieldInfo object, not just the default value |
| 127 | + cls.model_fields[field], |
| 128 | + ) |
| 129 | + for field in fields |
| 130 | + if field in cls.model_fields |
| 131 | + } |
| 132 | + sub_cls = create_model(new_cls_name, **field_definitions) |
| 133 | + subset_schema = sub_cls.model_json_schema() |
| 134 | + return subset_schema |
| 135 | + |
| 136 | + @property |
| 137 | + def action_field(self) -> Optional[str]: |
| 138 | + return None |
| 139 | + |
| 140 | + def push_event(self, state: BaseState): |
| 141 | + # The following is just for internal record keeping. |
| 142 | + data = self.model_dump() |
| 143 | + state.push_event( |
| 144 | + description=f"Player {self.actor_id}, you submitted {data}", |
| 145 | + event_name=ACTION_EVENT_MAP[self.__class__.__name__], |
| 146 | + public=False, |
| 147 | + visible_to=[], |
| 148 | + data=data, |
| 149 | + ) |
| 150 | + |
| 151 | + |
| 152 | +# ——— Mix-in for actions that need a target ------------------------ # |
| 153 | +class TargetedAction(Action): |
| 154 | + target_id: PlayerID = Field(description="The target player's id.") |
| 155 | + |
| 156 | + @classmethod |
| 157 | + @lru_cache(maxsize=10) |
| 158 | + def schema_for_player(cls, fields=None, new_cls_name=None): |
| 159 | + fields = fields or ["perceived_threat_level", "reasoning", "target_id"] |
| 160 | + return super(TargetedAction, cls).schema_for_player(fields, new_cls_name) |
| 161 | + |
| 162 | + @property |
| 163 | + def action_field(self): |
| 164 | + return "target_id" |
| 165 | + |
| 166 | + |
| 167 | +# ——— Concrete leaf classes --------------------------------------- # |
| 168 | +@register_event(EventName.HEAL_ACTION) |
| 169 | +class HealAction(TargetedAction): |
| 170 | + def push_event(self, state: BaseState): |
| 171 | + action_data = DoctorHealActionDataEntry( |
| 172 | + actor_id=self.actor_id, |
| 173 | + target_id=self.target_id, |
| 174 | + reasoning=self.reasoning, |
| 175 | + perceived_threat_level=self.perceived_threat_level, |
| 176 | + action=self, |
| 177 | + ) |
| 178 | + state.push_event( |
| 179 | + description=f"Player {self.actor_id}, you chose to heal player {self.target_id}.", |
| 180 | + event_name=EventName.HEAL_ACTION, |
| 181 | + public=False, |
| 182 | + visible_to=[self.actor_id], |
| 183 | + data=action_data, |
| 184 | + ) |
| 185 | + |
| 186 | + |
| 187 | +@register_event(EventName.INSPECT_ACTION) |
| 188 | +class InspectAction(TargetedAction): |
| 189 | + def push_event(self, state: BaseState): |
| 190 | + action_data = SeerInspectActionDataEntry( |
| 191 | + actor_id=self.actor_id, |
| 192 | + target_id=self.target_id, |
| 193 | + reasoning=self.reasoning, |
| 194 | + perceived_threat_level=self.perceived_threat_level, |
| 195 | + action=self, |
| 196 | + ) |
| 197 | + state.push_event( |
| 198 | + description=f"Player {self.actor_id}, you chose to inspect player {self.target_id}.", |
| 199 | + event_name=EventName.INSPECT_ACTION, |
| 200 | + public=False, |
| 201 | + visible_to=[self.actor_id], |
| 202 | + data=action_data, |
| 203 | + ) |
| 204 | + |
| 205 | + |
| 206 | +@register_event(EventName.VOTE_ACTION) |
| 207 | +class VoteAction(TargetedAction): |
| 208 | + pass |
| 209 | + |
| 210 | + |
| 211 | +@register_event(EventName.ELIMINATE_PROPOSAL_ACTION) |
| 212 | +class EliminateProposalAction(VoteAction): |
| 213 | + pass |
| 214 | + |
| 215 | + |
| 216 | +@register_event(EventName.DISCUSSION) |
| 217 | +class ChatAction(Action): |
| 218 | + message: str = Field(default="", max_length=4096) |
| 219 | + |
| 220 | + @field_validator("message", mode="before") |
| 221 | + @classmethod |
| 222 | + def filter_message(cls, v): |
| 223 | + return filter_language(v) |
| 224 | + |
| 225 | + @classmethod |
| 226 | + @lru_cache(maxsize=10) |
| 227 | + def schema_for_player(cls, fields=None, new_cls_name=None): |
| 228 | + fields = fields or ["perceived_threat_level", "reasoning", "message"] |
| 229 | + return super(ChatAction, cls).schema_for_player(fields, new_cls_name) |
| 230 | + |
| 231 | + @property |
| 232 | + def action_field(self): |
| 233 | + return "message" |
| 234 | + |
| 235 | + |
| 236 | +@register_event(EventName.NOOP_ACTION) |
| 237 | +class NoOpAction(Action): |
| 238 | + pass |
| 239 | + |
| 240 | + |
| 241 | +# ------------------------------------------------------------ # |
| 242 | +@register_event(EventName.BID_ACTION) |
| 243 | +class BidAction(Action): |
| 244 | + """ |
| 245 | + An amount the actor is willing to pay this round. |
| 246 | + Currency unit can be generic 'chips' or role-specific. |
| 247 | + """ |
| 248 | + |
| 249 | + amount: int = Field(ge=0) |
| 250 | + |
| 251 | + @classmethod |
| 252 | + @lru_cache(maxsize=10) |
| 253 | + def schema_for_player(cls, fields=None, new_cls_name=None): |
| 254 | + fields = fields or ["perceived_threat_level", "reasoning", "amount"] |
| 255 | + return super(BidAction, cls).schema_for_player(fields, new_cls_name) |
| 256 | + |
| 257 | + @property |
| 258 | + def action_field(self): |
| 259 | + return "amount" |
| 260 | + |
| 261 | + |
| 262 | +ACTIONS = [EliminateProposalAction, HealAction, InspectAction, VoteAction, ChatAction, BidAction, NoOpAction] |
| 263 | + |
| 264 | +ACTION_REGISTRY = {action.__name__: action for action in ACTIONS} |
| 265 | + |
| 266 | + |
| 267 | +def create_action(serialized): |
| 268 | + return ACTION_REGISTRY[serialized["action_type"]](**serialized.get("kwargs", {})) |
0 commit comments