@@ -278,15 +278,19 @@ def pause(self):
278
278
f"Collector pause() is not implemented for { type (self ).__name__ } ."
279
279
)
280
280
281
- def async_shutdown (self , timeout : float | None = None ) -> None :
281
+ def async_shutdown (
282
+ self , timeout : float | None = None , close_env : bool = True
283
+ ) -> None :
282
284
"""Shuts down the collector when started asynchronously with the `start` method.
283
285
284
286
Arg:
285
287
timeout (float, optional): The maximum time to wait for the collector to shutdown.
288
+ close_env (bool, optional): If True, the collector will close the contained environment.
289
+ Defaults to `True`.
286
290
287
291
.. seealso:: :meth:`~.start`
288
292
"""
289
- return self .shutdown (timeout = timeout )
293
+ return self .shutdown (timeout = timeout , close_env = close_env )
290
294
291
295
def update_policy_weights_ (
292
296
self ,
@@ -342,7 +346,7 @@ def next(self):
342
346
return None
343
347
344
348
@abc .abstractmethod
345
- def shutdown (self , timeout : float | None = None ) -> None :
349
+ def shutdown (self , timeout : float | None = None , close_env : bool = True ) -> None :
346
350
raise NotImplementedError
347
351
348
352
@abc .abstractmethod
@@ -1317,12 +1321,14 @@ def _run_iterator(self):
1317
1321
if self ._stop :
1318
1322
return
1319
1323
1320
- def async_shutdown (self , timeout : float | None = None ) -> None :
1324
+ def async_shutdown (
1325
+ self , timeout : float | None = None , close_env : bool = True
1326
+ ) -> None :
1321
1327
"""Finishes processes started by ray.init() during async execution."""
1322
1328
self ._stop = True
1323
1329
if hasattr (self , "_thread" ) and self ._thread .is_alive ():
1324
1330
self ._thread .join (timeout = timeout )
1325
- self .shutdown ()
1331
+ self .shutdown (close_env = close_env )
1326
1332
1327
1333
def _postproc (self , tensordict_out ):
1328
1334
if self .split_trajs :
@@ -1582,14 +1588,20 @@ def reset(self, index=None, **kwargs) -> None:
1582
1588
)
1583
1589
self ._shuttle ["collector" ] = collector_metadata
1584
1590
1585
- def shutdown (self , timeout : float | None = None ) -> None :
1586
- """Shuts down all workers and/or closes the local environment."""
1591
+ def shutdown (self , timeout : float | None = None , close_env : bool = True ) -> None :
1592
+ """Shuts down all workers and/or closes the local environment.
1593
+
1594
+ Args:
1595
+ timeout (float, optional): The timeout for closing pipes between workers.
1596
+ No effect for this class.
1597
+ close_env (bool, optional): Whether to close the environment. Defaults to `True`.
1598
+ """
1587
1599
if not self .closed :
1588
1600
self .closed = True
1589
1601
del self ._shuttle
1590
1602
if self ._use_buffers :
1591
1603
del self ._final_rollout
1592
- if not self .env .is_closed :
1604
+ if close_env and not self .env .is_closed :
1593
1605
self .env .close ()
1594
1606
del self .env
1595
1607
return
@@ -2391,8 +2403,17 @@ def __del__(self):
2391
2403
# __del__ will not affect the program.
2392
2404
pass
2393
2405
2394
- def shutdown (self , timeout : float | None = None ) -> None :
2395
- """Shuts down all processes. This operation is irreversible."""
2406
+ def shutdown (self , timeout : float | None = None , close_env : bool = True ) -> None :
2407
+ """Shuts down all processes. This operation is irreversible.
2408
+
2409
+ Args:
2410
+ timeout (float, optional): The timeout for closing pipes between workers.
2411
+ close_env (bool, optional): Whether to close the environment. Defaults to `True`.
2412
+ """
2413
+ if not close_env :
2414
+ raise RuntimeError (
2415
+ f"Cannot shutdown { type (self ).__name__ } collector without environment being closed."
2416
+ )
2396
2417
self ._shutdown_main (timeout )
2397
2418
2398
2419
def _shutdown_main (self , timeout : float | None = None ) -> None :
@@ -2665,7 +2686,11 @@ def next(self):
2665
2686
return super ().next ()
2666
2687
2667
2688
# for RPC
2668
- def shutdown (self , timeout : float | None = None ) -> None :
2689
+ def shutdown (self , timeout : float | None = None , close_env : bool = True ) -> None :
2690
+ if not close_env :
2691
+ raise RuntimeError (
2692
+ f"Cannot shutdown { type (self ).__name__ } collector without environment being closed."
2693
+ )
2669
2694
if hasattr (self , "out_buffer" ):
2670
2695
del self .out_buffer
2671
2696
if hasattr (self , "buffers" ):
@@ -3038,9 +3063,13 @@ def next(self):
3038
3063
return super ().next ()
3039
3064
3040
3065
# for RPC
3041
- def shutdown (self , timeout : float | None = None ) -> None :
3066
+ def shutdown (self , timeout : float | None = None , close_env : bool = True ) -> None :
3042
3067
if hasattr (self , "out_tensordicts" ):
3043
3068
del self .out_tensordicts
3069
+ if not close_env :
3070
+ raise RuntimeError (
3071
+ f"Cannot shutdown { type (self ).__name__ } collector without environment being closed."
3072
+ )
3044
3073
return super ().shutdown (timeout = timeout )
3045
3074
3046
3075
# for RPC
@@ -3382,8 +3411,8 @@ def next(self):
3382
3411
return super ().next ()
3383
3412
3384
3413
# for RPC
3385
- def shutdown (self , timeout : float | None = None ) -> None :
3386
- return super ().shutdown (timeout = timeout )
3414
+ def shutdown (self , timeout : float | None = None , close_env : bool = True ) -> None :
3415
+ return super ().shutdown (timeout = timeout , close_env = close_env )
3387
3416
3388
3417
# for RPC
3389
3418
def set_seed (self , seed : int , static_seed : bool = False ) -> int :
0 commit comments