Skip to content

Commit 7a1a83f

Browse files
gchalumpfacebook-github-bot
authored andcommitted
Adding Kineto support to bench:sparse_ops (#5060)
Summary: X-link: facebookresearch/FBGEMM#2069 Support added to group-index-select-2d-bench for now. Differential Revision: D85199596
1 parent 648e57a commit 7a1a83f

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

fbgemm_gpu/bench/sparse_ops_benchmark.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
import functools
1111
import logging
1212
import math
13+
import os
1314
import random
15+
from contextlib import nullcontext
16+
from typing import Callable
1417

1518
import click
1619
import fbgemm_gpu
@@ -240,13 +243,26 @@ def jagged_index_select_2d_ref(
240243
@click.option("--input-precision", type=str, default="fp32")
241244
@click.option("--sort-indices", type=bool, default=True)
242245
@click.option("--num-groups", default=32)
246+
@click.option(
247+
"--export-trace",
248+
is_flag=True,
249+
default=False,
250+
help="Enable export of trace for profiling. Default is False.",
251+
)
252+
@click.option(
253+
"--trace-url",
254+
type=str,
255+
default="group_index_select_2d_{phase}_trace_{ospid}.json",
256+
)
243257
def group_index_select_2d_bench(
244258
row_size: int,
245259
batch_size: int,
246260
unique_batch_size: int,
247261
input_precision: str,
248262
sort_indices: bool,
249263
num_groups: int,
264+
export_trace: bool,
265+
trace_url: str,
250266
) -> None:
251267
def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
252268
inverse_index = list(range(curr_size))
@@ -285,6 +301,13 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
285301

286302
bench_kwargs = {"num_warmups": 10, "iters": 100}
287303

304+
def _kineto_trace_handler(p: profile, phase: str) -> None:
305+
p.export_chrome_trace(trace_url.format(phase=phase, ospid=os.getpid()))
306+
307+
# pyre-ignore[3]
308+
def context_factory(on_trace_ready: Callable[[profile], None]):
309+
return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext()
310+
288311
# Benchmark forward
289312
time_ref, output_ref = benchmark_torch_function(
290313
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
@@ -297,13 +320,14 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
297320
)
298321

299322
input_group = input.split(batch_size, 0)
300-
time, output_group = benchmark_torch_function(
301-
torch.ops.fbgemm.group_index_select_dim0,
302-
(input_group, indices_group),
303-
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
304-
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
305-
**bench_kwargs,
306-
)
323+
with context_factory(lambda p: _kineto_trace_handler(p, "fwd")):
324+
time, output_group = benchmark_torch_function(
325+
torch.ops.fbgemm.group_index_select_dim0,
326+
(input_group, indices_group),
327+
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
328+
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
329+
**bench_kwargs,
330+
)
307331
logging.info(
308332
f"forward: PyTorch batch {time_ref:.5f} sec ({num_bytes / time_ref / 1e9:.5f} GB/s), "
309333
f"fbgemm group {time:5f} sec ({num_bytes / time / 1e9:.5f} GB/s)"
@@ -322,13 +346,14 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
322346
# pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
323347
# typing.Tuple[Tensor, ...]]` but got `Tensor`.
324348
cat_output = torch.cat(output_group)
325-
time, _ = benchmark_torch_function(
326-
functools.partial(cat_output.backward, retain_graph=True),
327-
(grad,),
328-
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
329-
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
330-
**bench_kwargs,
331-
)
349+
with context_factory(lambda p: _kineto_trace_handler(p, "bwd")):
350+
time, _ = benchmark_torch_function(
351+
functools.partial(cat_output.backward, retain_graph=True),
352+
(grad,),
353+
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
354+
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
355+
**bench_kwargs,
356+
)
332357
logging.info(
333358
f"backward: PyTorch batch {time_ref:.5f} sec ({num_bytes / time_ref / 1e9:.5f} GB/s), "
334359
f"fbgemm group {time:.5f} sec ({num_bytes / time / 1e9:.5f} GB/s)"

0 commit comments

Comments
 (0)