@@ -318,7 +318,7 @@ def __init__(
318
318
f"weights precision: { weights_precision } , "
319
319
f"output dtype: { output_dtype } , "
320
320
f"chunk size in bulk init: { bulk_init_chunk_size } bytes, backend_type: { backend_type } , "
321
- f"zero_collision_config : { kv_zch_params } "
321
+ f"kv_zch_params : { kv_zch_params } "
322
322
)
323
323
self .register_buffer (
324
324
"lxu_cache_state" ,
@@ -2009,7 +2009,6 @@ def split_optimizer_states(
2009
2009
dtype = dtype ,
2010
2010
row_offset = row_offset ,
2011
2011
snapshot_handle = snapshot_handle ,
2012
- materialized_shape = ([sorted_id_tensor [t ].size (0 ), emb_opt_dim ]),
2013
2012
sorted_indices = sorted_id_tensor [t ],
2014
2013
)
2015
2014
(
@@ -2134,7 +2133,7 @@ def split_embedding_weights(
2134
2133
no_snapshot : bool = True ,
2135
2134
should_flush : bool = False ,
2136
2135
) -> Tuple [ # TODO: make this a NamedTuple for readability
2137
- List [PartiallyMaterializedTensor ],
2136
+ List [PartiallyMaterializedTensor ] | List [ torch . Tensor ] ,
2138
2137
Optional [List [torch .Tensor ]],
2139
2138
Optional [List [torch .Tensor ]],
2140
2139
]:
@@ -2157,6 +2156,7 @@ def split_embedding_weights(
2157
2156
3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
2158
2157
where for the i th element, we have i + bucket_id_start = global bucket id
2159
2158
"""
2159
+ start_time = time .time ()
2160
2160
# Force device synchronize for now
2161
2161
torch .cuda .synchronize ()
2162
2162
snapshot_handle = None
@@ -2166,11 +2166,28 @@ def split_embedding_weights(
2166
2166
if should_flush :
2167
2167
# Flush L1 and L2 caches
2168
2168
self .flush (force = True )
2169
+ logging .info (
2170
+ f"flush latency for weight states: { (time .time () - start_time ) * 1000 } ms"
2171
+ )
2169
2172
snapshot_handle = self .ssd_db .create_snapshot ()
2173
+ logging .info (
2174
+ f"created snapshot for weight states: { snapshot_handle } , latency: { (time .time () - start_time ) * 1000 } ms"
2175
+ )
2170
2176
elif self .backend_type == BackendType .DRAM :
2171
2177
self .flush (force = True )
2172
2178
2173
2179
dtype = self .weights_precision .as_dtype ()
2180
+ if self .load_state_dict and self .kv_zch_params :
2181
+ # init for checkpointing loading
2182
+ assert (
2183
+ self ._cached_kvzch_data is not None
2184
+ ), "weight id and bucket state are not initialized for load checkpointing"
2185
+ return (
2186
+ self ._cached_kvzch_data .cached_weight_tensor_per_table ,
2187
+ self ._cached_kvzch_data .cached_id_tensor_per_table ,
2188
+ self ._cached_kvzch_data .cached_bucket_splits ,
2189
+ )
2190
+
2174
2191
pmt_splits = []
2175
2192
bucket_sorted_id_splits = [] if self .kv_zch_params else None
2176
2193
active_id_cnt_per_bucket_split = [] if self .kv_zch_params else None
@@ -2179,18 +2196,15 @@ def split_embedding_weights(
2179
2196
for i , (emb_height , emb_dim ) in enumerate (self .embedding_specs ):
2180
2197
bucket_ascending_id_tensor = None
2181
2198
bucket_t = None
2199
+ row_offset = table_offset
2182
2200
if self .kv_zch_params :
2183
2201
bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
2184
2202
# pyre-ignore
2185
2203
bucket_size = self .kv_zch_params .bucket_sizes [i ]
2186
2204
2187
2205
# linearize with table offset
2188
- table_input_id_start = (
2189
- min (bucket_id_start * bucket_size , emb_height ) + table_offset
2190
- )
2191
- table_input_id_end = (
2192
- min (bucket_id_end * bucket_size , emb_height ) + table_offset
2193
- )
2206
+ table_input_id_start = table_offset
2207
+ table_input_id_end = table_offset + emb_height
2194
2208
# 1. get all keys from backend for one table
2195
2209
unordered_id_tensor = self ._ssd_db .get_keys_in_range_by_snapshot (
2196
2210
table_input_id_start ,
@@ -2203,15 +2217,38 @@ def split_embedding_weights(
2203
2217
torch .ops .fbgemm .get_bucket_sorted_indices_and_bucket_tensor (
2204
2218
unordered_id_tensor ,
2205
2219
0 , # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
2206
- bucket_id_start ,
2207
- bucket_id_end ,
2220
+ 0 , # local bucket offset
2221
+ bucket_id_end - bucket_id_start , # local bucket num
2208
2222
bucket_size ,
2209
2223
)
2210
2224
)
2211
- # pyre-ignore
2225
+ # 3. convert local id back to global id
2226
+ bucket_ascending_id_tensor .add_ (bucket_id_start * bucket_size )
2227
+
2228
+ if (
2229
+ bucket_ascending_id_tensor .size (0 ) == 0
2230
+ and self .local_weight_counts [i ] > 0
2231
+ ):
2232
+ logging .info (
2233
+ f"resetting bucket id tensor with { self .local_weight_counts [i ]} "
2234
+ )
2235
+ bucket_ascending_id_tensor = torch .zeros (
2236
+ (self .local_weight_counts [i ], 1 ),
2237
+ device = torch .device ("cpu" ),
2238
+ dtype = torch .int64 ,
2239
+ )
2240
+ # self.local_weight_counts[i] = 0 # Reset the count
2241
+
2242
+ # pyre-ignore [16] bucket_sorted_id_splits is not None
2212
2243
bucket_sorted_id_splits .append (bucket_ascending_id_tensor )
2213
2244
active_id_cnt_per_bucket_split .append (bucket_t )
2214
2245
2246
+ # for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
2247
+ # but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
2248
+ # first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
2249
+ # to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
2250
+ row_offset = table_offset - (bucket_id_start * bucket_size )
2251
+
2215
2252
tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2216
2253
shape = [
2217
2254
(
@@ -2222,33 +2259,184 @@ def split_embedding_weights(
2222
2259
emb_dim ,
2223
2260
],
2224
2261
dtype = dtype ,
2225
- row_offset = table_offset ,
2262
+ row_offset = row_offset ,
2226
2263
snapshot_handle = snapshot_handle ,
2264
+ # set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
2265
+ # embedding weights.
2227
2266
sorted_indices = (
2228
2267
bucket_ascending_id_tensor if self .kv_zch_params else None
2229
2268
),
2230
2269
)
2231
- # TODO add if else support in the future for dram integration.
2232
- tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2270
+ (
2271
+ tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2272
+ if self .backend_type == BackendType .SSD
2273
+ else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
2274
+ )
2233
2275
table_offset += emb_height
2234
2276
pmt_splits .append (
2235
2277
PartiallyMaterializedTensor (
2236
2278
tensor_wrapper ,
2237
2279
True if self .kv_zch_params else False ,
2238
2280
)
2239
2281
)
2282
+ logging .info (
2283
+ f"split_embedding_weights latency: { (time .time () - start_time ) * 1000 } ms"
2284
+ )
2240
2285
return (pmt_splits , bucket_sorted_id_splits , active_id_cnt_per_bucket_split )
2241
2286
2242
2287
@torch .jit .ignore
2243
2288
def apply_state_dict (self ) -> None :
2244
2289
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
2245
2290
# Caller should call this function to apply the cached states to backend.
2246
- pass
2291
+ if self .load_state_dict is False :
2292
+ return
2293
+ self .load_state_dict = False
2294
+ assert self .kv_zch_params is not None , "apply_state_dict supports KV ZCH only"
2295
+ assert (
2296
+ self ._cached_kvzch_data is not None
2297
+ and self ._cached_kvzch_data .cached_optimizer_state_per_table is not None
2298
+ ), "optimizer state is not initialized for load checkpointing"
2299
+ assert (
2300
+ self ._cached_kvzch_data .cached_weight_tensor_per_table is not None
2301
+ and self ._cached_kvzch_data .cached_id_tensor_per_table is not None
2302
+ ), "weight and id state is not initialized for load checkpointing"
2303
+
2304
+ # Compute the number of elements of cache_dtype needed to store the
2305
+ # optimizer state, round to the nearest 4
2306
+ # optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
2307
+ # apply weight and optimizer state per table
2308
+ table_offset = 0
2309
+ for i , (emb_height , _ ) in enumerate (self .embedding_specs ):
2310
+ # pyre-ignore [16]
2311
+ bucket_id_start , _ = self .kv_zch_params .bucket_offsets [i ]
2312
+ # pyre-ignore [16]
2313
+ bucket_size = self .kv_zch_params .bucket_sizes [i ]
2314
+ row_offset = table_offset - bucket_id_start * bucket_size
2315
+
2316
+ if self .enable_optimizer_offloading :
2317
+ # pyre-ignore [16]
2318
+ weight_state = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2319
+ # pyre-ignore [16]
2320
+ opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2321
+ self .streaming_write_weight_and_id_per_table (
2322
+ weight_state ,
2323
+ opt_state ,
2324
+ # pyre-ignore [16]
2325
+ self ._cached_kvzch_data .cached_id_tensor_per_table [i ],
2326
+ row_offset ,
2327
+ )
2328
+ self ._cached_kvzch_data .cached_weight_tensor_per_table [i ] = None
2329
+ self ._cached_kvzch_data .cached_optimizer_state_per_table [i ] = None
2330
+ else :
2331
+ weight = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2332
+ id = self ._cached_kvzch_data .cached_id_tensor_per_table [i ]
2333
+ local_id = id + row_offset
2334
+ logging .info (
2335
+ f"applying sd for table { i } without optimizer offloading, local_id is { local_id } "
2336
+ )
2337
+ opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2338
+ t_device = self .momentum1_dev .device
2339
+ self .momentum1_dev .index_put_ (
2340
+ indices = (
2341
+ local_id .to (t_device ).view (- 1 ),
2342
+ ), # expects tuple of tensors
2343
+ values = opt_state .to (t_device ),
2344
+ )
2345
+ self .ssd_db .set_cuda (
2346
+ local_id .view (- 1 ),
2347
+ weight ,
2348
+ torch .as_tensor (local_id .size (0 )),
2349
+ 1 ,
2350
+ False ,
2351
+ )
2352
+ table_offset += emb_height
2353
+ self .clear_cache ()
2354
+
2355
+ @torch .jit .ignore
2356
+ def streaming_write_weight_and_id_per_table (
2357
+ self ,
2358
+ weight_state : torch .Tensor ,
2359
+ opt_state : torch .Tensor ,
2360
+ id_tensor : torch .Tensor ,
2361
+ row_offset : int ,
2362
+ ) -> None :
2363
+ """
2364
+ This function is used to write weight, optimizer and id to the backend using kvt wrapper.
2365
+ to avoid over use memory, we will write the weight and id to backend in a rolling window manner
2366
+
2367
+ Args:
2368
+ weight_state (torch.tensor): The weight state tensor to be written.
2369
+ opt_state (torch.tensor): The optimizer state tensor to be written.
2370
+ id_tensor (torch.tensor): The id tensor to be written.
2371
+ """
2372
+ D_rounded = pad4 (weight_state .size (1 )) # padded to 4 bytes alignment
2373
+ dtype = self .weights_precision .as_dtype ()
2374
+ kvt = torch .classes .fbgemm .KVTensorWrapper (
2375
+ db = self .ssd_db ,
2376
+ shape = [weight_state .size (0 ), self .cache_row_dim ],
2377
+ dtype = dtype ,
2378
+ row_offset = row_offset ,
2379
+ snapshot_handle = None ,
2380
+ sorted_indices = id_tensor ,
2381
+ )
2382
+ # TODO: make chunk_size configurable or dynamic
2383
+ chunk_size = 10000
2384
+ row = weight_state .size (0 )
2385
+ optimizer_dim = self .optimizer .state_size_dim (dtype )
2386
+ opt_state_2d = opt_state .view (dtype ).view (- 1 , optimizer_dim )
2387
+ for i in range (0 , row , chunk_size ):
2388
+ length = min (chunk_size , row - i )
2389
+ chunk_buffer = torch .empty (
2390
+ length ,
2391
+ self .cache_row_dim ,
2392
+ dtype = dtype ,
2393
+ device = "cpu" ,
2394
+ )
2395
+ chunk_buffer [:, : weight_state .size (1 )] = weight_state [i : i + length , :]
2396
+ chunk_buffer [:, D_rounded : D_rounded + optimizer_dim ] = opt_state_2d [
2397
+ i : i + length , :
2398
+ ]
2399
+ kvt .set_weights_and_ids (chunk_buffer , id_tensor [i : i + length , :].view (- 1 ))
2247
2400
2248
2401
@torch .jit .ignore
2249
2402
def enable_load_state_dict_mode (self ) -> None :
2250
2403
# Enable load state dict mode before loading checkpoint
2251
- pass
2404
+ if self .load_state_dict :
2405
+ return
2406
+ self .load_state_dict = True
2407
+
2408
+ dtype = self .weights_precision .as_dtype ()
2409
+ self ._cached_kvzch_data = KVZCHCachedData ([], [], [], [])
2410
+ for i , (_ , emb_dim ) in enumerate (self .embedding_specs ):
2411
+ # for checkpointing loading, we need to store the weight and id tensor temporarily in memory
2412
+ assert (
2413
+ self .local_weight_counts [i ] > 0
2414
+ ), f"local_weight_counts for table { i } is not set"
2415
+ # pyre-ignore [16]
2416
+ bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
2417
+ rows = self .local_weight_counts [i ]
2418
+ weight_state = torch .empty (rows , emb_dim , dtype = dtype , device = "cpu" )
2419
+ opt_state = torch .empty (rows , dtype = torch .float32 , device = "cpu" )
2420
+ # pyre-ignore [16]
2421
+ self ._cached_kvzch_data .cached_weight_tensor_per_table .append (weight_state )
2422
+ # pyre-ignore [16]
2423
+ self ._cached_kvzch_data .cached_optimizer_state_per_table .append (opt_state )
2424
+ logging .info (
2425
+ f"for checkpoint loading, table { i } , weight_state shape is { weight_state .shape } , opt_state shape is { opt_state .shape } "
2426
+ )
2427
+ id_tensor = torch .zeros (
2428
+ (self .local_weight_counts [i ], 1 ), dtype = torch .int64 , device = "cpu"
2429
+ )
2430
+ # pyre-ignore [16]
2431
+ self ._cached_kvzch_data .cached_id_tensor_per_table .append (id_tensor )
2432
+ # pyre-ignore [16]
2433
+ self ._cached_kvzch_data .cached_bucket_splits .append (
2434
+ torch .empty (
2435
+ (bucket_id_end - bucket_id_start , 1 ),
2436
+ dtype = torch .int64 ,
2437
+ device = "cpu" ,
2438
+ )
2439
+ )
2252
2440
2253
2441
@torch .jit .export
2254
2442
def set_learning_rate (self , lr : float ) -> None :
0 commit comments