Skip to content

Commit 2ac0b21

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Fix metadata tensor size and add ut to cover that. (#4704)
Summary: Pull Request resolved: #4704 X-link: facebookresearch/FBGEMM#1729 As title, currently has size mismatch issue {F1981225318} Reviewed By: kathyxuyy Differential Revision: D80282519 fbshipit-source-id: df9e7a588bb7d73aaa0a993c98d669aaebc7761d
1 parent 3f017e6 commit 2ac0b21

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2982,7 +2982,7 @@ def split_embedding_weights(
29822982
bucket_ascending_id_tensor + table_offset,
29832983
torch.as_tensor(bucket_ascending_id_tensor.size(0)),
29842984
snapshot_handle,
2985-
)
2985+
).view(-1, 1)
29862986

29872987
# 3. convert local id back to global id
29882988
bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2250,6 +2250,8 @@ def test_kv_emb_state_dict(
22502250
rtol=tolerance,
22512251
)
22522252

2253+
self.assertTrue(len(metadata_list[table_index].size()) == 2)
2254+
22532255
@given(
22542256
**{
22552257
"T": st.integers(min_value=1, max_value=10),

0 commit comments

Comments
 (0)