diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 1551b585f..2851bf286 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -596,9 +596,8 @@ def __init__( if env.process_group and dist.get_backend(env.process_group) != "fake": self._initialize_torch_state() - if module.device not in ["meta", "cpu"] and module.device.type not in [ + if module.device not in ["meta"] and module.device.type not in [ "meta", - "cpu", ]: self.load_state_dict(module.state_dict(), strict=False)