Skip to content

Commit 025543a

Browse files
emlinfacebook-github-bot
authored andcommitted
implement optimizer state with opt offloading (#4141)
Summary: X-link: facebookresearch/FBGEMM#1224 implement split_optimizer_states for optimizer state dict integration Reviewed By: duduyi2013, bobbyliujb Differential Revision: D74790121
1 parent 157e88b commit 025543a

File tree

6 files changed

+435
-27
lines changed

6 files changed

+435
-27
lines changed

fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ def state_size_dim(self, dtype: torch.dtype) -> int:
5858
"""
5959
return int(math.ceil(self.state_size() / dtype.itemsize))
6060

61+
def dtype(self) -> torch.dtype:
62+
"""
63+
Returns the dtype of the optimizer state
64+
"""
65+
return {
66+
EmbOptimType.EXACT_ROWWISE_ADAGRAD: torch.float32,
67+
}.get(self, torch.float32)
68+
6169

6270
# Base class for quantization configuration (in case other numeric types have
6371
# configs)

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class KVZCHParams(NamedTuple):
5858
bucket_sizes: List[int] = []
5959
# enable optimizer offloading or not
6060
enable_optimizer_offloading: bool = True
61+
# streaming load/save checkpoint chunk size
62+
streaming_ckpt_chunk_size: int = 1000000
6163

6264
def validate(self) -> None:
6365
assert len(self.bucket_offsets) == len(self.bucket_sizes), (

fbgemm_gpu/fbgemm_gpu/tbe/ssd/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# pyre-ignore-all-errors[56]
1010

1111
import torch
12+
1213
from fbgemm_gpu.utils.loader import load_torch_module
1314

1415
try:

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 172 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,13 @@
5050
from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
5151
generate_vbe_metadata,
5252
)
53-
5453
from torch import distributed as dist, nn, Tensor # usort:skip
5554
from dataclasses import dataclass
5655

5756
from fbgemm_gpu.tbe.ssd.common import tensor_pad4
58-
5957
from torch.autograd.profiler import record_function
6058

6159
from ..cache import get_unique_indices_v2
62-
6360
from .common import ASSOC, pad4
6461
from .utils.partially_materialized_tensor import PartiallyMaterializedTensor
6562

@@ -78,9 +75,9 @@ class IterData:
7875

7976
@dataclass
8077
class KVZCHCachedData:
81-
cached_id_tensor_per_table: List[torch.Tensor]
82-
cached_weight_tensor_per_table: List[torch.Tensor]
8378
cached_optimizer_state_per_table: List[torch.Tensor]
79+
cached_weight_tensor_per_table: List[torch.Tensor]
80+
cached_id_tensor_per_table: List[torch.Tensor]
8481
cached_bucket_splits: List[torch.Tensor]
8582

8683

@@ -175,11 +172,13 @@ def __init__(
175172
) -> None:
176173
super(SSDTableBatchedEmbeddingBags, self).__init__()
177174

175+
# Set the optimizer
178176
assert optimizer in (
179177
OptimType.EXACT_ROWWISE_ADAGRAD,
180178
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
181179
self.optimizer = optimizer
182180

181+
# Set the table weight and output dtypes
183182
assert weights_precision in (SparseType.FP32, SparseType.FP16)
184183
self.weights_precision = weights_precision
185184
self.output_dtype: int = output_dtype.as_int()
@@ -702,7 +701,9 @@ def __init__(
702701
momentum1_offsets = [0] + list(itertools.accumulate(rows))
703702
self._apply_split(
704703
SplitState(
705-
dev_size=self.total_hash_size,
704+
dev_size=(
705+
self.total_hash_size if not self.enable_optimizer_offloading else 0
706+
),
706707
host_size=0,
707708
uvm_size=0,
708709
placements=[EmbeddingLocation.DEVICE for _ in range(T_)],
@@ -1720,6 +1721,7 @@ def forward(
17201721
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
17211722
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
17221723
) -> Tensor:
1724+
self.clear_cache()
17231725
indices, offsets, per_sample_weights, vbe_metadata = self.prepare_inputs(
17241726
indices, offsets, per_sample_weights, batch_size_per_feature_per_rank
17251727
)
@@ -1877,10 +1879,30 @@ def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor, int, int]]:
18771879
for t, row in enumerate(rows)
18781880
]
18791881

1882+
@torch.jit.ignore
1883+
def _split_optimizer_states_non_kv_zch(
1884+
self,
1885+
) -> List[torch.Tensor]:
1886+
"""
1887+
Returns a list of optimizer states, split by table. So far, we only support EXACT_ROWWISE_ADAGRAD,
1888+
so only momentum1 state is returned.
1889+
"""
1890+
logging.info("_split_optimizer_states_non_kv_zch")
1891+
(rows, _) = zip(*self.embedding_specs)
1892+
1893+
rows_cumsum = [0] + list(itertools.accumulate(rows))
1894+
1895+
return [
1896+
self.momentum1_dev.detach()[rows_cumsum[t] : rows_cumsum[t + 1]].view(row)
1897+
for t, row in enumerate(rows)
1898+
]
1899+
18801900
@torch.jit.export
18811901
def split_optimizer_states(
18821902
self,
18831903
sorted_id_tensor: Optional[List[torch.Tensor]] = None,
1904+
no_snapshot: bool = True,
1905+
should_flush: bool = False,
18841906
) -> List[torch.Tensor]:
18851907
"""
18861908
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD,
@@ -1897,14 +1919,126 @@ def split_optimizer_states(
18971919
id consistency between weight and optimizer states.
18981920
18991921
"""
1900-
raise NotImplementedError(
1901-
"split_optimizer_states is not implemented for SSDTableBatchedEmbeddingBags"
1922+
1923+
if not self.kv_zch_params:
1924+
return self._split_optimizer_states_non_kv_zch()
1925+
1926+
if self.load_state_dict:
1927+
# init for checkpointing loading
1928+
assert (
1929+
self._cached_kvzch_data is not None
1930+
and self._cached_kvzch_data.cached_optimizer_state_per_table is not None
1931+
), "optimizer state is not initialized for load checkpointing"
1932+
return self._cached_kvzch_data.cached_optimizer_state_per_table
1933+
1934+
logging.info(
1935+
f"split_optimizer_states for KV ZCH: {no_snapshot=}, {should_flush=}"
1936+
)
1937+
start_time = time.time()
1938+
snapshot_handle = self._may_create_snapshot_for_state_dict(
1939+
no_snapshot=no_snapshot,
1940+
should_flush=should_flush,
19021941
)
19031942

1943+
opt_list = []
1944+
table_offset = 0
1945+
1946+
dtype = self.weights_precision.as_dtype()
1947+
optimizer_dim = self.optimizer.state_size_dim(dtype)
1948+
pad4_optimizer_dim = pad4(optimizer_dim)
1949+
logging.info(
1950+
f"split_optimizer_states: {optimizer_dim=} {pad4_optimizer_dim=} {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}"
1951+
)
1952+
1953+
for t, (emb_height, emb_dim) in enumerate(self.embedding_specs):
1954+
# pyre-ignore
1955+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
1956+
# pyre-ignore
1957+
bucket_size = self.kv_zch_params.bucket_sizes[t]
1958+
row_offset = table_offset
1959+
if sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0:
1960+
opt_list.append(
1961+
torch.empty(0, dtype=self.optimizer.dtype(), device="cpu")
1962+
# empty optimizer state for module initialization
1963+
)
1964+
else:
1965+
if not self.enable_optimizer_offloading:
1966+
# convert global id back to local id, then linearize with table offset
1967+
local_id_tensor = (
1968+
sorted_id_tensor[t]
1969+
- bucket_id_start * bucket_size
1970+
+ table_offset
1971+
)
1972+
opt_list.append(
1973+
self.momentum1_dev.detach().cpu()[local_id_tensor].view(-1),
1974+
)
1975+
else:
1976+
emb_opt_dim = pad4(emb_dim) + pad4_optimizer_dim
1977+
row_offset = table_offset - (bucket_id_start * bucket_size)
1978+
# using KVTensorWrapper to query backend to avoid OOM memory, since
1979+
# backend will return both weight and optimizer in one tensor, read the whole tensor
1980+
# out could OOM CPU memory.
1981+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
1982+
shape=[emb_height, emb_opt_dim],
1983+
dtype=dtype,
1984+
row_offset=row_offset,
1985+
snapshot_handle=snapshot_handle,
1986+
materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]),
1987+
sorted_indices=sorted_id_tensor[t],
1988+
)
1989+
(
1990+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
1991+
if self.backend_type == BackendType.SSD
1992+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
1993+
)
1994+
opt_list.append(
1995+
self.get_offloaded_optimizer_states(
1996+
tensor_wrapper=tensor_wrapper,
1997+
row=sorted_id_tensor[t].size(
1998+
0
1999+
), # we only need to copy the size of sorted_id_tensor
2000+
optimizer_dim=optimizer_dim,
2001+
start_dim_pos=pad4(emb_dim),
2002+
)
2003+
)
2004+
table_offset += emb_height
2005+
logging.info(
2006+
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms"
2007+
)
2008+
return opt_list
2009+
2010+
@torch.jit.export
2011+
def get_offloaded_optimizer_states(
2012+
self,
2013+
# pyre-ignore [2]
2014+
tensor_wrapper,
2015+
row: int,
2016+
optimizer_dim: int,
2017+
start_dim_pos: int,
2018+
) -> torch.Tensor:
2019+
weight_dtype = self.weights_precision.as_dtype()
2020+
opt_state_t = torch.empty(
2021+
row, optimizer_dim, dtype=weight_dtype, device="cpu"
2022+
) # 1D optimizer for OptimType.EXACT_ROWWISE_ADAGRAD
2023+
2024+
# pyre-ignore [16]
2025+
chunk_size = self.kv_zch_params.streaming_ckpt_chunk_size
2026+
for i in range(0, row, chunk_size):
2027+
length = min(chunk_size, row - i)
2028+
opt_state_t.narrow(0, i, length).copy_(
2029+
tensor_wrapper.narrow(0, i, length).narrow(
2030+
1, start_dim_pos, optimizer_dim
2031+
)
2032+
)
2033+
# view optimizer state back to correct dtype
2034+
return opt_state_t.view(-1).view(self.optimizer.dtype())
2035+
19042036
@torch.jit.export
19052037
def get_optimizer_state(
19062038
self,
19072039
sorted_id_tensor: Optional[List[torch.Tensor]],
2040+
no_snapshot: bool = True,
2041+
should_flush: bool = False,
19082042
) -> List[Dict[str, torch.Tensor]]:
19092043
"""
19102044
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD
@@ -1914,6 +2048,8 @@ def get_optimizer_state(
19142048
({"momentum1": states})
19152049
for states in self.split_optimizer_states(
19162050
sorted_id_tensor=sorted_id_tensor,
2051+
no_snapshot=no_snapshot,
2052+
should_flush=should_flush,
19172053
)
19182054
]
19192055

@@ -1963,8 +2099,32 @@ def debug_split_embedding_weights(self) -> List[torch.Tensor]:
19632099
return splits
19642100

19652101
def clear_cache(self) -> None:
2102+
# clear KV ZCH cache for checkpointing
19662103
self._cached_kvzch_data = None
19672104

2105+
@torch.jit.ignore
2106+
# pyre-ignore [3] - do not definte snapshot class EmbeddingSnapshotHandleWrapper to avoid import dependency in other production code
2107+
def _may_create_snapshot_for_state_dict(
2108+
self,
2109+
no_snapshot: bool = True,
2110+
should_flush: bool = False,
2111+
):
2112+
"""
2113+
Create a rocksdb snapshot if needed.
2114+
"""
2115+
# Force device synchronize for now
2116+
torch.cuda.synchronize()
2117+
snapshot_handle = None
2118+
if self.backend_type == BackendType.SSD:
2119+
# Create a rocksdb snapshot
2120+
if not no_snapshot:
2121+
# Flush L1 and L2 caches
2122+
self.flush(force=should_flush)
2123+
snapshot_handle = self.ssd_db.create_snapshot()
2124+
elif self.backend_type == BackendType.DRAM:
2125+
self.flush(force=should_flush)
2126+
return snapshot_handle
2127+
19682128
@torch.jit.export
19692129
def split_embedding_weights(
19702130
self,
@@ -1994,18 +2154,10 @@ def split_embedding_weights(
19942154
3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
19952155
where for the i th element, we have i + bucket_id_start = global bucket id
19962156
"""
1997-
# Force device synchronize for now
1998-
torch.cuda.synchronize()
1999-
snapshot_handle = None
2000-
if self.backend_type == BackendType.SSD:
2001-
# Create a rocksdb snapshot
2002-
if not no_snapshot:
2003-
if should_flush:
2004-
# Flush L1 and L2 caches
2005-
self.flush(force=True)
2006-
snapshot_handle = self.ssd_db.create_snapshot()
2007-
elif self.backend_type == BackendType.DRAM:
2008-
self.flush(force=True)
2157+
snapshot_handle = self._may_create_snapshot_for_state_dict(
2158+
no_snapshot=no_snapshot,
2159+
should_flush=should_flush,
2160+
)
20092161

20102162
dtype = self.weights_precision.as_dtype()
20112163
pmt_splits = []

fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
# pyre-strict
88
# pyre-ignore-all-errors[3,6,56]
99

10-
import logging
1110
import math
1211
import tempfile
1312

0 commit comments

Comments
 (0)