Skip to content

Commit a86c9a1

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
[torchrec][LocalShardsWrapper] Implement tensor padding for local shards wrapper (pytorch#163183)
Summary: X-link: pytorch/torchrec#3382 This diff implements the constant padding functionality (aten.constant_pad_nd.default) for `LocalShardsWrapper`. The method applies constant padding to the local shards based on the provided padding specification. Depending on the sharding type (RW, CW), the padding on [left, right, top, bottom] directions will be either applied to the first/last shard, or all local shards. New unit tests cover: - 1D (RW) top/bottom paddings - 2D (CW) left, right, top, bottom paddings - empty shards, number of dimensions > 2 Test Plan: ``` buck2 test fbcode//caffe2/test/distributed/tensor:shards_wrapper 2025-09-18T15:32:46.525914Z WARN buck2_interpreter_for_build::interpreter::functions::warning: ptxas 12.8 is not available on platform platform010-aarch64-compat 2025-09-18T15:32:46.525953Z WARN buck2_interpreter_for_build::interpreter::functions::warning: ptxas 12.8 is not available on platform platform010-compat 2025-09-18T15:32:46.525959Z WARN buck2_interpreter_for_build::interpreter::functions::warning: ptxas 12.8 is not available on platform platform010-libcxx Buck UI: https://www.internalfb.com/buck2/ffb34bcb-1555-4fa3-89c6-9c22d078606a Test UI: https://www.internalfb.com/intern/testinfra/testrun/12384899087608299 Network: Up: 159MiB Down: 13GiB (reSessionID-f734bd3c-19ca-44c9-919f-57203ac00be8) Loading targets. Remaining 0/5110 104336 dirs read, 1265395 targets declared Analyzing targets. Remaining 0/80346 3349033 actions, 4142832 artifacts declared Executing actions. Remaining 0/521855 149:06:17.8s exec time total Command: test. Finished 14 local, 397 remote, 199840 cache (99% hit) 148:27:40.9s exec time cached (99%) Time elapsed: 8:55.5s Tests finished: Pass 14. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Differential Revision: D82663766
1 parent 62a746f commit a86c9a1

File tree

2 files changed

+503
-0
lines changed

2 files changed

+503
-0
lines changed
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates
2+
# Owner(s): ["oncall: distributed"]
3+
import unittest
4+
5+
import torch
6+
from torch.distributed.tensor._shards_wrapper import LocalShardsWrapper
7+
from torch.testing._internal.common_utils import run_tests
8+
9+
10+
class LocalShardsWrapperPaddingTest(unittest.TestCase):
11+
"""Test cases for constant padding functionality in LocalShardsWrapper."""
12+
13+
def test_empty_shards_padding(self) -> None:
14+
"""Test padding with empty shards list."""
15+
lsw = LocalShardsWrapper([], [])
16+
pad_spec = [1, 2, 3, 4]
17+
pad_value = 5.0
18+
19+
self.assertRaises(
20+
Exception,
21+
torch.ops.aten.constant_pad_nd.default,
22+
lsw,
23+
pad_spec,
24+
pad_value,
25+
)
26+
27+
def test_single_shard_padding_2d(self) -> None:
28+
"""Test padding with single 2D shard."""
29+
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
30+
lsw = LocalShardsWrapper([tensor], [(0, 0)])
31+
pad_spec = [1, 2, 3, 4] # [left=1, right=2, top=3, bottom=4]
32+
pad_value = 0.0
33+
34+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
35+
36+
self.assertIsInstance(result, LocalShardsWrapper)
37+
self.assertEqual(len(result.local_shards()), 1)
38+
39+
expected = torch.nn.functional.pad(
40+
tensor, pad_spec, mode="constant", value=pad_value
41+
)
42+
torch.testing.assert_close(result.local_shards()[0], expected)
43+
44+
def test_single_shard_padding_1d(self) -> None:
45+
"""Test padding with single 1D shard."""
46+
tensor = torch.tensor([1.0, 2.0, 3.0])
47+
lsw = LocalShardsWrapper([tensor], [(0,)])
48+
pad_spec = [2, 1] # [top=2, bottom=1]
49+
pad_value = -1.0
50+
51+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
52+
53+
self.assertIsInstance(result, LocalShardsWrapper)
54+
self.assertEqual(len(result.local_shards()), 1)
55+
56+
expected = torch.nn.functional.pad(
57+
tensor, pad_spec, mode="constant", value=pad_value
58+
)
59+
torch.testing.assert_close(result.local_shards()[0], expected)
60+
61+
def test_cw_sharding_top_padding(self) -> None:
62+
"""Test column-wise sharding with top padding (affects all shards)."""
63+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
64+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
65+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
66+
pad_spec = [0, 0, 2, 0] # top=2
67+
pad_value = 0.0
68+
69+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
70+
71+
self.assertEqual(len(result.local_shards()), 2)
72+
# Both shards should have 2 rows added at top
73+
expected_shape = (4, 2)
74+
self.assertEqual(result.local_shards()[0].shape, expected_shape)
75+
self.assertEqual(result.local_shards()[1].shape, expected_shape)
76+
77+
torch.testing.assert_close(result.local_shards()[0][:2], torch.zeros(2, 2))
78+
torch.testing.assert_close(result.local_shards()[1][:2], torch.zeros(2, 2))
79+
torch.testing.assert_close(result.local_shards()[0][2:], shard1)
80+
torch.testing.assert_close(result.local_shards()[1][2:], shard2)
81+
82+
def test_cw_sharding_bottom_padding(self) -> None:
83+
"""Test column-wise sharding with bottom padding (affects all shards)."""
84+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
85+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
86+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
87+
pad_spec = [0, 0, 0, 1] # bottom=1
88+
pad_value = -1.0
89+
90+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
91+
92+
self.assertEqual(len(result.local_shards()), 2)
93+
expected_shape = (3, 2)
94+
self.assertEqual(result.local_shards()[0].shape, expected_shape)
95+
self.assertEqual(result.local_shards()[1].shape, expected_shape)
96+
97+
torch.testing.assert_close(result.local_shards()[0][:2], shard1)
98+
torch.testing.assert_close(result.local_shards()[1][:2], shard2)
99+
torch.testing.assert_close(
100+
result.local_shards()[0][2:], torch.full((1, 2), -1.0)
101+
)
102+
torch.testing.assert_close(
103+
result.local_shards()[1][2:], torch.full((1, 2), -1.0)
104+
)
105+
106+
def test_cw_sharding_left_padding(self) -> None:
107+
"""Test column-wise sharding with left padding (affects first shard only)."""
108+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
109+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
110+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
111+
pad_spec = [3, 0, 0, 0] # left=3
112+
pad_value = 2.0
113+
114+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
115+
116+
self.assertEqual(len(result.local_shards()), 2)
117+
# First shard should have 3 columns added at left
118+
self.assertEqual(result.local_shards()[0].shape, (2, 5))
119+
self.assertEqual(result.local_shards()[1].shape, (2, 2))
120+
121+
# Check content
122+
torch.testing.assert_close(
123+
result.local_shards()[0][:, :3], torch.full((2, 3), 2.0)
124+
)
125+
torch.testing.assert_close(result.local_shards()[0][:, 3:], shard1)
126+
torch.testing.assert_close(result.local_shards()[1], shard2)
127+
128+
def test_cw_sharding_right_padding(self) -> None:
129+
"""Test column-wise sharding with right padding (affects last shard only)."""
130+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
131+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
132+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
133+
pad_spec = [0, 2, 0, 0] # right=2
134+
pad_value = 3.0
135+
136+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
137+
138+
self.assertEqual(len(result.local_shards()), 2)
139+
self.assertEqual(result.local_shards()[0].shape, (2, 2))
140+
# Second shard should have 2 columns added at right
141+
self.assertEqual(result.local_shards()[1].shape, (2, 4))
142+
143+
torch.testing.assert_close(result.local_shards()[0], shard1)
144+
torch.testing.assert_close(result.local_shards()[1][:, :2], shard2)
145+
torch.testing.assert_close(
146+
result.local_shards()[1][:, 2:], torch.full((2, 2), 3.0)
147+
)
148+
149+
def test_cw_sharding_mixed_padding(self) -> None:
150+
"""Test column-wise sharding with mixed padding directions."""
151+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
152+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
153+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
154+
pad_spec = [1, 2, 1, 1] # [left=1, right=2, top=1, bottom=1]
155+
pad_value = 0.0
156+
157+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
158+
159+
self.assertEqual(len(result.local_shards()), 2)
160+
self.assertEqual(result.local_shards()[0].shape, (4, 3))
161+
self.assertEqual(result.local_shards()[1].shape, (4, 4))
162+
163+
def test_rw_sharding_top_padding(self) -> None:
164+
"""Test row-wise sharding with top padding (affects first shard only).
165+
166+
In 1D RW sharding, conceptually we stack rows:
167+
[elem1, elem2, elem3, elem4, elem5, elem6] (global)
168+
shard1: [elem1, elem2, elem3] (top portion)
169+
shard2: [elem4, elem5, elem6] (bottom portion)
170+
171+
Top padding adds elements at the beginning (affects first shard).
172+
"""
173+
shard1 = torch.tensor([1.0, 2.0, 3.0])
174+
shard2 = torch.tensor([4.0, 5.0, 6.0])
175+
lsw = LocalShardsWrapper([shard1, shard2], [(0,), (3,)])
176+
pad_spec = [2, 0] # top=2
177+
pad_value = 0.0
178+
179+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
180+
181+
self.assertEqual(len(result.local_shards()), 2)
182+
self.assertEqual(result.local_shards()[0].shape, (5,))
183+
self.assertEqual(result.local_shards()[1].shape, (3,))
184+
185+
torch.testing.assert_close(result.local_shards()[0][:2], torch.zeros(2))
186+
torch.testing.assert_close(result.local_shards()[0][2:], shard1)
187+
torch.testing.assert_close(result.local_shards()[1], shard2)
188+
189+
def test_rw_sharding_bottom_padding(self) -> None:
190+
"""Test row-wise sharding with bottom padding (affects last shard only).
191+
192+
In 1D RW sharding, conceptually we stack rows:
193+
[elem1, elem2, elem3, elem4, elem5, elem6] (global)
194+
shard1: [elem1, elem2, elem3] (top portion)
195+
shard2: [elem4, elem5, elem6] (bottom portion)
196+
197+
Bottom padding adds elements at the end (affects last shard).
198+
"""
199+
shard1 = torch.tensor([1.0, 2.0, 3.0])
200+
shard2 = torch.tensor([4.0, 5.0, 6.0])
201+
lsw = LocalShardsWrapper([shard1, shard2], [(0,), (3,)])
202+
pad_spec = [0, 1] # bottom=1
203+
pad_value = -1.0
204+
205+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
206+
207+
self.assertEqual(len(result.local_shards()), 2)
208+
self.assertEqual(result.local_shards()[0].shape, (3,))
209+
self.assertEqual(result.local_shards()[1].shape, (4,))
210+
211+
torch.testing.assert_close(result.local_shards()[0], shard1)
212+
torch.testing.assert_close(result.local_shards()[1][:3], shard2)
213+
torch.testing.assert_close(result.local_shards()[1][3:], torch.tensor([-1.0]))
214+
215+
def test_rw_sharding_mixed_padding(self) -> None:
216+
"""Test row-wise sharding with mixed top/bottom padding."""
217+
shard1 = torch.tensor([1.0, 2.0])
218+
shard2 = torch.tensor([3.0, 4.0])
219+
lsw = LocalShardsWrapper([shard1, shard2], [(0,), (2,)])
220+
pad_spec = [1, 2] # [top=1, bottom=2]
221+
pad_value = 5.0
222+
223+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
224+
225+
self.assertEqual(len(result.local_shards()), 2)
226+
self.assertEqual(result.local_shards()[0].shape, (3,))
227+
self.assertEqual(result.local_shards()[1].shape, (4,))
228+
229+
def test_higher_dimensions_not_implemented(self) -> None:
230+
"""Test that higher dimensional tensors raise NotImplementedError."""
231+
tensor_3d = torch.rand(2, 3, 4) # 3D tensor
232+
lsw = LocalShardsWrapper([tensor_3d, tensor_3d], [(0, 0, 0), (2, 0, 0)])
233+
pad_spec = [1, 1, 1, 1, 1, 1] # 3D padding spec
234+
pad_value = 0.0
235+
236+
with self.assertRaises(NotImplementedError) as cm:
237+
torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
238+
239+
self.assertIn("3D tensors is not supported", str(cm.exception))
240+
self.assertIn(
241+
"Only 1D and 2D tensors are currently supported", str(cm.exception)
242+
)
243+
244+
def test_offsets_and_storage_metadata_after_padding_1D_rw(self) -> None:
245+
# Test 1D RW sharding with top+bottom padding
246+
shard1 = torch.tensor([1.0, 2.0])
247+
shard2 = torch.tensor([3.0, 4.0])
248+
original_offsets = [(0,), (2,)]
249+
lsw = LocalShardsWrapper([shard1, shard2], original_offsets)
250+
251+
# Check original storage metadata
252+
original_storage = lsw.storage_metadata()
253+
self.assertEqual(original_storage.size, torch.Size([4])) # [1,2,3,4]
254+
self.assertEqual(len(original_storage.chunks), 2)
255+
self.assertEqual(original_storage.chunks[0].offsets, torch.Size([0]))
256+
self.assertEqual(original_storage.chunks[0].sizes, torch.Size([2]))
257+
self.assertEqual(original_storage.chunks[1].offsets, torch.Size([2]))
258+
self.assertEqual(original_storage.chunks[1].sizes, torch.Size([2]))
259+
260+
pad_spec = [1, 1] # add 1 element at top and bottom
261+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, 0.0)
262+
263+
expected_offsets = [
264+
torch.Size([0]),
265+
torch.Size([3]),
266+
] # Second shard's offset shifted by 1
267+
self.assertEqual(result.local_offsets(), expected_offsets)
268+
269+
result_storage = result.storage_metadata()
270+
271+
# Global tensor should be: [0, 1, 2, 3, 4, 0] shape=[6]
272+
expected_global_size = torch.Size([6])
273+
self.assertEqual(result_storage.size, expected_global_size)
274+
275+
self.assertEqual(len(result_storage.chunks), 2)
276+
277+
# First chunk: [3] elements at offset [0] (size increased by top padding)
278+
# Second chunk: [3] elements at offset [3] (size increased by bottom padding, offset shifted)
279+
self.assertEqual(result_storage.chunks[0].offsets, torch.Size([0]))
280+
self.assertEqual(result_storage.chunks[0].sizes, torch.Size([3]))
281+
self.assertEqual(result_storage.chunks[1].offsets, torch.Size([3]))
282+
self.assertEqual(result_storage.chunks[1].sizes, torch.Size([3]))
283+
284+
def test_offsets_and_storage_metadata_after_padding_2D_cw(self) -> None:
285+
# Test 2D CW sharding with left+right padding
286+
shard1_2d = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) # [2, 2] columns 0-1
287+
shard2_2d = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) # [2, 2] columns 2-3
288+
original_offsets_2d = [(0, 0), (0, 2)]
289+
lsw_2d = LocalShardsWrapper([shard1_2d, shard2_2d], original_offsets_2d)
290+
291+
pad_spec_2d = [1, 1, 0, 0] # [left=1, right=1, top=0, bottom=0]
292+
result_2d = torch.ops.aten.constant_pad_nd.default(lsw_2d, pad_spec_2d, 0.0)
293+
294+
expected_offsets_2d = [
295+
torch.Size([0, 0]),
296+
torch.Size([0, 3]),
297+
]
298+
self.assertEqual(result_2d.local_offsets(), expected_offsets_2d)
299+
300+
result_storage_2d = result_2d.storage_metadata()
301+
302+
# Global tensor should go from [2,4] to [2,6] (add 1 left + 1 right)
303+
expected_global_size_2d = torch.Size([2, 6]) # [2, 4+1+1]
304+
self.assertEqual(result_storage_2d.size, expected_global_size_2d)
305+
306+
# First chunk: [2,3] at offset [0,0] (size increased by left padding)
307+
# Second chunk: [2,3] at offset [0,3] (size increased by right padding, offset shifted)
308+
self.assertEqual(result_storage_2d.chunks[0].offsets, torch.Size([0, 0]))
309+
self.assertEqual(result_storage_2d.chunks[0].sizes, torch.Size([2, 3]))
310+
self.assertEqual(result_storage_2d.chunks[1].offsets, torch.Size([0, 3]))
311+
self.assertEqual(result_storage_2d.chunks[1].sizes, torch.Size([2, 3]))
312+
313+
314+
if __name__ == "__main__":
315+
run_tests()

0 commit comments

Comments
 (0)