From 7a1a83f1af2cdc4f4de8358f989c4a13c099b95a Mon Sep 17 00:00:00 2001 From: Gantaphon Chalumporn Date: Tue, 11 Nov 2025 14:34:08 -0800 Subject: [PATCH] Adding Kineto support to bench:sparse_ops (#5060) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2069 Support added to group-index-select-2d-bench for now. Differential Revision: D85199596 --- fbgemm_gpu/bench/sparse_ops_benchmark.py | 53 +++++++++++++++++------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index 136e117538..cf2c84bc1a 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -10,7 +10,10 @@ import functools import logging import math +import os import random +from contextlib import nullcontext +from typing import Callable import click import fbgemm_gpu @@ -240,6 +243,17 @@ def jagged_index_select_2d_ref( @click.option("--input-precision", type=str, default="fp32") @click.option("--sort-indices", type=bool, default=True) @click.option("--num-groups", default=32) +@click.option( + "--export-trace", + is_flag=True, + default=False, + help="Enable export of trace for profiling. Default is False.", +) +@click.option( + "--trace-url", + type=str, + default="group_index_select_2d_{phase}_trace_{ospid}.json", +) def group_index_select_2d_bench( row_size: int, batch_size: int, @@ -247,6 +261,8 @@ def group_index_select_2d_bench( input_precision: str, sort_indices: bool, num_groups: int, + export_trace: bool, + trace_url: str, ) -> None: def gen_inverse_index(curr_size: int, final_size: int) -> np.array: inverse_index = list(range(curr_size)) @@ -285,6 +301,13 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array: bench_kwargs = {"num_warmups": 10, "iters": 100} + def _kineto_trace_handler(p: profile, phase: str) -> None: + p.export_chrome_trace(trace_url.format(phase=phase, ospid=os.getpid())) + + # pyre-ignore[3] + def context_factory(on_trace_ready: Callable[[profile], None]): + return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext() + # Benchmark forward time_ref, output_ref = benchmark_torch_function( # pyre-fixme[6]: For 3rd argument expected `bool` but got `int`. @@ -297,13 +320,14 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array: ) input_group = input.split(batch_size, 0) - time, output_group = benchmark_torch_function( - torch.ops.fbgemm.group_index_select_dim0, - (input_group, indices_group), - # pyre-fixme[6]: For 3rd argument expected `bool` but got `int`. - # pyre-fixme[6]: For 3rd argument expected `str` but got `int`. - **bench_kwargs, - ) + with context_factory(lambda p: _kineto_trace_handler(p, "fwd")): + time, output_group = benchmark_torch_function( + torch.ops.fbgemm.group_index_select_dim0, + (input_group, indices_group), + # pyre-fixme[6]: For 3rd argument expected `bool` but got `int`. + # pyre-fixme[6]: For 3rd argument expected `str` but got `int`. + **bench_kwargs, + ) logging.info( f"forward: PyTorch batch {time_ref:.5f} sec ({num_bytes / time_ref / 1e9:.5f} GB/s), " f"fbgemm group {time:5f} sec ({num_bytes / time / 1e9:.5f} GB/s)" @@ -322,13 +346,14 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array: # pyre-fixme[6]: For 1st argument expected `Union[List[Tensor], # typing.Tuple[Tensor, ...]]` but got `Tensor`. cat_output = torch.cat(output_group) - time, _ = benchmark_torch_function( - functools.partial(cat_output.backward, retain_graph=True), - (grad,), - # pyre-fixme[6]: For 3rd argument expected `bool` but got `int`. - # pyre-fixme[6]: For 3rd argument expected `str` but got `int`. - **bench_kwargs, - ) + with context_factory(lambda p: _kineto_trace_handler(p, "bwd")): + time, _ = benchmark_torch_function( + functools.partial(cat_output.backward, retain_graph=True), + (grad,), + # pyre-fixme[6]: For 3rd argument expected `bool` but got `int`. + # pyre-fixme[6]: For 3rd argument expected `str` but got `int`. + **bench_kwargs, + ) logging.info( f"backward: PyTorch batch {time_ref:.5f} sec ({num_bytes / time_ref / 1e9:.5f} GB/s), " f"fbgemm group {time:.5f} sec ({num_bytes / time / 1e9:.5f} GB/s)"