Skip to content

Add 'device-with-speclist' bench #4648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 326 additions & 0 deletions fbgemm_gpu/bench/tbe/tbe_training_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
ComputeDevice,
DenseTableBatchedEmbeddingBagsCodegen,
get_available_compute_device,
SplitTableBatchedEmbeddingBagsCodegen,
)
from fbgemm_gpu.tbe.bench import (
benchmark_requests,
benchmark_requests_with_spec,
EmbeddingOpsCommonConfigLoader,
TBEBenchmarkingConfigLoader,
TBEDataConfigListLoader,
TBEDataConfigLoader,
)
from fbgemm_gpu.tbe.ssd import SSDTableBatchedEmbeddingBags
Expand All @@ -41,6 +44,13 @@
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)

try:
import mtia.host_runtime.torch_mtia.dynamic_library # pyright: ignore # noqa: F401 # pyre-ignore[21]

torch.mtia.init()
except Exception:
pass


@click.group()
def cli() -> None:
Expand Down Expand Up @@ -288,5 +298,321 @@ def _context_factory(on_trace_ready: Callable[[profile], None]):
)


@cli.command()
@click.option(
"--emb-op-type",
default="split",
type=click.Choice(["split", "dense", "ssd"], case_sensitive=False),
help="The type of the embedding op to benchmark",
)
@click.option(
"--row-wise/--no-row-wise",
default=True,
help="Whether to use row-wise adagrad optimzier or not",
)
@click.option(
"--weighted-num-requires-grad",
type=int,
default=None,
help="The number of weighted tables that require gradient",
)
@click.option(
"--ssd-prefix",
type=str,
default="/tmp/ssd_benchmark",
help="SSD directory prefix",
)
@click.option(
"--pooling-list",
type=str,
default=None,
help="override pooling list",
)
@click.option("--cache-load-factor", default=0.2)
@TBEBenchmarkingConfigLoader.options
@TBEDataConfigListLoader.options
@EmbeddingOpsCommonConfigLoader.options
@click.pass_context
def device_with_speclist( # noqa C901
context: click.Context,
emb_op_type: click.Choice,
row_wise: bool,
weighted_num_requires_grad: Optional[int],
cache_load_factor: float,
# SSD params
ssd_prefix: str,
pooling_list: Optional[str],
# pyre-ignore[2]
**kwargs,
) -> None:
# Initialize random seeds
np.random.seed(42)
torch.manual_seed(42)

# Load general TBE benchmarking configuration from cli arguments
benchconfig = TBEBenchmarkingConfigLoader.load(context)

# Load TBE data configuration from cli arguments
tbeconfig = TBEDataConfigListLoader.load(context)

# Load common embedding op configuration from cli arguments
embconfig = EmbeddingOpsCommonConfigLoader.load(context)

# Generate feature_requires_grad
feature_requires_grad = (
tbeconfig.generate_feature_requires_grad(weighted_num_requires_grad)
if weighted_num_requires_grad
else None
)

# Determine the optimizer
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD

# Construct the common split arguments for the embedding op
common_split_args: Dict[str, Any] = embconfig.split_args() | {
"optimizer": optimizer,
"learning_rate": 0.1,
"eps": 0.1,
"feature_table_map": list(range(tbeconfig.T)),
}
batch_size_per_feature_per_rank = None
if tbeconfig.batch_params.sigma_B is not None:
batch_size_per_feature_per_rank = []
for b in tbeconfig.batch_params.Bs:
batch_size_per_feature_per_rank.append([b])
# print("====================================")
# print(batch_size_per_feature_per_rank)
# print(tbeconfig.batch_params.sigma_B)

managed_option = (
EmbeddingLocation.DEVICE
if get_available_compute_device() == ComputeDevice.CUDA
else EmbeddingLocation.HOST
)

if emb_op_type == "dense":
embedding_op = DenseTableBatchedEmbeddingBagsCodegen(
[
(
e,
d,
)
for e, d in zip(tbeconfig.Es, tbeconfig.Ds)
],
pooling_mode=embconfig.pooling_mode,
use_cpu=not torch.cuda.is_available(),
)
elif emb_op_type == "ssd":
assert (
torch.cuda.is_available()
), "SSDTableBatchedEmbeddingBags only supports GPU execution"
cache_set = max(sum(tbeconfig.batch_params.Bs), 1)
tempdir = tempfile.mkdtemp(prefix=ssd_prefix)
embedding_op = SSDTableBatchedEmbeddingBags(
embedding_specs=[(e, d) for e, d in zip(tbeconfig.Es, tbeconfig.Ds)],
cache_sets=cache_set,
ssd_storage_directory=tempdir,
ssd_cache_location=EmbeddingLocation.DEVICE,
ssd_rocksdb_shards=8,
**common_split_args,
)
else:
embedding_op = SplitTableBatchedEmbeddingBagsCodegen(
[
(
e,
d,
managed_option,
get_available_compute_device(),
)
for e, d in zip(tbeconfig.Es, tbeconfig.Ds)
],
cache_precision=(
embconfig.weights_dtype
if embconfig.cache_dtype is None
else embconfig.cache_dtype
),
cache_algorithm=CacheAlgorithm.LRU,
cache_load_factor=cache_load_factor,
device=get_device(),
**common_split_args,
).to(get_device())
embedding_op = embedding_op.to(get_device())

if embconfig.weights_dtype == SparseType.INT8:
# pyre-fixme[29]: `Union[(self: DenseTableBatchedEmbeddingBagsCodegen,
# min_val: float, max_val: float) -> None, (self:
# SplitTableBatchedEmbeddingBagsCodegen, min_val: float, max_val: float) ->
# None, Tensor, Module]` is not a function.
embedding_op.init_embedding_weights_uniform(-0.0003, 0.0003)

avg_B = int(np.average(tbeconfig.batch_params.Bs))

nparams = sum(d * e for e, d in zip(tbeconfig.Es, tbeconfig.Ds))
param_size_multiplier = embconfig.weights_dtype.bit_rate() / 8.0
output_size_multiplier = embconfig.output_dtype.bit_rate() / 8.0
if embconfig.pooling_mode.do_pooling():
read_write_bytes = (
output_size_multiplier * avg_B * sum(tbeconfig.Ds)
+ param_size_multiplier
* avg_B
* sum(tbeconfig.Ds)
* tbeconfig.pooling_params.L
)
else:
read_write_bytes = (
output_size_multiplier
* avg_B
* sum(tbeconfig.Ds)
* tbeconfig.pooling_params.L
+ param_size_multiplier
* avg_B
* sum(tbeconfig.Ds)
* tbeconfig.pooling_params.L
)

logging.info(f"Managed option: {embconfig.embedding_location}")
logging.info(
f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, "
f"{nparams * param_size_multiplier / 1.0e9: .2f} GB"
)
logging.info(
f"Accessed weights per batch: {avg_B * sum(tbeconfig.Ds) * tbeconfig.pooling_params.L * param_size_multiplier / 1.0e9: .2f} GB"
)

if pooling_list is not None:
pooling_list_extracted = [float(x) for x in pooling_list.split(",")]
tensor_pooling_list = torch.tensor(pooling_list_extracted)
requests = tbeconfig.generate_requests_with_Llist(
tensor_pooling_list,
benchconfig.num_requests,
batch_size_per_feature_per_rank,
)
else:
requests = tbeconfig.generate_requests(
benchconfig.num_requests, batch_size_per_feature_per_rank
)

# pyre-ignore[53]
def _kineto_trace_handler(p: profile, phase: str) -> None:
p.export_chrome_trace(
benchconfig.trace_url.format(
emb_op_type=emb_op_type, phase=phase, ospid=os.getpid()
)
)

# pyre-ignore[3,53]
def _context_factory(on_trace_ready: Callable[[profile], None]):
return (
profile(on_trace_ready=on_trace_ready, with_stack=True, record_shapes=True)
if benchconfig.export_trace
else nullcontext()
)

# to add batch_size_per_feature_per_rank, Yan's edit

if torch.cuda.is_available():
with _context_factory(lambda p: _kineto_trace_handler(p, "fwd")):
# forward
time_per_iter = benchmark_requests_with_spec(
requests,
lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: embedding_op.forward(
indices.to(dtype=tbeconfig.indices_params.index_dtype),
offsets.to(dtype=tbeconfig.indices_params.offset_dtype),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
),
flush_gpu_cache_size_mb=benchconfig.flush_gpu_cache_size_mb,
num_warmups=benchconfig.warmup_iterations,
iters=benchconfig.iterations,
)
else:
time_per_iter = benchmark_requests_with_spec(
requests,
lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: embedding_op.forward(
indices.to(dtype=tbeconfig.indices_params.index_dtype),
offsets.to(dtype=tbeconfig.indices_params.offset_dtype),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
),
flush_gpu_cache_size_mb=benchconfig.flush_gpu_cache_size_mb,
num_warmups=benchconfig.warmup_iterations,
iters=benchconfig.iterations,
)

avg_E = int(np.average(tbeconfig.Es))
avg_D = int(np.average(tbeconfig.Ds))
logging.info(
f"Forward, B: {avg_B}, "
f"E: {avg_E}, T: {tbeconfig.T}, D: {avg_D}, L: {tbeconfig.pooling_params.L}, W: {tbeconfig.weighted}, "
f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"T: {time_per_iter * 1.0e6:.0f}us"
)

if embconfig.output_dtype == SparseType.INT8:
# backward bench not representative
return

if embconfig.pooling_mode.do_pooling():
if batch_size_per_feature_per_rank is None:
grad_output = torch.randn(avg_B, sum(tbeconfig.Ds)).to(get_device())
else:
output_size = sum(
[b * d for (b, d) in zip(tbeconfig.batch_params.Bs, tbeconfig.Ds)]
)
grad_output = torch.randn(output_size).to(get_device())

else:
grad_output = torch.randn(
avg_B * tbeconfig.T * tbeconfig.pooling_params.L,
avg_D,
).to(get_device())
assert (
batch_size_per_feature_per_rank is None or grad_output.dim() == 1
), f"VBE expects 1D grad_output but got {grad_output.shape}"
if torch.cuda.is_available():
with _context_factory(lambda p: _kineto_trace_handler(p, "fwd_bwd")):
# backward
time_per_iter = benchmark_requests_with_spec(
requests,
lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: embedding_op(
indices.to(dtype=tbeconfig.indices_params.index_dtype),
offsets.to(dtype=tbeconfig.indices_params.offset_dtype),
per_sample_weights,
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
feature_requires_grad=feature_requires_grad,
),
flush_gpu_cache_size_mb=benchconfig.flush_gpu_cache_size_mb,
bwd_only=True,
grad=grad_output,
num_warmups=benchconfig.warmup_iterations,
iters=benchconfig.iterations,
)
else:
time_per_iter = benchmark_requests_with_spec(
requests,
lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: embedding_op(
indices.to(dtype=tbeconfig.indices_params.index_dtype),
offsets.to(dtype=tbeconfig.indices_params.offset_dtype),
per_sample_weights,
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
feature_requires_grad=feature_requires_grad,
),
flush_gpu_cache_size_mb=benchconfig.flush_gpu_cache_size_mb,
bwd_only=True,
grad=grad_output,
num_warmups=benchconfig.warmup_iterations,
iters=benchconfig.iterations,
)

logging.info(
f"Backward, B: {avg_B}, E: {avg_E}, T: {tbeconfig.T}, D: {avg_D}, L: {tbeconfig.pooling_params.L}, "
f"BW: {2 * read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, "
f"T: {time_per_iter * 1.0e6:.0f}us"
)


if __name__ == "__main__":
cli()
4 changes: 3 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/tbe/bench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
benchmark_pipelined_requests,
benchmark_requests,
benchmark_requests_refer,
benchmark_requests_with_spec,
benchmark_vbe,
)
from .benchmark_click_interface import TbeBenchClickInterface # noqa F401
Expand All @@ -30,9 +31,10 @@
EvalCompressionBenchmarkOutput,
)
from .reporter import BenchmarkReporter # noqa F401
from .tbe_data_config import TBEDataConfig # noqa F401
from .tbe_data_config import TBEDataConfig, TBEDataListConfig # noqa F401
from .tbe_data_config_loader import ( # noqa F401
TBEDataConfigHelperText,
TBEDataConfigListLoader,
TBEDataConfigLoader,
)
from .tbe_data_config_param_models import ( # noqa F401
Expand Down
Loading