-
Notifications
You must be signed in to change notification settings - Fork 24
Reinforcement Learning Template #276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
theory-in-progress
wants to merge
8
commits into
pytorch-ignite:main
Choose a base branch
from
theory-in-progress:template-reinforcement-learning
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 1 commit
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
557431b
Adding the Reinforcement Learning Template
theory-in-progress 60025af
Merge branch 'pytorch-ignite:main' into template-reinforcement-learning
theory-in-progress 1c268b1
DQN CarRacing-v2 Template
theory-in-progress 1eb91a9
Merge branch 'main' into template-reinforcement-learning
theory-in-progress 9490220
Modified RL Template
theory-in-progress ef1a873
Update README and requirements
theory-in-progress 316e5bf
Merge branch 'main' into template-reinforcement-learning
theory-in-progress 7619b24
Modify in the colab, remove from requirements
theory-in-progress File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
234 changes: 234 additions & 0 deletions
234
src/templates/template-reinforcement-learning/advantage_actor_critic_a2c.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
from collections import deque, namedtuple | ||
|
||
from shutil import copy | ||
|
||
import ignite.distributed as idist | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
|
||
from ignite.engine import Engine, Events | ||
|
||
from ignite.utils import manual_seed | ||
|
||
from torch.distributions import Categorical | ||
|
||
from utils import * | ||
|
||
from typing import Any | ||
|
||
import numpy as np | ||
|
||
# from matplotlib import pyplot as plt | ||
|
||
try: | ||
import gymnasium as gym | ||
except ImportError: | ||
raise ModuleNotFoundError("Please install opengym: pip install gymnasium[box2d]") | ||
|
||
SavedAction = namedtuple("SavedAction", ["log_prob", "value"]) | ||
|
||
eps = np.finfo(np.float32).eps.item() | ||
|
||
|
||
class ActorCriticNetwork(nn.Module): | ||
def __init__(self, n_actions): | ||
super(ActorCriticNetwork, self).__init__() | ||
self.LeakyReLU = nn.LeakyReLU() | ||
self.Sigmoid = nn.Sigmoid() | ||
self.Softplus = nn.Softplus() | ||
|
||
# REVIEW: | ||
# OPTIMIZE: | ||
self.conv1 = nn.Conv2d(3, 8, kernel_size=7, stride=4, padding=0) | ||
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=2) | ||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | ||
|
||
self.fc1 = nn.Linear(576, 512) | ||
self.fc_critic2 = nn.Linear(512, 1) | ||
self.fc_actor2 = nn.Linear(512, 256) | ||
self.fc_actor3 = nn.Linear(256, n_actions) | ||
|
||
self.flatten = nn.Flatten() | ||
|
||
self.saved_actions = [] | ||
self.rewards = [] | ||
self.saved_log_probs = [] | ||
|
||
def forward(self, observation): | ||
# state = torch.Tensor(observation).to(self.device) | ||
|
||
# Shared weights | ||
x = self.LeakyReLU(self.conv1(observation)) | ||
x = self.pool(x) | ||
x = self.LeakyReLU(self.conv2(x)) | ||
x = self.pool(x) | ||
x = self.flatten(x) | ||
x = self.fc1(x) | ||
|
||
# actor and critic | ||
# actor | ||
dist = self.LeakyReLU(self.fc_actor2(x)) | ||
dist = self.Softplus(self.fc_actor3(dist)) | ||
|
||
actor = F.softmax(dist, dim=1) | ||
|
||
# critic | ||
critic = self.fc_critic2(x) | ||
|
||
return actor, critic | ||
|
||
|
||
# choose an action for the discrete actions | ||
def choose_action(policy, observation): | ||
observation = observation.float().unsqueeze(0) | ||
state = torch.transpose(observation, 1, 3) | ||
probabilities, value = policy(state) | ||
# probabilities = F.softmax(probabilities) | ||
|
||
action_probs = Categorical(probabilities) | ||
action = action_probs.sample() | ||
|
||
log_probs = action_probs.log_prob(action) | ||
policy.saved_actions.append(SavedAction(log_probs, value)) | ||
policy.saved_log_probs.append(log_probs) | ||
|
||
return action.item() | ||
|
||
|
||
def learn(policy, optimizer, gamma): | ||
R = 0 | ||
saved_actions = policy.saved_actions | ||
policy_losses = [] # list to save actor (policy) loss | ||
value_losses = [] # list to save critic (value) loss | ||
returns = deque() # list to save the true values | ||
|
||
for r in policy.rewards[::-1]: | ||
# calculate the discounted value | ||
R = r + gamma * R | ||
returns.appendleft(R) | ||
|
||
returns = torch.tensor(returns) | ||
returns = (returns - returns.mean()) / (returns.std() + eps) | ||
|
||
for (log_prob, value), R in zip(saved_actions, returns): | ||
advantage = R - value.item() | ||
|
||
# calculate actor (policy) loss | ||
policy_losses.append(-log_prob * advantage) | ||
|
||
# calculate critic (value) loss using L1 smooth loss | ||
value_losses.append(F.smooth_l1_loss(value, torch.tensor([R]))) | ||
|
||
# reset gradients | ||
optimizer.zero_grad() | ||
|
||
# sum up all the values of policy_losses and value_losses | ||
loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum() | ||
|
||
# perform backprop | ||
loss.backward() | ||
optimizer.step() | ||
# reset rewards and action buffer | ||
del policy.rewards[:] | ||
del policy.saved_actions[:] | ||
|
||
|
||
EPISODE_STARTED = Events.EPOCH_STARTED | ||
EPISODE_COMPLETED = Events.EPOCH_COMPLETED | ||
|
||
|
||
def run(local_rank: int, env: Any, config: Any): | ||
# make seed | ||
rank = idist.get_rank() | ||
manual_seed(config.seed + rank) | ||
|
||
# create output folder and copy config file to ouput dir | ||
config.output_dir = setup_output_dir(config, rank) | ||
if rank == 0: | ||
copy(config.config, f"{config.output_dir}/config-lock.yaml") | ||
|
||
# create wrapper for saving video | ||
if config.render: | ||
|
||
def trigger(episode): | ||
return episode % config.save_every_episodes == 0 | ||
|
||
env = gym.wrappers.RecordVideo(env, config.recordings_path, trigger) | ||
|
||
# device, policy, optimizer | ||
device = idist.device() | ||
actor_critic = ActorCriticNetwork(env.action_space.n).to(device) | ||
optimizer = idist.auto_optim(optim.Adam(actor_critic.parameters(), lr=config.lr, betas=(0.9, 0.999))) | ||
|
||
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
# self.to(self.device) | ||
timesteps = range(10000) | ||
|
||
def run_single_timestep(engine, timestep): | ||
observation = engine.state.observation | ||
|
||
# select action from the policy | ||
observation = torch.Tensor(observation).to(device) | ||
action = choose_action(actor_critic, observation) | ||
|
||
engine.state.observation, reward, done, _, _ = env.step(action) | ||
|
||
if config.render: | ||
env.render() | ||
|
||
actor_critic.rewards.append(reward) | ||
engine.state.ep_reward += reward | ||
if done: | ||
engine.terminate_epoch() | ||
engine.state.timestep = timestep | ||
|
||
trainer = Engine(run_single_timestep) | ||
trainer.state.running_reward = 10 | ||
|
||
@trainer.on(EPISODE_STARTED) | ||
def reset_environment_state(): | ||
# reset environment and episode reward | ||
torch.manual_seed(config.seed + trainer.state.epoch) | ||
trainer.state.observation, _ = env.reset(seed=config.seed + trainer.state.epoch) | ||
trainer.state.ep_reward = 0 | ||
|
||
@trainer.on(EPISODE_COMPLETED) | ||
def update_model(): | ||
# update cumulative reward | ||
trainer.state.running_reward = 0.05 * trainer.state.ep_reward + (1 - 0.05) * trainer.state.running_reward | ||
# perform backprop | ||
learn(actor_critic, optimizer, config.gamma) | ||
|
||
@trainer.on(EPISODE_COMPLETED(every=config.log_every_episodes)) | ||
def log_episode(): | ||
i_episode = trainer.state.epoch | ||
print( | ||
f"Episode {i_episode}\tLast reward: {trainer.state.ep_reward:.2f}" | ||
f"\tAverage reward: {trainer.state.running_reward:.2f}" | ||
) | ||
|
||
@trainer.on(EPISODE_COMPLETED) | ||
def should_finish_training(): | ||
# check if we have "solved" the cart pole problem | ||
running_reward = trainer.state.running_reward | ||
if running_reward > env.spec.reward_threshold: | ||
print( | ||
f"Solved! Running reward is now {running_reward} and " | ||
f"the last episode runs to {trainer.state.timestep} time steps!" | ||
) | ||
trainer.should_terminate = True | ||
|
||
trainer.run(timesteps, max_epochs=config.max_episodes) | ||
|
||
|
||
def main(): | ||
config = setup_config() | ||
env = gym.make("CarRacing-v2", continuous=False, render_mode="rgb_array" if config.render else None) | ||
with idist.Parallel(config.backend) as p: | ||
p.run(run, env=env, config=config) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
seed: 666 | ||
render: true | ||
gamma: 0.99 | ||
recordings_path: ./recordings | ||
lr: 0.0003 | ||
max_episodes: 10000 | ||
log_every_episodes: 1 | ||
save_every_episodes: 10 | ||
output_dir: ./logs |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@theory-in-progress @vfdev-5 I think it would be better if we define the network in a separate file so that it can be edited easily. Also I was looking at
torchrl
and I think it provides some abstactions over the contents ofActorCriticNetwork
. Should we use it ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, good point, I was also once thinking about torchrl and that we could use it. Thanks!