Skip to content

[BUG] Memory Leak in BraxEnv with requires_grad=True #2837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mondeg0 opened this issue Mar 7, 2025 · 2 comments
Open

[BUG] Memory Leak in BraxEnv with requires_grad=True #2837

mondeg0 opened this issue Mar 7, 2025 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@mondeg0
Copy link

mondeg0 commented Mar 7, 2025

Describe the bug

When using BraxEnv with requires_grad=True, there appears to be a memory leak on the CPU side. The memory usage keeps increasing over time, which can be observed using tools like htop. This happens even when explicitly detaching and cloning next_td and backwarding to release the graph.

To Reproduce

Simply run the following code and observe the increasing in memory usage.

import os
import torch
import torch.nn as nn
import torch.optim as optim
from brax import envs
from torchrl.envs.libs.brax import BraxEnv

class Actor(nn.Module):
    def __init__(self, obs_size, action_size, hidden_size, policy_std_init=0.05):
        super(Actor, self).__init__()
        self.mu_net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, action_size),
        )

    def forward(self, obs):
        loc = self.mu_net(obs)
        return torch.tanh(loc)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
env_name = "hopper"
seed = 13
num_envs = 2
epochs = 10000

policy = Actor(11, 3, 64).to(device)
optim = optim.Adam(policy.parameters(), lr=1e-4)

env = BraxEnv(env_name, batch_size=[num_envs], requires_grad=True, device=device)
env.set_seed(seed)

next_td = env.reset()
for i in range(epochs):
    print(i)
    next_td["action"] = policy(next_td["observation"])
    out_td, next_td = env.step_and_maybe_reset(next_td)

    if out_td["next", "done"].any():
        loss = out_td["next", "observation"].sum()  # just for demonstration purpose
        loss.backward()
        optim.step()
        optim.zero_grad()
        next_td = next_td.detach().clone()

Expected behavior

The memory usage should remain stable over time instead of continuously increasing.

System info

Installation was done with pip.

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

outputs :

>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.7.1 2.2.3 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] linux

Additional context

This issue seems to be related to the way BraxEnv manages memory when requires_grad=True. Even though next_td is detached and cloned, memory continues to accumulate on the CPU.

Reason and Possible fixes

Some list or buffer is accumulating things under the wood?

Checklist

  • [x ] I have checked that there is no similar issue in the repo (required)
  • [x ] I have read the documentation (required)
  • [ x] I have provided a minimal working example to reproduce the bug (required)
@mondeg0 mondeg0 added the bug Something isn't working label Mar 7, 2025
@vmoens
Copy link
Contributor

vmoens commented Mar 7, 2025

Will look into this, thanks for reporting!

@jacr13
Copy link

jacr13 commented Mar 20, 2025

It seems that deleting ctx.vjp_fn before returning gradients, helps preventing the explosion of the memory in cpu.

        grads = (grad_action, *grad_state_qp.values())

        del ctx.vjp_fn  # Free memory
        return (None, None, *grads)

another thing that seems to help is to jit the call to vjp

    def _step_and_vjp(self, state, actions):
        new_state, vjp_fn = self.jax.vjp(self._vmap_env_step, state, actions)
        return new_state, vjp_fn

    def _init_env(self) -> int | None:
        jax = self.jax
        self._key = None
        self._vmap_jit_env_reset = jax.vmap(jax.jit(self._env.reset))
        if self.requires_grad:
            self._vmap_env_step = jax.vmap(self._env.step)
            self._vmap_jit_env_step = jax.jit(self._step_and_vjp)
        else:
            self._vmap_jit_env_step = jax.vmap(jax.jit(self._env.step))
        self._state_example = self._make_state_example()

But there is still a little increase in memory, see example of output:

[MEMORY] 949 - step -: 2788.57 MB (+0.12 MB)
[MEMORY]                in forward call vjp: 2789.70 MB (+1.12 MB)
[MEMORY]                in forward to td: 2789.82 MB (+0.12 MB)
[MEMORY]                in forward tensors to ndarrays: 2789.95 MB (+0.12 MB)
[MEMORY]        960 - loss back -: 2790.07 MB (+0.12 MB)
[MEMORY]                in forward to td: 2790.20 MB (+0.12 MB)
[MEMORY]                in forward call vjp: 2792.07 MB (+1.88 MB)
[MEMORY] 969 - step -: 2792.20 MB (+0.12 MB)
[MEMORY]                in forward to td: 2792.32 MB (+0.12 MB)
[MEMORY]                in forward flatten batch size: 2792.45 MB (+0.12 MB)
[MEMORY]                in forward to td: 2792.57 MB (+0.12 MB)
[MEMORY]        990 - loss back -: 2792.70 MB (+0.12 MB)
[MEMORY]                in forward to td: 2792.82 MB (+0.12 MB)
[MEMORY]                in forward to td: 2792.95 MB (+0.12 MB)

I attached the files to reproduce.

main.txt
libs_brax.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants