|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates |
| 2 | +# Owner(s): ["oncall: distributed"] |
| 3 | +import unittest |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch.distributed.tensor._shards_wrapper import LocalShardsWrapper |
| 7 | +from torch.testing._internal.common_utils import run_tests |
| 8 | + |
| 9 | + |
| 10 | +class LocalShardsWrapperPaddingTest(unittest.TestCase): |
| 11 | + """Test cases for constant padding functionality in LocalShardsWrapper.""" |
| 12 | + |
| 13 | + def test_empty_shards_padding(self) -> None: |
| 14 | + """Test padding with empty shards list.""" |
| 15 | + lsw = LocalShardsWrapper([], []) |
| 16 | + pad_spec = [1, 2, 3, 4] |
| 17 | + pad_value = 5.0 |
| 18 | + |
| 19 | + self.assertRaises( |
| 20 | + Exception, |
| 21 | + torch.ops.aten.constant_pad_nd.default, |
| 22 | + lsw, |
| 23 | + pad_spec, |
| 24 | + pad_value, |
| 25 | + ) |
| 26 | + |
| 27 | + def test_single_shard_padding_2d(self) -> None: |
| 28 | + """Test padding with single 2D shard.""" |
| 29 | + tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) |
| 30 | + lsw = LocalShardsWrapper([tensor], [(0, 0)]) |
| 31 | + pad_spec = [1, 2, 3, 4] # [left=1, right=2, top=3, bottom=4] |
| 32 | + pad_value = 0.0 |
| 33 | + |
| 34 | + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) |
| 35 | + |
| 36 | + self.assertIsInstance(result, LocalShardsWrapper) |
| 37 | + self.assertEqual(len(result.local_shards()), 1) |
| 38 | + |
| 39 | + expected = torch.nn.functional.pad( |
| 40 | + tensor, pad_spec, mode="constant", value=pad_value |
| 41 | + ) |
| 42 | + torch.testing.assert_close(result.local_shards()[0], expected) |
| 43 | + |
| 44 | + def test_single_shard_padding_1d(self) -> None: |
| 45 | + """Test padding with single 1D shard.""" |
| 46 | + tensor = torch.tensor([1.0, 2.0, 3.0]) |
| 47 | + lsw = LocalShardsWrapper([tensor], [(0,)]) |
| 48 | + pad_spec = [2, 1] # [top=2, bottom=1] |
| 49 | + pad_value = -1.0 |
| 50 | + |
| 51 | + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) |
| 52 | + |
| 53 | + self.assertIsInstance(result, LocalShardsWrapper) |
| 54 | + self.assertEqual(len(result.local_shards()), 1) |
| 55 | + |
| 56 | + expected = torch.nn.functional.pad( |
| 57 | + tensor, pad_spec, mode="constant", value=pad_value |
| 58 | + ) |
| 59 | + torch.testing.assert_close(result.local_shards()[0], expected) |
| 60 | + |
| 61 | + def test_cw_sharding_top_padding(self) -> None: |
| 62 | + """Test column-wise sharding with top padding (affects all shards).""" |
| 63 | + shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) |
| 64 | + shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) |
| 65 | + lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)]) |
| 66 | + pad_spec = [0, 0, 2, 0] # top=2 |
| 67 | + pad_value = 0.0 |
| 68 | + |
| 69 | + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) |
| 70 | + |
| 71 | + self.assertEqual(len(result.local_shards()), 2) |
| 72 | + # Both shards should have 2 rows added at top |
| 73 | + expected_shape = (4, 2) |
| 74 | + self.assertEqual(result.local_shards()[0].shape, expected_shape) |
| 75 | + self.assertEqual(result.local_shards()[1].shape, expected_shape) |
| 76 | + |
| 77 | + torch.testing.assert_close(result.local_shards()[0][:2], torch.zeros(2, 2)) |
| 78 | + torch.testing.assert_close(result.local_shards()[1][:2], torch.zeros(2, 2)) |
| 79 | + torch.testing.assert_close(result.local_shards()[0][2:], shard1) |
| 80 | + torch.testing.assert_close(result.local_shards()[1][2:], shard2) |
| 81 | + |
| 82 | + def test_cw_sharding_bottom_padding(self) -> None: |
| 83 | + """Test column-wise sharding with bottom padding (affects all shards).""" |
| 84 | + shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) |
| 85 | + shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) |
| 86 | + lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)]) |
| 87 | + pad_spec = [0, 0, 0, 1] # bottom=1 |
| 88 | + pad_value = -1.0 |
| 89 | + |
| 90 | + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) |
| 91 | + |
| 92 | + self.assertEqual(len(result.local_shards()), 2) |
| 93 | + expected_shape = (3, 2) |
| 94 | + self.assertEqual(result.local_shards()[0].shape, expected_shape) |
| 95 | + self.assertEqual(result.local_shards()[1].shape, expected_shape) |
| 96 | + |
| 97 | + torch.testing.assert_close(result.local_shards()[0][:2], shard1) |
| 98 | + torch.testing.assert_close(result.local_shards()[1][:2], shard2) |
| 99 | + torch.testing.assert_close( |
| 100 | + result.local_shards()[0][2:], torch.full((1, 2), -1.0) |
| 101 | + ) |
| 102 | + torch.testing.assert_close( |
| 103 | + result.local_shards()[1][2:], torch.full((1, 2), -1.0) |
| 104 | + ) |
| 105 | + |
| 106 | + def test_cw_sharding_left_padding(self) -> None: |
| 107 | + """Test column-wise sharding with left padding (affects first shard only).""" |
| 108 | + shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) |
| 109 | + shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) |
| 110 | + lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)]) |
| 111 | + pad_spec = [3, 0, 0, 0] # left=3 |
| 112 | + pad_value = 2.0 |
| 113 | + |
| 114 | + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) |
| 115 | + |
| 116 | + self.assertEqual(len(result.local_shards()), 2) |
| 117 | + # First shard should have 3 columns added at left |
| 118 | + self.assertEqual(result.local_shards()[0].shape, (2, 5)) |
| 119 | + self.assertEqual(result.local_shards()[1].shape, (2, 2)) |
| 120 | + |
| 121 | + # Check content |
| 122 | + torch.testing.assert_close( |
| 123 | + result.local_shards()[0][:, :3], torch.full((2, 3), 2.0) |
| 124 | + ) |
| 125 | + torch.testing.assert_close(result.local_shards()[0][:, 3:], shard1) |
| 126 | + torch.testing.assert_close(result.local_shards()[1], shard2) |
| 127 | + |
| 128 | + def test_cw_sharding_right_padding(self) -> None: |
| 129 | + """Test column-wise sharding with right padding (affects last shard only).""" |
| 130 | + shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) |
| 131 | + shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) |
| 132 | + lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)]) |
| 133 | + pad_spec = [0, 2, 0, 0] # right=2 |
| 134 | + pad_value = 3.0 |
| 135 | + |
| 136 | + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) |
| 137 | + |
| 138 | + self.assertEqual(len(result.local_shards()), 2) |
| 139 | + self.assertEqual(result.local_shards()[0].shape, (2, 2)) |
| 140 | + # Second shard should have 2 columns added at right |
| 141 | + self.assertEqual(result.local_shards()[1].shape, (2, 4)) |
| 142 | + |
| 143 | + torch.testing.assert_close(result.local_shards()[0], shard1) |
| 144 | + torch.testing.assert_close(result.local_shards()[1][:, :2], shard2) |
| 145 | + torch.testing.assert_close( |
| 146 | + result.local_shards()[1][:, 2:], torch.full((2, 2), 3.0) |
| 147 | + ) |
| 148 | + |
| 149 | + def test_cw_sharding_mixed_padding(self) -> None: |
| 150 | + """Test column-wise sharding with mixed padding directions.""" |
| 151 | + shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) |
| 152 | + shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) |
| 153 | + lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)]) |
| 154 | + pad_spec = [1, 2, 1, 1] # [left=1, right=2, top=1, bottom=1] |
| 155 | + pad_value = 0.0 |
| 156 | + |
| 157 | + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) |
| 158 | + |
| 159 | + self.assertEqual(len(result.local_shards()), 2) |
| 160 | + self.assertEqual(result.local_shards()[0].shape, (4, 3)) |
| 161 | + self.assertEqual(result.local_shards()[1].shape, (4, 4)) |
| 162 | + |
| 163 | + def test_rw_sharding_top_padding(self) -> None: |
| 164 | + """Test row-wise sharding with top padding (affects first shard only). |
| 165 | +
|
| 166 | + In 1D RW sharding, conceptually we stack rows: |
| 167 | + [elem1, elem2, elem3, elem4, elem5, elem6] (global) |
| 168 | + shard1: [elem1, elem2, elem3] (top portion) |
| 169 | + shard2: [elem4, elem5, elem6] (bottom portion) |
| 170 | +
|
| 171 | + Top padding adds elements at the beginning (affects first shard). |
| 172 | + """ |
| 173 | + shard1 = torch.tensor([1.0, 2.0, 3.0]) |
| 174 | + shard2 = torch.tensor([4.0, 5.0, 6.0]) |
| 175 | + lsw = LocalShardsWrapper([shard1, shard2], [(0,), (3,)]) |
| 176 | + pad_spec = [2, 0] # top=2 |
| 177 | + pad_value = 0.0 |
| 178 | + |
| 179 | + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) |
| 180 | + |
| 181 | + self.assertEqual(len(result.local_shards()), 2) |
| 182 | + self.assertEqual(result.local_shards()[0].shape, (5,)) |
| 183 | + self.assertEqual(result.local_shards()[1].shape, (3,)) |
| 184 | + |
| 185 | + torch.testing.assert_close(result.local_shards()[0][:2], torch.zeros(2)) |
| 186 | + torch.testing.assert_close(result.local_shards()[0][2:], shard1) |
| 187 | + torch.testing.assert_close(result.local_shards()[1], shard2) |
| 188 | + |
| 189 | + def test_rw_sharding_bottom_padding(self) -> None: |
| 190 | + """Test row-wise sharding with bottom padding (affects last shard only). |
| 191 | +
|
| 192 | + In 1D RW sharding, conceptually we stack rows: |
| 193 | + [elem1, elem2, elem3, elem4, elem5, elem6] (global) |
| 194 | + shard1: [elem1, elem2, elem3] (top portion) |
| 195 | + shard2: [elem4, elem5, elem6] (bottom portion) |
| 196 | +
|
| 197 | + Bottom padding adds elements at the end (affects last shard). |
| 198 | + """ |
| 199 | + shard1 = torch.tensor([1.0, 2.0, 3.0]) |
| 200 | + shard2 = torch.tensor([4.0, 5.0, 6.0]) |
| 201 | + lsw = LocalShardsWrapper([shard1, shard2], [(0,), (3,)]) |
| 202 | + pad_spec = [0, 1] # bottom=1 |
| 203 | + pad_value = -1.0 |
| 204 | + |
| 205 | + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) |
| 206 | + |
| 207 | + self.assertEqual(len(result.local_shards()), 2) |
| 208 | + self.assertEqual(result.local_shards()[0].shape, (3,)) |
| 209 | + self.assertEqual(result.local_shards()[1].shape, (4,)) |
| 210 | + |
| 211 | + torch.testing.assert_close(result.local_shards()[0], shard1) |
| 212 | + torch.testing.assert_close(result.local_shards()[1][:3], shard2) |
| 213 | + torch.testing.assert_close(result.local_shards()[1][3:], torch.tensor([-1.0])) |
| 214 | + |
| 215 | + def test_rw_sharding_mixed_padding(self) -> None: |
| 216 | + """Test row-wise sharding with mixed top/bottom padding.""" |
| 217 | + shard1 = torch.tensor([1.0, 2.0]) |
| 218 | + shard2 = torch.tensor([3.0, 4.0]) |
| 219 | + lsw = LocalShardsWrapper([shard1, shard2], [(0,), (2,)]) |
| 220 | + pad_spec = [1, 2] # [top=1, bottom=2] |
| 221 | + pad_value = 5.0 |
| 222 | + |
| 223 | + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) |
| 224 | + |
| 225 | + self.assertEqual(len(result.local_shards()), 2) |
| 226 | + self.assertEqual(result.local_shards()[0].shape, (3,)) |
| 227 | + self.assertEqual(result.local_shards()[1].shape, (4,)) |
| 228 | + |
| 229 | + def test_higher_dimensions_not_implemented(self) -> None: |
| 230 | + """Test that higher dimensional tensors raise NotImplementedError.""" |
| 231 | + tensor_3d = torch.rand(2, 3, 4) # 3D tensor |
| 232 | + lsw = LocalShardsWrapper([tensor_3d, tensor_3d], [(0, 0, 0), (2, 0, 0)]) |
| 233 | + pad_spec = [1, 1, 1, 1, 1, 1] # 3D padding spec |
| 234 | + pad_value = 0.0 |
| 235 | + |
| 236 | + with self.assertRaises(NotImplementedError) as cm: |
| 237 | + torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) |
| 238 | + |
| 239 | + self.assertIn("3D tensors is not supported", str(cm.exception)) |
| 240 | + self.assertIn( |
| 241 | + "Only 1D and 2D tensors are currently supported", str(cm.exception) |
| 242 | + ) |
| 243 | + |
| 244 | + def test_offsets_and_storage_metadata_after_padding_1D_rw(self) -> None: |
| 245 | + # Test 1D RW sharding with top+bottom padding |
| 246 | + shard1 = torch.tensor([1.0, 2.0]) |
| 247 | + shard2 = torch.tensor([3.0, 4.0]) |
| 248 | + original_offsets = [(0,), (2,)] |
| 249 | + lsw = LocalShardsWrapper([shard1, shard2], original_offsets) |
| 250 | + |
| 251 | + # Check original storage metadata |
| 252 | + original_storage = lsw.storage_metadata() |
| 253 | + self.assertEqual(original_storage.size, torch.Size([4])) # [1,2,3,4] |
| 254 | + self.assertEqual(len(original_storage.chunks), 2) |
| 255 | + self.assertEqual(original_storage.chunks[0].offsets, torch.Size([0])) |
| 256 | + self.assertEqual(original_storage.chunks[0].sizes, torch.Size([2])) |
| 257 | + self.assertEqual(original_storage.chunks[1].offsets, torch.Size([2])) |
| 258 | + self.assertEqual(original_storage.chunks[1].sizes, torch.Size([2])) |
| 259 | + |
| 260 | + pad_spec = [1, 1] # add 1 element at top and bottom |
| 261 | + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, 0.0) |
| 262 | + |
| 263 | + expected_offsets = [ |
| 264 | + torch.Size([0]), |
| 265 | + torch.Size([3]), |
| 266 | + ] # Second shard's offset shifted by 1 |
| 267 | + self.assertEqual(result.local_offsets(), expected_offsets) |
| 268 | + |
| 269 | + result_storage = result.storage_metadata() |
| 270 | + |
| 271 | + # Global tensor should be: [0, 1, 2, 3, 4, 0] shape=[6] |
| 272 | + expected_global_size = torch.Size([6]) |
| 273 | + self.assertEqual(result_storage.size, expected_global_size) |
| 274 | + |
| 275 | + self.assertEqual(len(result_storage.chunks), 2) |
| 276 | + |
| 277 | + # First chunk: [3] elements at offset [0] (size increased by top padding) |
| 278 | + # Second chunk: [3] elements at offset [3] (size increased by bottom padding, offset shifted) |
| 279 | + self.assertEqual(result_storage.chunks[0].offsets, torch.Size([0])) |
| 280 | + self.assertEqual(result_storage.chunks[0].sizes, torch.Size([3])) |
| 281 | + self.assertEqual(result_storage.chunks[1].offsets, torch.Size([3])) |
| 282 | + self.assertEqual(result_storage.chunks[1].sizes, torch.Size([3])) |
| 283 | + |
| 284 | + def test_offsets_and_storage_metadata_after_padding_2D_cw(self) -> None: |
| 285 | + # Test 2D CW sharding with left+right padding |
| 286 | + shard1_2d = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) # [2, 2] columns 0-1 |
| 287 | + shard2_2d = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) # [2, 2] columns 2-3 |
| 288 | + original_offsets_2d = [(0, 0), (0, 2)] |
| 289 | + lsw_2d = LocalShardsWrapper([shard1_2d, shard2_2d], original_offsets_2d) |
| 290 | + |
| 291 | + pad_spec_2d = [1, 1, 0, 0] # [left=1, right=1, top=0, bottom=0] |
| 292 | + result_2d = torch.ops.aten.constant_pad_nd.default(lsw_2d, pad_spec_2d, 0.0) |
| 293 | + |
| 294 | + expected_offsets_2d = [ |
| 295 | + torch.Size([0, 0]), |
| 296 | + torch.Size([0, 3]), |
| 297 | + ] |
| 298 | + self.assertEqual(result_2d.local_offsets(), expected_offsets_2d) |
| 299 | + |
| 300 | + result_storage_2d = result_2d.storage_metadata() |
| 301 | + |
| 302 | + # Global tensor should go from [2,4] to [2,6] (add 1 left + 1 right) |
| 303 | + expected_global_size_2d = torch.Size([2, 6]) # [2, 4+1+1] |
| 304 | + self.assertEqual(result_storage_2d.size, expected_global_size_2d) |
| 305 | + |
| 306 | + # First chunk: [2,3] at offset [0,0] (size increased by left padding) |
| 307 | + # Second chunk: [2,3] at offset [0,3] (size increased by right padding, offset shifted) |
| 308 | + self.assertEqual(result_storage_2d.chunks[0].offsets, torch.Size([0, 0])) |
| 309 | + self.assertEqual(result_storage_2d.chunks[0].sizes, torch.Size([2, 3])) |
| 310 | + self.assertEqual(result_storage_2d.chunks[1].offsets, torch.Size([0, 3])) |
| 311 | + self.assertEqual(result_storage_2d.chunks[1].sizes, torch.Size([2, 3])) |
| 312 | + |
| 313 | + |
| 314 | +if __name__ == "__main__": |
| 315 | + run_tests() |
0 commit comments