Skip to content

[Feature Request] Dynamically Update Environment Parameters (and Reset Transforms) in MultiaSyncDataCollector Workers #2896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
1 task done
Xavier9031 opened this issue Apr 8, 2025 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@Xavier9031
Copy link

Motivation

Currently, training paradigms that require dynamically changing environment configuration parameters during a single training run (e.g., curriculum learning, adaptive difficulty, switching between environment variants/tasks within a bandit framework) face significant performance bottlenecks when using MultiaSyncDataCollector.

My specific problem is that I need to update a configuration parameter within the custom Gymnasium environment instances running on the worker processes periodically. The only reliable way to achieve this currently is to shut down the entire MultiaSyncDataCollector and create a new one with an updated create_env_fn. This process incurs a substantial time cost (tens of seconds per configuration switch in my case), making training loops with frequent updates impractically slow. I'm frustrated because the core computation (environment steps and policy inference) is fast, but the infrastructure management (process shutdown/restart) dominates the wall-clock time during these transitions.

Solution

I propose the addition of a mechanism within torchrl, specifically for MultiaSyncDataCollector (and potentially other parallel collectors), that allows users to:

  1. Broadcast Parameter Updates: Send new configuration parameter values from the main process to all active worker processes without terminating them.
  2. Invoke Environment Logic: Trigger specific methods within the environment instances on the workers to apply these new parameters (e.g., calling an internal env.update_config(new_param=value) method).
  3. Trigger Transform Re-initialization: Critically, provide a way to signal stateful transforms (like ObservationNorm, Compose, etc.) within the worker environments to re-initialize their state based on the new environment configuration. This might involve allowing users to specify which transform methods (e.g., transform.init_stats(), a custom transform.reset_state()) should be called after the environment parameters are updated.

This would allow for efficient, in-place updates of the environment setup across all workers, eliminating the costly shutdown/restart cycle.

Alternatives

  1. Current Workaround (Shutdown & Recreate): The primary alternative is the current working method: call collector.shutdown() and instantiate a new MultiaSyncDataCollector with the updated configuration in create_env_fn.
    • Pro: Guarantees correctness, including proper initialization of environments and transforms.
    • Con: Extremely slow due to process management overhead.
  2. Using collector.reset() with Arguments: I attempted to pass the new parameter via collector.reset(my_param=new_value), hoping it would be forwarded to env.reset() in the workers.
    • Con: This is not supported functionality (TypeError: reset() got an unexpected keyword argument 'my_param') and, more importantly, it wouldn't address the re-initialization requirement for stateful transforms like ObservationNorm.
  3. Manual IPC via collector.pipes: One could theoretically try sending custom messages through collector.pipes to the workers.
    • Con: Requires modifying torchrl's internal worker loop logic to handle these custom messages, making it brittle, hard to maintain, and breaking encapsulation. It also requires careful handling of synchronization and transform state.

Additional context

The need for dynamic updates is particularly relevant for complex training procedures where the environment's characteristics change over time based on agent progress or predefined schedules. The correct handling of stateful transforms (ObservationNorm being a key example) is essential for stability, as using stale normalization statistics after a parameter change can lead to incorrect observations and poor learning. The error message TypeError: MultiaSyncDataCollector.reset() got an unexpected keyword argument '...' confirms the limitation of the current reset approach for passing parameters.


Checklist

  • I have checked that there is no similar issue in the repo (required)
@Xavier9031 Xavier9031 added the enhancement New feature or request label Apr 8, 2025
@vmoens
Copy link
Contributor

vmoens commented Apr 8, 2025

Hello!
We're working on a weight updater API that will be able to do that.
The idea is that you'll have a local WeightReceiver that will implement the dispatching of the weights to the model and env, and a WeightSender that will send the weights from the main data collectors to the leaves.
Both of these will have access to the data collector (through a weakref) such that you can do weight_receiver.collector.env.load_state_dict(env_state_dict).
This is still WIP

cc @Darktex RE discussion on how to handle the weight sync within the collector

@Xavier9031
Copy link
Author

Hi, @vmoens

Thank you for your response—I’m really encouraged to hear that you’re working on a weight updater API. This approach seems promising in avoiding the need to shut down and restart the entire MultiaSyncDataCollector when updating environment configurations.

I’d also like to share some additional performance data I gathered with different worker counts. For clarity, I’ve organized the numbers into the following table:

Worker Count Environment Creation Data Collector Creation **Data Collection ** Environment Switch Time
1 ~7.99 sec ~2.83 sec ~24.39 sec ~18.62 sec
4 ~7.64 sec ~8.40 sec ~25.40 sec ~19.86 sec
10 ~7.65 sec ~19.66 sec ~27.63 sec ~20.74 sec
20 ~8.05 sec ~38.05 sec ~34.41 sec ~26.59 sec

One aspect I find somewhat puzzling is that when using a higher number of workers, the data collection step after switching the environment takes noticeably longer. I would have expected better scalability with more workers. This could indicate some additional overheads—perhaps related to inter-worker communication, synchronization, or resource contention—that might be affecting performance.

I also want to clarify that these experimental results were obtained by forcibly injecting messages into the workers’ communication channels. Here’s a brief explanation of the approach I used:

  • Main Process:
    In the main application, I iterate over the collector’s pipes and send a tuple containing the configuration parameter (or any parameter to be updated) along with the string "reset". For example:

    for i, pipe in enumerate(collector.pipes):
        pipe.send((parm, "reset"))

    This ensures that each worker receives an explicit instruction to perform a reset.

  • Worker Process (Collector Side):
    I modified the source code in collectors.py so that when a worker receives a message with the "reset" command, it logs the received data and calls its internal reset function. Depending on whether a parameter is provided, it either calls inner_collector.reset() or calls a version of reset that accepts a parameter. After completing the reset, the worker sends a confirmation message back to the main process. All exceptions during this process are caught and logged, and an error message is sent back if necessary. This forced message injection mechanism demonstrates how an in-place update might work, without incurring the overhead of shutting down and recreating the entire data collector.

A few points I’d like to clarify further:

  1. Dynamic Environment Updates:
    Will the weight updater API support not only model weight updates but also the dynamic update of environment configuration parameters? For instance, could the API facilitate invoking something like env.load_state_dict(env_state_dict) on each worker so that the environment state (and its configuration) is updated without a full restart?

  2. Resetting Stateful Transforms:
    As mentioned in my original issue, it’s crucial to reinitialize stateful transforms (e.g., ObservationNorm) when the environment configuration changes. Will the API provide a mechanism or callbacks to trigger methods such as transform.init_stats() or a custom reset function for these transforms?

  3. Performance with More Workers:
    The table indicates that data collection after an environment switch becomes slower with a larger number of workers. Do you have any insights into why this might be happening, or if there are potential optimizations for a multi-worker setup? Any feedback on this would be greatly appreciated.

Could you also share any additional details on the API’s progress or expected timeline, or point me to any early documentation/examples? I’m eager to test these changes and provide further feedback if needed.

Thank you again for your hard work on addressing these issues—I look forward to hearing more!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants