Skip to content

Commit 224d6ad

Browse files
emlinfacebook-github-bot
authored andcommitted
support get state dict and apply state dict (#4145)
Summary: X-link: pytorch/torchrec#2976 X-link: facebookresearch/FBGEMM#1226 # Functions **Saving State Dict** When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded. **Checkpoint Loading Mode** We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading. # Current Solution The current solution involves caching all data in Python tensors, following these steps: - Set self.local_weight_counts based on checkpoint bucket tensor size. - Enable load state dict mode to initialize local cache tensors. - Call state_dict to get empty tensors for the checkpoint loader. - Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader. - Call apply_state_dict to write all cached tensors to the backend. **Apply State Dict Flow** During the apply_state_dict step, we perform the following operations: - If optimizer offloading is enabled: - Loop through chunks of weight and optimizer. - Concatenate weight and optimizer together. - Write to backend using KVTensorWrapper interface. - If optimizer offloading is disabled: - Set optimizer to device tensor based on ID. - Write ID weight to backend for each table. # Limitations The current solution has two limitations: - Memory overhead: - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables. - Performance regression: - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression. # Future Improvements After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead. Reviewed By: bobbyliujb Differential Revision: D74790154
1 parent 025543a commit 224d6ad

File tree

2 files changed

+417
-17
lines changed

2 files changed

+417
-17
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 205 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def __init__(
314314
f"weights precision: {weights_precision}, "
315315
f"output dtype: {output_dtype}, "
316316
f"chunk size in bulk init: {bulk_init_chunk_size} bytes, backend_type: {backend_type}, "
317-
f"zero_collision_config: {kv_zch_params}"
317+
f"kv_zch_params: {kv_zch_params}"
318318
)
319319
self.register_buffer(
320320
"lxu_cache_state",
@@ -1983,7 +1983,6 @@ def split_optimizer_states(
19831983
dtype=dtype,
19841984
row_offset=row_offset,
19851985
snapshot_handle=snapshot_handle,
1986-
materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]),
19871986
sorted_indices=sorted_id_tensor[t],
19881987
)
19891988
(
@@ -2112,6 +2111,7 @@ def _may_create_snapshot_for_state_dict(
21122111
"""
21132112
Create a rocksdb snapshot if needed.
21142113
"""
2114+
start_time = time.time()
21152115
# Force device synchronize for now
21162116
torch.cuda.synchronize()
21172117
snapshot_handle = None
@@ -2120,7 +2120,13 @@ def _may_create_snapshot_for_state_dict(
21202120
if not no_snapshot:
21212121
# Flush L1 and L2 caches
21222122
self.flush(force=should_flush)
2123+
logging.info(
2124+
f"flush latency for weight states: {(time.time() - start_time) * 1000} ms"
2125+
)
21232126
snapshot_handle = self.ssd_db.create_snapshot()
2127+
logging.info(
2128+
f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms"
2129+
)
21242130
elif self.backend_type == BackendType.DRAM:
21252131
self.flush(force=should_flush)
21262132
return snapshot_handle
@@ -2131,7 +2137,7 @@ def split_embedding_weights(
21312137
no_snapshot: bool = True,
21322138
should_flush: bool = False,
21332139
) -> Tuple[ # TODO: make this a NamedTuple for readability
2134-
List[PartiallyMaterializedTensor],
2140+
List[PartiallyMaterializedTensor] | List[torch.Tensor],
21352141
Optional[List[torch.Tensor]],
21362142
Optional[List[torch.Tensor]],
21372143
]:
@@ -2160,6 +2166,17 @@ def split_embedding_weights(
21602166
)
21612167

21622168
dtype = self.weights_precision.as_dtype()
2169+
if self.load_state_dict and self.kv_zch_params:
2170+
# init for checkpointing loading
2171+
assert (
2172+
self._cached_kvzch_data is not None
2173+
), "weight id and bucket state are not initialized for load checkpointing"
2174+
return (
2175+
self._cached_kvzch_data.cached_weight_tensor_per_table,
2176+
self._cached_kvzch_data.cached_id_tensor_per_table,
2177+
self._cached_kvzch_data.cached_bucket_splits,
2178+
)
2179+
start_time = time.time()
21632180
pmt_splits = []
21642181
bucket_sorted_id_splits = [] if self.kv_zch_params else None
21652182
active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
@@ -2168,18 +2185,15 @@ def split_embedding_weights(
21682185
for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
21692186
bucket_ascending_id_tensor = None
21702187
bucket_t = None
2188+
row_offset = table_offset
21712189
if self.kv_zch_params:
21722190
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
21732191
# pyre-ignore
21742192
bucket_size = self.kv_zch_params.bucket_sizes[i]
21752193

21762194
# linearize with table offset
2177-
table_input_id_start = (
2178-
min(bucket_id_start * bucket_size, emb_height) + table_offset
2179-
)
2180-
table_input_id_end = (
2181-
min(bucket_id_end * bucket_size, emb_height) + table_offset
2182-
)
2195+
table_input_id_start = table_offset
2196+
table_input_id_end = table_offset + emb_height
21832197
# 1. get all keys from backend for one table
21842198
unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot(
21852199
table_input_id_start,
@@ -2192,15 +2206,38 @@ def split_embedding_weights(
21922206
torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor(
21932207
unordered_id_tensor,
21942208
0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
2195-
bucket_id_start,
2196-
bucket_id_end,
2209+
0, # local bucket offset
2210+
bucket_id_end - bucket_id_start, # local bucket num
21972211
bucket_size,
21982212
)
21992213
)
2200-
# pyre-ignore
2214+
# 3. convert local id back to global id
2215+
bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)
2216+
2217+
if (
2218+
bucket_ascending_id_tensor.size(0) == 0
2219+
and self.local_weight_counts[i] > 0
2220+
):
2221+
logging.info(
2222+
f"resetting bucket id tensor with {self.local_weight_counts[i]}"
2223+
)
2224+
bucket_ascending_id_tensor = torch.zeros(
2225+
(self.local_weight_counts[i], 1),
2226+
device=torch.device("cpu"),
2227+
dtype=torch.int64,
2228+
)
2229+
# self.local_weight_counts[i] = 0 # Reset the count
2230+
2231+
# pyre-ignore [16] bucket_sorted_id_splits is not None
22012232
bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
22022233
active_id_cnt_per_bucket_split.append(bucket_t)
22032234

2235+
# for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
2236+
# but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
2237+
# first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
2238+
# to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
2239+
row_offset = table_offset - (bucket_id_start * bucket_size)
2240+
22042241
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
22052242
shape=[
22062243
(
@@ -2211,33 +2248,184 @@ def split_embedding_weights(
22112248
emb_dim,
22122249
],
22132250
dtype=dtype,
2214-
row_offset=table_offset,
2251+
row_offset=row_offset,
22152252
snapshot_handle=snapshot_handle,
2253+
# set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
2254+
# embedding weights.
22162255
sorted_indices=(
22172256
bucket_ascending_id_tensor if self.kv_zch_params else None
22182257
),
22192258
)
2220-
# TODO add if else support in the future for dram integration.
2221-
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2259+
(
2260+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2261+
if self.backend_type == BackendType.SSD
2262+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2263+
)
22222264
table_offset += emb_height
22232265
pmt_splits.append(
22242266
PartiallyMaterializedTensor(
22252267
tensor_wrapper,
22262268
True if self.kv_zch_params else False,
22272269
)
22282270
)
2271+
logging.info(
2272+
f"split_embedding_weights latency: {(time.time() - start_time) * 1000} ms"
2273+
)
22292274
return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)
22302275

22312276
@torch.jit.ignore
22322277
def apply_state_dict(self) -> None:
22332278
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
22342279
# Caller should call this function to apply the cached states to backend.
2235-
pass
2280+
if self.load_state_dict is False:
2281+
return
2282+
self.load_state_dict = False
2283+
assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only"
2284+
assert (
2285+
self._cached_kvzch_data is not None
2286+
and self._cached_kvzch_data.cached_optimizer_state_per_table is not None
2287+
), "optimizer state is not initialized for load checkpointing"
2288+
assert (
2289+
self._cached_kvzch_data.cached_weight_tensor_per_table is not None
2290+
and self._cached_kvzch_data.cached_id_tensor_per_table is not None
2291+
), "weight and id state is not initialized for load checkpointing"
2292+
2293+
# Compute the number of elements of cache_dtype needed to store the
2294+
# optimizer state, round to the nearest 4
2295+
# optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
2296+
# apply weight and optimizer state per table
2297+
table_offset = 0
2298+
for i, (emb_height, _) in enumerate(self.embedding_specs):
2299+
# pyre-ignore [16]
2300+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[i]
2301+
# pyre-ignore [16]
2302+
bucket_size = self.kv_zch_params.bucket_sizes[i]
2303+
row_offset = table_offset - bucket_id_start * bucket_size
2304+
2305+
if self.enable_optimizer_offloading:
2306+
# pyre-ignore [16]
2307+
weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
2308+
# pyre-ignore [16]
2309+
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
2310+
self.streaming_write_weight_and_id_per_table(
2311+
weight_state,
2312+
opt_state,
2313+
# pyre-ignore [16]
2314+
self._cached_kvzch_data.cached_id_tensor_per_table[i],
2315+
row_offset,
2316+
)
2317+
self._cached_kvzch_data.cached_weight_tensor_per_table[i] = None
2318+
self._cached_kvzch_data.cached_optimizer_state_per_table[i] = None
2319+
else:
2320+
weight = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
2321+
id = self._cached_kvzch_data.cached_id_tensor_per_table[i]
2322+
local_id = id + row_offset
2323+
logging.info(
2324+
f"applying sd for table {i} without optimizer offloading, local_id is {local_id}"
2325+
)
2326+
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
2327+
t_device = self.momentum1_dev.device
2328+
self.momentum1_dev.index_put_(
2329+
indices=(
2330+
local_id.to(t_device).view(-1),
2331+
), # expects tuple of tensors
2332+
values=opt_state.to(t_device),
2333+
)
2334+
self.ssd_db.set_cuda(
2335+
local_id.view(-1),
2336+
weight,
2337+
torch.as_tensor(local_id.size(0)),
2338+
1,
2339+
False,
2340+
)
2341+
table_offset += emb_height
2342+
self.clear_cache()
2343+
2344+
@torch.jit.ignore
2345+
def streaming_write_weight_and_id_per_table(
2346+
self,
2347+
weight_state: torch.Tensor,
2348+
opt_state: torch.Tensor,
2349+
id_tensor: torch.Tensor,
2350+
row_offset: int,
2351+
) -> None:
2352+
"""
2353+
This function is used to write weight, optimizer and id to the backend using kvt wrapper.
2354+
to avoid over use memory, we will write the weight and id to backend in a rolling window manner
2355+
2356+
Args:
2357+
weight_state (torch.tensor): The weight state tensor to be written.
2358+
opt_state (torch.tensor): The optimizer state tensor to be written.
2359+
id_tensor (torch.tensor): The id tensor to be written.
2360+
"""
2361+
D_rounded = pad4(weight_state.size(1)) # padded to 4 bytes alignment
2362+
dtype = self.weights_precision.as_dtype()
2363+
kvt = torch.classes.fbgemm.KVTensorWrapper(
2364+
db=self.ssd_db,
2365+
shape=[weight_state.size(0), self.cache_row_dim],
2366+
dtype=dtype,
2367+
row_offset=row_offset,
2368+
snapshot_handle=None,
2369+
sorted_indices=id_tensor,
2370+
)
2371+
# TODO: make chunk_size configurable or dynamic
2372+
chunk_size = 10000
2373+
row = weight_state.size(0)
2374+
optimizer_dim = self.optimizer.state_size_dim(dtype)
2375+
opt_state_2d = opt_state.view(dtype).view(-1, optimizer_dim)
2376+
for i in range(0, row, chunk_size):
2377+
length = min(chunk_size, row - i)
2378+
chunk_buffer = torch.empty(
2379+
length,
2380+
self.cache_row_dim,
2381+
dtype=dtype,
2382+
device="cpu",
2383+
)
2384+
chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :]
2385+
chunk_buffer[:, D_rounded : D_rounded + optimizer_dim] = opt_state_2d[
2386+
i : i + length, :
2387+
]
2388+
kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1))
22362389

22372390
@torch.jit.ignore
22382391
def enable_load_state_dict_mode(self) -> None:
22392392
# Enable load state dict mode before loading checkpoint
2240-
pass
2393+
if self.load_state_dict:
2394+
return
2395+
self.load_state_dict = True
2396+
2397+
dtype = self.weights_precision.as_dtype()
2398+
self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
2399+
for i, (_, emb_dim) in enumerate(self.embedding_specs):
2400+
# for checkpointing loading, we need to store the weight and id tensor temporarily in memory
2401+
assert (
2402+
self.local_weight_counts[i] > 0
2403+
), f"local_weight_counts for table {i} is not set"
2404+
# pyre-ignore [16]
2405+
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
2406+
rows = self.local_weight_counts[i]
2407+
weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu")
2408+
opt_state = torch.empty(rows, dtype=torch.float32, device="cpu")
2409+
# pyre-ignore [16]
2410+
self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state)
2411+
# pyre-ignore [16]
2412+
self._cached_kvzch_data.cached_optimizer_state_per_table.append(opt_state)
2413+
logging.info(
2414+
f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}, opt_state shape is {opt_state.shape}"
2415+
)
2416+
id_tensor = torch.zeros(
2417+
(self.local_weight_counts[i], 1), dtype=torch.int64, device="cpu"
2418+
)
2419+
# pyre-ignore [16]
2420+
self._cached_kvzch_data.cached_id_tensor_per_table.append(id_tensor)
2421+
# pyre-ignore [16]
2422+
self._cached_kvzch_data.cached_bucket_splits.append(
2423+
torch.empty(
2424+
(bucket_id_end - bucket_id_start, 1),
2425+
dtype=torch.int64,
2426+
device="cpu",
2427+
)
2428+
)
22412429

22422430
@torch.jit.export
22432431
def set_learning_rate(self, lr: float) -> None:

0 commit comments

Comments
 (0)