Skip to content

Commit 231555d

Browse files
committed
[BugFix] Better make_composite_from_td
ghstack-source-id: c99dfa6 Pull-Request-resolved: #2952
1 parent 5056a62 commit 231555d

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torchrl/envs/utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ def make_composite_from_td(
938938
# of unbounded values.
939939
def make_shape(shape):
940940
if shape or not unsqueeze_null_shapes:
941-
if dynamic_shape:
941+
if dynamic_shape and shape:
942942
return shape[:-1] + (-1,)
943943
else:
944944
return shape
@@ -954,7 +954,8 @@ def make_shape(shape):
954954
if is_tensor_collection(tensor) and not is_non_tensor(tensor)
955955
else NonTensor(
956956
shape=tensor.shape,
957-
example_data=tensor.data,
957+
# Assume all the non-tensors have the same datatype
958+
example_data=tensor.view(-1)[0].data,
958959
device=tensor.device,
959960
)
960961
if is_non_tensor(tensor)

0 commit comments

Comments
 (0)