50
50
from fbgemm_gpu .split_table_batched_embeddings_ops_training_common import (
51
51
generate_vbe_metadata ,
52
52
)
53
-
54
53
from torch import distributed as dist , nn , Tensor # usort:skip
55
54
from dataclasses import dataclass
56
55
57
56
from fbgemm_gpu .tbe .ssd .common import tensor_pad4
58
-
59
57
from torch .autograd .profiler import record_function
60
58
61
59
from ..cache import get_unique_indices_v2
62
-
63
60
from .common import ASSOC , pad4
64
61
from .utils .partially_materialized_tensor import PartiallyMaterializedTensor
65
62
@@ -78,9 +75,9 @@ class IterData:
78
75
79
76
@dataclass
80
77
class KVZCHCachedData :
81
- cached_id_tensor_per_table : List [torch .Tensor ]
82
- cached_weight_tensor_per_table : List [torch .Tensor ]
83
78
cached_optimizer_state_per_table : List [torch .Tensor ]
79
+ cached_weight_tensor_per_table : List [torch .Tensor ]
80
+ cached_id_tensor_per_table : List [torch .Tensor ]
84
81
cached_bucket_splits : List [torch .Tensor ]
85
82
86
83
@@ -175,11 +172,13 @@ def __init__(
175
172
) -> None :
176
173
super (SSDTableBatchedEmbeddingBags , self ).__init__ ()
177
174
175
+ # Set the optimizer
178
176
assert optimizer in (
179
177
OptimType .EXACT_ROWWISE_ADAGRAD ,
180
178
), f"Optimizer { optimizer } is not supported by SSDTableBatchedEmbeddingBags"
181
179
self .optimizer = optimizer
182
180
181
+ # Set the table weight and output dtypes
183
182
assert weights_precision in (SparseType .FP32 , SparseType .FP16 )
184
183
self .weights_precision = weights_precision
185
184
self .output_dtype : int = output_dtype .as_int ()
@@ -702,7 +701,9 @@ def __init__(
702
701
momentum1_offsets = [0 ] + list (itertools .accumulate (rows ))
703
702
self ._apply_split (
704
703
SplitState (
705
- dev_size = self .total_hash_size ,
704
+ dev_size = (
705
+ self .total_hash_size if not self .enable_optimizer_offloading else 0
706
+ ),
706
707
host_size = 0 ,
707
708
uvm_size = 0 ,
708
709
placements = [EmbeddingLocation .DEVICE for _ in range (T_ )],
@@ -1720,6 +1721,7 @@ def forward(
1720
1721
batch_size_per_feature_per_rank : Optional [List [List [int ]]] = None ,
1721
1722
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
1722
1723
) -> Tensor :
1724
+ self .clear_cache ()
1723
1725
indices , offsets , per_sample_weights , vbe_metadata = self .prepare_inputs (
1724
1726
indices , offsets , per_sample_weights , batch_size_per_feature_per_rank
1725
1727
)
@@ -1877,10 +1879,30 @@ def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor, int, int]]:
1877
1879
for t , row in enumerate (rows )
1878
1880
]
1879
1881
1882
+ @torch .jit .ignore
1883
+ def _split_optimizer_states_non_kv_zch (
1884
+ self ,
1885
+ ) -> List [torch .Tensor ]:
1886
+ """
1887
+ Returns a list of optimizer states, split by table. So far, we only support EXACT_ROWWISE_ADAGRAD,
1888
+ so only momentum1 state is returned.
1889
+ """
1890
+ logging .info ("_split_optimizer_states_non_kv_zch" )
1891
+ (rows , _ ) = zip (* self .embedding_specs )
1892
+
1893
+ rows_cumsum = [0 ] + list (itertools .accumulate (rows ))
1894
+
1895
+ return [
1896
+ self .momentum1_dev .detach ()[rows_cumsum [t ] : rows_cumsum [t + 1 ]].view (row )
1897
+ for t , row in enumerate (rows )
1898
+ ]
1899
+
1880
1900
@torch .jit .export
1881
1901
def split_optimizer_states (
1882
1902
self ,
1883
1903
sorted_id_tensor : Optional [List [torch .Tensor ]] = None ,
1904
+ no_snapshot : bool = True ,
1905
+ should_flush : bool = False ,
1884
1906
) -> List [torch .Tensor ]:
1885
1907
"""
1886
1908
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD,
@@ -1897,14 +1919,126 @@ def split_optimizer_states(
1897
1919
id consistency between weight and optimizer states.
1898
1920
1899
1921
"""
1900
- raise NotImplementedError (
1901
- "split_optimizer_states is not implemented for SSDTableBatchedEmbeddingBags"
1922
+
1923
+ if not self .kv_zch_params :
1924
+ return self ._split_optimizer_states_non_kv_zch ()
1925
+
1926
+ if self .load_state_dict :
1927
+ # init for checkpointing loading
1928
+ assert (
1929
+ self ._cached_kvzch_data is not None
1930
+ and self ._cached_kvzch_data .cached_optimizer_state_per_table is not None
1931
+ ), "optimizer state is not initialized for load checkpointing"
1932
+ return self ._cached_kvzch_data .cached_optimizer_state_per_table
1933
+
1934
+ logging .info (
1935
+ f"split_optimizer_states for KV ZCH: { no_snapshot = } , { should_flush = } "
1936
+ )
1937
+ start_time = time .time ()
1938
+ snapshot_handle = self ._may_create_snapshot_for_state_dict (
1939
+ no_snapshot = no_snapshot ,
1940
+ should_flush = should_flush ,
1902
1941
)
1903
1942
1943
+ opt_list = []
1944
+ table_offset = 0
1945
+
1946
+ dtype = self .weights_precision .as_dtype ()
1947
+ optimizer_dim = self .optimizer .state_size_dim (dtype )
1948
+ pad4_optimizer_dim = pad4 (optimizer_dim )
1949
+ logging .info (
1950
+ f"split_optimizer_states: { optimizer_dim = } { pad4_optimizer_dim = } { self .optimizer .dtype ()= } { self .enable_load_state_dict_mode = } "
1951
+ )
1952
+
1953
+ for t , (emb_height , emb_dim ) in enumerate (self .embedding_specs ):
1954
+ # pyre-ignore
1955
+ bucket_id_start , _ = self .kv_zch_params .bucket_offsets [t ]
1956
+ # pyre-ignore
1957
+ bucket_size = self .kv_zch_params .bucket_sizes [t ]
1958
+ row_offset = table_offset
1959
+ if sorted_id_tensor is None or sorted_id_tensor [t ].numel () == 0 :
1960
+ opt_list .append (
1961
+ torch .empty (0 , dtype = self .optimizer .dtype (), device = "cpu" )
1962
+ # empty optimizer state for module initialization
1963
+ )
1964
+ else :
1965
+ if not self .enable_optimizer_offloading :
1966
+ # convert global id back to local id, then linearize with table offset
1967
+ local_id_tensor = (
1968
+ sorted_id_tensor [t ]
1969
+ - bucket_id_start * bucket_size
1970
+ + table_offset
1971
+ )
1972
+ opt_list .append (
1973
+ self .momentum1_dev .detach ().cpu ()[local_id_tensor ].view (- 1 ),
1974
+ )
1975
+ else :
1976
+ emb_opt_dim = pad4 (emb_dim ) + pad4_optimizer_dim
1977
+ row_offset = table_offset - (bucket_id_start * bucket_size )
1978
+ # using KVTensorWrapper to query backend to avoid OOM memory, since
1979
+ # backend will return both weight and optimizer in one tensor, read the whole tensor
1980
+ # out could OOM CPU memory.
1981
+ tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
1982
+ shape = [emb_height , emb_opt_dim ],
1983
+ dtype = dtype ,
1984
+ row_offset = row_offset ,
1985
+ snapshot_handle = snapshot_handle ,
1986
+ materialized_shape = ([sorted_id_tensor [t ].size (0 ), emb_opt_dim ]),
1987
+ sorted_indices = sorted_id_tensor [t ],
1988
+ )
1989
+ (
1990
+ tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
1991
+ if self .backend_type == BackendType .SSD
1992
+ else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
1993
+ )
1994
+ opt_list .append (
1995
+ self .get_offloaded_optimizer_states (
1996
+ tensor_wrapper = tensor_wrapper ,
1997
+ row = sorted_id_tensor [t ].size (
1998
+ 0
1999
+ ), # we only need to copy the size of sorted_id_tensor
2000
+ optimizer_dim = optimizer_dim ,
2001
+ start_dim_pos = pad4 (emb_dim ),
2002
+ )
2003
+ )
2004
+ table_offset += emb_height
2005
+ logging .info (
2006
+ f"KV ZCH tables split_optimizer_states query latency: { (time .time () - start_time ) * 1000 } ms"
2007
+ )
2008
+ return opt_list
2009
+
2010
+ @torch .jit .export
2011
+ def get_offloaded_optimizer_states (
2012
+ self ,
2013
+ # pyre-ignore [2]
2014
+ tensor_wrapper ,
2015
+ row : int ,
2016
+ optimizer_dim : int ,
2017
+ start_dim_pos : int ,
2018
+ ) -> torch .Tensor :
2019
+ weight_dtype = self .weights_precision .as_dtype ()
2020
+ opt_state_t = torch .empty (
2021
+ row , optimizer_dim , dtype = weight_dtype , device = "cpu"
2022
+ ) # 1D optimizer for OptimType.EXACT_ROWWISE_ADAGRAD
2023
+
2024
+ # pyre-ignore [16]
2025
+ chunk_size = self .kv_zch_params .streaming_ckpt_chunk_size
2026
+ for i in range (0 , row , chunk_size ):
2027
+ length = min (chunk_size , row - i )
2028
+ opt_state_t .narrow (0 , i , length ).copy_ (
2029
+ tensor_wrapper .narrow (0 , i , length ).narrow (
2030
+ 1 , start_dim_pos , optimizer_dim
2031
+ )
2032
+ )
2033
+ # view optimizer state back to correct dtype
2034
+ return opt_state_t .view (- 1 ).view (self .optimizer .dtype ())
2035
+
1904
2036
@torch .jit .export
1905
2037
def get_optimizer_state (
1906
2038
self ,
1907
2039
sorted_id_tensor : Optional [List [torch .Tensor ]],
2040
+ no_snapshot : bool = True ,
2041
+ should_flush : bool = False ,
1908
2042
) -> List [Dict [str , torch .Tensor ]]:
1909
2043
"""
1910
2044
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD
@@ -1914,6 +2048,8 @@ def get_optimizer_state(
1914
2048
({"momentum1" : states })
1915
2049
for states in self .split_optimizer_states (
1916
2050
sorted_id_tensor = sorted_id_tensor ,
2051
+ no_snapshot = no_snapshot ,
2052
+ should_flush = should_flush ,
1917
2053
)
1918
2054
]
1919
2055
@@ -1963,8 +2099,32 @@ def debug_split_embedding_weights(self) -> List[torch.Tensor]:
1963
2099
return splits
1964
2100
1965
2101
def clear_cache (self ) -> None :
2102
+ # clear KV ZCH cache for checkpointing
1966
2103
self ._cached_kvzch_data = None
1967
2104
2105
+ @torch .jit .ignore
2106
+ # pyre-ignore [3] - do not definte snapshot class EmbeddingSnapshotHandleWrapper to avoid import dependency in other production code
2107
+ def _may_create_snapshot_for_state_dict (
2108
+ self ,
2109
+ no_snapshot : bool = True ,
2110
+ should_flush : bool = False ,
2111
+ ):
2112
+ """
2113
+ Create a rocksdb snapshot if needed.
2114
+ """
2115
+ # Force device synchronize for now
2116
+ torch .cuda .synchronize ()
2117
+ snapshot_handle = None
2118
+ if self .backend_type == BackendType .SSD :
2119
+ # Create a rocksdb snapshot
2120
+ if not no_snapshot :
2121
+ # Flush L1 and L2 caches
2122
+ self .flush (force = should_flush )
2123
+ snapshot_handle = self .ssd_db .create_snapshot ()
2124
+ elif self .backend_type == BackendType .DRAM :
2125
+ self .flush (force = should_flush )
2126
+ return snapshot_handle
2127
+
1968
2128
@torch .jit .export
1969
2129
def split_embedding_weights (
1970
2130
self ,
@@ -1994,18 +2154,10 @@ def split_embedding_weights(
1994
2154
3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
1995
2155
where for the i th element, we have i + bucket_id_start = global bucket id
1996
2156
"""
1997
- # Force device synchronize for now
1998
- torch .cuda .synchronize ()
1999
- snapshot_handle = None
2000
- if self .backend_type == BackendType .SSD :
2001
- # Create a rocksdb snapshot
2002
- if not no_snapshot :
2003
- if should_flush :
2004
- # Flush L1 and L2 caches
2005
- self .flush (force = True )
2006
- snapshot_handle = self .ssd_db .create_snapshot ()
2007
- elif self .backend_type == BackendType .DRAM :
2008
- self .flush (force = True )
2157
+ snapshot_handle = self ._may_create_snapshot_for_state_dict (
2158
+ no_snapshot = no_snapshot ,
2159
+ should_flush = should_flush ,
2160
+ )
2009
2161
2010
2162
dtype = self .weights_precision .as_dtype ()
2011
2163
pmt_splits = []
0 commit comments