From 5ca57a7b425e004d42624fdd961e058f1aefad23 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:03:25 +0530 Subject: [PATCH 01/40] Add DLRM-DCNv2 example for MLPerf --- .../configs/datasets/dummy_dataset.py | 140 ++++++++ .../configs/models/default_model.py | 19 ++ .../configs/training/default_training.py | 7 + examples/mlperf_dlrm_dcnv2/configs/v6e_16.py | 15 + examples/mlperf_dlrm_dcnv2/configs/v6e_8.py | 15 + examples/mlperf_dlrm_dcnv2/dataloader.py | 69 ++++ examples/mlperf_dlrm_dcnv2/main.py | 232 +++++++++++++ examples/mlperf_dlrm_dcnv2/model.py | 321 ++++++++++++++++++ 8 files changed, 818 insertions(+) create mode 100644 examples/mlperf_dlrm_dcnv2/configs/datasets/dummy_dataset.py create mode 100644 examples/mlperf_dlrm_dcnv2/configs/models/default_model.py create mode 100644 examples/mlperf_dlrm_dcnv2/configs/training/default_training.py create mode 100644 examples/mlperf_dlrm_dcnv2/configs/v6e_16.py create mode 100644 examples/mlperf_dlrm_dcnv2/configs/v6e_8.py create mode 100644 examples/mlperf_dlrm_dcnv2/dataloader.py create mode 100644 examples/mlperf_dlrm_dcnv2/main.py create mode 100644 examples/mlperf_dlrm_dcnv2/model.py diff --git a/examples/mlperf_dlrm_dcnv2/configs/datasets/dummy_dataset.py b/examples/mlperf_dlrm_dcnv2/configs/datasets/dummy_dataset.py new file mode 100644 index 00000000..469ad206 --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/configs/datasets/dummy_dataset.py @@ -0,0 +1,140 @@ +from keras.utils import Config + +# === Dataset === +dataset_config = Config() +dataset_config.file_pattern = None +# Features +dataset_config.label = "clicked" +dataset_config.dense = [f"int-feature-{i}" for i in range(13)] +dataset_config.sparse = [ + { + "name": "categorical-feature-14", + "vocabulary_size": 40000000, + "multi_hot_size": 3, + }, + { + "name": "categorical-feature-15", + "vocabulary_size": 39060, + "multi_hot_size": 2, + }, + { + "name": "categorical-feature-16", + "vocabulary_size": 17295, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-17", + "vocabulary_size": 7424, + "multi_hot_size": 2, + }, + { + "name": "categorical-feature-18", + "vocabulary_size": 20265, + "multi_hot_size": 6, + }, + { + "name": "categorical-feature-19", + "vocabulary_size": 3, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-20", + "vocabulary_size": 7122, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-21", + "vocabulary_size": 1543, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-22", + "vocabulary_size": 63, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-23", + "vocabulary_size": 40000000, + "multi_hot_size": 7, + }, + { + "name": "categorical-feature-24", + "vocabulary_size": 3067956, + "multi_hot_size": 3, + }, + { + "name": "categorical-feature-25", + "vocabulary_size": 405282, + "multi_hot_size": 8, + }, + { + "name": "categorical-feature-26", + "vocabulary_size": 10, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-27", + "vocabulary_size": 2209, + "multi_hot_size": 6, + }, + { + "name": "categorical-feature-28", + "vocabulary_size": 11938, + "multi_hot_size": 9, + }, + { + "name": "categorical-feature-29", + "vocabulary_size": 155, + "multi_hot_size": 5, + }, + { + "name": "categorical-feature-30", + "vocabulary_size": 4, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-31", + "vocabulary_size": 976, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-32", + "vocabulary_size": 14, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-33", + "vocabulary_size": 40000000, + "multi_hot_size": 12, + }, + { + "name": "categorical-feature-34", + "vocabulary_size": 40000000, + "multi_hot_size": 100, + }, + { + "name": "categorical-feature-35", + "vocabulary_size": 40000000, + "multi_hot_size": 27, + }, + { + "name": "categorical-feature-36", + "vocabulary_size": 590152, + "multi_hot_size": 10, + }, + { + "name": "categorical-feature-37", + "vocabulary_size": 12973, + "multi_hot_size": 3, + }, + { + "name": "categorical-feature-38", + "vocabulary_size": 108, + "multi_hot_size": 1, + }, + { + "name": "categorical-feature-39", + "vocabulary_size": 36, + "multi_hot_size": 1, + }, +] diff --git a/examples/mlperf_dlrm_dcnv2/configs/models/default_model.py b/examples/mlperf_dlrm_dcnv2/configs/models/default_model.py new file mode 100644 index 00000000..1a07e9d9 --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/configs/models/default_model.py @@ -0,0 +1,19 @@ +from keras.utils import Config + +# === Model === +model_config = Config() +# Embedding +model_config.embedding_dim = 128 +model_config.allow_id_dropping = True +model_config.embedding_threshold = 21000 +model_config.max_ids_per_partition = 4096 +model_config.max_unique_ids_per_partition = 2048 +model_config.learning_rate = 0.005 + +# MLP +model_config.bottom_mlp_dims = [512, 256, 128] +model_config.top_mlp_dims = [1024, 1024, 512, 256, 1] + +# DCN +model_config.num_dcn_layers = 3 +model_config.dcn_projection_dim = 512 diff --git a/examples/mlperf_dlrm_dcnv2/configs/training/default_training.py b/examples/mlperf_dlrm_dcnv2/configs/training/default_training.py new file mode 100644 index 00000000..b758bc59 --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/configs/training/default_training.py @@ -0,0 +1,7 @@ +from keras.utils import Config + +# === Training Hyperparameters === +training_config = Config() +training_config.learning_rate = 0.005 +training_config.global_batch_size = 128 +training_config.num_epochs = 1 diff --git a/examples/mlperf_dlrm_dcnv2/configs/v6e_16.py b/examples/mlperf_dlrm_dcnv2/configs/v6e_16.py new file mode 100644 index 00000000..b246d773 --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/configs/v6e_16.py @@ -0,0 +1,15 @@ +from configs.datasets.dummy_dataset import dataset_config +from configs.models.default_model import model_config +from configs.training.default_training import training_config +from keras.utils import Config + +config = Config() + +config.experiment_name = "v6e_16" +config.model_dir = "./v6e_16" + +config.dataset = dataset_config +config.model = model_config +config.training = training_config + +config.freeze() diff --git a/examples/mlperf_dlrm_dcnv2/configs/v6e_8.py b/examples/mlperf_dlrm_dcnv2/configs/v6e_8.py new file mode 100644 index 00000000..552f25d0 --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/configs/v6e_8.py @@ -0,0 +1,15 @@ +from configs.datasets.dummy_dataset import dataset_config +from configs.models.default_model import model_config +from configs.training.default_training import training_config +from keras.utils import Config + +config = Config() + +config.experiment_name = "v6e_8" +config.model_dir = "./v6e_8" + +config.dataset = dataset_config +config.model = model_config +config.training = training_config + +config.freeze() diff --git a/examples/mlperf_dlrm_dcnv2/dataloader.py b/examples/mlperf_dlrm_dcnv2/dataloader.py new file mode 100644 index 00000000..467f390a --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/dataloader.py @@ -0,0 +1,69 @@ +import numpy as np +import tensorflow as tf + + +def _get_dummy_batch(batch_size, large_emb_features, small_emb_features): + """Returns a dummy batch of data in the final desired structure.""" + + # Labels + data = { + "clicked": np.random.randint(0, 2, size=(batch_size,), dtype=np.int64) + } + + # Dense features + dense_input_list = [ + np.random.uniform(0.0, 0.9, size=(batch_size, 1)).astype(np.float32) + for _ in range(13) + ] + data["dense_input"] = np.concatenate(dense_input_list, axis=-1) + + # Sparse features + large_emb_inputs = {} + for large_emb_feature in large_emb_features: + vocabulary_size = large_emb_feature["vocabulary_size"] + multi_hot_size = large_emb_feature["multi_hot_size"] + idx = large_emb_feature["name"].split("-")[-1] + + large_emb_inputs[f"cat_{idx}_id"] = np.random.randint( + low=0, + high=vocabulary_size, + size=(batch_size, multi_hot_size), + dtype=np.int64, + ) + + data["large_emb_inputs"] = large_emb_inputs + + # Dense lookup features + small_emb_inputs = {} + for small_emb_feature in small_emb_features: + vocabulary_size = small_emb_feature["vocabulary_size"] + multi_hot_size = small_emb_feature["multi_hot_size"] + idx = small_emb_feature["name"].split("-")[-1] + + # TODO: We don't need this custom renaming. Remove later, when we + # shift from dummy data to actual data. + small_emb_inputs[f"cat_{idx}_id"] = np.random.randint( + low=0, + high=vocabulary_size, + size=(batch_size, multi_hot_size), + dtype=np.int64, + ) + + if small_emb_inputs: + data["small_emb_inputs"] = small_emb_inputs + + return data + + +def create_dummy_dataset(batch_size, large_emb_features, small_emb_features): + """Creates a TF dataset from cached dummy data of the final batch size.""" + dummy_data = _get_dummy_batch( + batch_size, large_emb_features, small_emb_features + ) + + # Separate labels from features to create a `(features, labels)` tuple. + labels = dummy_data.pop("clicked") + features = dummy_data + + dataset = tf.data.Dataset.from_tensors((features, labels)).repeat(512) + return dataset diff --git a/examples/mlperf_dlrm_dcnv2/main.py b/examples/mlperf_dlrm_dcnv2/main.py new file mode 100644 index 00000000..115648cf --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/main.py @@ -0,0 +1,232 @@ +import argparse +import importlib +import os + +os.environ["KERAS_BACKEND"] = "jax" + +import keras + +import keras_rs + +from .dataloader import create_dummy_dataset +from .model import DLRMDCNV2 + +SEED = 1337 + + +def main( + file_pattern, + dense_features, + large_emb_features, + small_emb_features, + label, + embedding_dim, + allow_id_dropping, + max_ids_per_partition, + max_unique_ids_per_partition, + embedding_learning_rate, + bottom_mlp_dims, + top_mlp_dims, + num_dcn_layers, + dcn_projection_dim, + learning_rate, + global_batch_size, + num_epochs, +): + # Set DDP as Keras distribution strategy + devices = keras.distribution.list_devices(device_type="tpu") + distribution = keras.distribution.DataParallel(devices=devices) + keras.distribution.set_distribution(distribution) + num_processes = distribution._num_process() + + per_host_batch_size = global_batch_size // num_processes + + # === Distributed embeddings' configs for sparse features === + feature_configs = {} + for large_emb_feature in large_emb_features: + # Rename these features to something shorter; was facing some weird + # issues with the longer names. + feature_name = ( + large_emb_feature["name"] + .replace("-", "_") + .replace("egorical_feature", "") + ) + vocabulary_size = large_emb_feature["vocabulary_size"] + multi_hot_size = large_emb_feature["multi_hot_size"] + + table_config = keras_rs.layers.TableConfig( + name=f"{feature_name}_table", + vocabulary_size=vocabulary_size, + embedding_dim=embedding_dim, + # TODO(abheesht): Verify. + initializer=keras.initializers.VarianceScaling( + scale=1.0, + mode="fan_in", + distribution="uniform", + seed=SEED, + ), + optimizer=keras.optimizers.Adagrad( + learning_rate=embedding_learning_rate + ), + combiner="sum", + placement="sparsecore", + # TODO: These two args are not getting passed down to + # `jax-tpu-embedding` properly, seems like. + max_ids_per_partition=max_ids_per_partition, + max_unique_ids_per_partition=max_unique_ids_per_partition, + ) + feature_configs[f"{feature_name}_id"] = keras_rs.layers.FeatureConfig( + name=feature_name.replace("id", ""), + table=table_config, + # TODO: Verify whether it should be `(bsz, 1)` or + # `(bsz, multi_hot_size)`. + input_shape=(per_host_batch_size, multi_hot_size), + output_shape=(per_host_batch_size, embedding_dim), + ) + + # === Instantiate model === + # We instantiate the model first, because we need to preprocess sparse + # inputs using the distributed embedding layer defined inside the model + # class. + print("===== Initialising model =====") + model = DLRMDCNV2( + large_emb_feature_configs=feature_configs, + small_emb_features=small_emb_features, + embedding_dim=embedding_dim, + bottom_mlp_dims=bottom_mlp_dims, + top_mlp_dims=top_mlp_dims, + num_dcn_layers=num_dcn_layers, + dcn_projection_dim=dcn_projection_dim, + seed=SEED, + dtype="float32", + name="dlrm_dcn_v2", + ) + model.compile( + loss=keras.losses.BinaryCrossentropy(), + optimizer=keras.optimizers.Adagrad(learning_rate=learning_rate), + metrics=[keras.metrics.BinaryAccuracy()], + ) + + # === Load dataset === + print("===== Loading dataset =====") + train_ds = create_dummy_dataset( + batch_size=per_host_batch_size, + large_emb_features=large_emb_features, + small_emb_features=small_emb_features, + ) + # For the multi-host case, the dataset has to be distributed manually. + # See note here: + # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. + if num_processes > 1: + train_ds = distribution.distribute_dataset(train_ds) + distribution.auto_shard_dataset = False + + def generator(dataset, training=False): + """Converts tf.data Dataset to a Python generator and preprocesses + sparse features. + """ + for features, labels in dataset: + preprocessed_large_embeddings = model.embedding_layer.preprocess( + features["large_emb_inputs"], training=training + ) + + x = { + "dense_input": features["dense_input"], + "large_emb_inputs": preprocessed_large_embeddings, + "small_emb_inputs": features["small_emb_inputs"], + } + y = labels + yield (x, y) + + train_generator = generator(train_ds, training=True) + for first_batch in train_generator: + model(first_batch[0]) + break + + # Train the model. + model.fit(train_generator, epochs=1) + + +if __name__ == "__main__": + keras.config.disable_traceback_filtering() + + print("====== Launching train script =======") + parser = argparse.ArgumentParser( + description=( + "Benchmark the DLRM-DCNv2 model on the Criteo dataset (MLPerf)" + ) + ) + parser.add_argument( + "--config_name", type=str, help="Name of the `.py` config file." + ) + args = parser.parse_args() + + print(f"===== Reading config from {args.config_name} ======") + config = getattr(importlib.import_module("configs"), args.config_name) + + # === Unpack args from config === + + # == Dataset config == + ds_cfg = config["dataset"] + # File path + file_pattern = ds_cfg["file_pattern"] + # Features + label = ds_cfg["label"] + dense_features = ds_cfg["dense"] + emb_features = ds_cfg["sparse"] + + # == Model config == + model_cfg = config["model"] + # Embedding + embedding_dim = model_cfg["embedding_dim"] + allow_id_dropping = model_cfg["allow_id_dropping"] + embedding_threshold = model_cfg["embedding_threshold"] + max_ids_per_partition = model_cfg["max_ids_per_partition"] + max_unique_ids_per_partition = model_cfg["max_unique_ids_per_partition"] + embedding_learning_rate = model_cfg["learning_rate"] + # MLP + bottom_mlp_dims = model_cfg["bottom_mlp_dims"] + top_mlp_dims = model_cfg["top_mlp_dims"] + # DCN + num_dcn_layers = model_cfg["num_dcn_layers"] + dcn_projection_dim = model_cfg["dcn_projection_dim"] + + # == Training config == + training_cfg = config["training"] + learning_rate = training_cfg["learning_rate"] + global_batch_size = training_cfg["global_batch_size"] + num_epochs = training_cfg["num_epochs"] + + # For features which have vocabulary_size < embedding_threshold, we can + # just do a normal dense lookup for those instead of have distributed + # embeddings. We could ideally pass `placement = default_device` to + # `keras_rs.layers.TableConfig` directly (and wouldn't have to do this + # separation of features), but doing it that way will necessarily require + # a separate optimiser for the embedding layer. + small_emb_features = [] + large_emb_features = [] + for emb_feature in emb_features: + if emb_feature["vocabulary_size"] < embedding_threshold: + small_emb_features.append(emb_feature) + else: + large_emb_features.append(emb_feature) + + main( + file_pattern, + dense_features, + large_emb_features, + small_emb_features, + label, + embedding_dim, + allow_id_dropping, + max_ids_per_partition, + max_unique_ids_per_partition, + embedding_learning_rate, + bottom_mlp_dims, + top_mlp_dims, + num_dcn_layers, + dcn_projection_dim, + learning_rate, + global_batch_size, + num_epochs, + ) diff --git a/examples/mlperf_dlrm_dcnv2/model.py b/examples/mlperf_dlrm_dcnv2/model.py new file mode 100644 index 00000000..b77579f9 --- /dev/null +++ b/examples/mlperf_dlrm_dcnv2/model.py @@ -0,0 +1,321 @@ +from typing import Any, TypeAlias + +import keras +from keras import ops + +import keras_rs + +Tensor: TypeAlias = Any + + +def _clone_initializer( + initializer: keras.initializers.Initializer, + seed: int | keras.random.SeedGenerator, +): + """Clones the provided initializer with a new seed. + + This function creates a new instance of a Keras initializer from an + existing one, but with a different seed. This is useful for ensuring + different weights in a model are initialized with different seeds. + + Args: + initializer: a keras.initializers.Initializer instance. The initializer + to be cloned. + seed: int, or a keras.random.SeedGenerator() instance. The random seed. + + Returns: + A new `keras.initializers.Initializer` instance configured with the + provided seed. + """ + config = initializer.get_config() + config.pop("seed") + config = {**config, "seed": seed} + initializer_class: type[keras.initializers.Initializer] = ( + initializer.__class__ + ) + return initializer_class.from_config(config) + + +class DLRMDCNV2(keras.Model): + def __init__( + self, + large_emb_feature_configs: dict[str, keras_rs.layers.FeatureConfig], + small_emb_features: list, + embedding_dim: int, + bottom_mlp_dims: list[int], + top_mlp_dims: list[int], + num_dcn_layers: int, + dcn_projection_dim: int, + seed: int | keras.random.SeedGenerator | None = None, + dtype: str | None = None, + name: str | None = None, + **kwargs: Any, + ): + """DLRM-DCNv2 model. + + The model processes two types of input features: + 1. Dense Features: Continuous-valued features that are processed by + a multi-layer perceptron (the "bottom MLP"). + 2. Sparse Features: High-cardinality categorical features that are + first mapped into low-dimensional embedding vectors using the + `keras_rs.layers.DistributedEmbedding` layer. This layer is highly + optimized for large-scale recommendation models, especially on TPUs + with SparseCore, as it can shard large embedding tables across + multiple accelerator chips for improved performance. On other + hardware (GPUs, CPUs), it functions like a standard embedding layer. + + The output of the bottom MLP and the embedding vectors are then + concatenated and fed into a DCN block for learning feature interactions. + The output of the DCN block is then processed by another MLP + (the "top MLP") to produce a final prediction. + + Args: + large_emb_feature_configs: A dictionary with features names as keys + and `keras_rs.layers.FeatureConfig` objects as values. These + configs link features to their corresponding embedding tables + (`keras_rs.layers.TableConfig`), specifying parameters like + vocabulary size, embedding dimension, and hardware placement + strategy. + bottom_mlp_dims: A list of integers specifying the number of units + in each layer of the bottom MLP. + top_mlp_dims: A list of integers specifying the number of units in + each layer of the top MLP. The last value is the final output + dimension (e.g., 1 for binary classification). + num_dcn_layers: The number of feature-crossing layers in the DCNv2 + block. + dcn_projection_dim: The projection dimension used within each DCNv2 + cross-layer. + seed: The random seed. + dtype: Optional dtype. + name: The name of the layer. + """ + super().__init__(dtype=dtype, name=name, **kwargs) + self.seed = seed + + # === Layers ==== + + # Bottom MLP for encoding dense features + self.bottom_mlp = keras.Sequential( + self._get_mlp_layers( + dims=bottom_mlp_dims, + intermediate_activation="relu", + final_activation="relu", + ), + name="bottom_mlp", + ) + # Distributed embeddings for large embedding tables + self.embedding_layer = keras_rs.layers.DistributedEmbedding( + feature_configs=large_emb_feature_configs, + table_stacking="auto_stacking", + dtype=dtype, + name="embedding_layer", + ) + # Embedding layers for small embedding tables + self.small_embedding_layers = None + if small_emb_features: + self.small_embedding_layers = [ + keras.layers.Embedding( + input_dim=small_emb_feature["vocabulary_size"], + output_dim=embedding_dim, + embeddings_initializer="zeros", + name=f"small_embedding_layer_{i}", + ) + for i, small_emb_feature in enumerate(small_emb_features) + ] + # DCN for "interactions" + self.dcn_block = DCNBlock( + num_layers=num_dcn_layers, + projection_dim=dcn_projection_dim, + seed=seed, + dtype=dtype, + name="dcn_block", + ) + # Top MLP for predictions + self.top_mlp = keras.Sequential( + self._get_mlp_layers( + dims=top_mlp_dims, + intermediate_activation="relu", + final_activation="sigmoid", + ), + name="top_mlp", + ) + + # === Passed attributes === + self.large_emb_feature_configs = large_emb_feature_configs + self.small_emb_features = small_emb_features + self.embedding_dim = embedding_dim + self.bottom_mlp_dims = bottom_mlp_dims + self.top_mlp_dims = top_mlp_dims + self.num_dcn_layers = num_dcn_layers + self.dcn_projection_dim = dcn_projection_dim + + def call(self, inputs: dict[str, Tensor]) -> Tensor: + """Forward pass of the model. + + Args: + inputs: A dictionary containing `"dense_features"` and + `"preprocessed_large_emb_features"` as keys. + """ + # Inputs + dense_input = inputs["dense_input"] + large_emb_inputs = inputs["large_emb_inputs"] + + # Embed features. + dense_output = self.bottom_mlp(dense_input) + # jax.debug.print("dense_ouput {}", dense_output.shape) + large_embeddings = self.embedding_layer(large_emb_inputs) + small_embeddings = [] + if self.small_emb_features: + small_emb_inputs = inputs["small_emb_inputs"] + for small_emb_input, embedding_layer in zip( + small_emb_inputs.values(), self.small_embedding_layers + ): + embedding = embedding_layer(small_emb_input) + embedding = ops.sum(embedding, axis=-2) + small_embeddings.append(embedding) + + small_embeddings = ops.concatenate(small_embeddings, axis=-1) + + # Interaction + x = ops.concatenate( + [dense_output, small_embeddings, *large_embeddings.values()], + axis=-1, + ) + # jax.debug.print("x {}", x.shape) + x = self.dcn_block(x) + + # Predictions + outputs = self.top_mlp(x) + return outputs + + def _get_mlp_layers( + self, + dims: list[int], + intermediate_activation: str | keras.layers.Activation, + final_activation: str | keras.layers.Activation, + ) -> list[keras.layers.Layer]: + """Creates a list of Dense layers. + + Args: + dims: list. Output dimensions of the dense layers to be created. + intermediate_activation: string or `keras.layers.Activation`. The + activation to be used in all layers, save the last. + final_activation: str or `keras.layers.Activation`. The activation + to be used in the last layer. + + Returns: + A list of `keras.layers.Dense` layers. + """ + initializer = keras.initializers.VarianceScaling( + scale=1.0, + mode="fan_in", + distribution="uniform", + seed=self.seed, + ) + + layers = [ + keras.layers.Dense( + units=dim, + activation=intermediate_activation, + kernel_initializer=_clone_initializer( + initializer, seed=self.seed + ), + bias_initializer=_clone_initializer( + initializer, seed=self.seed + ), + dtype=self.dtype, + ) + for dim in dims[:-1] + ] + layers += [ + keras.layers.Dense( + units=dims[-1], + activation=final_activation, + kernel_initializer=_clone_initializer( + initializer, seed=self.seed + ), + bias_initializer=_clone_initializer( + initializer, seed=self.seed + ), + dtype=self.dtype, + ) + ] + return layers + + def get_config(self): + """Returns the config of the model.""" + config = super().get_config() + config.update( + { + "large_emb_feature_configs": self.large_emb_feature_configs, + "small_emb_features": self.small_emb_features, + "embedding_dim": self.embedding_dim, + "bottom_mlp_dims": self.bottom_mlp_dims, + "top_mlp_dims": self.top_mlp_dims, + "num_dcn_layers": self.num_dcn_layers, + "dcn_projection_dim": self.dcn_projection_dim, + "seed": self.seed, + } + ) + return config + + +class DCNBlock(keras.layers.Layer): + def __init__( + self, + num_layers: int, + projection_dim: int, + seed: int | keras.random.SeedGenerator, + dtype: str | None = None, + name: str | None = None, + **kwargs, + ): + """ + A block of Deep & Cross Network V2 (DCNv2) layers. + + This layer implements the "cross network" part of the DCNv2 architecture + by stacking multiple `keras_rs.layers.FeatureCross` layers, which learn + feature interactions. + + Args: + num_layers: The number of `FeatureCross` layers to stack. + projection_dim: The dimensionality of the low-rank projection used + within each cross layer. + seed: The random seed for initializers. + dtype: Optional dtype. + name: The name of the layer. + """ + super().__init__(dtype=dtype, name=name, **kwargs) + + # Layers + self.layers = [ + keras_rs.layers.FeatureCross( + projection_dim=projection_dim, + kernel_initializer=keras.initializers.GlorotUniform(seed=seed), + bias_initializer="zeros", + dtype=dtype, + ) + for _ in range(num_layers) + ] + + # Passed attributes + self.num_layers = num_layers + self.projection_dim = projection_dim + self.seed = seed + + def call(self, x0): + xl = x0 + for layer in self.layers: + xl = layer(x0, xl) + return xl + + def get_config(self): + config = super().get_config() + config.update( + { + "num_layers": self.num_layers, + "projection_dim": self.projection_dim, + "seed": self.seed, + } + ) + return config From 090098df8d03d99cd540f539d0b3ad4fb892ea1a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:05:15 +0530 Subject: [PATCH 02/40] Fix table_stacking arg --- examples/mlperf_dlrm_dcnv2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mlperf_dlrm_dcnv2/model.py b/examples/mlperf_dlrm_dcnv2/model.py index b77579f9..76cf3d09 100644 --- a/examples/mlperf_dlrm_dcnv2/model.py +++ b/examples/mlperf_dlrm_dcnv2/model.py @@ -106,7 +106,7 @@ def __init__( # Distributed embeddings for large embedding tables self.embedding_layer = keras_rs.layers.DistributedEmbedding( feature_configs=large_emb_feature_configs, - table_stacking="auto_stacking", + table_stacking="auto", dtype=dtype, name="embedding_layer", ) From 1dd94228bf168142928797223699b2339c7bd401 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:08:17 +0530 Subject: [PATCH 03/40] Rename dir --- .../configs/datasets/dummy_dataset.py | 0 .../configs/models/default_model.py | 0 .../configs/training/default_training.py | 0 examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/v6e_16.py | 0 examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/v6e_8.py | 0 examples/{mlperf_dlrm_dcnv2 => ml_perf}/dataloader.py | 0 examples/{mlperf_dlrm_dcnv2 => ml_perf}/main.py | 0 examples/{mlperf_dlrm_dcnv2 => ml_perf}/model.py | 0 8 files changed, 0 insertions(+), 0 deletions(-) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/datasets/dummy_dataset.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/models/default_model.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/training/default_training.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/v6e_16.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/configs/v6e_8.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/dataloader.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/main.py (100%) rename examples/{mlperf_dlrm_dcnv2 => ml_perf}/model.py (100%) diff --git a/examples/mlperf_dlrm_dcnv2/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/configs/datasets/dummy_dataset.py rename to examples/ml_perf/configs/datasets/dummy_dataset.py diff --git a/examples/mlperf_dlrm_dcnv2/configs/models/default_model.py b/examples/ml_perf/configs/models/default_model.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/configs/models/default_model.py rename to examples/ml_perf/configs/models/default_model.py diff --git a/examples/mlperf_dlrm_dcnv2/configs/training/default_training.py b/examples/ml_perf/configs/training/default_training.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/configs/training/default_training.py rename to examples/ml_perf/configs/training/default_training.py diff --git a/examples/mlperf_dlrm_dcnv2/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/configs/v6e_16.py rename to examples/ml_perf/configs/v6e_16.py diff --git a/examples/mlperf_dlrm_dcnv2/configs/v6e_8.py b/examples/ml_perf/configs/v6e_8.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/configs/v6e_8.py rename to examples/ml_perf/configs/v6e_8.py diff --git a/examples/mlperf_dlrm_dcnv2/dataloader.py b/examples/ml_perf/dataloader.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/dataloader.py rename to examples/ml_perf/dataloader.py diff --git a/examples/mlperf_dlrm_dcnv2/main.py b/examples/ml_perf/main.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/main.py rename to examples/ml_perf/main.py diff --git a/examples/mlperf_dlrm_dcnv2/model.py b/examples/ml_perf/model.py similarity index 100% rename from examples/mlperf_dlrm_dcnv2/model.py rename to examples/ml_perf/model.py From 25233c9e87b03210e1e4b2e8ee80f4a3d26699dd Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:34:48 +0530 Subject: [PATCH 04/40] Add blank __init__ file to configs --- examples/ml_perf/configs/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 examples/ml_perf/configs/__init__.py diff --git a/examples/ml_perf/configs/__init__.py b/examples/ml_perf/configs/__init__.py new file mode 100644 index 00000000..e69de29b From 3d519d2b6f0e9c51c43981d2bdb5ac19b1a336f7 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:37:50 +0530 Subject: [PATCH 05/40] Fix imports --- examples/ml_perf/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 115648cf..d0b49c9e 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -162,7 +162,10 @@ def generator(dataset, training=False): args = parser.parse_args() print(f"===== Reading config from {args.config_name} ======") - config = getattr(importlib.import_module("configs"), args.config_name) + config = getattr( + importlib.import_module(".configs", package=__package__), + args.config_name + ) # === Unpack args from config === From eef6568dc05bc25f375605f4f57409e1ff418b2e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:42:48 +0530 Subject: [PATCH 06/40] Fix imports --- examples/ml_perf/main.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index d0b49c9e..96412c81 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -163,9 +163,8 @@ def generator(dataset, training=False): print(f"===== Reading config from {args.config_name} ======") config = getattr( - importlib.import_module(".configs", package=__package__), - args.config_name - ) + importlib.import_module(f".configs.{args.config_name}", package=__package__) + ).config # === Unpack args from config === From fc77ad49986fe0cbebc8520526b61290519511fc Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:45:12 +0530 Subject: [PATCH 07/40] Fix imports --- examples/ml_perf/configs/datasets/__init__.py | 0 examples/ml_perf/configs/models/__init__.py | 0 examples/ml_perf/configs/training/__init__.py | 0 examples/ml_perf/configs/v6e_16.py | 6 +++--- examples/ml_perf/configs/v6e_8.py | 6 +++--- 5 files changed, 6 insertions(+), 6 deletions(-) create mode 100644 examples/ml_perf/configs/datasets/__init__.py create mode 100644 examples/ml_perf/configs/models/__init__.py create mode 100644 examples/ml_perf/configs/training/__init__.py diff --git a/examples/ml_perf/configs/datasets/__init__.py b/examples/ml_perf/configs/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/ml_perf/configs/models/__init__.py b/examples/ml_perf/configs/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/ml_perf/configs/training/__init__.py b/examples/ml_perf/configs/training/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index b246d773..c9c50aff 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -1,6 +1,6 @@ -from configs.datasets.dummy_dataset import dataset_config -from configs.models.default_model import model_config -from configs.training.default_training import training_config +from .datasets.dummy_dataset import dataset_config +from .models.default_model import model_config +from .training.default_training import training_config from keras.utils import Config config = Config() diff --git a/examples/ml_perf/configs/v6e_8.py b/examples/ml_perf/configs/v6e_8.py index 552f25d0..4e3904e2 100644 --- a/examples/ml_perf/configs/v6e_8.py +++ b/examples/ml_perf/configs/v6e_8.py @@ -1,6 +1,6 @@ -from configs.datasets.dummy_dataset import dataset_config -from configs.models.default_model import model_config -from configs.training.default_training import training_config +from .datasets.dummy_dataset import dataset_config +from .models.default_model import model_config +from .training.default_training import training_config from keras.utils import Config config = Config() From 0a31e0fb5f34608ce4a359dd16c9e5dd8910df03 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:47:45 +0530 Subject: [PATCH 08/40] Fix imports --- examples/ml_perf/main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 96412c81..078e63f3 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -162,9 +162,11 @@ def generator(dataset, training=False): args = parser.parse_args() print(f"===== Reading config from {args.config_name} ======") - config = getattr( - importlib.import_module(f".configs.{args.config_name}", package=__package__) - ).config + config = ( + importlib.import_module( + f".configs.{args.config_name}", package=__package__ + ).config + ) # === Unpack args from config === From 53e9c2304588cdba2cf26bb915f3e93b7ada7fb2 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 12:48:47 +0530 Subject: [PATCH 09/40] Fix num_processes --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 078e63f3..9abcb718 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -37,7 +37,7 @@ def main( devices = keras.distribution.list_devices(device_type="tpu") distribution = keras.distribution.DataParallel(devices=devices) keras.distribution.set_distribution(distribution) - num_processes = distribution._num_process() + num_processes = distribution._num_process per_host_batch_size = global_batch_size // num_processes From f22b5ffdaf343daa3e3d33973311ec2c6b4977fb Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 13:50:13 +0530 Subject: [PATCH 10/40] Add bash script --- examples/ml_perf/configs/v6e_16.py | 3 +- examples/ml_perf/configs/v6e_8.py | 3 +- examples/ml_perf/main.py | 8 +-- examples/ml_perf/model.py | 4 +- examples/ml_perf/run.sh | 106 +++++++++++++++++++++++++++++ 5 files changed, 116 insertions(+), 8 deletions(-) create mode 100644 examples/ml_perf/run.sh diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py index c9c50aff..4b6df8df 100644 --- a/examples/ml_perf/configs/v6e_16.py +++ b/examples/ml_perf/configs/v6e_16.py @@ -1,7 +1,8 @@ +from keras.utils import Config + from .datasets.dummy_dataset import dataset_config from .models.default_model import model_config from .training.default_training import training_config -from keras.utils import Config config = Config() diff --git a/examples/ml_perf/configs/v6e_8.py b/examples/ml_perf/configs/v6e_8.py index 4e3904e2..fcd81e39 100644 --- a/examples/ml_perf/configs/v6e_8.py +++ b/examples/ml_perf/configs/v6e_8.py @@ -1,7 +1,8 @@ +from keras.utils import Config + from .datasets.dummy_dataset import dataset_config from .models.default_model import model_config from .training.default_training import training_config -from keras.utils import Config config = Config() diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 9abcb718..5bea8c72 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -162,11 +162,9 @@ def generator(dataset, training=False): args = parser.parse_args() print(f"===== Reading config from {args.config_name} ======") - config = ( - importlib.import_module( - f".configs.{args.config_name}", package=__package__ - ).config - ) + config = importlib.import_module( + f".configs.{args.config_name}", package=__package__ + ).config # === Unpack args from config === diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 76cf3d09..a9f01f36 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -117,7 +117,9 @@ def __init__( keras.layers.Embedding( input_dim=small_emb_feature["vocabulary_size"], output_dim=embedding_dim, - embeddings_initializer="zeros", + embeddings_initializer=keras.initializers.LecunNormal( + seed=self.seed, + ), name=f"small_embedding_layer_{i}", ) for i, small_emb_feature in enumerate(small_emb_features) diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh new file mode 100644 index 00000000..f5d840a5 --- /dev/null +++ b/examples/ml_perf/run.sh @@ -0,0 +1,106 @@ +#!/bin/bash + +# ============================================================================== +# Script Configuration & Argument Handling +# ============================================================================== +# This script accepts up to three optional arguments: +# 1. Accelerator Type (default: v6e-8, options: v6e-8, v6e-16) +# 2. Zone (default: us-east5-a) +# 3. Project (default: tpu-prod-env-one-vm) + +ACCELERATOR_TYPE=${1:-"v6e-8"} +ZONE=${2:-"us-east5-a"} +PROJECT=${3:-"tpu-prod-env-one-vm"} + +# Validate the provided accelerator type +if [[ "${ACCELERATOR_TYPE}" != "v6e-8" && "${ACCELERATOR_TYPE}" != "v6e-16" ]]; then + echo "Error: Invalid accelerator type '${ACCELERATOR_TYPE}'." >&2 + echo "Usage: $0 [v6e-8|v6e-16] [gcp_zone] [gcp_project]" >&2 + exit 1 +fi + +# ============================================================================== +# Environment Variables +# ============================================================================== +# TPU name is generated dynamically. Zone and Project are set from args or defaults. +export TPU_NAME="abheesht-mlperf-${ACCELERATOR_TYPE}" +export ZONE +export PROJECT + +echo ">>> Using Configuration:" +echo " Accelerator: ${ACCELERATOR_TYPE}" +echo " TPU Name: ${TPU_NAME}" +echo " Zone: ${ZONE}" +echo " Project: ${PROJECT}" +echo "--------------------------------------------------" + + +# ============================================================================== +# TPU VM Creation +# ============================================================================== +echo ">>> Creating TPU VM: ${TPU_NAME} with accelerator ${ACCELERATOR_TYPE}..." +gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \ + --zone=${ZONE} \ + --accelerator-type=${ACCELERATOR_TYPE} \ + --version=v2-alpha-tpuv6e \ + --project=${PROJECT} \ + --metadata=enable-oslogin=TRUE \ + --scopes=https://www.googleapis.com/auth/cloud-platform + + +# ============================================================================== +# Setup Python Virtual Environment on all workers +# ============================================================================== +echo ">>> Creating Python virtual environment..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="sudo apt-get update && sudo apt install -y python3.10-venv && python3 -m venv .keras-env" + + +# ============================================================================== +# Clone KerasRS and Install Dependencies +# ============================================================================== +echo ">>> Cloning KerasRS and installing dependencies..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="source .keras-env/bin/activate && git clone https://github.com/abheesht17/keras-rs.git && cd keras-rs && git checkout ml-perf && pip install -e . && pip install tensorflow-datasets && pip uninstall -y tensorflow keras && pip install git+https://github.com/keras-team/keras.git && pip install jax-tpu-embedding tensorflow-cpu" + + +# ============================================================================== +# Install TPU-compatible JAX +# ============================================================================== +echo ">>> Re-installing JAX for TPU compatibility..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="source .keras-env/bin/activate && pip uninstall -y jax jaxlib && pip install -U 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" + + +# ============================================================================== +# Verify Installation +# ============================================================================== +echo ">>> Verifying JAX installation..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="source .keras-env/bin/activate && echo 'import jax; print(jax.devices())' > script.py && python script.py" + + +# ============================================================================== +# Run Training Script +# ============================================================================== +# The config path is now also set dynamically. +echo ">>> Running the main script with config for ${ACCELERATOR_TYPE}..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="source .keras-env/bin/activate && cd keras-rs && python3 -m examples.ml_perf.main --config_name ${ACCELERATOR_TYPE}" + +echo ">>> Script finished." From 41f297753c6f45d807d654ac8857c9dc771da20f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 13:51:31 +0530 Subject: [PATCH 11/40] Add bash script --- examples/ml_perf/run.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index f5d840a5..d84bd7d1 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -22,7 +22,6 @@ fi # ============================================================================== # Environment Variables # ============================================================================== -# TPU name is generated dynamically. Zone and Project are set from args or defaults. export TPU_NAME="abheesht-mlperf-${ACCELERATOR_TYPE}" export ZONE export PROJECT @@ -95,7 +94,6 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ # ============================================================================== # Run Training Script # ============================================================================== -# The config path is now also set dynamically. echo ">>> Running the main script with config for ${ACCELERATOR_TYPE}..." gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ From 09ca14c381bfc8e1ba3c54faf639eb54137d8b5a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 14:02:12 +0530 Subject: [PATCH 12/40] Modify bash script to take in config name --- examples/ml_perf/run.sh | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index d84bd7d1..b34953f2 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -7,10 +7,12 @@ # 1. Accelerator Type (default: v6e-8, options: v6e-8, v6e-16) # 2. Zone (default: us-east5-a) # 3. Project (default: tpu-prod-env-one-vm) +# 4. Config Name (default: derived from accelerator type, e.g., v6e_8) ACCELERATOR_TYPE=${1:-"v6e-8"} ZONE=${2:-"us-east5-a"} PROJECT=${3:-"tpu-prod-env-one-vm"} +USER_CONFIG_NAME=${4} # Capture the fourth argument # Validate the provided accelerator type if [[ "${ACCELERATOR_TYPE}" != "v6e-8" && "${ACCELERATOR_TYPE}" != "v6e-16" ]]; then @@ -26,11 +28,19 @@ export TPU_NAME="abheesht-mlperf-${ACCELERATOR_TYPE}" export ZONE export PROJECT +# Use the user-provided config name if it exists, otherwise derive it. +if [[ -n "${USER_CONFIG_NAME}" ]]; then + export CONFIG_NAME=${USER_CONFIG_NAME} +else + export CONFIG_NAME=${ACCELERATOR_TYPE//-/_} +fi + echo ">>> Using Configuration:" echo " Accelerator: ${ACCELERATOR_TYPE}" echo " TPU Name: ${TPU_NAME}" echo " Zone: ${ZONE}" echo " Project: ${PROJECT}" +echo " Config Name: ${CONFIG_NAME}" echo "--------------------------------------------------" @@ -99,6 +109,6 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="source .keras-env/bin/activate && cd keras-rs && python3 -m examples.ml_perf.main --config_name ${ACCELERATOR_TYPE}" + --command="source .keras-env/bin/activate && cd keras-rs && python3 -m examples.ml_perf.main --config_name ${CONFIG_NAME}" echo ">>> Script finished." From 2a1c759251adc0e19c58592da66a56d3447a3341 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 19:54:33 +0530 Subject: [PATCH 13/40] Add way to load real dataset --- .../ml_perf/configs/datasets/dummy_dataset.py | 26 ++ examples/ml_perf/dataloader.py | 237 +++++++++++++----- examples/ml_perf/main.py | 16 +- examples/ml_perf/run.sh | 3 +- 4 files changed, 222 insertions(+), 60 deletions(-) diff --git a/examples/ml_perf/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py index 469ad206..0cbda44f 100644 --- a/examples/ml_perf/configs/datasets/dummy_dataset.py +++ b/examples/ml_perf/configs/datasets/dummy_dataset.py @@ -11,130 +11,156 @@ "name": "categorical-feature-14", "vocabulary_size": 40000000, "multi_hot_size": 3, + "new_name": "cat_14_id", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, "multi_hot_size": 2, + "new_name": "cat_15_id", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, "multi_hot_size": 1, + "new_name": "cat_16_id", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "multi_hot_size": 2, + "new_name": "cat_17_id", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, "multi_hot_size": 6, + "new_name": "cat_18_id", }, { "name": "categorical-feature-19", "vocabulary_size": 3, "multi_hot_size": 1, + "new_name": "cat_19_id", }, { "name": "categorical-feature-20", "vocabulary_size": 7122, "multi_hot_size": 1, + "new_name": "cat_20_id", }, { "name": "categorical-feature-21", "vocabulary_size": 1543, "multi_hot_size": 1, + "new_name": "cat_21_id", }, { "name": "categorical-feature-22", "vocabulary_size": 63, "multi_hot_size": 1, + "new_name": "cat_22_id", }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, "multi_hot_size": 7, + "new_name": "cat_23_id", }, { "name": "categorical-feature-24", "vocabulary_size": 3067956, "multi_hot_size": 3, + "new_name": "cat_24_id", }, { "name": "categorical-feature-25", "vocabulary_size": 405282, "multi_hot_size": 8, + "new_name": "cat_25_id", }, { "name": "categorical-feature-26", "vocabulary_size": 10, "multi_hot_size": 1, + "new_name": "cat_26_id", }, { "name": "categorical-feature-27", "vocabulary_size": 2209, "multi_hot_size": 6, + "new_name": "cat_27_id", }, { "name": "categorical-feature-28", "vocabulary_size": 11938, "multi_hot_size": 9, + "new_name": "cat_28_id", }, { "name": "categorical-feature-29", "vocabulary_size": 155, "multi_hot_size": 5, + "new_name": "cat_29_id", }, { "name": "categorical-feature-30", "vocabulary_size": 4, "multi_hot_size": 1, + "new_name": "cat_30_id", }, { "name": "categorical-feature-31", "vocabulary_size": 976, "multi_hot_size": 1, + "new_name": "cat_31_id", }, { "name": "categorical-feature-32", "vocabulary_size": 14, "multi_hot_size": 1, + "new_name": "cat_32_id", }, { "name": "categorical-feature-33", "vocabulary_size": 40000000, "multi_hot_size": 12, + "new_name": "cat_33_id", }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, "multi_hot_size": 100, + "new_name": "cat_34_id", }, { "name": "categorical-feature-35", "vocabulary_size": 40000000, "multi_hot_size": 27, + "new_name": "cat_35_id", }, { "name": "categorical-feature-36", "vocabulary_size": 590152, "multi_hot_size": 10, + "new_name": "cat_36_id", }, { "name": "categorical-feature-37", "vocabulary_size": 12973, "multi_hot_size": 3, + "new_name": "cat_37_id", }, { "name": "categorical-feature-38", "vocabulary_size": 108, "multi_hot_size": 1, + "new_name": "cat_38_id", }, { "name": "categorical-feature-39", "vocabulary_size": 36, "multi_hot_size": 1, + "new_name": "cat_39_id", }, ] diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 467f390a..2b34f352 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -2,68 +2,191 @@ import tensorflow as tf -def _get_dummy_batch(batch_size, large_emb_features, small_emb_features): - """Returns a dummy batch of data in the final desired structure.""" - - # Labels - data = { - "clicked": np.random.randint(0, 2, size=(batch_size,), dtype=np.int64) - } - - # Dense features - dense_input_list = [ - np.random.uniform(0.0, 0.9, size=(batch_size, 1)).astype(np.float32) - for _ in range(13) - ] - data["dense_input"] = np.concatenate(dense_input_list, axis=-1) - - # Sparse features - large_emb_inputs = {} - for large_emb_feature in large_emb_features: - vocabulary_size = large_emb_feature["vocabulary_size"] - multi_hot_size = large_emb_feature["multi_hot_size"] - idx = large_emb_feature["name"].split("-")[-1] - - large_emb_inputs[f"cat_{idx}_id"] = np.random.randint( - low=0, - high=vocabulary_size, - size=(batch_size, multi_hot_size), - dtype=np.int64, - ) +class DataLoader: + def __init__( + self, + file_pattern, + batch_size, + dense_features, + large_emb_features, + small_emb_features, + label, + training=False, + ): + # Passed attributes. + self.file_pattern = file_pattern + self.batch_size = batch_size + self.dense_features = dense_features + self.large_emb_features = large_emb_features + self.small_emb_features = small_emb_features + self.label = label + self.training = training + + # Derived attributes. + self._return_dummy_dataset = file_pattern is None + + def _get_dummy_batch(self): + """Returns a dummy batch of data in the final desired structure.""" + + # Labels + data = { + "clicked": np.random.randint( + 0, 2, size=(self.batch_size,), dtype=np.int64 + ) + } + + # Dense features + dense_input_list = [ + np.random.uniform(0.0, 0.9, size=(self.batch_size, 1)).astype( + np.float32 + ) + for _ in range(13) + ] + data["dense_input"] = np.concatenate(dense_input_list, axis=-1) + + # Sparse features + large_emb_inputs = {} + for large_emb_feature in self.large_emb_features: + name = large_emb_feature["name"] + new_name = large_emb_feature.get("new_name", name) + vocabulary_size = large_emb_feature["vocabulary_size"] + multi_hot_size = large_emb_feature["multi_hot_size"] + + large_emb_inputs[new_name] = np.random.randint( + low=0, + high=vocabulary_size, + size=(self.batch_size, multi_hot_size), + dtype=np.int64, + ) + + data["large_emb_inputs"] = large_emb_inputs + + # Dense lookup features + small_emb_inputs = {} + for small_emb_feature in self.small_emb_features: + name = small_emb_feature["name"] + new_name = small_emb_feature.get("new_name", name) + vocabulary_size = small_emb_feature["vocabulary_size"] + multi_hot_size = small_emb_feature["multi_hot_size"] + + small_emb_inputs[new_name] = np.random.randint( + low=0, + high=vocabulary_size, + size=(self.batch_size, multi_hot_size), + dtype=np.int64, + ) + + if small_emb_inputs: + data["small_emb_inputs"] = small_emb_inputs + + return data + + def _create_dummy_dataset(self): + """Creates a TF dummy dataset (randomly initialised).""" + dummy_data = self._get_dummy_batch() + + # Separate labels from features to create a `(features, labels)` tuple. + labels = dummy_data.pop("clicked") + features = dummy_data + + dataset = tf.data.Dataset.from_tensors((features, labels)).repeat(512) + return dataset - data["large_emb_inputs"] = large_emb_inputs - - # Dense lookup features - small_emb_inputs = {} - for small_emb_feature in small_emb_features: - vocabulary_size = small_emb_feature["vocabulary_size"] - multi_hot_size = small_emb_feature["multi_hot_size"] - idx = small_emb_feature["name"].split("-")[-1] - - # TODO: We don't need this custom renaming. Remove later, when we - # shift from dummy data to actual data. - small_emb_inputs[f"cat_{idx}_id"] = np.random.randint( - low=0, - high=vocabulary_size, - size=(batch_size, multi_hot_size), - dtype=np.int64, + def _get_feature_spec(self): + feature_spec = { + self.label: tf.io.FixedLenFeature( + [self.batch_size], + dtype=tf.int64, + ) + } + + for dense_feat in self.dense_features: + feature_spec[dense_feat] = tf.io.FixedLenFeature( + [self.batch_size], + dtype=tf.float32, + ) + + for emb_feat in self.large_emb_features + self.small_emb_features: + name = emb_feat["name"] + feature_spec[name] = tf.io.FixedLenFeature( + [self.batch_size], + dtype=tf.string, + ) + + return feature_spec + + def _preprocess(self, example): + # Read example. + feature_spec = self.get_feature_spec() + example = tf.io.parse_single_example(example, feature_spec) + + # Dense features + dense_input = tf.stack( + [ + tf.reshape(example[dense_feature], [self.batch_size, 1]) + for dense_feature in self.dense_features + ], + axis=-1, ) - if small_emb_inputs: - data["small_emb_inputs"] = small_emb_inputs + def _get_emb_inputs(emb_features): + emb_inputs = {} + for emb_feature in emb_features: + name = emb_feature["name"] + new_name = emb_feature.get("new_name", name) + multi_hot_size = emb_feature["multi_hot_size"] + + raw_values = tf.io.decode_raw(example[name], tf.int64) + raw_values = tf.reshape( + raw_values, [self.batch_size, multi_hot_size] + ) + emb_inputs[new_name] = raw_values + return emb_inputs + + # Sparse features + large_emb_inputs = _get_emb_inputs(self.large_emb_features) + small_emb_inputs = _get_emb_inputs(self.small_emb_features) + + # Labels + labels = tf.reshape(example[self.label], [self.batch_size]) - return data + x = { + "dense_input": dense_input, + "large_emb_inputs": large_emb_inputs, + } + if small_emb_inputs: + x["small_emb_inputs"] = small_emb_inputs + return (x, labels) + + def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): + if self._return_dummy_dataset: + return self._create_dummy_dataset() + + dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) + + # Shard the dataset across hosts/workers. + # TODO: Do we need to do this if we are distributing the dataset + # manually using distribution.distribute_dataset(...)? + if num_processes > 1: + dataset = dataset.shard(num_processes, process_id) + + dataset = tf.data.TFRecordDataset( + dataset, + buffer_size=None, + num_parallel_reads=tf.data.AUTOTUNE, + ) + + # Process example. + dataset = dataset.map( + lambda x: self._preprocess(x), + num_parallel_calls=tf.data.AUTOTUNE, + ) -def create_dummy_dataset(batch_size, large_emb_features, small_emb_features): - """Creates a TF dataset from cached dummy data of the final batch size.""" - dummy_data = _get_dummy_batch( - batch_size, large_emb_features, small_emb_features - ) + # Shuffle dataset if in training mode. + if self.training and shuffle_buffer and shuffle_buffer > 0: + dataset = dataset.shuffle(shuffle_buffer) - # Separate labels from features to create a `(features, labels)` tuple. - labels = dummy_data.pop("clicked") - features = dummy_data + dataset = dataset.prefetch(tf.data.AUTOTUNE) - dataset = tf.data.Dataset.from_tensors((features, labels)).repeat(512) - return dataset + return dataset diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 5bea8c72..da7e35f8 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -8,7 +8,7 @@ import keras_rs -from .dataloader import create_dummy_dataset +from .dataloader import DataLoader from .model import DLRMDCNV2 SEED = 1337 @@ -20,6 +20,7 @@ def main( large_emb_features, small_emb_features, label, + shuffle_buffer, embedding_dim, allow_id_dropping, max_ids_per_partition, @@ -109,10 +110,18 @@ def main( # === Load dataset === print("===== Loading dataset =====") - train_ds = create_dummy_dataset( + train_ds = DataLoader( + file_pattern=file_pattern, batch_size=per_host_batch_size, + dense_features=dense_features, large_emb_features=large_emb_features, small_emb_features=small_emb_features, + label=label, + training=True, + ).create_dataset( + process_id=distribution._process_id, + num_processes=num_processes, + shuffle_buffer=shuffle_buffer, ) # For the multi-host case, the dataset has to be distributed manually. # See note here: @@ -172,6 +181,8 @@ def generator(dataset, training=False): ds_cfg = config["dataset"] # File path file_pattern = ds_cfg["file_pattern"] + # Shuffling + shuffle_buffer = ds_cfg["shuffle_buffer"] # Features label = ds_cfg["label"] dense_features = ds_cfg["dense"] @@ -219,6 +230,7 @@ def generator(dataset, training=False): large_emb_features, small_emb_features, label, + shuffle_buffer, embedding_dim, allow_id_dropping, max_ids_per_partition, diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index b34953f2..9c7a6784 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -1,4 +1,5 @@ #!/bin/bash +set -euo pipefail # ============================================================================== # Script Configuration & Argument Handling @@ -12,7 +13,7 @@ ACCELERATOR_TYPE=${1:-"v6e-8"} ZONE=${2:-"us-east5-a"} PROJECT=${3:-"tpu-prod-env-one-vm"} -USER_CONFIG_NAME=${4} # Capture the fourth argument +USER_CONFIG_NAME=${4} # Validate the provided accelerator type if [[ "${ACCELERATOR_TYPE}" != "v6e-8" && "${ACCELERATOR_TYPE}" != "v6e-16" ]]; then From b50a7f5a25402030727f34ad4083816d9bc8d73e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 20:09:00 +0530 Subject: [PATCH 14/40] Add way to load real dataset (1) --- examples/ml_perf/main.py | 4 +-- examples/ml_perf/run.sh | 60 ++++++++++++++++++++++++++++------------ 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index da7e35f8..2ae6a078 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -182,7 +182,7 @@ def generator(dataset, training=False): # File path file_pattern = ds_cfg["file_pattern"] # Shuffling - shuffle_buffer = ds_cfg["shuffle_buffer"] + shuffle_buffer = ds_cfg.get("shuffle_buffer", None) # Features label = ds_cfg["label"] dense_features = ds_cfg["dense"] @@ -211,7 +211,7 @@ def generator(dataset, training=False): num_epochs = training_cfg["num_epochs"] # For features which have vocabulary_size < embedding_threshold, we can - # just do a normal dense lookup for those instead of have distributed + # just do a normal dense lookup for those instead of having distributed # embeddings. We could ideally pass `placement = default_device` to # `keras_rs.layers.TableConfig` directly (and wouldn't have to do this # separation of features), but doing it that way will necessarily require diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index 9c7a6784..aebf840a 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -4,7 +4,7 @@ set -euo pipefail # ============================================================================== # Script Configuration & Argument Handling # ============================================================================== -# This script accepts up to three optional arguments: +# This script accepts up to four optional arguments: # 1. Accelerator Type (default: v6e-8, options: v6e-8, v6e-16) # 2. Zone (default: us-east5-a) # 3. Project (default: tpu-prod-env-one-vm) @@ -13,12 +13,12 @@ set -euo pipefail ACCELERATOR_TYPE=${1:-"v6e-8"} ZONE=${2:-"us-east5-a"} PROJECT=${3:-"tpu-prod-env-one-vm"} -USER_CONFIG_NAME=${4} +USER_CONFIG_NAME=${4:-""} # Initialize with an empty string if not provided # Validate the provided accelerator type if [[ "${ACCELERATOR_TYPE}" != "v6e-8" && "${ACCELERATOR_TYPE}" != "v6e-16" ]]; then echo "Error: Invalid accelerator type '${ACCELERATOR_TYPE}'." >&2 - echo "Usage: $0 [v6e-8|v6e-16] [gcp_zone] [gcp_project]" >&2 + echo "Usage: $0 [v6e-8|v6e-16] [gcp_zone] [gcp_project] [config_name]" >&2 exit 1 fi @@ -48,36 +48,62 @@ echo "--------------------------------------------------" # ============================================================================== # TPU VM Creation # ============================================================================== -echo ">>> Creating TPU VM: ${TPU_NAME} with accelerator ${ACCELERATOR_TYPE}..." -gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \ - --zone=${ZONE} \ - --accelerator-type=${ACCELERATOR_TYPE} \ - --version=v2-alpha-tpuv6e \ - --project=${PROJECT} \ - --metadata=enable-oslogin=TRUE \ - --scopes=https://www.googleapis.com/auth/cloud-platform +echo ">>> Checking for existing TPU VM: ${TPU_NAME}..." +if gcloud alpha compute tpus tpu-vm describe ${TPU_NAME} --zone=${ZONE} --project=${PROJECT} &> /dev/null; then + echo ">>> TPU VM '${TPU_NAME}' already exists. Skipping creation." +else + echo ">>> Creating TPU VM: ${TPU_NAME} with accelerator ${ACCELERATOR_TYPE}..." + gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \ + --zone=${ZONE} \ + --accelerator-type=${ACCELERATOR_TYPE} \ + --version=v2-alpha-tpuv6e \ + --project=${PROJECT} \ + --metadata=enable-oslogin=TRUE \ + --scopes=https://www.googleapis.com/auth/cloud-platform +fi # ============================================================================== -# Setup Python Virtual Environment on all workers +# Setup Python venv on all workers # ============================================================================== -echo ">>> Creating Python virtual environment..." +echo ">>> Checking for Python virtual environment..." gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="sudo apt-get update && sudo apt install -y python3.10-venv && python3 -m venv .keras-env" + --command="sudo apt-get update && sudo apt install -y python3.10-venv && if [ ! -d '.keras-env' ]; then echo '>>> Creating .keras-env...'; python3 -m venv .keras-env; else echo '>>> .keras-env already exists.'; fi" # ============================================================================== -# Clone KerasRS and Install Dependencies +# Clone/Update KerasRS and Install Dependencies # ============================================================================== -echo ">>> Cloning KerasRS and installing dependencies..." +echo ">>> Cloning or updating KerasRS and installing dependencies..." gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="source .keras-env/bin/activate && git clone https://github.com/abheesht17/keras-rs.git && cd keras-rs && git checkout ml-perf && pip install -e . && pip install tensorflow-datasets && pip uninstall -y tensorflow keras && pip install git+https://github.com/keras-team/keras.git && pip install jax-tpu-embedding tensorflow-cpu" + --command=" + set -e # Ensure script exits on error + source .keras-env/bin/activate + + if [ ! -d 'keras-rs' ]; then + echo '>>> Cloning keras-rs repository...' + git clone https://github.com/abheesht17/keras-rs.git + cd keras-rs + git checkout ml-perf + else + echo '>>> keras-rs repository exists. Pulling latest changes...' + cd keras-rs + git checkout ml-perf # Ensure we are on the correct branch + git pull + fi + + echo '>>> Installing/updating dependencies...' + pip install -e . + pip uninstall -y tensorflow keras + pip install git+https://github.com/keras-team/keras.git + pip install jax-tpu-embedding tensorflow-cpu + " # ============================================================================== From fe8dc414520457a8290eb129cf204791c34a9c3e Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 21:09:57 +0530 Subject: [PATCH 15/40] Add dataset path --- .../ml_perf/configs/v6e_8_full_dataset.py | 20 ++++++++ examples/ml_perf/run.sh | 50 +++++++++++++++---- 2 files changed, 60 insertions(+), 10 deletions(-) create mode 100644 examples/ml_perf/configs/v6e_8_full_dataset.py diff --git a/examples/ml_perf/configs/v6e_8_full_dataset.py b/examples/ml_perf/configs/v6e_8_full_dataset.py new file mode 100644 index 00000000..71fd3a5a --- /dev/null +++ b/examples/ml_perf/configs/v6e_8_full_dataset.py @@ -0,0 +1,20 @@ +from keras.utils import Config + +from .datasets.dummy_dataset import dataset_config +from .models.default_model import model_config +from .training.default_training import training_config + +config = Config() + +config.experiment_name = "v6e_8_full_dataset" +config.model_dir = "./v6e_8_full_dataset" + +config.dataset = dataset_config +config.dataset.file_pattern = ( + "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" + "train-00000-of-01024tfrecord" +) +config.model = model_config +config.training = training_config + +config.freeze() diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index aebf840a..122dd92c 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -5,20 +5,50 @@ set -euo pipefail # Script Configuration & Argument Handling # ============================================================================== # This script accepts up to four optional arguments: -# 1. Accelerator Type (default: v6e-8, options: v6e-8, v6e-16) -# 2. Zone (default: us-east5-a) -# 3. Project (default: tpu-prod-env-one-vm) -# 4. Config Name (default: derived from accelerator type, e.g., v6e_8) - -ACCELERATOR_TYPE=${1:-"v6e-8"} -ZONE=${2:-"us-east5-a"} -PROJECT=${3:-"tpu-prod-env-one-vm"} -USER_CONFIG_NAME=${4:-""} # Initialize with an empty string if not provided +# 1. --accelerator-type (default: v6e-8, options: v6e-8, v6e-16) +# 2. --zone (default: us-east5-a) +# 3. --project (default: tpu-prod-env-one-vm) +# 4. --config-name (default: derived from accelerator type, e.g., v6e_8) + +# Defaults +ACCELERATOR_TYPE="v6e-8" +ZONE="us-east5-a" +PROJECT="tpu-prod-env-one-vm" +USER_CONFIG_NAME="" + +# ============================================================================== +# Argument Parsing +# ============================================================================== + +show_help() { +cat << EOF +Usage: $0 [--accelerator-type ] [--zone ] [--project ] [--config-name ] +Options: + --accelerator-type The type of TPU accelerator (default: v6e-8). Options: v6e-8, v6e-16. + --zone The GCP zone for the TPU VM (default: us-east5-a). + --project The GCP project ID (default: tpu-prod-env-one-vm). + --config-name The specific configuration name to use for the training script. + (default: derived from accelerator type, e.g., v6e_8). + -h, --help Show this help message. +EOF +} + + +while [[ "$#" -gt 0 ]]; do + case $1 in + --accelerator-type) ACCELERATOR_TYPE="$2"; shift ;; + --zone) ZONE="$2"; shift ;; + --project) PROJECT="$2"; shift ;; + --config-name) USER_CONFIG_NAME="$2"; shift ;; + *) echo "Unknown parameter passed: $1"; show_help; exit 1 ;; + esac + shift +done # Validate the provided accelerator type if [[ "${ACCELERATOR_TYPE}" != "v6e-8" && "${ACCELERATOR_TYPE}" != "v6e-16" ]]; then echo "Error: Invalid accelerator type '${ACCELERATOR_TYPE}'." >&2 - echo "Usage: $0 [v6e-8|v6e-16] [gcp_zone] [gcp_project] [config_name]" >&2 + show_help exit 1 fi From ca297d45b25af03993b31d95edf7bdda8b74ecb4 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 21:12:57 +0530 Subject: [PATCH 16/40] Dataloader fixes (1) --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 2b34f352..89556b5f 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -117,7 +117,7 @@ def _get_feature_spec(self): def _preprocess(self, example): # Read example. - feature_spec = self.get_feature_spec() + feature_spec = self._get_feature_spec() example = tf.io.parse_single_example(example, feature_spec) # Dense features From 35c3d61aa97a40b935c29c0f3dc4b15ef7b3ab82 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 21:16:59 +0530 Subject: [PATCH 17/40] Dataloader fixes (2) --- examples/ml_perf/configs/datasets/dummy_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py index 0cbda44f..418eedba 100644 --- a/examples/ml_perf/configs/datasets/dummy_dataset.py +++ b/examples/ml_perf/configs/datasets/dummy_dataset.py @@ -5,7 +5,7 @@ dataset_config.file_pattern = None # Features dataset_config.label = "clicked" -dataset_config.dense = [f"int-feature-{i}" for i in range(13)] +dataset_config.dense = [f"int-feature-{i}" for i in range(1, 14)] dataset_config.sparse = [ { "name": "categorical-feature-14", From 9cc9b88fb7071b2fd081cf4bfdb4424e838ff665 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 18 Aug 2025 21:21:52 +0530 Subject: [PATCH 18/40] Dataloader fixes (3) --- examples/ml_perf/configs/v6e_8_full_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ml_perf/configs/v6e_8_full_dataset.py b/examples/ml_perf/configs/v6e_8_full_dataset.py index 71fd3a5a..ef0f5347 100644 --- a/examples/ml_perf/configs/v6e_8_full_dataset.py +++ b/examples/ml_perf/configs/v6e_8_full_dataset.py @@ -16,5 +16,6 @@ ) config.model = model_config config.training = training_config +config.training.batch_size = 4224 config.freeze() From a0431ba0a43182823bf1feb50b405fda01dc71d2 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 19 Aug 2025 05:56:23 +0530 Subject: [PATCH 19/40] Feature naming edit --- examples/ml_perf/main.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 2ae6a078..28545a42 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -47,11 +47,7 @@ def main( for large_emb_feature in large_emb_features: # Rename these features to something shorter; was facing some weird # issues with the longer names. - feature_name = ( - large_emb_feature["name"] - .replace("-", "_") - .replace("egorical_feature", "") - ) + feature_name = large_emb_feature["new_name"] vocabulary_size = large_emb_feature["vocabulary_size"] multi_hot_size = large_emb_feature["multi_hot_size"] From 2b9538f79ad8f3961b394022a36a7f57ab232e95 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 19 Aug 2025 05:56:55 +0530 Subject: [PATCH 20/40] Feature naming edit --- examples/ml_perf/main.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 28545a42..11d6c8d1 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -45,8 +45,6 @@ def main( # === Distributed embeddings' configs for sparse features === feature_configs = {} for large_emb_feature in large_emb_features: - # Rename these features to something shorter; was facing some weird - # issues with the longer names. feature_name = large_emb_feature["new_name"] vocabulary_size = large_emb_feature["vocabulary_size"] multi_hot_size = large_emb_feature["multi_hot_size"] From 187ccd50a962d31ffdf9ca9df5bbb44b96480739 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 19 Aug 2025 06:03:52 +0530 Subject: [PATCH 21/40] Feature naming edit --- examples/ml_perf/model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index a9f01f36..713e02c1 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -166,8 +166,9 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: dense_output = self.bottom_mlp(dense_input) # jax.debug.print("dense_ouput {}", dense_output.shape) large_embeddings = self.embedding_layer(large_emb_inputs) - small_embeddings = [] + small_embeddings = None if self.small_emb_features: + small_embeddings = [] small_emb_inputs = inputs["small_emb_inputs"] for small_emb_input, embedding_layer in zip( small_emb_inputs.values(), self.small_embedding_layers @@ -179,11 +180,10 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor: small_embeddings = ops.concatenate(small_embeddings, axis=-1) # Interaction - x = ops.concatenate( - [dense_output, small_embeddings, *large_embeddings.values()], - axis=-1, - ) - # jax.debug.print("x {}", x.shape) + to_concatenate = [dense_output, *large_embeddings.values()] + if small_embeddings is not None: + to_concatenate += [small_embeddings] + x = ops.concatenate(to_concatenate, axis=-1) x = self.dcn_block(x) # Predictions From a98d431649a31102713ae46b6afed8ff373893af Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Tue, 19 Aug 2025 06:04:32 +0530 Subject: [PATCH 22/40] Feature naming edit --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 11d6c8d1..8a01a8c3 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -147,7 +147,7 @@ def generator(dataset, training=False): break # Train the model. - model.fit(train_generator, epochs=1) + model.fit(train_generator, epochs=num_epochs) if __name__ == "__main__": From d15957d74937eea5dd1d7dcfddc399c4f2e3391d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 09:25:00 +0530 Subject: [PATCH 23/40] Actual dataset loading fixes (1) --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 89556b5f..a085ab18 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -96,7 +96,7 @@ def _get_feature_spec(self): feature_spec = { self.label: tf.io.FixedLenFeature( [self.batch_size], - dtype=tf.int64, + dtype=tf.float32, ) } From 8870c8df77acff1498880e6c7f3213502eb96c9c Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 09:42:18 +0530 Subject: [PATCH 24/40] Fix feature spec dtypes --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index a085ab18..0184ee56 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -110,7 +110,7 @@ def _get_feature_spec(self): name = emb_feat["name"] feature_spec[name] = tf.io.FixedLenFeature( [self.batch_size], - dtype=tf.string, + dtype=tf.int64, ) return feature_spec From e971f190e196fda777ad860521672f17c6d4091b Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 10:24:09 +0530 Subject: [PATCH 25/40] Fix feature spec dtypes --- examples/ml_perf/dataloader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 0184ee56..89556b5f 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -96,7 +96,7 @@ def _get_feature_spec(self): feature_spec = { self.label: tf.io.FixedLenFeature( [self.batch_size], - dtype=tf.float32, + dtype=tf.int64, ) } @@ -110,7 +110,7 @@ def _get_feature_spec(self): name = emb_feat["name"] feature_spec[name] = tf.io.FixedLenFeature( [self.batch_size], - dtype=tf.int64, + dtype=tf.string, ) return feature_spec From 02c1881a7d6cac530e0d17fe14846e9c272abde2 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 11:13:45 +0530 Subject: [PATCH 26/40] Allow different batch sizes from file batch size --- .../ml_perf/configs/v6e_8_full_dataset.py | 4 +++- examples/ml_perf/dataloader.py | 24 ++++++++++++------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/examples/ml_perf/configs/v6e_8_full_dataset.py b/examples/ml_perf/configs/v6e_8_full_dataset.py index ef0f5347..eca6c7ed 100644 --- a/examples/ml_perf/configs/v6e_8_full_dataset.py +++ b/examples/ml_perf/configs/v6e_8_full_dataset.py @@ -14,8 +14,10 @@ "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" "train-00000-of-01024tfrecord" ) +# The path which we are reading from already has the batched dataset. +config.dataset.file_batch_size = 4224 config.model = model_config config.training = training_config -config.training.batch_size = 4224 +config.training.batch_size = 256 config.freeze() diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 89556b5f..bfa1a963 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -7,6 +7,7 @@ def __init__( self, file_pattern, batch_size, + file_batch_size, dense_features, large_emb_features, small_emb_features, @@ -16,6 +17,7 @@ def __init__( # Passed attributes. self.file_pattern = file_pattern self.batch_size = batch_size + self.file_batch_size = file_batch_size self.dense_features = dense_features self.large_emb_features = large_emb_features self.small_emb_features = small_emb_features @@ -95,21 +97,21 @@ def _create_dummy_dataset(self): def _get_feature_spec(self): feature_spec = { self.label: tf.io.FixedLenFeature( - [self.batch_size], + [self.file_batch_size], dtype=tf.int64, ) } for dense_feat in self.dense_features: feature_spec[dense_feat] = tf.io.FixedLenFeature( - [self.batch_size], + [self.file_batch_size], dtype=tf.float32, ) for emb_feat in self.large_emb_features + self.small_emb_features: name = emb_feat["name"] feature_spec[name] = tf.io.FixedLenFeature( - [self.batch_size], + [self.file_batch_size], dtype=tf.string, ) @@ -123,7 +125,7 @@ def _preprocess(self, example): # Dense features dense_input = tf.stack( [ - tf.reshape(example[dense_feature], [self.batch_size, 1]) + tf.reshape(example[dense_feature], [self.file_batch_size, 1]) for dense_feature in self.dense_features ], axis=-1, @@ -138,7 +140,7 @@ def _get_emb_inputs(emb_features): raw_values = tf.io.decode_raw(example[name], tf.int64) raw_values = tf.reshape( - raw_values, [self.batch_size, multi_hot_size] + raw_values, [self.file_batch_size, multi_hot_size] ) emb_inputs[new_name] = raw_values return emb_inputs @@ -148,7 +150,7 @@ def _get_emb_inputs(emb_features): small_emb_inputs = _get_emb_inputs(self.small_emb_features) # Labels - labels = tf.reshape(example[self.label], [self.batch_size]) + labels = tf.reshape(example[self.label], [self.file_batch_size]) x = { "dense_input": dense_input, @@ -179,14 +181,20 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): # Process example. dataset = dataset.map( - lambda x: self._preprocess(x), - num_parallel_calls=tf.data.AUTOTUNE, + self._preprocess, num_parallel_calls=tf.data.AUTOTUNE ) + dataset.unbatch() # Shuffle dataset if in training mode. if self.training and shuffle_buffer and shuffle_buffer > 0: dataset = dataset.shuffle(shuffle_buffer) + dataset = dataset.batch( + self.batch_size, + drop_remainder=True, + num_parallel_calls=tf.data.AUTOTUNE, + ) + dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset From b5db3043ffe9a7d9e36d1c6dd238dcc476de68fe Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 11:18:16 +0530 Subject: [PATCH 27/40] Allow different batch sizes from file batch size (fixes) --- examples/ml_perf/main.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 8a01a8c3..102fac26 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -32,6 +32,7 @@ def main( dcn_projection_dim, learning_rate, global_batch_size, + file_batch_size, num_epochs, ): # Set DDP as Keras distribution strategy @@ -107,6 +108,7 @@ def main( train_ds = DataLoader( file_pattern=file_pattern, batch_size=per_host_batch_size, + file_batch_size=file_batch_size, dense_features=dense_features, large_emb_features=large_emb_features, small_emb_features=small_emb_features, @@ -175,6 +177,8 @@ def generator(dataset, training=False): ds_cfg = config["dataset"] # File path file_pattern = ds_cfg["file_pattern"] + # File batch size + file_batch_size = ds_cfg.get("file_batch_size", None) # Shuffling shuffle_buffer = ds_cfg.get("shuffle_buffer", None) # Features @@ -236,5 +240,6 @@ def generator(dataset, training=False): dcn_projection_dim, learning_rate, global_batch_size, + file_batch_size, num_epochs, ) From e98bdf94177a4f04384ebc940fd6ee1015366d60 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 11:27:44 +0530 Subject: [PATCH 28/40] Fix feature naming --- .../ml_perf/configs/datasets/dummy_dataset.py | 52 +++++++++---------- examples/ml_perf/dataloader.py | 6 +-- examples/ml_perf/main.py | 2 +- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/examples/ml_perf/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py index 418eedba..411b3b41 100644 --- a/examples/ml_perf/configs/datasets/dummy_dataset.py +++ b/examples/ml_perf/configs/datasets/dummy_dataset.py @@ -11,156 +11,156 @@ "name": "categorical-feature-14", "vocabulary_size": 40000000, "multi_hot_size": 3, - "new_name": "cat_14_id", + "new_name": "cat_14", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, "multi_hot_size": 2, - "new_name": "cat_15_id", + "new_name": "cat_15", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, "multi_hot_size": 1, - "new_name": "cat_16_id", + "new_name": "cat_16", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, "multi_hot_size": 2, - "new_name": "cat_17_id", + "new_name": "cat_17", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, "multi_hot_size": 6, - "new_name": "cat_18_id", + "new_name": "cat_18", }, { "name": "categorical-feature-19", "vocabulary_size": 3, "multi_hot_size": 1, - "new_name": "cat_19_id", + "new_name": "cat_19", }, { "name": "categorical-feature-20", "vocabulary_size": 7122, "multi_hot_size": 1, - "new_name": "cat_20_id", + "new_name": "cat_20", }, { "name": "categorical-feature-21", "vocabulary_size": 1543, "multi_hot_size": 1, - "new_name": "cat_21_id", + "new_name": "cat_21", }, { "name": "categorical-feature-22", "vocabulary_size": 63, "multi_hot_size": 1, - "new_name": "cat_22_id", + "new_name": "cat_22", }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, "multi_hot_size": 7, - "new_name": "cat_23_id", + "new_name": "cat_23", }, { "name": "categorical-feature-24", "vocabulary_size": 3067956, "multi_hot_size": 3, - "new_name": "cat_24_id", + "new_name": "cat_24", }, { "name": "categorical-feature-25", "vocabulary_size": 405282, "multi_hot_size": 8, - "new_name": "cat_25_id", + "new_name": "cat_25", }, { "name": "categorical-feature-26", "vocabulary_size": 10, "multi_hot_size": 1, - "new_name": "cat_26_id", + "new_name": "cat_26", }, { "name": "categorical-feature-27", "vocabulary_size": 2209, "multi_hot_size": 6, - "new_name": "cat_27_id", + "new_name": "cat_27", }, { "name": "categorical-feature-28", "vocabulary_size": 11938, "multi_hot_size": 9, - "new_name": "cat_28_id", + "new_name": "cat_28", }, { "name": "categorical-feature-29", "vocabulary_size": 155, "multi_hot_size": 5, - "new_name": "cat_29_id", + "new_name": "cat_29", }, { "name": "categorical-feature-30", "vocabulary_size": 4, "multi_hot_size": 1, - "new_name": "cat_30_id", + "new_name": "cat_30", }, { "name": "categorical-feature-31", "vocabulary_size": 976, "multi_hot_size": 1, - "new_name": "cat_31_id", + "new_name": "cat_31", }, { "name": "categorical-feature-32", "vocabulary_size": 14, "multi_hot_size": 1, - "new_name": "cat_32_id", + "new_name": "cat_32", }, { "name": "categorical-feature-33", "vocabulary_size": 40000000, "multi_hot_size": 12, - "new_name": "cat_33_id", + "new_name": "cat_33", }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, "multi_hot_size": 100, - "new_name": "cat_34_id", + "new_name": "cat_34", }, { "name": "categorical-feature-35", "vocabulary_size": 40000000, "multi_hot_size": 27, - "new_name": "cat_35_id", + "new_name": "cat_35", }, { "name": "categorical-feature-36", "vocabulary_size": 590152, "multi_hot_size": 10, - "new_name": "cat_36_id", + "new_name": "cat_36", }, { "name": "categorical-feature-37", "vocabulary_size": 12973, "multi_hot_size": 3, - "new_name": "cat_37_id", + "new_name": "cat_37", }, { "name": "categorical-feature-38", "vocabulary_size": 108, "multi_hot_size": 1, - "new_name": "cat_38_id", + "new_name": "cat_38", }, { "name": "categorical-feature-39", "vocabulary_size": 36, "multi_hot_size": 1, - "new_name": "cat_39_id", + "new_name": "cat_39", }, ] diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index bfa1a963..fdf836a1 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -54,7 +54,7 @@ def _get_dummy_batch(self): vocabulary_size = large_emb_feature["vocabulary_size"] multi_hot_size = large_emb_feature["multi_hot_size"] - large_emb_inputs[new_name] = np.random.randint( + large_emb_inputs[f"{new_name}_id"] = np.random.randint( low=0, high=vocabulary_size, size=(self.batch_size, multi_hot_size), @@ -71,7 +71,7 @@ def _get_dummy_batch(self): vocabulary_size = small_emb_feature["vocabulary_size"] multi_hot_size = small_emb_feature["multi_hot_size"] - small_emb_inputs[new_name] = np.random.randint( + small_emb_inputs[f"{new_name}_id"] = np.random.randint( low=0, high=vocabulary_size, size=(self.batch_size, multi_hot_size), @@ -142,7 +142,7 @@ def _get_emb_inputs(emb_features): raw_values = tf.reshape( raw_values, [self.file_batch_size, multi_hot_size] ) - emb_inputs[new_name] = raw_values + emb_inputs[f"{new_name}_id"] = raw_values return emb_inputs # Sparse features diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 102fac26..2658f98c 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -72,7 +72,7 @@ def main( max_unique_ids_per_partition=max_unique_ids_per_partition, ) feature_configs[f"{feature_name}_id"] = keras_rs.layers.FeatureConfig( - name=feature_name.replace("id", ""), + name=feature_name, table=table_config, # TODO: Verify whether it should be `(bsz, 1)` or # `(bsz, multi_hot_size)`. From 73ca47710afb3966c6ff9795277b00d2c34dc920 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 12:55:16 +0530 Subject: [PATCH 29/40] Fix batching --- examples/ml_perf/configs/v6e_8_full_dataset.py | 4 ++++ examples/ml_perf/dataloader.py | 2 +- examples/ml_perf/main.py | 5 +++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/ml_perf/configs/v6e_8_full_dataset.py b/examples/ml_perf/configs/v6e_8_full_dataset.py index eca6c7ed..8489b084 100644 --- a/examples/ml_perf/configs/v6e_8_full_dataset.py +++ b/examples/ml_perf/configs/v6e_8_full_dataset.py @@ -14,6 +14,10 @@ "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" "train-00000-of-01024tfrecord" ) +config.dataset.val_file_pattern = ( + "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" + "train-00000-of-01024tfrecord" +) # The path which we are reading from already has the batched dataset. config.dataset.file_batch_size = 4224 config.model = model_config diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index fdf836a1..5a14257b 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -183,7 +183,7 @@ def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): dataset = dataset.map( self._preprocess, num_parallel_calls=tf.data.AUTOTUNE ) - dataset.unbatch() + dataset = dataset.unbatch() # Shuffle dataset if in training mode. if self.training and shuffle_buffer and shuffle_buffer > 0: diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 2658f98c..ce12751b 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -16,6 +16,7 @@ def main( file_pattern, + val_file_pattern, dense_features, large_emb_features, small_emb_features, @@ -124,6 +125,7 @@ def main( # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. if num_processes > 1: train_ds = distribution.distribute_dataset(train_ds) + # eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False def generator(dataset, training=False): @@ -144,6 +146,7 @@ def generator(dataset, training=False): yield (x, y) train_generator = generator(train_ds, training=True) + # eval_generator = generator(eval_ds, training=False) for first_batch in train_generator: model(first_batch[0]) break @@ -177,6 +180,7 @@ def generator(dataset, training=False): ds_cfg = config["dataset"] # File path file_pattern = ds_cfg["file_pattern"] + val_file_pattern = ds_cfg("val_file_pattern", None) # File batch size file_batch_size = ds_cfg.get("file_batch_size", None) # Shuffling @@ -224,6 +228,7 @@ def generator(dataset, training=False): main( file_pattern, + val_file_pattern, dense_features, large_emb_features, small_emb_features, From c2ad8a971a515fd62eb3d28f87c27c9f8c1bbb6a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 12:59:52 +0530 Subject: [PATCH 30/40] Fix batching --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index ce12751b..e7a5f735 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -180,7 +180,7 @@ def generator(dataset, training=False): ds_cfg = config["dataset"] # File path file_pattern = ds_cfg["file_pattern"] - val_file_pattern = ds_cfg("val_file_pattern", None) + val_file_pattern = ds_cfg.get("val_file_pattern", None) # File batch size file_batch_size = ds_cfg.get("file_batch_size", None) # Shuffling From a0801439f3a9b602613ec6186c9840ef9a6298a3 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 13:48:24 +0530 Subject: [PATCH 31/40] Fix dense features --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 5a14257b..e0eafc7f 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -123,7 +123,7 @@ def _preprocess(self, example): example = tf.io.parse_single_example(example, feature_spec) # Dense features - dense_input = tf.stack( + dense_input = tf.concatenate( [ tf.reshape(example[dense_feature], [self.file_batch_size, 1]) for dense_feature in self.dense_features From 28b7189f187d6dfbb7af73f34e7e5b2d2537337f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 13:51:32 +0530 Subject: [PATCH 32/40] Fix dense features concat --- examples/ml_perf/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index e0eafc7f..9f4df58d 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -123,7 +123,7 @@ def _preprocess(self, example): example = tf.io.parse_single_example(example, feature_spec) # Dense features - dense_input = tf.concatenate( + dense_input = tf.concat( [ tf.reshape(example[dense_feature], [self.file_batch_size, 1]) for dense_feature in self.dense_features From a47817d6eda032f385380092e90bc72a22418c1b Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 14:21:01 +0530 Subject: [PATCH 33/40] Rename multi_hot_size to feature_list_length --- .../ml_perf/configs/datasets/dummy_dataset.py | 52 +++++++++---------- examples/ml_perf/dataloader.py | 12 ++--- examples/ml_perf/main.py | 6 +-- 3 files changed, 35 insertions(+), 35 deletions(-) diff --git a/examples/ml_perf/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py index 411b3b41..510f8b62 100644 --- a/examples/ml_perf/configs/datasets/dummy_dataset.py +++ b/examples/ml_perf/configs/datasets/dummy_dataset.py @@ -10,157 +10,157 @@ { "name": "categorical-feature-14", "vocabulary_size": 40000000, - "multi_hot_size": 3, + "feature_list_length": 3, "new_name": "cat_14", }, { "name": "categorical-feature-15", "vocabulary_size": 39060, - "multi_hot_size": 2, + "feature_list_length": 2, "new_name": "cat_15", }, { "name": "categorical-feature-16", "vocabulary_size": 17295, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_16", }, { "name": "categorical-feature-17", "vocabulary_size": 7424, - "multi_hot_size": 2, + "feature_list_length": 2, "new_name": "cat_17", }, { "name": "categorical-feature-18", "vocabulary_size": 20265, - "multi_hot_size": 6, + "feature_list_length": 6, "new_name": "cat_18", }, { "name": "categorical-feature-19", "vocabulary_size": 3, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_19", }, { "name": "categorical-feature-20", "vocabulary_size": 7122, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_20", }, { "name": "categorical-feature-21", "vocabulary_size": 1543, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_21", }, { "name": "categorical-feature-22", "vocabulary_size": 63, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_22", }, { "name": "categorical-feature-23", "vocabulary_size": 40000000, - "multi_hot_size": 7, + "feature_list_length": 7, "new_name": "cat_23", }, { "name": "categorical-feature-24", "vocabulary_size": 3067956, - "multi_hot_size": 3, + "feature_list_length": 3, "new_name": "cat_24", }, { "name": "categorical-feature-25", "vocabulary_size": 405282, - "multi_hot_size": 8, + "feature_list_length": 8, "new_name": "cat_25", }, { "name": "categorical-feature-26", "vocabulary_size": 10, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_26", }, { "name": "categorical-feature-27", "vocabulary_size": 2209, - "multi_hot_size": 6, + "feature_list_length": 6, "new_name": "cat_27", }, { "name": "categorical-feature-28", "vocabulary_size": 11938, - "multi_hot_size": 9, + "feature_list_length": 9, "new_name": "cat_28", }, { "name": "categorical-feature-29", "vocabulary_size": 155, - "multi_hot_size": 5, + "feature_list_length": 5, "new_name": "cat_29", }, { "name": "categorical-feature-30", "vocabulary_size": 4, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_30", }, { "name": "categorical-feature-31", "vocabulary_size": 976, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_31", }, { "name": "categorical-feature-32", "vocabulary_size": 14, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_32", }, { "name": "categorical-feature-33", "vocabulary_size": 40000000, - "multi_hot_size": 12, + "feature_list_length": 12, "new_name": "cat_33", }, { "name": "categorical-feature-34", "vocabulary_size": 40000000, - "multi_hot_size": 100, + "feature_list_length": 100, "new_name": "cat_34", }, { "name": "categorical-feature-35", "vocabulary_size": 40000000, - "multi_hot_size": 27, + "feature_list_length": 27, "new_name": "cat_35", }, { "name": "categorical-feature-36", "vocabulary_size": 590152, - "multi_hot_size": 10, + "feature_list_length": 10, "new_name": "cat_36", }, { "name": "categorical-feature-37", "vocabulary_size": 12973, - "multi_hot_size": 3, + "feature_list_length": 3, "new_name": "cat_37", }, { "name": "categorical-feature-38", "vocabulary_size": 108, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_38", }, { "name": "categorical-feature-39", "vocabulary_size": 36, - "multi_hot_size": 1, + "feature_list_length": 1, "new_name": "cat_39", }, ] diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 9f4df58d..5d65c49c 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -52,12 +52,12 @@ def _get_dummy_batch(self): name = large_emb_feature["name"] new_name = large_emb_feature.get("new_name", name) vocabulary_size = large_emb_feature["vocabulary_size"] - multi_hot_size = large_emb_feature["multi_hot_size"] + feature_list_length = large_emb_feature["feature_list_length"] large_emb_inputs[f"{new_name}_id"] = np.random.randint( low=0, high=vocabulary_size, - size=(self.batch_size, multi_hot_size), + size=(self.batch_size, feature_list_length), dtype=np.int64, ) @@ -69,12 +69,12 @@ def _get_dummy_batch(self): name = small_emb_feature["name"] new_name = small_emb_feature.get("new_name", name) vocabulary_size = small_emb_feature["vocabulary_size"] - multi_hot_size = small_emb_feature["multi_hot_size"] + feature_list_length = small_emb_feature["feature_list_length"] small_emb_inputs[f"{new_name}_id"] = np.random.randint( low=0, high=vocabulary_size, - size=(self.batch_size, multi_hot_size), + size=(self.batch_size, feature_list_length), dtype=np.int64, ) @@ -136,11 +136,11 @@ def _get_emb_inputs(emb_features): for emb_feature in emb_features: name = emb_feature["name"] new_name = emb_feature.get("new_name", name) - multi_hot_size = emb_feature["multi_hot_size"] + feature_list_length = emb_feature["feature_list_length"] raw_values = tf.io.decode_raw(example[name], tf.int64) raw_values = tf.reshape( - raw_values, [self.file_batch_size, multi_hot_size] + raw_values, [self.file_batch_size, feature_list_length] ) emb_inputs[f"{new_name}_id"] = raw_values return emb_inputs diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index e7a5f735..a6bce733 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -49,7 +49,7 @@ def main( for large_emb_feature in large_emb_features: feature_name = large_emb_feature["new_name"] vocabulary_size = large_emb_feature["vocabulary_size"] - multi_hot_size = large_emb_feature["multi_hot_size"] + feature_list_length = large_emb_feature["feature_list_length"] table_config = keras_rs.layers.TableConfig( name=f"{feature_name}_table", @@ -76,8 +76,8 @@ def main( name=feature_name, table=table_config, # TODO: Verify whether it should be `(bsz, 1)` or - # `(bsz, multi_hot_size)`. - input_shape=(per_host_batch_size, multi_hot_size), + # `(bsz, feature_list_length)`. + input_shape=(per_host_batch_size, feature_list_length), output_shape=(per_host_batch_size, embedding_dim), ) From 9a33f091d31bdfc06acf8b6c60d4e5a85cb225d6 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 14:24:35 +0530 Subject: [PATCH 34/40] Rename sparse to lookup --- examples/ml_perf/configs/datasets/dummy_dataset.py | 2 +- examples/ml_perf/dataloader.py | 6 +++--- examples/ml_perf/main.py | 12 ++++++------ examples/ml_perf/model.py | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/ml_perf/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py index 510f8b62..aac66c57 100644 --- a/examples/ml_perf/configs/datasets/dummy_dataset.py +++ b/examples/ml_perf/configs/datasets/dummy_dataset.py @@ -6,7 +6,7 @@ # Features dataset_config.label = "clicked" dataset_config.dense = [f"int-feature-{i}" for i in range(1, 14)] -dataset_config.sparse = [ +dataset_config.lookup = [ { "name": "categorical-feature-14", "vocabulary_size": 40000000, diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py index 5d65c49c..ce2e7286 100644 --- a/examples/ml_perf/dataloader.py +++ b/examples/ml_perf/dataloader.py @@ -46,7 +46,7 @@ def _get_dummy_batch(self): ] data["dense_input"] = np.concatenate(dense_input_list, axis=-1) - # Sparse features + # Big embedding features large_emb_inputs = {} for large_emb_feature in self.large_emb_features: name = large_emb_feature["name"] @@ -63,7 +63,7 @@ def _get_dummy_batch(self): data["large_emb_inputs"] = large_emb_inputs - # Dense lookup features + # Small embedding features small_emb_inputs = {} for small_emb_feature in self.small_emb_features: name = small_emb_feature["name"] @@ -145,7 +145,7 @@ def _get_emb_inputs(emb_features): emb_inputs[f"{new_name}_id"] = raw_values return emb_inputs - # Sparse features + # Embedding/lookup features large_emb_inputs = _get_emb_inputs(self.large_emb_features) small_emb_inputs = _get_emb_inputs(self.small_emb_features) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index a6bce733..a0d2eac7 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -44,7 +44,7 @@ def main( per_host_batch_size = global_batch_size // num_processes - # === Distributed embeddings' configs for sparse features === + # === Distributed embeddings' configs for lookup features === feature_configs = {} for large_emb_feature in large_emb_features: feature_name = large_emb_feature["new_name"] @@ -82,9 +82,9 @@ def main( ) # === Instantiate model === - # We instantiate the model first, because we need to preprocess sparse - # inputs using the distributed embedding layer defined inside the model - # class. + # We instantiate the model first, because we need to preprocess large + # embedding feature inputs using the distributed embedding layer defined + # inside the model class. print("===== Initialising model =====") model = DLRMDCNV2( large_emb_feature_configs=feature_configs, @@ -130,7 +130,7 @@ def main( def generator(dataset, training=False): """Converts tf.data Dataset to a Python generator and preprocesses - sparse features. + large embedding features. """ for features, labels in dataset: preprocessed_large_embeddings = model.embedding_layer.preprocess( @@ -188,7 +188,7 @@ def generator(dataset, training=False): # Features label = ds_cfg["label"] dense_features = ds_cfg["dense"] - emb_features = ds_cfg["sparse"] + emb_features = ds_cfg["lookup"] # == Model config == model_cfg = config["model"] diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py index 713e02c1..4f84bbbf 100644 --- a/examples/ml_perf/model.py +++ b/examples/ml_perf/model.py @@ -56,7 +56,7 @@ def __init__( The model processes two types of input features: 1. Dense Features: Continuous-valued features that are processed by a multi-layer perceptron (the "bottom MLP"). - 2. Sparse Features: High-cardinality categorical features that are + 2. Lookup Features: High-cardinality categorical features that are first mapped into low-dimensional embedding vectors using the `keras_rs.layers.DistributedEmbedding` layer. This layer is highly optimized for large-scale recommendation models, especially on TPUs From a66e1c6fd15ecf12e7aa4f3db2a902bd4a2c54f1 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 19:04:33 +0530 Subject: [PATCH 35/40] Debug --- examples/ml_perf/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index a0d2eac7..2d7504e2 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -80,6 +80,8 @@ def main( input_shape=(per_host_batch_size, feature_list_length), output_shape=(per_host_batch_size, embedding_dim), ) + + print("-->", os.environ['XLA_FLAGS']) # === Instantiate model === # We instantiate the model first, because we need to preprocess large From a56532a196beaf88fffda3c9f6279b88e449da34 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 19:13:51 +0530 Subject: [PATCH 36/40] Try out XLA flags --- examples/ml_perf/main.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 2d7504e2..43e06320 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -45,6 +45,13 @@ def main( per_host_batch_size = global_batch_size // num_processes # === Distributed embeddings' configs for lookup features === + # Set XLA flags. + os.environ['XLA_FLAGS'] = ( + "--xla_sparse_core_max_ids_per_partition_per_sample=" + f"{max_ids_per_partition} " + "--xla_sparse_core_max_unique_ids_per_partition_per_sample=" + f"{max_unique_ids_per_partition}" + ) feature_configs = {} for large_emb_feature in large_emb_features: feature_name = large_emb_feature["new_name"] @@ -80,8 +87,6 @@ def main( input_shape=(per_host_batch_size, feature_list_length), output_shape=(per_host_batch_size, embedding_dim), ) - - print("-->", os.environ['XLA_FLAGS']) # === Instantiate model === # We instantiate the model first, because we need to preprocess large From 0a9d00b8df5667ff37094d603420c9b04e616c72 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 19:20:00 +0530 Subject: [PATCH 37/40] Format --- examples/ml_perf/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 43e06320..c085c1f0 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -46,7 +46,7 @@ def main( # === Distributed embeddings' configs for lookup features === # Set XLA flags. - os.environ['XLA_FLAGS'] = ( + os.environ["XLA_FLAGS"] = ( "--xla_sparse_core_max_ids_per_partition_per_sample=" f"{max_ids_per_partition} " "--xla_sparse_core_max_unique_ids_per_partition_per_sample=" From 42c40223e7ebd01d88dcc477884b6c5a869b301d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 20 Aug 2025 19:23:32 +0530 Subject: [PATCH 38/40] Format --- examples/ml_perf/main.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index c085c1f0..8dbb38cb 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -135,6 +135,10 @@ def main( # eval_ds = distribution.distribute_dataset(eval_ds) distribution.auto_shard_dataset = False + # Print one sample. + for element in train_ds.take(1): + print(">>> train sample", element[0]) + def generator(dataset, training=False): """Converts tf.data Dataset to a Python generator and preprocesses large embedding features. From 6e76b59eb2f77f47763fbeb789aee74d3ee195a7 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 3 Sep 2025 19:57:13 +0530 Subject: [PATCH 39/40] Copy over Antonio's fixes --- .../embedding/jax/distributed_embedding.py | 204 +++--------------- .../layers/embedding/jax/embedding_utils.py | 90 ++++++-- 2 files changed, 102 insertions(+), 192 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 2562a8be..72f504af 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -15,7 +15,6 @@ table_stacking as jte_table_stacking, ) from jax_tpu_embedding.sparsecore.utils import utils as jte_utils -from keras.src import backend from keras_rs.src import types from keras_rs.src.layers.embedding import base_distributed_embedding @@ -247,23 +246,6 @@ def _create_sparsecore_distribution( ) return sparsecore_distribution, sparsecore_layout - def _create_cpu_distribution( - self, cpu_axis_name: str = "cpu" - ) -> tuple[ - keras.distribution.ModelParallel, keras.distribution.TensorLayout - ]: - """Share a variable across all CPU processes.""" - cpu_devices = jax.devices("cpu") - device_mesh = keras.distribution.DeviceMesh( - (len(cpu_devices),), [cpu_axis_name], cpu_devices - ) - replicated_layout = keras.distribution.TensorLayout([], device_mesh) - layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh) - cpu_distribution = keras.distribution.ModelParallel( - layout_map=layout_map - ) - return cpu_distribution, replicated_layout - def _add_sparsecore_weight( self, name: str, @@ -405,11 +387,6 @@ def sparsecore_build( self._sparsecore_layout = sparsecore_layout self._sparsecore_distribution = sparsecore_distribution - # Distribution for CPU operations. - cpu_distribution, cpu_layout = self._create_cpu_distribution() - self._cpu_distribution = cpu_distribution - self._cpu_layout = cpu_layout - mesh = sparsecore_distribution.device_mesh.backend_mesh global_device_count = mesh.devices.size num_sc_per_device = jte_utils.num_sparsecores_per_device( @@ -466,10 +443,6 @@ def sparsecore_build( # Collect all stacked tables. table_specs = embedding_utils.get_table_specs(feature_specs) table_stacks = embedding_utils.get_table_stacks(table_specs) - stacked_table_specs = { - stack_name: stack[0].stacked_table_spec - for stack_name, stack in table_stacks.items() - } # Create variables for all stacked tables and slot variables. with sparsecore_distribution.scope(): @@ -502,50 +475,6 @@ def sparsecore_build( ) self._iterations.overwrite_with_gradient = True - with cpu_distribution.scope(): - # Create variables to track static buffer size and max IDs for each - # table during preprocessing. These variables are shared across all - # processes on CPU. We don't add these via `add_weight` because we - # can't have them passed to the training function. - replicated_zeros_initializer = ShardedInitializer( - "zeros", cpu_layout - ) - - with backend.name_scope(self.name, caller=self): - self._preprocessing_buffer_size = { - table_name: backend.Variable( - initializer=replicated_zeros_initializer, - shape=(), - dtype=backend.standardize_dtype("int32"), - trainable=False, - name=table_name + ":preprocessing:buffer_size", - ) - for table_name in stacked_table_specs.keys() - } - self._preprocessing_max_unique_ids_per_partition = { - table_name: backend.Variable( - shape=(), - name=table_name - + ":preprocessing:max_unique_ids_per_partition", - initializer=replicated_zeros_initializer, - dtype=backend.standardize_dtype("int32"), - trainable=False, - ) - for table_name in stacked_table_specs.keys() - } - - self._preprocessing_max_ids_per_partition = { - table_name: backend.Variable( - shape=(), - name=table_name - + ":preprocessing:max_ids_per_partition", - initializer=replicated_zeros_initializer, - dtype=backend.standardize_dtype("int32"), - trainable=False, - ) - for table_name in stacked_table_specs.keys() - } - self._config = jte_embedding_lookup.EmbeddingLookupConfiguration( feature_specs, mesh=mesh, @@ -660,76 +589,35 @@ def _sparsecore_preprocess( mesh.devices.item(0) ) - # Get current buffer size/max_ids. - previous_max_ids_per_partition = keras.tree.map_structure( - lambda max_ids_per_partition: max_ids_per_partition.value.item(), - self._preprocessing_max_ids_per_partition, - ) - previous_max_unique_ids_per_partition = keras.tree.map_structure( - lambda max_unique_ids_per_partition: ( - max_unique_ids_per_partition.value.item() - ), - self._preprocessing_max_unique_ids_per_partition, - ) - previous_buffer_size = keras.tree.map_structure( - lambda buffer_size: buffer_size.value.item(), - self._preprocessing_buffer_size, - ) - preprocessed, stats = embedding_utils.stack_and_shard_samples( self._config.feature_specs, samples, local_device_count, global_device_count, num_sc_per_device, - static_buffer_size=previous_buffer_size, ) - # Extract max unique IDs and buffer sizes. - # We need to replicate this value across all local CPU devices. if training: + # Synchronize input statistics across all devices and update the + # underlying stacked tables specs in the feature specs. + prev_stats = embedding_utils.get_stacked_table_stats( + self._config.feature_specs + ) + + # Take the maximum with existing stats. + stats = keras.tree.map_structure(max, prev_stats, stats) + + # Flatten the stats so we can more efficiently transfer them + # between hosts. We use jax.tree because we will later need to + # unflatten. + flat_stats, stats_treedef = jax.tree.flatten(stats) + + # In the case of multiple local CPU devices per host, we need to + # replicate the stats to placate JAX collectives. num_local_cpu_devices = jax.local_device_count("cpu") - local_max_ids_per_partition = { - table_name: np.repeat( - # Maximum across all partitions and previous max. - np.maximum( - np.max(elems), - previous_max_ids_per_partition[table_name], - ), - num_local_cpu_devices, - ) - for table_name, elems in stats.max_ids_per_partition.items() - } - local_max_unique_ids_per_partition = { - name: np.repeat( - # Maximum across all partitions and previous max. - np.maximum( - np.max(elems), - previous_max_unique_ids_per_partition[name], - ), - num_local_cpu_devices, - ) - for name, elems in stats.max_unique_ids_per_partition.items() - } - local_buffer_size = { - table_name: np.repeat( - np.maximum( - np.max( - # Round values up to the next multiple of 8. - # Currently using this as a proxy for the actual - # required buffer size. - ((elems + 7) // 8) * 8 - ) - * global_device_count - * num_sc_per_device - * local_device_count - * num_sc_per_device, - previous_buffer_size[table_name], - ), - num_local_cpu_devices, - ) - for table_name, elems in stats.max_ids_per_partition.items() - } + tiled_stats = np.tile( + np.array(flat_stats, dtype=np.int32), (num_local_cpu_devices, 1) + ) # Aggregate variables across all processes/devices. max_across_cpus = jax.pmap( @@ -737,48 +625,24 @@ def _sparsecore_preprocess( x, "all_cpus" ), axis_name="all_cpus", - devices=self._cpu_layout.device_mesh.backend_mesh.devices, - ) - new_max_ids_per_partition = max_across_cpus( - local_max_ids_per_partition - ) - new_max_unique_ids_per_partition = max_across_cpus( - local_max_unique_ids_per_partition + backend="cpu", ) - new_buffer_size = max_across_cpus(local_buffer_size) - - # Assign new preprocessing parameters. - with self._cpu_distribution.scope(): - # For each process, all max ids/buffer sizes are replicated - # across all local devices. Take the value from the first - # device. - keras.tree.map_structure( - lambda var, values: var.assign(values[0]), - self._preprocessing_max_ids_per_partition, - new_max_ids_per_partition, - ) - keras.tree.map_structure( - lambda var, values: var.assign(values[0]), - self._preprocessing_max_unique_ids_per_partition, - new_max_unique_ids_per_partition, - ) - keras.tree.map_structure( - lambda var, values: var.assign(values[0]), - self._preprocessing_buffer_size, - new_buffer_size, - ) - # Update parameters in the underlying feature specs. - int_max_ids_per_partition = keras.tree.map_structure( - lambda varray: varray.item(), new_max_ids_per_partition - ) - int_max_unique_ids_per_partition = keras.tree.map_structure( - lambda varray: varray.item(), - new_max_unique_ids_per_partition, + flat_stats = max_across_cpus(tiled_stats)[0].tolist() + stats = jax.tree.unflatten(stats_treedef, flat_stats) + + # Update configuration and repeat preprocessing if stats changed. + if stats != prev_stats: + embedding_utils.update_stacked_table_stats( + self._config.feature_specs, stats ) - embedding_utils.update_stacked_table_specs( + + # Re-execute preprocessing with consistent input statistics. + preprocessed, _ = embedding_utils.stack_and_shard_samples( self._config.feature_specs, - int_max_ids_per_partition, - int_max_unique_ids_per_partition, + samples, + local_device_count, + global_device_count, + num_sc_per_device, ) return {"inputs": preprocessed} diff --git a/keras_rs/src/layers/embedding/jax/embedding_utils.py b/keras_rs/src/layers/embedding/jax/embedding_utils.py index 393c197c..38e69f7d 100644 --- a/keras_rs/src/layers/embedding/jax/embedding_utils.py +++ b/keras_rs/src/layers/embedding/jax/embedding_utils.py @@ -35,6 +35,12 @@ class ShardedCooMatrix(NamedTuple): values: ArrayLike +class InputStatsPerTable(NamedTuple): + max_ids_per_partition: int + max_unique_ids_per_partition: int + required_buffer_size_per_device: int + + def _round_up_to_multiple(value: int, multiple: int) -> int: return ((value + multiple - 1) // multiple) * multiple @@ -335,19 +341,47 @@ def get_table_stacks( return stacked_table_specs -def update_stacked_table_specs( +def get_stacked_table_stats( feature_specs: Nested[FeatureSpec], - max_ids_per_partition: Mapping[str, int], - max_unique_ids_per_partition: Mapping[str, int], +) -> dict[str, InputStatsPerTable]: + """Extracts the stacked-table input statistics from the feature specs. + + Args: + feature_specs: Feature specs from which to extracts the statistics. + + Returns: + A mapping of stacked table names to input statistics per table. + """ + stacked_table_specs: dict[str, StackedTableSpec] = {} + for feature_spec in jax.tree.flatten(feature_specs)[0]: + feature_spec = typing.cast(FeatureSpec, feature_spec) + stacked_table_spec = typing.cast( + StackedTableSpec, feature_spec.table_spec.stacked_table_spec + ) + stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec + + stats: dict[str, InputStatsPerTable] = {} + for stacked_table_spec in stacked_table_specs.values(): + buffer_size = stacked_table_spec.suggested_coo_buffer_size_per_device + buffer_size = buffer_size or 0 + stats[stacked_table_spec.stack_name] = InputStatsPerTable( + max_ids_per_partition=stacked_table_spec.max_ids_per_partition, + max_unique_ids_per_partition=stacked_table_spec.max_unique_ids_per_partition, + required_buffer_size_per_device=buffer_size, + ) + + return stats + + +def update_stacked_table_stats( + feature_specs: Nested[FeatureSpec], + stats: Mapping[str, InputStatsPerTable], ) -> None: - """Updates properties in the supplied feature specs. + """Updates stacked-table input properties in the supplied feature specs. Args: feature_specs: Feature specs to update in-place. - max_ids_per_partition: Mapping of table stack name to - new `max_ids_per_partition` for the stack. - max_unique_ids_per_partition: Mapping of table stack name to - new `max_unique_ids_per_partition` for the stack. + stats: Per-stacked-table input statistics. """ # Collect table specs and stacked table specs. table_specs: dict[str, TableSpec] = {} @@ -363,18 +397,17 @@ def update_stacked_table_specs( stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec # Replace fields in the stacked_table_specs. - stacked_table_specs = { - stack_name: dataclasses.replace( + stack_names = stacked_table_specs.keys() + for stack_name in stack_names: + stack_stats = stats[stack_name] + stacked_table_spec = stacked_table_specs[stack_name] + buffer_size = stack_stats.required_buffer_size_per_device or None + stacked_table_specs[stack_name] = dataclasses.replace( stacked_table_spec, - max_ids_per_partition=max_ids_per_partition[ - stacked_table_spec.stack_name - ], - max_unique_ids_per_partition=max_unique_ids_per_partition[ - stacked_table_spec.stack_name - ], + max_ids_per_partition=stack_stats.max_ids_per_partition, + max_unique_ids_per_partition=stack_stats.max_unique_ids_per_partition, + suggested_coo_buffer_size_per_device=buffer_size, ) - for stack_name, stacked_table_spec in stacked_table_specs.items() - } # Insert new stacked tables into tables. for table_spec in table_specs.values(): @@ -534,7 +567,7 @@ def stack_and_shard_samples( global_device_count: int, num_sc_per_device: int, static_buffer_size: int | Mapping[str, int] | None = None, -) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]: +) -> tuple[dict[str, ShardedCooMatrix], dict[str, InputStatsPerTable]]: """Prepares input samples for use in embedding lookups. Args: @@ -544,8 +577,8 @@ def stack_and_shard_samples( global_device_count: Number of global JAX devices. num_sc_per_device: Number of sparsecores per device. static_buffer_size: The static buffer size to use for the samples. - Defaults to None, in which case an upper-bound for the buffer size - will be automatically determined. + Defaults to None, in which case an upper-bound for the buffer size + will be automatically determined. Returns: The preprocessed inputs, and statistics useful for updating FeatureSpecs @@ -579,6 +612,7 @@ def collect_tokens_and_weights( ) out: dict[str, ShardedCooMatrix] = {} + out_stats: dict[str, InputStatsPerTable] = {} tables_names = preprocessed_inputs.lhs_row_pointers.keys() for table_name in tables_names: shard_ends = preprocessed_inputs.lhs_row_pointers[table_name] @@ -592,5 +626,17 @@ def collect_tokens_and_weights( row_ids=preprocessed_inputs.lhs_sample_ids[table_name], values=preprocessed_inputs.lhs_gains[table_name], ) + out_stats[table_name] = InputStatsPerTable( + max_ids_per_partition=np.max( + stats.max_ids_per_partition[table_name] + ), + max_unique_ids_per_partition=np.max( + stats.max_unique_ids_per_partition[table_name] + ), + required_buffer_size_per_device=np.max( + stats.required_buffer_size_per_sc[table_name] + ) + * num_sc_per_device, + ) - return out, stats + return out, out_stats From ed0372761c16965de3d20e4609ed9c04b1781764 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 10 Sep 2025 11:09:37 +0530 Subject: [PATCH 40/40] Change dataloader to global bsz --- examples/ml_perf/main.py | 2 +- examples/ml_perf/run.sh | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py index 8dbb38cb..3f661a8c 100644 --- a/examples/ml_perf/main.py +++ b/examples/ml_perf/main.py @@ -115,7 +115,7 @@ def main( print("===== Loading dataset =====") train_ds = DataLoader( file_pattern=file_pattern, - batch_size=per_host_batch_size, + batch_size=global_batch_size, file_batch_size=file_batch_size, dense_features=dense_features, large_emb_features=large_emb_features, diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh index 122dd92c..7a774221 100644 --- a/examples/ml_perf/run.sh +++ b/examples/ml_perf/run.sh @@ -101,7 +101,7 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="sudo apt-get update && sudo apt install -y python3.10-venv && if [ ! -d '.keras-env' ]; then echo '>>> Creating .keras-env...'; python3 -m venv .keras-env; else echo '>>> .keras-env already exists.'; fi" + --command="sudo apt-get update && sudo apt install -y python3.12-venv && if [ ! -d '.keras-env' ]; then echo '>>> Creating .keras-env...'; python3.12 -m venv .keras-env; else echo '>>> .keras-env already exists.'; fi" # ============================================================================== @@ -155,7 +155,7 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="source .keras-env/bin/activate && echo 'import jax; print(jax.devices())' > script.py && python script.py" + --command="source .keras-env/bin/activate && echo 'import jax; print(jax.devices())' > script.py && python3.12 script.py" # ============================================================================== @@ -166,6 +166,6 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT} \ --zone ${ZONE} \ --worker=all \ - --command="source .keras-env/bin/activate && cd keras-rs && python3 -m examples.ml_perf.main --config_name ${CONFIG_NAME}" + --command="source .keras-env/bin/activate && cd keras-rs && python3.12 -m examples.ml_perf.main --config_name ${CONFIG_NAME}" echo ">>> Script finished."