1010import functools
1111import logging
1212import math
13+ import os
1314import random
15+ from contextlib import nullcontext
16+ from typing import Callable
1417
1518import click
1619import fbgemm_gpu
@@ -240,13 +243,26 @@ def jagged_index_select_2d_ref(
240243@click .option ("--input-precision" , type = str , default = "fp32" )
241244@click .option ("--sort-indices" , type = bool , default = True )
242245@click .option ("--num-groups" , default = 32 )
246+ @click .option (
247+ "--export-trace" ,
248+ is_flag = True ,
249+ default = False ,
250+ help = "Enable export of trace for profiling. Default is False." ,
251+ )
252+ @click .option (
253+ "--trace-url" ,
254+ type = str ,
255+ default = "group_index_select_2d_{phase}_trace_{ospid}.json" ,
256+ )
243257def group_index_select_2d_bench (
244258 row_size : int ,
245259 batch_size : int ,
246260 unique_batch_size : int ,
247261 input_precision : str ,
248262 sort_indices : bool ,
249263 num_groups : int ,
264+ export_trace : bool ,
265+ trace_url : str ,
250266) -> None :
251267 def gen_inverse_index (curr_size : int , final_size : int ) -> np .array :
252268 inverse_index = list (range (curr_size ))
@@ -285,6 +301,13 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
285301
286302 bench_kwargs = {"num_warmups" : 10 , "iters" : 100 }
287303
304+ def _kineto_trace_handler (p : profile , phase : str ) -> None :
305+ p .export_chrome_trace (trace_url .format (phase = phase , ospid = os .getpid ()))
306+
307+ # pyre-ignore[3]
308+ def context_factory (on_trace_ready : Callable [[profile ], None ]):
309+ return profile (on_trace_ready = on_trace_ready ) if export_trace else nullcontext ()
310+
288311 # Benchmark forward
289312 time_ref , output_ref = benchmark_torch_function (
290313 # pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
@@ -297,13 +320,14 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
297320 )
298321
299322 input_group = input .split (batch_size , 0 )
300- time , output_group = benchmark_torch_function (
301- torch .ops .fbgemm .group_index_select_dim0 ,
302- (input_group , indices_group ),
303- # pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
304- # pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
305- ** bench_kwargs ,
306- )
323+ with context_factory (lambda p : _kineto_trace_handler (p , "fwd" )):
324+ time , output_group = benchmark_torch_function (
325+ torch .ops .fbgemm .group_index_select_dim0 ,
326+ (input_group , indices_group ),
327+ # pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
328+ # pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
329+ ** bench_kwargs ,
330+ )
307331 logging .info (
308332 f"forward: PyTorch batch { time_ref :.5f} sec ({ num_bytes / time_ref / 1e9 :.5f} GB/s), "
309333 f"fbgemm group { time :5f} sec ({ num_bytes / time / 1e9 :.5f} GB/s)"
@@ -322,13 +346,14 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
322346 # pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
323347 # typing.Tuple[Tensor, ...]]` but got `Tensor`.
324348 cat_output = torch .cat (output_group )
325- time , _ = benchmark_torch_function (
326- functools .partial (cat_output .backward , retain_graph = True ),
327- (grad ,),
328- # pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
329- # pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
330- ** bench_kwargs ,
331- )
349+ with context_factory (lambda p : _kineto_trace_handler (p , "bwd" )):
350+ time , _ = benchmark_torch_function (
351+ functools .partial (cat_output .backward , retain_graph = True ),
352+ (grad ,),
353+ # pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
354+ # pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
355+ ** bench_kwargs ,
356+ )
332357 logging .info (
333358 f"backward: PyTorch batch { time_ref :.5f} sec ({ num_bytes / time_ref / 1e9 :.5f} GB/s), "
334359 f"fbgemm group { time :.5f} sec ({ num_bytes / time / 1e9 :.5f} GB/s)"
0 commit comments