System Info
PyTorch version: 2.1.0
Python version: 3.10
OS: Ubuntu 22.04
Information
🐛 Describe the bug
When passing sharding_strategy
as a command line argument to FSDP config, it fails with a KeyError because the string value is not properly converted to ShardingStrategy
enum.
Running with --fsdp_config.sharding_strategy "FULL_SHARD"
results in:
Error logs
Traceback (most recent call last):
File "finetuning.py", line 272, in main
model = FSDP(
...
File "torch/distributed/fsdp/_init_utils.py", line 652, in _init_param_handle_from_params
SHARDING_STRATEGY_MAP[state.sharding_strategy],
KeyError: 'FULL_SHARD'
Expected behavior
The string value "FULL_SHARD" should be converted to ShardingStrategy.FULL_SHARD enum when updating the config.