Skip to content

support get state dict and apply state dict #4145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

emlin
Copy link
Contributor

@emlin emlin commented May 17, 2025

Summary:

  • while loading state dict, convert id from local to global, so when table is resharded, no need to shift id based on sharding decision
  • enable load state dict mode for checkpoint loading.
    Since checkpoint client does not support ordered tensor loading, we implemented a short term solution to cache all id, weight and bucket tensor in memory, and after all data are loaded, apply everything to backend.

The current solution is to cache all data in python tensor, here is the flow to use these interfaces:

  • set self.local_weight_counts based on checkpoint bucket tensor size
  • enable_load_state_dict_mode to initialize local cache tensor
  • call state_dict to get empty tensors from checkpoint loader
  • write checkpoint data to cached tensors by checkpoint loader
  • call apply_state_dict to write all cached tensor to backend
    • in apply_state_dict:
      • if optimizer offloading is enabled:
        • for loop chunk of weight and optimizer
        • concat weight and optimizer together
        • write to backend through KVTensorWrapper interface
      • if optimizer offloading is disabled:
        • set optimizer to device tensor based on id
        • write id weight to backend

With this solution is that, when write data to backend, the python tensor's memory cannot be released until all tensor data is duplicated in backend. In a short time, there will have one table's weight tensor be duplicated, we need to make sure the memory capacity is enough.
In addition, with optimizer offloading, we need to concat weight and optimizer together before we can write to backend, since the input data need to be contiguous. To avoid triple the memory consumption, we write chunk data in a for loop, which will cause write performance regression.

After the first version e2e is ready, we'll support unordered loading from backend to improve performance and also reduce memory overhead.

Saving State Dict
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

Checkpoint Loading Mode
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

Current Solution
The current solution involves caching all data in Python tensors, following these steps:

  • Set self.local_weight_counts based on checkpoint bucket tensor size.
  • Enable load state dict mode to initialize local cache tensors.
  • Call state_dict to get empty tensors for the checkpoint loader.
  • Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
  • Call apply_state_dict to write all cached tensors to the backend.

Apply State Dict Flow
During the apply_state_dict step, we perform the following operations:

  • If optimizer offloading is enabled:
    • Loop through chunks of weight and optimizer.
    • Concatenate weight and optimizer together.
    • Write to backend using KVTensorWrapper interface.
  • If optimizer offloading is disabled:
    • Set optimizer to device tensor based on ID.
    • Write ID weight to backend for each table.

Limitations
The current solution has two limitations:

  • Memory overhead:
    • When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
  • Performance regression:
    • With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

Future Improvements
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Differential Revision: D74790154

Copy link

netlify bot commented May 17, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit d2568ed
🔍 Latest deploy log https://app.netlify.com/projects/pytorch-fbgemm-docs/deploys/682c0f2c92c281000877b6f5
😎 Deploy Preview https://deploy-preview-4145--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify project configuration.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D74790154

@emlin emlin force-pushed the export-D74790154 branch from 14bc777 to 4f5149c Compare May 17, 2025 06:46
emlin added a commit to emlin/FBGEMM that referenced this pull request May 17, 2025
Summary:
X-link: pytorch/torchrec#2976


X-link: facebookresearch/FBGEMM#1226

**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

**Current Solution**
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.

**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
  - Loop through chunks of weight and optimizer.
  - Concatenate weight and optimizer together.
  - Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
  - Set optimizer to device tensor based on ID.
  - Write ID weight to backend for each table.

**Limitations**
The current solution has two limitations:
- Memory overhead: 
  - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression: 
  - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

**Future Improvements**
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Differential Revision: D74790154
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D74790154

emlin added a commit to emlin/torchrec that referenced this pull request May 17, 2025
Summary:

X-link: pytorch/FBGEMM#4145

X-link: facebookresearch/FBGEMM#1226

**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

**Current Solution**
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.

**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
  - Loop through chunks of weight and optimizer.
  - Concatenate weight and optimizer together.
  - Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
  - Set optimizer to device tensor based on ID.
  - Write ID weight to backend for each table.

**Limitations**
The current solution has two limitations:
- Memory overhead: 
  - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression: 
  - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

**Future Improvements**
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Differential Revision: D74790154
@emlin emlin force-pushed the export-D74790154 branch from 4f5149c to 2c5efc0 Compare May 19, 2025 22:26
emlin added a commit to emlin/FBGEMM that referenced this pull request May 19, 2025
Summary:
X-link: pytorch/torchrec#2976


X-link: facebookresearch/FBGEMM#1226

# Functions
**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

# Current Solution
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.

**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
  - Loop through chunks of weight and optimizer.
  - Concatenate weight and optimizer together.
  - Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
  - Set optimizer to device tensor based on ID.
  - Write ID weight to backend for each table.

# Limitations
The current solution has two limitations:
- Memory overhead:
  - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression:
  - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

# Future Improvements
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Reviewed By: bobbyliujb

Differential Revision: D74790154
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D74790154

emlin added a commit to emlin/torchrec that referenced this pull request May 19, 2025
Summary:

X-link: pytorch/FBGEMM#4145

X-link: facebookresearch/FBGEMM#1226

# Functions
**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

# Current Solution
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.

**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
  - Loop through chunks of weight and optimizer.
  - Concatenate weight and optimizer together.
  - Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
  - Set optimizer to device tensor based on ID.
  - Write ID weight to backend for each table.

# Limitations
The current solution has two limitations:
- Memory overhead:
  - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression:
  - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

# Future Improvements
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Reviewed By: bobbyliujb

Differential Revision: D74790154
emlin added a commit to emlin/FBGEMM that referenced this pull request May 19, 2025
Summary:
X-link: pytorch/torchrec#2976


X-link: facebookresearch/FBGEMM#1226

# Functions
**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

# Current Solution
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.

**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
  - Loop through chunks of weight and optimizer.
  - Concatenate weight and optimizer together.
  - Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
  - Set optimizer to device tensor based on ID.
  - Write ID weight to backend for each table.

# Limitations
The current solution has two limitations:
- Memory overhead:
  - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression:
  - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

# Future Improvements
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Reviewed By: bobbyliujb

Differential Revision: D74790154
@emlin emlin force-pushed the export-D74790154 branch from 2c5efc0 to a92923e Compare May 19, 2025 22:43
emlin added a commit to emlin/torchrec that referenced this pull request May 19, 2025
Summary:

X-link: pytorch/FBGEMM#4145

X-link: facebookresearch/FBGEMM#1226

# Functions
**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

# Current Solution
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.

**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
  - Loop through chunks of weight and optimizer.
  - Concatenate weight and optimizer together.
  - Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
  - Set optimizer to device tensor based on ID.
  - Write ID weight to backend for each table.

# Limitations
The current solution has two limitations:
- Memory overhead:
  - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression:
  - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

# Future Improvements
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Reviewed By: bobbyliujb

Differential Revision: D74790154
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D74790154

Summary:
Pull Request resolved: pytorch#4141

X-link: facebookresearch/FBGEMM#1224

implement split_optimizer_states for optimizer state dict integration

Differential Revision: D74790121

Reviewed By: duduyi2013, bobbyliujb
emlin added a commit to emlin/torchrec that referenced this pull request May 20, 2025
Summary:

X-link: pytorch/FBGEMM#4145

X-link: facebookresearch/FBGEMM#1226

# Functions
**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

# Current Solution
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.

**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
  - Loop through chunks of weight and optimizer.
  - Concatenate weight and optimizer together.
  - Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
  - Set optimizer to device tensor based on ID.
  - Write ID weight to backend for each table.

# Limitations
The current solution has two limitations:
- Memory overhead:
  - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression:
  - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

# Future Improvements
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Reviewed By: bobbyliujb

Differential Revision: D74790154
@emlin emlin force-pushed the export-D74790154 branch from a92923e to 224d6ad Compare May 20, 2025 05:07
emlin added a commit to emlin/FBGEMM that referenced this pull request May 20, 2025
Summary:
X-link: pytorch/torchrec#2976


X-link: facebookresearch/FBGEMM#1226

# Functions
**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

# Current Solution
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.

**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
  - Loop through chunks of weight and optimizer.
  - Concatenate weight and optimizer together.
  - Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
  - Set optimizer to device tensor based on ID.
  - Write ID weight to backend for each table.

# Limitations
The current solution has two limitations:
- Memory overhead:
  - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression:
  - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

# Future Improvements
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Reviewed By: bobbyliujb

Differential Revision: D74790154
@emlin emlin force-pushed the export-D74790154 branch from 224d6ad to 87ef8a6 Compare May 20, 2025 05:07
emlin added a commit to emlin/FBGEMM that referenced this pull request May 20, 2025
Summary:
X-link: pytorch/torchrec#2976


X-link: facebookresearch/FBGEMM#1226

# Functions
**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

# Current Solution
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.

**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
  - Loop through chunks of weight and optimizer.
  - Concatenate weight and optimizer together.
  - Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
  - Set optimizer to device tensor based on ID.
  - Write ID weight to backend for each table.

# Limitations
The current solution has two limitations:
- Memory overhead:
  - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression:
  - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

# Future Improvements
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Reviewed By: bobbyliujb

Differential Revision: D74790154
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D74790154

emlin added a commit to emlin/torchrec that referenced this pull request May 20, 2025
Summary:

X-link: pytorch/FBGEMM#4145

X-link: facebookresearch/FBGEMM#1226

# Functions
**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

# Current Solution
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.

**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
  - Loop through chunks of weight and optimizer.
  - Concatenate weight and optimizer together.
  - Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
  - Set optimizer to device tensor based on ID.
  - Write ID weight to backend for each table.

# Limitations
The current solution has two limitations:
- Memory overhead:
  - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression:
  - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

# Future Improvements
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Reviewed By: bobbyliujb

Differential Revision: D74790154
Summary:
X-link: pytorch/torchrec#2976

Pull Request resolved: pytorch#4145

X-link: facebookresearch/FBGEMM#1226

# Functions
**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

# Current Solution
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.

**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
  - Loop through chunks of weight and optimizer.
  - Concatenate weight and optimizer together.
  - Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
  - Set optimizer to device tensor based on ID.
  - Write ID weight to backend for each table.

# Limitations
The current solution has two limitations:
- Memory overhead:
  - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression:
  - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

# Future Improvements
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Reviewed By: bobbyliujb

Differential Revision: D74790154
emlin added a commit to emlin/torchrec that referenced this pull request May 20, 2025
Summary:
Pull Request resolved: pytorch#2976

X-link: pytorch/FBGEMM#4145

X-link: facebookresearch/FBGEMM#1226

# Functions
**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

# Current Solution
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.

**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
  - Loop through chunks of weight and optimizer.
  - Concatenate weight and optimizer together.
  - Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
  - Set optimizer to device tensor based on ID.
  - Write ID weight to backend for each table.

# Limitations
The current solution has two limitations:
- Memory overhead:
  - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression:
  - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

# Future Improvements
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Reviewed By: bobbyliujb

Differential Revision: D74790154
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D74790154

@emlin emlin force-pushed the export-D74790154 branch from 87ef8a6 to d2568ed Compare May 20, 2025 05:12
duduyi2013 added a commit to duduyi2013/FBGEMM that referenced this pull request May 20, 2025
Summary:
X-link: pytorch/torchrec#2976

Pull Request resolved: pytorch#4145

X-link: facebookresearch/FBGEMM#1226

# Functions
**Saving State Dict**
When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded.

**Checkpoint Loading Mode**
We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading.

# Current Solution
The current solution involves caching all data in Python tensors, following these steps:
- Set self.local_weight_counts based on checkpoint bucket tensor size.
- Enable load state dict mode to initialize local cache tensors.
- Call state_dict to get empty tensors for the checkpoint loader.
- Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader.
- Call apply_state_dict to write all cached tensors to the backend.

**Apply State Dict Flow**
During the apply_state_dict step, we perform the following operations:
- If optimizer offloading is enabled:
  - Loop through chunks of weight and optimizer.
  - Concatenate weight and optimizer together.
  - Write to backend using KVTensorWrapper interface.
- If optimizer offloading is disabled:
  - Set optimizer to device tensor based on ID.
  - Write ID weight to backend for each table.

# Limitations
The current solution has two limitations:
- Memory overhead:
  - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables.
- Performance regression:
  - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression.

# Future Improvements
After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead.

Differential Revision: D74790154

Reviewed By: bobbyliujb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants