Skip to content

Commit 7f068c8

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
[torchrec][LocalShardsWrapper] Implement tensor padding for local shards wrapper (pytorch#163183)
Summary: X-link: pytorch/torchrec#3382 This diff implements the constant padding functionality (aten.constant_pad_nd.default) for `LocalShardsWrapper`. The method applies constant padding to the local shards based on the provided padding specification. Depending on the sharding type (RW, CW), the padding on [left, right, top, bottom] directions will be either applied to the first/last shard, or all local shards. New unit tests cover: - 1D (RW) top/bottom paddings - 2D (CW) left, right, top, bottom paddings - empty shards, number of dimensions > 2 Test Plan: ``` buck2 test fbcode//caffe2/test/distributed/tensor:shards_wrapper 2025-09-18T15:32:46.525914Z WARN buck2_interpreter_for_build::interpreter::functions::warning: ptxas 12.8 is not available on platform platform010-aarch64-compat 2025-09-18T15:32:46.525953Z WARN buck2_interpreter_for_build::interpreter::functions::warning: ptxas 12.8 is not available on platform platform010-compat 2025-09-18T15:32:46.525959Z WARN buck2_interpreter_for_build::interpreter::functions::warning: ptxas 12.8 is not available on platform platform010-libcxx Buck UI: https://www.internalfb.com/buck2/ffb34bcb-1555-4fa3-89c6-9c22d078606a Test UI: https://www.internalfb.com/intern/testinfra/testrun/12384899087608299 Network: Up: 159MiB Down: 13GiB (reSessionID-f734bd3c-19ca-44c9-919f-57203ac00be8) Loading targets. Remaining 0/5110 104336 dirs read, 1265395 targets declared Analyzing targets. Remaining 0/80346 3349033 actions, 4142832 artifacts declared Executing actions. Remaining 0/521855 149:06:17.8s exec time total Command: test. Finished 14 local, 397 remote, 199840 cache (99% hit) 148:27:40.9s exec time cached (99%) Time elapsed: 8:55.5s Tests finished: Pass 14. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Differential Revision: D82663766
1 parent 1aeac30 commit 7f068c8

File tree

2 files changed

+497
-0
lines changed

2 files changed

+497
-0
lines changed
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates
2+
# Owner(s): ["oncall: distributed"]
3+
import unittest
4+
5+
import torch
6+
from torch.distributed.tensor._shards_wrapper import LocalShardsWrapper
7+
8+
class LocalShardsWrapperPaddingTest(unittest.TestCase):
9+
"""Test cases for constant padding functionality in LocalShardsWrapper."""
10+
11+
def test_empty_shards_padding(self) -> None:
12+
"""Test padding with empty shards list."""
13+
lsw = LocalShardsWrapper([], [])
14+
pad_spec = [1, 2, 3, 4]
15+
pad_value = 5.0
16+
17+
self.assertRaises(
18+
Exception,
19+
torch.ops.aten.constant_pad_nd.default,
20+
lsw,
21+
pad_spec,
22+
pad_value,
23+
)
24+
25+
def test_single_shard_padding_2d(self) -> None:
26+
"""Test padding with single 2D shard."""
27+
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
28+
lsw = LocalShardsWrapper([tensor], [(0, 0)])
29+
pad_spec = [1, 2, 3, 4] # [left=1, right=2, top=3, bottom=4]
30+
pad_value = 0.0
31+
32+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
33+
34+
self.assertIsInstance(result, LocalShardsWrapper)
35+
self.assertEqual(len(result.local_shards()), 1)
36+
37+
expected = torch.nn.functional.pad(
38+
tensor, pad_spec, mode="constant", value=pad_value
39+
)
40+
torch.testing.assert_close(result.local_shards()[0], expected)
41+
42+
def test_single_shard_padding_1d(self) -> None:
43+
"""Test padding with single 1D shard."""
44+
tensor = torch.tensor([1.0, 2.0, 3.0])
45+
lsw = LocalShardsWrapper([tensor], [(0,)])
46+
pad_spec = [2, 1] # [top=2, bottom=1]
47+
pad_value = -1.0
48+
49+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
50+
51+
self.assertIsInstance(result, LocalShardsWrapper)
52+
self.assertEqual(len(result.local_shards()), 1)
53+
54+
expected = torch.nn.functional.pad(
55+
tensor, pad_spec, mode="constant", value=pad_value
56+
)
57+
torch.testing.assert_close(result.local_shards()[0], expected)
58+
59+
def test_cw_sharding_top_padding(self) -> None:
60+
"""Test column-wise sharding with top padding (affects all shards)."""
61+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
62+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
63+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
64+
pad_spec = [0, 0, 2, 0] # top=2
65+
pad_value = 0.0
66+
67+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
68+
69+
self.assertEqual(len(result.local_shards()), 2)
70+
# Both shards should have 2 rows added at top
71+
expected_shape = (4, 2)
72+
self.assertEqual(result.local_shards()[0].shape, expected_shape)
73+
self.assertEqual(result.local_shards()[1].shape, expected_shape)
74+
75+
torch.testing.assert_close(result.local_shards()[0][:2], torch.zeros(2, 2))
76+
torch.testing.assert_close(result.local_shards()[1][:2], torch.zeros(2, 2))
77+
torch.testing.assert_close(result.local_shards()[0][2:], shard1)
78+
torch.testing.assert_close(result.local_shards()[1][2:], shard2)
79+
80+
def test_cw_sharding_bottom_padding(self) -> None:
81+
"""Test column-wise sharding with bottom padding (affects all shards)."""
82+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
83+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
84+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
85+
pad_spec = [0, 0, 0, 1] # bottom=1
86+
pad_value = -1.0
87+
88+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
89+
90+
self.assertEqual(len(result.local_shards()), 2)
91+
expected_shape = (3, 2)
92+
self.assertEqual(result.local_shards()[0].shape, expected_shape)
93+
self.assertEqual(result.local_shards()[1].shape, expected_shape)
94+
95+
torch.testing.assert_close(result.local_shards()[0][:2], shard1)
96+
torch.testing.assert_close(result.local_shards()[1][:2], shard2)
97+
torch.testing.assert_close(
98+
result.local_shards()[0][2:], torch.full((1, 2), -1.0)
99+
)
100+
torch.testing.assert_close(
101+
result.local_shards()[1][2:], torch.full((1, 2), -1.0)
102+
)
103+
104+
def test_cw_sharding_left_padding(self) -> None:
105+
"""Test column-wise sharding with left padding (affects first shard only)."""
106+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
107+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
108+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
109+
pad_spec = [3, 0, 0, 0] # left=3
110+
pad_value = 2.0
111+
112+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
113+
114+
self.assertEqual(len(result.local_shards()), 2)
115+
# First shard should have 3 columns added at left
116+
self.assertEqual(result.local_shards()[0].shape, (2, 5))
117+
self.assertEqual(result.local_shards()[1].shape, (2, 2))
118+
119+
# Check content
120+
torch.testing.assert_close(
121+
result.local_shards()[0][:, :3], torch.full((2, 3), 2.0)
122+
)
123+
torch.testing.assert_close(result.local_shards()[0][:, 3:], shard1)
124+
torch.testing.assert_close(result.local_shards()[1], shard2)
125+
126+
def test_cw_sharding_right_padding(self) -> None:
127+
"""Test column-wise sharding with right padding (affects last shard only)."""
128+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
129+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
130+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
131+
pad_spec = [0, 2, 0, 0] # right=2
132+
pad_value = 3.0
133+
134+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
135+
136+
self.assertEqual(len(result.local_shards()), 2)
137+
self.assertEqual(result.local_shards()[0].shape, (2, 2))
138+
# Second shard should have 2 columns added at right
139+
self.assertEqual(result.local_shards()[1].shape, (2, 4))
140+
141+
torch.testing.assert_close(result.local_shards()[0], shard1)
142+
torch.testing.assert_close(result.local_shards()[1][:, :2], shard2)
143+
torch.testing.assert_close(
144+
result.local_shards()[1][:, 2:], torch.full((2, 2), 3.0)
145+
)
146+
147+
def test_cw_sharding_mixed_padding(self) -> None:
148+
"""Test column-wise sharding with mixed padding directions."""
149+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
150+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
151+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
152+
pad_spec = [1, 2, 1, 1] # [left=1, right=2, top=1, bottom=1]
153+
pad_value = 0.0
154+
155+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
156+
157+
self.assertEqual(len(result.local_shards()), 2)
158+
self.assertEqual(result.local_shards()[0].shape, (4, 3))
159+
self.assertEqual(result.local_shards()[1].shape, (4, 4))
160+
161+
def test_rw_sharding_top_padding(self) -> None:
162+
"""Test row-wise sharding with top padding (affects first shard only).
163+
164+
In 1D RW sharding, conceptually we stack rows:
165+
[elem1, elem2, elem3, elem4, elem5, elem6] (global)
166+
shard1: [elem1, elem2, elem3] (top portion)
167+
shard2: [elem4, elem5, elem6] (bottom portion)
168+
169+
Top padding adds elements at the beginning (affects first shard).
170+
"""
171+
shard1 = torch.tensor([1.0, 2.0, 3.0])
172+
shard2 = torch.tensor([4.0, 5.0, 6.0])
173+
lsw = LocalShardsWrapper([shard1, shard2], [(0,), (3,)])
174+
pad_spec = [2, 0] # top=2
175+
pad_value = 0.0
176+
177+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
178+
179+
self.assertEqual(len(result.local_shards()), 2)
180+
self.assertEqual(result.local_shards()[0].shape, (5,))
181+
self.assertEqual(result.local_shards()[1].shape, (3,))
182+
183+
torch.testing.assert_close(result.local_shards()[0][:2], torch.zeros(2))
184+
torch.testing.assert_close(result.local_shards()[0][2:], shard1)
185+
torch.testing.assert_close(result.local_shards()[1], shard2)
186+
187+
def test_rw_sharding_bottom_padding(self) -> None:
188+
"""Test row-wise sharding with bottom padding (affects last shard only).
189+
190+
In 1D RW sharding, conceptually we stack rows:
191+
[elem1, elem2, elem3, elem4, elem5, elem6] (global)
192+
shard1: [elem1, elem2, elem3] (top portion)
193+
shard2: [elem4, elem5, elem6] (bottom portion)
194+
195+
Bottom padding adds elements at the end (affects last shard).
196+
"""
197+
shard1 = torch.tensor([1.0, 2.0, 3.0])
198+
shard2 = torch.tensor([4.0, 5.0, 6.0])
199+
lsw = LocalShardsWrapper([shard1, shard2], [(0,), (3,)])
200+
pad_spec = [0, 1] # bottom=1
201+
pad_value = -1.0
202+
203+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
204+
205+
self.assertEqual(len(result.local_shards()), 2)
206+
self.assertEqual(result.local_shards()[0].shape, (3,))
207+
self.assertEqual(result.local_shards()[1].shape, (4,))
208+
209+
torch.testing.assert_close(result.local_shards()[0], shard1)
210+
torch.testing.assert_close(result.local_shards()[1][:3], shard2)
211+
torch.testing.assert_close(result.local_shards()[1][3:], torch.tensor([-1.0]))
212+
213+
def test_rw_sharding_mixed_padding(self) -> None:
214+
"""Test row-wise sharding with mixed top/bottom padding."""
215+
shard1 = torch.tensor([1.0, 2.0])
216+
shard2 = torch.tensor([3.0, 4.0])
217+
lsw = LocalShardsWrapper([shard1, shard2], [(0,), (2,)])
218+
pad_spec = [1, 2] # [top=1, bottom=2]
219+
pad_value = 5.0
220+
221+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
222+
223+
self.assertEqual(len(result.local_shards()), 2)
224+
self.assertEqual(result.local_shards()[0].shape, (3,))
225+
self.assertEqual(result.local_shards()[1].shape, (4,))
226+
227+
def test_higher_dimensions_not_implemented(self) -> None:
228+
"""Test that higher dimensional tensors raise NotImplementedError."""
229+
tensor_3d = torch.rand(2, 3, 4) # 3D tensor
230+
lsw = LocalShardsWrapper([tensor_3d, tensor_3d], [(0, 0, 0), (2, 0, 0)])
231+
pad_spec = [1, 1, 1, 1, 1, 1] # 3D padding spec
232+
pad_value = 0.0
233+
234+
with self.assertRaises(NotImplementedError) as cm:
235+
torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
236+
237+
self.assertIn("3D tensors is not supported", str(cm.exception))
238+
self.assertIn(
239+
"Only 1D and 2D tensors are currently supported", str(cm.exception)
240+
)
241+
242+
def test_offsets_and_storage_metadata_after_padding_1D_rw(self) -> None:
243+
# Test 1D RW sharding with top+bottom padding
244+
shard1 = torch.tensor([1.0, 2.0])
245+
shard2 = torch.tensor([3.0, 4.0])
246+
original_offsets = [(0,), (2,)]
247+
lsw = LocalShardsWrapper([shard1, shard2], original_offsets)
248+
249+
# Check original storage metadata
250+
original_storage = lsw.storage_metadata()
251+
self.assertEqual(original_storage.size, torch.Size([4])) # [1,2,3,4]
252+
self.assertEqual(len(original_storage.chunks), 2)
253+
self.assertEqual(original_storage.chunks[0].offsets, torch.Size([0]))
254+
self.assertEqual(original_storage.chunks[0].sizes, torch.Size([2]))
255+
self.assertEqual(original_storage.chunks[1].offsets, torch.Size([2]))
256+
self.assertEqual(original_storage.chunks[1].sizes, torch.Size([2]))
257+
258+
pad_spec = [1, 1] # add 1 element at top and bottom
259+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, 0.0)
260+
261+
expected_offsets = [
262+
torch.Size([0]),
263+
torch.Size([3]),
264+
] # Second shard's offset shifted by 1
265+
self.assertEqual(result.local_offsets(), expected_offsets)
266+
267+
result_storage = result.storage_metadata()
268+
269+
# Global tensor should be: [0, 1, 2, 3, 4, 0] shape=[6]
270+
expected_global_size = torch.Size([6])
271+
self.assertEqual(result_storage.size, expected_global_size)
272+
273+
self.assertEqual(len(result_storage.chunks), 2)
274+
275+
# First chunk: [3] elements at offset [0] (size increased by top padding)
276+
# Second chunk: [3] elements at offset [3] (size increased by bottom padding, offset shifted)
277+
self.assertEqual(result_storage.chunks[0].offsets, torch.Size([0]))
278+
self.assertEqual(result_storage.chunks[0].sizes, torch.Size([3]))
279+
self.assertEqual(result_storage.chunks[1].offsets, torch.Size([3]))
280+
self.assertEqual(result_storage.chunks[1].sizes, torch.Size([3]))
281+
282+
def test_offsets_and_storage_metadata_after_padding_2D_cw(self) -> None:
283+
# Test 2D CW sharding with left+right padding
284+
shard1_2d = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) # [2, 2] columns 0-1
285+
shard2_2d = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) # [2, 2] columns 2-3
286+
original_offsets_2d = [(0, 0), (0, 2)]
287+
lsw_2d = LocalShardsWrapper([shard1_2d, shard2_2d], original_offsets_2d)
288+
289+
pad_spec_2d = [1, 1, 0, 0] # [left=1, right=1, top=0, bottom=0]
290+
result_2d = torch.ops.aten.constant_pad_nd.default(lsw_2d, pad_spec_2d, 0.0)
291+
292+
expected_offsets_2d = [
293+
torch.Size([0, 0]),
294+
torch.Size([0, 3]),
295+
]
296+
self.assertEqual(result_2d.local_offsets(), expected_offsets_2d)
297+
298+
result_storage_2d = result_2d.storage_metadata()
299+
300+
# Global tensor should go from [2,4] to [2,6] (add 1 left + 1 right)
301+
expected_global_size_2d = torch.Size([2, 6]) # [2, 4+1+1]
302+
self.assertEqual(result_storage_2d.size, expected_global_size_2d)
303+
304+
# First chunk: [2,3] at offset [0,0] (size increased by left padding)
305+
# Second chunk: [2,3] at offset [0,3] (size increased by right padding, offset shifted)
306+
self.assertEqual(result_storage_2d.chunks[0].offsets, torch.Size([0, 0]))
307+
self.assertEqual(result_storage_2d.chunks[0].sizes, torch.Size([2, 3]))
308+
self.assertEqual(result_storage_2d.chunks[1].offsets, torch.Size([0, 3]))
309+
self.assertEqual(result_storage_2d.chunks[1].sizes, torch.Size([2, 3]))

0 commit comments

Comments
 (0)