|
25 | 25 | from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
|
26 | 26 | ComputeDevice,
|
27 | 27 | DenseTableBatchedEmbeddingBagsCodegen,
|
| 28 | + get_available_compute_device, |
28 | 29 | SplitTableBatchedEmbeddingBagsCodegen,
|
29 | 30 | )
|
30 | 31 | from fbgemm_gpu.tbe.bench import (
|
31 | 32 | benchmark_requests,
|
| 33 | + benchmark_requests_with_spec, |
32 | 34 | EmbeddingOpsCommonConfigLoader,
|
33 | 35 | TBEBenchmarkingConfigLoader,
|
| 36 | + TBEDataConfigListLoader, |
34 | 37 | TBEDataConfigLoader,
|
35 | 38 | )
|
36 | 39 | from fbgemm_gpu.tbe.ssd import SSDTableBatchedEmbeddingBags
|
|
41 | 44 | logger.setLevel(logging.DEBUG)
|
42 | 45 | logging.basicConfig(level=logging.DEBUG)
|
43 | 46 |
|
| 47 | +try: |
| 48 | + import mtia.host_runtime.torch_mtia.dynamic_library # pyright: ignore # noqa: F401 # pyre-ignore[21] |
| 49 | + |
| 50 | + torch.mtia.init() |
| 51 | +except Exception: |
| 52 | + pass |
| 53 | + |
44 | 54 |
|
45 | 55 | @click.group()
|
46 | 56 | def cli() -> None:
|
@@ -288,5 +298,321 @@ def _context_factory(on_trace_ready: Callable[[profile], None]):
|
288 | 298 | )
|
289 | 299 |
|
290 | 300 |
|
| 301 | +@cli.command() |
| 302 | +@click.option( |
| 303 | + "--emb-op-type", |
| 304 | + default="split", |
| 305 | + type=click.Choice(["split", "dense", "ssd"], case_sensitive=False), |
| 306 | + help="The type of the embedding op to benchmark", |
| 307 | +) |
| 308 | +@click.option( |
| 309 | + "--row-wise/--no-row-wise", |
| 310 | + default=True, |
| 311 | + help="Whether to use row-wise adagrad optimzier or not", |
| 312 | +) |
| 313 | +@click.option( |
| 314 | + "--weighted-num-requires-grad", |
| 315 | + type=int, |
| 316 | + default=None, |
| 317 | + help="The number of weighted tables that require gradient", |
| 318 | +) |
| 319 | +@click.option( |
| 320 | + "--ssd-prefix", |
| 321 | + type=str, |
| 322 | + default="/tmp/ssd_benchmark", |
| 323 | + help="SSD directory prefix", |
| 324 | +) |
| 325 | +@click.option( |
| 326 | + "--pooling-list", |
| 327 | + type=str, |
| 328 | + default=None, |
| 329 | + help="override pooling list", |
| 330 | +) |
| 331 | +@click.option("--cache-load-factor", default=0.2) |
| 332 | +@TBEBenchmarkingConfigLoader.options |
| 333 | +@TBEDataConfigListLoader.options |
| 334 | +@EmbeddingOpsCommonConfigLoader.options |
| 335 | +@click.pass_context |
| 336 | +def device_with_speclist( # noqa C901 |
| 337 | + context: click.Context, |
| 338 | + emb_op_type: click.Choice, |
| 339 | + row_wise: bool, |
| 340 | + weighted_num_requires_grad: Optional[int], |
| 341 | + cache_load_factor: float, |
| 342 | + # SSD params |
| 343 | + ssd_prefix: str, |
| 344 | + pooling_list: Optional[str], |
| 345 | + # pyre-ignore[2] |
| 346 | + **kwargs, |
| 347 | +) -> None: |
| 348 | + # Initialize random seeds |
| 349 | + np.random.seed(42) |
| 350 | + torch.manual_seed(42) |
| 351 | + |
| 352 | + # Load general TBE benchmarking configuration from cli arguments |
| 353 | + benchconfig = TBEBenchmarkingConfigLoader.load(context) |
| 354 | + |
| 355 | + # Load TBE data configuration from cli arguments |
| 356 | + tbeconfig = TBEDataConfigListLoader.load(context) |
| 357 | + |
| 358 | + # Load common embedding op configuration from cli arguments |
| 359 | + embconfig = EmbeddingOpsCommonConfigLoader.load(context) |
| 360 | + |
| 361 | + # Generate feature_requires_grad |
| 362 | + feature_requires_grad = ( |
| 363 | + tbeconfig.generate_feature_requires_grad(weighted_num_requires_grad) |
| 364 | + if weighted_num_requires_grad |
| 365 | + else None |
| 366 | + ) |
| 367 | + |
| 368 | + # Determine the optimizer |
| 369 | + optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD |
| 370 | + |
| 371 | + # Construct the common split arguments for the embedding op |
| 372 | + common_split_args: Dict[str, Any] = embconfig.split_args() | { |
| 373 | + "optimizer": optimizer, |
| 374 | + "learning_rate": 0.1, |
| 375 | + "eps": 0.1, |
| 376 | + "feature_table_map": list(range(tbeconfig.T)), |
| 377 | + } |
| 378 | + batch_size_per_feature_per_rank = None |
| 379 | + if tbeconfig.batch_params.sigma_B is not None: |
| 380 | + batch_size_per_feature_per_rank = [] |
| 381 | + for b in tbeconfig.batch_params.Bs: |
| 382 | + batch_size_per_feature_per_rank.append([b]) |
| 383 | + # print("====================================") |
| 384 | + # print(batch_size_per_feature_per_rank) |
| 385 | + # print(tbeconfig.batch_params.sigma_B) |
| 386 | + |
| 387 | + managed_option = ( |
| 388 | + EmbeddingLocation.DEVICE |
| 389 | + if get_available_compute_device() == ComputeDevice.CUDA |
| 390 | + else EmbeddingLocation.HOST |
| 391 | + ) |
| 392 | + |
| 393 | + if emb_op_type == "dense": |
| 394 | + embedding_op = DenseTableBatchedEmbeddingBagsCodegen( |
| 395 | + [ |
| 396 | + ( |
| 397 | + e, |
| 398 | + d, |
| 399 | + ) |
| 400 | + for e, d in zip(tbeconfig.Es, tbeconfig.Ds) |
| 401 | + ], |
| 402 | + pooling_mode=embconfig.pooling_mode, |
| 403 | + use_cpu=not torch.cuda.is_available(), |
| 404 | + ) |
| 405 | + elif emb_op_type == "ssd": |
| 406 | + assert ( |
| 407 | + torch.cuda.is_available() |
| 408 | + ), "SSDTableBatchedEmbeddingBags only supports GPU execution" |
| 409 | + cache_set = max(sum(tbeconfig.batch_params.Bs), 1) |
| 410 | + tempdir = tempfile.mkdtemp(prefix=ssd_prefix) |
| 411 | + embedding_op = SSDTableBatchedEmbeddingBags( |
| 412 | + embedding_specs=[(e, d) for e, d in zip(tbeconfig.Es, tbeconfig.Ds)], |
| 413 | + cache_sets=cache_set, |
| 414 | + ssd_storage_directory=tempdir, |
| 415 | + ssd_cache_location=EmbeddingLocation.DEVICE, |
| 416 | + ssd_rocksdb_shards=8, |
| 417 | + **common_split_args, |
| 418 | + ) |
| 419 | + else: |
| 420 | + embedding_op = SplitTableBatchedEmbeddingBagsCodegen( |
| 421 | + [ |
| 422 | + ( |
| 423 | + e, |
| 424 | + d, |
| 425 | + managed_option, |
| 426 | + get_available_compute_device(), |
| 427 | + ) |
| 428 | + for e, d in zip(tbeconfig.Es, tbeconfig.Ds) |
| 429 | + ], |
| 430 | + cache_precision=( |
| 431 | + embconfig.weights_dtype |
| 432 | + if embconfig.cache_dtype is None |
| 433 | + else embconfig.cache_dtype |
| 434 | + ), |
| 435 | + cache_algorithm=CacheAlgorithm.LRU, |
| 436 | + cache_load_factor=cache_load_factor, |
| 437 | + device=get_device(), |
| 438 | + **common_split_args, |
| 439 | + ).to(get_device()) |
| 440 | + embedding_op = embedding_op.to(get_device()) |
| 441 | + |
| 442 | + if embconfig.weights_dtype == SparseType.INT8: |
| 443 | + # pyre-fixme[29]: `Union[(self: DenseTableBatchedEmbeddingBagsCodegen, |
| 444 | + # min_val: float, max_val: float) -> None, (self: |
| 445 | + # SplitTableBatchedEmbeddingBagsCodegen, min_val: float, max_val: float) -> |
| 446 | + # None, Tensor, Module]` is not a function. |
| 447 | + embedding_op.init_embedding_weights_uniform(-0.0003, 0.0003) |
| 448 | + |
| 449 | + avg_B = int(np.average(tbeconfig.batch_params.Bs)) |
| 450 | + |
| 451 | + nparams = sum(d * e for e, d in zip(tbeconfig.Es, tbeconfig.Ds)) |
| 452 | + param_size_multiplier = embconfig.weights_dtype.bit_rate() / 8.0 |
| 453 | + output_size_multiplier = embconfig.output_dtype.bit_rate() / 8.0 |
| 454 | + if embconfig.pooling_mode.do_pooling(): |
| 455 | + read_write_bytes = ( |
| 456 | + output_size_multiplier * avg_B * sum(tbeconfig.Ds) |
| 457 | + + param_size_multiplier |
| 458 | + * avg_B |
| 459 | + * sum(tbeconfig.Ds) |
| 460 | + * tbeconfig.pooling_params.L |
| 461 | + ) |
| 462 | + else: |
| 463 | + read_write_bytes = ( |
| 464 | + output_size_multiplier |
| 465 | + * avg_B |
| 466 | + * sum(tbeconfig.Ds) |
| 467 | + * tbeconfig.pooling_params.L |
| 468 | + + param_size_multiplier |
| 469 | + * avg_B |
| 470 | + * sum(tbeconfig.Ds) |
| 471 | + * tbeconfig.pooling_params.L |
| 472 | + ) |
| 473 | + |
| 474 | + logging.info(f"Managed option: {embconfig.embedding_location}") |
| 475 | + logging.info( |
| 476 | + f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, " |
| 477 | + f"{nparams * param_size_multiplier / 1.0e9: .2f} GB" |
| 478 | + ) |
| 479 | + logging.info( |
| 480 | + f"Accessed weights per batch: {avg_B * sum(tbeconfig.Ds) * tbeconfig.pooling_params.L * param_size_multiplier / 1.0e9: .2f} GB" |
| 481 | + ) |
| 482 | + |
| 483 | + if pooling_list is not None: |
| 484 | + pooling_list_extracted = [float(x) for x in pooling_list.split(",")] |
| 485 | + tensor_pooling_list = torch.tensor(pooling_list_extracted) |
| 486 | + requests = tbeconfig.generate_requests_with_Llist( |
| 487 | + tensor_pooling_list, |
| 488 | + benchconfig.num_requests, |
| 489 | + batch_size_per_feature_per_rank, |
| 490 | + ) |
| 491 | + else: |
| 492 | + requests = tbeconfig.generate_requests( |
| 493 | + benchconfig.num_requests, batch_size_per_feature_per_rank |
| 494 | + ) |
| 495 | + |
| 496 | + # pyre-ignore[53] |
| 497 | + def _kineto_trace_handler(p: profile, phase: str) -> None: |
| 498 | + p.export_chrome_trace( |
| 499 | + benchconfig.trace_url.format( |
| 500 | + emb_op_type=emb_op_type, phase=phase, ospid=os.getpid() |
| 501 | + ) |
| 502 | + ) |
| 503 | + |
| 504 | + # pyre-ignore[3,53] |
| 505 | + def _context_factory(on_trace_ready: Callable[[profile], None]): |
| 506 | + return ( |
| 507 | + profile(on_trace_ready=on_trace_ready, with_stack=True, record_shapes=True) |
| 508 | + if benchconfig.export_trace |
| 509 | + else nullcontext() |
| 510 | + ) |
| 511 | + |
| 512 | + # to add batch_size_per_feature_per_rank, Yan's edit |
| 513 | + |
| 514 | + if torch.cuda.is_available(): |
| 515 | + with _context_factory(lambda p: _kineto_trace_handler(p, "fwd")): |
| 516 | + # forward |
| 517 | + time_per_iter = benchmark_requests_with_spec( |
| 518 | + requests, |
| 519 | + lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: embedding_op.forward( |
| 520 | + indices.to(dtype=tbeconfig.indices_params.index_dtype), |
| 521 | + offsets.to(dtype=tbeconfig.indices_params.offset_dtype), |
| 522 | + per_sample_weights, |
| 523 | + feature_requires_grad=feature_requires_grad, |
| 524 | + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, |
| 525 | + ), |
| 526 | + flush_gpu_cache_size_mb=benchconfig.flush_gpu_cache_size_mb, |
| 527 | + num_warmups=benchconfig.warmup_iterations, |
| 528 | + iters=benchconfig.iterations, |
| 529 | + ) |
| 530 | + else: |
| 531 | + time_per_iter = benchmark_requests_with_spec( |
| 532 | + requests, |
| 533 | + lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: embedding_op.forward( |
| 534 | + indices.to(dtype=tbeconfig.indices_params.index_dtype), |
| 535 | + offsets.to(dtype=tbeconfig.indices_params.offset_dtype), |
| 536 | + per_sample_weights, |
| 537 | + feature_requires_grad=feature_requires_grad, |
| 538 | + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, |
| 539 | + ), |
| 540 | + flush_gpu_cache_size_mb=benchconfig.flush_gpu_cache_size_mb, |
| 541 | + num_warmups=benchconfig.warmup_iterations, |
| 542 | + iters=benchconfig.iterations, |
| 543 | + ) |
| 544 | + |
| 545 | + avg_E = int(np.average(tbeconfig.Es)) |
| 546 | + avg_D = int(np.average(tbeconfig.Ds)) |
| 547 | + logging.info( |
| 548 | + f"Forward, B: {avg_B}, " |
| 549 | + f"E: {avg_E}, T: {tbeconfig.T}, D: {avg_D}, L: {tbeconfig.pooling_params.L}, W: {tbeconfig.weighted}, " |
| 550 | + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 |
| 551 | + f"T: {time_per_iter * 1.0e6:.0f}us" |
| 552 | + ) |
| 553 | + |
| 554 | + if embconfig.output_dtype == SparseType.INT8: |
| 555 | + # backward bench not representative |
| 556 | + return |
| 557 | + |
| 558 | + if embconfig.pooling_mode.do_pooling(): |
| 559 | + if batch_size_per_feature_per_rank is None: |
| 560 | + grad_output = torch.randn(avg_B, sum(tbeconfig.Ds)).to(get_device()) |
| 561 | + else: |
| 562 | + output_size = sum( |
| 563 | + [b * d for (b, d) in zip(tbeconfig.batch_params.Bs, tbeconfig.Ds)] |
| 564 | + ) |
| 565 | + grad_output = torch.randn(output_size).to(get_device()) |
| 566 | + |
| 567 | + else: |
| 568 | + grad_output = torch.randn( |
| 569 | + avg_B * tbeconfig.T * tbeconfig.pooling_params.L, |
| 570 | + avg_D, |
| 571 | + ).to(get_device()) |
| 572 | + assert ( |
| 573 | + batch_size_per_feature_per_rank is None or grad_output.dim() == 1 |
| 574 | + ), f"VBE expects 1D grad_output but got {grad_output.shape}" |
| 575 | + if torch.cuda.is_available(): |
| 576 | + with _context_factory(lambda p: _kineto_trace_handler(p, "fwd_bwd")): |
| 577 | + # backward |
| 578 | + time_per_iter = benchmark_requests_with_spec( |
| 579 | + requests, |
| 580 | + lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: embedding_op( |
| 581 | + indices.to(dtype=tbeconfig.indices_params.index_dtype), |
| 582 | + offsets.to(dtype=tbeconfig.indices_params.offset_dtype), |
| 583 | + per_sample_weights, |
| 584 | + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, |
| 585 | + feature_requires_grad=feature_requires_grad, |
| 586 | + ), |
| 587 | + flush_gpu_cache_size_mb=benchconfig.flush_gpu_cache_size_mb, |
| 588 | + bwd_only=True, |
| 589 | + grad=grad_output, |
| 590 | + num_warmups=benchconfig.warmup_iterations, |
| 591 | + iters=benchconfig.iterations, |
| 592 | + ) |
| 593 | + else: |
| 594 | + time_per_iter = benchmark_requests_with_spec( |
| 595 | + requests, |
| 596 | + lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: embedding_op( |
| 597 | + indices.to(dtype=tbeconfig.indices_params.index_dtype), |
| 598 | + offsets.to(dtype=tbeconfig.indices_params.offset_dtype), |
| 599 | + per_sample_weights, |
| 600 | + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, |
| 601 | + feature_requires_grad=feature_requires_grad, |
| 602 | + ), |
| 603 | + flush_gpu_cache_size_mb=benchconfig.flush_gpu_cache_size_mb, |
| 604 | + bwd_only=True, |
| 605 | + grad=grad_output, |
| 606 | + num_warmups=benchconfig.warmup_iterations, |
| 607 | + iters=benchconfig.iterations, |
| 608 | + ) |
| 609 | + |
| 610 | + logging.info( |
| 611 | + f"Backward, B: {avg_B}, E: {avg_E}, T: {tbeconfig.T}, D: {avg_D}, L: {tbeconfig.pooling_params.L}, " |
| 612 | + f"BW: {2 * read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " |
| 613 | + f"T: {time_per_iter * 1.0e6:.0f}us" |
| 614 | + ) |
| 615 | + |
| 616 | + |
291 | 617 | if __name__ == "__main__":
|
292 | 618 | cli()
|
0 commit comments