Skip to content

Commit 29f9c23

Browse files
committed
Update
[ghstack-poisoned]
1 parent e7dc7d3 commit 29f9c23

File tree

2 files changed

+117
-100
lines changed

2 files changed

+117
-100
lines changed

examples/trees/mcts.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def tree_format_fn(tree):
7272

7373
def get_best_move(fen, mcts_steps, rollout_steps):
7474
root = env.reset(TensorDict({"fen": fen}))
75-
tree = torchrl.modules.mcts.MCTS(forest, root, env, mcts_steps, rollout_steps)
75+
mcts = torchrl.modules.mcts.MCTS(mcts_steps, rollout_steps)
76+
tree = mcts(forest, root, env)
7677
moves = []
7778

7879
for subtree in tree.subtree:
@@ -96,7 +97,7 @@ def get_best_move(fen, mcts_steps, rollout_steps):
9697
return moves[0][1]
9798

9899

99-
for idx in range(3):
100+
for idx in range(30):
100101
print("==========")
101102
print(idx)
102103
print("==========")

torchrl/modules/mcts/mcts.py

+114-98
Original file line numberDiff line numberDiff line change
@@ -6,113 +6,129 @@
66
import torch
77
import torchrl
88
from tensordict import TensorDict, TensorDictBase
9+
from tensordict.nn import TensorDictModuleBase
910

1011
from torchrl.data.map import MCTSForest, Tree
1112
from torchrl.envs import EnvBase
1213

1314
C = 2.0**0.5
1415

1516

16-
# TODO: Allow user to specify different priority functions with PR #2358
17-
def _traversal_priority_UCB1(tree):
18-
subtree = tree.subtree
19-
visits = subtree.visits
20-
reward_sum = subtree.wins
21-
22-
# If it's black's turn, flip the reward, since black wants to optimize for
23-
# the lowest reward, not highest.
24-
# TODO: Need a more generic way to do this, since not all use cases of MCTS
25-
# will be two player turn based games.
26-
if not subtree.rollout[0, 0]["turn"]:
27-
reward_sum = -reward_sum
28-
29-
parent_visits = tree.visits
30-
reward_sum = reward_sum.squeeze(-1)
31-
priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits
32-
priority[visits == 0] = float("inf")
33-
return priority
34-
35-
36-
def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps):
37-
done = False
38-
trees_visited = [tree]
39-
40-
while not done:
41-
if tree.subtree is None:
42-
td_tree = tree.rollout[-1]["next"].clone()
43-
44-
if (tree.visits > 0 or tree.parent is None) and not td_tree["done"]:
45-
actions = env.all_actions(td_tree)
46-
subtrees = []
47-
48-
for action in actions:
49-
td = env.step(env.reset(td_tree).update(action))
50-
new_node = torchrl.data.Tree(
51-
rollout=td.unsqueeze(0),
52-
node_data=td["next"].select(*forest.node_map.in_keys),
53-
count=torch.tensor(0),
54-
wins=torch.zeros_like(td["next", env.reward_key]),
55-
)
56-
subtrees.append(new_node)
57-
58-
# NOTE: This whole script runs about 2x faster with lazy stack
59-
# versus eager stack.
60-
tree.subtree = TensorDict.lazy_stack(subtrees)
61-
chosen_idx = torch.randint(0, len(subtrees), ()).item()
62-
rollout_state = subtrees[chosen_idx].rollout[-1]["next"]
17+
class MCTS(TensorDictModuleBase):
18+
"""Monte-Carlo tree search.
6319
64-
else:
65-
rollout_state = td_tree
20+
Attributes:
21+
num_traversals (int): Number of times to traverse the tree.
22+
rollout_max_steps (int): Maximum number of steps for each rollout.
6623
67-
if rollout_state["done"]:
68-
rollout_reward = rollout_state[env.reward_key]
69-
else:
70-
rollout = env.rollout(
71-
max_steps=max_rollout_steps,
72-
tensordict=rollout_state,
73-
)
74-
rollout_reward = rollout[-1]["next", env.reward_key]
75-
done = True
76-
77-
else:
78-
priorities = _traversal_priority_UCB1(tree)
79-
chosen_idx = torch.argmax(priorities).item()
80-
tree = tree.subtree[chosen_idx]
81-
trees_visited.append(tree)
82-
83-
for tree in trees_visited:
84-
tree.visits += 1
85-
tree.wins += rollout_reward
86-
87-
88-
def MCTS(
89-
forest: MCTSForest,
90-
root: TensorDictBase,
91-
env: EnvBase,
92-
num_steps: int,
93-
max_rollout_steps: int | None = None,
94-
) -> Tree:
95-
"""Performs Monte-Carlo tree search in an environment.
96-
97-
Args:
98-
forest (MCTSForest): Forest of the tree to update. If the tree does not
99-
exist yet, it is added.
100-
root (TensorDict): The root step of the tree to update.
101-
env (EnvBase): Environment to performs actions in.
102-
num_steps (int): Number of iterations to traverse.
103-
max_rollout_steps (int): Maximum number of steps for each rollout.
24+
Methods:
25+
forward: Runs the tree search.
10426
"""
105-
for action in env.all_actions(root):
106-
td = env.step(env.reset(root.clone()).update(action))
107-
forest.extend(td.unsqueeze(0))
10827

109-
tree = forest.get_tree(root)
110-
111-
tree.wins = torch.zeros_like(td["next", env.reward_key])
112-
for subtree in tree.subtree:
113-
subtree.wins = torch.zeros_like(td["next", env.reward_key])
114-
115-
for _ in range(num_steps):
116-
_traverse_MCTS_one_step(forest, tree, env, max_rollout_steps)
28+
def __init__(
29+
self,
30+
num_traversals: int,
31+
rollout_max_steps: int | None = None,
32+
):
33+
super().__init__()
34+
self.num_traversals = num_traversals
35+
self.rollout_max_steps = rollout_max_steps
36+
37+
def forward(
38+
self,
39+
forest: MCTSForest,
40+
root: TensorDictBase,
41+
env: EnvBase,
42+
) -> Tree:
43+
"""Performs Monte-Carlo tree search in an environment.
44+
45+
Args:
46+
forest (MCTSForest): Forest of the tree to update. If the tree does not
47+
exist yet, it is added.
48+
root (TensorDict): The root step of the tree to update.
49+
env (EnvBase): Environment to performs actions in.
50+
"""
51+
for action in env.all_actions(root):
52+
td = env.step(env.reset(root.clone()).update(action))
53+
forest.extend(td.unsqueeze(0))
54+
55+
tree = forest.get_tree(root)
56+
57+
tree.wins = torch.zeros_like(td["next", env.reward_key])
58+
for subtree in tree.subtree:
59+
subtree.wins = torch.zeros_like(td["next", env.reward_key])
60+
61+
for _ in range(self.num_traversals):
62+
self._traverse_MCTS_one_step(forest, tree, env, self.rollout_max_steps)
63+
64+
return tree
65+
66+
def _traverse_MCTS_one_step(self, forest, tree, env, rollout_max_steps):
67+
done = False
68+
trees_visited = [tree]
69+
70+
while not done:
71+
if tree.subtree is None:
72+
td_tree = tree.rollout[-1]["next"].clone()
73+
74+
if (tree.visits > 0 or tree.parent is None) and not td_tree["done"]:
75+
actions = env.all_actions(td_tree)
76+
subtrees = []
77+
78+
for action in actions:
79+
td = env.step(env.reset(td_tree).update(action))
80+
new_node = torchrl.data.Tree(
81+
rollout=td.unsqueeze(0),
82+
node_data=td["next"].select(*forest.node_map.in_keys),
83+
count=torch.tensor(0),
84+
wins=torch.zeros_like(td["next", env.reward_key]),
85+
)
86+
subtrees.append(new_node)
87+
88+
# NOTE: This whole script runs about 2x faster with lazy stack
89+
# versus eager stack.
90+
tree.subtree = TensorDict.lazy_stack(subtrees)
91+
chosen_idx = torch.randint(0, len(subtrees), ()).item()
92+
rollout_state = subtrees[chosen_idx].rollout[-1]["next"]
93+
94+
else:
95+
rollout_state = td_tree
96+
97+
if rollout_state["done"]:
98+
rollout_reward = rollout_state[env.reward_key]
99+
else:
100+
rollout = env.rollout(
101+
max_steps=rollout_max_steps,
102+
tensordict=rollout_state,
103+
)
104+
rollout_reward = rollout[-1]["next", env.reward_key]
105+
done = True
117106

118-
return tree
107+
else:
108+
priorities = self._traversal_priority_UCB1(tree)
109+
chosen_idx = torch.argmax(priorities).item()
110+
tree = tree.subtree[chosen_idx]
111+
trees_visited.append(tree)
112+
113+
for tree in trees_visited:
114+
tree.visits += 1
115+
tree.wins += rollout_reward
116+
117+
# TODO: Allow user to specify different priority functions with PR #2358
118+
def _traversal_priority_UCB1(self, tree):
119+
subtree = tree.subtree
120+
visits = subtree.visits
121+
reward_sum = subtree.wins
122+
123+
# If it's black's turn, flip the reward, since black wants to optimize for
124+
# the lowest reward, not highest.
125+
# TODO: Need a more generic way to do this, since not all use cases of MCTS
126+
# will be two player turn based games.
127+
if not subtree.rollout[0, 0]["turn"]:
128+
reward_sum = -reward_sum
129+
130+
parent_visits = tree.visits
131+
reward_sum = reward_sum.squeeze(-1)
132+
priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits
133+
priority[visits == 0] = float("inf")
134+
return priority

0 commit comments

Comments
 (0)