Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions kaggle_environments/envs/open_spiel/open_spiel.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,32 @@ def random_agent(
return {"submission": int(action)}


def debug_agent(
observation: dict[str, Any],
configuration: dict[str, Any],
max_history_length: int = 8,
) -> int:
"""A built-in random agent specifically for OpenSpiel environments."""
del configuration
serialized_game_and_state = observation.get("serializedGameAndState")
if not serialized_game_and_state:
return None
game, state = pyspiel.deserialize_game_and_state(serialized_game_and_state)
if len(state.history()) >= max_history_length:
return {
"submission": pyspiel.INVALID_ACTION,
"status": "Max history length reached; intentionally submitting invalid action.",
}
legal_actions = observation.get("legalActions")
if not legal_actions:
return None
action = random.choice(legal_actions)
return {"submission": int(action)}


AGENT_REGISTRY = {
"random": random_agent,
"debug": debug_agent,
}


Expand Down
20 changes: 19 additions & 1 deletion kaggle_environments/envs/open_spiel/test_open_spiel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from absl.testing import absltest
import functools
import sys

from absl.testing import absltest
from kaggle_environments import make
import pyspiel
from . import open_spiel as open_spiel_env
Expand Down Expand Up @@ -91,6 +93,22 @@ def test_agent_error(self):
self.assertEqual(json["rewards"], [None, None])
self.assertEqual(json["statuses"], ["ERROR", "ERROR"])

def test_debug_agent(self):
env = make("open_spiel_chess", debug=True)
max_history_length = 5
debug_agent = functools.partial(
open_spiel_env.debug_agent,
max_history_length=max_history_length,
)
env.run([debug_agent, "random"])
json = env.toJSON()
self.assertEqual(json["rewards"], [-1, 1])
self.assertEqual(json["statuses"], ["DONE", "DONE"])
self.assertEqual(
len(json["info"]["actionHistory"]),
max_history_length,
)


if __name__ == '__main__':
absltest.main()