Skip to content

Commit 1813e8e

Browse files
committed
[BugFix] Fix shifted value computation with an LSTM
ghstack-source-id: 9ccbf82 Pull-Request-resolved: #2941
1 parent 231555d commit 1813e8e

File tree

6 files changed

+274
-61
lines changed

6 files changed

+274
-61
lines changed

test/test_cost.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,20 @@
4646
from torchrl._utils import _standardize
4747
from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
4848
from torchrl.data.postprocs.postprocs import MultiStep
49-
from torchrl.envs import EnvBase
49+
from torchrl.envs import EnvBase, GymEnv, InitTracker, SerialEnv
50+
from torchrl.envs.libs.gym import _has_gym
5051
from torchrl.envs.model_based.dreamer import DreamerEnv
5152
from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
5253
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
5354
from torchrl.modules import (
5455
DistributionalQValueActor,
56+
GRUModule,
57+
LSTMModule,
5558
OneHotCategorical,
5659
QValueActor,
5760
recurrent_mode,
5861
SafeSequential,
62+
set_recurrent_mode,
5963
WorldModelWrapper,
6064
)
6165
from torchrl.modules.distributions.continuous import TanhDelta, TanhNormal
@@ -146,6 +150,7 @@
146150
dtype_fixture,
147151
get_available_devices,
148152
get_default_devices,
153+
PENDULUM_VERSIONED,
149154
)
150155
from pytorch.rl.test.mocking_classes import ContinuousActionConvMockEnv
151156
else:
@@ -154,6 +159,7 @@
154159
dtype_fixture,
155160
get_available_devices,
156161
get_default_devices,
162+
PENDULUM_VERSIONED,
157163
)
158164
from mocking_classes import ContinuousActionConvMockEnv
159165

@@ -13755,6 +13761,79 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1375513761

1375613762

1375713763
class TestValues:
13764+
@pytest.mark.skipif(not _has_gym, reason="requires gym")
13765+
@pytest.mark.parametrize("module", ["lstm", "gru"])
13766+
def test_gae_recurrent(self, module):
13767+
# Checks that shifted=True and False provide the same result in GAE when an LSTM is used
13768+
env = SerialEnv(
13769+
2,
13770+
[
13771+
functools.partial(
13772+
TransformedEnv, GymEnv(PENDULUM_VERSIONED()), InitTracker()
13773+
)
13774+
for _ in range(2)
13775+
],
13776+
)
13777+
env.set_seed(0)
13778+
torch.manual_seed(0)
13779+
if module == "lstm":
13780+
recurrent_module = LSTMModule(
13781+
input_size=env.observation_spec["observation"].shape[-1],
13782+
hidden_size=64,
13783+
in_keys=["observation", "rs_h", "rs_c"],
13784+
out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")],
13785+
python_based=True,
13786+
dropout=0,
13787+
)
13788+
elif module == "gru":
13789+
recurrent_module = GRUModule(
13790+
input_size=env.observation_spec["observation"].shape[-1],
13791+
hidden_size=64,
13792+
in_keys=["observation", "rs_h"],
13793+
out_keys=["intermediate", ("next", "rs_h")],
13794+
python_based=True,
13795+
dropout=0,
13796+
)
13797+
else:
13798+
raise NotImplementedError
13799+
recurrent_module.eval()
13800+
mlp_value = MLP(num_cells=[64], out_features=1)
13801+
value_net = Seq(
13802+
recurrent_module,
13803+
Mod(mlp_value, in_keys=["intermediate"], out_keys=["state_value"]),
13804+
)
13805+
mlp_policy = MLP(num_cells=[64], out_features=1)
13806+
policy_net = Seq(
13807+
recurrent_module,
13808+
Mod(mlp_policy, in_keys=["intermediate"], out_keys=["action"]),
13809+
)
13810+
env = env.append_transform(recurrent_module.make_tensordict_primer())
13811+
vals = env.rollout(1000, policy_net, break_when_any_done=False)
13812+
value_net(vals.copy())
13813+
13814+
# Shifted
13815+
gae_shifted = GAE(
13816+
gamma=0.9,
13817+
lmbda=0.99,
13818+
value_network=value_net,
13819+
shifted=True,
13820+
)
13821+
with set_recurrent_mode(True):
13822+
r0 = gae_shifted(vals.copy())
13823+
a0 = r0["advantage"]
13824+
13825+
gae = GAE(
13826+
gamma=0.9,
13827+
lmbda=0.99,
13828+
value_network=value_net,
13829+
shifted=False,
13830+
deactivate_vmap=True,
13831+
)
13832+
with set_recurrent_mode(True):
13833+
r1 = gae(vals.copy())
13834+
a1 = r1["advantage"]
13835+
torch.testing.assert_close(a0, a1)
13836+
1375813837
@pytest.mark.parametrize("device", get_default_devices())
1375913838
@pytest.mark.parametrize("gamma", [0.1, 0.5, 0.99])
1376013839
@pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99])

test/test_utils.py

+29
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import torch
1818

19+
from torchrl.objectives.utils import _pseudo_vmap
20+
1921
if os.getenv("PYTORCH_TEST_FBCODE"):
2022
from pytorch.rl.test._utils_internal import capture_log_records, get_default_devices
2123
else:
@@ -416,6 +418,33 @@ def str_to_tensor(s):
416418
assert len(records) == 1
417419

418420

421+
def add_one(x):
422+
return x + 1
423+
424+
425+
@pytest.mark.parametrize("in_dim, out_dim", [(0, 0), (0, 1), (1, 0), (1, 1)])
426+
def test_vmap_in_out_dims(in_dim, out_dim):
427+
# Create a tensor with batch dimension
428+
x = torch.arange(10).reshape(2, 5)
429+
# Move the input dimension to match in_dim
430+
x_moved = torch.moveaxis(x, 0, in_dim)
431+
# Using vmap with specified in_dim and out_dim
432+
vmapped_add_one = torch.vmap(add_one, in_dims=in_dim, out_dims=out_dim)
433+
actual_result = vmapped_add_one(x_moved)
434+
pseudo_vmapped_add_one = _pseudo_vmap(add_one, in_dims=in_dim, out_dims=out_dim)
435+
pseudo_actual_result = pseudo_vmapped_add_one(x_moved)
436+
437+
# Expected result by applying add_one on each element of the batch separately
438+
expected_result = x + 1
439+
# Move the output dimension to match the expected result
440+
if out_dim == 1:
441+
actual_result = torch.moveaxis(actual_result, out_dim, 0)
442+
pseudo_actual_result = torch.moveaxis(pseudo_actual_result, out_dim, 0)
443+
# Assert the results are as expected
444+
assert torch.allclose(actual_result, expected_result)
445+
assert torch.allclose(pseudo_actual_result, expected_result)
446+
447+
419448
if __name__ == "__main__":
420449
args, unknown = argparse.ArgumentParser().parse_known_args()
421450
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/tensor_specs.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ class TensorSpec(metaclass=abc.ABCMeta):
571571
shape: torch.Size
572572
space: None | Box
573573
device: torch.device | None = None
574-
dtype: torch.dtype = torch.float
574+
dtype: torch.dtype = torch.get_default_dtype()
575575
domain: str = ""
576576
_encode_memo_dict: dict[Any, Callable[[Any], Any]] = field(
577577
default_factory=dict,
@@ -1682,7 +1682,7 @@ class OneHot(TensorSpec):
16821682
shape: torch.Size
16831683
space: CategoricalBox
16841684
device: torch.device | None = None
1685-
dtype: torch.dtype = torch.float
1685+
dtype: torch.dtype = torch.get_default_dtype()
16861686
domain: str = ""
16871687
_encode_memo_dict: dict[Any, Callable[[Any], Any]] = field(
16881688
default_factory=dict,
@@ -2067,7 +2067,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING):
20672067

20682068
def _project(self, val: torch.Tensor) -> torch.Tensor:
20692069
if self.mask is None:
2070-
out = torch.multinomial(val.to(torch.float), 1).squeeze(-1)
2070+
out = torch.multinomial(val.to(torch.get_default_dtype()), 1).squeeze(-1)
20712071
out = torch.nn.functional.one_hot(out, self.space.n).to(self.dtype)
20722072
return out
20732073
shape = self.mask.shape
@@ -3735,7 +3735,7 @@ class Categorical(TensorSpec):
37353735
shape: torch.Size
37363736
space: CategoricalBox
37373737
device: torch.device | None = None
3738-
dtype: torch.dtype = torch.float
3738+
dtype: torch.dtype = torch.get_default_dtype()
37393739
domain: str = ""
37403740

37413741
# SPEC_HANDLED_FUNCTIONS = {}

torchrl/modules/tensordict_module/rnn.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -730,10 +730,16 @@ def forward(self, tensordict: TensorDictBase):
730730
# packed sequences do not help to get the accurate last hidden values
731731
# if splits is not None:
732732
# value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True)
733-
if hidden0 is not None:
733+
734+
if not self.recurrent_mode and hidden0 is not None:
735+
# We zero the hidden states if we're calling the lstm recursively
736+
# as we assume the hidden state comes from the previous trajectory.
737+
# When using the recurrent_mode=True option, the lstm can be called from
738+
# any intermediate state, hence zeroing should not be done.
734739
is_init_expand = expand_as_right(is_init, hidden0)
735740
hidden0 = torch.where(is_init_expand, 0, hidden0)
736741
hidden1 = torch.where(is_init_expand, 0, hidden1)
742+
737743
val, hidden0, hidden1 = self._lstm(
738744
value, batch, steps, device, dtype, hidden0, hidden1
739745
)
@@ -782,8 +788,8 @@ def _lstm(
782788
)
783789

784790
# we only need the first hidden state
785-
_hidden0_in = hidden0_in[:, 0]
786-
_hidden1_in = hidden1_in[:, 0]
791+
_hidden0_in = hidden0_in[..., 0, :, :]
792+
_hidden1_in = hidden1_in[..., 0, :, :]
787793
hidden = (
788794
_hidden0_in.transpose(-3, -2).contiguous(),
789795
_hidden1_in.transpose(-3, -2).contiguous(),
@@ -1517,7 +1523,7 @@ def forward(self, tensordict: TensorDictBase):
15171523
# packed sequences do not help to get the accurate last hidden values
15181524
# if splits is not None:
15191525
# value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True)
1520-
if is_init.any() and hidden is not None:
1526+
if not self.recurrent_mode and is_init.any() and hidden is not None:
15211527
is_init_expand = expand_as_right(is_init, hidden)
15221528
hidden = torch.where(is_init_expand, 0, hidden)
15231529
val, hidden = self._gru(value, batch, steps, device, dtype, hidden)

torchrl/objectives/utils.py

+56-3
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
import re
99
import warnings
1010
from enum import Enum
11-
from typing import Iterable
11+
from typing import Any, Callable, Iterable
1212

1313
import torch
1414
from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
1515
from tensordict.nn import TensorDictModule
1616
from torch import nn, Tensor
1717
from torch.nn import functional as F
1818
from torch.nn.modules import dropout
19+
from torch.utils._pytree import tree_map
1920

2021
try:
2122
from torch import vmap
@@ -527,7 +528,7 @@ def new_func(self, netname=None):
527528
return new_func
528529

529530

530-
def _vmap_func(module, *args, func=None, **kwargs):
531+
def _vmap_func(module, *args, func=None, pseudo_vmap: bool = False, **kwargs):
531532
try:
532533

533534
def decorated_module(*module_args_params):
@@ -539,7 +540,9 @@ def decorated_module(*module_args_params):
539540
else:
540541
return getattr(module, func)(*module_args)
541542

542-
return vmap(decorated_module, *args, **kwargs) # noqa: TOR101
543+
if not pseudo_vmap:
544+
return vmap(decorated_module, *args, **kwargs) # noqa: TOR101
545+
return _pseudo_vmap(decorated_module, *args, **kwargs)
543546

544547
except RuntimeError as err:
545548
if re.match(
@@ -550,6 +553,56 @@ def decorated_module(*module_args_params):
550553
) from err
551554

552555

556+
def _pseudo_vmap(
557+
func: Callable,
558+
in_dims: Any = 0,
559+
out_dims: Any = 0,
560+
randomness: str | None = None,
561+
*,
562+
chunk_size=None,
563+
):
564+
if randomness is not None and randomness not in ("different", "error"):
565+
raise ValueError(
566+
f"pseudo_vmap only supports 'different' or 'error' randomness modes, but got {randomness=}. If another mode is required, please "
567+
"submit an issue in TorchRL."
568+
)
569+
if isinstance(in_dims, int):
570+
in_dims = (in_dims,)
571+
if isinstance(out_dims, int):
572+
out_dims = (out_dims,)
573+
from tensordict.nn.functional_modules import _exclude_td_from_pytree
574+
575+
def _unbind(d, x):
576+
if d is not None and hasattr(x, "unbind"):
577+
return x.unbind(d)
578+
# Generator to reprod the value
579+
return (x for _ in range(1000))
580+
581+
def _stack(d, x):
582+
if d is not None:
583+
return torch.stack(list(x), d)
584+
return x
585+
586+
@functools.wraps(func)
587+
def new_func(*args, **kwargs):
588+
with _exclude_td_from_pytree():
589+
# Unbind inputs
590+
vs = zip(*tuple(tree_map(_unbind, in_dims, args)))
591+
rs = []
592+
for v in vs:
593+
r = func(*v, **kwargs)
594+
if not isinstance(r, tuple):
595+
r = (r,)
596+
rs.append(r)
597+
rs = tuple(zip(*rs))
598+
vs = tuple(tree_map(_stack, out_dims, rs))
599+
if len(vs) == 1:
600+
return vs[0]
601+
return vs
602+
603+
return new_func
604+
605+
553606
def _reduce(tensor: torch.Tensor, reduction: str) -> float | torch.Tensor:
554607
"""Reduces a tensor given the reduction method."""
555608
if reduction == "none":

0 commit comments

Comments
 (0)