@@ -314,7 +314,7 @@ def __init__(
314
314
f"weights precision: { weights_precision } , "
315
315
f"output dtype: { output_dtype } , "
316
316
f"chunk size in bulk init: { bulk_init_chunk_size } bytes, backend_type: { backend_type } , "
317
- f"zero_collision_config : { kv_zch_params } "
317
+ f"kv_zch_params : { kv_zch_params } "
318
318
)
319
319
self .register_buffer (
320
320
"lxu_cache_state" ,
@@ -1983,7 +1983,6 @@ def split_optimizer_states(
1983
1983
dtype = dtype ,
1984
1984
row_offset = row_offset ,
1985
1985
snapshot_handle = snapshot_handle ,
1986
- materialized_shape = ([sorted_id_tensor [t ].size (0 ), emb_opt_dim ]),
1987
1986
sorted_indices = sorted_id_tensor [t ],
1988
1987
)
1989
1988
(
@@ -2112,6 +2111,7 @@ def _may_create_snapshot_for_state_dict(
2112
2111
"""
2113
2112
Create a rocksdb snapshot if needed.
2114
2113
"""
2114
+ start_time = time .time ()
2115
2115
# Force device synchronize for now
2116
2116
torch .cuda .synchronize ()
2117
2117
snapshot_handle = None
@@ -2120,7 +2120,13 @@ def _may_create_snapshot_for_state_dict(
2120
2120
if not no_snapshot :
2121
2121
# Flush L1 and L2 caches
2122
2122
self .flush (force = should_flush )
2123
+ logging .info (
2124
+ f"flush latency for weight states: { (time .time () - start_time ) * 1000 } ms"
2125
+ )
2123
2126
snapshot_handle = self .ssd_db .create_snapshot ()
2127
+ logging .info (
2128
+ f"created snapshot for weight states: { snapshot_handle } , latency: { (time .time () - start_time ) * 1000 } ms"
2129
+ )
2124
2130
elif self .backend_type == BackendType .DRAM :
2125
2131
self .flush (force = should_flush )
2126
2132
return snapshot_handle
@@ -2131,7 +2137,7 @@ def split_embedding_weights(
2131
2137
no_snapshot : bool = True ,
2132
2138
should_flush : bool = False ,
2133
2139
) -> Tuple [ # TODO: make this a NamedTuple for readability
2134
- List [PartiallyMaterializedTensor ],
2140
+ List [PartiallyMaterializedTensor ] | List [ torch . Tensor ] ,
2135
2141
Optional [List [torch .Tensor ]],
2136
2142
Optional [List [torch .Tensor ]],
2137
2143
]:
@@ -2160,6 +2166,17 @@ def split_embedding_weights(
2160
2166
)
2161
2167
2162
2168
dtype = self .weights_precision .as_dtype ()
2169
+ if self .load_state_dict and self .kv_zch_params :
2170
+ # init for checkpointing loading
2171
+ assert (
2172
+ self ._cached_kvzch_data is not None
2173
+ ), "weight id and bucket state are not initialized for load checkpointing"
2174
+ return (
2175
+ self ._cached_kvzch_data .cached_weight_tensor_per_table ,
2176
+ self ._cached_kvzch_data .cached_id_tensor_per_table ,
2177
+ self ._cached_kvzch_data .cached_bucket_splits ,
2178
+ )
2179
+ start_time = time .time ()
2163
2180
pmt_splits = []
2164
2181
bucket_sorted_id_splits = [] if self .kv_zch_params else None
2165
2182
active_id_cnt_per_bucket_split = [] if self .kv_zch_params else None
@@ -2168,18 +2185,15 @@ def split_embedding_weights(
2168
2185
for i , (emb_height , emb_dim ) in enumerate (self .embedding_specs ):
2169
2186
bucket_ascending_id_tensor = None
2170
2187
bucket_t = None
2188
+ row_offset = table_offset
2171
2189
if self .kv_zch_params :
2172
2190
bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
2173
2191
# pyre-ignore
2174
2192
bucket_size = self .kv_zch_params .bucket_sizes [i ]
2175
2193
2176
2194
# linearize with table offset
2177
- table_input_id_start = (
2178
- min (bucket_id_start * bucket_size , emb_height ) + table_offset
2179
- )
2180
- table_input_id_end = (
2181
- min (bucket_id_end * bucket_size , emb_height ) + table_offset
2182
- )
2195
+ table_input_id_start = table_offset
2196
+ table_input_id_end = table_offset + emb_height
2183
2197
# 1. get all keys from backend for one table
2184
2198
unordered_id_tensor = self ._ssd_db .get_keys_in_range_by_snapshot (
2185
2199
table_input_id_start ,
@@ -2192,15 +2206,38 @@ def split_embedding_weights(
2192
2206
torch .ops .fbgemm .get_bucket_sorted_indices_and_bucket_tensor (
2193
2207
unordered_id_tensor ,
2194
2208
0 , # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
2195
- bucket_id_start ,
2196
- bucket_id_end ,
2209
+ 0 , # local bucket offset
2210
+ bucket_id_end - bucket_id_start , # local bucket num
2197
2211
bucket_size ,
2198
2212
)
2199
2213
)
2200
- # pyre-ignore
2214
+ # 3. convert local id back to global id
2215
+ bucket_ascending_id_tensor .add_ (bucket_id_start * bucket_size )
2216
+
2217
+ if (
2218
+ bucket_ascending_id_tensor .size (0 ) == 0
2219
+ and self .local_weight_counts [i ] > 0
2220
+ ):
2221
+ logging .info (
2222
+ f"resetting bucket id tensor with { self .local_weight_counts [i ]} "
2223
+ )
2224
+ bucket_ascending_id_tensor = torch .zeros (
2225
+ (self .local_weight_counts [i ], 1 ),
2226
+ device = torch .device ("cpu" ),
2227
+ dtype = torch .int64 ,
2228
+ )
2229
+ # self.local_weight_counts[i] = 0 # Reset the count
2230
+
2231
+ # pyre-ignore [16] bucket_sorted_id_splits is not None
2201
2232
bucket_sorted_id_splits .append (bucket_ascending_id_tensor )
2202
2233
active_id_cnt_per_bucket_split .append (bucket_t )
2203
2234
2235
+ # for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
2236
+ # but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
2237
+ # first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
2238
+ # to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
2239
+ row_offset = table_offset - (bucket_id_start * bucket_size )
2240
+
2204
2241
tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2205
2242
shape = [
2206
2243
(
@@ -2211,33 +2248,184 @@ def split_embedding_weights(
2211
2248
emb_dim ,
2212
2249
],
2213
2250
dtype = dtype ,
2214
- row_offset = table_offset ,
2251
+ row_offset = row_offset ,
2215
2252
snapshot_handle = snapshot_handle ,
2253
+ # set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
2254
+ # embedding weights.
2216
2255
sorted_indices = (
2217
2256
bucket_ascending_id_tensor if self .kv_zch_params else None
2218
2257
),
2219
2258
)
2220
- # TODO add if else support in the future for dram integration.
2221
- tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2259
+ (
2260
+ tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2261
+ if self .backend_type == BackendType .SSD
2262
+ else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
2263
+ )
2222
2264
table_offset += emb_height
2223
2265
pmt_splits .append (
2224
2266
PartiallyMaterializedTensor (
2225
2267
tensor_wrapper ,
2226
2268
True if self .kv_zch_params else False ,
2227
2269
)
2228
2270
)
2271
+ logging .info (
2272
+ f"split_embedding_weights latency: { (time .time () - start_time ) * 1000 } ms"
2273
+ )
2229
2274
return (pmt_splits , bucket_sorted_id_splits , active_id_cnt_per_bucket_split )
2230
2275
2231
2276
@torch .jit .ignore
2232
2277
def apply_state_dict (self ) -> None :
2233
2278
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
2234
2279
# Caller should call this function to apply the cached states to backend.
2235
- pass
2280
+ if self .load_state_dict is False :
2281
+ return
2282
+ self .load_state_dict = False
2283
+ assert self .kv_zch_params is not None , "apply_state_dict supports KV ZCH only"
2284
+ assert (
2285
+ self ._cached_kvzch_data is not None
2286
+ and self ._cached_kvzch_data .cached_optimizer_state_per_table is not None
2287
+ ), "optimizer state is not initialized for load checkpointing"
2288
+ assert (
2289
+ self ._cached_kvzch_data .cached_weight_tensor_per_table is not None
2290
+ and self ._cached_kvzch_data .cached_id_tensor_per_table is not None
2291
+ ), "weight and id state is not initialized for load checkpointing"
2292
+
2293
+ # Compute the number of elements of cache_dtype needed to store the
2294
+ # optimizer state, round to the nearest 4
2295
+ # optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
2296
+ # apply weight and optimizer state per table
2297
+ table_offset = 0
2298
+ for i , (emb_height , _ ) in enumerate (self .embedding_specs ):
2299
+ # pyre-ignore [16]
2300
+ bucket_id_start , _ = self .kv_zch_params .bucket_offsets [i ]
2301
+ # pyre-ignore [16]
2302
+ bucket_size = self .kv_zch_params .bucket_sizes [i ]
2303
+ row_offset = table_offset - bucket_id_start * bucket_size
2304
+
2305
+ if self .enable_optimizer_offloading :
2306
+ # pyre-ignore [16]
2307
+ weight_state = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2308
+ # pyre-ignore [16]
2309
+ opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2310
+ self .streaming_write_weight_and_id_per_table (
2311
+ weight_state ,
2312
+ opt_state ,
2313
+ # pyre-ignore [16]
2314
+ self ._cached_kvzch_data .cached_id_tensor_per_table [i ],
2315
+ row_offset ,
2316
+ )
2317
+ self ._cached_kvzch_data .cached_weight_tensor_per_table [i ] = None
2318
+ self ._cached_kvzch_data .cached_optimizer_state_per_table [i ] = None
2319
+ else :
2320
+ weight = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2321
+ id = self ._cached_kvzch_data .cached_id_tensor_per_table [i ]
2322
+ local_id = id + row_offset
2323
+ logging .info (
2324
+ f"applying sd for table { i } without optimizer offloading, local_id is { local_id } "
2325
+ )
2326
+ opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2327
+ t_device = self .momentum1_dev .device
2328
+ self .momentum1_dev .index_put_ (
2329
+ indices = (
2330
+ local_id .to (t_device ).view (- 1 ),
2331
+ ), # expects tuple of tensors
2332
+ values = opt_state .to (t_device ),
2333
+ )
2334
+ self .ssd_db .set_cuda (
2335
+ local_id .view (- 1 ),
2336
+ weight ,
2337
+ torch .as_tensor (local_id .size (0 )),
2338
+ 1 ,
2339
+ False ,
2340
+ )
2341
+ table_offset += emb_height
2342
+ self .clear_cache ()
2343
+
2344
+ @torch .jit .ignore
2345
+ def streaming_write_weight_and_id_per_table (
2346
+ self ,
2347
+ weight_state : torch .Tensor ,
2348
+ opt_state : torch .Tensor ,
2349
+ id_tensor : torch .Tensor ,
2350
+ row_offset : int ,
2351
+ ) -> None :
2352
+ """
2353
+ This function is used to write weight, optimizer and id to the backend using kvt wrapper.
2354
+ to avoid over use memory, we will write the weight and id to backend in a rolling window manner
2355
+
2356
+ Args:
2357
+ weight_state (torch.tensor): The weight state tensor to be written.
2358
+ opt_state (torch.tensor): The optimizer state tensor to be written.
2359
+ id_tensor (torch.tensor): The id tensor to be written.
2360
+ """
2361
+ D_rounded = pad4 (weight_state .size (1 )) # padded to 4 bytes alignment
2362
+ dtype = self .weights_precision .as_dtype ()
2363
+ kvt = torch .classes .fbgemm .KVTensorWrapper (
2364
+ db = self .ssd_db ,
2365
+ shape = [weight_state .size (0 ), self .cache_row_dim ],
2366
+ dtype = dtype ,
2367
+ row_offset = row_offset ,
2368
+ snapshot_handle = None ,
2369
+ sorted_indices = id_tensor ,
2370
+ )
2371
+ # TODO: make chunk_size configurable or dynamic
2372
+ chunk_size = 10000
2373
+ row = weight_state .size (0 )
2374
+ optimizer_dim = self .optimizer .state_size_dim (dtype )
2375
+ opt_state_2d = opt_state .view (dtype ).view (- 1 , optimizer_dim )
2376
+ for i in range (0 , row , chunk_size ):
2377
+ length = min (chunk_size , row - i )
2378
+ chunk_buffer = torch .empty (
2379
+ length ,
2380
+ self .cache_row_dim ,
2381
+ dtype = dtype ,
2382
+ device = "cpu" ,
2383
+ )
2384
+ chunk_buffer [:, : weight_state .size (1 )] = weight_state [i : i + length , :]
2385
+ chunk_buffer [:, D_rounded : D_rounded + optimizer_dim ] = opt_state_2d [
2386
+ i : i + length , :
2387
+ ]
2388
+ kvt .set_weights_and_ids (chunk_buffer , id_tensor [i : i + length , :].view (- 1 ))
2236
2389
2237
2390
@torch .jit .ignore
2238
2391
def enable_load_state_dict_mode (self ) -> None :
2239
2392
# Enable load state dict mode before loading checkpoint
2240
- pass
2393
+ if self .load_state_dict :
2394
+ return
2395
+ self .load_state_dict = True
2396
+
2397
+ dtype = self .weights_precision .as_dtype ()
2398
+ self ._cached_kvzch_data = KVZCHCachedData ([], [], [], [])
2399
+ for i , (_ , emb_dim ) in enumerate (self .embedding_specs ):
2400
+ # for checkpointing loading, we need to store the weight and id tensor temporarily in memory
2401
+ assert (
2402
+ self .local_weight_counts [i ] > 0
2403
+ ), f"local_weight_counts for table { i } is not set"
2404
+ # pyre-ignore [16]
2405
+ bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
2406
+ rows = self .local_weight_counts [i ]
2407
+ weight_state = torch .empty (rows , emb_dim , dtype = dtype , device = "cpu" )
2408
+ opt_state = torch .empty (rows , dtype = torch .float32 , device = "cpu" )
2409
+ # pyre-ignore [16]
2410
+ self ._cached_kvzch_data .cached_weight_tensor_per_table .append (weight_state )
2411
+ # pyre-ignore [16]
2412
+ self ._cached_kvzch_data .cached_optimizer_state_per_table .append (opt_state )
2413
+ logging .info (
2414
+ f"for checkpoint loading, table { i } , weight_state shape is { weight_state .shape } , opt_state shape is { opt_state .shape } "
2415
+ )
2416
+ id_tensor = torch .zeros (
2417
+ (self .local_weight_counts [i ], 1 ), dtype = torch .int64 , device = "cpu"
2418
+ )
2419
+ # pyre-ignore [16]
2420
+ self ._cached_kvzch_data .cached_id_tensor_per_table .append (id_tensor )
2421
+ # pyre-ignore [16]
2422
+ self ._cached_kvzch_data .cached_bucket_splits .append (
2423
+ torch .empty (
2424
+ (bucket_id_end - bucket_id_start , 1 ),
2425
+ dtype = torch .int64 ,
2426
+ device = "cpu" ,
2427
+ )
2428
+ )
2241
2429
2242
2430
@torch .jit .export
2243
2431
def set_learning_rate (self , lr : float ) -> None :
0 commit comments