Skip to content

Commit 921ac09

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
[torchrec][LocalShardsWrapper] Implement tensor padding for local shards wrapper (pytorch#163183)
Summary: X-link: pytorch/torchrec#3382 This diff implements the constant padding functionality (aten.constant_pad_nd.default) for `LocalShardsWrapper`. The method applies constant padding to the local shards based on the provided padding specification. Depending on the sharding type (RW, CW), the padding on [left, right, top, bottom] directions will be either applied to the first/last shard, or all local shards. New unit tests cover: - 1D (RW) top/bottom paddings - 2D (CW) left, right, top, bottom paddings - empty shards, number of dimensions > 2 Test Plan: ``` buck2 test fbcode//mode/opt fbcode//torchrec/distributed/tests:test_shards_wrapper <...> Buck UI: https://www.internalfb.com/buck2/9fff7732-346a-43eb-b1a0-f0e43e2e8815 Test UI: https://www.internalfb.com/intern/testinfra/testrun/18014398620870153 Network: Up: 110KiB Down: 95KiB (reSessionID-c0cdcb56-f82e-4f42-9fb8-54d8a3fb74eb) Analyzing targets. Remaining 0/191 Executing actions. Remaining 0/12849 7.6s exec time total Command: test. Finished 5 local Time elapsed: 1:40.1s Test execution completed but tests were skipped Tests finished: Pass 14. Fail 0. Fatal 0. Skip 3. Build failure 0 ``` Rollback Plan: Differential Revision: D82663766
1 parent 7a0f933 commit 921ac09

File tree

1 file changed

+188
-0
lines changed

1 file changed

+188
-0
lines changed

torch/distributed/tensor/_shards_wrapper.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[
109109
aten.detach.default: cls.handle_detach,
110110
aten.clone.default: cls.handle_clone,
111111
aten.new_empty.default: cls.handle_new_empty,
112+
aten.constant_pad_nd.default: cls.handle_constant_pad_nd,
112113
}
113114

114115
if func in dispatcher:
@@ -223,6 +224,193 @@ def handle_new_empty(args, kwargs) -> "LocalShardsWrapper":
223224
self_ls.local_offsets(),
224225
)
225226

227+
@staticmethod
228+
def handle_constant_pad_nd(args, kwargs) -> "LocalShardsWrapper":
229+
"""
230+
Apply constant padding to LocalShardsWrapper.
231+
232+
The padding is based off of the following ideas:
233+
- The resulting wrapper represents the padded version of the logical tensor.
234+
- Each shard is padded based on the sharding type + dimension that is padded.
235+
- For instance, CW shards padded on the left most col will have only padding on the first CW shard.
236+
- Padding the top row will apply to all CW shards.
237+
"""
238+
self_lsw = args[0]
239+
pad_spec = args[1]
240+
pad_value = args[2] if len(args) > 2 else 0.0
241+
242+
if len(self_lsw.local_shards()) == 0:
243+
raise NotImplementedError(
244+
"Padding empty LocalShardsWrapper is not supported."
245+
)
246+
247+
local_shards = self_lsw.local_shards()
248+
249+
if len(local_shards) == 1:
250+
padded_shard = torch.nn.functional.pad(
251+
local_shards[0], pad_spec, mode="constant", value=pad_value
252+
)
253+
return LocalShardsWrapper([padded_shard], self_lsw.local_offsets())
254+
255+
padded_shards = list(local_shards)
256+
257+
if local_shards[0].ndim == 2:
258+
# 2D Column-wise sharding: [pad_left, pad_right, pad_top, pad_bottom]
259+
pad_left, pad_right, pad_top, pad_bottom = (
260+
pad_spec[0],
261+
pad_spec[1],
262+
pad_spec[2],
263+
pad_spec[3],
264+
)
265+
266+
if pad_top > 0:
267+
padded_shards = [
268+
torch.nn.functional.pad(
269+
shard, [0, 0, pad_top, 0], mode="constant", value=pad_value
270+
)
271+
for shard in padded_shards
272+
]
273+
if pad_bottom > 0:
274+
padded_shards = [
275+
torch.nn.functional.pad(
276+
shard, [0, 0, 0, pad_bottom], mode="constant", value=pad_value
277+
)
278+
for shard in padded_shards
279+
]
280+
if pad_left > 0:
281+
padded_shards[0] = torch.nn.functional.pad(
282+
padded_shards[0],
283+
[pad_left, 0, 0, 0],
284+
mode="constant",
285+
value=pad_value,
286+
)
287+
if pad_right > 0:
288+
padded_shards[-1] = torch.nn.functional.pad(
289+
padded_shards[-1],
290+
[0, pad_right, 0, 0],
291+
mode="constant",
292+
value=pad_value,
293+
)
294+
elif local_shards[0].ndim == 1:
295+
# 1D Row-wise sharding: [pad_top, pad_bottom]
296+
pad_top, pad_bottom = pad_spec[0], pad_spec[1]
297+
298+
if pad_top > 0:
299+
padded_shards[0] = torch.nn.functional.pad(
300+
padded_shards[0], [pad_top, 0], mode="constant", value=pad_value
301+
)
302+
if pad_bottom > 0:
303+
padded_shards[-1] = torch.nn.functional.pad(
304+
padded_shards[-1], [0, pad_bottom], mode="constant", value=pad_value
305+
)
306+
else:
307+
raise NotImplementedError(
308+
f"Padding for {local_shards[0].ndim}D tensors is not supported. "
309+
f"Only 1D and 2D tensors are currently supported."
310+
)
311+
312+
# Update offsets and storage metadata
313+
original_storage = self_lsw.storage_metadata()
314+
updated_offsets, updated_storage = LocalShardsWrapper._compute_updated_metadata(
315+
original_storage,
316+
self_lsw.local_offsets(),
317+
pad_spec,
318+
local_shards[0].ndim,
319+
padded_shards,
320+
)
321+
322+
result = LocalShardsWrapper(padded_shards, updated_offsets)
323+
result._storage_meta = updated_storage
324+
return result
325+
326+
@staticmethod
327+
def _compute_updated_metadata(
328+
original_storage: TensorStorageMetadata,
329+
original_offsets: list[torch.Size],
330+
pad_spec: list[int],
331+
ndim: int,
332+
padded_shards: list[torch.Tensor],
333+
) -> tuple[list[tuple[int, ...]], TensorStorageMetadata]:
334+
"""
335+
Compute updated offsets and storage metadata after padding is applied.
336+
337+
Args:
338+
original_storage: Original storage metadata
339+
original_offsets: Original shard offsets
340+
pad_spec: Padding specification
341+
ndim: Number of dimensions (1=RW or 2=CW)
342+
padded_shards: Padded shard tensors
343+
344+
Returns:
345+
Tuple of (updated_offsets, updated_storage_metadata)
346+
"""
347+
if ndim == 1: # 1D RW
348+
pad_top, pad_bottom = pad_spec[0], pad_spec[1]
349+
350+
updated_offsets = []
351+
for i, offset in enumerate(original_offsets):
352+
if i == 0:
353+
# First shard: offset stays the same (absorbs top padding)
354+
updated_offsets.append(tuple(offset))
355+
else:
356+
# Subsequent shards: shift by top padding amount
357+
new_offset = (offset[0] + pad_top,)
358+
updated_offsets.append(new_offset)
359+
360+
new_global_size = torch.Size(
361+
[original_storage.size[0] + pad_top + pad_bottom]
362+
)
363+
364+
elif ndim == 2: # 2D CW
365+
pad_left, pad_right, pad_top, pad_bottom = (
366+
pad_spec[0],
367+
pad_spec[1],
368+
pad_spec[2],
369+
pad_spec[3],
370+
)
371+
372+
updated_offsets = []
373+
for i, offset in enumerate(original_offsets):
374+
row_offset = offset[0]
375+
col_offset = offset[1]
376+
377+
# Top/bottom padding doesn't affect offsets
378+
# Left padding affects column offsets
379+
if i == 0:
380+
# First shard: column offset stays the same (absorbs left padding)
381+
new_offset = (row_offset, col_offset)
382+
else:
383+
# Subsequent shards: shift column offset by left padding amount
384+
new_offset = (row_offset, col_offset + pad_left)
385+
386+
updated_offsets.append(new_offset)
387+
388+
new_global_size = torch.Size(
389+
[
390+
original_storage.size[0] + pad_top + pad_bottom,
391+
original_storage.size[1] + pad_left + pad_right,
392+
]
393+
)
394+
395+
else:
396+
raise NotImplementedError(f"Metadata computation for {ndim}D not supported")
397+
398+
updated_chunks = [
399+
ChunkStorageMetadata(
400+
offsets=torch.Size(offset),
401+
sizes=shard.size(),
402+
)
403+
for offset, shard in zip(updated_offsets, padded_shards)
404+
]
405+
406+
updated_storage = TensorStorageMetadata(
407+
properties=original_storage.properties,
408+
size=new_global_size,
409+
chunks=updated_chunks,
410+
)
411+
412+
return updated_offsets, updated_storage
413+
226414
@property
227415
def device(self) -> torch._C.device: # type: ignore[override]
228416
return (

0 commit comments

Comments
 (0)