Skip to content

Commit 5056a62

Browse files
committed
[Feature] IsaacLab wrapper
ghstack-source-id: eb3721a Pull-Request-resolved: #2937
1 parent 01399e0 commit 5056a62

File tree

13 files changed

+375
-58
lines changed

13 files changed

+375
-58
lines changed

.github/unittest/linux_libs/scripts_gym/setup_env.sh

-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ set -e
1010
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
1111
# Avoid error: "fatal: unsafe repository"
1212
apt-get update && apt-get install -y git wget gcc g++
13-
1413
apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libsdl2-dev libsdl2-2.0-0
1514
apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 xvfb libegl-dev libx11-dev freeglut3-dev
1615

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#!/usr/bin/env bash
2+
3+
set -e
4+
set -v
5+
6+
#if [[ "${{ github.ref }}" =~ release/* ]]; then
7+
# export RELEASE=1
8+
# export TORCH_VERSION=stable
9+
#else
10+
export RELEASE=0
11+
export TORCH_VERSION=nightly
12+
#fi
13+
14+
set -euo pipefail
15+
export PYTHON_VERSION="3.10"
16+
export CU_VERSION="12.8"
17+
export TAR_OPTIONS="--no-same-owner"
18+
export UPLOAD_CHANNEL="nightly"
19+
export TF_CPP_MIN_LOG_LEVEL=0
20+
export BATCHED_PIPE_TIMEOUT=60
21+
export TD_GET_DEFAULTS_TO_NONE=1
22+
export OMNI_KIT_ACCEPT_EULA=yes
23+
24+
nvidia-smi
25+
26+
# Setup
27+
apt-get update && apt-get install -y git wget gcc g++
28+
apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libsdl2-dev libsdl2-2.0-0
29+
apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 xvfb libegl-dev libx11-dev freeglut3-dev
30+
31+
git config --global --add safe.directory '*'
32+
root_dir="$(git rev-parse --show-toplevel)"
33+
conda_dir="${root_dir}/conda"
34+
env_dir="${root_dir}/env"
35+
lib_dir="${env_dir}/lib"
36+
37+
cd "${root_dir}"
38+
39+
# install conda
40+
printf "* Installing conda\n"
41+
wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh"
42+
bash ./miniconda.sh -b -f -p "${conda_dir}"
43+
eval "$(${conda_dir}/bin/conda shell.bash hook)"
44+
45+
46+
conda create --prefix ${env_dir} python=3.10 -y
47+
conda activate ${env_dir}
48+
49+
# Pin pytorch to 2.5.1 for IsaacLab
50+
conda install pytorch==2.5.1 torchvision==0.20.1 pytorch-cuda=12.4 -c pytorch -c nvidia -y
51+
52+
conda run -p ${env_dir} pip install --upgrade pip
53+
conda run -p ${env_dir} pip install 'isaacsim[all,extscache]==4.5.0' --extra-index-url https://pypi.nvidia.com
54+
conda install conda-forge::"cmake>3.22" -y
55+
56+
git clone https://github.com/isaac-sim/IsaacLab.git
57+
cd IsaacLab
58+
conda run -p ${env_dir} ./isaaclab.sh --install sb3
59+
cd ../
60+
61+
# install tensordict
62+
if [[ "$RELEASE" == 0 ]]; then
63+
conda install "anaconda::cmake>=3.22" -y
64+
conda run -p ${env_dir} python3 -m pip install "pybind11[global]"
65+
conda run -p ${env_dir} python3 -m pip install git+https://github.com/pytorch/tensordict.git
66+
else
67+
conda run -p ${env_dir} python3 -m pip install tensordict
68+
fi
69+
70+
# smoke test
71+
conda run -p ${env_dir} python -c "import tensordict"
72+
73+
printf "* Installing torchrl\n"
74+
conda run -p ${env_dir} python setup.py develop
75+
conda run -p ${env_dir} python -c "import torchrl"
76+
77+
# Install pytest
78+
conda run -p ${env_dir} python -m pip install pytest pytest-cov pytest-mock pytest-instafail pytest-rerunfailures pytest-error-for-skips pytest-asyncio
79+
80+
# Run tests
81+
conda run -p ${env_dir} python -m pytest test/test_libs.py -k isaac -s

.github/workflows/test-linux-libs.yml

+18
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,24 @@ jobs:
230230
./.github/unittest/linux_libs/scripts_gym/batch_scripts.sh
231231
./.github/unittest/linux_libs/scripts_gym/post_process.sh
232232
233+
unittests-isaaclab:
234+
strategy:
235+
matrix:
236+
python_version: ["3.10"]
237+
cuda_arch_version: ["12.8"]
238+
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments/Isaac') }}
239+
uses: vmoens/test-infra/.github/workflows/isaac_linux_job_v2.yml@main
240+
with:
241+
repository: pytorch/rl
242+
runner: "linux.g5.4xlarge.nvidia.gpu"
243+
docker-image: "nvcr.io/nvidia/isaac-lab:2.1.0"
244+
test-infra-repository: vmoens/test-infra
245+
gpu-arch-type: cuda
246+
gpu-arch-version: "12.8"
247+
timeout: 120
248+
script: |
249+
./.github/unittest/linux_libs/scripts_isaaclab/isaac.sh
250+
233251
unittests-jumanji:
234252
strategy:
235253
matrix:

docs/source/reference/envs.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,7 @@ the following function will return ``1`` when queried:
14171417
HabitatEnv
14181418
IsaacGymEnv
14191419
IsaacGymWrapper
1420+
IsaacLabWrapper
14201421
JumanjiEnv
14211422
JumanjiWrapper
14221423
MeltingpotEnv

test/test_libs.py

+89-26
Original file line numberDiff line numberDiff line change
@@ -32,32 +32,6 @@
3232
import pytest
3333
import torch
3434

35-
if os.getenv("PYTORCH_TEST_FBCODE"):
36-
from pytorch.rl.test._utils_internal import (
37-
_make_multithreaded_env,
38-
CARTPOLE_VERSIONED,
39-
get_available_devices,
40-
get_default_devices,
41-
HALFCHEETAH_VERSIONED,
42-
PENDULUM_VERSIONED,
43-
PONG_VERSIONED,
44-
rand_reset,
45-
retry,
46-
rollout_consistency_assertion,
47-
)
48-
else:
49-
from _utils_internal import (
50-
_make_multithreaded_env,
51-
CARTPOLE_VERSIONED,
52-
get_available_devices,
53-
get_default_devices,
54-
HALFCHEETAH_VERSIONED,
55-
PENDULUM_VERSIONED,
56-
PONG_VERSIONED,
57-
rand_reset,
58-
retry,
59-
rollout_consistency_assertion,
60-
)
6135
from packaging import version
6236
from tensordict import (
6337
assert_allclose_td,
@@ -155,6 +129,33 @@
155129
ValueOperator,
156130
)
157131

132+
if os.getenv("PYTORCH_TEST_FBCODE"):
133+
from pytorch.rl.test._utils_internal import (
134+
_make_multithreaded_env,
135+
CARTPOLE_VERSIONED,
136+
get_available_devices,
137+
get_default_devices,
138+
HALFCHEETAH_VERSIONED,
139+
PENDULUM_VERSIONED,
140+
PONG_VERSIONED,
141+
rand_reset,
142+
retry,
143+
rollout_consistency_assertion,
144+
)
145+
else:
146+
from _utils_internal import (
147+
_make_multithreaded_env,
148+
CARTPOLE_VERSIONED,
149+
get_available_devices,
150+
get_default_devices,
151+
HALFCHEETAH_VERSIONED,
152+
PENDULUM_VERSIONED,
153+
PONG_VERSIONED,
154+
rand_reset,
155+
retry,
156+
rollout_consistency_assertion,
157+
)
158+
158159
_has_d4rl = importlib.util.find_spec("d4rl") is not None
159160

160161
_has_mo = importlib.util.find_spec("mo_gymnasium") is not None
@@ -166,6 +167,9 @@
166167
_has_minari = importlib.util.find_spec("minari") is not None
167168

168169
_has_gymnasium = importlib.util.find_spec("gymnasium") is not None
170+
171+
_has_isaaclab = importlib.util.find_spec("isaaclab") is not None
172+
169173
_has_gym_regular = importlib.util.find_spec("gym") is not None
170174
if _has_gymnasium:
171175
set_gym_backend("gymnasium").set()
@@ -4541,6 +4545,65 @@ def test_render(self, rollout_steps):
45414545
assert not torch.equal(rollout_penultimate_image, image_from_env)
45424546

45434547

4548+
@pytest.mark.skipif(not _has_isaaclab, reason="Isaaclab not found")
4549+
class TestIsaacLab:
4550+
@pytest.fixture(scope="class")
4551+
def env(self):
4552+
torch.manual_seed(0)
4553+
import argparse
4554+
4555+
# This code block ensures that the Isaac app is started in headless mode
4556+
from isaaclab.app import AppLauncher
4557+
4558+
parser = argparse.ArgumentParser(description="Train an RL agent with TorchRL.")
4559+
AppLauncher.add_app_launcher_args(parser)
4560+
args_cli, hydra_args = parser.parse_known_args(["--headless"])
4561+
AppLauncher(args_cli)
4562+
4563+
# Imports and env
4564+
import gymnasium as gym
4565+
import isaaclab_tasks # noqa: F401
4566+
from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
4567+
from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
4568+
4569+
torchrl_logger.info("Making IsaacLab env...")
4570+
env = gym.make("Isaac-Ant-v0", cfg=AntEnvCfg())
4571+
torchrl_logger.info("Wrapping IsaacLab env...")
4572+
try:
4573+
env = IsaacLabWrapper(env)
4574+
yield env
4575+
finally:
4576+
torchrl_logger.info("Closing IsaacLab env...")
4577+
env.close()
4578+
torchrl_logger.info("Closed")
4579+
4580+
def test_isaaclab(self, env):
4581+
assert env.batch_size == (4096,)
4582+
assert env._is_batched
4583+
torchrl_logger.info("Checking env specs...")
4584+
env.check_env_specs(break_when_any_done="both")
4585+
torchrl_logger.info("Check succeeded!")
4586+
4587+
def test_isaac_collector(self, env):
4588+
col = SyncDataCollector(
4589+
env, env.rand_action, frames_per_batch=1000, total_frames=100_000_000
4590+
)
4591+
try:
4592+
for data in col:
4593+
assert data.shape == (4096, 1)
4594+
break
4595+
finally:
4596+
# We must do that, otherwise `__del__` calls `shutdown` and the next test will fail
4597+
col.shutdown(close_env=False)
4598+
4599+
def test_isaaclab_reset(self, env):
4600+
# Make a rollout that will stop as soon as a trajectory reaches a done state
4601+
r = env.rollout(1_000_000)
4602+
4603+
# Check that done obs are None
4604+
assert not r["next", "policy"][r["next", "done"].squeeze(-1)].isfinite().any()
4605+
4606+
45444607
if __name__ == "__main__":
45454608
args, unknown = argparse.ArgumentParser().parse_known_args()
45464609
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/collectors/collectors.py

+43-14
Original file line numberDiff line numberDiff line change
@@ -278,15 +278,19 @@ def pause(self):
278278
f"Collector pause() is not implemented for {type(self).__name__}."
279279
)
280280

281-
def async_shutdown(self, timeout: float | None = None) -> None:
281+
def async_shutdown(
282+
self, timeout: float | None = None, close_env: bool = True
283+
) -> None:
282284
"""Shuts down the collector when started asynchronously with the `start` method.
283285
284286
Arg:
285287
timeout (float, optional): The maximum time to wait for the collector to shutdown.
288+
close_env (bool, optional): If True, the collector will close the contained environment.
289+
Defaults to `True`.
286290
287291
.. seealso:: :meth:`~.start`
288292
"""
289-
return self.shutdown(timeout=timeout)
293+
return self.shutdown(timeout=timeout, close_env=close_env)
290294

291295
def update_policy_weights_(
292296
self,
@@ -342,7 +346,7 @@ def next(self):
342346
return None
343347

344348
@abc.abstractmethod
345-
def shutdown(self, timeout: float | None = None) -> None:
349+
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
346350
raise NotImplementedError
347351

348352
@abc.abstractmethod
@@ -1317,12 +1321,14 @@ def _run_iterator(self):
13171321
if self._stop:
13181322
return
13191323

1320-
def async_shutdown(self, timeout: float | None = None) -> None:
1324+
def async_shutdown(
1325+
self, timeout: float | None = None, close_env: bool = True
1326+
) -> None:
13211327
"""Finishes processes started by ray.init() during async execution."""
13221328
self._stop = True
13231329
if hasattr(self, "_thread") and self._thread.is_alive():
13241330
self._thread.join(timeout=timeout)
1325-
self.shutdown()
1331+
self.shutdown(close_env=close_env)
13261332

13271333
def _postproc(self, tensordict_out):
13281334
if self.split_trajs:
@@ -1582,14 +1588,20 @@ def reset(self, index=None, **kwargs) -> None:
15821588
)
15831589
self._shuttle["collector"] = collector_metadata
15841590

1585-
def shutdown(self, timeout: float | None = None) -> None:
1586-
"""Shuts down all workers and/or closes the local environment."""
1591+
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
1592+
"""Shuts down all workers and/or closes the local environment.
1593+
1594+
Args:
1595+
timeout (float, optional): The timeout for closing pipes between workers.
1596+
No effect for this class.
1597+
close_env (bool, optional): Whether to close the environment. Defaults to `True`.
1598+
"""
15871599
if not self.closed:
15881600
self.closed = True
15891601
del self._shuttle
15901602
if self._use_buffers:
15911603
del self._final_rollout
1592-
if not self.env.is_closed:
1604+
if close_env and not self.env.is_closed:
15931605
self.env.close()
15941606
del self.env
15951607
return
@@ -2391,8 +2403,17 @@ def __del__(self):
23912403
# __del__ will not affect the program.
23922404
pass
23932405

2394-
def shutdown(self, timeout: float | None = None) -> None:
2395-
"""Shuts down all processes. This operation is irreversible."""
2406+
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
2407+
"""Shuts down all processes. This operation is irreversible.
2408+
2409+
Args:
2410+
timeout (float, optional): The timeout for closing pipes between workers.
2411+
close_env (bool, optional): Whether to close the environment. Defaults to `True`.
2412+
"""
2413+
if not close_env:
2414+
raise RuntimeError(
2415+
f"Cannot shutdown {type(self).__name__} collector without environment being closed."
2416+
)
23962417
self._shutdown_main(timeout)
23972418

23982419
def _shutdown_main(self, timeout: float | None = None) -> None:
@@ -2665,7 +2686,11 @@ def next(self):
26652686
return super().next()
26662687

26672688
# for RPC
2668-
def shutdown(self, timeout: float | None = None) -> None:
2689+
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
2690+
if not close_env:
2691+
raise RuntimeError(
2692+
f"Cannot shutdown {type(self).__name__} collector without environment being closed."
2693+
)
26692694
if hasattr(self, "out_buffer"):
26702695
del self.out_buffer
26712696
if hasattr(self, "buffers"):
@@ -3038,9 +3063,13 @@ def next(self):
30383063
return super().next()
30393064

30403065
# for RPC
3041-
def shutdown(self, timeout: float | None = None) -> None:
3066+
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
30423067
if hasattr(self, "out_tensordicts"):
30433068
del self.out_tensordicts
3069+
if not close_env:
3070+
raise RuntimeError(
3071+
f"Cannot shutdown {type(self).__name__} collector without environment being closed."
3072+
)
30443073
return super().shutdown(timeout=timeout)
30453074

30463075
# for RPC
@@ -3382,8 +3411,8 @@ def next(self):
33823411
return super().next()
33833412

33843413
# for RPC
3385-
def shutdown(self, timeout: float | None = None) -> None:
3386-
return super().shutdown(timeout=timeout)
3414+
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
3415+
return super().shutdown(timeout=timeout, close_env=close_env)
33873416

33883417
# for RPC
33893418
def set_seed(self, seed: int, static_seed: bool = False) -> int:

0 commit comments

Comments
 (0)