From 3deb6dbbd6797c8699d56e8fac6296114c1d3144 Mon Sep 17 00:00:00 2001 From: Kaustubh Vartak Date: Tue, 26 Aug 2025 13:56:05 -0700 Subject: [PATCH] Fix init param test for EBC Summary: We were skipping loading of state dict for CPU models causing the state dict comparison between sharded and unsharded model to not match. I removed this condition so that we skip only for meta devices. https://www.internalfb.com/intern/test/562950052233328 Differential Revision: D81039244 --- torchrec/distributed/embeddingbag.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 1551b585f..2851bf286 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -596,9 +596,8 @@ def __init__( if env.process_group and dist.get_backend(env.process_group) != "fake": self._initialize_torch_state() - if module.device not in ["meta", "cpu"] and module.device.type not in [ + if module.device not in ["meta"] and module.device.type not in [ "meta", - "cpu", ]: self.load_state_dict(module.state_dict(), strict=False)