Skip to content

Commit 4f5149c

Browse files
emlinfacebook-github-bot
authored andcommitted
support get state dict and apply state dict (#4145)
Summary: X-link: pytorch/torchrec#2976 X-link: facebookresearch/FBGEMM#1226 **Saving State Dict** When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded. **Checkpoint Loading Mode** We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading. **Current Solution** The current solution involves caching all data in Python tensors, following these steps: - Set self.local_weight_counts based on checkpoint bucket tensor size. - Enable load state dict mode to initialize local cache tensors. - Call state_dict to get empty tensors for the checkpoint loader. - Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader. - Call apply_state_dict to write all cached tensors to the backend. **Apply State Dict Flow** During the apply_state_dict step, we perform the following operations: - If optimizer offloading is enabled: - Loop through chunks of weight and optimizer. - Concatenate weight and optimizer together. - Write to backend using KVTensorWrapper interface. - If optimizer offloading is disabled: - Set optimizer to device tensor based on ID. - Write ID weight to backend for each table. **Limitations** The current solution has two limitations: - Memory overhead: - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables. - Performance regression: - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression. **Future Improvements** After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead. Differential Revision: D74790154
1 parent e478779 commit 4f5149c

File tree

2 files changed

+417
-17
lines changed

2 files changed

+417
-17
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 205 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def __init__(
318318
f"weights precision: {weights_precision}, "
319319
f"output dtype: {output_dtype}, "
320320
f"chunk size in bulk init: {bulk_init_chunk_size} bytes, backend_type: {backend_type}, "
321-
f"zero_collision_config: {kv_zch_params}"
321+
f"kv_zch_params: {kv_zch_params}"
322322
)
323323
self.register_buffer(
324324
"lxu_cache_state",
@@ -2009,7 +2009,6 @@ def split_optimizer_states(
20092009
dtype=dtype,
20102010
row_offset=row_offset,
20112011
snapshot_handle=snapshot_handle,
2012-
materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]),
20132012
sorted_indices=sorted_id_tensor[t],
20142013
)
20152014
(
@@ -2134,7 +2133,7 @@ def split_embedding_weights(
21342133
no_snapshot: bool = True,
21352134
should_flush: bool = False,
21362135
) -> Tuple[ # TODO: make this a NamedTuple for readability
2137-
List[PartiallyMaterializedTensor],
2136+
List[PartiallyMaterializedTensor] | List[torch.Tensor],
21382137
Optional[List[torch.Tensor]],
21392138
Optional[List[torch.Tensor]],
21402139
]:
@@ -2157,6 +2156,7 @@ def split_embedding_weights(
21572156
3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
21582157
where for the i th element, we have i + bucket_id_start = global bucket id
21592158
"""
2159+
start_time = time.time()
21602160
# Force device synchronize for now
21612161
torch.cuda.synchronize()
21622162
snapshot_handle = None
@@ -2166,11 +2166,28 @@ def split_embedding_weights(
21662166
if should_flush:
21672167
# Flush L1 and L2 caches
21682168
self.flush(force=True)
2169+
logging.info(
2170+
f"flush latency for weight states: {(time.time() - start_time) * 1000} ms"
2171+
)
21692172
snapshot_handle = self.ssd_db.create_snapshot()
2173+
logging.info(
2174+
f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms"
2175+
)
21702176
elif self.backend_type == BackendType.DRAM:
21712177
self.flush(force=True)
21722178

21732179
dtype = self.weights_precision.as_dtype()
2180+
if self.load_state_dict and self.kv_zch_params:
2181+
# init for checkpointing loading
2182+
assert (
2183+
self._cached_kvzch_data is not None
2184+
), "weight id and bucket state are not initialized for load checkpointing"
2185+
return (
2186+
self._cached_kvzch_data.cached_weight_tensor_per_table,
2187+
self._cached_kvzch_data.cached_id_tensor_per_table,
2188+
self._cached_kvzch_data.cached_bucket_splits,
2189+
)
2190+
21742191
pmt_splits = []
21752192
bucket_sorted_id_splits = [] if self.kv_zch_params else None
21762193
active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
@@ -2179,18 +2196,15 @@ def split_embedding_weights(
21792196
for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
21802197
bucket_ascending_id_tensor = None
21812198
bucket_t = None
2199+
row_offset = table_offset
21822200
if self.kv_zch_params:
21832201
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
21842202
# pyre-ignore
21852203
bucket_size = self.kv_zch_params.bucket_sizes[i]
21862204

21872205
# linearize with table offset
2188-
table_input_id_start = (
2189-
min(bucket_id_start * bucket_size, emb_height) + table_offset
2190-
)
2191-
table_input_id_end = (
2192-
min(bucket_id_end * bucket_size, emb_height) + table_offset
2193-
)
2206+
table_input_id_start = table_offset
2207+
table_input_id_end = table_offset + emb_height
21942208
# 1. get all keys from backend for one table
21952209
unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot(
21962210
table_input_id_start,
@@ -2203,15 +2217,38 @@ def split_embedding_weights(
22032217
torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor(
22042218
unordered_id_tensor,
22052219
0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
2206-
bucket_id_start,
2207-
bucket_id_end,
2220+
0, # local bucket offset
2221+
bucket_id_end - bucket_id_start, # local bucket num
22082222
bucket_size,
22092223
)
22102224
)
2211-
# pyre-ignore
2225+
# 3. convert local id back to global id
2226+
bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)
2227+
2228+
if (
2229+
bucket_ascending_id_tensor.size(0) == 0
2230+
and self.local_weight_counts[i] > 0
2231+
):
2232+
logging.info(
2233+
f"resetting bucket id tensor with {self.local_weight_counts[i]}"
2234+
)
2235+
bucket_ascending_id_tensor = torch.zeros(
2236+
(self.local_weight_counts[i], 1),
2237+
device=torch.device("cpu"),
2238+
dtype=torch.int64,
2239+
)
2240+
# self.local_weight_counts[i] = 0 # Reset the count
2241+
2242+
# pyre-ignore [16] bucket_sorted_id_splits is not None
22122243
bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
22132244
active_id_cnt_per_bucket_split.append(bucket_t)
22142245

2246+
# for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
2247+
# but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
2248+
# first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
2249+
# to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
2250+
row_offset = table_offset - (bucket_id_start * bucket_size)
2251+
22152252
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
22162253
shape=[
22172254
(
@@ -2222,33 +2259,184 @@ def split_embedding_weights(
22222259
emb_dim,
22232260
],
22242261
dtype=dtype,
2225-
row_offset=table_offset,
2262+
row_offset=row_offset,
22262263
snapshot_handle=snapshot_handle,
2264+
# set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
2265+
# embedding weights.
22272266
sorted_indices=(
22282267
bucket_ascending_id_tensor if self.kv_zch_params else None
22292268
),
22302269
)
2231-
# TODO add if else support in the future for dram integration.
2232-
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2270+
(
2271+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2272+
if self.backend_type == BackendType.SSD
2273+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2274+
)
22332275
table_offset += emb_height
22342276
pmt_splits.append(
22352277
PartiallyMaterializedTensor(
22362278
tensor_wrapper,
22372279
True if self.kv_zch_params else False,
22382280
)
22392281
)
2282+
logging.info(
2283+
f"split_embedding_weights latency: {(time.time() - start_time) * 1000} ms"
2284+
)
22402285
return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)
22412286

22422287
@torch.jit.ignore
22432288
def apply_state_dict(self) -> None:
22442289
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
22452290
# Caller should call this function to apply the cached states to backend.
2246-
pass
2291+
if self.load_state_dict is False:
2292+
return
2293+
self.load_state_dict = False
2294+
assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only"
2295+
assert (
2296+
self._cached_kvzch_data is not None
2297+
and self._cached_kvzch_data.cached_optimizer_state_per_table is not None
2298+
), "optimizer state is not initialized for load checkpointing"
2299+
assert (
2300+
self._cached_kvzch_data.cached_weight_tensor_per_table is not None
2301+
and self._cached_kvzch_data.cached_id_tensor_per_table is not None
2302+
), "weight and id state is not initialized for load checkpointing"
2303+
2304+
# Compute the number of elements of cache_dtype needed to store the
2305+
# optimizer state, round to the nearest 4
2306+
# optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
2307+
# apply weight and optimizer state per table
2308+
table_offset = 0
2309+
for i, (emb_height, _) in enumerate(self.embedding_specs):
2310+
# pyre-ignore [16]
2311+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[i]
2312+
# pyre-ignore [16]
2313+
bucket_size = self.kv_zch_params.bucket_sizes[i]
2314+
row_offset = table_offset - bucket_id_start * bucket_size
2315+
2316+
if self.enable_optimizer_offloading:
2317+
# pyre-ignore [16]
2318+
weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
2319+
# pyre-ignore [16]
2320+
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
2321+
self.streaming_write_weight_and_id_per_table(
2322+
weight_state,
2323+
opt_state,
2324+
# pyre-ignore [16]
2325+
self._cached_kvzch_data.cached_id_tensor_per_table[i],
2326+
row_offset,
2327+
)
2328+
self._cached_kvzch_data.cached_weight_tensor_per_table[i] = None
2329+
self._cached_kvzch_data.cached_optimizer_state_per_table[i] = None
2330+
else:
2331+
weight = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
2332+
id = self._cached_kvzch_data.cached_id_tensor_per_table[i]
2333+
local_id = id + row_offset
2334+
logging.info(
2335+
f"applying sd for table {i} without optimizer offloading, local_id is {local_id}"
2336+
)
2337+
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
2338+
t_device = self.momentum1_dev.device
2339+
self.momentum1_dev.index_put_(
2340+
indices=(
2341+
local_id.to(t_device).view(-1),
2342+
), # expects tuple of tensors
2343+
values=opt_state.to(t_device),
2344+
)
2345+
self.ssd_db.set_cuda(
2346+
local_id.view(-1),
2347+
weight,
2348+
torch.as_tensor(local_id.size(0)),
2349+
1,
2350+
False,
2351+
)
2352+
table_offset += emb_height
2353+
self.clear_cache()
2354+
2355+
@torch.jit.ignore
2356+
def streaming_write_weight_and_id_per_table(
2357+
self,
2358+
weight_state: torch.Tensor,
2359+
opt_state: torch.Tensor,
2360+
id_tensor: torch.Tensor,
2361+
row_offset: int,
2362+
) -> None:
2363+
"""
2364+
This function is used to write weight, optimizer and id to the backend using kvt wrapper.
2365+
to avoid over use memory, we will write the weight and id to backend in a rolling window manner
2366+
2367+
Args:
2368+
weight_state (torch.tensor): The weight state tensor to be written.
2369+
opt_state (torch.tensor): The optimizer state tensor to be written.
2370+
id_tensor (torch.tensor): The id tensor to be written.
2371+
"""
2372+
D_rounded = pad4(weight_state.size(1)) # padded to 4 bytes alignment
2373+
dtype = self.weights_precision.as_dtype()
2374+
kvt = torch.classes.fbgemm.KVTensorWrapper(
2375+
db=self.ssd_db,
2376+
shape=[weight_state.size(0), self.cache_row_dim],
2377+
dtype=dtype,
2378+
row_offset=row_offset,
2379+
snapshot_handle=None,
2380+
sorted_indices=id_tensor,
2381+
)
2382+
# TODO: make chunk_size configurable or dynamic
2383+
chunk_size = 10000
2384+
row = weight_state.size(0)
2385+
optimizer_dim = self.optimizer.state_size_dim(dtype)
2386+
opt_state_2d = opt_state.view(dtype).view(-1, optimizer_dim)
2387+
for i in range(0, row, chunk_size):
2388+
length = min(chunk_size, row - i)
2389+
chunk_buffer = torch.empty(
2390+
length,
2391+
self.cache_row_dim,
2392+
dtype=dtype,
2393+
device="cpu",
2394+
)
2395+
chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :]
2396+
chunk_buffer[:, D_rounded : D_rounded + optimizer_dim] = opt_state_2d[
2397+
i : i + length, :
2398+
]
2399+
kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1))
22472400

22482401
@torch.jit.ignore
22492402
def enable_load_state_dict_mode(self) -> None:
22502403
# Enable load state dict mode before loading checkpoint
2251-
pass
2404+
if self.load_state_dict:
2405+
return
2406+
self.load_state_dict = True
2407+
2408+
dtype = self.weights_precision.as_dtype()
2409+
self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
2410+
for i, (_, emb_dim) in enumerate(self.embedding_specs):
2411+
# for checkpointing loading, we need to store the weight and id tensor temporarily in memory
2412+
assert (
2413+
self.local_weight_counts[i] > 0
2414+
), f"local_weight_counts for table {i} is not set"
2415+
# pyre-ignore [16]
2416+
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
2417+
rows = self.local_weight_counts[i]
2418+
weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu")
2419+
opt_state = torch.empty(rows, dtype=torch.float32, device="cpu")
2420+
# pyre-ignore [16]
2421+
self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state)
2422+
# pyre-ignore [16]
2423+
self._cached_kvzch_data.cached_optimizer_state_per_table.append(opt_state)
2424+
logging.info(
2425+
f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}, opt_state shape is {opt_state.shape}"
2426+
)
2427+
id_tensor = torch.zeros(
2428+
(self.local_weight_counts[i], 1), dtype=torch.int64, device="cpu"
2429+
)
2430+
# pyre-ignore [16]
2431+
self._cached_kvzch_data.cached_id_tensor_per_table.append(id_tensor)
2432+
# pyre-ignore [16]
2433+
self._cached_kvzch_data.cached_bucket_splits.append(
2434+
torch.empty(
2435+
(bucket_id_end - bucket_id_start, 1),
2436+
dtype=torch.int64,
2437+
device="cpu",
2438+
)
2439+
)
22522440

22532441
@torch.jit.export
22542442
def set_learning_rate(self, lr: float) -> None:

0 commit comments

Comments
 (0)