diff --git a/axelrod/data/model_attention.pth b/axelrod/data/model_attention.pth new file mode 100644 index 000000000..72ed07093 Binary files /dev/null and b/axelrod/data/model_attention.pth differ diff --git a/axelrod/load_data_.py b/axelrod/load_data_.py index 58ab40c46..dfe9d1a49 100644 --- a/axelrod/load_data_.py +++ b/axelrod/load_data_.py @@ -2,6 +2,8 @@ import pkgutil from typing import Callable, Dict, List, Optional, Tuple +import torch + def axl_filename(path: pathlib.Path) -> pathlib.Path: """Given a path under Axelrod/, return absolute filepath. @@ -77,3 +79,12 @@ def load_pso_tables(filename="pso_gambler.csv", directory="data"): values = list(map(float, row[4:])) d[(name, int(a), int(b), int(c))] = values return d + + +def load_attention_model_weights( + filename="model_attention.pth", directory="axelrod/data" +): + """Load attention model weights.""" + path = str(axl_filename(pathlib.Path(directory) / filename)) + weights = torch.load(path, map_location=torch.device("cpu")) + return weights diff --git a/axelrod/strategies/_strategies.py b/axelrod/strategies/_strategies.py index a209664c2..373797be8 100644 --- a/axelrod/strategies/_strategies.py +++ b/axelrod/strategies/_strategies.py @@ -30,6 +30,7 @@ from .ann import ANN, EvolvableANN # pylint: disable=unused-import from .apavlov import APavlov2006, APavlov2011 from .appeaser import Appeaser +from .attention import Attention, EvolvedAttention from .averagecopier import AverageCopier, NiceAverageCopier from .axelrod_first import ( FirstByDavis, @@ -303,6 +304,7 @@ AntiCycler, AntiTitForTat, Appeaser, + Attention, ArrogantQLearner, AverageCopier, BackStabber, @@ -348,6 +350,7 @@ EvolvedHMM5, EvolvedLookerUp1_1_1, EvolvedLookerUp2_2_2, + EvolvedAttention, FirmButFair, FirstByAnonymous, FirstByDavis, diff --git a/axelrod/strategies/attention.py b/axelrod/strategies/attention.py new file mode 100644 index 000000000..163d6fe18 --- /dev/null +++ b/axelrod/strategies/attention.py @@ -0,0 +1,399 @@ +import copy +from enum import IntEnum +from typing import Optional, Tuple + +import torch +from torch import nn + +from axelrod.action import Action +from axelrod.load_data_ import load_attention_model_weights +from axelrod.player import Player + +C, D = Action.C, Action.D + +MEMORY_LENGTH = 200 + +CLS_TOKEN = 0 +PAD_TOKEN = 1 + +DEVICES = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +model_weights = load_attention_model_weights() + + +class GameState(IntEnum): + CooperateDefect = 2 + DefectCooperate = 3 + CooperateCooperate = 4 + DefectDefect = 5 + + +def actions_to_game_state( + player_action: Action, opponent_action: Action +) -> GameState: + action_mapping = { + (C, D): GameState.CooperateDefect, + (D, C): GameState.DefectCooperate, + (C, C): GameState.CooperateCooperate, + (D, D): GameState.DefectDefect, + } + return action_mapping[(player_action, opponent_action)] + + +def compute_features( + player: Player, opponent: Player, right_pad: bool = False +) -> torch.IntTensor: + # The first token is the CLS token + player_history = player.history[-MEMORY_LENGTH:] + player_history = player_history[::-1] + opponent_history = opponent.history[-MEMORY_LENGTH:] + opponent_history = opponent_history[::-1] + + feature_size = MEMORY_LENGTH + 1 if right_pad else len(player_history) + 1 + + game_history = torch.full((feature_size,), PAD_TOKEN, dtype=torch.int) + game_history[0] = CLS_TOKEN + for index, (action_player, action_opponent) in enumerate( + zip(player_history, opponent_history) + ): + game_state = actions_to_game_state(action_player, action_opponent) + game_history[index + 1] = game_state + return game_history + + +class GELUActivation(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return nn.functional.gelu(input) + + +class PlayerConfig: + def __init__( + self, + state_size=6, # Number of possible game states, 4 possible game states and 2 specials token + hidden_size=256, + num_hidden_layers=24, + num_attention_heads=8, + intermediate_size=512, + hidden_dropout_prob=0.3, + attention_probs_dropout_prob=0.3, + max_game_size=MEMORY_LENGTH + 1, # Add 1 for the CLS token + initializer_range=0.02, + layer_norm_eps=1e-12, + ): + self.state_size = state_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_game_size = max_game_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + +class PlayerEmbeddings(nn.Module): + """Construct the embeddings from game state and position embeddings.""" + + def __init__(self, config: PlayerConfig): + super().__init__() + self.game_state_embeddings = nn.Embedding( + config.state_size, config.hidden_size + ) + self.position_embeddings = nn.Embedding( + config.max_game_size, config.hidden_size + ) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.register_buffer( + "position_ids", + torch.arange(config.max_game_size).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + input_ids: torch.LongTensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input_shape = input_ids.size() + seq_length = input_shape[1] + position_ids = self.position_ids[:, 0:seq_length] + embeddings = self.game_state_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + + attention_mask = (input_ids != PAD_TOKEN).long() + + return embeddings, attention_mask + + +class PlayerSelfAttention(nn.Module): + def __init__(self, config: PlayerConfig): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int( + config.hidden_size / config.num_attention_heads + ) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout_prob = config.attention_probs_dropout_prob + + def _transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = src_len + + expanded_mask = ( + mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + ) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + def forward( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor + ) -> torch.Tensor: + bsz, tgt_len, _ = hidden_states.size() + query_layer = self._transpose_for_scores(self.query(hidden_states)) + key_layer = self._transpose_for_scores(self.key(hidden_states)) + value_layer = self._transpose_for_scores(self.value(hidden_states)) + + attn_mask = self._expand_mask(attention_mask, query_layer.dtype) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + dropout_p=self.dropout_prob if self.training else 0.0, + attn_mask=attn_mask, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + return attn_output + + +class PlayerSelfOutput(nn.Module): + def __init__(self, config: PlayerConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class PlayerAttention(nn.Module): + def __init__(self, config: PlayerConfig): + super().__init__() + self.self = PlayerSelfAttention(config) + self.output = PlayerSelfOutput(config) + + def forward( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor + ) -> torch.Tensor: + self_outputs = self.self(hidden_states, attention_mask) + attention_output = self.output(self_outputs, hidden_states) + return attention_output + + +class PlayerIntermediate(nn.Module): + def __init__(self, config: PlayerConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = GELUActivation() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class PlayerOutput(nn.Module): + def __init__(self, config: PlayerConfig): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class PlayerLayer(nn.Module): + def __init__(self, config: PlayerConfig): + super().__init__() + self.seq_len_dim = 1 + self.attention = PlayerAttention(config) + self.intermediate = PlayerIntermediate(config) + self.output = PlayerOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class PlayerEncoder(nn.Module): + def __init__(self, config: PlayerConfig): + super().__init__() + self.layer = nn.ModuleList( + [PlayerLayer(config) for _ in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask) + return hidden_states + + +class PlayerPooler(nn.Module): + def __init__(self, config: PlayerConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class PlayerModel(nn.Module): + _no_split_modules = ["PlayerEmbeddings"] + + def __init__(self, config: PlayerConfig): + super().__init__() + self.config = config + self.embeddings = PlayerEmbeddings(config) + self.encoder = PlayerEncoder(config) + self.pooler = PlayerPooler(config) + + self.action = nn.Linear(config.hidden_size, 1) + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + embedding_output, attention_mask = self.embeddings(input_ids=input_ids) + sequence_output = self.encoder(embedding_output, attention_mask) + pooled_output = self.pooler(sequence_output) + return self.action(pooled_output) + + def __eq__(self, other: "PlayerModel") -> bool: + return isinstance(other, PlayerModel) + + +class Attention(Player): + """A player who uses an attention mechanism to analyse the game. + + Names: + - Attention: Attention by Marc-Olivier Derouin + """ + + name = "Attention" + classifier = { + "memory_depth": MEMORY_LENGTH, + "stochastic": False, + "makes_use_of": set(), + "long_run_time": False, + "inspects_source": False, + "manipulates_source": False, + "manipulates_state": False, + } + + def __init__(self, model: Optional[PlayerModel] = None) -> None: + super().__init__() + if model is None: + self.model = PlayerModel(PlayerConfig()) + else: + self.model = model + self.model.to(DEVICES) + self.model.eval() + + def strategy(self, opponent: Player) -> Action: + """Actual strategy definition that determines player's action.""" + # Compute features + features = compute_features(self, opponent).unsqueeze(0).to(DEVICES) + + # Get action from the model + logits = self.model(features) + + # Apply sigmoid + logits = torch.sigmoid(logits) + + return C if logits.item() < 0.5 else D + + +class EvolvedAttention(Attention): + """A player who uses an attention mechanism to analyse the game. Trained with self-play. + + Names: + - EvolvedAttention: EvolvedAttention by Marc-Olivier Derouin + """ + + name = "EvolvedAttention" + classifier = { + "memory_depth": MEMORY_LENGTH, + "stochastic": False, + "long_run_time": False, + "inspects_source": False, + "manipulates_source": False, + "manipulates_state": False, + } + + def __init__( + self, + ) -> None: + model = PlayerModel(PlayerConfig()) + model.load_state_dict(model_weights) + super().__init__(model) diff --git a/axelrod/tests/strategies/test_attention.py b/axelrod/tests/strategies/test_attention.py new file mode 100644 index 000000000..bba3e34a6 --- /dev/null +++ b/axelrod/tests/strategies/test_attention.py @@ -0,0 +1,103 @@ +"""Tests for the Attention strategies.""" + +import unittest + +import torch + +import axelrod as axl +from axelrod.strategies.attention import ( + MEMORY_LENGTH, + GameState, + PlayerModel, + actions_to_game_state, + compute_features, +) + +from .test_player import TestPlayer + +C, D = axl.Action.C, axl.Action.D + + +class TestFeatureComputation(unittest.TestCase): + """Test the feature computation functionality.""" + + def test_compute_features(self): + """Test that features are computed correctly.""" + player = axl.MockPlayer(actions=[C, D, C, D]) + opponent = axl.MockPlayer(actions=[D, C, C, D]) + # Play the actions to populate history + match = axl.Match((player, opponent), turns=4) + match.play() + + features = compute_features(player, opponent) + + # Check the shape and type + self.assertIsInstance(features, torch.Tensor) + self.assertEqual(features.shape, (len(player.history) + 1,)) + + # Check specific values (CLS token and game states) + self.assertEqual(features[0].item(), 0) # CLS token + self.assertEqual(features[1].item(), GameState.DefectDefect) + self.assertEqual(features[2].item(), GameState.CooperateCooperate) + self.assertEqual(features[3].item(), GameState.DefectCooperate) + self.assertEqual(features[4].item(), GameState.CooperateDefect) + + def test_compute_features_right_pad(self): + """Test that features are computed correctly.""" + player = axl.MockPlayer(actions=[C, D, C, D]) + opponent = axl.MockPlayer(actions=[D, C, C, D]) + # Play the actions to populate history + match = axl.Match((player, opponent), turns=4) + match.play() + + features = compute_features(player, opponent, True) + + # Check the shape and type + self.assertIsInstance(features, torch.Tensor) + self.assertEqual(features.shape, (MEMORY_LENGTH + 1,)) + + # Check specific values (CLS token and game states) + self.assertEqual(features[0].item(), 0) # CLS token + self.assertEqual(features[1].item(), GameState.DefectDefect) + self.assertEqual(features[2].item(), GameState.CooperateCooperate) + self.assertEqual(features[3].item(), GameState.DefectCooperate) + self.assertEqual(features[4].item(), GameState.CooperateDefect) + + def test_actions_to_game_state(self): + """Test the mapping from actions to game states.""" + self.assertEqual( + actions_to_game_state(C, C), GameState.CooperateCooperate + ) + self.assertEqual(actions_to_game_state(C, D), GameState.CooperateDefect) + self.assertEqual(actions_to_game_state(D, C), GameState.DefectCooperate) + self.assertEqual(actions_to_game_state(D, D), GameState.DefectDefect) + + +class TestAttention(unittest.TestCase): + def test_initilization(self): + """Test that the model is initialized correctly.""" + player = axl.Attention() + self.assertIsInstance(player.model, PlayerModel) + + +class TestEvolvedAttention(TestPlayer): + name = "EvolvedAttention" + player = axl.EvolvedAttention + expected_classifier = { + "memory_depth": MEMORY_LENGTH, + "stochastic": False, + "makes_use_of": set(), + "long_run_time": False, + "inspects_source": False, + "manipulates_source": False, + "manipulates_state": False, + } + + def test_model_initialization(self): + """Test that the model is initialized correctly.""" + player = self.player() + self.assertIsInstance(player.model, PlayerModel) + + def test_versus_cooperator(self): + actions = [(C, C)] * 5 + self.versus_test(axl.Cooperator(), expected_actions=actions) diff --git a/axelrod/tests/unit/test_load_data.py b/axelrod/tests/unit/test_load_data.py index 8273b1809..ddfde8a43 100644 --- a/axelrod/tests/unit/test_load_data.py +++ b/axelrod/tests/unit/test_load_data.py @@ -1,8 +1,15 @@ import os import pathlib import unittest +from unittest.mock import patch -from axelrod.load_data_ import axl_filename, load_file +import torch + +from axelrod.load_data_ import ( + axl_filename, + load_attention_model_weights, + load_file, +) class TestLoadData(unittest.TestCase): @@ -27,3 +34,29 @@ def test_raise_error_if_something(self): bad_loader = lambda _, __: None with self.assertRaises(FileNotFoundError): load_file(path, ".", bad_loader) + + def test_load_attention_model_weights(self): + """Test that the load_attention_model_weights function works correctly.""" + # Create a mock object to return + mock_weights = { + "layer1": torch.tensor([1.0, 2.0]), + "layer2": torch.tensor([3.0, 4.0]), + } + + # Patch torch.load to return our mock weights + with patch( + "axelrod.load_data_.torch.load", return_value=mock_weights + ) as mock_load: + # Call our function + result = load_attention_model_weights() + + # Check that torch.load was called once + mock_load.assert_called_once() + + # Check that the path passed to torch.load contains the expected components + args, kwargs = mock_load.call_args + self.assertIn("model_attention.pth", args[0]) + self.assertEqual(kwargs["map_location"], torch.device("cpu")) + + # Check that the function returned our mock weights + self.assertEqual(result, mock_weights) diff --git a/docs/index.rst b/docs/index.rst index 0fc7c8ff3..a379bbc7f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -53,7 +53,7 @@ Count the number of available players:: >>> import axelrod as axl >>> len(axl.strategies) - 242 + 244 Create matches between two players:: diff --git a/docs/reference/strategy_index.rst b/docs/reference/strategy_index.rst index 1e570fac6..9764d3082 100644 --- a/docs/reference/strategy_index.rst +++ b/docs/reference/strategy_index.rst @@ -18,6 +18,8 @@ Here are the docstrings of all the strategies in the library. :members: .. automodule:: axelrod.strategies.appeaser :members: +.. automodule:: axelrod.strategies.attention + :members: .. automodule:: axelrod.strategies.averagecopier :members: .. automodule:: axelrod.strategies.axelrod_first diff --git a/docs/requirements.txt b/docs/requirements.txt index b2c933e0a..0f4be075a 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,4 @@ docutils>=0.18.1 numpy==1.24.3 # numpy isn't mocked due to complex use in doctests mock>=5.1.0 +torch>=2.6.0 \ No newline at end of file diff --git a/docs/tutorials/running_axelrods_first_tournament/_static/running_axelrods_first_tournament/main.py b/docs/tutorials/running_axelrods_first_tournament/_static/running_axelrods_first_tournament/main.py index f4f826a30..a3c739223 100644 --- a/docs/tutorials/running_axelrods_first_tournament/_static/running_axelrods_first_tournament/main.py +++ b/docs/tutorials/running_axelrods_first_tournament/_static/running_axelrods_first_tournament/main.py @@ -2,9 +2,10 @@ Script to obtain plots for the running axelrod tournament tutorial. """ -import axelrod as axl import matplotlib.pyplot as plt +import axelrod as axl + first_tournament_participants_ordered_by_reported_rank = [ s() for s in axl.axelrod_first_strategies ] diff --git a/pyproject.toml b/pyproject.toml index 8aec1b8ac..40d963607 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "scipy>=1.3.3", "toolz>=0.8.2", "tqdm>=4.39.0", + "torch>=2.6.0", ] [project.optional-dependencies] diff --git a/run_mypy.py b/run_mypy.py index 9286bca9a..98f95be2c 100755 --- a/run_mypy.py +++ b/run_mypy.py @@ -20,6 +20,7 @@ "axelrod/strategies/ann.py", "axelrod/strategies/apavlov.py", "axelrod/strategies/appeaser.py", + "axelrod/strategies/attention.py", "axelrod/strategies/averagecopier.py", "axelrod/strategies/axelrod_first.py", "axelrod/strategies/axelrod_second.py", diff --git a/setup.py b/setup.py index 25014f364..ed16d6a98 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ -from collections import defaultdict import os import pathlib +from collections import defaultdict + from setuptools import setup # Read in the requirements files. diff --git a/tox.ini b/tox.ini index 7afca6c9d..d2d88bf3f 100644 --- a/tox.ini +++ b/tox.ini @@ -28,6 +28,7 @@ deps = isort black numpy==1.26.4 + torch==2.6.0 mypy types-setuptools commands =