You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[torchrec][LocalShardsWrapper] Implement tensor padding for local shards wrapper (pytorch#163183)
Summary:
X-link: pytorch/torchrec#3382
This diff implements the constant padding functionality (aten.constant_pad_nd.default) for `LocalShardsWrapper`. The method applies constant padding to the local shards based on the provided padding specification.
Depending on the sharding type (RW, CW), the padding on [left, right, top, bottom] directions will be either applied to the first/last shard, or all local shards.
New unit tests cover:
- 1D (RW) top/bottom paddings
- 2D (CW) left, right, top, bottom paddings
- empty shards, number of dimensions > 2
Test Plan:
```
buck2 test fbcode//mode/opt fbcode//torchrec/distributed/tests:test_shards_wrapper
<...>
Buck UI: https://www.internalfb.com/buck2/9fff7732-346a-43eb-b1a0-f0e43e2e8815
Test UI: https://www.internalfb.com/intern/testinfra/testrun/18014398620870153
Network: Up: 110KiB Down: 95KiB (reSessionID-c0cdcb56-f82e-4f42-9fb8-54d8a3fb74eb)
Analyzing targets. Remaining 0/191
Executing actions. Remaining 0/12849 7.6s exec time total
Command: test. Finished 5 local
Time elapsed: 1:40.1s
Test execution completed but tests were skipped
Tests finished: Pass 14. Fail 0. Fatal 0. Skip 3. Build failure 0
```
Rollback Plan:
Differential Revision: D82663766
0 commit comments