From cd7dfeaa94182e5f2ecbf73d97d5e00039d14d83 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 11:57:27 +0000 Subject: [PATCH 01/92] Add gfx950 build support + fp16 fix + index type fix --- fbgemm_gpu/cmake/Hip.cmake | 8 ++++++++ .../embedding_backward_split_template.cu | 2 +- ..._backward_split_device_kernel_template.hip | 2 +- .../include/fbgemm_gpu/rocm/cdna_guard.h | 2 +- .../fbgemm_gpu/rocm/split_embeddings_common.h | 20 ++++++++++++++++++- fbgemm_gpu/src/tbe/eeg/indices_generator.cpp | 2 +- 6 files changed, 31 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 17640b7254..2011a34c33 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -78,6 +78,14 @@ if(HIP_FOUND) list(APPEND HIP_CXX_FLAGS -mf16c) list(APPEND HIP_CXX_FLAGS -mfma) list(APPEND HIP_CXX_FLAGS -std=c++20) + list(APPEND HIP_CXX_FLAGS -g) + list(APPEND HIP_CXX_FLAGS -ggdb) + + # list(APPEND HIP_CXX_FLAGS -Wa,-adhln) + #list(APPEND HIP_CXX_FLAGS -adhln) + list(APPEND HIP_CXX_FLAGS -save-temps) + list(APPEND HIP_CXX_FLAGS -fverbose-asm) + set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) # Ask hcc to generate device code during compilation so we can use diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 76eba64c99..76a2b347d8 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1193,7 +1193,7 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna()) + if (use_hip_kernel && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256] %} diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2fcbba395e..5acc61382e 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -179,7 +179,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_length_mod = segment_length & length_mask; cache_t grad_acc[dword_per_row]; - int32_t infos[segment_unroll]; + uint32_t infos[segment_unroll]; grad_t grad_data[dword_per_row * segment_prefetch]; emb_t emb_data[dword_per_row]; float indice_weights[segment_unroll]; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h index b55fd72fce..447613c5fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -38,7 +38,7 @@ namespace fbgemm_gpu::rocm { [[nodiscard]] inline bool is_supported_cdna() { - const std::set supported_archs{"gfx942", "gfx90a"}; + const std::set supported_archs{"gfx942", "gfx90a", "gfx950"}; int device_id = 0; HIP_CHECK(hipGetDevice(&device_id)); hipDeviceProp_t dev_props; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b3a56c4b52..c96da01063 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -215,6 +215,24 @@ struct load_row_per_warp { } }; +template +struct load_row_per_warp { + static __device__ void run( + c10::Half* emb_data, + index_t row_index, + const c10::Half* p_emb_table, + int lane_id) { + load_row_per_warp::run( + reinterpret_cast(emb_data), + row_index, + reinterpret_cast(p_emb_table), + lane_id + ); + } + +}; + + template < typename emb_t, int32_t embedding_dim, @@ -471,7 +489,7 @@ __device__ __forceinline__ void generic_dpp_reduction(data_t& result) { // of trivial operation with an option to use custom operation template __device__ __forceinline__ void dpp_reduction(data_t& result) { -#if defined(__gfx942__) || defined(__gfx90a__) +#if defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__) if constexpr (std::is_same_v) { DPP_REDUCE_F16_F32(add); return; diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp index dfea2dce8a..361059020e 100644 --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,7 +131,7 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( - std::execution::par, + // std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs, From 602b7bfa12525434295d69a1cad89c9732b96061 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 13:16:41 +0000 Subject: [PATCH 02/92] Change int64_t to index_t as template parameters in load_raw_per_warp --- .../rocm/embedding_backward_split_device_kernel_template.hip | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 5acc61382e..d5841d6e00 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -452,7 +452,7 @@ L_tail_grad_acc: } // load the old emb weight data - load_row_per_warp::run( + load_row_per_warp::run( &emb_data[0], emb_idx, p_emb_table, lane_id); optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); From a587e06ddb5b3b0a06397a94ac7907adcad4b5a8 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 14:39:22 +0000 Subject: [PATCH 03/92] Implement llvm fp16 buffer load for gfx950 --- .../fbgemm_gpu/rocm/split_embeddings_common.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index c96da01063..4b33fd1422 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -60,7 +60,12 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.load.f16"); +#endif __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, @@ -72,7 +77,12 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.load.v2f16"); +#endif __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, From 48a10bf2c6c1f73a8cb55669f72b2a9e72f2a51b Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 11 Aug 2025 08:23:47 +0000 Subject: [PATCH 04/92] Fix c-style half to float cast --- .../include/fbgemm_gpu/rocm/split_embeddings_common.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 4b33fd1422..238a83440a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -261,7 +261,14 @@ struct accumulate_row_per_warp { } else { #pragma unroll for (int i = 0; i < dword_per_row; i++) { - acc[i] += static_cast((float)emb_data[i] * row_weight); + if constexpr (std::is_same_v) + { + acc[i] += static_cast(__half2float(emb_data[i]) * row_weight); + } + else + { + acc[i] += static_cast(static_cast(emb_data[i]) * row_weight); + } } } } From d4acaba5d865b555436015a706f962cb893ecec8 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 11 Aug 2025 08:24:29 +0000 Subject: [PATCH 05/92] Patch 256 half stores --- .../include/fbgemm_gpu/rocm/split_embeddings_common.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 238a83440a..974eae2594 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -294,6 +294,16 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + out[lane_id + 64] = *reinterpret_cast(&acc[2]); + } +}; + + template <> struct store_row_per_warp { static __device__ void run(float* acc, float* p_output, int lane_id) { From a6636f0dddf6e9a374151b56ef551429928aa3af Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Fri, 8 Aug 2025 05:02:58 +0000 Subject: [PATCH 06/92] cta_per_row workgroup optim --- .../training/backward/embedding_backward_split_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 76a2b347d8..9412edc1a5 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1042,7 +1042,7 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + int32_t num_cta_per_row_groups = (kMaxThreads/2) / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1053,7 +1053,7 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, kMaxThreads), + div_round_up(total_unique_indices, (kMaxThreads/2)), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( From a15fb0900a18cac9e5254b2c23159270c3d97ac9 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 11 Aug 2025 21:06:48 +0000 Subject: [PATCH 07/92] Added mi350 guards --- ...ding_backward_split_indice_weights_template.cu | 15 ++++++++++++++- .../backward/embedding_backward_split_template.cu | 10 ++++++++++ .../forward/embedding_forward_split_template.cu | 14 ++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) mode change 100644 => 100755 fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu mode change 100644 => 100755 fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu mode change 100644 => 100755 fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu old mode 100644 new mode 100755 index 6d38d1d99a..9e1f71ef4e --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -23,6 +23,10 @@ #include "fbgemm_gpu/utils/assert_macros.h" #include "fbgemm_gpu/utils/kernel_launcher.cuh" +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -359,7 +363,16 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output); CUDA_DEVICE_GUARD(dev_weights); - + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); // offsets = [B x T + 1] diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu old mode 100644 new mode 100755 index 9412edc1a5..9e9e7aac68 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -652,6 +652,16 @@ Tensor {{ embedding_cuda_op }}( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu old mode 100644 new mode 100755 index 6574bda45e..bbd62a8bbc --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -31,6 +31,10 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" {%- endif %} +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations @@ -454,6 +458,16 @@ batch_index_select_dim0_codegen_forward_cuda( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if not nobag %} int32_t T = D_offsets.numel() - 1; {%- else %} From 6af95e05fbc3bc901f9bc46a9eb6dcfb7c815988 Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Tue, 12 Aug 2025 15:09:39 +0000 Subject: [PATCH 08/92] Fix index overflow in row load --- ..._backward_split_device_kernel_template.hip | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index d5841d6e00..d1a874805a 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -227,7 +227,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); {%- if nobag %} @@ -236,7 +236,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ #pragma unroll @@ -250,7 +250,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -261,7 +261,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -290,7 +290,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -301,7 +301,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -328,7 +328,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); {%- if nobag %} @@ -337,7 +337,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted) { @@ -352,7 +352,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -363,7 +363,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -383,7 +383,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -394,7 +394,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -420,7 +420,7 @@ L_tail_grad_acc: table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); @@ -441,7 +441,7 @@ L_tail_grad_acc: table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); From be5f1b8666d6f98a15ab086be55e4e8d390de1ad Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Tue, 12 Aug 2025 20:13:09 +0000 Subject: [PATCH 09/92] cta_per_row workgroup reduce by 4 optim --- .../training/backward/embedding_backward_split_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 9e9e7aac68..c59f6fe9aa 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1052,7 +1052,7 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = (kMaxThreads/2) / kWarpSize; + int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1063,7 +1063,7 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, (kMaxThreads/2)), + div_round_up(total_unique_indices, (kMaxThreads/4)), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( From acef908f0ec48e34911b7cadb83a81b003a9ddfa Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 13 Aug 2025 13:21:38 +0000 Subject: [PATCH 10/92] Fix mixed_D frontend to backend connection --- .../training/backward/embedding_backward_split_template.cu | 2 +- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 1 + .../split_table_batched_embeddings_ops_training.py | 5 ++++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index c59f6fe9aa..c8a846a552 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1203,7 +1203,7 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && supported_weights_type && rocm::is_supported_cdna()) + if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256] %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 3720f1ea42..20c055e917 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -698,6 +698,7 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + const auto mixed_D = aux_bool[IDX_MIXED_D]; {%- endif %} // Default values for Dynamo tracing diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index a572de0738..a0bc843902 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -813,7 +813,7 @@ def __init__( # noqa C901 assert ( self.pooling_mode != PoolingMode.NONE ), "Mixed dimension tables only supported for pooling tables." - + self.mixed_D = mixed_D assert all( cd == compute_devices[0] for cd in compute_devices ), "Heterogenous compute_devices are NOT supported!" @@ -2505,6 +2505,7 @@ def forward( # noqa: C901 row_counter, iter_int, self.max_counter.item(), + mixed_D=self.mixed_D, ), ) elif self._used_rowwise_adagrad_with_global_weight_decay: @@ -2523,6 +2524,7 @@ def forward( # noqa: C901 # `Optional[Tensor]` but got `Union[Module, Tensor]`. prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, + mixed_D=self.mixed_D, ), ) else: @@ -2532,6 +2534,7 @@ def forward( # noqa: C901 common_args, self.optimizer_args, momentum1, + mixed_D=self.mixed_D, ), ) From 33f4ad96584936cedd8233fc43144422dee7928b Mon Sep 17 00:00:00 2001 From: kudomcho Date: Fri, 15 Aug 2025 15:32:19 +0000 Subject: [PATCH 11/92] changed max_segment_length_per_cta to 4096 --- .../training/backward/embedding_backward_split_template.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index c8a846a552..1ddcea55b2 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -980,7 +980,7 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { From aaf19666d909849a4e0cf050ad204839e8baa764 Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Mon, 18 Aug 2025 22:32:58 +0000 Subject: [PATCH 12/92] added rocm guards and removed comment --- .../embedding_backward_split_template.cu | 19 ++++++++++++++++--- fbgemm_gpu/src/tbe/eeg/indices_generator.cpp | 1 - 2 files changed, 16 insertions(+), 4 deletions(-) mode change 100644 => 100755 fbgemm_gpu/src/tbe/eeg/indices_generator.cpp diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 1ddcea55b2..099c7e5685 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -980,7 +980,11 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + #ifdef USE_ROCM + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + #else + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + #endif Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { @@ -1052,7 +1056,11 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + #ifdef USE_ROCM + int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + #else + int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + #endif const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1063,7 +1071,12 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, (kMaxThreads/4)), + #ifdef USE_ROCM + div_round_up(total_unique_indices, (kMaxThreads/4)), + #else + div_round_up(total_unique_indices, kMaxThreads), + #endif + get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp old mode 100644 new mode 100755 index 361059020e..715acd8c0c --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,7 +131,6 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( - // std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs, From 48e7f97b19aabcf26e87e7c14f7c9c705138e37f Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 20 Aug 2025 03:00:56 +0000 Subject: [PATCH 13/92] clean debug statements in Hip.cmake --- fbgemm_gpu/cmake/Hip.cmake | 8 -------- 1 file changed, 8 deletions(-) diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 2011a34c33..17640b7254 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -78,14 +78,6 @@ if(HIP_FOUND) list(APPEND HIP_CXX_FLAGS -mf16c) list(APPEND HIP_CXX_FLAGS -mfma) list(APPEND HIP_CXX_FLAGS -std=c++20) - list(APPEND HIP_CXX_FLAGS -g) - list(APPEND HIP_CXX_FLAGS -ggdb) - - # list(APPEND HIP_CXX_FLAGS -Wa,-adhln) - #list(APPEND HIP_CXX_FLAGS -adhln) - list(APPEND HIP_CXX_FLAGS -save-temps) - list(APPEND HIP_CXX_FLAGS -fverbose-asm) - set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) # Ask hcc to generate device code during compilation so we can use From 750bee4ab280a186893c4db29ed35fdff2b5e8d9 Mon Sep 17 00:00:00 2001 From: Shreya Date: Thu, 28 Aug 2025 11:43:32 -0500 Subject: [PATCH 14/92] Merge pull request #121 warp per row wg change --- .../embedding_backward_split_template.cu | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 099c7e5685..2425322948 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1056,10 +1056,21 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); + int32_t total_L = indices.numel(); #ifdef USE_ROCM - int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + int32_t num_cta_per_row_groups; + int32_t work_group_size; + if (total_L/total_B > 1){ + num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + work_group_size = (kMaxThreads/4); + } + else{ + num_cta_per_row_groups = kMaxThreads / kWarpSize; + work_group_size = kMaxThreads; + } #else int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + int32_t work_group_size = kMaxThreads; #endif const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, @@ -1071,17 +1082,13 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - #ifdef USE_ROCM - div_round_up(total_unique_indices, (kMaxThreads/4)), - #else - div_round_up(total_unique_indices, kMaxThreads), - #endif - + div_round_up(total_unique_indices, work_group_size), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, + // (64, 2) dim3(kThreadGroupSize, num_cta_per_row_groups), cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), @@ -1185,7 +1192,18 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; // Compute shared memory size for warp_per_row - int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + #ifdef USE_ROCM + int32_t num_warp_per_row_groups; + + if (total_L/total_B > 1){ + num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; + } + else{ + num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + } + #else + int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + #endif int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { From f0acbc3cdce30702c2f9d136af9b8493064a516f Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 2 Sep 2025 09:25:03 +0000 Subject: [PATCH 15/92] Guard f16 llvm intrinsics with ROCm >=7.0 --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 974eae2594..46c4603381 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -24,6 +24,7 @@ #include #include #include +#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -61,7 +62,7 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i16"); #else __asm("llvm.amdgcn.raw.buffer.load.f16"); @@ -78,7 +79,7 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); From 0ee2366b676f8dae828cfecba0ee850c4d457f10 Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 18 Sep 2025 16:28:31 +0000 Subject: [PATCH 16/92] fix the bug in dimention 160 in ROCm optimization --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 46c4603381..8a97579d6a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -165,7 +165,7 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + amdgcn_make_buffer_resource(p_emb_table + row_index * 160); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, lane_id * sizeof(half2), 0, 0); if ((lane_id + 128) % 192 < 160) { From e33120d2eb73fd3a8a14cabd0b0de71a1f61f43b Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 19 Aug 2025 13:41:17 +0000 Subject: [PATCH 17/92] Cleanup optimized warp_per_raw kernel --- fbgemm_gpu/cmake/tbe_sources.py | 2 - .../genscript/generate_backward_split.py | 10 +- ...ing_backward_split_kernel_warp_template.cu | 40 +++----- .../embedding_backward_split_template.cu | 18 ++-- ..._backward_split_device_kernel_template.hip | 94 +++++-------------- 5 files changed, 54 insertions(+), 110 deletions(-) diff --git a/fbgemm_gpu/cmake/tbe_sources.py b/fbgemm_gpu/cmake/tbe_sources.py index 82092cc173..b38f862564 100644 --- a/fbgemm_gpu/cmake/tbe_sources.py +++ b/fbgemm_gpu/cmake/tbe_sources.py @@ -176,7 +176,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( @@ -495,7 +494,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index a817232910..5acb6f2e7f 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -52,7 +52,11 @@ def render_backward_templates( return weighted_options = [True, False] - nobag_options = [True, False] if (not is_gwd) else [False] + nobag_options = ( + [True, False] + if (not (is_gwd or kwargs.get("is_hip_optimized_backward"))) + else [False] + ) vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False] ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False] template = CodeTemplate.load(template_filepath) @@ -327,8 +331,7 @@ def generate_backward_indices() -> None: @staticmethod def generate_rocm_backward_split(**kwargs: Any) -> None: - # Generate backward device kernels based on weighted (True/False), VBE - # (True/False), no bag (True/False) + # Generate backward device kernels based on weighted (True/False) template_filepath = ( "training/backward/rocm/embedding_backward_split_device_kernel_template.hip" ) @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None: "has_ssd_support": False, "dense": False, "gen_once": False, + "is_hip_optimized_backward": True, }, ) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 5137b5766c..1158721526 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,6 +32,14 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" @@ -538,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endif %} -{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include #include #include "fbgemm_gpu/rocm/split_embeddings_common.h" @@ -612,12 +620,8 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} {%- endif %} ) { - {%- if not nobag %} int32_t T = D_offsets.size(0) - 1; - {%- else %} - int32_t T = weights_offsets.size(0); - {%- endif %} - + auto p_output_grad = grad_output.data(); auto p_emb_table = dev_weights.data(); auto p_hash_size_cumsum = hash_size_cumsum.data(); @@ -632,8 +636,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd constexpr int32_t segment_prefetch = 2; constexpr int32_t segment_unroll = 8; constexpr int32_t segment_split = 0; - auto batch = grad_output.size(0); - auto num_rows = dev_weights.size(0) / T / max_D; {%- if weighted %} constexpr bool is_weighted = true; {%- else %} @@ -646,22 +648,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd // weight_decay(_mode) is supplied as args.split_function_args_no_defaults opt_karg.weight_decay_mode = weight_decay_mode_v; opt_karg.weight_decay = weight_decay; - auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t { - assert(d >= 1 && d <= INT32_MAX); - uint8_t shift; - for(shift = 0; shift < 32; shift++) - if((1U << shift) >= d) - break; - - uint64_t one = 1; - uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; - assert(magic <= 0xffffffffUL); - - rocm::magic_div_u32_t result; - result.magic = magic; - result.shift = shift; - return result; - }(batch); + rocm::split_tbe_backward_hip_kernel_{{kdesc}}< rocm::{{optimizer}}_optimizer_t, rocm::{{optimizer}}_kernel_arg_t, @@ -680,16 +667,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd p_sorted_linear_indices_run, p_sorted_linear_indices_cumulative_run_lengths, p_sorted_linear_indices_num_runs, - {%- if not nobag %} info_B_num_bits, info_B_mask, - {%- endif %} p_sorted_infos, - batch_mdiv, max_segment_length_per_warp, emb_dim, - batch, - num_rows, T, opt_karg {%- if weighted %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 2425322948..fb125101e7 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,6 +48,15 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + template < typename emb_t, typename grad_t, @@ -227,8 +236,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include "fbgemm_gpu/rocm/split_embeddings_common.h" template < typename emb_t, @@ -862,8 +870,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} + {%- if is_optimized_hip_kernel_supported_mode %} {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, @@ -1226,8 +1233,7 @@ Tensor {{ embedding_cuda_op }}( get_max_thread_blocks_()); #ifdef USE_ROCM - {%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and - not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %} + {%- if is_optimized_hip_kernel_supported_mode %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index d1a874805a..951cff4399 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -122,20 +122,11 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const index_t* p_sorted_linear_indices_run, const int32_t* p_sorted_linear_indices_cumulative_run_lengths, const int32_t* p_sorted_linear_indices_num_runs, - {%- if not nobag %} const int32_t info_B_num_bits, const uint32_t info_B_mask, - {%- endif %} - {%- if not nobag %} const int32_t* p_sorted_infos, - {%- else %} - const int64_t* p_sorted_infos, - {%- endif %} - magic_div_u32_t batch_mdiv, uint32_t max_segment_length_per_warp, uint32_t emb_dim, - uint32_t batch, - uint32_t num_rows, uint32_t num_tables, optimizer_karg_t opt_karg, const float * p_sorted_indice_weights = nullptr) @@ -157,13 +148,9 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; - {%- if nobag %} - const auto info_0 = p_sorted_infos[segment_start]; - int32_t t_0 = info_0 % num_tables; - {%- else %} const auto info_0 = reinterpret_cast(&p_sorted_infos[0])[segment_start]; const auto t_0 = info_0 >> info_B_num_bits; - {%- endif %} + int64_t hash_size = p_hash_size_cumsum[t_0]; const int64_t emb_idx = linear_index - hash_size; @@ -221,21 +208,15 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( // LOOP for(; itr < segment_length_mod; itr += segment_unroll) { - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ @@ -244,23 +225,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -284,23 +261,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -322,21 +295,16 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( } // LAST - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} + table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); @@ -346,23 +314,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -377,23 +341,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -414,12 +374,9 @@ L_tail_grad_acc: infos[0] = p_sorted_infos[segment_start]; p_sorted_infos++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -435,12 +392,9 @@ L_tail_grad_acc: p_sorted_infos++; p_sorted_indice_weights++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( From 3447ef0f889bc772891f1fa0f7b8b0fad767cd8a Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 20 Aug 2025 12:15:37 +0000 Subject: [PATCH 18/92] Add 320 embedding dim support for optimized warp_per_row kernel --- ...ing_backward_split_kernel_warp_template.cu | 2 +- .../embedding_backward_split_template.cu | 2 +- .../fbgemm_gpu/rocm/split_embeddings_common.h | 26 +++++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 1158721526..e61b3fc0aa 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -766,7 +766,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index fb125101e7..7eb2b6880f 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1243,7 +1243,7 @@ Tensor {{ embedding_cuda_op }}( if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; - {%- for kDimSize in [64, 128, 160, 192, 256] %} + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} {%- for kWeightDecayMode in [0, 1, 2] %} if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) { diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 8a97579d6a..5b9d69d910 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -205,6 +205,22 @@ struct load_row_per_warp { } }; +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 320); + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; + } +}; + template struct load_row_per_warp { static __device__ void @@ -304,6 +320,16 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + out[lane_id + 64] = *reinterpret_cast(&acc[2]); + p_output[lane_id + 256] = acc[4]; + } +}; + template <> struct store_row_per_warp { From a1361ab4f62acb0df9051031ee1adb10c8af344f Mon Sep 17 00:00:00 2001 From: root Date: Mon, 8 Sep 2025 19:34:16 +0000 Subject: [PATCH 19/92] changed the max length per warp and cta per row WG size --- .../backward/embedding_backward_split_host_template.cpp | 2 +- .../training/backward/embedding_backward_split_template.cu | 6 +----- .../training/index_select/batch_index_select_dim0_host.cpp | 2 +- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 2 +- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index e071d88768..6d3769534e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -949,7 +949,7 @@ class {{ autograd_func }} : #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + constexpr int32_t max_segment_length_per_warp = 4096; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 7eb2b6880f..86d4ce8b8b 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -987,11 +987,7 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - #ifdef USE_ROCM - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; - #else - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; - #endif + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index 18378b6106..00673abc8b 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -658,7 +658,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 32; + constexpr int32_t max_segment_length_per_warp = 4096; auto grad_output = grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 20c055e917..46384be1bb 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1006,7 +1006,7 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + constexpr int32_t max_segment_length_per_warp = 4096; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; From 9c2fd1d37a0f1167a6fe865da9c3d60dee5a872f Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 9 Sep 2025 20:25:30 +0000 Subject: [PATCH 20/92] added DPP and changed max length per warp to 16k --- .../embedding_backward_split_host_template.cpp | 2 +- .../index_select/batch_index_select_dim0_host.cpp | 4 ++-- .../embedding_split_host_pt2_autograd_template.cpp | 2 +- .../include/fbgemm_gpu/utils/cuda_prelude.cuh | 14 ++++++++------ 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 6d3769534e..05b93d9d7e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -949,7 +949,7 @@ class {{ autograd_func }} : #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 4096; + constexpr int32_t max_segment_length_per_warp = 16384; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index 00673abc8b..02529f2d89 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -342,7 +342,7 @@ class BatchIndexSelectDim0GPUOp Tensor grad_dev_weights; TORCH_CHECK_EQ(grad_outputs.size(), 1); - constexpr int32_t max_segment_length_per_warp = 32; + constexpr int32_t max_segment_length_per_warp = 16384; auto grad_output = grad_outputs[0]; @@ -658,7 +658,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 4096; + constexpr int32_t max_segment_length_per_warp = 16384; auto grad_output = grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 46384be1bb..8fb2cdf2ed 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1006,7 +1006,7 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 4096; + constexpr int32_t max_segment_length_per_warp = 16384; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh index 0d65c4798a..a1d9819017 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -21,7 +21,9 @@ #include #endif #include - +#ifdef USE_ROCM +#include "fbgemm_gpu/rocm/split_embeddings_common.h" +#endif namespace { inline int get_device_sm_cnt_() { @@ -138,11 +140,11 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { -#pragma unroll - for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { - val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); - } - return val; + return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); } DEVICE_INLINE void syncwarp() { From 54690c9f6d952531a4dbce5649ef7348a13492e0 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Wed, 10 Sep 2025 19:33:44 +0000 Subject: [PATCH 21/92] guard max segment warp based on emb dim --- ...dding_split_host_pt2_autograd_template.cpp | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 8fb2cdf2ed..c587ccb83a 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1006,7 +1006,25 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 16384; + int32_t max_segment_length_per_warp = 64; + // Workaround. Should not be upstreamed in any way. + // Redistribute all cta_per_row work to warp_per_row. + {%- if (not nobag) and + (optimizer == "rowwise_adagrad") and + (not vbe) and + (not is_gwd) and + (not ssd) and + (not is_index_select) and + (not dense) %} + const auto T = weights_offsets.sym_numel(); + const auto B = (offsets.size(0) - 1) / T; + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} + if(!mixed_D && (max_D == {{ kDimSize }})) + { + max_segment_length_per_warp = 16384; + } + {%- endfor %} + {%- endif %} #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; From d666611892106d3c8d104a8f274da8d4005adf32 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Wed, 10 Sep 2025 22:00:20 +0000 Subject: [PATCH 22/92] added guarding opt of max segment for the case batch size list=1 --- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index c587ccb83a..fa6a27ab55 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1009,6 +1009,7 @@ static torch::autograd::variable_list backward( int32_t max_segment_length_per_warp = 64; // Workaround. Should not be upstreamed in any way. // Redistribute all cta_per_row work to warp_per_row. + int32_t total_L = indices.numel(); {%- if (not nobag) and (optimizer == "rowwise_adagrad") and (not vbe) and @@ -1017,9 +1018,10 @@ static torch::autograd::variable_list backward( (not is_index_select) and (not dense) %} const auto T = weights_offsets.sym_numel(); - const auto B = (offsets.size(0) - 1) / T; + auto total_B = (offsets.size(0) - 1); + const auto B = total_B / T; {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} - if(!mixed_D && (max_D == {{ kDimSize }})) + if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) { max_segment_length_per_warp = 16384; } From df863d02ccd522ad16d9e860153278e77b0220b1 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 Sep 2025 09:26:57 +0000 Subject: [PATCH 23/92] opt for grad_indice_weights kernel --- ..._backward_split_indice_weights_template.cu | 77 ++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 9e1f71ef4e..b30e3e5c77 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -214,7 +214,82 @@ __global__ __launch_bounds__(kForwardMaxThreads) void ) {%- endif %} - for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { + int32_t j = 0; + {%- if not ssd and not dense and not use_vec_blocking and not vbe %} + // Currently for split_embedding_codegen_grad_indice_weights_kernel only + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + const auto cache_idx_j0 = shfl_sync(cache_idx, j); + const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); + const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); + const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + [[maybe_unused]] const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + [[maybe_unused]] const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + [[maybe_unused]] const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + [[maybe_unused]] const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + if (placement == PlacementType::MANAGED_CACHING) { + weight0 = (cache_idx_j0 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j0][d]) : + weight_row0.load(d); + + weight1 = (cache_idx_j1 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j1][d]) : + weight_row1.load(d); + + weight2 = (cache_idx_j2 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j2][d]) : + weight_row2.load(d); + + weight3 = (cache_idx_j3 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j3][d]) : + weight_row3.load(d); + } else { + weight0 = weight_row0.load(d); + weight1 = weight_row1.load(d); + weight2 = weight_row2.load(d); + weight3 = weight_row3.load(d); + } + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); + + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + {%- endif %} + for (; j < kWarpSize && l_start + j < L; ++j) { const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); From e0bee9fc762fe7ef20dab03907f999ef291332a9 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 23 Sep 2025 02:09:26 +0000 Subject: [PATCH 24/92] added store row per warp on emb 192 and added accuracy test functionality --- ...plit_table_batched_embeddings_benchmark.py | 223 +++++++++++++----- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 125 ++++++++-- .../fbgemm_gpu/rocm/split_embeddings_common.h | 18 +- 3 files changed, 277 insertions(+), 89 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 4ffb7341a5..3fad8f53fe 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -7,7 +7,8 @@ # pyre-strict - +import gzip +import yaml import logging import os import tempfile @@ -1011,7 +1012,15 @@ def context_factory(on_trace_ready: Callable[[profile], None]): @TbeBenchClickInterface.common_options @TbeBenchClickInterface.device_options @TbeBenchClickInterface.vbe_options +@click.option("--save", type=str, default=None) +@click.option("--load", type=str, default=None) +@click.option("--random-weights", is_flag=True, default=False) +@click.option("--compressed", is_flag=True, default=False) +@click.option("--slice-min", type=int, default=None) +@click.option("--slice-max", type=int, default=None) +@click.pass_context def device_with_spec( # noqa C901 + ctx, alpha: float, bag_size_list: str, bag_size_sigma_list: str, @@ -1031,7 +1040,39 @@ def device_with_spec( # noqa C901 bounds_check_mode: int, flush_gpu_cache_size_mb: int, output_dtype: SparseType, + save: str, + load: str, + random_weights: bool, + compressed: bool, + slice_min: int, + slice_max: int, ) -> None: + if load: + with open(f"{load}/params.yaml", "r") as f: + ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader) + alpha = ctx.params["alpha"] + bag_size_list = ctx.params["bag_size_list"] + bag_size_sigma_list = ctx.params["bag_size_sigma_list"] + batch_size = ctx.params["batch_size"] + embedding_dim_list = ctx.params["embedding_dim_list"] + weights_precision = ctx.params["weights_precision"] + cache_precision = ctx.params["cache_precision"] + stoc = ctx.params["stoc"] + iters = ctx.params["iters"] + warmup_runs = ctx.params["warmup_runs"] + managed = ctx.params["managed"] + num_embeddings_list = ctx.params["num_embeddings_list"] + reuse = ctx.params["reuse"] + row_wise = ctx.params["row_wise"] + weighted = ctx.params["weighted"] + pooling = ctx.params["pooling"] + bounds_check_mode = ctx.params["bounds_check_mode"] + flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"] + output_dtype = ctx.params["output_dtype"] + random_weights = ctx.params["random_weights"] + compressed = ctx.params["compressed"] + slice_min = ctx.params["slice_min"] + slice_max = ctx.params["slice_max"] np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -1040,6 +1081,11 @@ def device_with_spec( # noqa C901 T = len(Ds) use_variable_bag_sizes = bag_size_sigma_list != "None" + params = ctx.params + if save: + os.makedirs(f"{save}", exist_ok=True) + with open(f"{save}/params.yaml", "w") as f: + yaml.dump(params, f, sort_keys=False) if use_variable_bag_sizes: Ls = [int(mu) for mu in bag_size_list.split(",")] @@ -1118,6 +1164,22 @@ def device_with_spec( # noqa C901 if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) + elif random_weights: + emb.init_embedding_weights_uniform(-1.0, 1.0) + + if save: + if compressed: + with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/model_state.pth") + + if load: + if compressed: + with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: + emb.load_state_dict(torch.load(f)) + else: + emb.load_state_dict(torch.load(f"{load}/model_state.pth")) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 @@ -1130,53 +1192,68 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - for t, e in enumerate(Es): - # (indices, offsets, weights) - requests = generate_requests( - iters, - B, - 1, - Ls[t], - e, - reuse=reuse, - alpha=alpha, - weighted=weighted, - # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. - sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, - zipf_oversample_ratio=3 if Ls[t] > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - for i, req in enumerate(requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - prev_indices_len = -1 - requests = [] - for i in range(iters): - indices = torch.concat(all_requests["indices"][i]) - if prev_indices_len == -1: - prev_indices_len = indices.numel() - assert ( - prev_indices_len == indices.numel() - ), "Number of indices for every iteration must be the same" - offsets = torch.concat(all_requests["offsets"][i]) - if weighted: - weights = torch.concat(all_requests["weights"][i]) - else: - weights = None - requests.append(TBERequest(indices, offsets, weights)) - - del all_requests - + + if load: + requests = [] + for i in range(iters): + indices = torch.load(f"{load}/{i}_indices.pt") + offsets = torch.load(f"{load}/{i}_offsets.pt") + per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") + Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt") + requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank)) + else: + for t, e in enumerate(Es): + # (indices, offsets, weights) + requests = generate_requests( + iters, + B, + 1, + Ls[t], + e, + reuse=reuse, + alpha=alpha, + weighted=weighted, + # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. + sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, + zipf_oversample_ratio=3 if Ls[t] > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) + for i, req in enumerate(requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + prev_indices_len = -1 + requests = [] + for i in range(iters): + indices = torch.concat(all_requests["indices"][i]) + if prev_indices_len == -1: + prev_indices_len = indices.numel() + assert ( + prev_indices_len == indices.numel() + ), "Number of indices for every iteration must be the same" + offsets = torch.concat(all_requests["offsets"][i]) + if weighted: + weights = torch.concat(all_requests["weights"][i]) + else: + weights = None + requests.append(TBERequest(indices, offsets, weights)) + del all_requests assert len(requests) == iters - + if save: + for i in range(iters): + req = requests[i] + torch.save(req.indices, f"{save}/{i}_indices.pt") + torch.save(req.offsets, f"{save}/{i}_offsets.pt") + torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") + torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") + sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: read_write_bytes = ( @@ -1203,34 +1280,44 @@ def device_with_spec( # noqa C901 # forward time_per_iter = benchmark_requests( - requests, - lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - ), - flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, - num_warmups=warmup_runs, - ) + requests, + lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, + ) logging.info( - f"Forward, B: {B}, " - f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " - f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 - f"T: {time_per_iter * 1.0e6:.0f}us" - ) + f"Forward, B: {B}, " + f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"T: {time_per_iter * 1.0e6:.0f}us" + ) + if output_dtype == SparseType.INT8: # backward bench not representative return - if do_pooling: - grad_output = torch.randn(B, sum(Ds)).to(get_device()) + if load: + grad_output = torch.load(f"{load}/grad_output.pt") else: # Obtain B * L from indices len # pyre-ignore[19] # pyre-fixme[61]: `D` is undefined, or not always defined. - grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) + else: + # Obtain B * L from indices len + # pyre-ignore[19] + # pyre-fixme[61]: `D` is undefined, or not always defined. + grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + + if save: + torch.save(grad_output, f"{save}/grad_output.pt") # backward time_per_iter = benchmark_requests( requests, @@ -1244,6 +1331,12 @@ def device_with_spec( # noqa C901 bwd_only=True, grad=grad_output, num_warmups=warmup_runs, + emb=emb, + save=save, + load=load, + compressed=compressed, + slice_min=slice_min, + slice_max=slice_max, ) logging.info( f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, " diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 1243f14db4..1bda3188e5 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -11,6 +11,7 @@ import statistics import threading import time +import gzip from subprocess import Popen from typing import Callable, Optional @@ -18,7 +19,7 @@ from fbgemm_gpu.tbe.utils import b_indices, TBERequest from fbgemm_gpu.tbe.utils.common import get_device - +from fbgemm_gpu.split_table_batched_embeddings_ops_training import SplitTableBatchedEmbeddingBagsCodegen logging.basicConfig(level=logging.DEBUG) @@ -241,36 +242,43 @@ def benchmark_requests( # noqa: C901 periodic_logs: bool = False, warmup_ms: Optional[int] = None, iters: int = -1, + emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, + save: Optional[str] = None, + load: Optional[str] = None, + compressed: bool = False, + slice_min: Optional[int] = None, + slice_max: Optional[int] = None, ) -> float: times = [] # Run at least one warmup iteration to avoid the long cudaLaunchKernel time # for the first kernel if warmup_ms > 0 # warmup_ms is prioritized over num_warmups - + import copy if warmup_ms is None: num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 - # warm-up the GPU before profiling - bench_warmup( - requests[0], - # pyre-ignore[6] - warmup_ms, - num_warmups, - lambda indices, offsets, per_sample_weights: func( - indices, - offsets, - per_sample_weights, - ), - bwd_only=bwd_only, - grad=grad, - ) + if not (load or save): + # warm-up the GPU before profiling + bench_warmup( + requests[0], + # pyre-ignore[6] + warmup_ms, + num_warmups, + lambda indices, offsets, per_sample_weights: func( + indices, + offsets, + per_sample_weights, + ), + bwd_only=bwd_only, + grad=grad, + ) - if callback_after_warmup is not None: - callback_after_warmup() + if callback_after_warmup is not None: + callback_after_warmup() num_reqs = len(requests) iters = num_reqs if iters == -1 else iters - + sliced = slice_min is not None and slice_max is not None if torch.cuda.is_available(): torch.cuda.synchronize() start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -278,7 +286,86 @@ def benchmark_requests( # noqa: C901 else: start_events = [] end_events = [] + if save and emb: + for it in range(iters): + req = requests[it % num_reqs] + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + if compressed: + with gzip.open(f"{save}/{it}_fwd_grad_out.pt.gz", "wb") as f: + torch.save(out, f) + else: + torch.save(out, f"{save}/{it}_fwd_grad_out.pt") + + out.backward(grad) + torch.cuda.synchronize() + torch.save(out, f"{save}/{it}_bwd_grad_out.pt") + + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: + torch.save(t[slice_min:slice_max,:].clone(), f) + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") + torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") + + else: + if compressed: + with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") + + if load and emb: + for it in range(iters): + req = requests[it % num_reqs] + + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + + out.backward(grad) + torch.cuda.synchronize() + emb_ref = copy.deepcopy(emb) + if not sliced: + if compressed: + with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: + emb_ref.load_state_dict(torch.load(f)) + else: + emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) + + print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: + w_ref = torch.load(f) + else: + w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") + torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + else: + for id, t in enumerate(emb.split_embedding_weights()): + torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + print("PASS") + + print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) + if sliced: + m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") + m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") + else: + m_dev_ref = emb_ref.momentum1_dev + m_uvm_ref = emb_ref.momentum1_uvm + torch.testing.assert_close(emb.momentum1_dev, m_dev_ref, atol=1.0e-4, rtol=1.0e-4) + torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref, atol=1.0e-4, rtol=1.0e-4) + print("PASS") for it in range(iters): req = requests[it % num_reqs] diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 5b9d69d910..745499ac08 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -24,7 +24,6 @@ #include #include #include -#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -62,7 +61,7 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if ROCM_VERSION_MAJOR >= 7 +#if defined(__gfx950__) __asm("llvm.amdgcn.raw.buffer.load.i16"); #else __asm("llvm.amdgcn.raw.buffer.load.f16"); @@ -79,7 +78,7 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if ROCM_VERSION_MAJOR >= 7 +#if defined(__gfx950__) __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); @@ -165,7 +164,7 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 160); + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, lane_id * sizeof(half2), 0, 0); if ((lane_id + 128) % 192 < 160) { @@ -320,6 +319,15 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + *(reinterpret_cast(&out[64]) + lane_id) = *reinterpret_cast(acc + 2); + } +}; + template <> struct store_row_per_warp { static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { @@ -619,4 +627,4 @@ __device__ inline void magic_div_u32_run_with_mod( quo = magic_div_u32_run(mdiv, n); rem = n - quo * d; } -} // namespace fbgemm_gpu::rocm +} // namespace fbgemm_gpu::rocm \ No newline at end of file From ca829505bc8ada311198df97baa7e29df169672b Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 22 Sep 2025 16:09:05 +0000 Subject: [PATCH 25/92] workgroup tuning and loop unrolled --- .../forward/embedding_forward_split_kernel_template.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) mode change 100644 => 100755 fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu old mode 100644 new mode 100755 index a39d33e391..25aca3506b --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -469,10 +469,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + kThreadGroupSize > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize/32) > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { + for (auto j = 0; j < (kThreadGroupSize/32) && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group @@ -641,7 +641,7 @@ batch_index_select_dim0_codegen_forward_kernel( {%- endif %} {%- if is_rocm %} // Unroll factor for ROCm devices - constexpr int kManualUnrollLength = 4; + constexpr int kManualUnrollLength = 8; {%- endif %} // Determine the linearized warp ID, and exit early if needed From 7ad444bd195b2d8a2b56ad0d6143e992c09c1215 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Sep 2025 22:38:17 +0200 Subject: [PATCH 26/92] specialize --- ..._backward_split_indice_weights_template.cu | 145 ++++++++++++------ 1 file changed, 95 insertions(+), 50 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index b30e3e5c77..0052d96406 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -217,33 +217,82 @@ __global__ __launch_bounds__(kForwardMaxThreads) void int32_t j = 0; {%- if not ssd and not dense and not use_vec_blocking and not vbe %} // Currently for split_embedding_codegen_grad_indice_weights_kernel only - for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { - const auto offset_idx_j0 = shfl_sync(offset_idx, j); - const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); - const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); - const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); - - const auto cache_idx_j0 = shfl_sync(cache_idx, j); - const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); - const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); - const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); - - at::acc_type grad_indice_weight0 = 0.0; - at::acc_type grad_indice_weight1 = 0.0; - at::acc_type grad_indice_weight2 = 0.0; - at::acc_type grad_indice_weight3 = 0.0; - - [[maybe_unused]] const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); - [[maybe_unused]] const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); - [[maybe_unused]] const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); - [[maybe_unused]] const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + if (placement != PlacementType::MANAGED_CACHING) { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + weight0 = weight_row0.load(d); + weight1 = weight_row1.load(d); + weight2 = weight_row2.load(d); + weight3 = weight_row3.load(d); - #pragma unroll kFixedMaxVecsPerThread - for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { - const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); - Vec4T> weight0, weight1, weight2, weight3; - if (placement == PlacementType::MANAGED_CACHING) { + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + } else { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + const auto cache_idx_j0 = shfl_sync(cache_idx, j); + const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); + const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); + const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; weight0 = (cache_idx_j0 != kCacheLocationMissing) ? Vec4T>(&lxu_cache_weights[cache_idx_j0][d]) : weight_row0.load(d); @@ -259,33 +308,29 @@ __global__ __launch_bounds__(kForwardMaxThreads) void weight3 = (cache_idx_j3 != kCacheLocationMissing) ? Vec4T>(&lxu_cache_weights[cache_idx_j3][d]) : weight_row3.load(d); - } else { - weight0 = weight_row0.load(d); - weight1 = weight_row1.load(d); - weight2 = weight_row2.load(d); - weight3 = weight_row3.load(d); + + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; } - grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + - weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; - grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + - weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; - grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + - weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; - grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + - weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; - } - - grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); - grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); - grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); - grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); - if (threadIdx.x == 0) { - grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; - grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; - grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; - grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } } } {%- endif %} @@ -447,7 +492,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( TORCH_WARN_ONCE("Running on CDNA architecture"); } #endif - + const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); // offsets = [B x T + 1] From 970229ba828aa19a5bfbf9c05791adecefd3bc4c Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 24 Sep 2025 00:48:35 +0000 Subject: [PATCH 27/92] explicitly link to tbb --- cmake/modules/CppLibrary.cmake | 12 ++++++++++++ cmake/modules/GpuCppLibrary.cmake | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/cmake/modules/CppLibrary.cmake b/cmake/modules/CppLibrary.cmake index 92a93a60b6..388d3ac779 100644 --- a/cmake/modules/CppLibrary.cmake +++ b/cmake/modules/CppLibrary.cmake @@ -168,6 +168,18 @@ function(cpp_library) target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + target_link_libraries(${lib_name} PUBLIC TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + target_link_libraries(${lib_name} PUBLIC ${TBB_LIB}) + endif() + endif() + # Add sanitizer options if needed if(args_SANITIZER_OPTIONS) target_link_options(${lib_name} PUBLIC diff --git a/cmake/modules/GpuCppLibrary.cmake b/cmake/modules/GpuCppLibrary.cmake index 51c30df750..e662848348 100644 --- a/cmake/modules/GpuCppLibrary.cmake +++ b/cmake/modules/GpuCppLibrary.cmake @@ -302,6 +302,18 @@ function(gpu_cpp_library) list(APPEND library_dependencies ${NVML_LIB_PATH}) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + list(APPEND library_dependencies TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + list(APPEND library_dependencies ${TBB_LIB}) + endif() + endif() + # Link against the external libraries as needed target_link_libraries(${lib_name} PRIVATE ${library_dependencies}) From 539985c2b18521610d79fb222514bfd691e69b82 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Thu, 25 Sep 2025 19:00:23 +0000 Subject: [PATCH 28/92] added warpReduceAllSum with rocm guards --- .../include/fbgemm_gpu/utils/cuda_prelude.cuh | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) mode change 100644 => 100755 fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh old mode 100644 new mode 100755 index a1d9819017..d51e3fa475 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -140,11 +140,19 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { - return rocm::wave_reduce< - rocm::reduce_op::sum, // Sum reduction - T, // Data type - ReduceWidth // Wave/Warp size - >(val); + #ifdef USE_ROCM + return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); + #else + #pragma unroll + for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { + val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); + } + return val; + #endif } DEVICE_INLINE void syncwarp() { From e3d477397b0af79048be7541e1c7dbfc069b29db Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 13 Oct 2025 20:34:59 +0000 Subject: [PATCH 29/92] revert unroll and wg tuning --- .../forward/embedding_forward_split_kernel_template.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index 25aca3506b..4dd60af489 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -469,10 +469,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize/32) > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize) > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < (kThreadGroupSize/32) && l_start + j < L; ++j) { + for (auto j = 0; j < (kThreadGroupSize) && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group @@ -641,7 +641,7 @@ batch_index_select_dim0_codegen_forward_kernel( {%- endif %} {%- if is_rocm %} // Unroll factor for ROCm devices - constexpr int kManualUnrollLength = 8; + constexpr int kManualUnrollLength = 4; {%- endif %} // Determine the linearized warp ID, and exit early if needed From 9505ffe1da60d8f8582d12b13a0de44928ac1dfe Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 13 Oct 2025 15:46:07 -0500 Subject: [PATCH 30/92] Minor update embedding_forward_split_kernel_template.cu --- .../forward/embedding_forward_split_kernel_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index 4dd60af489..a39d33e391 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -469,10 +469,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize) > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + kThreadGroupSize > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < (kThreadGroupSize) && l_start + j < L; ++j) { + for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group From 8709307b06d590a4f0bbcaec5cac68dea9da44cc Mon Sep 17 00:00:00 2001 From: Li Li Date: Fri, 17 Oct 2025 21:17:37 +0000 Subject: [PATCH 31/92] add tbb-devel to the install_build_tools () --- .github/scripts/utils_build.bash | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/scripts/utils_build.bash b/.github/scripts/utils_build.bash index 82fa3e26a2..709e7b62f4 100644 --- a/.github/scripts/utils_build.bash +++ b/.github/scripts/utils_build.bash @@ -370,6 +370,7 @@ install_build_tools () { patchelf \ rhash \ scikit-build \ + tbb-devel \ tbb \ wheel \ xz \ From 6a3d3cba7da38c1c6959f0b987a31d33e626a1c5 Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 21 Oct 2025 18:54:56 +0000 Subject: [PATCH 32/92] fix lint issues --- ...plit_table_batched_embeddings_benchmark.py | 35 +++++++++---------- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 4 +-- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 3fad8f53fe..02d3820b07 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -1192,7 +1192,6 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - if load: requests = [] for i in range(iters): @@ -1253,7 +1252,6 @@ def device_with_spec( # noqa C901 torch.save(req.offsets, f"{save}/{i}_offsets.pt") torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") - sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: read_write_bytes = ( @@ -1280,23 +1278,22 @@ def device_with_spec( # noqa C901 # forward time_per_iter = benchmark_requests( - requests, - lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - ), - flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, - num_warmups=warmup_runs, - ) + requests, + lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, + ) logging.info( - f"Forward, B: {B}, " - f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " - f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 - f"T: {time_per_iter * 1.0e6:.0f}us" - ) - + f"Forward, B: {B}, " + f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"T: {time_per_iter * 1.0e6:.0f}us" + ) if output_dtype == SparseType.INT8: # backward bench not representative @@ -1315,7 +1312,7 @@ def device_with_spec( # noqa C901 # pyre-ignore[19] # pyre-fixme[61]: `D` is undefined, or not always defined. grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) - + if save: torch.save(grad_output, f"{save}/grad_output.pt") # backward diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 1bda3188e5..22c79ed807 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -298,11 +298,11 @@ def benchmark_requests( # noqa: C901 torch.save(out, f) else: torch.save(out, f"{save}/{it}_fwd_grad_out.pt") - + out.backward(grad) torch.cuda.synchronize() torch.save(out, f"{save}/{it}_bwd_grad_out.pt") - + if sliced: for id, t in enumerate(emb.split_embedding_weights()): if compressed: From 6351c434e5a127739f51f4cbbca2edc8fdb8b131 Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 21 Oct 2025 21:23:38 +0000 Subject: [PATCH 33/92] solve lint issues --- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 18 +++++++----------- .../fbgemm_gpu/rocm/split_embeddings_common.h | 2 +- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 22c79ed807..591ce4a6c4 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -307,14 +307,13 @@ def benchmark_requests( # noqa: C901 for id, t in enumerate(emb.split_embedding_weights()): if compressed: with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: - torch.save(t[slice_min:slice_max,:].clone(), f) + torch.save(t[slice_min:slice_max, :].clone(), f) else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(t[slice_min:slice_max, :].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(t[slice_min:slice_max, :].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") - else: if compressed: with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: @@ -325,11 +324,9 @@ def benchmark_requests( # noqa: C901 if load and emb: for it in range(iters): req = requests[it % num_reqs] - indices, offsets, weights = req.unpack_3() out = emb(indices, offsets, weights) torch.cuda.synchronize() - out.backward(grad) torch.cuda.synchronize() emb_ref = copy.deepcopy(emb) @@ -339,8 +336,8 @@ def benchmark_requests( # noqa: C901 emb_ref.load_state_dict(torch.load(f)) else: emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) - print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) + if sliced: for id, t in enumerate(emb.split_embedding_weights()): if compressed: @@ -348,15 +345,15 @@ def benchmark_requests( # noqa: C901 w_ref = torch.load(f) else: w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") - torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, + torch.testing.assert_close(t[slice_min:slice_max, :], w_ref, msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) else: for id, t in enumerate(emb.split_embedding_weights()): - torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], + torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) print("PASS") - print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) + if sliced: m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") @@ -368,7 +365,6 @@ def benchmark_requests( # noqa: C901 print("PASS") for it in range(iters): req = requests[it % num_reqs] - indices, offsets, weights = req.unpack_3() if bwd_only: # Run forward before profiling if does backward only diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 745499ac08..aa869fe2b5 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -627,4 +627,4 @@ __device__ inline void magic_div_u32_run_with_mod( quo = magic_div_u32_run(mdiv, n); rem = n - quo * d; } -} // namespace fbgemm_gpu::rocm \ No newline at end of file +} // namespace fbgemm_gpu::rocm From 1e9b3f323b3c3ef9b3c2ed6084ff2316b18cc384 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Wed, 22 Oct 2025 18:45:41 +0000 Subject: [PATCH 34/92] applied jinja is_rocm onto optimizations for backward and forward parameters --- ..._backward_split_indice_weights_template.cu | 5 ++++- .../embedding_backward_split_template.cu | 22 ++++++++++--------- .../embedding_forward_split_template.cu | 4 ++-- .../batch_index_select_dim0_host.cpp | 4 ++-- ...dding_split_host_pt2_autograd_template.cpp | 4 ++++ 5 files changed, 24 insertions(+), 15 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 0052d96406..c58ba89f78 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -213,7 +213,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void 2, offset_idx + D_emb <= weights_numel, offset_idx ) {%- endif %} - + {%- if is_rocm %} int32_t j = 0; {%- if not ssd and not dense and not use_vec_blocking and not vbe %} // Currently for split_embedding_codegen_grad_indice_weights_kernel only @@ -335,6 +335,9 @@ __global__ __launch_bounds__(kForwardMaxThreads) void } {%- endif %} for (; j < kWarpSize && l_start + j < L; ++j) { + {%- else %} // if is_rocm + for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { + {%- endif %} // if is_rocm const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 86d4ce8b8b..759bbfd9bb 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -987,8 +987,11 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; - + {% if is_rocm %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + {% else %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + {%- endif %} Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { long_run_id_to_really_long_run_ids = @@ -1059,8 +1062,8 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t total_L = indices.numel(); - #ifdef USE_ROCM + {% if is_rocm %} + int32_t total_L = indices.numel(); int32_t num_cta_per_row_groups; int32_t work_group_size; if (total_L/total_B > 1){ @@ -1071,10 +1074,10 @@ Tensor {{ embedding_cuda_op }}( num_cta_per_row_groups = kMaxThreads / kWarpSize; work_group_size = kMaxThreads; } - #else + {%- else %} int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; int32_t work_group_size = kMaxThreads; - #endif + {%- endif %} const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1091,7 +1094,6 @@ Tensor {{ embedding_cuda_op }}( FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, - // (64, 2) dim3(kThreadGroupSize, num_cta_per_row_groups), cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), @@ -1195,7 +1197,7 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; // Compute shared memory size for warp_per_row - #ifdef USE_ROCM + {%- if is_rocm %} int32_t num_warp_per_row_groups; if (total_L/total_B > 1){ @@ -1204,9 +1206,9 @@ Tensor {{ embedding_cuda_op }}( else{ num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; } - #else + {%- else %} int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; - #endif + {%- endif %} int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index bbd62a8bbc..b6e9c94745 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -458,7 +458,7 @@ batch_index_select_dim0_codegen_forward_cuda( CUDA_DEVICE_GUARD(dev_weights); - #ifdef USE_ROCM + {% if is_rocm %} if (!rocm::is_supported_cdna()) { TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); } @@ -466,7 +466,7 @@ batch_index_select_dim0_codegen_forward_cuda( // Ensure we're running on a supported CDNA architecture (including MI350) TORCH_WARN_ONCE("Running on CDNA architecture"); } - #endif + {%- endif %} {%- if not nobag %} int32_t T = D_offsets.numel() - 1; diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index 02529f2d89..608f6017ec 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -342,7 +342,7 @@ class BatchIndexSelectDim0GPUOp Tensor grad_dev_weights; TORCH_CHECK_EQ(grad_outputs.size(), 1); - constexpr int32_t max_segment_length_per_warp = 16384; + constexpr int32_t max_segment_length_per_warp = 32; auto grad_output = grad_outputs[0]; @@ -658,7 +658,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 16384; + constexpr int32_t max_segment_length_per_warp = 32; auto grad_output = grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index fa6a27ab55..825681a57c 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -698,8 +698,10 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + {% if is_rocm %} const auto mixed_D = aux_bool[IDX_MIXED_D]; {%- endif %} + {%- endif %} // Default values for Dynamo tracing // SymInt does not support bitshifts operator @@ -1009,7 +1011,9 @@ static torch::autograd::variable_list backward( int32_t max_segment_length_per_warp = 64; // Workaround. Should not be upstreamed in any way. // Redistribute all cta_per_row work to warp_per_row. + {% if is_rocm %} int32_t total_L = indices.numel(); + {%- endif %} {%- if (not nobag) and (optimizer == "rowwise_adagrad") and (not vbe) and From 46b9f805703b0713eb654f8e9a1408572add920a Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 23 Oct 2025 13:54:56 +0000 Subject: [PATCH 35/92] Guard supported grad_t for optimized warp_per_row dispatch --- .../training/backward/embedding_backward_split_template.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 759bbfd9bb..18beeae1ff 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1238,7 +1238,9 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) + constexpr bool supported_grad_type = std::is_same_v || std::is_same_v; + + if (use_hip_kernel && !mixed_D && supported_weights_type && supported_grad_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} From ab5cf5da8cad945100f12a1bb9a4a7306d05fdf2 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 23 Oct 2025 13:56:05 +0000 Subject: [PATCH 36/92] Forward index_t to the optimizer --- .../backward/embedding_backward_split_kernel_warp_template.cu | 2 +- .../rocm/embedding_backward_split_device_kernel_template.hip | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index e61b3fc0aa..b757f64d36 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -650,7 +650,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd opt_karg.weight_decay = weight_decay; rocm::split_tbe_backward_hip_kernel_{{kdesc}}< - rocm::{{optimizer}}_optimizer_t, + rocm::{{optimizer}}_optimizer_t, rocm::{{optimizer}}_kernel_arg_t, emb_t, cache_t, diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 951cff4399..87d259ebee 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -27,7 +27,7 @@ #include "fbgemm_gpu/rocm/split_embeddings_common.h" namespace fbgemm_gpu::rocm { -template +template struct rowwise_adagrad_optimizer_t { __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) @@ -36,7 +36,7 @@ struct rowwise_adagrad_optimizer_t } template - __device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index) + __device__ void update(cache_t* acc, emb_t* weight, index_t row_index) { if constexpr(segment_split == 0) { From 5164f6ea28e70da67e656504c2ca33a8f8e423f4 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 2 Sep 2025 09:25:03 +0000 Subject: [PATCH 37/92] Guard f16 llvm intrinsics with ROCm >=7.0 --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index aa869fe2b5..08e1efa3e9 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -24,6 +24,7 @@ #include #include #include +#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -61,7 +62,7 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i16"); #else __asm("llvm.amdgcn.raw.buffer.load.f16"); @@ -78,7 +79,7 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); From cde00fcdc90bc86e4ac204b827517fc5e6338db7 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 23 Oct 2025 13:59:56 +0000 Subject: [PATCH 38/92] Fix buffer offset for emb_dim == 160 --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 08e1efa3e9..c1d98d3e9f 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -165,7 +165,7 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + amdgcn_make_buffer_resource(p_emb_table + row_index * 160); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, lane_id * sizeof(half2), 0, 0); if ((lane_id + 128) % 192 < 160) { From 5d73b9c7d8026d2ae922869f306908ddaebd0400 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 27 Oct 2025 14:36:52 +0000 Subject: [PATCH 39/92] Remove sanity check --- ...plit_table_batched_embeddings_benchmark.py | 190 +++++------------- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 123 ++---------- 2 files changed, 70 insertions(+), 243 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 02d3820b07..4ffb7341a5 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -7,8 +7,7 @@ # pyre-strict -import gzip -import yaml + import logging import os import tempfile @@ -1012,15 +1011,7 @@ def context_factory(on_trace_ready: Callable[[profile], None]): @TbeBenchClickInterface.common_options @TbeBenchClickInterface.device_options @TbeBenchClickInterface.vbe_options -@click.option("--save", type=str, default=None) -@click.option("--load", type=str, default=None) -@click.option("--random-weights", is_flag=True, default=False) -@click.option("--compressed", is_flag=True, default=False) -@click.option("--slice-min", type=int, default=None) -@click.option("--slice-max", type=int, default=None) -@click.pass_context def device_with_spec( # noqa C901 - ctx, alpha: float, bag_size_list: str, bag_size_sigma_list: str, @@ -1040,39 +1031,7 @@ def device_with_spec( # noqa C901 bounds_check_mode: int, flush_gpu_cache_size_mb: int, output_dtype: SparseType, - save: str, - load: str, - random_weights: bool, - compressed: bool, - slice_min: int, - slice_max: int, ) -> None: - if load: - with open(f"{load}/params.yaml", "r") as f: - ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader) - alpha = ctx.params["alpha"] - bag_size_list = ctx.params["bag_size_list"] - bag_size_sigma_list = ctx.params["bag_size_sigma_list"] - batch_size = ctx.params["batch_size"] - embedding_dim_list = ctx.params["embedding_dim_list"] - weights_precision = ctx.params["weights_precision"] - cache_precision = ctx.params["cache_precision"] - stoc = ctx.params["stoc"] - iters = ctx.params["iters"] - warmup_runs = ctx.params["warmup_runs"] - managed = ctx.params["managed"] - num_embeddings_list = ctx.params["num_embeddings_list"] - reuse = ctx.params["reuse"] - row_wise = ctx.params["row_wise"] - weighted = ctx.params["weighted"] - pooling = ctx.params["pooling"] - bounds_check_mode = ctx.params["bounds_check_mode"] - flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"] - output_dtype = ctx.params["output_dtype"] - random_weights = ctx.params["random_weights"] - compressed = ctx.params["compressed"] - slice_min = ctx.params["slice_min"] - slice_max = ctx.params["slice_max"] np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -1081,11 +1040,6 @@ def device_with_spec( # noqa C901 T = len(Ds) use_variable_bag_sizes = bag_size_sigma_list != "None" - params = ctx.params - if save: - os.makedirs(f"{save}", exist_ok=True) - with open(f"{save}/params.yaml", "w") as f: - yaml.dump(params, f, sort_keys=False) if use_variable_bag_sizes: Ls = [int(mu) for mu in bag_size_list.split(",")] @@ -1164,22 +1118,6 @@ def device_with_spec( # noqa C901 if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) - elif random_weights: - emb.init_embedding_weights_uniform(-1.0, 1.0) - - if save: - if compressed: - with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/model_state.pth") - - if load: - if compressed: - with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: - emb.load_state_dict(torch.load(f)) - else: - emb.load_state_dict(torch.load(f"{load}/model_state.pth")) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 @@ -1192,66 +1130,53 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - if load: - requests = [] - for i in range(iters): - indices = torch.load(f"{load}/{i}_indices.pt") - offsets = torch.load(f"{load}/{i}_offsets.pt") - per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") - Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt") - requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank)) - else: - for t, e in enumerate(Es): - # (indices, offsets, weights) - requests = generate_requests( - iters, - B, - 1, - Ls[t], - e, - reuse=reuse, - alpha=alpha, - weighted=weighted, - # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. - sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, - zipf_oversample_ratio=3 if Ls[t] > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - for i, req in enumerate(requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - prev_indices_len = -1 - requests = [] - for i in range(iters): - indices = torch.concat(all_requests["indices"][i]) - if prev_indices_len == -1: - prev_indices_len = indices.numel() - assert ( - prev_indices_len == indices.numel() - ), "Number of indices for every iteration must be the same" - offsets = torch.concat(all_requests["offsets"][i]) - if weighted: - weights = torch.concat(all_requests["weights"][i]) - else: - weights = None - requests.append(TBERequest(indices, offsets, weights)) - del all_requests + for t, e in enumerate(Es): + # (indices, offsets, weights) + requests = generate_requests( + iters, + B, + 1, + Ls[t], + e, + reuse=reuse, + alpha=alpha, + weighted=weighted, + # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. + sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, + zipf_oversample_ratio=3 if Ls[t] > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) + for i, req in enumerate(requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + prev_indices_len = -1 + requests = [] + for i in range(iters): + indices = torch.concat(all_requests["indices"][i]) + if prev_indices_len == -1: + prev_indices_len = indices.numel() + assert ( + prev_indices_len == indices.numel() + ), "Number of indices for every iteration must be the same" + offsets = torch.concat(all_requests["offsets"][i]) + if weighted: + weights = torch.concat(all_requests["weights"][i]) + else: + weights = None + requests.append(TBERequest(indices, offsets, weights)) + + del all_requests + assert len(requests) == iters - if save: - for i in range(iters): - req = requests[i] - torch.save(req.indices, f"{save}/{i}_indices.pt") - torch.save(req.offsets, f"{save}/{i}_offsets.pt") - torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") - torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") + sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: read_write_bytes = ( @@ -1299,22 +1224,13 @@ def device_with_spec( # noqa C901 # backward bench not representative return - if load: - grad_output = torch.load(f"{load}/grad_output.pt") + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) else: # Obtain B * L from indices len # pyre-ignore[19] # pyre-fixme[61]: `D` is undefined, or not always defined. - if do_pooling: - grad_output = torch.randn(B, sum(Ds)).to(get_device()) - else: - # Obtain B * L from indices len - # pyre-ignore[19] - # pyre-fixme[61]: `D` is undefined, or not always defined. - grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) - - if save: - torch.save(grad_output, f"{save}/grad_output.pt") + grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) # backward time_per_iter = benchmark_requests( requests, @@ -1328,12 +1244,6 @@ def device_with_spec( # noqa C901 bwd_only=True, grad=grad_output, num_warmups=warmup_runs, - emb=emb, - save=save, - load=load, - compressed=compressed, - slice_min=slice_min, - slice_max=slice_max, ) logging.info( f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, " diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 591ce4a6c4..1243f14db4 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -11,7 +11,6 @@ import statistics import threading import time -import gzip from subprocess import Popen from typing import Callable, Optional @@ -19,7 +18,7 @@ from fbgemm_gpu.tbe.utils import b_indices, TBERequest from fbgemm_gpu.tbe.utils.common import get_device -from fbgemm_gpu.split_table_batched_embeddings_ops_training import SplitTableBatchedEmbeddingBagsCodegen + logging.basicConfig(level=logging.DEBUG) @@ -242,43 +241,36 @@ def benchmark_requests( # noqa: C901 periodic_logs: bool = False, warmup_ms: Optional[int] = None, iters: int = -1, - emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, - save: Optional[str] = None, - load: Optional[str] = None, - compressed: bool = False, - slice_min: Optional[int] = None, - slice_max: Optional[int] = None, ) -> float: times = [] # Run at least one warmup iteration to avoid the long cudaLaunchKernel time # for the first kernel if warmup_ms > 0 # warmup_ms is prioritized over num_warmups - import copy + if warmup_ms is None: num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 - if not (load or save): - # warm-up the GPU before profiling - bench_warmup( - requests[0], - # pyre-ignore[6] - warmup_ms, - num_warmups, - lambda indices, offsets, per_sample_weights: func( - indices, - offsets, - per_sample_weights, - ), - bwd_only=bwd_only, - grad=grad, - ) + # warm-up the GPU before profiling + bench_warmup( + requests[0], + # pyre-ignore[6] + warmup_ms, + num_warmups, + lambda indices, offsets, per_sample_weights: func( + indices, + offsets, + per_sample_weights, + ), + bwd_only=bwd_only, + grad=grad, + ) - if callback_after_warmup is not None: - callback_after_warmup() + if callback_after_warmup is not None: + callback_after_warmup() num_reqs = len(requests) iters = num_reqs if iters == -1 else iters - sliced = slice_min is not None and slice_max is not None + if torch.cuda.is_available(): torch.cuda.synchronize() start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -286,85 +278,10 @@ def benchmark_requests( # noqa: C901 else: start_events = [] end_events = [] - if save and emb: - for it in range(iters): - req = requests[it % num_reqs] - indices, offsets, weights = req.unpack_3() - out = emb(indices, offsets, weights) - torch.cuda.synchronize() - if compressed: - with gzip.open(f"{save}/{it}_fwd_grad_out.pt.gz", "wb") as f: - torch.save(out, f) - else: - torch.save(out, f"{save}/{it}_fwd_grad_out.pt") - - out.backward(grad) - torch.cuda.synchronize() - torch.save(out, f"{save}/{it}_bwd_grad_out.pt") - - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: - torch.save(t[slice_min:slice_max, :].clone(), f) - else: - torch.save(t[slice_min:slice_max, :].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - else: - torch.save(t[slice_min:slice_max, :].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") - torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") - else: - if compressed: - with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") - - if load and emb: - for it in range(iters): - req = requests[it % num_reqs] - indices, offsets, weights = req.unpack_3() - out = emb(indices, offsets, weights) - torch.cuda.synchronize() - out.backward(grad) - torch.cuda.synchronize() - emb_ref = copy.deepcopy(emb) - if not sliced: - if compressed: - with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: - emb_ref.load_state_dict(torch.load(f)) - else: - emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) - print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) - - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: - w_ref = torch.load(f) - else: - w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") - torch.testing.assert_close(t[slice_min:slice_max, :], w_ref, - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - else: - for id, t in enumerate(emb.split_embedding_weights()): - torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - print("PASS") - print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) - - if sliced: - m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") - m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") - else: - m_dev_ref = emb_ref.momentum1_dev - m_uvm_ref = emb_ref.momentum1_uvm - torch.testing.assert_close(emb.momentum1_dev, m_dev_ref, atol=1.0e-4, rtol=1.0e-4) - torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref, atol=1.0e-4, rtol=1.0e-4) - print("PASS") for it in range(iters): req = requests[it % num_reqs] + indices, offsets, weights = req.unpack_3() if bwd_only: # Run forward before profiling if does backward only From 919db74ae443d1871190b85ba4103d30d830e10e Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 27 Oct 2025 20:14:58 +0000 Subject: [PATCH 40/92] address the potential lint issues and revert the change in indices_generator.cpp --- ...dding_backward_split_kernel_warp_template.cu | 11 +++++------ .../embedding_backward_split_template.cu | 13 ++++++------- ...ng_backward_split_device_kernel_template.hip | 17 ++++++++--------- .../forward/embedding_forward_split_template.cu | 2 +- ...bedding_split_host_pt2_autograd_template.cpp | 14 +++++++------- .../fbgemm_gpu/rocm/split_embeddings_common.h | 4 ++-- fbgemm_gpu/src/tbe/eeg/indices_generator.cpp | 1 + 7 files changed, 30 insertions(+), 32 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index b757f64d36..7b3b5b653a 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,13 +32,13 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and - optimizer == "rowwise_adagrad" and +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and not dense and - not nobag and + not nobag and not is_index_select and - not is_gwd_kernel and - not vbe and + not is_gwd_kernel and + not vbe and not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" @@ -621,7 +621,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- endif %} ) { int32_t T = D_offsets.size(0) - 1; - auto p_output_grad = grad_output.data(); auto p_emb_table = dev_weights.data(); auto p_hash_size_cumsum = hash_size_cumsum.data(); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 18beeae1ff..72cf189ccc 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,13 +48,13 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and - optimizer == "rowwise_adagrad" and +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and not dense and - not nobag and + not nobag and not is_index_select and - not is_gwd_kernel and - not vbe and + not is_gwd_kernel and + not vbe and not ssd %} template < @@ -669,7 +669,7 @@ Tensor {{ embedding_cuda_op }}( TORCH_WARN_ONCE("Running on CDNA architecture"); } #endif - + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} @@ -1199,7 +1199,6 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for warp_per_row {%- if is_rocm %} int32_t num_warp_per_row_groups; - if (total_L/total_B > 1){ num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; } diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 87d259ebee..2a747731cc 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -225,7 +225,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; @@ -234,7 +234,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; @@ -261,7 +261,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; @@ -270,7 +270,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; @@ -301,7 +301,6 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; @@ -314,7 +313,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; @@ -323,7 +322,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; @@ -341,7 +340,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; @@ -350,7 +349,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index b6e9c94745..f4de721bc9 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -467,7 +467,7 @@ batch_index_select_dim0_codegen_forward_cuda( TORCH_WARN_ONCE("Running on CDNA architecture"); } {%- endif %} - + {%- if not nobag %} int32_t T = D_offsets.numel() - 1; {%- else %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 825681a57c..4c00a1ba9f 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1014,18 +1014,18 @@ static torch::autograd::variable_list backward( {% if is_rocm %} int32_t total_L = indices.numel(); {%- endif %} - {%- if (not nobag) and - (optimizer == "rowwise_adagrad") and - (not vbe) and - (not is_gwd) and - (not ssd) and - (not is_index_select) and + {%- if (not nobag) and + (optimizer == "rowwise_adagrad") and + (not vbe) and + (not is_gwd) and + (not ssd) and + (not is_index_select) and (not dense) %} const auto T = weights_offsets.sym_numel(); auto total_B = (offsets.size(0) - 1); const auto B = total_B / T; {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} - if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) + if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) { max_segment_length_per_warp = 16384; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index c1d98d3e9f..b5aa74c1ab 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -217,7 +217,7 @@ struct load_row_per_warp { *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; + emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; } }; @@ -335,7 +335,7 @@ struct store_row_per_warp { auto out = reinterpret_cast(p_output); out[lane_id] = *reinterpret_cast(acc); out[lane_id + 64] = *reinterpret_cast(&acc[2]); - p_output[lane_id + 256] = acc[4]; + p_output[lane_id + 256] = acc[4]; } }; diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp index 715acd8c0c..dfea2dce8a 100755 --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,6 +131,7 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( + std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs, From 3df3c9130e4a0036c9e9fd898dd7a77815dcc7ea Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 27 Oct 2025 20:37:50 +0000 Subject: [PATCH 41/92] addresss code style issue --- .../training/index_select/batch_index_select_dim0_host.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index 608f6017ec..18378b6106 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -658,7 +658,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 32; + constexpr int32_t max_segment_length_per_warp = 32; auto grad_output = grad_outputs[0]; From 6c3a362d66e893857fcec66df3e70ce712ec715a Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Fri, 24 Oct 2025 14:44:16 +0000 Subject: [PATCH 42/92] Remove general load/store methods --- ..._backward_split_device_kernel_template.hip | 2 +- .../fbgemm_gpu/rocm/split_embeddings_common.h | 397 ++++++++++++------ 2 files changed, 259 insertions(+), 140 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2a747731cc..cd3d645775 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -410,6 +410,6 @@ L_tail_grad_acc: optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); - store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); + store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); } } // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b5aa74c1ab..e6e575e6e5 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -21,7 +21,11 @@ * ******************************************************************************/ #pragma once + +#include #include +#include + #include #include #include @@ -47,10 +51,10 @@ union amdgcn_buffer_resource { }; template -__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { +__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr, const int32_t size_in_bytes = 0xFFFFFFFF) { amdgcn_buffer_resource buffer_resource; buffer_resource.address = const_cast(addr); - buffer_resource.range = 0xffffffff; + buffer_resource.range = size_in_bytes; buffer_resource.config = AMDGCN_BUFFER_RES_3; // for gfx9 return buffer_resource.content; @@ -60,8 +64,8 @@ __device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) + int32_t soffset = 0, + int32_t glc_slc = 0) #if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i16"); #else @@ -71,33 +75,59 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.load.f32"); __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) + int32_t soffset = 0, + int32_t glc_slc = 0) #if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); #endif +__device__ void llvm_amdgcn_raw_buffer_store_fp16( + const half vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0 +) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.store.f16"); +#endif + +__device__ void llvm_amdgcn_raw_buffer_store_fp16x2( + const half2 vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0 +) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.store.v2f16"); +#endif + __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.f32"); __device__ void llvm_amdgcn_raw_buffer_store_fp32x2( floatx2_t vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); /******************************************************************************/ @@ -107,35 +137,15 @@ struct load_row_per_warp { emb_t* emb_data, index_t row_index, const emb_t* p_emb_table, - int lane_id) {} -}; - -template -struct load_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim); -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - if constexpr (embedding_dim == 160) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); + // Types are not supported, but we need an instance of run method to avoid run-time .so symbol + // failure. Currently, the kernel dispatch for unsupported type is guarded on host side + if constexpr (std::is_same_v || std::is_same_v) { + __builtin_trap(); } else { - emb_data[i] = 0.f; + static_assert(false, "HIP: Optimized load operation is not supported yet"); } - } else { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); } - } - } }; template @@ -145,7 +155,7 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 64); emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half), 0, 0); + llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half)); } }; @@ -156,7 +166,7 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 128); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); } }; @@ -165,15 +175,11 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 160); + amdgcn_make_buffer_resource(p_emb_table + row_index * 160, sizeof(half) * 160); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - if ((lane_id + 128) % 192 < 160) { + emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); - } else { - emb_data[2] = __float2half(0.0); - } + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -184,9 +190,9 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 192); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -198,10 +204,10 @@ struct load_row_per_warp { amdgcn_make_buffer_resource(p_emb_table + row_index * 256); *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); } }; @@ -210,35 +216,15 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 320); - *reinterpret_cast(&emb_data[0]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[2]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 512); + amdgcn_make_buffer_resource(p_emb_table + row_index * 320, sizeof(half) * 320); *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[4]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 2) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[6]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 3) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); + emb_data[4] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -256,9 +242,97 @@ struct load_row_per_warp { lane_id ); } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 64); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 128); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 160, sizeof(float) * 160); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + } +}; +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 256); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 192) * sizeof(float)); + } }; +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 320, sizeof(float) * 320); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 192) * sizeof(float)); + emb_data[4] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 256) * sizeof(float)); + } +}; template < typename emb_t, @@ -291,116 +365,161 @@ struct accumulate_row_per_warp { } }; -template +template struct store_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run(output_t* acc, output_t* p_output, int lane_id) { - if constexpr (embedding_dim == 160) { - for (int i = 0; i < dword_per_row; i++) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } - } + static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id) { + // Types are not supported, but we need an instance of run method to avoid run-time .so symbol + // failure. Currently, the kernel dispatch for unsupported type is guarded on host function + if constexpr (std::is_same_v || std::is_same_v) { + __builtin_trap(); } else { -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } + static_assert(false, "HIP: Optimized load operation is not supported yet"); } } }; template <> -struct store_row_per_warp { - static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { - auto out = reinterpret_cast(p_output); - out[lane_id] = *reinterpret_cast(acc); - out[lane_id + 64] = *reinterpret_cast(&acc[2]); +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16(acc[0], out_res, lane_id * sizeof(half)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, 160 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[2], out_res, (lane_id + 128) * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { - auto out = reinterpret_cast(p_output); - out[lane_id] = *reinterpret_cast(acc); - *(reinterpret_cast(&out[64]) + lane_id) = *reinterpret_cast(acc + 2); +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[2], out_res, (lane_id + 128) * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { - auto out = reinterpret_cast(p_output); - out[lane_id] = *reinterpret_cast(acc); - out[lane_id + 64] = *reinterpret_cast(&acc[2]); - p_output[lane_id + 256] = acc[4]; +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc + 2), out_res, (lane_id + 64) * sizeof(half2)); } }; +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, 320 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc + 2), out_res, (lane_id + 64) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[4], out_res, (lane_id + 256) * sizeof(half)); + } +}; + +template +struct store_row_per_warp { + static __device__ void run( + const c10::Half* emb_data, + c10::Half* p_emb_table, + int lane_id) { + store_row_per_warp::run( + reinterpret_cast(emb_data), + reinterpret_cast(p_emb_table), + lane_id + ); + } +}; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32( + acc[0], out_res, lane_id * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), + *reinterpret_cast(acc), out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); + lane_id * sizeof(floatx2_t)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 160); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), + *reinterpret_cast(acc), out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - if ((lane_id + 128) % 192 < 160) { - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); - } + lane_id * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 128) * sizeof(float)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), + *reinterpret_cast(acc), out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); + lane_id * sizeof(floatx2_t)); llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); + acc[2], out_res, (lane_id + 128) * sizeof(float)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), + *reinterpret_cast(acc), + out_res, + lane_id * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(&acc[2]), + out_res, + (lane_id + 64) * sizeof(floatx2_t)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 320); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); + lane_id * sizeof(floatx2_t)); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), + *reinterpret_cast(&acc[2]), out_res, - (lane_id + 64) * sizeof(floatx2_t), - 0, - 0); + (lane_id + 64) * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[4], out_res, (lane_id + 256) * sizeof(float)); } }; From 8cb68388175b0d10a020f2eb63ff400fa2690fb8 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Fri, 24 Oct 2025 14:57:41 +0000 Subject: [PATCH 43/92] Move weight type check to compile-time --- .../training/backward/embedding_backward_split_template.cu | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 72cf189ccc..82acd61baa 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1234,9 +1234,7 @@ Tensor {{ embedding_cuda_op }}( const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); - const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half - || dev_weights.scalar_type() == at::ScalarType::Float; - + constexpr bool supported_weights_type = std::is_same_v || std::is_same_v; constexpr bool supported_grad_type = std::is_same_v || std::is_same_v; if (use_hip_kernel && !mixed_D && supported_weights_type && supported_grad_type && rocm::is_supported_cdna()) From ab6fa10667be47bc9a68e0636ba48bb1b6994ab5 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 27 Oct 2025 11:51:43 +0000 Subject: [PATCH 44/92] Switch to 256B stores for float type --- .../fbgemm_gpu/rocm/split_embeddings_common.h | 54 +++++++------------ 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index e6e575e6e5..5475f74ddd 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -449,8 +449,7 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32( - acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); } }; @@ -458,10 +457,8 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); } }; @@ -469,12 +466,9 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 160); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); } }; @@ -482,12 +476,9 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); } }; @@ -495,14 +486,10 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), - out_res, - (lane_id + 64) * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[3], out_res, (lane_id + 192) * sizeof(float)); } }; @@ -510,16 +497,11 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 320); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), - out_res, - (lane_id + 64) * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[4], out_res, (lane_id + 256) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[4], out_res, (lane_id + 192) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[5], out_res, (lane_id + 256) * sizeof(float)); } }; From c5a915d8ec18e18c4581d284ac508bcd297c2da9 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 28 Oct 2025 19:16:27 +0000 Subject: [PATCH 45/92] removed guard rocm on mixed_D and refactored mixed_D var assignment --- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 4c00a1ba9f..789877f69f 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -698,9 +698,7 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; - {% if is_rocm %} - const auto mixed_D = aux_bool[IDX_MIXED_D]; - {%- endif %} + const auto mixed_D = static_cast(aux_bool[IDX_MIXED_D]); {%- endif %} // Default values for Dynamo tracing @@ -813,7 +811,7 @@ class {{ autograd_func }} : {%- if not nobag %} ctx->saved_data["max_D"] = max_D; - ctx->saved_data["mixed_D"] = static_cast(aux_bool[IDX_MIXED_D]); + ctx->saved_data["mixed_D"] = mixed_D; ctx->saved_data["pooling_mode"] = pooling_mode; {%- else %} ctx->saved_data["D"] = D; From ca4701f275e0f84102bdfe148c98bbabc83d6641 Mon Sep 17 00:00:00 2001 From: Wulley Date: Sun, 2 Nov 2025 03:11:09 +0000 Subject: [PATCH 46/92] hack param --- .../embedding_backward_split_kernel_warp_template.cu | 4 ++-- .../backward/embedding_backward_split_template.cu | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 7b3b5b653a..c960ad9d9d 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,7 +32,7 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and +{%- set is_optimized_hip_kernel_supported_mode_ori = is_rocm and optimizer == "rowwise_adagrad" and not dense and not nobag and @@ -546,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endif %} -{%- if is_optimized_hip_kernel_supported_mode %} +{%- if is_optimized_hip_kernel_supported_mode_ori %} #include #include #include "fbgemm_gpu/rocm/split_embeddings_common.h" diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 82acd61baa..8be79b8816 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,7 +48,7 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and +{%- set is_optimized_hip_kernel_supported_mode_ori = is_rocm and optimizer == "rowwise_adagrad" and not dense and not nobag and @@ -236,7 +236,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_optimized_hip_kernel_supported_mode %} +{%- if is_optimized_hip_kernel_supported_mode_ori %} #include "fbgemm_gpu/rocm/split_embeddings_common.h" template < typename emb_t, @@ -870,7 +870,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_optimized_hip_kernel_supported_mode %} + {%- if is_optimized_hip_kernel_supported_mode_ori %} {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, @@ -1230,7 +1230,7 @@ Tensor {{ embedding_cuda_op }}( get_max_thread_blocks_()); #ifdef USE_ROCM - {%- if is_optimized_hip_kernel_supported_mode %} + {%- if is_optimized_hip_kernel_supported_mode_ori %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); From 5bf0cf6885dceb6fdb0fb165a901a7e4a7e3674f Mon Sep 17 00:00:00 2001 From: Wulley Date: Mon, 27 Oct 2025 06:36:55 +0000 Subject: [PATCH 47/92] support opt code_gen --- ...ing_backward_split_kernel_warp_template.cu | 339 ++++++++++++++++++ .../embedding_backward_split_template.cu | 209 ++++++++++- 2 files changed, 546 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index c960ad9d9d..959c617efd 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -41,6 +41,14 @@ not vbe and not ssd %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" @@ -341,6 +349,258 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( } } +{%- if is_optimized_hip_kernel_supported_mode %} +template < + typename emb_t, + typename grad_t, + typename cache_t, + typename index_t, + {%- for ph_name in args.placeholder_tensor_names %} + typename {{ ph_name + "_ph_t"}}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking> +__global__ __launch_bounds__(kBackwardMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const pta::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} + {%- endif %} +) { + {%- if not nobag %} + int32_t T = D_offsets.size(0) - 1; + {%- else %} + int32_t T = weights_offsets.size(0); + {%- endif %} + const auto start_run_id = blockIdx.x * blockDim.y + threadIdx.y; + +#ifdef FBGEMM_USE_SUBWARP_SHUFFLE + const unsigned int shfl_sync_mask = + ((1L << kThreadGroupSize) - 1) << + (threadIdx.y % (kWarpSize / kThreadGroupSize) * kThreadGroupSize); +#else + const unsigned int shfl_sync_mask = 0xffffffffu; +#endif + +#define BROADCAST(val, srcLane) __builtin_amdgcn_readlane(val,srcLane) + + constexpr int VEC_WIDTH = 4; + constexpr auto kIsInt8 = std::is_same::value; + + struct SharedMemory> smem; + const int32_t grad_sum_stride = max_D / VEC_WIDTH; + auto* smem_grad_sum = (kUseVecBlocking || kIsInt8) + ? smem.getPointer() + threadIdx.y * grad_sum_stride + : nullptr; + + constexpr int num_unroll = 32; + + auto num_run_id = min(sorted_linear_indices_run.size(0), sorted_linear_indices_num_runs[0]); + + for (uint32_t out_run_id = start_run_id * num_unroll; out_run_id < num_run_id; out_run_id += gridDim.x * blockDim.y * num_unroll) { + auto stride = gridDim.x * blockDim.y; + auto num_valid_id = min(num_unroll, num_run_id - out_run_id); + auto is_valid = threadIdx.x < num_valid_id; + + int32_t s_segment_start = is_valid? sorted_linear_indices_cumulative_run_lengths[(out_run_id + threadIdx.x)] : -1; + int32_t s_segment_end = is_valid? sorted_linear_indices_cumulative_run_lengths[(out_run_id + threadIdx.x + 1)] : -1; + int64_t s_idx = is_valid? sorted_linear_indices_run[out_run_id + threadIdx.x] : -1; + + {%- if not nobag %} + uint32_t s_t_0 = is_valid? reinterpret_cast(&sorted_infos[0])[s_segment_start] : -1; + s_t_0 = s_t_0 >> info_B_num_bits; + {%- else %} + auto s_t_0 = is_valid? sorted_infos[s_segment_start] : -1; + s_t_0 = s_t_0 % T; + {%- endif %} + + int64_t s_hash_size = is_valid? hash_size_cumsum[s_t_0] : -1; + s_idx -= s_hash_size; + {%- if not nobag %} + int32_t s_D_offsets_0 = is_valid? D_offsets[s_t_0] : 0; + int32_t s_D_offsets_1 = is_valid? D_offsets[s_t_0 + 1] : 0; + auto s_D = s_D_offsets_1 - s_D_offsets_0; + {%- endif %} + + int32_t s_table_unique_indice_offset = is_valid? table_unique_indices_offsets[s_t_0] : 0; + int64_t s_weights_offset = is_valid? weights_offsets[s_t_0] : 0; + int64_t s_momentum1_offset = is_valid? momentum1_offsets[s_t_0] : 0; + int32_t s_weights_placement = is_valid? weights_placements[s_t_0] : 0; + int32_t s_momentum1_placement = is_valid? momentum1_placements[s_t_0] : 0; + + at::acc_type* __restrict__ s_momentum1; + if (static_cast(s_momentum1_placement) == PlacementType::DEVICE) { + s_momentum1 = &momentum1_dev[s_momentum1_offset]; + } else { + s_momentum1 = &momentum1_uvm[s_momentum1_offset]; + } + + for (auto i = 0; i < num_valid_id; ++i) { + auto run_id = out_run_id + i; + auto t_0 = BROADCAST(s_t_0, i); + auto idx = BROADCAST(s_idx, i); + auto segment_start = BROADCAST(s_segment_start, i); + auto segment_end = BROADCAST(s_segment_end, i); + auto D = BROADCAST(s_D, i); + int32_t table_unique_indice_offset = BROADCAST(s_table_unique_indice_offset, i); + const int32_t SL = segment_end - segment_start; + + const int64_t weights_offset = SHFL_SYNC(s_weights_offset, i); + const auto weights_placement = static_cast(SHFL_SYNC(s_weights_placement, i)); + + const int64_t momentum1_offset = SHFL_SYNC(s_momentum1_offset, i); + const auto momentum1_placement = static_cast(SHFL_SYNC(s_momentum1_placement, i)); + auto momentum1 = reinterpret_cast*>(SHFL_SYNC(reinterpret_cast(s_momentum1), i)); + auto momentum1_val = momentum1[idx]; + + if (SL >= max_segment_length_per_warp) { + continue; + } + + // now, each segment corresponds to exactly one table `t` and row in + // that table (`idx`). Thus, we can hoist out some of the book-keeping. + + const int32_t SL_per_warp = div_round_up(SL, blockDim.y); + const int32_t sl_start = 0; + const int32_t sl_end = SL; + + Vec4TAcc grad_sum[kFixedMaxVecsPerThread]; + constexpr int32_t kGroupVecWidth = kThreadGroupSize * VEC_WIDTH; + const int32_t num_vecs = (D + kGroupVecWidth - 1) / kGroupVecWidth; + + compute_grad_sum_{{ kdesc }}< + grad_t, + cache_t, + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + grad_sum, + smem_grad_sum, + grad_output, + {%- if not nobag or is_index_select %} + D_offsets, + {%- endif %} + D, + T, + sorted_infos, + {%- if weighted %} + sorted_indice_weights, + {%- endif %} + {%- if not nobag and vbe %} + B_offsets, + row_output_offsets, + {%- endif %} + {%- if not nobag %} + info_B_num_bits, + info_B_mask, + {%- endif %} + segment_start, + sl_start, + sl_end, + shfl_sync_mask, + num_vecs + ); + + // Copy value to max_vecs to make max_vecs_per_thread known at compile time + // when kUseVecBlocking == false + const int32_t max_vecs = + kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; + split_rowwise_adagrad_table_update_kernel< + emb_t, + cache_t, + {%- for ph_name in args.placeholder_tensor_names %} + {{ ph_name + "_ph_t" }}, + {%- endfor %} + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + dev_weights, + uvm_weights, + lxu_cache_weights, + weights_placements, + weights_offsets, + sorted_{{ locs_or_addrs_tensor }}, + grad_sum, + smem_grad_sum, + smem_grad_sum, // shared_weight_update_row (reuse smem_grad_sum) + stochastic_rounding, + stochastic_rounding_philox_args, + run_id, + use_uniq_cache_locations + ? (run_id - table_unique_indices_offsets[t_0]) + : segment_start, + D, + t_0, + idx, + {%- if is_gwd_kernel %} + global_weight_decay, + {%- elif has_global_weight_decay_support %} + {# /* cases where gwd is not enabled/supported */ #} + 1, // global_weight_decay + {%- endif %} + shfl_sync_mask, + max_vecs, + momentum1, momentum1_val, learning_rate, eps, weight_decay, weight_decay_mode, max_norm + ); + } + } +} +{%- endif %} + //////////////////////////////////////////////////////////////////////////////// // Explicit Template Instantiations @@ -455,6 +715,85 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row }} {%- endif %} ); + +{%- if is_optimized_hip_kernel_supported_mode %} + +template __global__ __launch_bounds__(kBackwardMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1 +< {{ emb_type }}, + {{ grad_type }}, + {{ cache_type }}, + {{ index_type }}, + {%- for ph_name in args.placeholder_tensor_names %} + {{ ph_type_combo[ph_name].primitive_type }}, + {%- endfor %} + {{ kFixedMaxVecsPerThread }}, + {{ kThreadGroupSize }}, + {{ kUseVecBlocking }} +> ( + const pta::PackedTensorAccessor64<{{ grad_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights, + pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const pta::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args_no_defaults | + replace_pta_namespace() | + replace_placeholder_types(ph_type_combo) | + join(",\n ") | + replace("cache_t", cache_type) + }} + {%- endif %} +); + +{%- endif %} {%- endmacro %} {%- macro bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 8be79b8816..41679bd7ad 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -56,6 +56,14 @@ using namespace fbgemm_gpu; not is_gwd_kernel and not vbe and not ssd %} + +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} template < typename emb_t, @@ -307,6 +315,147 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- endif %} ); {%- endif %} + +{%- if is_optimized_hip_kernel_supported_mode %} + +template < + typename emb_t, + typename grad_t, + typename cache_t, + typename index_t, + {%- for ph_name in args.placeholder_tensor_names %} + typename {{ ph_name + "_ph_t" }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking> +__global__ __launch_bounds__(kMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_cta_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} // if optimizer != "none" + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + const pta::PackedTensorAccessor32 long_run_ids, + const pta::PackedTensorAccessor32 num_long_run_ids, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- if optimizer == "none" %} + const int32_t max_D, + {%- endif %} + {%- endif %} // if not dense and optimizer != "none" + {%- if vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const pta::PackedTensorAccessor32 long_run_id_to_really_long_run_ids, + pta::PackedTensorAccessor32, 2, at::RestrictPtrTraits> temp_grad_accum, + pta::PackedTensorAccessor32 grad_accum_counter, + const int32_t max_segment_length_per_cta, + const bool use_deterministic_algorithms, + const int32_t max_vecs_per_thread, + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} +); + +template < + typename emb_t, + typename grad_t, + typename cache_t, + typename index_t, + {%- for ph_name in args.placeholder_tensor_names %} + typename {{ ph_name + "_ph_t" }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking> +__global__ __launch_bounds__(kBackwardMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} +); +{%- endif %} + {% if is_index_select %} namespace index_select { {% else %} @@ -877,7 +1026,25 @@ Tensor {{ embedding_cuda_op }}( wdesc, vdesc, ) - %} + %} + {%- endif %} + + {%- if is_optimized_hip_kernel_supported_mode %} + {%- set hip_mixed_d_warp_kernel = "hip_mixed_d_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( + ndesc, + optimizer, + wdesc, + vdesc, + ) + %} + + {%- set hip_mixed_d_cta_kernel = "hip_mixed_d_split_embedding{}_backward_codegen_{}_{}{}_kernel_cta_per_row_1".format( + ndesc, + optimizer, + wdesc, + vdesc, + ) + %} {%- endif %} AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] { @@ -1029,6 +1196,10 @@ Tensor {{ embedding_cuda_op }}( {use_deterministic_algorithms ? 0 : grad_accum_counter.numel(), max_D}, aligned_grad_output.options().dtype(std::is_same::value ? at::kDouble : at::kFloat)); + {%- if is_optimized_hip_kernel_supported_mode %} + const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); + {%- endif %} + DISPATCH_PLACEHOLDER_TYPES( {%- for ph_name in args.placeholder_tensor_names %} {{ ph_name + "_dev" }}.scalar_type(), @@ -1047,7 +1218,7 @@ Tensor {{ embedding_cuda_op }}( ) %} - const auto backward_cta_per_row_kernel = + auto backward_cta_per_row_kernel = {{ cta_kernel }} ; + + {%- if is_optimized_hip_kernel_supported_mode %} + if (use_hip_kernel && mixed_D) { + backward_cta_per_row_kernel = + {{ hip_mixed_d_cta_kernel }} + ; + } + {%- endif %} // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); @@ -1196,6 +1384,23 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; + {%- if is_optimized_hip_kernel_supported_mode %} + if (use_hip_kernel && mixed_D) { + backward_warp_per_row_kernel = + {{ hip_mixed_d_warp_kernel }} + ; + } + {%- endif %} + // Compute shared memory size for warp_per_row {%- if is_rocm %} int32_t num_warp_per_row_groups; From b72bdd8700b9a31d79fe595666ee4b6a2fa33cc4 Mon Sep 17 00:00:00 2001 From: yadai Date: Wed, 6 Aug 2025 11:29:38 +0000 Subject: [PATCH 48/92] support subwarp --- ...plit_table_batched_embeddings_benchmark.py | 525 +++++++++++------- fbgemm_gpu/codegen/genscript/optimizers.py | 36 ++ ...ding_backward_split_kernel_cta_template.cu | 2 +- ...ing_backward_split_kernel_warp_template.cu | 114 ++-- .../embedding_backward_split_template.cu | 23 +- ...optimizer_split_device_kernel_template.cuh | 198 ++++++- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 251 ++++++++- 7 files changed, 882 insertions(+), 267 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 4ffb7341a5..2d3755fe06 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -8,11 +8,13 @@ # pyre-strict +import gzip import logging import os import tempfile from contextlib import nullcontext -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Optional +import yaml import click import numpy as np @@ -1011,7 +1013,31 @@ def context_factory(on_trace_ready: Callable[[profile], None]): @TbeBenchClickInterface.common_options @TbeBenchClickInterface.device_options @TbeBenchClickInterface.vbe_options +@click.option("--batch-size", default=512) +@click.option("--embedding-dim-list", type=str, default="128") +@click.option("--weights-precision", type=SparseType, default=SparseType.FP32) +@click.option("--cache-precision", type=SparseType, default=None) +@click.option("--stoc", is_flag=True, default=False) +@click.option("--iters", default=100) +@click.option("--warmup-runs", default=0) +@click.option("--managed", default="device") +@click.option("--num-embeddings-list", type=str, default="100000") +@click.option("--reuse", default=0.0) +@click.option("--row-wise/--no-row-wise", default=True) +@click.option("--weighted", is_flag=True, default=False) +@click.option("--pooling", type=str, default="sum") +@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value) +@click.option("--flush-gpu-cache-size-mb", default=0) +@click.option("--output-dtype", type=SparseType, default=SparseType.FP32) +@click.option("--save", type=str, default=None) +@click.option("--load", type=str, default=None) +@click.option("--random-weights", is_flag=True, default=False) +@click.option("--compressed", is_flag=True, default=False) +@click.option("--slice-min", type=int, default=None) +@click.option("--slice-max", type=int, default=None) +@click.pass_context def device_with_spec( # noqa C901 + ctx, alpha: float, bag_size_list: str, bag_size_sigma_list: str, @@ -1031,7 +1057,40 @@ def device_with_spec( # noqa C901 bounds_check_mode: int, flush_gpu_cache_size_mb: int, output_dtype: SparseType, + save: str, + load: str, + random_weights: bool, + compressed: bool, + slice_min: int, + slice_max: int, ) -> None: + if load: + with open(f"{load}/params.yaml", "r") as f: + ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader) + alpha = ctx.params["alpha"] + bag_size_list = ctx.params["bag_size_list"] + bag_size_sigma_list = ctx.params["bag_size_sigma_list"] + batch_size = ctx.params["batch_size"] + embedding_dim_list = ctx.params["embedding_dim_list"] + weights_precision = ctx.params["weights_precision"] + cache_precision = ctx.params["cache_precision"] + stoc = ctx.params["stoc"] + iters = ctx.params["iters"] + warmup_runs = ctx.params["warmup_runs"] + managed = ctx.params["managed"] + num_embeddings_list = ctx.params["num_embeddings_list"] + reuse = ctx.params["reuse"] + row_wise = ctx.params["row_wise"] + weighted = ctx.params["weighted"] + pooling = ctx.params["pooling"] + bounds_check_mode = ctx.params["bounds_check_mode"] + flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"] + output_dtype = ctx.params["output_dtype"] + random_weights = ctx.params["random_weights"] + compressed = ctx.params["compressed"] + slice_min = ctx.params["slice_min"] + slice_max = ctx.params["slice_max"] + np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -1040,6 +1099,12 @@ def device_with_spec( # noqa C901 T = len(Ds) use_variable_bag_sizes = bag_size_sigma_list != "None" + + params = ctx.params + if save: + os.makedirs(f"{save}", exist_ok=True) + with open(f"{save}/params.yaml", "w") as f: + yaml.dump(params, f, sort_keys=False) if use_variable_bag_sizes: Ls = [int(mu) for mu in bag_size_list.split(",")] @@ -1118,6 +1183,22 @@ def device_with_spec( # noqa C901 if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) + elif random_weights: + emb.init_embedding_weights_uniform(-1.0, 1.0) + + if save: + if compressed: + with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/model_state.pth") + + if load: + if compressed: + with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: + emb.load_state_dict(torch.load(f)) + else: + emb.load_state_dict(torch.load(f"{load}/model_state.pth")) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 @@ -1130,52 +1211,68 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - for t, e in enumerate(Es): - # (indices, offsets, weights) - requests = generate_requests( - iters, - B, - 1, - Ls[t], - e, - reuse=reuse, - alpha=alpha, - weighted=weighted, - # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. - sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, - zipf_oversample_ratio=3 if Ls[t] > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - for i, req in enumerate(requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - prev_indices_len = -1 - requests = [] - for i in range(iters): - indices = torch.concat(all_requests["indices"][i]) - if prev_indices_len == -1: - prev_indices_len = indices.numel() - assert ( - prev_indices_len == indices.numel() - ), "Number of indices for every iteration must be the same" - offsets = torch.concat(all_requests["offsets"][i]) - if weighted: - weights = torch.concat(all_requests["weights"][i]) - else: - weights = None - requests.append(TBERequest(indices, offsets, weights)) - - del all_requests + if load: + requests = [] + for i in range(iters): + indices = torch.load(f"{load}/{i}_indices.pt") + offsets = torch.load(f"{load}/{i}_offsets.pt") + per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") + Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt") + requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank)) + else: + for t, e in enumerate(Es): + # (indices, offsets, weights) + requests = generate_requests( + iters, + B, + 1, + Ls[t], + e, + reuse=reuse, + alpha=alpha, + weighted=weighted, + # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. + sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, + zipf_oversample_ratio=3 if Ls[t] > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) + for i, req in enumerate(requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + prev_indices_len = -1 + requests = [] + for i in range(iters): + indices = torch.concat(all_requests["indices"][i]) + if prev_indices_len == -1: + prev_indices_len = indices.numel() + assert ( + prev_indices_len == indices.numel() + ), "Number of indices for every iteration must be the same" + offsets = torch.concat(all_requests["offsets"][i]) + if weighted: + weights = torch.concat(all_requests["weights"][i]) + else: + weights = None + requests.append(TBERequest(indices, offsets, weights)) + del all_requests + assert len(requests) == iters + if save: + for i in range(iters): + req = requests[i] + torch.save(req.indices, f"{save}/{i}_indices.pt") + torch.save(req.offsets, f"{save}/{i}_offsets.pt") + torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") + torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: @@ -1201,36 +1298,44 @@ def device_with_spec( # noqa C901 f"Accessed weights per batch: {B * sum_DLs * param_size_multiplier / 1.0e9: .2f} GB" ) + if load is None and save is None: # forward - time_per_iter = benchmark_requests( - requests, - lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - ), - flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, - num_warmups=warmup_runs, - ) - logging.info( - f"Forward, B: {B}, " - f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " - f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 - f"T: {time_per_iter * 1.0e6:.0f}us" - ) + time_per_iter = benchmark_requests( + requests, + lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, + ) + logging.info( + f"Forward, B: {B}, " + f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"T: {time_per_iter * 1.0e6:.0f}us" + ) if output_dtype == SparseType.INT8: # backward bench not representative return - if do_pooling: - grad_output = torch.randn(B, sum(Ds)).to(get_device()) + if load: + grad_output = torch.load(f"{load}/grad_output.pt") else: - # Obtain B * L from indices len - # pyre-ignore[19] - # pyre-fixme[61]: `D` is undefined, or not always defined. - grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) + else: + # Obtain B * L from indices len + # pyre-ignore[19] + # pyre-fixme[61]: `D` is undefined, or not always defined. + grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + + if save: + torch.save(grad_output, f"{save}/grad_output.pt") + # backward time_per_iter = benchmark_requests( requests, @@ -1244,6 +1349,12 @@ def device_with_spec( # noqa C901 bwd_only=True, grad=grad_output, num_warmups=warmup_runs, + emb=emb, + save=save, + load=load, + compressed=compressed, + slice_min=slice_min, + slice_max=slice_max, ) logging.info( f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, " @@ -1256,19 +1367,19 @@ def device_with_spec( # noqa C901 @click.option( "--batch-size-list", type=str, - required=True, + required=False, help="A comma separated list of batch sizes (B) for each table.", ) @click.option( "--embedding-dim-list", type=str, - required=True, + required=False, help="A comma separated list of embedding dimensions (D) for each table.", ) @click.option( "--bag-size-list", type=str, - required=True, + required=False, help="A comma separated list of bag sizes (L) for each table.", ) @click.option( @@ -1281,7 +1392,7 @@ def device_with_spec( # noqa C901 @click.option( "--num-embeddings-list", type=str, - required=True, + required=False, help="A comma separated list of number of embeddings (E) for each table.", ) @click.option( @@ -1294,7 +1405,7 @@ def device_with_spec( # noqa C901 @click.option( "--num-tables", type=int, - required=True, + required=False, help="The number of tables.", ) @click.option( @@ -1303,16 +1414,12 @@ def device_with_spec( # noqa C901 default=False, help="Whether the table is weighted or not", ) -@click.option( - "--print-kernel-summary", - is_flag=True, - default=False, - help="Whether the table is weighted or not", -) -@click.option("--ssd", is_flag=True, default=False) -@click.option( - "--ssd-prefix", type=str, default="/tmp/ssd_benchmark", help="SSD directory prefix" -) +@click.option("--save", type=str, default=None) +@click.option("--load", type=str, default=None) +@click.option("--random-weights", is_flag=True, default=False) +@click.option("--compressed", is_flag=True, default=False) +@click.option("--slice-min", type=int, default=None) +@click.option("--slice-max", type=int, default=None) @TBEBenchmarkingConfigLoader.options @EmbeddingOpsCommonConfigLoader.options @click.pass_context @@ -1326,9 +1433,12 @@ def vbe( alpha_list: str, num_tables: int, weighted: bool, - print_kernel_summary: bool, - ssd: bool, - ssd_prefix: str, + save: str, + load: str, + random_weights: bool, + compressed: bool, + slice_min: int, + slice_max: int, # pyre-ignore[2] **kwargs, ) -> None: @@ -1340,6 +1450,28 @@ def vbe( np.random.seed(42) torch.manual_seed(42) + if save: + os.makedirs(f"{save}", exist_ok=True) + with open(f"{save}/params.yaml", "w") as f: + yaml.dump(context.params, f, sort_keys=False) + + if load: + with open(f"{load}/params.yaml", "r") as f: + context.params = yaml.load(f, Loader=yaml.UnsafeLoader) + params = context.params + batch_size_list = params["batch_size_list"] + embedding_dim_list = params["embedding_dim_list"] + bag_size_list = params["bag_size_list"] + bag_size_sigma_list = params["bag_size_sigma_list"] + num_embeddings_list = params["num_embeddings_list"] + alpha_list = params["alpha_list"] + num_tables = params["num_tables"] + weighted = params["weighted"] + random_weights = params["random_weights"] + compressed = params["compressed"] + slice_min = params["slice_min"] + slice_max = params["slice_max"] + # Load general TBE benchmarking configuration from cli arguments benchconfig = TBEBenchmarkingConfigLoader.load(context) if benchconfig.num_requests != benchconfig.iterations: @@ -1348,6 +1480,9 @@ def vbe( if benchconfig.flush_gpu_cache_size_mb != 0: raise ValueError("--bench-flush-gpu-cache-size is not supported.") + if benchconfig.export_trace: + raise ValueError("--bench-export-trace is not supported.") + # Load common embedding op configuration from cli arguments embconfig = EmbeddingOpsCommonConfigLoader.load(context) if embconfig.uvm_host_mapped: @@ -1384,126 +1519,122 @@ def vbe( else EmbeddingLocation.HOST ) - common_split_args: dict[str, Any] = { - "weights_precision": embconfig.weights_dtype, - "stochastic_rounding": embconfig.stochastic_rounding, - "output_dtype": embconfig.output_dtype, - "pooling_mode": embconfig.pooling_mode, - "bounds_check_mode": embconfig.bounds_check_mode, - "optimizer": optimizer, - "learning_rate": 0.1, - "eps": 0.1, - "feature_table_map": list(range(T)), - } - - if ssd: - cache_set = max(T * max(Bs), 1) - tempdir = tempfile.mkdtemp(prefix=ssd_prefix) - emb = SSDTableBatchedEmbeddingBags( - [(E, D) for E, D in zip(Es, Ds)], - cache_sets=cache_set, - ssd_storage_directory=tempdir, - ssd_cache_location=EmbeddingLocation.DEVICE, - ssd_rocksdb_shards=8, - **common_split_args, - ) - else: - emb = SplitTableBatchedEmbeddingBagsCodegen( - [ - ( - E, - D, - managed_option, - get_available_compute_device(), - ) - for E, D in zip(Es, Ds) - ], - cache_precision=embconfig.cache_dtype, - **common_split_args, - ) - emb = emb.to(get_device()) - all_requests = { - "indices": [[] for _ in range(benchconfig.iterations)], - "offsets": [[] for _ in range(benchconfig.iterations)], - "weights": [[] for _ in range(benchconfig.iterations)], - } - for t, (E, B, L, sigma_L, alpha) in enumerate(zip(Es, Bs, Ls, sigma_Ls, alphas)): - # Generate a request for a single table. - local_requests = generate_requests( - benchconfig.iterations, - B, - 1, - L, - E, - alpha=alpha, - weighted=weighted, - sigma_L=sigma_L, - zipf_oversample_ratio=3 if L > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - - # Store requests for each table in all_requests. - for i, req in enumerate(local_requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - # pyre-ignore[53] - def _kineto_trace_handler( - p: profile, emb_op_type: str = "vbe", print_summary: bool = False - ) -> None: - p.export_chrome_trace( - benchconfig.trace_url.format(emb_op_type=emb_op_type, ospid=os.getpid()) - ) - if print_summary: - print(p.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + emb = SplitTableBatchedEmbeddingBagsCodegen( + [ + ( + E, + D, + managed_option, + get_available_compute_device(), + ) + for E, D in zip(Es, Ds) + ], + optimizer=optimizer, + learning_rate=0.1, + eps=0.1, + cache_precision=embconfig.cache_dtype, + weights_precision=embconfig.weights_dtype, + stochastic_rounding=embconfig.stochastic_rounding, + output_dtype=embconfig.output_dtype, + pooling_mode=embconfig.pooling_mode, + bounds_check_mode=embconfig.bounds_check_mode, + ).to(get_device()) + + if random_weights: + emb.init_embedding_weights_uniform(-1.0, 1.0) + + if save: + if compressed: + with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/model_state.pth") - emb_op_type = "vbe" + if load: + if compressed: + with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: + emb.load_state_dict(torch.load(f)) + else: + emb.load_state_dict(torch.load(f"{load}/model_state.pth")) - # pyre-ignore[3, 53] - def context_factory(on_trace_ready: Callable[[profile], None]): - return ( - profile(on_trace_ready=on_trace_ready) - if benchconfig.export_trace - else nullcontext() - ) + if load: + requests = [] + for i in range(benchconfig.iterations): + indices = torch.load(f"{load}/{i}_indices.pt") + offsets = torch.load(f"{load}/{i}_offsets.pt") + per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") + requests.append((indices, offsets, per_sample_weights)) + else: + all_requests = { + "indices": [[] for _ in range(benchconfig.iterations)], + "offsets": [[] for _ in range(benchconfig.iterations)], + "weights": [[] for _ in range(benchconfig.iterations)], + } + for t, (E, B, L, sigma_L, alpha) in enumerate(zip(Es, Bs, Ls, sigma_Ls, alphas)): + # Generate a request for a single table. + local_requests = generate_requests( + benchconfig.iterations, + B, + 1, + L, + E, + alpha=alpha, + weighted=weighted, + sigma_L=sigma_L, + zipf_oversample_ratio=3 if L > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) - # Combine the requests for all tables by - requests = [ - ( - torch.concat(all_requests["indices"][i]), - torch.concat(all_requests["offsets"][i]), - torch.concat(all_requests["weights"][i]) if weighted else None, - ) - for i in range(benchconfig.iterations) - ] + # Store requests for each table in all_requests. + for i, req in enumerate(local_requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + # Combine the requests for all tables by + requests = [ + ( + torch.concat(all_requests["indices"][i]), + torch.concat(all_requests["offsets"][i]), + torch.concat(all_requests["weights"][i]) if weighted else None, + ) + for i in range(benchconfig.iterations) + ] + + del all_requests - del all_requests + if save: + for i, (indices, offsets, weights) in enumerate(requests): + torch.save(indices, f"{save}/{i}_indices.pt") + torch.save(offsets, f"{save}/{i}_offsets.pt") + torch.save(weights, f"{save}/{i}_per_sample_weights.pt") - with context_factory( - lambda p: _kineto_trace_handler(p, emb_op_type, print_kernel_summary) - ): - fwd_time_sec, bwd_time_sec = benchmark_vbe( - requests, - func=lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - batch_size_per_feature_per_rank=[[B] for B in Bs], - ), - num_warmups=benchconfig.warmup_iterations, - ) + fwd_time_sec, bwd_time_sec = benchmark_vbe( + requests, + func=lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank=[[B] for B in Bs], + ), + num_warmups=benchconfig.warmup_iterations, + emb=emb, + save=save, + load=load, + compressed=compressed, + slice_min=slice_min, + slice_max=slice_max, + ) logging.info( f"T: {T}, Bs: {Bs}, Ds: {Ds}, Ls: {Ls}, Es: {Es}\n" f"fwd: {fwd_time_sec * 1.0e6:.0f}us, bwd: {bwd_time_sec * 1.0e6:.0f}us" ) - if __name__ == "__main__": cli() diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index c61e6843f9..8c25dc0d8f 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -197,6 +197,9 @@ def rowwise_adagrad() -> Dict[str, Any]: at::acc_type multiplier = 0.0; at::acc_type correction = 0.0; + """ + split_precomputation_preload = split_precomputation + split_precomputation += """ if (threadIdx.x == 0) { auto new_sum_square_grads = g_avg_square; @@ -228,6 +231,38 @@ def rowwise_adagrad() -> Dict[str, Any]: multiplier = SHFL_SYNC(multiplier, 0); correction = SHFL_SYNC(correction, 0); """ + split_precomputation_preload += """ + if (threadIdx.x == 0) { + auto new_sum_square_grads = g_avg_square; + + // Update the optimizer state. Use optimizer state offloading only if + // SSD and if enabled by the user + if (enable_optimizer_offloading) { + // Fetch the pointer to the optimizer state along the cache row + auto* optimizer = weight_row_template.template optimizer_state_ptr(); + new_sum_square_grads += optimizer->momentum; + optimizer->momentum = new_sum_square_grads; + + } else { + new_sum_square_grads += momentum1_val; + momentum1[idx] = new_sum_square_grads; + } + + multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); + if (weight_decay_mode == 1) { + // L2 regularization + correction = 1.0 - multiplier * weight_decay; + } else if (weight_decay_mode == 2 || weight_decay_mode == 5) { + // Decoupled weight decay + correction = 1.0 - learning_rate * weight_decay; + } else { + // default value + correction = 1.0; + } + } + multiplier = SHFL_SYNC(multiplier, 0); + correction = SHFL_SYNC(correction, 0); + """ split_weight_update_cpu = """ at::acc_type g_local_sum_square = 0.0; for (int64_t d = 0; d < D; ++d) { @@ -275,6 +310,7 @@ def rowwise_adagrad() -> Dict[str, Any]: }, ), "split_precomputation": split_precomputation, + "split_precomputation_preload": split_precomputation_preload, "split_weight_update": split_weight_update, "split_post_update": split_post_update, "split_weight_update_cpu": split_weight_update_cpu, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu index 25f7119a7a..b10eb1312e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu @@ -625,7 +625,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row codegen/embedding_common_code_generator.py for more details */ #} -{{ instantiate_templates(use_subwarp_shuffle=False) }} +{{ instantiate_templates(use_subwarp_shuffle=True) }} //////////////////////////////////////////////////////////////////////////////// #endif diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 959c617efd..185bf8650e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -426,6 +426,8 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc {%- endif %} const auto start_run_id = blockIdx.x * blockDim.y + threadIdx.y; +#define SUBWARP_SHFL_SYNC(val, srcLane) __shfl_sync(UINT64_MAX, val, srcLane, kThreadGroupSize) + #ifdef FBGEMM_USE_SUBWARP_SHUFFLE const unsigned int shfl_sync_mask = ((1L << kThreadGroupSize) - 1) << @@ -445,7 +447,7 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc ? smem.getPointer() + threadIdx.y * grad_sum_stride : nullptr; - constexpr int num_unroll = 32; + constexpr int num_unroll = kThreadGroupSize; auto num_run_id = min(sorted_linear_indices_run.size(0), sorted_linear_indices_num_runs[0]); @@ -476,39 +478,49 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc int32_t s_table_unique_indice_offset = is_valid? table_unique_indices_offsets[s_t_0] : 0; int64_t s_weights_offset = is_valid? weights_offsets[s_t_0] : 0; - int64_t s_momentum1_offset = is_valid? momentum1_offsets[s_t_0] : 0; int32_t s_weights_placement = is_valid? weights_placements[s_t_0] : 0; - int32_t s_momentum1_placement = is_valid? momentum1_placements[s_t_0] : 0; - at::acc_type* __restrict__ s_momentum1; - if (static_cast(s_momentum1_placement) == PlacementType::DEVICE) { - s_momentum1 = &momentum1_dev[s_momentum1_offset]; + {%- for tensor in args.split_tensors %} + {{ args.split_tensor_types[tensor] }}* __restrict__ s_{{ tensor }}; + const auto s_{{ tensor }}_placement = {{ tensor }}_placements[s_t_0]; + const int64_t s_{{ tensor }}_offset = {{ tensor }}_offsets[s_t_0]; + if (static_cast(s_{{ tensor }}_placement) == PlacementType::DEVICE) { + s_{{ tensor }} = &{{ tensor }}_dev[s_{{ tensor }}_offset]; } else { - s_momentum1 = &momentum1_uvm[s_momentum1_offset]; + s_{{ tensor }} = &{{ tensor }}_uvm[s_{{ tensor }}_offset]; } + {{ args.split_tensor_types[tensor] }} s_{{tensor}}_val = is_valid? s_{{tensor}}[s_idx] : 0; + + {%- endfor %} for (auto i = 0; i < num_valid_id; ++i) { - auto run_id = out_run_id + i; - auto t_0 = BROADCAST(s_t_0, i); - auto idx = BROADCAST(s_idx, i); - auto segment_start = BROADCAST(s_segment_start, i); - auto segment_end = BROADCAST(s_segment_end, i); - auto D = BROADCAST(s_D, i); - int32_t table_unique_indice_offset = BROADCAST(s_table_unique_indice_offset, i); + auto segment_start = SUBWARP_SHFL_SYNC(s_segment_start, i); + auto segment_end = SUBWARP_SHFL_SYNC(s_segment_end, i); const int32_t SL = segment_end - segment_start; - - const int64_t weights_offset = SHFL_SYNC(s_weights_offset, i); - const auto weights_placement = static_cast(SHFL_SYNC(s_weights_placement, i)); - - const int64_t momentum1_offset = SHFL_SYNC(s_momentum1_offset, i); - const auto momentum1_placement = static_cast(SHFL_SYNC(s_momentum1_placement, i)); - auto momentum1 = reinterpret_cast*>(SHFL_SYNC(reinterpret_cast(s_momentum1), i)); - auto momentum1_val = momentum1[idx]; - if (SL >= max_segment_length_per_warp) { continue; } + auto run_id = out_run_id + i; + auto t_0 = SUBWARP_SHFL_SYNC(s_t_0, i); + auto idx = SUBWARP_SHFL_SYNC(s_idx, i); + + {%- if not nobag %} + auto D = SUBWARP_SHFL_SYNC(s_D, i); + {%- endif %} + int32_t table_unique_indice_offset = SUBWARP_SHFL_SYNC(s_table_unique_indice_offset, i); + + {%- for tensor in args.split_tensors %} + const auto {{ tensor }}_placement = SUBWARP_SHFL_SYNC(s_{{ tensor }}_placement, i); + const int64_t {{ tensor }}_offset = SUBWARP_SHFL_SYNC(s_{{ tensor }}_offset, i); + {{ args.split_tensor_types[tensor] }} {{tensor}}_val = SUBWARP_SHFL_SYNC(s_{{ tensor }}_val, i); + {%- endfor %} + + // const int64_t momentum1_offset = SHFL_SYNC(s_momentum1_offset, i); + // const auto momentum1_placement = static_cast(SHFL_SYNC(s_momentum1_placement, i)); + // auto momentum1 = reinterpret_cast*>(SHFL_SYNC(reinterpret_cast(s_momentum1), i)); + // auto momentum1_val = momentum1[idx]; + // now, each segment corresponds to exactly one table `t` and row in // that table (`idx`). Thus, we can hoist out some of the book-keeping. @@ -558,7 +570,11 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc // when kUseVecBlocking == false const int32_t max_vecs = kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; - split_rowwise_adagrad_table_update_kernel< + + {%- if not dense and optimizer != "none" %} + const int64_t weights_offset = SUBWARP_SHFL_SYNC(s_weights_offset, i); + const int32_t weights_placement = SUBWARP_SHFL_SYNC(s_weights_placement, i); + {{ mdesc }}_{{ optimizer }}_table_update_kernel< emb_t, cache_t, {%- for ph_name in args.placeholder_tensor_names %} @@ -571,8 +587,8 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc dev_weights, uvm_weights, lxu_cache_weights, - weights_placements, - weights_offsets, + weights_placement, + weights_offset, sorted_{{ locs_or_addrs_tensor }}, grad_sum, smem_grad_sum, @@ -594,8 +610,42 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc {%- endif %} shfl_sync_mask, max_vecs, - momentum1, momentum1_val, learning_rate, eps, weight_decay, weight_decay_mode, max_norm + {%- if ssd %} + enable_optimizer_offloading, + {%- endif %} + {%- for tensor in args.split_tensors %} + {{ tensor }}_placement, + {{ tensor }}_offset, + {{ tensor }}_val, + {%- endfor %} + {{ args.split_kernel_arg_names | join(", ") }} + ); + {%- else %} + // Write deduplicated gradient to grad_dev_weights gradient is sparse + // for split_embedding and dense for dense_embedding + {%- if dense %} + const int64_t weights_offset = weights_offsets[t_0]; + {%- else %} + // Compute offset of sparse gradient + const int64_t weights_offset = run_id * max_D; + idx = 0; + {%- endif %} + store_grad_sum< + emb_t, + cache_t, + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + grad_dev_weights, + grad_sum, + kUseVecBlocking ? smem_grad_sum : nullptr, + D, + weights_offset, + idx, + max_vecs ); + {%- endif %} // if not dense and optimizer != "none" } } } @@ -877,7 +927,7 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc codegen/embedding_common_code_generator.py for more details */ #} -{{ instantiate_templates(use_subwarp_shuffle=False) }} +{{ instantiate_templates(use_subwarp_shuffle=True) }} //////////////////////////////////////////////////////////////////////////////// #endif @@ -1101,10 +1151,10 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} - {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} - {%- for cache_type in ['float', 'at::Half'] %} - {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} + {%- for emb_type in (['float', 'at::Half', 'at::BFloat16'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} + {%- for cache_type in ['float', 'at::Half', 'at::BFloat16'] %} + {%- for index_type in ['int32_t', 'int64_t', 'at::BFloat16'] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 41679bd7ad..81fed327e7 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1232,9 +1232,9 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; {%- if is_optimized_hip_kernel_supported_mode %} - if (use_hip_kernel && mixed_D) { + if (!kUseVecBlocking) { backward_cta_per_row_kernel = - {{ hip_mixed_d_cta_kernel }} + {{ cta_kernel }} ; + 1, + 32, + false>; } {%- endif %} @@ -1282,7 +1282,7 @@ Tensor {{ embedding_cuda_op }}( FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, - dim3(kThreadGroupSize, num_cta_per_row_groups), + dim3(32, num_cta_per_row_groups), cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), grad_output_accessor, @@ -1385,7 +1385,8 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; {%- if is_optimized_hip_kernel_supported_mode %} - if (use_hip_kernel && mixed_D) { + if (!kUseVecBlocking) { + printf("%s:%d call here\n", __FILE__, __LINE__); backward_warp_per_row_kernel = {{ hip_mixed_d_warp_kernel }} ; + 1, + 32, + false>; } {%- endif %} @@ -1429,6 +1430,7 @@ Tensor {{ embedding_cuda_op }}( } auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); + // auto blockSize = dim3(32, num_warp_per_row_groups); int32_t warp_per_row_grid_size = std::min( div_round_up(total_unique_indices, num_warp_per_row_groups), @@ -1470,7 +1472,6 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} #endif - FBGEMM_LAUNCH_KERNEL( backward_warp_per_row_kernel, warp_per_row_grid_size, diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh index e4fb6c548c..ef1a011e1d 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh @@ -11,8 +11,42 @@ #include "fbgemm_gpu/utils/tensor_accessor_builder.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" -#define GROUP_REDUCE_ALL_SUM(val, ...) \ - warpReduceAllSum<__VA_ARGS__, kThreadGroupSize>(val, shfl_sync_mask) +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + +template +DEVICE_INLINE __device__ T subwarp_reduce_add(T value) { + static_assert(kThreadGroupSize == 8 || kThreadGroupSize == 16 || kThreadGroupSize == 32 || kThreadGroupSize == 64, "Wavefront size must be 16/32/64"); + if (kThreadGroupSize == 16) { + // Reduce across 4 groups of 16 threads + value += __shfl_xor(value, 1, 16); + value += __shfl_xor(value, 2, 16); + value += __shfl_xor(value, 4, 16); + value += __shfl_xor(value, 8, 16); + } else if (kThreadGroupSize == 32) { + // Reduce across 2 groups of 32 threads + value += __shfl_xor(value, 1, 32); + value += __shfl_xor(value, 2, 32); + value += __shfl_xor(value, 4, 32); + value += __shfl_xor(value, 8, 32); + value += __shfl_xor(value, 16, 32); + } else if (kThreadGroupSize == 64) { + value += __shfl_xor(value, 1, 64); + value += __shfl_xor(value, 2, 64); + value += __shfl_xor(value, 4, 64); + value += __shfl_xor(value, 8, 64); + value += __shfl_xor(value, 16, 64); + value += __shfl_xor(value, 32, 64); + } + return value; +} + +#define GROUP_REDUCE_ALL_SUM(val, ...) subwarp_reduce_add(val) {%- set mdesc = "ssd" if ssd else "split" %} {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} @@ -176,4 +210,164 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( {{ split_post_update }} } +{%- if is_optimized_hip_kernel_supported_mode %} +template < + typename emb_t, + typename cache_t, + {%- for ph_name in args.placeholder_tensor_names %} + {%- set ph_type = "{}_ph_t".format(ph_name) %} + typename {{ ph_type }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize = kWarpSize, + int32_t VEC_WIDTH, + bool kUseVecBlocking +> +DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( + pta::PackedTensorAccessor64& dev_weights, + pta::PackedTensorAccessor64& uvm_weights, + pta::PackedTensorAccessor64& lxu_cache_weights, + const int32_t weights_placement, + const int64_t weights_offset, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits>& sorted_{{ locs_or_addrs_tensor }}, + Vec4TAcc* grad_sum, + Vec4TAcc* smem_grad_sum, + Vec4TAcc* shared_weight_update_row, + const bool stochastic_rounding, + const at::PhiloxCudaState& stochastic_rounding_philox_args, + const uint32_t run_id, + const uint32_t cache_loc_run_id, + const int32_t D, + const int32_t t, + const int64_t idx, + {%- if has_global_weight_decay_support %} + const float global_weight_decay, + {%- endif %} + const uint32_t shfl_sync_mask, + const int32_t max_vecs_per_thread, + {%- if ssd %} + const bool enable_optimizer_offloading, + {%- endif %} + {%- for tensor in args.split_tensors %} + const int32_t {{ tensor }}_placement, + const int64_t {{ tensor }}_offset, + const int64_t {{ tensor }}_val, + {%- endfor %} + {{ args.split_ref_kernel_args | replace_pta_namespace() | join(",\n ") }} +) { + constexpr auto kIsInt8 = std::is_same_v; + // Copy value to max_vecs to make max_vecs_per_thread known at compile time + // when kUseVecBlocking == false + const int32_t max_vecs = + kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; + emb_t* __restrict__ weights {nullptr}; + cache_t* __restrict__ cache_weights {nullptr}; + int32_t D_emb = D; + if constexpr (kIsInt8) { + D_emb += kINT8QparamsBytes; + } + if (static_cast(weights_placement) == PlacementType::DEVICE) { + weights = &dev_weights[weights_offset + idx * D_emb]; + } else { + weights = {{ "nullptr" if ssd else "&uvm_weights[weights_offset + idx * D_emb]" }}; + } + if (static_cast(weights_placement) == PlacementType::MANAGED_CACHING) { + const auto {{ locs_or_addrs_idx }} = sorted_{{ locs_or_addrs_tensor }}[cache_loc_run_id]; + {%- if ssd %} + cache_weights = reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }})); + {%- else %} + if ({{ locs_or_addrs_idx }} != kCacheLocationMissing) { + cache_weights = &lxu_cache_weights[{{ locs_or_addrs_idx }}][0]; + } + {%- endif %} + } + {%- for tensor in args.split_tensors %} + {{ args.split_tensor_types[tensor] }}* __restrict__ {{ tensor }}; + // const auto {{ tensor }}_placement = static_cast({{ tensor }}_placements[t]); + // const int64_t {{ tensor }}_offset = {{ tensor }}_offsets[t]; + if (static_cast({{ tensor }}_placement) == PlacementType::DEVICE) { + {{ tensor }} = &{{ tensor }}_dev[{{ tensor }}_offset]; + } else { + {{ tensor }} = &{{ tensor }}_uvm[{{ tensor }}_offset]; + } + {%- endfor %} + + auto weight_row_template = + WeightRow>( + weights, + cache_weights, + D, + stochastic_rounding, + &stochastic_rounding_philox_args, + threadIdx.x + run_id * blockDim.x); + + float2 qparams_template; + if constexpr (kIsInt8) { + if (!cache_weights) { + qparams_template = weight_row_template.load_qparams(); + } + } + + {%- if not ssd %} + [[maybe_unused]] constexpr auto enable_optimizer_offloading = false; + {%- endif %} + + {{ split_precomputation_preload }} + + {# /* Note: technically, global weight decay (gwd) compensation should be done before + `split_precomputation`). But since decouple mode in `rowwise_adagrad` only computes correction, + the order of applying gwd does not matter. We perform gwd update before `split_weight_update` + below to minimize number of times to load weights. + So, note that the behavior may be different if you want to enable gwd for other optimizers + such as `lamb` or `partial_rowwise_lamb`. + */#} + float2 qparams_new; + {{ + generate_optimized_grad_sum_loop_access( + """ + Vec4TAcc weight_new = weight_row_template.load(d, qparams_template); + Vec4TAcc& grad = {grad_vec}; + {global_weight_decay_update} + {split_weight_update} + if (kIsInt8 && !cache_weights) { + shared_weight_update_row[d_vec] = weight_new; + } else { + // qparams_new not used if type is not int8 + weight_row_template.store(weight_new, d, qparams_new); + } + """, + other_formats={ + "split_weight_update": split_weight_update, + "global_weight_decay_update": "weight_new.mul_(global_weight_decay);" if has_global_weight_decay_support else "" + }, + ) + }} + + if constexpr (kIsInt8) { + if (!cache_weights) { + // Calculate new qparams after row update + qparams_new = thrust_find_qparams>( + shared_weight_update_row, D); + weight_row_template.store_qparams(qparams_new); + + // Fetch cached updated row from shared mem and quantize on-the-fly + // when saving to lowp embedding + for (int32_t vec = 0; + (vec * kThreadGroupSize + threadIdx.x) * VEC_WIDTH < D; + ++vec) { + const auto d_vec = vec * kThreadGroupSize + threadIdx.x; + const int32_t d = d_vec * VEC_WIDTH; + weight_row_template.store( + shared_weight_update_row[d_vec], + d, + qparams_new); + } + } + } + + {{ split_post_update }} +} +{%- endif %} + // clang-format on diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 1243f14db4..da502d1c21 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -11,6 +11,7 @@ import statistics import threading import time +import gzip from subprocess import Popen from typing import Callable, Optional @@ -18,6 +19,9 @@ from fbgemm_gpu.tbe.utils import b_indices, TBERequest from fbgemm_gpu.tbe.utils.common import get_device +from fbgemm_gpu.split_table_batched_embeddings_ops_training import SplitTableBatchedEmbeddingBagsCodegen + +import copy logging.basicConfig(level=logging.DEBUG) @@ -241,35 +245,43 @@ def benchmark_requests( # noqa: C901 periodic_logs: bool = False, warmup_ms: Optional[int] = None, iters: int = -1, + emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, + save: Optional[str] = None, + load: Optional[str] = None, + compressed: bool = False, + slice_min: Optional[int] = None, + slice_max: Optional[int] = None, ) -> float: times = [] # Run at least one warmup iteration to avoid the long cudaLaunchKernel time # for the first kernel if warmup_ms > 0 # warmup_ms is prioritized over num_warmups - + import copy if warmup_ms is None: num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 - # warm-up the GPU before profiling - bench_warmup( - requests[0], - # pyre-ignore[6] - warmup_ms, - num_warmups, - lambda indices, offsets, per_sample_weights: func( - indices, - offsets, - per_sample_weights, - ), - bwd_only=bwd_only, - grad=grad, - ) + if not (load or save): + # warm-up the GPU before profiling + bench_warmup( + requests[0], + # pyre-ignore[6] + warmup_ms, + num_warmups, + lambda indices, offsets, per_sample_weights: func( + indices, + offsets, + per_sample_weights, + ), + bwd_only=bwd_only, + grad=grad, + ) - if callback_after_warmup is not None: - callback_after_warmup() + if callback_after_warmup is not None: + callback_after_warmup() num_reqs = len(requests) iters = num_reqs if iters == -1 else iters + sliced = slice_min is not None and slice_max is not None if torch.cuda.is_available(): torch.cuda.synchronize() @@ -279,6 +291,94 @@ def benchmark_requests( # noqa: C901 start_events = [] end_events = [] + if save and emb: + for it in range(iters): + req = requests[it % num_reqs] + + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + if compressed: + with gzip.open(f"{save}/{it}_fwd_grad_out.pt.gz", "wb") as f: + torch.save(out, f) + else: + torch.save(out, f"{save}/{it}_fwd_grad_out.pt") + + out.backward(grad) + torch.cuda.synchronize() + torch.save(out, f"{save}/{it}_bwd_grad_out.pt") + + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: + torch.save(t[slice_min:slice_max,:].clone(), f) + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") + torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") + + else: + if compressed: + with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") + + if load and emb: + for it in range(iters): + req = requests[it % num_reqs] + + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + + out_ref = torch.load(f"{load}/{it}_fwd_grad_out.pt") + torch.testing.assert_close(out, out_ref, atol=1.0e-3, rtol=1.0e-3) + + print(f"[{it + 1}/{iters}] Forward output check... ", end="", flush=True) + print("FWD PASS") + + out.backward(grad) + torch.cuda.synchronize() + emb_ref = copy.deepcopy(emb) + if not sliced: + if compressed: + with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: + emb_ref.load_state_dict(torch.load(f)) + else: + emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) + + print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: + w_ref = torch.load(f) + else: + w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") + torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + else: + for id, t in enumerate(emb.split_embedding_weights()): + torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + print("PASS") + + print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) + if sliced: + m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") + m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") + else: + m_dev_ref = emb_ref.momentum1_dev + m_uvm_ref = emb_ref.momentum1_uvm + torch.testing.assert_close(emb.momentum1_dev, m_dev_ref) + torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref) + print("PASS") + + for it in range(iters): req = requests[it % num_reqs] @@ -602,6 +702,12 @@ def benchmark_vbe( requests: list[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor], num_warmups: int = 0, + emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, + save: Optional[str] = None, + load: Optional[str] = None, + compressed: bool = False, + slice_min: Optional[int] = None, + slice_max: Optional[int] = None, ) -> tuple[float, float]: """ A benchmark function to return the average execution time in seconds of @@ -626,14 +732,16 @@ def benchmark_vbe( """ use_cuda = torch.cuda.is_available() + sliced = slice_min is not None and slice_max is not None + if not (load or save): # Warm-ups. - for _ in range(num_warmups): - # Warm-up using the first request as done in benchmark_requests - indices, offsets, weights = requests[0] - out = func(indices, offsets, weights) - grad = torch.rand_like(out) - out.backward(grad) + for _ in range(num_warmups): + # Warm-up using the first request as done in benchmark_requests + indices, offsets, weights = requests[0] + out = func(indices, offsets, weights) + grad = torch.rand_like(out) + out.backward(grad) iters = len(requests) if use_cuda: @@ -647,6 +755,101 @@ def benchmark_vbe( fwd_times_sec = [] bwd_times_sec = [] + if save and emb: + for it, req in enumerate(requests): + + indices, offsets, weights = req + out = func(indices, offsets, weights) + torch.cuda.synchronize() + + torch.save(out, f"{save}/{it}_fwd_out.pt") + + grad = torch.rand_like(out) + if compressed: + with gzip.open(f"{save}/{it}_grad.pt.gz", "wb") as f: + torch.save(grad, f) + else: + torch.save(grad, f"{save}/{it}_grad.pt") + + out.backward(grad) + torch.cuda.synchronize() + + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: + torch.save(t[slice_min:slice_max,:].clone(), f) + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") + torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") + + else: + if compressed: + with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") + + if load and emb: + for it, req in enumerate(requests): + + indices, offsets, weights = req + out = func(indices, offsets, weights) + torch.cuda.synchronize() + + out_ref = torch.load(f"{load}/{it}_fwd_out.pt") + torch.testing.assert_close(out, out_ref, atol=1.0e-3, rtol=1.0e-3) + + print(f"[{it + 1}/{iters}] Forward output check... ", end="", flush=True) + print("FWD PASS") + + if compressed: + with gzip.open(f"{load}/{it}_grad.pt.gz", "rb") as f: + grad = torch.load(f) + else: + grad = torch.load(f"{load}/{it}_grad.pt") + + out.backward(grad) + torch.cuda.synchronize() + emb_ref = copy.deepcopy(emb) + if not sliced: + if compressed: + with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: + emb_ref.load_state_dict(torch.load(f)) + else: + emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) + + print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: + w_ref = torch.load(f) + else: + w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") + torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + else: + for id, t in enumerate(emb.split_embedding_weights()): + torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + print("PASS") + + print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) + if sliced: + m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") + m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") + else: + m_dev_ref = emb_ref.momentum1_dev + m_uvm_ref = emb_ref.momentum1_uvm + torch.testing.assert_close(emb.momentum1_dev, m_dev_ref) + torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref) + print("PASS") + + for i, (indices, offsets, weights) in enumerate(requests): # forward if use_cuda: @@ -699,4 +902,4 @@ def benchmark_vbe( # pyre-ignore[61] bwd_time_sec = statistics.median(bwd_times_sec) - return fwd_time_sec, bwd_time_sec + return fwd_time_sec, bwd_time_sec \ No newline at end of file From 6343a4f05229f2de0039662482d8206e2e51cdbb Mon Sep 17 00:00:00 2001 From: Wulley Date: Tue, 28 Oct 2025 08:59:14 +0000 Subject: [PATCH 49/92] update subwarp kernel --- ...ing_backward_split_kernel_warp_template.cu | 1 + .../embedding_backward_split_template.cu | 49 ++++++++++++++----- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 185bf8650e..3a8f4977ab 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -46,6 +46,7 @@ not dense and not is_index_select and not is_gwd_kernel and + not nobag and not vbe and not ssd %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 81fed327e7..8f2a50c3fe 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -62,7 +62,8 @@ using namespace fbgemm_gpu; not dense and not is_index_select and not is_gwd_kernel and - not vbe and + not vbe and + not nobag and not ssd %} template < @@ -1231,8 +1232,10 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; + int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; {%- if is_optimized_hip_kernel_supported_mode %} - if (!kUseVecBlocking) { + auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); + if (max_D <= 128) { backward_cta_per_row_kernel = {{ cta_kernel }} ; + + auto cta_blockSize = dim3(32, num_cta_per_row_groups); } + {%- else %} + auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); {%- endif %} // Compute shared memory size for cta_per_row @@ -1282,7 +1289,7 @@ Tensor {{ embedding_cuda_op }}( FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, - dim3(32, num_cta_per_row_groups), + cta_blockSize, cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), grad_output_accessor, @@ -1384,9 +1391,10 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; + int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; {%- if is_optimized_hip_kernel_supported_mode %} - if (!kUseVecBlocking) { - printf("%s:%d call here\n", __FILE__, __LINE__); + auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); + if (use_hip_kernel && mixed_D) { backward_warp_per_row_kernel = {{ hip_mixed_d_warp_kernel }} ; + kFixedMaxVecsPerThread, + kThreadGroupSize, + kUseVecBlocking>; + if (max_D <= 128) { + backward_warp_per_row_kernel = + {{ hip_mixed_d_warp_kernel }} + ; + + blockSize = dim3(32, num_warp_per_row_groups); + } } + {%- else %} + // Compute shared memory size for warp_per_row + auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); {%- endif %} - // Compute shared memory size for warp_per_row {%- if is_rocm %} int32_t num_warp_per_row_groups; if (total_L/total_B > 1){ @@ -1414,6 +1440,7 @@ Tensor {{ embedding_cuda_op }}( {%- else %} int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; {%- endif %} + int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { @@ -1428,10 +1455,6 @@ Tensor {{ embedding_cuda_op }}( backward_warp_per_row_kernel, used_shared_bytes); } - - auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); - // auto blockSize = dim3(32, num_warp_per_row_groups); - int32_t warp_per_row_grid_size = std::min( div_round_up(total_unique_indices, num_warp_per_row_groups), get_max_thread_blocks_()); From c3860725f1d72639e302963bfc6a3942c72b4313 Mon Sep 17 00:00:00 2001 From: xzhu Date: Mon, 27 Oct 2025 03:02:34 +0000 Subject: [PATCH 50/92] grad sum kernel unroll improvement --- ..._backward_split_device_kernel_template.cuh | 144 +++++++++++++----- 1 file changed, 106 insertions(+), 38 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index b9db6e47f8..d58f67bcb0 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -14,6 +14,98 @@ using namespace fbgemm_gpu; +// Helper macro: Generate block_size grad_offset_j_i variables (i from 1 to block_size-1) +#define GRAD_OFFSET(i, j) const auto grad_offset_j_##i = SHFL_SYNC(grad_offset, j + i); +#define L(i, j) int32_t l_j_##i = SHFL_SYNC(l, j + i); +#define B(i, j) int32_t b_j_##i = SHFL_SYNC(b, j + i); +#define D_START(i, j) int32_t D_start_j_##i = SHFL_SYNC(D_start, j + i); +#define IDX_WEIGHT(i, j) at::acc_type idx_weight_j_##i = SHFL_SYNC(idx_weight, j + i); + +#define REPEAT_8(X, j) X(1, j); X(2, j); X(3, j); X(4, j); X(5, j); X(6, j); X(7, j); +#define REPEAT_4(X, j) X(1, j); X(2, j); X(3, j); +#define REPEAT_2(X, j) X(1, j); +#define REPEAT_1(X, j) // No additional variables needed for block size 1 + +#define REPEAT_I_S_8(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n); X(4, j, m, n); X(5, j, m, n); X(6, j, m, n); X(7, j, m, n); +#define REPEAT_I_S_4(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n); +#define REPEAT_I_S_2(X, j, m, n) X(1, j, m, n); +#define REPEAT_I_S_1(X, j, m, n) // No additional variables needed for block size 1 + +// Helper macro: Generate block_size Vec4TAcc objects (i from 1 to block_size-1) +// if nobag and is_index_select +#define GRAD_VEC_N_I(i, grad_offset, grad_stride, d) Vec4TAcc grad_out_vec_##i(&grad_output[grad_offset + l_j_##i * grad_stride + d]); +// elif nobag +#define GRAD_VEC_N(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[l_j_##i][d]); +// elif vbe +#define GRAD_VEC_V(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[0][grad_offset_j_##i + d]); +// else +#define GRAD_VEC(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[b_j_##i][0] + D_start_j_##i + d); + +// Helper macro: Generate block_size fma_ calls (i from 1 to block_size-1) +#define FMA_GRAD(i, vec) grad_sum[vec].fma_(grad_out_vec_##i, idx_weight_j_##i); +// Helper macro: Generate block_size add_ calls (i from 1 to block_size-1) +#define ADD_GRAD(i, vec) grad_sum[vec].add_(grad_out_vec_##i); + +// Core macro: Process blocks of specified size (block_size = 8/4/2/1) +// Parameters: +// - block_size: Size of each block to process +// - unroll_count: Number of unroll iterations for the inner loop +#define PROCESS_BLOCK(block_size, unroll_count, grad_sum, grad_output, grad_offset, vec_start, kThreadGroupSize, threadIdx_x, VEC_WIDTH, D, j, sl, sl_end) \ + for (; j + (block_size - 1) < kThreadGroupSize && sl + j + (block_size - 1) < sl_end; j += block_size) { \ + {%- if nobag %} + int32_t l_j_0 = SHFL_SYNC(l, j); \ + REPEAT_##block_size(L, j) \ + {%- elif vbe %} + /* Generate block_size grad_offset_j_0 ~ grad_offset_j_(block_size-1) */ \ + const auto grad_offset_j_0 = SHFL_SYNC(grad_offset, j); \ + /* Generate subsequent grad_offset_j_1 ~ grad_offset_j_(block_size-1) based on block size */ \ + REPEAT_##block_size(GRAD_OFFSET, j) \ + {%- else %} + int32_t b_j_0 = SHFL_SYNC(b, j); \ + REPEAT_##block_size(B, j) \ + int32_t D_start_j_0 = SHFL_SYNC(D_start, j); \ + REPEAT_##block_size(D_START, j) \ + {%- endif %} + {%- if weighted %} + at::acc_type idx_weight_j_0 = SHFL_SYNC(idx_weight, j); \ + REPEAT_##block_size(IDX_WEIGHT, j) \ + {%- endif %} + {%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %} + \ + for (int32_t vec = 0; vec < unroll_count && (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH) < D; ++vec) { \ + const int32_t d = (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH); \ + /* Generate block_size Vec4TAcc objects and accumulate them */ \ + Vec4TAcc grad_out_vec_0( \ + {%- if nobag and is_index_select %} + &grad_output[grad_offset + l_j_0 * grad_stride + d] \ + {%- elif nobag %} + &grad_output[l_j_0][d] \ + {%- elif vbe %} + &grad_output[0][grad_offset_j_0 + d] \ + {%- else %} + &grad_output[b_j_0][0] + D_start_j_0 + d \ + {%- endif %} + ); \ + {%- if nobag and is_index_select %} + REPEAT_I_S_##block_size(GRAD_VEC_N_I, grad_offset, grad_stride, d) \ + {%- elif nobag %} + REPEAT_##block_size(GRAD_VEC_N, d) \ + {%- elif vbe %} + REPEAT_##block_size(GRAD_VEC_V, d) \ + {%- else %} + REPEAT_##block_size(GRAD_VEC, d) \ + {%- endif %} + \ + {%- if weighted %} + grad_sum[vec].fma_(grad_out_vec_0, idx_weight_j_0); \ + REPEAT_##block_size(FMA_GRAD, vec) \ + {%- else %} + grad_sum[vec].add_(grad_out_vec_0); \ + REPEAT_##block_size(ADD_GRAD, vec) \ + {%- endif %} + } \ + } + {%- if gen_once %} {#- /* The kernels in this section will be generated only once for all TBE configs @@ -141,45 +233,21 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( ? sorted_indice_weights[segment_start + sl_j] : 0.0; {%- endif %} - for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; ++j) { - {%- if nobag %} - int32_t l_j = SHFL_SYNC(l, j); - {%- elif vbe %} - const auto grad_offset_j = SHFL_SYNC(grad_offset, j); - {%- else %} - int32_t b_j = SHFL_SYNC(b, j); - int32_t D_start_j = SHFL_SYNC(D_start, j); - {%- endif %} - - {%- if weighted %} - at::acc_type idx_weight_j = SHFL_SYNC(idx_weight, j); - {%- endif %} + int32_t j = 0; - {%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %} - - #pragma unroll kFixedMaxVecsPerThread - for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) { - const int32_t d = {{ d }}; - Vec4TAcc grad_out_vec( - {%- if nobag and is_index_select %} - // grad_output is 1d - &grad_output[grad_offset + l_j * grad_stride + d] - {%- elif nobag %} - &grad_output[l_j][d] - {%- elif vbe %} - &grad_output[0][grad_offset_j + d] - {%- else %} - &grad_output[b_j][0] + D_start_j + d - {%- endif %} // if nobag - ); - - {%- if weighted %} - grad_sum[vec].fma_(grad_out_vec, idx_weight_j); - {%- else %} - grad_sum[vec].add_(grad_out_vec); - {%- endif %} - } - } + // Process blocks of different sizes with loop unrolling + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(4, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(2, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(1, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) } {%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %} From 7bf6dd8a366fa8e74b45460283e5272ae9bd1aac Mon Sep 17 00:00:00 2001 From: yadai Date: Wed, 29 Oct 2025 08:29:02 +0000 Subject: [PATCH 51/92] fix performance issuse --- .../embedding_backward_split_template.cu | 69 +++++++++---------- ...optimizer_split_device_kernel_template.cuh | 2 +- 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 8f2a50c3fe..8181853423 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1232,7 +1232,22 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + {% if is_rocm %} + int32_t total_L = indices.numel(); + int32_t num_cta_per_row_groups; + int32_t work_group_size; + if (total_L/total_B > 1){ + num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + work_group_size = (kMaxThreads/4); + } + else{ + num_cta_per_row_groups = kMaxThreads / kWarpSize; + work_group_size = kMaxThreads; + } + {%- else %} + int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + int32_t work_group_size = kMaxThreads; + {%- endif %} {%- if is_optimized_hip_kernel_supported_mode %} auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); if (max_D <= 128) { @@ -1249,30 +1264,15 @@ Tensor {{ embedding_cuda_op }}( 32, false>; - auto cta_blockSize = dim3(32, num_cta_per_row_groups); + cta_blockSize = dim3(32, num_cta_per_row_groups); } {%- else %} auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); {%- endif %} + // printf("%s:%d %d\n", __FILE__, __LINE__, num_cta_per_row_groups); // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - {% if is_rocm %} - int32_t total_L = indices.numel(); - int32_t num_cta_per_row_groups; - int32_t work_group_size; - if (total_L/total_B > 1){ - num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; - work_group_size = (kMaxThreads/4); - } - else{ - num_cta_per_row_groups = kMaxThreads / kWarpSize; - work_group_size = kMaxThreads; - } - {%- else %} - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; - int32_t work_group_size = kMaxThreads; - {%- endif %} const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1391,9 +1391,20 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; - int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; - {%- if is_optimized_hip_kernel_supported_mode %} + {%- if is_rocm %} + int32_t num_warp_per_row_groups; + if (total_L/total_B > 1){ + num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; + } + else{ + num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + } + {%- else %} + int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + {%- endif %} auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); + {%- if is_optimized_hip_kernel_supported_mode %} + // printf("%s:%d warp kernel %d %d %d\n", __FILE__, __LINE__, num_warp_per_row_groups, use_hip_kernel, mixed_D); if (use_hip_kernel && mixed_D) { backward_warp_per_row_kernel = {{ hip_mixed_d_warp_kernel }} @@ -1420,27 +1431,11 @@ Tensor {{ embedding_cuda_op }}( 1, 32, false>; - blockSize = dim3(32, num_warp_per_row_groups); + // printf("%s:%d warp kernel %d\n", __FILE__, __LINE__, num_warp_per_row_groups); } } - {%- else %} - // Compute shared memory size for warp_per_row - auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); {%- endif %} - - {%- if is_rocm %} - int32_t num_warp_per_row_groups; - if (total_L/total_B > 1){ - num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; - } - else{ - num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; - } - {%- else %} - int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; - {%- endif %} - int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh index ef1a011e1d..514d8428b9 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh @@ -251,7 +251,7 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( {%- for tensor in args.split_tensors %} const int32_t {{ tensor }}_placement, const int64_t {{ tensor }}_offset, - const int64_t {{ tensor }}_val, + const {{ args.split_tensor_types[tensor] }} {{ tensor }}_val, {%- endfor %} {{ args.split_ref_kernel_args | replace_pta_namespace() | join(",\n ") }} ) { From fb7f0a88ad06a82626bce116ca963d78d898347c Mon Sep 17 00:00:00 2001 From: Wulley Date: Sun, 2 Nov 2025 08:03:11 +0000 Subject: [PATCH 52/92] fix vbe opt not imply --- ...plit_table_batched_embeddings_benchmark.py | 527 +++++++----------- ..._backward_split_device_kernel_template.cuh | 8 +- ...ing_backward_split_kernel_warp_template.cu | 20 +- .../embedding_backward_split_template.cu | 38 +- ...optimizer_split_device_kernel_template.cuh | 16 +- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 249 +-------- 6 files changed, 259 insertions(+), 599 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 2d3755fe06..4dd8b3dbb3 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -8,13 +8,11 @@ # pyre-strict -import gzip import logging import os import tempfile from contextlib import nullcontext -from typing import Any, Callable, Dict, Optional -import yaml +from typing import Any, Callable, Optional import click import numpy as np @@ -1013,31 +1011,7 @@ def context_factory(on_trace_ready: Callable[[profile], None]): @TbeBenchClickInterface.common_options @TbeBenchClickInterface.device_options @TbeBenchClickInterface.vbe_options -@click.option("--batch-size", default=512) -@click.option("--embedding-dim-list", type=str, default="128") -@click.option("--weights-precision", type=SparseType, default=SparseType.FP32) -@click.option("--cache-precision", type=SparseType, default=None) -@click.option("--stoc", is_flag=True, default=False) -@click.option("--iters", default=100) -@click.option("--warmup-runs", default=0) -@click.option("--managed", default="device") -@click.option("--num-embeddings-list", type=str, default="100000") -@click.option("--reuse", default=0.0) -@click.option("--row-wise/--no-row-wise", default=True) -@click.option("--weighted", is_flag=True, default=False) -@click.option("--pooling", type=str, default="sum") -@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value) -@click.option("--flush-gpu-cache-size-mb", default=0) -@click.option("--output-dtype", type=SparseType, default=SparseType.FP32) -@click.option("--save", type=str, default=None) -@click.option("--load", type=str, default=None) -@click.option("--random-weights", is_flag=True, default=False) -@click.option("--compressed", is_flag=True, default=False) -@click.option("--slice-min", type=int, default=None) -@click.option("--slice-max", type=int, default=None) -@click.pass_context def device_with_spec( # noqa C901 - ctx, alpha: float, bag_size_list: str, bag_size_sigma_list: str, @@ -1057,40 +1031,7 @@ def device_with_spec( # noqa C901 bounds_check_mode: int, flush_gpu_cache_size_mb: int, output_dtype: SparseType, - save: str, - load: str, - random_weights: bool, - compressed: bool, - slice_min: int, - slice_max: int, ) -> None: - if load: - with open(f"{load}/params.yaml", "r") as f: - ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader) - alpha = ctx.params["alpha"] - bag_size_list = ctx.params["bag_size_list"] - bag_size_sigma_list = ctx.params["bag_size_sigma_list"] - batch_size = ctx.params["batch_size"] - embedding_dim_list = ctx.params["embedding_dim_list"] - weights_precision = ctx.params["weights_precision"] - cache_precision = ctx.params["cache_precision"] - stoc = ctx.params["stoc"] - iters = ctx.params["iters"] - warmup_runs = ctx.params["warmup_runs"] - managed = ctx.params["managed"] - num_embeddings_list = ctx.params["num_embeddings_list"] - reuse = ctx.params["reuse"] - row_wise = ctx.params["row_wise"] - weighted = ctx.params["weighted"] - pooling = ctx.params["pooling"] - bounds_check_mode = ctx.params["bounds_check_mode"] - flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"] - output_dtype = ctx.params["output_dtype"] - random_weights = ctx.params["random_weights"] - compressed = ctx.params["compressed"] - slice_min = ctx.params["slice_min"] - slice_max = ctx.params["slice_max"] - np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -1099,12 +1040,6 @@ def device_with_spec( # noqa C901 T = len(Ds) use_variable_bag_sizes = bag_size_sigma_list != "None" - - params = ctx.params - if save: - os.makedirs(f"{save}", exist_ok=True) - with open(f"{save}/params.yaml", "w") as f: - yaml.dump(params, f, sort_keys=False) if use_variable_bag_sizes: Ls = [int(mu) for mu in bag_size_list.split(",")] @@ -1183,22 +1118,6 @@ def device_with_spec( # noqa C901 if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) - elif random_weights: - emb.init_embedding_weights_uniform(-1.0, 1.0) - - if save: - if compressed: - with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/model_state.pth") - - if load: - if compressed: - with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: - emb.load_state_dict(torch.load(f)) - else: - emb.load_state_dict(torch.load(f"{load}/model_state.pth")) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 @@ -1211,68 +1130,52 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - if load: - requests = [] - for i in range(iters): - indices = torch.load(f"{load}/{i}_indices.pt") - offsets = torch.load(f"{load}/{i}_offsets.pt") - per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") - Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt") - requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank)) - else: - for t, e in enumerate(Es): - # (indices, offsets, weights) - requests = generate_requests( - iters, - B, - 1, - Ls[t], - e, - reuse=reuse, - alpha=alpha, - weighted=weighted, - # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. - sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, - zipf_oversample_ratio=3 if Ls[t] > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - for i, req in enumerate(requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - prev_indices_len = -1 - requests = [] - for i in range(iters): - indices = torch.concat(all_requests["indices"][i]) - if prev_indices_len == -1: - prev_indices_len = indices.numel() - assert ( - prev_indices_len == indices.numel() - ), "Number of indices for every iteration must be the same" - offsets = torch.concat(all_requests["offsets"][i]) - if weighted: - weights = torch.concat(all_requests["weights"][i]) - else: - weights = None - requests.append(TBERequest(indices, offsets, weights)) + for t, e in enumerate(Es): + # (indices, offsets, weights) + requests = generate_requests( + iters, + B, + 1, + Ls[t], + e, + reuse=reuse, + alpha=alpha, + weighted=weighted, + # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. + sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, + zipf_oversample_ratio=3 if Ls[t] > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) + for i, req in enumerate(requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + prev_indices_len = -1 + requests = [] + for i in range(iters): + indices = torch.concat(all_requests["indices"][i]) + if prev_indices_len == -1: + prev_indices_len = indices.numel() + assert ( + prev_indices_len == indices.numel() + ), "Number of indices for every iteration must be the same" + offsets = torch.concat(all_requests["offsets"][i]) + if weighted: + weights = torch.concat(all_requests["weights"][i]) + else: + weights = None + requests.append(TBERequest(indices, offsets, weights)) + + del all_requests - del all_requests - assert len(requests) == iters - if save: - for i in range(iters): - req = requests[i] - torch.save(req.indices, f"{save}/{i}_indices.pt") - torch.save(req.offsets, f"{save}/{i}_offsets.pt") - torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") - torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: @@ -1298,44 +1201,36 @@ def device_with_spec( # noqa C901 f"Accessed weights per batch: {B * sum_DLs * param_size_multiplier / 1.0e9: .2f} GB" ) - if load is None and save is None: # forward - time_per_iter = benchmark_requests( - requests, - lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - ), - flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, - num_warmups=warmup_runs, - ) - logging.info( - f"Forward, B: {B}, " - f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " - f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 - f"T: {time_per_iter * 1.0e6:.0f}us" - ) + time_per_iter = benchmark_requests( + requests, + lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, + ) + logging.info( + f"Forward, B: {B}, " + f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"T: {time_per_iter * 1.0e6:.0f}us" + ) if output_dtype == SparseType.INT8: # backward bench not representative return - if load: - grad_output = torch.load(f"{load}/grad_output.pt") + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) else: - if do_pooling: - grad_output = torch.randn(B, sum(Ds)).to(get_device()) - else: - # Obtain B * L from indices len - # pyre-ignore[19] - # pyre-fixme[61]: `D` is undefined, or not always defined. - grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) - - if save: - torch.save(grad_output, f"{save}/grad_output.pt") - + # Obtain B * L from indices len + # pyre-ignore[19] + # pyre-fixme[61]: `D` is undefined, or not always defined. + grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) # backward time_per_iter = benchmark_requests( requests, @@ -1349,12 +1244,6 @@ def device_with_spec( # noqa C901 bwd_only=True, grad=grad_output, num_warmups=warmup_runs, - emb=emb, - save=save, - load=load, - compressed=compressed, - slice_min=slice_min, - slice_max=slice_max, ) logging.info( f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, " @@ -1367,19 +1256,19 @@ def device_with_spec( # noqa C901 @click.option( "--batch-size-list", type=str, - required=False, + required=True, help="A comma separated list of batch sizes (B) for each table.", ) @click.option( "--embedding-dim-list", type=str, - required=False, + required=True, help="A comma separated list of embedding dimensions (D) for each table.", ) @click.option( "--bag-size-list", type=str, - required=False, + required=True, help="A comma separated list of bag sizes (L) for each table.", ) @click.option( @@ -1392,7 +1281,7 @@ def device_with_spec( # noqa C901 @click.option( "--num-embeddings-list", type=str, - required=False, + required=True, help="A comma separated list of number of embeddings (E) for each table.", ) @click.option( @@ -1405,7 +1294,7 @@ def device_with_spec( # noqa C901 @click.option( "--num-tables", type=int, - required=False, + required=True, help="The number of tables.", ) @click.option( @@ -1414,12 +1303,16 @@ def device_with_spec( # noqa C901 default=False, help="Whether the table is weighted or not", ) -@click.option("--save", type=str, default=None) -@click.option("--load", type=str, default=None) -@click.option("--random-weights", is_flag=True, default=False) -@click.option("--compressed", is_flag=True, default=False) -@click.option("--slice-min", type=int, default=None) -@click.option("--slice-max", type=int, default=None) +@click.option( + "--print-kernel-summary", + is_flag=True, + default=False, + help="Whether the table is weighted or not", +) +@click.option("--ssd", is_flag=True, default=False) +@click.option( + "--ssd-prefix", type=str, default="/tmp/ssd_benchmark", help="SSD directory prefix" +) @TBEBenchmarkingConfigLoader.options @EmbeddingOpsCommonConfigLoader.options @click.pass_context @@ -1433,12 +1326,9 @@ def vbe( alpha_list: str, num_tables: int, weighted: bool, - save: str, - load: str, - random_weights: bool, - compressed: bool, - slice_min: int, - slice_max: int, + print_kernel_summary: bool, + ssd: bool, + ssd_prefix: str, # pyre-ignore[2] **kwargs, ) -> None: @@ -1450,28 +1340,6 @@ def vbe( np.random.seed(42) torch.manual_seed(42) - if save: - os.makedirs(f"{save}", exist_ok=True) - with open(f"{save}/params.yaml", "w") as f: - yaml.dump(context.params, f, sort_keys=False) - - if load: - with open(f"{load}/params.yaml", "r") as f: - context.params = yaml.load(f, Loader=yaml.UnsafeLoader) - params = context.params - batch_size_list = params["batch_size_list"] - embedding_dim_list = params["embedding_dim_list"] - bag_size_list = params["bag_size_list"] - bag_size_sigma_list = params["bag_size_sigma_list"] - num_embeddings_list = params["num_embeddings_list"] - alpha_list = params["alpha_list"] - num_tables = params["num_tables"] - weighted = params["weighted"] - random_weights = params["random_weights"] - compressed = params["compressed"] - slice_min = params["slice_min"] - slice_max = params["slice_max"] - # Load general TBE benchmarking configuration from cli arguments benchconfig = TBEBenchmarkingConfigLoader.load(context) if benchconfig.num_requests != benchconfig.iterations: @@ -1480,9 +1348,6 @@ def vbe( if benchconfig.flush_gpu_cache_size_mb != 0: raise ValueError("--bench-flush-gpu-cache-size is not supported.") - if benchconfig.export_trace: - raise ValueError("--bench-export-trace is not supported.") - # Load common embedding op configuration from cli arguments embconfig = EmbeddingOpsCommonConfigLoader.load(context) if embconfig.uvm_host_mapped: @@ -1519,122 +1384,126 @@ def vbe( else EmbeddingLocation.HOST ) - emb = SplitTableBatchedEmbeddingBagsCodegen( - [ - ( - E, - D, - managed_option, - get_available_compute_device(), - ) - for E, D in zip(Es, Ds) - ], - optimizer=optimizer, - learning_rate=0.1, - eps=0.1, - cache_precision=embconfig.cache_dtype, - weights_precision=embconfig.weights_dtype, - stochastic_rounding=embconfig.stochastic_rounding, - output_dtype=embconfig.output_dtype, - pooling_mode=embconfig.pooling_mode, - bounds_check_mode=embconfig.bounds_check_mode, - ).to(get_device()) - - if random_weights: - emb.init_embedding_weights_uniform(-1.0, 1.0) - - if save: - if compressed: - with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/model_state.pth") - - if load: - if compressed: - with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: - emb.load_state_dict(torch.load(f)) - else: - emb.load_state_dict(torch.load(f"{load}/model_state.pth")) + common_split_args: dict[str, Any] = { + "weights_precision": embconfig.weights_dtype, + "stochastic_rounding": embconfig.stochastic_rounding, + "output_dtype": embconfig.output_dtype, + "pooling_mode": embconfig.pooling_mode, + "bounds_check_mode": embconfig.bounds_check_mode, + "optimizer": optimizer, + "learning_rate": 0.1, + "eps": 0.1, + "feature_table_map": list(range(T)), + } - if load: - requests = [] - for i in range(benchconfig.iterations): - indices = torch.load(f"{load}/{i}_indices.pt") - offsets = torch.load(f"{load}/{i}_offsets.pt") - per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") - requests.append((indices, offsets, per_sample_weights)) + if ssd: + cache_set = max(T * max(Bs), 1) + tempdir = tempfile.mkdtemp(prefix=ssd_prefix) + emb = SSDTableBatchedEmbeddingBags( + [(E, D) for E, D in zip(Es, Ds)], + cache_sets=cache_set, + ssd_storage_directory=tempdir, + ssd_cache_location=EmbeddingLocation.DEVICE, + ssd_rocksdb_shards=8, + **common_split_args, + ) else: - all_requests = { - "indices": [[] for _ in range(benchconfig.iterations)], - "offsets": [[] for _ in range(benchconfig.iterations)], - "weights": [[] for _ in range(benchconfig.iterations)], - } - for t, (E, B, L, sigma_L, alpha) in enumerate(zip(Es, Bs, Ls, sigma_Ls, alphas)): - # Generate a request for a single table. - local_requests = generate_requests( - benchconfig.iterations, - B, - 1, - L, - E, - alpha=alpha, - weighted=weighted, - sigma_L=sigma_L, - zipf_oversample_ratio=3 if L > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) + emb = SplitTableBatchedEmbeddingBagsCodegen( + [ + ( + E, + D, + managed_option, + get_available_compute_device(), + ) + for E, D in zip(Es, Ds) + ], + cache_precision=embconfig.cache_dtype, + **common_split_args, + ) + emb = emb.to(get_device()) + all_requests = { + "indices": [[] for _ in range(benchconfig.iterations)], + "offsets": [[] for _ in range(benchconfig.iterations)], + "weights": [[] for _ in range(benchconfig.iterations)], + } + for t, (E, B, L, sigma_L, alpha) in enumerate(zip(Es, Bs, Ls, sigma_Ls, alphas)): + # Generate a request for a single table. + local_requests = generate_requests( + benchconfig.iterations, + B, + 1, + L, + E, + alpha=alpha, + weighted=weighted, + sigma_L=sigma_L, + zipf_oversample_ratio=3 if L > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) - # Store requests for each table in all_requests. - for i, req in enumerate(local_requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - # Combine the requests for all tables by - requests = [ - ( - torch.concat(all_requests["indices"][i]), - torch.concat(all_requests["offsets"][i]), - torch.concat(all_requests["weights"][i]) if weighted else None, - ) - for i in range(benchconfig.iterations) - ] - - del all_requests + # Store requests for each table in all_requests. + for i, req in enumerate(local_requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) - if save: - for i, (indices, offsets, weights) in enumerate(requests): - torch.save(indices, f"{save}/{i}_indices.pt") - torch.save(offsets, f"{save}/{i}_offsets.pt") - torch.save(weights, f"{save}/{i}_per_sample_weights.pt") + # pyre-ignore[53] + def _kineto_trace_handler( + p: profile, emb_op_type: str = "vbe", print_summary: bool = False + ) -> None: + p.export_chrome_trace( + benchconfig.trace_url.format(emb_op_type=emb_op_type, ospid=os.getpid()) + ) + if print_summary: + print(p.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - fwd_time_sec, bwd_time_sec = benchmark_vbe( - requests, - func=lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - batch_size_per_feature_per_rank=[[B] for B in Bs], - ), - num_warmups=benchconfig.warmup_iterations, - emb=emb, - save=save, - load=load, - compressed=compressed, - slice_min=slice_min, - slice_max=slice_max, - ) + emb_op_type = "vbe" + + # pyre-ignore[3, 53] + def context_factory(on_trace_ready: Callable[[profile], None]): + return ( + profile(on_trace_ready=on_trace_ready) + if benchconfig.export_trace + else nullcontext() + ) + + # Combine the requests for all tables by + requests = [ + ( + torch.concat(all_requests["indices"][i]), + torch.concat(all_requests["offsets"][i]), + torch.concat(all_requests["weights"][i]) if weighted else None, + ) + for i in range(benchconfig.iterations) + ] + + del all_requests + + with context_factory( + lambda p: _kineto_trace_handler(p, emb_op_type, print_kernel_summary) + ): + fwd_time_sec, bwd_time_sec = benchmark_vbe( + requests, + func=lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank=[[B] for B in Bs], + ), + num_warmups=benchconfig.warmup_iterations, + ) logging.info( f"T: {T}, Bs: {Bs}, Ds: {Ds}, Ls: {Ls}, Es: {Es}\n" f"fwd: {fwd_time_sec * 1.0e6:.0f}us, bwd: {bwd_time_sec * 1.0e6:.0f}us" ) + if __name__ == "__main__": - cli() + cli() \ No newline at end of file diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index d58f67bcb0..6e25c40f10 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -236,9 +236,11 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( int32_t j = 0; // Process blocks of different sizes with loop unrolling - #pragma unroll kFixedMaxVecsPerThread - PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ - vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + if constexpr (sizeof(grad_t) <= 2) { + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + } #pragma unroll kFixedMaxVecsPerThread PROCESS_BLOCK(4, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 3a8f4977ab..56b1fc344d 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -41,14 +41,13 @@ not vbe and not ssd %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and - optimizer == "rowwise_adagrad" and - not dense and - not is_index_select and - not is_gwd_kernel and - not nobag and - not vbe and - not ssd %} +{%- set enable_optimized_hip_mixed_D_kernel = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not nobag and + not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" @@ -350,7 +349,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( } } -{%- if is_optimized_hip_kernel_supported_mode %} +{%- if enable_optimized_hip_mixed_D_kernel %} template < typename emb_t, typename grad_t, @@ -453,7 +452,6 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc auto num_run_id = min(sorted_linear_indices_run.size(0), sorted_linear_indices_num_runs[0]); for (uint32_t out_run_id = start_run_id * num_unroll; out_run_id < num_run_id; out_run_id += gridDim.x * blockDim.y * num_unroll) { - auto stride = gridDim.x * blockDim.y; auto num_valid_id = min(num_unroll, num_run_id - out_run_id); auto is_valid = threadIdx.x < num_valid_id; @@ -767,7 +765,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endif %} ); -{%- if is_optimized_hip_kernel_supported_mode %} +{%- if enable_optimized_hip_mixed_D_kernel %} template __global__ __launch_bounds__(kBackwardMaxThreads) void hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1 diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 8181853423..7e53e32cc6 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -57,14 +57,13 @@ using namespace fbgemm_gpu; not vbe and not ssd %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and - optimizer == "rowwise_adagrad" and - not dense and - not is_index_select and - not is_gwd_kernel and - not vbe and - not nobag and - not ssd %} +{%- set enable_optimized_hip_mixed_D_kernel = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not nobag and + not ssd %} template < typename emb_t, @@ -317,7 +316,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd ); {%- endif %} -{%- if is_optimized_hip_kernel_supported_mode %} +{%- if enable_optimized_hip_mixed_D_kernel %} template < typename emb_t, @@ -1030,7 +1029,7 @@ Tensor {{ embedding_cuda_op }}( %} {%- endif %} - {%- if is_optimized_hip_kernel_supported_mode %} + {%- if enable_optimized_hip_mixed_D_kernel %} {%- set hip_mixed_d_warp_kernel = "hip_mixed_d_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, @@ -1038,14 +1037,6 @@ Tensor {{ embedding_cuda_op }}( vdesc, ) %} - - {%- set hip_mixed_d_cta_kernel = "hip_mixed_d_split_embedding{}_backward_codegen_{}_{}{}_kernel_cta_per_row_1".format( - ndesc, - optimizer, - wdesc, - vdesc, - ) - %} {%- endif %} AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] { @@ -1197,7 +1188,7 @@ Tensor {{ embedding_cuda_op }}( {use_deterministic_algorithms ? 0 : grad_accum_counter.numel(), max_D}, aligned_grad_output.options().dtype(std::is_same::value ? at::kDouble : at::kFloat)); - {%- if is_optimized_hip_kernel_supported_mode %} + {%- if enable_optimized_hip_mixed_D_kernel %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); {%- endif %} @@ -1248,7 +1239,7 @@ Tensor {{ embedding_cuda_op }}( int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; int32_t work_group_size = kMaxThreads; {%- endif %} - {%- if is_optimized_hip_kernel_supported_mode %} + {%- if enable_optimized_hip_mixed_D_kernel %} auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); if (max_D <= 128) { backward_cta_per_row_kernel = @@ -1403,9 +1394,12 @@ Tensor {{ embedding_cuda_op }}( int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; {%- endif %} auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); - {%- if is_optimized_hip_kernel_supported_mode %} - // printf("%s:%d warp kernel %d %d %d\n", __FILE__, __LINE__, num_warp_per_row_groups, use_hip_kernel, mixed_D); + {%- if enable_optimized_hip_mixed_D_kernel %} + {%- if vbe %} + if (use_hip_kernel) { + {%- else %} if (use_hip_kernel && mixed_D) { + {%- endif %} backward_warp_per_row_kernel = {{ hip_mixed_d_warp_kernel }} DEVICE_INLINE __device__ T subwarp_reduce_add(T value) { @@ -210,7 +210,7 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( {{ split_post_update }} } -{%- if is_optimized_hip_kernel_supported_mode %} +{%- if enable_optimized_hip_mixed_D_kernel %} template < typename emb_t, typename cache_t, diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index da502d1c21..ae805870bd 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -11,7 +11,6 @@ import statistics import threading import time -import gzip from subprocess import Popen from typing import Callable, Optional @@ -19,9 +18,6 @@ from fbgemm_gpu.tbe.utils import b_indices, TBERequest from fbgemm_gpu.tbe.utils.common import get_device -from fbgemm_gpu.split_table_batched_embeddings_ops_training import SplitTableBatchedEmbeddingBagsCodegen - -import copy logging.basicConfig(level=logging.DEBUG) @@ -245,43 +241,35 @@ def benchmark_requests( # noqa: C901 periodic_logs: bool = False, warmup_ms: Optional[int] = None, iters: int = -1, - emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, - save: Optional[str] = None, - load: Optional[str] = None, - compressed: bool = False, - slice_min: Optional[int] = None, - slice_max: Optional[int] = None, ) -> float: times = [] # Run at least one warmup iteration to avoid the long cudaLaunchKernel time # for the first kernel if warmup_ms > 0 # warmup_ms is prioritized over num_warmups - import copy + if warmup_ms is None: num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 - if not (load or save): - # warm-up the GPU before profiling - bench_warmup( - requests[0], - # pyre-ignore[6] - warmup_ms, - num_warmups, - lambda indices, offsets, per_sample_weights: func( - indices, - offsets, - per_sample_weights, - ), - bwd_only=bwd_only, - grad=grad, - ) + # warm-up the GPU before profiling + bench_warmup( + requests[0], + # pyre-ignore[6] + warmup_ms, + num_warmups, + lambda indices, offsets, per_sample_weights: func( + indices, + offsets, + per_sample_weights, + ), + bwd_only=bwd_only, + grad=grad, + ) - if callback_after_warmup is not None: - callback_after_warmup() + if callback_after_warmup is not None: + callback_after_warmup() num_reqs = len(requests) iters = num_reqs if iters == -1 else iters - sliced = slice_min is not None and slice_max is not None if torch.cuda.is_available(): torch.cuda.synchronize() @@ -291,94 +279,6 @@ def benchmark_requests( # noqa: C901 start_events = [] end_events = [] - if save and emb: - for it in range(iters): - req = requests[it % num_reqs] - - indices, offsets, weights = req.unpack_3() - out = emb(indices, offsets, weights) - torch.cuda.synchronize() - if compressed: - with gzip.open(f"{save}/{it}_fwd_grad_out.pt.gz", "wb") as f: - torch.save(out, f) - else: - torch.save(out, f"{save}/{it}_fwd_grad_out.pt") - - out.backward(grad) - torch.cuda.synchronize() - torch.save(out, f"{save}/{it}_bwd_grad_out.pt") - - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: - torch.save(t[slice_min:slice_max,:].clone(), f) - else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") - torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") - - else: - if compressed: - with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") - - if load and emb: - for it in range(iters): - req = requests[it % num_reqs] - - indices, offsets, weights = req.unpack_3() - out = emb(indices, offsets, weights) - torch.cuda.synchronize() - - out_ref = torch.load(f"{load}/{it}_fwd_grad_out.pt") - torch.testing.assert_close(out, out_ref, atol=1.0e-3, rtol=1.0e-3) - - print(f"[{it + 1}/{iters}] Forward output check... ", end="", flush=True) - print("FWD PASS") - - out.backward(grad) - torch.cuda.synchronize() - emb_ref = copy.deepcopy(emb) - if not sliced: - if compressed: - with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: - emb_ref.load_state_dict(torch.load(f)) - else: - emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) - - print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: - w_ref = torch.load(f) - else: - w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") - torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - else: - for id, t in enumerate(emb.split_embedding_weights()): - torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - print("PASS") - - print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) - if sliced: - m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") - m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") - else: - m_dev_ref = emb_ref.momentum1_dev - m_uvm_ref = emb_ref.momentum1_uvm - torch.testing.assert_close(emb.momentum1_dev, m_dev_ref) - torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref) - print("PASS") - - for it in range(iters): req = requests[it % num_reqs] @@ -702,12 +602,6 @@ def benchmark_vbe( requests: list[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor], num_warmups: int = 0, - emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, - save: Optional[str] = None, - load: Optional[str] = None, - compressed: bool = False, - slice_min: Optional[int] = None, - slice_max: Optional[int] = None, ) -> tuple[float, float]: """ A benchmark function to return the average execution time in seconds of @@ -732,16 +626,14 @@ def benchmark_vbe( """ use_cuda = torch.cuda.is_available() - sliced = slice_min is not None and slice_max is not None - if not (load or save): # Warm-ups. - for _ in range(num_warmups): - # Warm-up using the first request as done in benchmark_requests - indices, offsets, weights = requests[0] - out = func(indices, offsets, weights) - grad = torch.rand_like(out) - out.backward(grad) + for _ in range(num_warmups): + # Warm-up using the first request as done in benchmark_requests + indices, offsets, weights = requests[0] + out = func(indices, offsets, weights) + grad = torch.rand_like(out) + out.backward(grad) iters = len(requests) if use_cuda: @@ -755,101 +647,6 @@ def benchmark_vbe( fwd_times_sec = [] bwd_times_sec = [] - if save and emb: - for it, req in enumerate(requests): - - indices, offsets, weights = req - out = func(indices, offsets, weights) - torch.cuda.synchronize() - - torch.save(out, f"{save}/{it}_fwd_out.pt") - - grad = torch.rand_like(out) - if compressed: - with gzip.open(f"{save}/{it}_grad.pt.gz", "wb") as f: - torch.save(grad, f) - else: - torch.save(grad, f"{save}/{it}_grad.pt") - - out.backward(grad) - torch.cuda.synchronize() - - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: - torch.save(t[slice_min:slice_max,:].clone(), f) - else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") - torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") - - else: - if compressed: - with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") - - if load and emb: - for it, req in enumerate(requests): - - indices, offsets, weights = req - out = func(indices, offsets, weights) - torch.cuda.synchronize() - - out_ref = torch.load(f"{load}/{it}_fwd_out.pt") - torch.testing.assert_close(out, out_ref, atol=1.0e-3, rtol=1.0e-3) - - print(f"[{it + 1}/{iters}] Forward output check... ", end="", flush=True) - print("FWD PASS") - - if compressed: - with gzip.open(f"{load}/{it}_grad.pt.gz", "rb") as f: - grad = torch.load(f) - else: - grad = torch.load(f"{load}/{it}_grad.pt") - - out.backward(grad) - torch.cuda.synchronize() - emb_ref = copy.deepcopy(emb) - if not sliced: - if compressed: - with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: - emb_ref.load_state_dict(torch.load(f)) - else: - emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) - - print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: - w_ref = torch.load(f) - else: - w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") - torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - else: - for id, t in enumerate(emb.split_embedding_weights()): - torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - print("PASS") - - print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) - if sliced: - m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") - m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") - else: - m_dev_ref = emb_ref.momentum1_dev - m_uvm_ref = emb_ref.momentum1_uvm - torch.testing.assert_close(emb.momentum1_dev, m_dev_ref) - torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref) - print("PASS") - - for i, (indices, offsets, weights) in enumerate(requests): # forward if use_cuda: From bec6a6928acab3bb5f2b57178a5da7cbf0785b0e Mon Sep 17 00:00:00 2001 From: Wulley Date: Mon, 3 Nov 2025 03:08:43 +0000 Subject: [PATCH 53/92] fix smybol bug & rm comment --- .../embedding_backward_split_kernel_warp_template.cu | 12 ++++++------ .../backward/embedding_backward_split_template.cu | 10 ++++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 56b1fc344d..091b8d8001 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,7 +32,7 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} -{%- set is_optimized_hip_kernel_supported_mode_ori = is_rocm and +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and optimizer == "rowwise_adagrad" and not dense and not nobag and @@ -934,7 +934,7 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc {%- endif %} -{%- if is_optimized_hip_kernel_supported_mode_ori %} +{%- if is_optimized_hip_kernel_supported_mode %} #include #include #include "fbgemm_gpu/rocm/split_embeddings_common.h" @@ -1150,10 +1150,10 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} - {%- for emb_type in (['float', 'at::Half', 'at::BFloat16'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} - {%- for cache_type in ['float', 'at::Half', 'at::BFloat16'] %} - {%- for index_type in ['int32_t', 'int64_t', 'at::BFloat16'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} + {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} + {%- for cache_type in ['float', 'at::Half'] %} + {%- for index_type in ['int32_t', 'int64_t'] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 7e53e32cc6..f88b413bdb 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,7 +48,7 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} -{%- set is_optimized_hip_kernel_supported_mode_ori = is_rocm and +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and optimizer == "rowwise_adagrad" and not dense and not nobag and @@ -244,7 +244,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_optimized_hip_kernel_supported_mode_ori %} +{%- if is_optimized_hip_kernel_supported_mode %} #include "fbgemm_gpu/rocm/split_embeddings_common.h" template < typename emb_t, @@ -1019,7 +1019,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_optimized_hip_kernel_supported_mode_ori %} + {%- if is_optimized_hip_kernel_supported_mode %} {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, @@ -1261,7 +1261,6 @@ Tensor {{ embedding_cuda_op }}( auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); {%- endif %} - // printf("%s:%d %d\n", __FILE__, __LINE__, num_cta_per_row_groups); // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( @@ -1426,7 +1425,6 @@ Tensor {{ embedding_cuda_op }}( 32, false>; blockSize = dim3(32, num_warp_per_row_groups); - // printf("%s:%d warp kernel %d\n", __FILE__, __LINE__, num_warp_per_row_groups); } } {%- endif %} @@ -1449,7 +1447,7 @@ Tensor {{ embedding_cuda_op }}( get_max_thread_blocks_()); #ifdef USE_ROCM - {%- if is_optimized_hip_kernel_supported_mode_ori %} + {%- if is_optimized_hip_kernel_supported_mode %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); From 9555b3be689ddb00f3def87e0622681313cf8407 Mon Sep 17 00:00:00 2001 From: Nicolas De Carli Date: Tue, 28 Oct 2025 14:07:51 -0700 Subject: [PATCH 54/92] Remove AVX compilation on aarch64 (#5065) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2072 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5065 We are removing AVX2/AVX512 compilation of fbgemm for aarch64. These modules are no longer used, thus removing them should decrease build time and code size. Reviewed By: mcfi, YifanYuan3 Differential Revision: D85603930 fbshipit-source-id: 42ed869c5ba406bdc1430d8e94b3726440eeb904 --- include/fbgemm/Fbgemm.h | 4 +- include/fbgemm/FbgemmConvert.h | 8 ++-- include/fbgemm/FbgemmEmbedding.h | 2 +- include/fbgemm/FbgemmI8DepthwiseAvx2.h | 4 ++ include/fbgemm/FbgemmSparse.h | 2 +- include/fbgemm/OutputProcessing-inl.h | 4 +- include/fbgemm/QuantUtilsAvx2.h | 14 ++++--- include/fbgemm/QuantUtilsAvx512.h | 2 +- src/FbgemmBfloat16Convert.cc | 4 +- src/FbgemmFP16.cc | 4 +- src/FbgemmFP16UKernelsAvx2.h | 4 ++ src/FbgemmFloat16Convert.cc | 4 +- src/FbgemmSparseDense.cc | 4 +- src/GroupwiseConv.cc | 4 +- src/PackWeightsForConv.cc | 4 +- src/PackWeightsForDirectConv.cc | 4 +- src/QuantUtilsNeon.cc | 52 +++++++++++++++++++++++++- src/TransposeUtils.cc | 5 +-- 18 files changed, 92 insertions(+), 37 deletions(-) diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index bc784f8035..4d1d2959ef 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -616,7 +616,7 @@ class FBGEMM_API PackWeightsForConv { return W_im2col_packed_; } -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) std::shared_ptr getPackedWForDepthwise() { return W_dw_packed_; } @@ -672,7 +672,7 @@ class FBGEMM_API PackWeightsForConv { const conv_param_t conv_param_; // Packed weights if we use im2col based convolution implementation std::shared_ptr> W_im2col_packed_; -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) // Packed weights if we use depthwise convolution implementation std::shared_ptr W_dw_packed_; #endif // __aarch64__ diff --git a/include/fbgemm/FbgemmConvert.h b/include/fbgemm/FbgemmConvert.h index cf404d2056..88dd5e8e30 100644 --- a/include/fbgemm/FbgemmConvert.h +++ b/include/fbgemm/FbgemmConvert.h @@ -47,6 +47,7 @@ FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size); FBGEMM_API void Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size); +#if !defined(__aarch64__) /** * @brief AVX2 implementation to convert fp32 numbers to bf16 numbers. * @@ -58,10 +59,8 @@ FloatToBfloat16_avx2(const float* src, bfloat16* dst, size_t size); * @brief AVX512 implementation to convert fp32 numbers to bf16 numbers. * */ -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) FBGEMM_API void FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size); -#endif /** * @brief AVX2 implementation to convert bf16 numbers to fp32 numbers. @@ -74,7 +73,6 @@ Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, size_t size); * @brief AVX512 implementation to convert bf16 numbers to fp32 numbers. * */ -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) FBGEMM_API void Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size); #endif @@ -124,6 +122,7 @@ Float16ToFloat_simd(const float16* src, float* dst, size_t size); * @brief AVX2 implementation to convert fp32 numbers to fp16 numbers. * */ +#if !defined(__aarch64__) FBGEMM_API void FloatToFloat16_avx2( const float* src, float16* dst, @@ -134,7 +133,6 @@ FBGEMM_API void FloatToFloat16_avx2( * @brief AVX512 implementation to convert fp32 numbers to fp16 numbers. * */ -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) FBGEMM_API void FloatToFloat16_avx512( const float* src, float16* dst, @@ -152,6 +150,7 @@ FBGEMM_API void FloatToFloat16_sve2( size_t size, bool do_clip = false); +#if !defined(__aarch64__) /** * @brief AVX2 implementation to convert fp16 numbers to fp32 numbers. * @@ -163,7 +162,6 @@ Float16ToFloat_avx2(const float16* src, float* dst, size_t size); * @brief AVX512 implementation to convert fp16 numbers to fp32 numbers. * */ -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) FBGEMM_API void Float16ToFloat_avx512(const float16* src, float* dst, size_t size); #endif diff --git a/include/fbgemm/FbgemmEmbedding.h b/include/fbgemm/FbgemmEmbedding.h index 073b9f8655..12eb8babd6 100644 --- a/include/fbgemm/FbgemmEmbedding.h +++ b/include/fbgemm/FbgemmEmbedding.h @@ -349,7 +349,7 @@ FBGEMM_API bool EmbeddingSpMDMBlockSize1_( bool use_offsets = true, bool is_bf16 = false); -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) template void compressed_indices_remap_avx512( std::int32_t offsets_numel, diff --git a/include/fbgemm/FbgemmI8DepthwiseAvx2.h b/include/fbgemm/FbgemmI8DepthwiseAvx2.h index 7aadb91290..4533902234 100644 --- a/include/fbgemm/FbgemmI8DepthwiseAvx2.h +++ b/include/fbgemm/FbgemmI8DepthwiseAvx2.h @@ -8,6 +8,8 @@ #pragma once +#if !defined(__aarch64__) + #include #include "fbgemm/ConvUtils.h" #include "fbgemm/FbgemmBuild.h" @@ -110,3 +112,5 @@ FBGEMM_API void depthwise_3d_same_pad( int num_threads = 1); } // namespace fbgemm + +#endif // !defined(__aarch64__) diff --git a/include/fbgemm/FbgemmSparse.h b/include/fbgemm/FbgemmSparse.h index 82e8f889c6..dc00338fb7 100644 --- a/include/fbgemm/FbgemmSparse.h +++ b/include/fbgemm/FbgemmSparse.h @@ -166,7 +166,7 @@ void SparseDenseMMAvx2( int ldc, bool accum = false); -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) void SparseDenseMMAvx512( int M, int N, diff --git a/include/fbgemm/OutputProcessing-inl.h b/include/fbgemm/OutputProcessing-inl.h index 5faabe7eeb..ec70a49aa4 100644 --- a/include/fbgemm/OutputProcessing-inl.h +++ b/include/fbgemm/OutputProcessing-inl.h @@ -125,7 +125,7 @@ ReQuantizeOutput::f( } } -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) } else if constexpr ( instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) { @@ -249,7 +249,7 @@ inline int ReQuantizeForFloat::f( } } -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) } else if constexpr ( instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) { bool b_symmetric = diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index 0f2859c8ff..6a0d85deb3 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -52,6 +52,13 @@ struct FBGEMM_API RequantizationParams { TensorQuantizationParams target_qparams; }; +/// @ingroup fbgemm-quant-utils-avx2 +/// +/// @brief Find the min and max value in a float matrix. +void FBGEMM_API FindMinMax(const float* m, float* min, float* max, int64_t len); + +#if !defined(__aarch64__) + //////////////////////////////////////////////////////////////////////////////// // Utility functions //////////////////////////////////////////////////////////////////////////////// @@ -77,11 +84,6 @@ void FusedQuantizeDequantizeAvx2( /// this paper. uint32_t FBGEMM_API Xor128(); -/// @ingroup fbgemm-quant-utils-avx2 -/// -/// @brief Find the min and max value in a float matrix. -void FBGEMM_API FindMinMax(const float* m, float* min, float* max, int64_t len); - void RequantizeFixedPointAvx2( const std::int32_t* src, std::uint8_t* dst, @@ -176,4 +178,6 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2( int input_columns, OutputType* output); +#endif // !defined(__aarch64__) + } // namespace fbgemm diff --git a/include/fbgemm/QuantUtilsAvx512.h b/include/fbgemm/QuantUtilsAvx512.h index c4b01817bd..1ad1efe71e 100644 --- a/include/fbgemm/QuantUtilsAvx512.h +++ b/include/fbgemm/QuantUtilsAvx512.h @@ -9,7 +9,7 @@ #pragma once #include "Types.h" -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) #include #include "./FbgemmBuild.h" // @manual diff --git a/src/FbgemmBfloat16Convert.cc b/src/FbgemmBfloat16Convert.cc index 34baed622b..4c7a358d94 100644 --- a/src/FbgemmBfloat16Convert.cc +++ b/src/FbgemmBfloat16Convert.cc @@ -29,7 +29,7 @@ namespace fbgemm { void FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size) { // Run time CPU detection if (cpuinfo_initialize()) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) if (fbgemmHasAvx512Support()) { FloatToBfloat16_avx512(src, dst, size); } else if (fbgemmHasAvx2Support()) { @@ -48,7 +48,7 @@ void FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size) { void Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size) { // Run time CPU detection if (cpuinfo_initialize()) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) if (fbgemmHasAvx512Support()) { Bfloat16ToFloat_avx512(src, dst, size); } else if (fbgemmHasAvx2Support()) { diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc index 106f071953..79eda23712 100644 --- a/src/FbgemmFP16.cc +++ b/src/FbgemmFP16.cc @@ -35,7 +35,7 @@ namespace { // the restrictions of ymm register numbers (16). constexpr kernel_array_t kernel_fp16_avx2 = { nullptr, -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) gemmkernel_1x2_Avx2_fp16_fA0fB0fC0, gemmkernel_2x2_Avx2_fp16_fA0fB0fC0, gemmkernel_3x2_Avx2_fp16_fA0fB0fC0, @@ -79,7 +79,7 @@ constexpr kernel_array_t kernel_fp16_neon = { constexpr kernel_array_t kernel_fp16_avx512_256 = { nullptr, -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) gemmkernel_1x2_Avx2_fp16_fA0fB0fC0, gemmkernel_2x2_Avx2_fp16_fA0fB0fC0, gemmkernel_3x2_Avx2_fp16_fA0fB0fC0, diff --git a/src/FbgemmFP16UKernelsAvx2.h b/src/FbgemmFP16UKernelsAvx2.h index 888bae1833..455c49fdd5 100644 --- a/src/FbgemmFP16UKernelsAvx2.h +++ b/src/FbgemmFP16UKernelsAvx2.h @@ -16,6 +16,8 @@ namespace fbgemm { using GemmParamsFP16 = GemmParams; +#if !defined(__aarch64__) + void NOINLINE gemmkernel_1x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); void NOINLINE gemmkernel_2x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); void NOINLINE gemmkernel_3x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); @@ -23,4 +25,6 @@ void NOINLINE gemmkernel_4x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); void NOINLINE gemmkernel_5x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); void NOINLINE gemmkernel_6x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); +#endif // !defined(__aarch64__) + } // namespace fbgemm diff --git a/src/FbgemmFloat16Convert.cc b/src/FbgemmFloat16Convert.cc index 9519d6cb62..1f76baeafc 100644 --- a/src/FbgemmFloat16Convert.cc +++ b/src/FbgemmFloat16Convert.cc @@ -23,7 +23,7 @@ void FloatToFloat16_simd( bool do_clip) { // Run time CPU detection if (cpuinfo_initialize()) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) if (fbgemmHasAvx512Support()) { FloatToFloat16_avx512(src, dst, size, do_clip); } else if (fbgemmHasAvx2Support()) { @@ -42,7 +42,7 @@ void FloatToFloat16_simd( void Float16ToFloat_simd(const float16* src, float* dst, size_t size) { // Run time CPU detection if (cpuinfo_initialize()) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) if (fbgemmHasAvx512Support()) { Float16ToFloat_avx512(src, dst, size); } else if (fbgemmHasAvx2Support()) { diff --git a/src/FbgemmSparseDense.cc b/src/FbgemmSparseDense.cc index 1e2122d78f..eb8a82f60c 100644 --- a/src/FbgemmSparseDense.cc +++ b/src/FbgemmSparseDense.cc @@ -193,7 +193,7 @@ void SparseDenseMM( float* C, int ldc, bool accum) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) // Run time CPU detection static const auto iset = fbgemmInstructionSet(); @@ -229,7 +229,7 @@ FBGEMM_API void fbgemmSparseDenseInt8MM( return; } -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) // Run time CPU detection static const auto iset = fbgemmInstructionSet(); diff --git a/src/GroupwiseConv.cc b/src/GroupwiseConv.cc index f92408f2ec..38ec4910b0 100644 --- a/src/GroupwiseConv.cc +++ b/src/GroupwiseConv.cc @@ -121,7 +121,7 @@ static jit_conv_kernel_fp getOrCreateConvKernel( accum); if (cpuinfo_initialize()) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) if (fbgemmHasAvx512VnniSupport()) { return GenConvKernel::codeCache_ .getOrCreate(kernelSig, [&]() { @@ -954,7 +954,7 @@ static void dispatchOutputProcessing( } if (cpuinfo_initialize()) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) if (fbgemmHasAvx512Support() || fbgemmHasAvx512VnniSupport()) { REQUANTIZE_C_PER_G(Avx512); } else if (fbgemmHasAvx2Support() || fbgemmHasArmNeonSupport()) { diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index 040dbe682c..7008da3e8f 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -25,7 +25,7 @@ PackWeightsForConv::PackWeightsForConv( // FbgemmConv.cc switch (ConvFastPath(conv_p)) { case optimized_conv_t::depthwise: { -#if !defined(FBGEMM_FBCODE) && defined(__aarch64__) +#if defined(__aarch64__) throw std::runtime_error( "PackWeightsForConv::PackWeightsForConv(): No fallback available for aarch64"); #else @@ -98,7 +98,7 @@ PackWeightsForConv::PackWeightsForConv( template void PackWeightsForConv::unpack(T* origin_buf) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) if (W_dw_packed_) { W_dw_packed_->unpack(origin_buf); } else diff --git a/src/PackWeightsForDirectConv.cc b/src/PackWeightsForDirectConv.cc index 3be4528642..01fcecc892 100644 --- a/src/PackWeightsForDirectConv.cc +++ b/src/PackWeightsForDirectConv.cc @@ -239,7 +239,7 @@ void fbgemmDirectConv( return; } -#if !defined(FBGEMM_FBCODE) && defined(__aarch64__) +#if defined(__aarch64__) throw std::runtime_error( "fbgemmDirectConv(): No fallback available for aarch64"); #else @@ -459,7 +459,7 @@ void fbgemmDirectConv( } } // else SPATIAL_DIM -#endif // defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#endif // !defined(__aarch64__) } #define INSTANTIATE_REQUANTIZE_SPATIAL_DIM( \ diff --git a/src/QuantUtilsNeon.cc b/src/QuantUtilsNeon.cc index a8835f0e05..dfb27fe8f8 100644 --- a/src/QuantUtilsNeon.cc +++ b/src/QuantUtilsNeon.cc @@ -6,13 +6,15 @@ * LICENSE file in the root directory of this source tree. */ -#include "fbgemm/Utils.h" +#if defined(__aarch64__) -#if HAVE_SVE +#include "fbgemm/Utils.h" #define FBGEMM_EXPORTS #include // @manual +#if HAVE_SVE #include // @manual +#endif #include // @manual #include //for std::min/std::max @@ -31,6 +33,50 @@ using namespace std; //////////////////////////////////////////////////////////////////////////////// // Utility functions +void FindMinMax(const float* m, float* min, float* max, int64_t len) { + if (__builtin_expect(len <= 0, 0)) { + *min = 0.0f; + *max = 0.0f; + return; + } + + float first = *m; + + float32x4_t temp_min_0 = vdupq_n_f32(first); + float32x4_t temp_min_1 = vdupq_n_f32(first); + float32x4_t temp_max_0 = vdupq_n_f32(first); + float32x4_t temp_max_1 = vdupq_n_f32(first); + uint64_t i = 0; + uint64_t count = static_cast(len); + uint64_t loopBound = count - (count % 8); + + for (; i < loopBound; i += 8) { + float32x4_t v0 = vld1q_f32(m + i); + float32x4_t v1 = vld1q_f32(m + i + 4); + temp_min_0 = vminq_f32(temp_min_0, v0); + temp_min_1 = vminq_f32(temp_min_1, v1); + temp_max_0 = vmaxq_f32(temp_max_0, v0); + temp_max_1 = vmaxq_f32(temp_max_1, v1); + } + + temp_min_0 = vminq_f32(temp_min_0, temp_min_1); + temp_max_0 = vmaxq_f32(temp_max_0, temp_max_1); + + float tmp_min_s = vminvq_f32(temp_min_0); + float tmp_max_s = vmaxvq_f32(temp_max_0); + + for (; i < count; i++) { + float tmp = *m; + tmp_min_s = std::min(tmp_min_s, tmp); + tmp_max_s = std::max(tmp_max_s, tmp); + } + + *min = tmp_min_s; + *max = tmp_max_s; +} + +#if HAVE_SVE + template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( const std::uint8_t* input, @@ -141,6 +187,8 @@ INSTANTIATE_QuantizationNeonFunctions8Bits(float16) // clang-format on #undef INSTANTIATE_QuantizationNeonFunctions8Bits +#endif // HAVE_SVE + } // namespace fbgemm #endif // __aarch64__ diff --git a/src/TransposeUtils.cc b/src/TransposeUtils.cc index aecec554da..cb1cb58d5a 100644 --- a/src/TransposeUtils.cc +++ b/src/TransposeUtils.cc @@ -57,14 +57,11 @@ void transpose_simd( #else static const auto iset = fbgemmInstructionSet(); // Run time CPU detection -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) if (isZmm(iset)) { internal::transpose_avx512(M, N, src, ld_src, dst, ld_dst); } else if (isYmm(iset)) { internal::transpose_avx2(M, N, src, ld_src, dst, ld_dst); - } else -#endif - { + } else { transpose_ref(M, N, src, ld_src, dst, ld_dst); } From 9d29ec169bee08869dd21f31d25698c50566bacb Mon Sep 17 00:00:00 2001 From: Emma Lin Date: Tue, 28 Oct 2025 20:10:27 -0700 Subject: [PATCH 55/92] add auto feature score collection to EC (#5030) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5030 X-link: https://github.com/meta-pytorch/torchrec/pull/3474 X-link: https://github.com/facebookresearch/FBGEMM/pull/2043 Enable feature score auto collection in ShardedEmbeddingCollection based on static feature to score mapping. If user needs custom score for specific id, they can disable auto collection and then change model code explicitly to collect score for each id. Here is the sample eviction policy config in embedding_table config to enable auto score collection: virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( training_id_eviction_trigger_count=260_000_000, # 260M training_id_keep_count=160_000_000, # 160M enable_auto_feature_score_collection=True, feature_score_mapping={ "sparse_public_original_content_creator": 1.0, }, feature_score_default_value=0.5, ), Additionally the counter collected previously during EC dedup is not used by kvzch backend, so this diff removed that counter and allow KJT to transfer a single float32 weight tensor to backend. This allows feature score collection for EBC since there could have another float weight for EBC pooling already. Reviewed By: RachelZheng, EddyLXJ Differential Revision: D83945722 fbshipit-source-id: 2dc71f6601de055b982f62ca3d73cdbe5fba2dce --- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 2 +- .../src/dram_kv_embedding_cache/dram_kv_embedding_cache.h | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index a497cf9a5b..32fb3991f7 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -2089,7 +2089,7 @@ def _prefetch( # noqa C901 torch.tensor( [weights.shape[0]], device="cpu", dtype=torch.long ), - weights.cpu().view(torch.float32).view(-1, 2), + weights.cpu(), ) # Generate row addresses (pointing to either L1 or the current diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h index 9738b846cc..98f3a44e35 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h @@ -770,7 +770,6 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { CHECK_EQ(indices.size(0), engege_rates.size(0)); auto indices_data_ptr = indices.data_ptr(); auto engage_rate_ptr = engege_rates.data_ptr(); - int64_t stride = 2; { auto before_write_lock_ts = facebook::WallClockUtil::NowInUsecFast(); @@ -785,8 +784,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { index_iter++) { const auto& id_index = *index_iter; auto id = int64_t(indices_data_ptr[id_index]); - float engege_rate = - float(engage_rate_ptr[id_index * stride + 0]); + float engege_rate = float(engage_rate_ptr[id_index]); // use mempool weight_type* block = nullptr; auto before_lookup_cache_ts = From e9e5fff5ca6f7fbfb55f89203e395f6dff8332af Mon Sep 17 00:00:00 2001 From: Gantaphon Chalumporn Date: Wed, 29 Oct 2025 09:27:47 -0700 Subject: [PATCH 56/92] Add kineto tracing to bench:jagged_tensor (#5061) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5061 X-link: https://github.com/facebookresearch/FBGEMM/pull/2063 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5039 X-link: https://github.com/facebookresearch/FBGEMM/pull/2048 Add kineto tracing to bench:jagged_tensor - only to keyed-jagged-index-select-dim1 for now Reviewed By: spcyppt Differential Revision: D85169086 fbshipit-source-id: 309a28eb93553196949e98d864ecc4b683b12e1f --- fbgemm_gpu/bench/jagged_tensor_benchmark.py | 61 +++++++++++++++------ 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index bcc3e27488..51375f0a64 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -10,8 +10,11 @@ import functools import logging +import os import random +from contextlib import nullcontext from dataclasses import dataclass +from typing import Callable import click import fbgemm_gpu @@ -542,6 +545,17 @@ def ref( @click.option("--has-weights", is_flag=True, default=False) @click.option("--weight-type", type=str, default="float") @click.option("--use-selected-lengths-sum", is_flag=True, default=False) +@click.option( + "--export-trace", + is_flag=True, + default=False, + help="Enable export of trace for profiling. Default is False.", +) +@click.option( + "--trace-url", + type=str, + default="keyed_jagged_index_select_dim1_{phase}_trace_{ospid}.json", +) def keyed_jagged_index_select_dim1( num_batches: int, max_seq_length: int, @@ -551,6 +565,8 @@ def keyed_jagged_index_select_dim1( has_weights: bool, weight_type: str, use_selected_lengths_sum: bool, + export_trace: bool, + trace_url: str, ) -> None: jagged_tensor_types = { "float": torch.float, @@ -622,20 +638,28 @@ def keyed_jagged_index_select_dim1( if is_float: values.requires_grad = True - time, output = benchmark_torch_function( - torch.ops.fbgemm.keyed_jagged_index_select_dim1, - ( - values, - lengths, - offsets, - indices, - input_batch_size, - weights, - selected_lengths_sum, - ), - iters=1000, - ) - output = output[0] + def _kineto_trace_handler(p: profile, phase: str) -> None: + p.export_chrome_trace(trace_url.format(phase=phase, ospid=os.getpid())) + + # pyre-ignore[3] + def context_factory(on_trace_ready: Callable[[profile], None]): + return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext() + + with context_factory(lambda p: _kineto_trace_handler(p, "fwd")): + time, output = benchmark_torch_function( + torch.ops.fbgemm.keyed_jagged_index_select_dim1, + ( + values, + lengths, + offsets, + indices, + input_batch_size, + weights, + selected_lengths_sum, + ), + iters=1000, + ) + output = output[0] # Prepare inputs for the reference run ref_inputs = [] @@ -687,9 +711,12 @@ def keyed_jagged_index_select_dim1_ref( return grad = torch.rand_like(output) - time, _ = benchmark_torch_function( - functools.partial(output.backward, retain_graph=True), (grad,), iters=1000 - ) + + with context_factory(lambda p: _kineto_trace_handler(p, "bwd")): + time, _ = benchmark_torch_function( + functools.partial(output.backward, retain_graph=True), (grad,), iters=1000 + ) + time_ref, _ = benchmark_torch_function( functools.partial(output_ref.backward, retain_graph=True), (grad,), iters=1000 ) From 678eaf7eafceaa88b40f3b2fb2bcbebcf7a09590 Mon Sep 17 00:00:00 2001 From: Eddy Li Date: Wed, 29 Oct 2025 10:10:45 -0700 Subject: [PATCH 57/92] Adding python api to support sync trigger evict (#4984) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/4984 X-link: https://github.com/facebookresearch/FBGEMM/pull/1997 As title, `has_running_evict` and `trigger_feature_evict` are needed to support sync trigger eviction Reviewed By: kathyxuyy Differential Revision: D83896308 fbshipit-source-id: 2c68a691ff66ca68c225528cdc7a8c7d50aab516 --- .../dram_kv_embedding_cache.h | 18 ++++++++---------- .../dram_kv_embedding_cache_wrapper.h | 8 ++++++++ .../embedding_rocksdb_wrapper.h | 8 ++++++++ .../kv_db_table_batched_embeddings.cpp | 8 ++++++++ .../kv_db_table_batched_embeddings.h | 4 ++++ .../ssd_split_table_batched_embeddings.cpp | 8 ++++++++ 6 files changed, 44 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h index 98f3a44e35..3f2848d4a3 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h @@ -1175,17 +1175,8 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { void compact() override {} - void trigger_feature_evict( - std::optional inplace_update_ts = std::nullopt) { + void trigger_feature_evict() { if (feature_evict_) { - if (inplace_update_ts.has_value() && - feature_evict_config_.value()->trigger_strategy_ == - EvictTriggerStrategy::BY_TIMESTAMP_THRESHOLD) { - auto* tt_evict = dynamic_cast*>( - feature_evict_.get()); - CHECK(tt_evict != nullptr); - tt_evict->set_eviction_timestamp_threshold(inplace_update_ts.value()); - } feature_evict_->trigger_evict(); } } @@ -1269,6 +1260,13 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { } } + bool is_evicting() override { + if (feature_evict_) { + return feature_evict_->is_evicting(); + } + return false; + } + // for inference only, this logs the total hit/miss count // this should be called at the end of full/delta snapshot chunk by chunk // update diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h index 8e70b41b93..11c4e43930 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h @@ -179,6 +179,14 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { impl_->set_backend_return_whole_row(backend_return_whole_row); } + void trigger_feature_evict() { + impl_->trigger_feature_evict(); + } + + bool is_evicting() { + return impl_->is_evicting(); + } + void set_feature_score_metadata_cuda( at::Tensor indices, at::Tensor count, diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h index 8cebdef1eb..4ca404c157 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h @@ -236,6 +236,14 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { impl_->set_backend_return_whole_row(backend_return_whole_row); } + void trigger_feature_evict() { + impl_->trigger_feature_evict(); + } + + bool is_evicting() { + return impl_->is_evicting(); + } + private: friend class KVTensorWrapper; diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index eb95d343e6..e0077058ee 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -383,6 +383,14 @@ void EmbeddingKVDB::set_backend_return_whole_row( return; } +void EmbeddingKVDB::trigger_feature_evict() { + return; +} + +bool EmbeddingKVDB::is_evicting() { + return false; +} + void EmbeddingKVDB::set( const at::Tensor& indices, const at::Tensor& weights, diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index 94e1a62711..a8082af235 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -301,6 +301,10 @@ class EmbeddingKVDB : public std::enable_shared_from_this { FBEXCEPTION("Not implemented"); } + virtual void trigger_feature_evict(); + + virtual bool is_evicting(); + /** * @brief need to support set backend_return_whole_row from frontend * if one model changed from SSD to DRAM, or vice versa we need to diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 0b95285a8f..64d4dc134c 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -880,6 +880,10 @@ static auto embedding_rocks_db_wrapper = { torch::arg("backend_return_whole_row"), }) + .def( + "trigger_feature_evict", + &EmbeddingRocksDBWrapper::trigger_feature_evict) + .def("is_evicting", &EmbeddingRocksDBWrapper::is_evicting) .def("stream_sync_cuda", &EmbeddingRocksDBWrapper::stream_sync_cuda) .def("get_cuda", &EmbeddingRocksDBWrapper::get_cuda) .def("compact", &EmbeddingRocksDBWrapper::compact) @@ -980,6 +984,10 @@ static auto dram_kv_embedding_cache_wrapper = { torch::arg("backend_return_whole_row"), }) + .def( + "trigger_feature_evict", + &DramKVEmbeddingCacheWrapper::trigger_feature_evict) + .def("is_evicting", &DramKVEmbeddingCacheWrapper::is_evicting) .def("set", &DramKVEmbeddingCacheWrapper::set) .def( "set_range_to_storage", From f1eb5b6dcb8cfd6f0cb1572fe3cf42f820ddb5e9 Mon Sep 17 00:00:00 2001 From: Eddy Li Date: Wed, 29 Oct 2025 17:03:30 -0700 Subject: [PATCH 58/92] Adding KVZCHEvictionTBEConfig in FBGEEM (#5058) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5058 X-link: https://github.com/facebookresearch/FBGEMM/pull/2067 See diff D85604160, this KVZCHEvictionTBEConfig is in FBGEMM and used in torchrec. Both FBGEEM and torchrec are open source in github. It is required to land first, otherwise torchrec github build will throw error {F1983027645} Reviewed By: emlin Differential Revision: D83896528 fbshipit-source-id: 7a8bacc3d0ee1f53a797dac6ba2647d372a15074 --- .../split_table_batched_embeddings_ops_common.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index 4d55ed2738..bd43100cb0 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -240,6 +240,19 @@ def validate(self) -> None: ), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled" +class KVZCHEvictionTBEConfig(NamedTuple): + # Eviction trigger model for kvzch table: 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count, 5: free_mem + kvzch_eviction_trigger_mode: int = 2 # mem_util + # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode. + eviction_free_mem_threshold_gb: int = 200 # 200GB + # Number of batches between checks for free memory threshold when using free_mem trigger mode. + eviction_free_mem_check_interval_batch: int = 1000 + # The width of each feature score bucket used for threshold calculation in feature score-based eviction. + threshold_calculation_bucket_stride: float = 0.2 + # Total number of feature score buckets used for threshold calculation in feature score-based eviction. + threshold_calculation_bucket_num: Optional[int] = 1000000 # 1M + + class BackendType(enum.IntEnum): SSD = 0 DRAM = 1 From 40a39cd6f9acafc8fa64da217f6af3f1b84f8bc4 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 30 Oct 2025 06:07:33 -0700 Subject: [PATCH 59/92] remove pt2 compliant xfails for jagged ops (#5068) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2075 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5068 letting CI tell me what tests to run to fix these ops for pt2 Reviewed By: ezyang Differential Revision: D85630006 fbshipit-source-id: dbf3795f7b1ed847882e17f3d24566e8339729fc --- .../jagged_tensor_ops/jagged_tensor_ops.cu | 5 --- .../jagged_tensor_ops_autograd.cpp | 42 +++++++++++++++++- .../jagged_tensor_ops_cpu.cpp | 4 +- .../jagged_tensor_ops_meta.cpp | 9 ++-- fbgemm_gpu/test/jagged/common.py | 11 +---- .../test/jagged/dense_to_jagged_test.py | 5 +++ .../jagged/jagged_index_select_2d_test.py | 20 +++++++++ .../jagged/jagged_to_padded_dense_test.py | 44 +++++++++++++++++++ 8 files changed, 117 insertions(+), 23 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu index dbccc6fdfd..bb6e3e2b96 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu @@ -8,11 +8,6 @@ #include "common.cuh" -FBGEMM_OP_DISPATCH(CUDA, "dense_to_jagged", fbgemm_gpu::dense_to_jagged); -FBGEMM_OP_DISPATCH( - CUDA, - "jagged_to_padded_dense", - fbgemm_gpu::jagged_to_padded_dense); FBGEMM_OP_DISPATCH( CUDA, "jagged_dense_elementwise_add", diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp index 1a20f680b8..ac14fdd975 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp @@ -48,6 +48,8 @@ class JaggedToPaddedDenseOp const std::vector& offsets, at::ArrayRef max_lengths, const double padding_value)>(); + + at::AutoDispatchBelowAutograd mode; Tensor padded_values = op.call(values, offsets, max_lengths, padding_value); return {padded_values}; @@ -286,6 +288,7 @@ class DenseToJaggedOp : public torch::autograd::Function { const Tensor& dense, const std::vector& offsets, std::optional total_L)>(); + at::AutoDispatchBelowAutograd mode; auto output = op.call(dense, offsets, total_L); return {output}; @@ -785,7 +788,7 @@ class JaggedSliceOp : public torch::autograd::Function { } // namespace ///@ingroup jagged-tensor-ops-cpu -Tensor jagged_to_padded_dense( +Tensor jagged_to_padded_dense_forward_autograd( const Tensor& values, const std::vector& offsets, const c10::SymIntArrayRef max_lengths, @@ -793,6 +796,22 @@ Tensor jagged_to_padded_dense( return JaggedToPaddedDenseOp::apply( values, offsets, max_lengths, padding_value)[0]; } +Tensor jagged_to_padded_dense( + const Tensor& values, + const std::vector& offsets, + const c10::SymIntArrayRef max_lengths, + const double padding_value) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::jagged_to_padded_dense_forward", "") + .typed& offsets, + at::ArrayRef max_lengths, + const double padding_value)>(); + Tensor output = op.call(values, offsets, max_lengths, padding_value); + return output; +} ///@ingroup jagged-tensor-ops-cpu /// Output = x + y where x is jagged, y and output are dense @@ -855,7 +874,20 @@ std::tuple> dense_to_jagged( const Tensor& dense, const std::vector& offsets, std::optional total_L) { - return {DenseToJaggedOp::apply(dense, offsets, total_L)[0], offsets}; + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::dense_to_jagged_forward", "") + .typed& offsets, + std::optional total_L)>(); + auto output = op.call(dense, offsets, total_L); + return {output, offsets}; +} +Tensor dense_to_jagged_forward_autograd( + const Tensor& dense, + const std::vector& offsets, + std::optional total_L) { + return DenseToJaggedOp::apply(dense, offsets, total_L)[0]; } ///@ingroup jagged-tensor-ops-cpu @@ -973,6 +1005,12 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) { m.impl("jagged_jagged_bmm", TORCH_FN(fbgemm_gpu::jagged_jagged_bmm)); m.impl("jagged_dense_bmm", TORCH_FN(fbgemm_gpu::jagged_dense_bmm)); m.impl("jagged_slice", TORCH_FN(fbgemm_gpu::jagged_slice)); + m.impl( + "jagged_to_padded_dense_forward", + TORCH_FN(fbgemm_gpu::jagged_to_padded_dense_forward_autograd)); + m.impl( + "dense_to_jagged_forward", + TORCH_FN(fbgemm_gpu::dense_to_jagged_forward_autograd)); } TORCH_LIBRARY_IMPL(fbgemm, CompositeImplicitAutograd, m) { diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp index eb047b882e..c5512509ff 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp @@ -1818,13 +1818,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { DISPATCH_TO_CPU("jagged_2d_to_dense", fbgemm_gpu::jagged_2d_to_dense); DISPATCH_TO_CPU("jagged_1d_to_dense", fbgemm_gpu::jagged_1d_to_dense); - DISPATCH_TO_CPU("dense_to_jagged", fbgemm_gpu::dense_to_jagged); DISPATCH_TO_CPU( "dense_to_jagged_forward", fbgemm_gpu::dense_to_jagged_forward); - DISPATCH_TO_CPU("jagged_to_padded_dense", fbgemm_gpu::jagged_to_padded_dense); DISPATCH_TO_CPU( "jagged_to_padded_dense_forward", - fbgemm_gpu::jagged_to_padded_dense_forward); + fbgemm_gpu::jagged_to_padded_dense_forward_cpu); DISPATCH_TO_CPU( "jagged_to_padded_dense_backward", fbgemm_gpu::jagged_to_padded_dense_backward); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp index 43cbb1c9bf..87c2ad23f0 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp @@ -53,18 +53,21 @@ Tensor jagged_to_padded_dense_meta( Tensor jagged_to_padded_dense_backward_meta( const at::Tensor& grad_output, - const std::vector& /*offsets*/, + const std::vector& offsets, at::SymInt total_L) { const auto& grad_padded_values = grad_output; - at::SymInt D = grad_padded_values.sym_size(-1); + const bool D_folded = grad_padded_values.dim() == offsets.size() + 1; + const auto& grad_padded_values_view = + D_folded ? grad_padded_values.unsqueeze(-1) : grad_padded_values; + at::SymInt D = grad_padded_values_view.sym_size(-1); // Initialize with zeros so output will be zero for the portion truncated // in forward. auto grad_values = at::zeros_symint({std::move(total_L), D}, grad_padded_values.options()); TORCH_CHECK(grad_values.is_meta()); - return grad_values; + return D_folded ? grad_values.squeeze(-1) : grad_values; } Tensor jagged_dense_dense_elementwise_add_jagged_output_forward_meta( diff --git a/fbgemm_gpu/test/jagged/common.py b/fbgemm_gpu/test/jagged/common.py index d8838b8447..2cdb9078ec 100644 --- a/fbgemm_gpu/test/jagged/common.py +++ b/fbgemm_gpu/test/jagged/common.py @@ -10,7 +10,6 @@ import itertools import sys -import unittest from typing import Callable import fbgemm_gpu @@ -43,15 +42,7 @@ # Please avoid putting tests here, you should put operator-specific # skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json # pyre-ignore[24]: Generic type `Callable` expects 2 type parameters. -additional_decorators: dict[str, list[Callable]] = { - "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [ - # This operator has been grandfathered in. We need to fix this test failure. - unittest.expectedFailure, - ], - "test_pt2_compliant_tag_fbgemm_jagged_to_padded_dense": [ - unittest.expectedFailure, - ], -} +additional_decorators: dict[str, list[Callable]] = {} def lengths_to_segment_ids(lengths: torch.Tensor) -> torch.Tensor: diff --git a/fbgemm_gpu/test/jagged/dense_to_jagged_test.py b/fbgemm_gpu/test/jagged/dense_to_jagged_test.py index 0e6e08e56a..d03823c364 100644 --- a/fbgemm_gpu/test/jagged/dense_to_jagged_test.py +++ b/fbgemm_gpu/test/jagged/dense_to_jagged_test.py @@ -80,6 +80,11 @@ def _test_dense_to_jagged( jagged_values.backward(ref_output_values) torch.testing.assert_close(dense.grad, ref_values) + torch.library.opcheck( + torch.ops.fbgemm.dense_to_jagged, + (dense.detach().requires_grad_(True), offsets), + ) + @given( num_jagged_dim=st.integers(1, 5), outer_dense_size=st.integers(0, 5), diff --git a/fbgemm_gpu/test/jagged/jagged_index_select_2d_test.py b/fbgemm_gpu/test/jagged/jagged_index_select_2d_test.py index 7433edbeb3..70b2ef276a 100644 --- a/fbgemm_gpu/test/jagged/jagged_index_select_2d_test.py +++ b/fbgemm_gpu/test/jagged/jagged_index_select_2d_test.py @@ -158,6 +158,26 @@ def test_jagged_index_select_2d( rtol=1e-2 if jagged_tensor_dtype in [torch.half, torch.bfloat16] else None, atol=1e-2 if jagged_tensor_dtype in [torch.half, torch.bfloat16] else None, ) + if known_shape: + with torch.no_grad(): + tmp_output, _ = torch.ops.fbgemm.jagged_index_select( + values, lengths, indices + ) + num_dense_output_rows = tmp_output.shape[0] + torch.library.opcheck( + torch.ops.fbgemm.jagged_index_select.default, + ( + values.detach().requires_grad_(), + lengths, + indices, + num_dense_output_rows, + ), + ) + else: + torch.library.opcheck( + torch.ops.fbgemm.jagged_index_select.default, + (values.detach().requires_grad_(), lengths, indices), + ) @given( max_seq_length=st.integers(5, 10), diff --git a/fbgemm_gpu/test/jagged/jagged_to_padded_dense_test.py b/fbgemm_gpu/test/jagged/jagged_to_padded_dense_test.py index 1242470d18..24a8567bee 100644 --- a/fbgemm_gpu/test/jagged/jagged_to_padded_dense_test.py +++ b/fbgemm_gpu/test/jagged/jagged_to_padded_dense_test.py @@ -113,6 +113,50 @@ def test_jagged_to_padded_dense( rtol=1e-3, ) + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c, d): + return torch.ops.fbgemm.jagged_to_padded_dense(a, b, c, d) + + with torch.inference_mode(): + gm = torch.export.export( + Mod(), + ( + x_values.float().requires_grad_(True), + x_offsets, + max_lengths.astype(int).tolist(), + padding_value, + ), + ).run_decompositions() + num_fw_ops = len( + [ + x + for x in gm.graph.nodes + if x.target is torch.ops.fbgemm.jagged_to_padded_dense_forward.default + ] + ) + num_composite_ops = len( + [ + x + for x in gm.graph.nodes + if x.target is torch.ops.fbgemm.jagged_to_padded_dense.default + ] + ) + self.assertEqual(num_fw_ops, 1) + self.assertEqual(num_composite_ops, 0) + + torch.library.opcheck( + torch.ops.fbgemm.jagged_to_padded_dense, + ( + x_values.float().requires_grad_(True), + x_offsets, + max_lengths, + padding_value, + ), + ) + @given( num_jagged_dim=st.integers(1, 5), outer_dense_size=st.integers(0, 5), From 9eef03126c15517cd4fb1fc06f4af1a29b5f16bc Mon Sep 17 00:00:00 2001 From: Ahmed Shuaibi Date: Thu, 30 Oct 2025 16:11:17 -0700 Subject: [PATCH 60/92] log all table names in TBE Summary: - add set of table names to set to simplify debugging X-link: https://github.com/facebookresearch/FBGEMM/pull/2076 Reviewed By: spcyppt Differential Revision: D85865458 Pulled By: ashuaibi7 fbshipit-source-id: 7dbee3e10f49c158b0c10e8cf0c5926072bf8499 --- .../split_table_batched_embeddings_ops_training.py | 2 +- fbgemm_gpu/test/tbe/utils/split_embeddings_test.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index a0bc843902..cd923b0e20 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1556,7 +1556,7 @@ def get_table_name_for_logging(table_names: Optional[list[str]]) -> str: table_name_set = set(table_names) if len(table_name_set) == 1: return next(iter(table_name_set)) - return f"<{len(table_name_set)} tables>" + return f"<{len(table_name_set)} tables>: {table_name_set}" @staticmethod def get_prefetch_passes( diff --git a/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py b/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py index 3dd1bc2cd4..b6864a3ac1 100644 --- a/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py @@ -178,17 +178,17 @@ def test_get_table_name_for_logging(self) -> None: SplitTableBatchedEmbeddingBagsCodegen.get_table_name_for_logging( ["t1", "t2"] ), - "<2 tables>", + "<2 tables>: {'t1', 't2'}", ) self.assertEqual( SplitTableBatchedEmbeddingBagsCodegen.get_table_name_for_logging( ["t1", "t2", "t1"] ), - "<2 tables>", + "<2 tables>: {'t1', 't2'}", ) self.assertEqual( SplitTableBatchedEmbeddingBagsCodegen.get_table_name_for_logging([]), - "<0 tables>", + "<0 tables>: set()", ) @unittest.skipIf(*gpu_unavailable) From c5619f2513f36d3e677cb6ab531883ccc603fe2b Mon Sep 17 00:00:00 2001 From: Tom Lin Date: Thu, 30 Oct 2025 19:37:09 -0700 Subject: [PATCH 61/92] Add sync ops and update the method names to be more generic for future integration + Abstract out APIs from dram_kv into an interface (#5069) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2077 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5069 - Refactor Dram KV inference embedding implementation by abstracting out APIs into an interface - Add sync ops and update the method names to be more generic for future integration Reviewed By: emlin Differential Revision: D85733409 fbshipit-source-id: 7d95914698500e6052ff716eaca8302f896f9455 --- .../dram_kv_embedding_inference_wrapper.cpp | 52 ++++-- .../dram_kv_embedding_inference_wrapper.h | 15 +- .../dram_kv_inference_embedding.h | 52 ++++-- .../kv_inference_embedding_interface.h | 158 ++++++++++++++++++ 4 files changed, 240 insertions(+), 37 deletions(-) create mode 100644 fbgemm_gpu/src/dram_kv_embedding_cache/kv_inference_embedding_interface.h diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp index 8145a42023..6361c4878a 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp @@ -10,11 +10,17 @@ #include #include #include "deeplearning/fbgemm/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h" // @manual=//deeplearning/fbgemm/fbgemm_gpu:fbgemm_gpu +#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h" DEFINE_int64( dram_kv_embedding_num_shards, 32, "Number of shards for DRAM KV inference embedding"); +DEFINE_bool( + kv_embedding_async_get_set, + true, + "Whether to use async get/set for DRAM KV inference embedding." + "This should be true for dram but might be different for other non-Dram backends."); namespace fbgemm_gpu { @@ -52,10 +58,10 @@ void DramKVEmbeddingInferenceWrapper::init( << ", row_alignment: " << row_alignment << ", scale_bias_size_in_bytes: " << scale_bias_size_in_bytes << ", max_row_bytes_: " << max_row_bytes_; - if (dram_kv_ != nullptr) { + if (kv_backend_ != nullptr) { return; } - dram_kv_ = std::make_shared>( + kv_backend_ = std::make_shared>( max_row_bytes_, uniform_init_lower_, uniform_init_upper_, @@ -86,14 +92,19 @@ void DramKVEmbeddingInferenceWrapper::init( disable_random_init_); } -std::shared_ptr> -DramKVEmbeddingInferenceWrapper::get_dram_kv() { - return dram_kv_; +int64_t DramKVEmbeddingInferenceWrapper::get_max_row_bytes() const { + return max_row_bytes_; } -void DramKVEmbeddingInferenceWrapper::set_dram_kv( - std::shared_ptr> dram_kv) { - dram_kv_ = std::move(dram_kv); +std::shared_ptr> +DramKVEmbeddingInferenceWrapper::get_kv_backend() { + return kv_backend_; +} + +void DramKVEmbeddingInferenceWrapper::set_kv_backend( + std::shared_ptr> + kv_backend) { + kv_backend_ = std::move(kv_backend); } void DramKVEmbeddingInferenceWrapper::set_embeddings( @@ -106,8 +117,13 @@ void DramKVEmbeddingInferenceWrapper::set_embeddings( inplacee_update_ts = static_cast(inplace_update_ts_opt.value()); } - folly::coro::blockingWait(dram_kv_->inference_set_kv_db_async( - indices, weights, count, inplacee_update_ts)); + + if (FLAGS_kv_embedding_async_get_set) { + folly::coro::blockingWait(kv_backend_->inference_set_kv_db_async( + indices, weights, count, inplacee_update_ts)); + } else { + kv_backend_->set_kv_db_sync(indices, weights, count, inplacee_update_ts); + } } at::Tensor DramKVEmbeddingInferenceWrapper::get_embeddings( @@ -119,24 +135,30 @@ at::Tensor DramKVEmbeddingInferenceWrapper::get_embeddings( max_row_bytes_, }, at::kByte); - folly::coro::blockingWait(dram_kv_->get_kv_db_async(indices, weights, count)); + + if (FLAGS_kv_embedding_async_get_set) { + folly::coro::blockingWait( + kv_backend_->get_kv_db_async(indices, weights, count)); + } else { + kv_backend_->get_kv_db_sync(indices, weights, count); + } return weights; } void DramKVEmbeddingInferenceWrapper::log_inplace_update_stats() { - dram_kv_->log_inplace_update_stats(); + kv_backend_->log_inplace_update_stats(); } void DramKVEmbeddingInferenceWrapper::trigger_evict( int64_t inplace_update_ts_64b) { uint32_t inplace_update_ts_32b = static_cast(inplace_update_ts_64b); - dram_kv_->trigger_feature_evict(inplace_update_ts_32b); - dram_kv_->resume_ongoing_eviction(); + kv_backend_->trigger_feature_evict(inplace_update_ts_32b); + kv_backend_->resume_ongoing_eviction(); } void DramKVEmbeddingInferenceWrapper::wait_evict_completion() { - dram_kv_->wait_until_eviction_done(); + kv_backend_->wait_until_eviction_done(); } c10::List DramKVEmbeddingInferenceWrapper::serialize() const { diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h index 7af9a83d74..1c0af807e9 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h @@ -10,9 +10,10 @@ #include #include -#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h" +#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/kv_inference_embedding_interface.h" DECLARE_int64(dram_kv_embedding_num_shards); +DECLARE_bool(kv_embedding_async_get_set); namespace fbgemm_gpu { @@ -46,22 +47,26 @@ class DramKVEmbeddingInferenceWrapper : public torch::jit::CustomClassHolder { void wait_evict_completion(); - std::shared_ptr> get_dram_kv(); + std::shared_ptr> + get_kv_backend(); - void set_dram_kv( - std::shared_ptr> dram_kv); + void set_kv_backend( + std::shared_ptr> + kv_backend); c10::List serialize() const; void deserialize(const c10::List& states); + int64_t get_max_row_bytes() const; + private: int64_t num_shards_ = 32; double uniform_init_lower_ = 0.0; double uniform_init_upper_ = 0.0; bool disable_random_init_ = false; - std::shared_ptr> dram_kv_; + std::shared_ptr> kv_backend_; int64_t max_row_bytes_ = 0; }; diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h index e40d0ffd5f..57c1f28160 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h @@ -30,6 +30,7 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" #include "feature_evict.h" #include "fixed_block_pool.h" +#include "kv_inference_embedding_interface.h" namespace kv_mem { @@ -64,7 +65,8 @@ namespace kv_mem { /// @brief An implementation of EmbeddingKVDB for ZCH v.Next /// template -class DramKVInferenceEmbedding { +class DramKVInferenceEmbedding + : public KVInferenceEmbeddingInterface { public: /// DramKVInferenceEmbedding constructor /// @@ -163,7 +165,7 @@ class DramKVInferenceEmbedding { double uniform_init_lower, double uniform_init_upper, int64_t row_storage_bitwidth, - bool disable_random_init) { + bool disable_random_init) override { for (auto i = 0; i < num_shards; ++i) { auto* gen = at::check_generator( at::detail::getDefaultCPUGenerator()); @@ -181,11 +183,26 @@ class DramKVInferenceEmbedding { disable_random_init_ = disable_random_init; } + void set_kv_db_sync( + const at::Tensor& /*indices*/, + const at::Tensor& /*weights*/, + const at::Tensor& /*count*/, + std::optional /*inplace_update_ts*/) override { + throw std::runtime_error("set_kv_db_sync is not implemented for DRAM"); + } + + void get_kv_db_sync( + const at::Tensor& /*indices*/, + const at::Tensor& /*weights*/, + const at::Tensor& /*count*/) override { + throw std::runtime_error("get_kv_db_sync is not implemented for DRAM"); + } + folly::SemiFuture> inference_set_kv_db_async( const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count, - std::optional inplace_update_ts) { + std::optional inplace_update_ts) override { std::vector>> futures; auto shardid_to_indexes = shard_input(indices, count); @@ -552,15 +569,15 @@ class DramKVInferenceEmbedding { folly::SemiFuture> get_kv_db_async( const at::Tensor& indices, const at::Tensor& weights, - const at::Tensor& count) { + const at::Tensor& count) override { current_iter_++; return get_kv_db_async_impl(indices, weights, count); } - void compact() {} + void compact() override {} void trigger_feature_evict( - std::optional inplace_update_ts = std::nullopt) { + std::optional inplace_update_ts = std::nullopt) override { if (feature_evict_) { if (inplace_update_ts.has_value() && feature_evict_config_.value()->trigger_strategy_ == @@ -574,7 +591,7 @@ class DramKVInferenceEmbedding { } } - void maybe_evict() { + void maybe_evict() override { if (!feature_evict_config_.has_value()) { return; } @@ -603,25 +620,25 @@ class DramKVInferenceEmbedding { } // wait until eviction finishes, if any - void wait_until_eviction_done() { + void wait_until_eviction_done() override { if (feature_evict_) { feature_evict_->wait_until_eviction_done(); } } - size_t get_map_used_memsize_in_bytes() const { + size_t get_map_used_memsize_in_bytes() const override { return kv_store_.getUsedMemSizeInBytes(); } - size_t get_map_actual_used_chunk_in_bytes() const { + size_t get_map_actual_used_chunk_in_bytes() const override { return kv_store_.getActualUsedChunkInBytes(); } - size_t get_num_rows() const { + size_t get_num_rows() const override { return kv_store_.getNumRows(); } - void resume_ongoing_eviction(bool force_resume = false) { + void resume_ongoing_eviction(bool force_resume = false) override { if (!force_resume) { return; } @@ -630,7 +647,7 @@ class DramKVInferenceEmbedding { } } - void pause_ongoing_eviction(bool force_pause = false) { + void pause_ongoing_eviction(bool force_pause = false) override { if (!force_pause) { return; } @@ -648,7 +665,7 @@ class DramKVInferenceEmbedding { // for inference only, this logs the total hit/miss count // this should be called at the end of full/delta snapshot chunk by chunk // update - void log_inplace_update_stats() { + void log_inplace_update_stats() override { int reset_val = 0; auto inplace_update_hit_cnt = inplace_update_hit_cnt_.exchange(reset_val); @@ -661,7 +678,8 @@ class DramKVInferenceEmbedding { << (total_cnt > 0 ? (double)inplace_update_hit_cnt / total_cnt : 0.0); } - std::optional get_feature_evict_metric() const { + std::optional get_feature_evict_metric() + const override { if (!feature_evict_config_.has_value()) { return std::nullopt; } @@ -789,11 +807,11 @@ class DramKVInferenceEmbedding { return shardid_to_indexes; } - void flush_or_compact(const int64_t timestep) {} + void flush_or_compact(const int64_t timestep) override {} std::vector get_dram_kv_perf( const int64_t step, - const int64_t interval) { + const int64_t interval) override { std::vector ret(23, 0); // num metrics if (step > 0 && step % interval == 0) { int reset_val = 0; diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/kv_inference_embedding_interface.h b/fbgemm_gpu/src/dram_kv_embedding_cache/kv_inference_embedding_interface.h new file mode 100644 index 0000000000..0f9090221b --- /dev/null +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/kv_inference_embedding_interface.h @@ -0,0 +1,158 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include + +#include "feature_evict.h" + +namespace kv_mem { + +/// @ingroup KVMemEmbedding +/// +/// @brief Interface for KV Inference Embedding implementations +/// +/// This interface defines the core API that all KV embedding implementations +/// must provide, enabling different backend implementations (DRAM, SSD, etc.) +/// to be used interchangeably. +/// +template +class KVInferenceEmbeddingInterface { + public: + virtual ~KVInferenceEmbeddingInterface() = default; + + /// Initialize the initializers for weight initialization + /// + /// @param num_shards number of shards for the kvstore + /// @param max_D the maximum dimension of embedding tensor + /// @param uniform_init_lower the lower bound of the uniform distribution + /// @param uniform_init_upper the upper bound of the uniform distribution + /// @param row_storage_bitwidth storage bitwidth for each row + /// @param disable_random_init whether to disable random initialization + virtual void initialize_initializers( + int64_t num_shards, + int64_t max_D, + double uniform_init_lower, + double uniform_init_upper, + int64_t row_storage_bitwidth, + bool disable_random_init) = 0; + + /// Set embeddings in the KV store (sync version) + /// + /// @param indices The 1D embedding index tensor + /// @param weights The 2D tensor containing embeddings + /// @param count A single element tensor with number of indices to process + /// @param inplace_update_ts Optional timestamp for inplace update + virtual void set_kv_db_sync( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count, + std::optional inplace_update_ts) = 0; + + /// Get embeddings from KV store (sync version) + /// + /// @param indices The 1D embedding index tensor + /// @param weights The 2D tensor to be filled with embeddings + /// @param count A single element tensor with number of indices to process + virtual void get_kv_db_sync( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count) = 0; + + /// Set embeddings in the KV store (async inference version) + /// + /// @param indices The 1D embedding index tensor + /// @param weights The 2D tensor containing embeddings + /// @param count A single element tensor with number of indices to process + /// @param inplace_update_ts Optional timestamp for inplace update + /// @return SemiFuture for async completion + virtual folly::SemiFuture> inference_set_kv_db_async( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count, + std::optional inplace_update_ts) = 0; + + /// Get embeddings from KV store (async) + /// + /// @param indices The 1D embedding index tensor + /// @param weights The 2D tensor to be filled with embeddings + /// @param count A single element tensor with number of indices to process + /// @return SemiFuture for async completion + virtual folly::SemiFuture> get_kv_db_async( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count) = 0; + + /// Compact the KV store (placeholder for future implementations) + virtual void compact() = 0; + + /// Trigger feature eviction + /// + /// @param inplace_update_ts Optional timestamp for eviction threshold + virtual void trigger_feature_evict( + std::optional inplace_update_ts = std::nullopt) = 0; + + /// Maybe trigger eviction based on configured trigger mode + virtual void maybe_evict() = 0; + + /// Wait until ongoing eviction completes + virtual void wait_until_eviction_done() = 0; + + /// Get the total memory used by the KV store + /// + /// @return Memory size in bytes + virtual size_t get_map_used_memsize_in_bytes() const = 0; + + /// Get the actual memory used by allocated chunks + /// + /// @return Memory size in bytes + virtual size_t get_map_actual_used_chunk_in_bytes() const = 0; + + /// Get the number of rows in the KV store + /// + /// @return Number of rows + virtual size_t get_num_rows() const = 0; + + /// Resume ongoing eviction + /// + /// @param force_resume Force resume even if not paused + virtual void resume_ongoing_eviction(bool force_resume = false) = 0; + + /// Pause ongoing eviction + /// + /// @param force_pause Force pause even if not running + virtual void pause_ongoing_eviction(bool force_pause = false) = 0; + + /// Log statistics for inplace update (inference only) + virtual void log_inplace_update_stats() = 0; + + /// Get feature eviction metrics + /// + /// @return Optional metrics tensors + virtual std::optional get_feature_evict_metric() + const = 0; + + /// Get performance metrics + /// + /// @param step Current step/iteration + /// @param interval Reporting interval + /// @return Vector of performance metrics + virtual std::vector get_dram_kv_perf( + const int64_t step, + const int64_t interval) = 0; + + /// Flush or compact at a specific timestep + /// + /// @param timestep The timestep for flush/compact + virtual void flush_or_compact(const int64_t timestep) = 0; +}; + +} // namespace kv_mem From 962f0137f2c82994a5764abb956d79927917e5d8 Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Thu, 30 Oct 2025 21:51:35 -0700 Subject: [PATCH 62/92] Cutlass Qtile Size shrunk to 64 (#5072) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5072 X-link: https://github.com/facebookresearch/FBGEMM/pull/2078 Changing the QtileSize to 64. I see good improvement > 20 %.. For correctness this includes changing the TMEM atoms and introducing warp sync for row stats. Perf: ``` (Batch, SeqLenQ, SeqLenKV, MaxLenKV, HeadQ, HeadKV, HeadD) cutlass_blackwell_fmha_decode-gbps Improvment with Qtile = 64 (16, 1, 256, 256, 8, 1, 128) 238.2206209 1.31463193 (16, 1, 512, 512, 8, 1, 128) 410.8838061 1.315872068 (16, 1, 1024, 1024, 8, 1, 128) 660.5696208 1.335567769 (16, 1, 2048, 2048, 8, 1, 128) 916.5460174 1.310093116 (16, 1, 4096, 4096, 8, 1, 128) 1133.690174 1.258896694 (16, 1, 8192, 8192, 8, 1, 128) 1271.341515 1.229311967 (32, 1, 256, 256, 8, 1, 128) 468.9034945 1.295635241 (32, 1, 512, 512, 8, 1, 128) 799.2689835 1.280831124 (32, 1, 1024, 1024, 8, 1, 128) 1285.452285 1.293538886 (32, 1, 2048, 2048, 8, 1, 128) 1797.074701 1.269787171 (32, 1, 4096, 4096, 8, 1, 128) 2210.946865 1.229703361 (32, 1, 8192, 8192, 8, 1, 128) 2498.665399 1.212166122 (64, 1, 256, 256, 8, 1, 128) 893.9747894 1.302172409 (64, 1, 512, 512, 8, 1, 128) 1493.150844 1.274679551 (64, 1, 1024, 1024, 8, 1, 128) 2309.825211 1.220419935 (64, 1, 2048, 2048, 8, 1, 128) 3012.271892 1.159444905 (64, 1, 4096, 4096, 8, 1, 128) 3552.001019 1.089389445 (64, 1, 8192, 8192, 8, 1, 128) 4348.016208 1.131298153 (128, 1, 256, 256, 8, 1, 128) 1549.388365 1.233405251 (128, 1, 512, 512, 8, 1, 128) 2480.52007 1.210676964 (128, 1, 1024, 1024, 8, 1, 128) 3360.125922 1.145674899 (128, 1, 2048, 2048, 8, 1, 128) 4103.461192 1.093136854 (128, 1, 4096, 4096, 8, 1, 128) 4783.429328 1.095583284 ``` Reviewed By: jianyuh, v0i0 Differential Revision: D85155388 fbshipit-source-id: ec3e43e2c7b0ce68c8eebc3fac74db6c9b66de07 --- .../blackwell_gen_impl.cu | 2 +- ...m100_fmha_gen_mainloop_warpspecialized.hpp | 111 ++++++++++-------- .../test/attention/blackwell_fmha_test.py | 12 +- 3 files changed, 76 insertions(+), 49 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu index 5b618b6526..5227510217 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu @@ -301,7 +301,7 @@ at::Tensor dispatch_fmha_gen_fwd( return DISPATCH_ELEMENT_TYPE(q.scalar_type(), Element, [&] { return DISPATCH_KERNEL_TYPE(static_cast(kernel_type), KType, [&] { - GenRunner, Shape<_1, _1, _1>> + GenRunner, Shape<_1, _1, _1>> runner; return runner.fmha_fwd(q, k, v, seqlen_kv, batch_idx); }); diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index e8e9aafceb..1738c121f1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -526,35 +526,48 @@ struct Sm100FmhaGenMainloopWarpspecialized { PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, OrderBarrierSoftmax& order_s) { - - Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + Tensor tScS = + typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); - Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); - tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); - Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); - - auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; - Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); - tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); - Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); - - // local changes - // Each thread owns a single row + Tensor tStS_v = + tStS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); + tStS_v.data() = + uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = + tScS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); + + auto tilePlikeFP32 = _32{}; // 32 for FP32 + // size<1>(TileShapeQK{}) / Int{} * Int{}; + + // tilePlikeFP32 = 64/4*2 = 32 for BF16 + // Preserve hierarchical structure: ((16, 4), 32) = 16*4*32 = 2048 elements + Tensor tStS_P = tStS.compose( + make_layout(make_shape(make_shape(_16{}, _4{}), tilePlikeFP32))); + tStS_P.data() = warp_uniform( + uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose( + make_layout(make_shape(make_shape(_16{}, _4{}), tilePlikeFP32))); + + // Select TMEM operation based on K dimension (number of columns) + // For K=64: 64 rows × 64 cols = 4,096 elements → use 16dp32b4x + // For K=128: 64 rows × 128 cols = 8,192 elements → use 16dp32b8x using TMEM_LOAD = conditional_t< - size<1>(TileShapeQK{}) < _128{}, - SM100_TMEM_LOAD_32dp32b8x, - SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + size<1>(TileShapeQK{}) == _64{}, + SM100_TMEM_LOAD_16dp32b16x, // For K=64: 4,096 elements + SM100_TMEM_LOAD_16dp32b8x>; // For K=128: 8,192 elements + using TMEM_STORE = conditional_t< - size<1>(TileShapeQK{}) < _128{}, - SM100_TMEM_STORE_32dp32b16x, - SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem - using TMEM_STORE_V = - SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + size<1>(TileShapeQK{}) == _64{}, + SM100_TMEM_STORE_16dp32b8x, // For K=64, BF16: 2,048 elements + SM100_TMEM_STORE_16dp32b8x>; - int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + // TMEM_STORE_V: Store row statistics (old_max, new_max) for online softmax + // correction Always 64 rows × 2 cols = 128 FP32 elements + using TMEM_STORE_V = SM100_TMEM_STORE_16dp32b2x; auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); @@ -604,7 +617,8 @@ struct Sm100FmhaGenMainloopWarpspecialized { row_max = ::fmax(row_max, row_max_2); row_max = ::fmax(row_max, row_max_3); } - + ElementQK shuffled_row_max = __shfl_xor_sync(0xffffffff, row_max, 16); + row_max = ::fmax(row_max, shuffled_row_max); ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); @@ -661,18 +675,12 @@ struct Sm100FmhaGenMainloopWarpspecialized { order_s.arrive(); } - // this prevents register spills in fp16 - if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { - if (i == size(tTMEM_LOADrS) - 6) { - copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); - } - } } // tmem_store(reg_S8) -> op_P - CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); - CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); - copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + // CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + // CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4); cutlass::arch::fence_view_async_tmem_store(); @@ -716,6 +724,8 @@ struct Sm100FmhaGenMainloopWarpspecialized { if (final_call) { // re-acquire the S part in the final step pipeline_s.consumer_wait(pipeline_s_consumer_state); + // Sync threads 0 and 16 to get the sum of row_sum between them + row_sum += __shfl_xor_sync(0xffffffff, row_sum, 16); Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; @@ -802,18 +812,24 @@ struct Sm100FmhaGenMainloopWarpspecialized { // As opposed to the softmax, we do not have enough registers here // to load all of the values (for tile kv = 128), so we loop // good values would be either 32 or 64 - const int kCorrectionTileSize = 32 / sizeof(ElementOut); + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + // TODO: load all values - using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOgO = mma.get_slice(0).partition_C(gO); - - Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); - Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); - Tensor tOgO_i = tOgO.compose(make_layout(make_shape(_128{}, Int{}))); + + Tensor tOtO_i = tOtO.compose(make_layout( + make_shape(make_shape(_16{}, _4{}), Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout( + make_shape(make_shape(_16{}, _4{}), Int{}))); + Tensor tOgO_i = tOgO.compose(make_layout( + make_shape(make_shape(_16{}, _4{}), Int{}))); Tensor tOtO0 = tOtO_i; tOtO0.data() = tOtO0.data().get() + uint32_t(TmemAllocation::O0); @@ -901,16 +917,18 @@ struct Sm100FmhaGenMainloopWarpspecialized { // good values would be either 32 or 64 const int kCorrectionTileSize = 32; - using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem - using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = SM100_TMEM_LOAD_16dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_16dp32b16x; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); - - Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); - Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + Tensor tOtO_i = tOtO.compose(make_layout( + make_shape(make_shape(_16{}, _4{}), Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout( + make_shape(make_shape(_16{}, _4{}), Int{}))); tOtO_i.data() = tOtO_i.data().get() + tmem_O; @@ -992,10 +1010,11 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); - Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); - Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tStS_v = tStS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); - using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + using TMEM_LOAD_V = + SM100_TMEM_LOAD_16dp32b2x; // 4x32 threads with 2 cols of 32b elem auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index 3ce07debff..42057aeafe 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -687,13 +687,13 @@ def _execute_cutlass_blackwell_attn_varlen( sm_scale, num_groups, ) - for dtype in [torch.bfloat16, torch.float8_e4m3fn] + for dtype in [torch.bfloat16] for seqlen_k in [64, 128, 256, 1024] for batch_size in [1, 2] for is_mqa in [True, False] for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)] for head_dim in [128] - for sm_scale in [None, 1.0 / head_dim] + for sm_scale in [None] for num_groups in [1, 2] ] ) @@ -711,6 +711,14 @@ def test_decode( ) -> None: seqlen_q = 1 causal = True + if True: + print( + f"Running test_decode with params: " + f"dtype={dtype}, seqlen_k={seqlen_k}, batch_size={batch_size}, " + f"is_mqa={is_mqa}, window_size={window_size}, head_dim={head_dim}, " + f"sm_scale={sm_scale}, q_heads={q_heads}" + ) + self._execute_cutlass_blackwell_attn_dense( batch_size, seqlen_q, From 8e60e437df6d544b3968e9105a3d65624f921153 Mon Sep 17 00:00:00 2001 From: Gefei Zuo Date: Fri, 31 Oct 2025 12:38:47 -0700 Subject: [PATCH 63/92] Mapping utilities (#5073) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5073 X-link: https://github.com/facebookresearch/FBGEMM/pull/2079 Compile-time static/const mapping utilities for: 1. constexpr value -> constexpr value 2. constexpr value -> type Useful when developing template-heavy cutlass code. Reviewed By: jianyuh Differential Revision: D85893168 fbshipit-source-id: 691dbb90e17c88dfc384432908e8ffdb8c0b2a04 --- .../collective/fmha_common.hpp | 89 +++++++++++++++++-- 1 file changed, 80 insertions(+), 9 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp index 2d3e2b166d..1e0ea6d449 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp @@ -78,10 +78,10 @@ to_tiled_mma_sm100_ts( TiledMMA, cute::C, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant>, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, TAs...>, TMs...>) { return TiledMMA, + a_major, + b_major, + a_neg, + b_neg>, TAs...>, TMs...>) { return TiledMMA +struct kValTyPair { + static constexpr auto key = keyVal; + using valueT = _valueT; +}; + +template +struct kValTyMap { + template + using query = std::conditional_t< + QueryKey == FirstMapping::key, + typename FirstMapping::valueT, + typename kValTyMap::template query>; +}; + +template +struct kValTyMap { + template + using query = std::conditional_t< + QueryKey == LastMapping::key, + typename LastMapping::valueT, + Default>; +}; + +} // namespace constexpr_type_map + +namespace constexpr_constexpr_map { + +template +struct kValValPair { + static constexpr auto key = keyVal; + static constexpr auto value = valueVal; +}; + +template +struct kValValMap { + using ValType = std::add_const_t; + static_assert( + std::is_same_v, + "Map value type mismatch"); + static_assert( + (std::is_same_v && ...), + "Map value type mismatch"); + template + static constexpr decltype(FirstMapping::value) query = + (QueryKey == FirstMapping::key) + ? FirstMapping::value + : kValValMap::template query; +}; + +template +struct kValValMap { + using ValType = std::add_const_t; + static_assert( + std::is_same_v, + "Map value type mismatch"); + template + static constexpr decltype(LastMapping::value) query = + (QueryKey == LastMapping::key) ? LastMapping::value : Default; +}; + +} // namespace constexpr_constexpr_map From 7d494da6eccee819c8903fd6d35f087523b7dd94 Mon Sep 17 00:00:00 2001 From: Nicolas De Carli Date: Fri, 31 Oct 2025 13:20:54 -0700 Subject: [PATCH 64/92] Fix build break (#5076) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2081 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5076 D85603930 removed AVX from aarch64 compilation, and broke Sigrid build. Proposed changes fix the build break. There is an fbgemm routine without ref implementation, so we need to implement a NEON port at a later diff. For now, four AVX2 files are compiled with the NEON package Reviewed By: YifanYuan3 Differential Revision: D85918535 fbshipit-source-id: 3a9892535046a0edc05c8c1fccbb0dac8ca8de35 --- include/fbgemm/Fbgemm.h | 4 ++-- include/fbgemm/FbgemmI8DepthwiseAvx2.h | 4 ---- include/fbgemm/QuantUtilsAvx2.h | 4 ++++ src/PackWeightsForConv.cc | 6 +++--- src/PackWeightsForDirectConv.cc | 2 +- src/QuantUtilsAvx2.cc | 4 ++++ 6 files changed, 14 insertions(+), 10 deletions(-) diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 4d1d2959ef..bc784f8035 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -616,7 +616,7 @@ class FBGEMM_API PackWeightsForConv { return W_im2col_packed_; } -#if !defined(__aarch64__) +#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) std::shared_ptr getPackedWForDepthwise() { return W_dw_packed_; } @@ -672,7 +672,7 @@ class FBGEMM_API PackWeightsForConv { const conv_param_t conv_param_; // Packed weights if we use im2col based convolution implementation std::shared_ptr> W_im2col_packed_; -#if !defined(__aarch64__) +#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) // Packed weights if we use depthwise convolution implementation std::shared_ptr W_dw_packed_; #endif // __aarch64__ diff --git a/include/fbgemm/FbgemmI8DepthwiseAvx2.h b/include/fbgemm/FbgemmI8DepthwiseAvx2.h index 4533902234..7aadb91290 100644 --- a/include/fbgemm/FbgemmI8DepthwiseAvx2.h +++ b/include/fbgemm/FbgemmI8DepthwiseAvx2.h @@ -8,8 +8,6 @@ #pragma once -#if !defined(__aarch64__) - #include #include "fbgemm/ConvUtils.h" #include "fbgemm/FbgemmBuild.h" @@ -112,5 +110,3 @@ FBGEMM_API void depthwise_3d_same_pad( int num_threads = 1); } // namespace fbgemm - -#endif // !defined(__aarch64__) diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index 6a0d85deb3..7c6fe24396 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -96,6 +96,8 @@ void RequantizeAvx2( int len, const RequantizationParams& params); +#endif // !defined(__aarch64__) + /// @ingroup fbgemm-quant-utils-avx2 /// /// Requantize with avx2 and bias is fused. @@ -145,6 +147,8 @@ FBGEMM_API void requantizeForFloatAvx2( int ld_in, const requantizationForFloatParams_t& r); +#if !defined(__aarch64__) + template void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2( const InputType* input, diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index 7008da3e8f..8870e6a903 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -25,7 +25,7 @@ PackWeightsForConv::PackWeightsForConv( // FbgemmConv.cc switch (ConvFastPath(conv_p)) { case optimized_conv_t::depthwise: { -#if defined(__aarch64__) +#if !defined(FBGEMM_FBCODE) && defined(__aarch64__) throw std::runtime_error( "PackWeightsForConv::PackWeightsForConv(): No fallback available for aarch64"); #else @@ -61,7 +61,7 @@ PackWeightsForConv::PackWeightsForConv( break; } case optimized_conv_t::directconv: { -#if defined(__aarch64__) +#if !defined(FBGEMM_FBCODE) && defined(__aarch64__) throw std::runtime_error( "PackWeightsForConv::PackWeightsForConv(): No fallback available for aarch64"); #else @@ -98,7 +98,7 @@ PackWeightsForConv::PackWeightsForConv( template void PackWeightsForConv::unpack(T* origin_buf) { -#if !defined(__aarch64__) +#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) if (W_dw_packed_) { W_dw_packed_->unpack(origin_buf); } else diff --git a/src/PackWeightsForDirectConv.cc b/src/PackWeightsForDirectConv.cc index 01fcecc892..db33d43d65 100644 --- a/src/PackWeightsForDirectConv.cc +++ b/src/PackWeightsForDirectConv.cc @@ -239,7 +239,7 @@ void fbgemmDirectConv( return; } -#if defined(__aarch64__) +#if !defined(FBGEMM_FBCODE) && defined(__aarch64__) throw std::runtime_error( "fbgemmDirectConv(): No fallback available for aarch64"); #else diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index ab6274d571..89deb44d39 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -280,6 +280,8 @@ SPECIALIZE_FUSEDDQAVX2(int8_t) #undef SPECIALIZE_FUSEDDQAVX2 +#ifndef __aarch64__ + void FindMinMax(const float* m, float* min, float* max, int64_t len) { if (len <= 0) { *min = 0.0f; @@ -317,6 +319,8 @@ void FindMinMax(const float* m, float* min, float* max, int64_t len) { *max = temp_max; } +#endif + //////////////////////////////////////////////////////////////////////////////// // Requantization (with floats) From 9db5454d0d6f0e3680661cb404c427703b593f3b Mon Sep 17 00:00:00 2001 From: Eddy Li Date: Fri, 31 Oct 2025 22:39:34 -0700 Subject: [PATCH 65/92] Free mem trigger with all2all for sync trigger eviction (#5062) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5062 X-link: https://github.com/meta-pytorch/torchrec/pull/3490 X-link: https://github.com/facebookresearch/FBGEMM/pull/2070 Before KVZCH is using ID_COUNT and MEM_UTIL eviction trigger mode, both are very tricky and hard for model engineer to decide what num to use for the id count or mem util threshold. Besides that, the eviction start time is out of sync after some time in training, which can cause great qps drop during eviction. This diff is adding support for free memory trigger eviction. It will check how many free memory left every N batch in every rank and if free memory below the threshold, it will trigger eviction in all tbes of all ranks using all reduce. In this way, we can force the start time of eviction in all ranks. Reviewed By: emlin Differential Revision: D85604160 fbshipit-source-id: 177ec779960a4ac9bfc3d41f38beeb7e56665db8 --- ...lit_table_batched_embeddings_ops_common.py | 25 ++- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 144 ++++++++++++++++-- fbgemm_gpu/requirements.txt | 1 + fbgemm_gpu/requirements_genai.txt | 1 + .../dram_kv_embedding_cache.h | 5 + .../dram_kv_embedding_cache/feature_evict.h | 11 +- 6 files changed, 161 insertions(+), 26 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index bd43100cb0..01832dfbc1 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -86,19 +86,19 @@ class EvictionPolicy(NamedTuple): None # feature_score_counter_decay_rates for each table if eviction strategy is feature score ) training_id_eviction_trigger_count: Optional[list[int]] = ( - None # training_id_eviction_trigger_count for each table + None # Number of training IDs that, when exceeded, will trigger eviction for each table. ) training_id_keep_count: Optional[list[int]] = ( - None # training_id_keep_count for each table + None # Target number of training IDs to retain in each table after eviction. ) l2_weight_thresholds: Optional[list[float]] = ( None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm ) threshold_calculation_bucket_stride: Optional[float] = ( - 0.2 # threshold_calculation_bucket_stride if eviction strategy is feature score + 0.2 # The width of each feature score bucket used for threshold calculation in feature score-based eviction. ) threshold_calculation_bucket_num: Optional[int] = ( - 1000000 # 1M, threshold_calculation_bucket_num if eviction strategy is feature score + 1000000 # 1M, Total number of feature score buckets used for threshold calculation in feature score-based eviction. ) interval_for_insufficient_eviction_s: int = ( # wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient @@ -114,10 +114,16 @@ class EvictionPolicy(NamedTuple): 24 * 3600 # 1 day, interval for feature statistics decay ) meta_header_lens: Optional[list[int]] = None # metaheader length for each table + eviction_free_mem_threshold_gb: Optional[int] = ( + None # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode. + ) + eviction_free_mem_check_interval_batch: Optional[int] = ( + None # Number of batches between checks for free memory threshold when using free_mem trigger mode. + ) def validate(self) -> None: - assert self.eviction_trigger_mode in [0, 1, 2, 3, 4], ( - "eviction_trigger_mode must be 0, 1, 2, 3 or 4 " + assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], ( + "eviction_trigger_mode must be 0, 1, 2, 3, 4, 5" f"actual {self.eviction_trigger_mode}" ) if self.eviction_trigger_mode == 0: @@ -143,6 +149,13 @@ def validate(self) -> None: assert ( self.training_id_eviction_trigger_count is not None ), "training_id_eviction_trigger_count must be set if eviction_trigger_mode is 4" + elif self.eviction_trigger_mode == 5: + assert ( + self.eviction_free_mem_threshold_gb is not None + ), "eviction_free_mem_threshold_gb must be set if eviction_trigger_mode is 5" + assert ( + self.eviction_free_mem_check_interval_batch is not None + ), "eviction_free_mem_check_interval_batch must be set if eviction_trigger_mode is 5" if self.eviction_strategy == 0: assert self.ttls_in_mins is not None, ( diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 32fb3991f7..59ea7f3b70 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -18,8 +18,9 @@ import time from functools import cached_property from math import floor, log2 -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, ClassVar, Optional, Union import torch # usort:skip +import weakref # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers @@ -34,6 +35,7 @@ BoundsCheckMode, CacheAlgorithm, EmbeddingLocation, + EvictionPolicy, get_bounds_check_version_for_platform, KVZCHParams, PoolingMode, @@ -54,6 +56,8 @@ from torch import distributed as dist, nn, Tensor # usort:skip from dataclasses import dataclass +import psutil + from torch.autograd.profiler import record_function from ..cache import get_unique_indices_v2 @@ -100,6 +104,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module): _local_instance_index: int = -1 res_params: RESParams table_names: list[str] + _all_tbe_instances: ClassVar[weakref.WeakSet] = weakref.WeakSet() + _first_instance_ref: ClassVar[weakref.ref] = None + _eviction_triggered: ClassVar[bool] = False def __init__( self, @@ -179,6 +186,7 @@ def __init__( table_names: Optional[list[str]] = None, use_rowwise_bias_correction: bool = False, # For Adam use optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006 + pg: Optional[dist.ProcessGroup] = None, ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() @@ -567,6 +575,10 @@ def __init__( # loading checkpoint flag, set by checkpoint loader, and cleared after weight is applied to backend self.load_state_dict: bool = False + SSDTableBatchedEmbeddingBags._all_tbe_instances.add(self) + if SSDTableBatchedEmbeddingBags._first_instance_ref is None: + SSDTableBatchedEmbeddingBags._first_instance_ref = weakref.ref(self) + # create tbe unique id using rank index | local tbe idx if tbe_unique_id == -1: SSDTableBatchedEmbeddingBags._local_instance_index += 1 @@ -584,6 +596,7 @@ def __init__( self.tbe_unique_id = tbe_unique_id self.l2_cache_size = l2_cache_size logging.info(f"tbe_unique_id: {tbe_unique_id}") + self.enable_free_mem_trigger_eviction: bool = False if self.backend_type == BackendType.SSD: logging.info( f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, " @@ -688,25 +701,31 @@ def __init__( if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb else self.l2_cache_size ) + kv_zch_params = self.kv_zch_params + eviction_policy = self.kv_zch_params.eviction_policy + if eviction_policy.eviction_trigger_mode == 5: + # If trigger mode is free_mem(5), populate config + self.set_free_mem_eviction_trigger_config(eviction_policy) + # Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters. eviction_config = torch.classes.fbgemm.FeatureEvictConfig( - self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count - self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score - self.kv_zch_params.eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration + eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count + eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score + eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util - self.kv_zch_params.eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp - self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter - self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter - self.kv_zch_params.eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score - self.kv_zch_params.eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table - self.kv_zch_params.eviction_policy.training_id_keep_count, # training_id_keep_count for each table - self.kv_zch_params.eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm + eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp + eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter + eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter + eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score + eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table + eviction_policy.training_id_keep_count, # training_id_keep_count for each table + eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm table_dims.tolist() if table_dims is not None else None, - self.kv_zch_params.eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score - self.kv_zch_params.eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score - self.kv_zch_params.eviction_policy.interval_for_insufficient_eviction_s, - self.kv_zch_params.eviction_policy.interval_for_sufficient_eviction_s, - self.kv_zch_params.eviction_policy.interval_for_feature_statistics_decay_s, + eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score + eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score + eviction_policy.interval_for_insufficient_eviction_s, + eviction_policy.interval_for_sufficient_eviction_s, + eviction_policy.interval_for_feature_statistics_decay_s, ) self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper( self.cache_row_dim, @@ -1065,6 +1084,8 @@ def __init__( self.bounds_check_version: int = get_bounds_check_version_for_platform() + self._pg = pg + @cached_property def cache_row_dim(self) -> int: """ @@ -2042,6 +2063,9 @@ def _prefetch( # noqa C901 if dist.get_rank() == 0: self._report_kv_backend_stats() + # May trigger eviction if free mem trigger mode enabled before get cuda + self.may_trigger_eviction() + # Fetch data from SSD if linear_cache_indices.numel() > 0: self.record_function_via_dummy_profile( @@ -4650,3 +4674,91 @@ def direct_write_embedding( ) # Return control to the main stream without waiting for the backend operation to complete + + def get_free_cpu_memory_gb(self) -> float: + mem = psutil.virtual_memory() + return mem.available / (1024**3) + + @classmethod + def trigger_evict_in_all_tbes(cls) -> None: + for tbe in cls._all_tbe_instances: + tbe.ssd_db.trigger_feature_evict() + + @classmethod + def tbe_has_ongoing_eviction(cls) -> bool: + for tbe in cls._all_tbe_instances: + if tbe.ssd_db.is_evicting(): + return True + return False + + def set_free_mem_eviction_trigger_config( + self, eviction_policy: EvictionPolicy + ) -> None: + self.enable_free_mem_trigger_eviction = True + self.eviction_trigger_mode: int = eviction_policy.eviction_trigger_mode + assert ( + eviction_policy.eviction_free_mem_check_interval_batch is not None + ), "eviction_free_mem_check_interval_batch is unexpected none for free_mem eviction trigger mode" + self.eviction_free_mem_check_interval_batch: int = ( + eviction_policy.eviction_free_mem_check_interval_batch + ) + assert ( + eviction_policy.eviction_free_mem_threshold_gb is not None + ), "eviction_policy.eviction_free_mem_threshold_gb is unexpected none for free_mem eviction trigger mode" + self.eviction_free_mem_threshold_gb: int = ( + eviction_policy.eviction_free_mem_threshold_gb + ) + logging.info( + f"[FREE_MEM Eviction] eviction config, trigger model: FREE_MEM, {self.eviction_free_mem_check_interval_batch=}, {self.eviction_free_mem_threshold_gb=}" + ) + + def may_trigger_eviction(self) -> None: + def is_first_tbe() -> bool: + first = SSDTableBatchedEmbeddingBags._first_instance_ref + return first is not None and first() is self + + # We assume that the eviction time is less than free mem check interval time + # So every time we reach this check, all evictions in all tbes should be finished. + # We only need to check the first tbe because all tbes share the same free mem, + # once the first tbe detect need to trigger eviction, it will call trigger func + # in all tbes from _all_tbe_instances + if ( + self.enable_free_mem_trigger_eviction + and self.step % self.eviction_free_mem_check_interval_batch == 0 + and self.training + and is_first_tbe() + ): + if not SSDTableBatchedEmbeddingBags.tbe_has_ongoing_eviction(): + SSDTableBatchedEmbeddingBags._eviction_triggered = False + + free_cpu_mem_gb = self.get_free_cpu_memory_gb() + local_evict_trigger = int( + free_cpu_mem_gb < self.eviction_free_mem_threshold_gb + ) + tensor_flag = torch.tensor( + local_evict_trigger, + device=self.current_device, + dtype=torch.int, + ) + world_size = dist.get_world_size(self._pg) + if world_size > 1: + dist.all_reduce(tensor_flag, op=dist.ReduceOp.SUM, group=self._pg) + global_evict_trigger = tensor_flag.item() + else: + global_evict_trigger = local_evict_trigger + if ( + global_evict_trigger >= 1 + and SSDTableBatchedEmbeddingBags._eviction_triggered + ): + logging.warning( + f"[FREE_MEM Eviction] {global_evict_trigger} ranks triggered eviction, but SSDTableBatchedEmbeddingBags._eviction_triggered is true" + ) + if ( + global_evict_trigger >= 1 + and not SSDTableBatchedEmbeddingBags._eviction_triggered + ): + SSDTableBatchedEmbeddingBags._eviction_triggered = True + SSDTableBatchedEmbeddingBags.trigger_evict_in_all_tbes() + logging.info( + f"[FREE_MEM Eviction] Evict all at batch {self.step}, {free_cpu_mem_gb} GB free CPU memory, {global_evict_trigger} ranks triggered eviction" + ) diff --git a/fbgemm_gpu/requirements.txt b/fbgemm_gpu/requirements.txt index c1f0bb92ff..dcd13bfcd9 100644 --- a/fbgemm_gpu/requirements.txt +++ b/fbgemm_gpu/requirements.txt @@ -29,3 +29,4 @@ setuptools_git_versioning tabulate patchelf fairscale +psutil diff --git a/fbgemm_gpu/requirements_genai.txt b/fbgemm_gpu/requirements_genai.txt index 59741362a5..722de8de37 100644 --- a/fbgemm_gpu/requirements_genai.txt +++ b/fbgemm_gpu/requirements_genai.txt @@ -30,3 +30,4 @@ setuptools_git_versioning tabulate patchelf fairscale +psutil diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h index 3f2848d4a3..4d1d2895a6 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h @@ -1212,6 +1212,11 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { } break; } + case EvictTriggerMode::FREE_MEM: { + // For free mem eviction, all conditions checked in frontend, no check + // option in backend + return; + } default: break; } diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h b/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h index e0443ee640..5637224754 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h @@ -34,7 +34,8 @@ enum class EvictTriggerMode { ITERATION, // Trigger based on iteration steps MEM_UTIL, // Trigger based on memory usage MANUAL, // Manually triggered by upstream - ID_COUNT // Trigger based on id count + ID_COUNT, // Trigger based on id count + FREE_MEM, // Trigger based on free memory }; inline std::string to_string(EvictTriggerMode mode) { switch (mode) { @@ -48,6 +49,8 @@ inline std::string to_string(EvictTriggerMode mode) { return "MANUAL"; case EvictTriggerMode::ID_COUNT: return "ID_COUNT"; + case EvictTriggerMode::FREE_MEM: + return "FREE_MEM"; } } @@ -184,6 +187,9 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder { eviction_trigger_stats_log += "]"; break; } + case EvictTriggerMode::FREE_MEM: { + break; + } default: throw std::runtime_error("Unknown evict trigger mode"); } @@ -202,7 +208,6 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder { case EvictTriggerStrategy::BY_FEATURE_SCORE: { CHECK(feature_score_counter_decay_rates_.has_value()); - CHECK(training_id_eviction_trigger_count_.has_value()); CHECK(training_id_keep_count_.has_value()); CHECK(threshold_calculation_bucket_stride_.has_value()); CHECK(threshold_calculation_bucket_num_.has_value()); @@ -210,8 +215,6 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder { LOG(INFO) << "eviction config, trigger mode:" << to_string(trigger_mode_) << eviction_trigger_stats_log << ", strategy: " << to_string(trigger_strategy_) - << ", training_id_eviction_trigger_count: " - << training_id_eviction_trigger_count_.value() << ", training_id_keep_count:" << training_id_keep_count_.value() << ", ttls_in_mins: " << ttls_in_mins_.value() From a515b0305a9e323b0d926b6d5ab9543d91a8d1c7 Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Fri, 31 Oct 2025 23:06:51 -0700 Subject: [PATCH 66/92] General adoption for Mtile = 64 (#5075) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5075 X-link: https://github.com/facebookresearch/FBGEMM/pull/2080 This diff generalizes the work in (D85155388) based on Gefei's diff D85631781 . Compared to D85631781, we avoid registers warp shuffling by using 32b TMEM atoms. This diff supports: 1. Different dtypes (fp8, bf16) 2. Different mtiles (128, 64) Reviewed By: v0i0 Differential Revision: D85893883 fbshipit-source-id: 25e93e627c573a120ab46336d3f234064c5ae066 --- ...m100_fmha_gen_mainloop_warpspecialized.hpp | 295 ++++++++++++------ .../test/attention/blackwell_fmha_test.py | 2 +- 2 files changed, 193 insertions(+), 104 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index 1738c121f1..1be2e43145 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -41,10 +41,13 @@ #include "collective/fmha_common.hpp" #include "collective/fmha_fusion.hpp" #include "collective/sm100_fmha_load_cpasync_warpspecialized.hpp" +#include "cutlass/detail/dependent_false.hpp" namespace cutlass::fmha::collective { using namespace cute; +using namespace constexpr_type_map; +using namespace constexpr_constexpr_map; template< class Element_, @@ -85,10 +88,32 @@ struct Sm100FmhaGenMainloopWarpspecialized { using StrideO = decltype(replace<0>(StrideO_{}, 0)); using Mask = Mask_; + using TileM = decltype(get<0>(TileShape{})); // seq Q dim + static_assert(TileM::value == 64 || TileM::value == 128, "Only expecting TileM to be 64 or 128"); static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2; - // local changes - static constexpr int StageCountKV = StageCountQ * (sizeof(Element) == 1 ? 11 : 5) ; - + // Choose StageCountKV based on: + // - Tile shape on the M (i.e., Query) dimension + // - Element size + using StageCountKVSelector = kValTyMap< + void, + kValTyPair<64, + kValValMap< + 65536 /* default, arbitrarily large to trigger smem OOM error */, + kValValPair<1, 12>, // fp8 + kValValPair<2, 6> // bf16/fp16 + >>, + kValTyPair<128, + kValValMap< + 65536 /* default, arbitrarily large to trigger smem OOM error */, + kValValPair<1, 11>, // fp8 + kValValPair<2, 5> // bf16/fp16 + >> + >; + static constexpr int StageCountKV = StageCountQ * + StageCountKVSelector:: + template query:: + template query; + using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; @@ -129,28 +154,52 @@ struct Sm100FmhaGenMainloopWarpspecialized { }; }; + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1, + kIdxStatsEnd = 2 + }; + + // Each storage reserves kTMEM_V_COLUMNS for row max/sum stats + // TileM=64 uses 16dp64b --> two threads processing a row + // TileM=128 uses 32dp32b --> one thread processing a row + using kTMEM_V_COLUMNS = typename kValTyMap>, + kValTyPair<128, Int> + >::template query; + + // TMEM column allocation, offset will be used to calc the lower 16-bit of tmem addresses. + // TMEM row/lane dimension is for the Q dim. enum class TmemAllocation : uint32_t { - kSizeS = 128, - kSizeO = 128, - kSizeP = 32, + kSizeS = get<1>(TileShapeQK{}), // i.e., KV dim in a tile + kSizeO = get<2>(TileShapeQK{}), // i.e., head dim + // carve kSizeS to two parts: first 1/4 for V0/V1 stats storage; the rest for P0/P1 + // 1/4 is wasting some storage here but there seems to be column-wise address alignment requirements not found in spec. + // Since there is enough storage left for P0/P1, chose to not debug alignment issues. + kSizeV = kSizeS / 2, + // P will be casted to the same type as V + kSizeP = kSizeS * sizeof(Element) / sizeof(float), S0 = 0, S1 = S0 + kSizeS, V0 = S0, // stats storage from softmax to correction V1 = S1, - P0 = S0 + kSizeP, - P1 = S1 + kSizeP, + P0 = V0 + kSizeV, + P1 = V1 + kSizeV, O0 = S1 + kSizeS, O1 = O0 + kSizeO, kEnd = O1 + kSizeO }; - - // indices for V0 / V1 - enum : int { - kIdxOldRowMax = 0, - kIdxNewRowMax = 1, - kIdxFinalRowSum = 0, - kIdxFinalRowMax = 1 - }; + static_assert(static_cast(TmemAllocation::kEnd) <= 512, "Exceeds TMEM 512 columns"); + static_assert( + static_cast(TmemAllocation::kSizeV) + static_cast(TmemAllocation::kSizeP) <= + static_cast(TmemAllocation::kSizeS), + "Not enough storage to carve V and P out of S"); + static_assert( + static_cast(kTMEM_V_COLUMNS::value) <= static_cast(TmemAllocation::kSizeV), + "Not enough storage reserved for V"); // from load to mma warp, protects q in smem using PipelineQ = cutlass::PipelineUmmaConsumerAsync< @@ -533,41 +582,41 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); - Tensor tStS_v = - tStS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); - tStS_v.data() = + Tensor tStS_v = + tStS.compose(make_layout(make_shape(TileM{}, kTMEM_V_COLUMNS{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); - Tensor tScS_v = - tScS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); + Tensor tScS_v = + tScS.compose(make_layout(make_shape(TileM{}, kTMEM_V_COLUMNS{}))); - auto tilePlikeFP32 = _32{}; // 32 for FP32 - // size<1>(TileShapeQK{}) / Int{} * Int{}; - - // tilePlikeFP32 = 64/4*2 = 32 for BF16 - // Preserve hierarchical structure: ((16, 4), 32) = 16*4*32 = 2048 elements + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; Tensor tStS_P = tStS.compose( - make_layout(make_shape(make_shape(_16{}, _4{}), tilePlikeFP32))); + make_layout(make_shape(TileM{}, tilePlikeFP32))); tStS_P.data() = warp_uniform( - uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); Tensor tScS_P = tScS.compose( - make_layout(make_shape(make_shape(_16{}, _4{}), tilePlikeFP32))); + make_layout(make_shape(TileM{}, tilePlikeFP32))); + + // needed number of cols to load from tmem to reg + constexpr int kConversionsPerStep = 2; + constexpr int kTmemLoadNcells = cute::min(32, size<1>(TileShapeQK{}) / kConversionsPerStep); + constexpr int kTmemStoreNcells = kTmemLoadNcells * sizeof_bits_v / sizeof_bits_v; - // Select TMEM operation based on K dimension (number of columns) - // For K=64: 64 rows × 64 cols = 4,096 elements → use 16dp32b4x - // For K=128: 64 rows × 128 cols = 8,192 elements → use 16dp32b8x - using TMEM_LOAD = conditional_t< - size<1>(TileShapeQK{}) == _64{}, - SM100_TMEM_LOAD_16dp32b16x, // For K=64: 4,096 elements - SM100_TMEM_LOAD_16dp32b8x>; // For K=128: 8,192 elements + using TMEM_LOAD_1xOP = typename kValTyMap, + // Each thread owns a single row + kValTyPair<128, SM100_TMEM_LOAD_32dp32b1x> + >::template query; + using TMEM_STORE_1xOP = decltype(TMEM::tmem_load_to_store(TMEM_LOAD_1xOP{})); + using TMEM_LOAD = decltype(TMEM::op_repeater()); + using TMEM_STORE = decltype(TMEM::op_repeater()); - using TMEM_STORE = conditional_t< - size<1>(TileShapeQK{}) == _64{}, - SM100_TMEM_STORE_16dp32b8x, // For K=64, BF16: 2,048 elements - SM100_TMEM_STORE_16dp32b8x>; + using TMEM_STORE_V = typename kValTyMap, + kValTyPair<128, SM100_TMEM_STORE_32dp32b2x> // 4x32 threads with 2 cols of 32b elem + >::template query; - // TMEM_STORE_V: Store row statistics (old_max, new_max) for online softmax - // correction Always 64 rows × 2 cols = 128 FP32 elements - using TMEM_STORE_V = SM100_TMEM_STORE_16dp32b2x; auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); @@ -616,12 +665,15 @@ struct Sm100FmhaGenMainloopWarpspecialized { row_max = ::fmax(row_max_0, row_max_1); row_max = ::fmax(row_max, row_max_2); row_max = ::fmax(row_max, row_max_3); + if constexpr (TileM{} == 64) { + ElementQK shuffled_row_max = __shfl_xor_sync(0xffffffff, row_max, 16); + row_max = ::fmax(row_max, shuffled_row_max); + } } - ElementQK shuffled_row_max = __shfl_xor_sync(0xffffffff, row_max, 16); - row_max = ::fmax(row_max, shuffled_row_max); ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + static_assert(size(tTMEM_STOREVrS) == 2); tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); @@ -639,48 +691,64 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); - constexpr int kConversionsPerStep = 2; + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); NumericArrayConverter convert; - const int kReleasePipeCount = 10; // must be multiple of 2 order_s.wait(); + static_assert(kReleasePipeCount % kConversionsPerStep == 0); + static_assert(kConversionsPerStep == 2); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { - float2 in = make_float2( - tTMEM_LOADrS(i + 0), - tTMEM_LOADrS(i + 1) - ); - float2 out; - cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); - tTMEM_LOADrS(i + 0) = out.x; - tTMEM_LOADrS(i + 1) = out.y; - - tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); - tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); - - Array in_conv; + { CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kConversionsPerStep; j++) { - in_conv[j] = tTMEM_LOADrS(i + j); - } - tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + for (int i = 0; i < size(tTMEM_LOADrS); i += kConversionsPerStep) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); - - if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { - order_s.arrive(); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + if constexpr (TileM::value == 128) { + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + //this prevents register spills in fp16 + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } } - - } + } // tmem_store(reg_S8) -> op_P - // CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); - // CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); - copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4); + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + if constexpr (TileM::value == 128) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + } else { + copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4); + } cutlass::arch::fence_view_async_tmem_store(); @@ -722,10 +790,14 @@ struct Sm100FmhaGenMainloopWarpspecialized { row_sum = local_row_sum; if (final_call) { + if constexpr (TileM{} == 64) { + // Sync threads 0 and 16 to get the sum of row_sum between them + row_sum += __shfl_xor_sync(0xffffffff, row_sum, 16); + } + // re-acquire the S part in the final step pipeline_s.consumer_wait(pipeline_s_consumer_state); - // Sync threads 0 and 16 to get the sum of row_sum between them - row_sum += __shfl_xor_sync(0xffffffff, row_sum, 16); + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; @@ -815,21 +887,31 @@ struct Sm100FmhaGenMainloopWarpspecialized { const int kCorrectionTileSize = 32 / sizeof(ElementOut); // TODO: load all values - using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + // Choose TMEM OP based on + // - TileM shape + // - kCorrectionTileSize + using TMEM_LOAD_OPMAP = kValTyMap + > + >, + kValTyPair<128, + kValTyMap + >> // 4x32 threads with 64 cols of 32b elem + >; + using TMEM_LOAD = typename TMEM_LOAD_OPMAP::template query::template query; typename CollectiveMmaPV::TiledMma mma; Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOgO = mma.get_slice(0).partition_C(gO); - - Tensor tOtO_i = tOtO.compose(make_layout( - make_shape(make_shape(_16{}, _4{}), Int{}))); - Tensor tOcO_i = tOcO.compose(make_layout( - make_shape(make_shape(_16{}, _4{}), Int{}))); - Tensor tOgO_i = tOgO.compose(make_layout( - make_shape(make_shape(_16{}, _4{}), Int{}))); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(TileM{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(TileM{}, Int{}))); + Tensor tOgO_i = tOgO.compose(make_layout(make_shape(TileM{}, Int{}))); Tensor tOtO0 = tOtO_i; tOtO0.data() = tOtO0.data().get() + uint32_t(TmemAllocation::O0); @@ -895,13 +977,13 @@ struct Sm100FmhaGenMainloopWarpspecialized { tCd(j) = convert.convert(tCs(j)); } - Tensor tSMgO_i = recast(tTMEM_LOADgO_i); - Tensor tSMrO_i = recast(tSMrO); + Tensor tSMgO_i = recast(tTMEM_LOADgO_i); + Tensor tSMrO_i = recast(tSMrO); - // could use masking do this right for smaller D - if (get<0>(tTMEM_LOADcO(_0{})) < get<0>(g_shape)) { + // could use masking do this right for smaller D + if (get<0>(tTMEM_LOADcO(_0{})) < get<0>(g_shape)) { copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMgO_i); - } + } } } @@ -917,18 +999,22 @@ struct Sm100FmhaGenMainloopWarpspecialized { // good values would be either 32 or 64 const int kCorrectionTileSize = 32; - using TMEM_LOAD = SM100_TMEM_LOAD_16dp32b16x; // 4x32 threads with 64 cols of 32b elem - using TMEM_STORE = SM100_TMEM_STORE_16dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = typename kValTyMap, + kValTyPair<128, SM100_TMEM_LOAD_32dp32b32x> // 4x32 threads with 64 cols of 32b elem + >::template query; + using TMEM_STORE = typename kValTyMap, + kValTyPair<128, SM100_TMEM_STORE_32dp32b32x> // 4x32 threads with 64 cols of 32b elem + >::template query; typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); - - Tensor tOtO_i = tOtO.compose(make_layout( - make_shape(make_shape(_16{}, _4{}), Int{}))); - Tensor tOcO_i = tOcO.compose(make_layout( - make_shape(make_shape(_16{}, _4{}), Int{}))); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(TileM{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(TileM{}, Int{}))); tOtO_i.data() = tOtO_i.data().get() + tmem_O; @@ -1009,13 +1095,15 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); - - Tensor tStS_v = tStS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); - Tensor tScS_v = tScS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); - using TMEM_LOAD_V = - SM100_TMEM_LOAD_16dp32b2x; // 4x32 threads with 2 cols of 32b elem + Tensor tStS_v = tStS.compose(make_layout(make_shape(TileM{}, kTMEM_V_COLUMNS{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(TileM{}, kTMEM_V_COLUMNS{}))); + using TMEM_LOAD_V = + typename kValTyMap, + kValTyPair<128, SM100_TMEM_LOAD_32dp32b2x> // 4x32 threads with 2 cols of 32b elem + >::template query; auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); @@ -1043,6 +1131,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + static_assert(size(tTMEM_LOADVrS) == 2); // read row_wise new global max copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index 42057aeafe..a3a51d15b4 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -687,7 +687,7 @@ def _execute_cutlass_blackwell_attn_varlen( sm_scale, num_groups, ) - for dtype in [torch.bfloat16] + for dtype in [torch.bfloat16, torch.float8_e4m3fn] for seqlen_k in [64, 128, 256, 1024] for batch_size in [1, 2] for is_mqa in [True, False] From 270edf4d0050ec9f51c076b7a2d2647fb33d54a9 Mon Sep 17 00:00:00 2001 From: Joey Yang Date: Mon, 3 Nov 2025 21:58:21 -0800 Subject: [PATCH 67/92] Map hash_zch_identities to corresponding unique indices in TBE (#5077) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2082 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5077 This change selects the `hash_zch_identities` that corresponds with unique indices during TBE prefetch. This is specifically required for MPZCH tables, which need both the slot index and the corresponding identities for correct lookup behavior. Without the identities, the inference side cannot correctly verify if it's using the correct slot, leading to potential lookup errors. Reviewed By: chouxi Differential Revision: D85999577 fbshipit-source-id: 3c8a4add1dd112e9a746b334e7046bb442ea977b --- ...t_table_batched_embeddings_ops_training.py | 113 +++++++++++---- .../training/store_prefetched_tensors_test.py | 134 ++++++++++++++++++ 2 files changed, 219 insertions(+), 28 deletions(-) create mode 100644 fbgemm_gpu/test/tbe/training/store_prefetched_tensors_test.py diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index cd923b0e20..7d86474cc1 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -200,6 +200,13 @@ class RESParams: ) # table sizes for the global rows the TBE holds +@dataclass(frozen=True) +class PrefetchedInfo: + linear_unique_indices: torch.Tensor + linear_unique_indices_length: torch.Tensor + hash_zch_identities: Optional[torch.Tensor] + + def construct_split_state( embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]], rowwise: bool, @@ -2100,6 +2107,12 @@ def forward( # noqa: C901 requires this information for allocating the weight gradient tensor in the backward pass. + hash_zch_identities (Optional[Tensor]): The original raw IDs before + remapping to ZCH (Zero-Collision Hash) table slots. This tensor is + populated when using Multi-Probe Zero Collision Hash (MPZCH) modules + and is required for Raw Embedding Streaming (RES) to maintain + consistency between training and inference. + Returns: A 2D-tensor containing looked up data. Shape `(B, total_D)` where `B` = batch size and `total_D` = the sum of all embedding dimensions in the @@ -2217,7 +2230,6 @@ def forward( # noqa: C901 # In forward, we don't enable multi-pass prefetch as we want the process # to be as fast as possible and memory usage doesn't matter (will be recycled # by dense fwd/bwd) - # TODO: Properly pass in the hash_zch_identities self._prefetch( indices, offsets, @@ -4143,6 +4155,60 @@ def raw_embedding_stream(self) -> None: False, # blocking_tensor_copy ) + @staticmethod + @torch.jit.ignore + def _get_prefetched_info( + linear_cache_indices_merged: torch.Tensor, + total_cache_hash_size: int, + hash_zch_identities: Optional[torch.Tensor], + ) -> PrefetchedInfo: + compute_inverse_indices = hash_zch_identities is not None + ( + linear_unique_indices, + linear_unique_indices_length, + linear_unique_indices_cnt, + linear_unique_inverse_indices, + ) = torch.ops.fbgemm.get_unique_indices_with_inverse( + linear_cache_indices_merged, + total_cache_hash_size, + compute_count=compute_inverse_indices, + compute_inverse_indices=compute_inverse_indices, + ) + # linear_unique_indices is the result after deduplication and sorting + linear_unique_indices = linear_unique_indices.narrow( + 0, 0, linear_unique_indices_length[0] + ) + + if hash_zch_identities is None: + return PrefetchedInfo( + linear_unique_indices, + linear_unique_indices_length, + None, + ) + + # Compute cumulative sum as indices for selecting unique elements to + # map hash_zch_identities to linear_unique_indices + count_cum_sum = torch.ops.fbgemm.asynchronous_complete_cumsum( + linear_unique_indices_cnt + ) + count_cum_sum = count_cum_sum.narrow(0, 0, linear_unique_indices_length[0]) + + # Select indices corresponding to first occurrence of each unique element + linear_unique_inverse_indices = linear_unique_inverse_indices.index_select( + dim=0, index=count_cum_sum + ) + + # Map hash_zch_identities to unique indices + hash_zch_identities_cpu = hash_zch_identities.index_select( + dim=0, index=linear_unique_inverse_indices + ).to(device=torch.device("cpu")) + + return PrefetchedInfo( + linear_unique_indices, + linear_unique_indices_length, + hash_zch_identities_cpu, + ) + @torch.jit.ignore def _store_prefetched_tensors( self, @@ -4153,35 +4219,26 @@ def _store_prefetched_tensors( NOTE: this needs to be a method with jit.ignore as the identities tensor is conditional. This function stores the prefetched tensors for the raw embedding streaming. """ - if self.enable_raw_embedding_streaming: - with record_function( - "## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid) - ): + if not self.enable_raw_embedding_streaming: + return + + with record_function( + "## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid) + ): + # Process hash_zch_identities using helper function + prefetched_info = self._get_prefetched_info( + linear_cache_indices_merged, + self.total_cache_hash_size, + hash_zch_identities, + ) + + self.prefetched_info.append( ( - linear_unique_indices, - linear_unique_indices_length, - _, - ) = torch.ops.fbgemm.get_unique_indices( - linear_cache_indices_merged, - self.total_cache_hash_size, - compute_count=False, - ) - linear_unique_indices = linear_unique_indices.narrow( - 0, 0, linear_unique_indices_length[0] - ) - self.prefetched_info.append( - ( - linear_unique_indices, - linear_unique_indices_length, - ( - hash_zch_identities.index_select( - dim=0, index=linear_unique_indices - ).to(device=torch.device("cpu")) - if hash_zch_identities is not None - else None - ), - ) + prefetched_info.linear_unique_indices, + prefetched_info.linear_unique_indices_length, + prefetched_info.hash_zch_identities, ) + ) @torch.jit.ignore def __report_input_params_factory( diff --git a/fbgemm_gpu/test/tbe/training/store_prefetched_tensors_test.py b/fbgemm_gpu/test/tbe/training/store_prefetched_tensors_test.py new file mode 100644 index 0000000000..b74690b2cc --- /dev/null +++ b/fbgemm_gpu/test/tbe/training/store_prefetched_tensors_test.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch + +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + SplitTableBatchedEmbeddingBagsCodegen, +) + +from ..common import open_source + +if open_source: + # pyre-ignore[21] + from test_utils import gpu_unavailable +else: + from fbgemm_gpu.test.test_utils import gpu_unavailable + + +class StorePrefetchedTensorsTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + def test_get_prefetched_info(self) -> None: + hash_zch_identities = torch.tensor( + [ + [3350213393928437575], # for index 54 + [6548733451892409412], # for index 27 + [4126118985661274454], # for index 43 + [2565973416302224539], # for index 90 + ], + device=torch.cuda.current_device(), + dtype=torch.int64, + ) + total_cache_hash_size = 100 + linear_cache_indices_merged = torch.tensor( + [54, 27, 43, 90], + device=torch.cuda.current_device(), + dtype=torch.int64, + ) + + prefetched_info = SplitTableBatchedEmbeddingBagsCodegen._get_prefetched_info( + linear_cache_indices_merged, + total_cache_hash_size, + hash_zch_identities, + ) + + self.assertEqual( + [27, 43, 54, 90], + prefetched_info.linear_unique_indices.tolist(), + ) + self.assertEqual( + prefetched_info.linear_unique_indices_length[0].item(), + 4, + ) + assert prefetched_info.hash_zch_identities is not None + self.assertEqual( + prefetched_info.hash_zch_identities.shape[0], + 4, + ) + self.assertEqual( + [ + [6548733451892409412], + [4126118985661274454], + [3350213393928437575], + [2565973416302224539], + ], + prefetched_info.hash_zch_identities.tolist(), + ) + + @unittest.skipIf(*gpu_unavailable) + def test_get_prefetched_info_with_duplicate_hash_zch_identities(self) -> None: + """ + Test that duplicate cache indices are correctly deduplicated. + When the same cache index appears multiple times with the same identity, + only the first occurrence should be kept in the output. + """ + hash_zch_identities = torch.tensor( + [ + [3350213393928437575], # for index 54 (first occurrence) + [6548733451892409412], # for index 27 + [3350213393928437575], # for index 54 (duplicate - same identity) + [4126118985661274454], # for index 43 + [6548733451892409412], # for index 27 (duplicate - same identity) + [3350213393928437575], # for index 54 (duplicate - same identity) + [2565973416302224539], # for index 90 + ], + device=torch.cuda.current_device(), + dtype=torch.int64, + ) + total_cache_hash_size = 100 + linear_cache_indices_merged = torch.tensor( + [54, 27, 54, 43, 27, 54, 90], # Duplicates: 54 appears 3x, 27 appears 2x + device=torch.cuda.current_device(), + dtype=torch.int64, + ) + + prefetched_info = SplitTableBatchedEmbeddingBagsCodegen._get_prefetched_info( + linear_cache_indices_merged, + total_cache_hash_size, + hash_zch_identities, + ) + + self.assertEqual( + [27, 43, 54, 90], + prefetched_info.linear_unique_indices.tolist(), + ) + self.assertEqual( + prefetched_info.linear_unique_indices_length[0].item(), + 4, + ) + assert prefetched_info.hash_zch_identities is not None + self.assertEqual( + prefetched_info.hash_zch_identities.shape[0], + 4, + ) + self.assertEqual( + [ + [6548733451892409412], # for index 27 + [4126118985661274454], # for index 43 + [3350213393928437575], # for index 54 + [2565973416302224539], # for index 90 + ], + prefetched_info.hash_zch_identities.tolist(), + ) + + +if __name__ == "__main__": + unittest.main() From d79485ea10b04777ce71c418dd1a2d887e2615c0 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Mon, 3 Nov 2025 22:18:28 -0800 Subject: [PATCH 68/92] Don't use 'not defined' in C++ preprocessing (#5025) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2086 While not is a valid C++ keyword, MSVC issues the following warnings ``` C:\actions-runner\_work\pytorch\pytorch\third_party\fbgemm\include\fbgemm\./FloatConversion.h(292): warning C4067: unexpected tokens following preprocessor directive - expected a newline ``` Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5025 Reviewed By: spcyppt Differential Revision: D86135907 Pulled By: q10 fbshipit-source-id: 3d55410aa1f6f4f1a4511d2881d1b0ba05ea5c5a --- include/fbgemm/FloatConversion.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/fbgemm/FloatConversion.h b/include/fbgemm/FloatConversion.h index f2628450e4..b88630d9b1 100644 --- a/include/fbgemm/FloatConversion.h +++ b/include/fbgemm/FloatConversion.h @@ -289,7 +289,7 @@ inline float cpu_half2float_ref(const float16 h) { // Same as the previous function, but use the built-in fp16 to fp32 // conversion provided by the compiler inline float cpu_half2float(const float16 h) { -#if defined(HAS_NATIVE_FP16_TYPE) && not defined(MISSING_GNU_F2H_IEEE) +#if defined(HAS_NATIVE_FP16_TYPE) && !defined(MISSING_GNU_F2H_IEEE) __fp16 h_fp16 = NAN; std::memcpy(&h_fp16, &h, sizeof(__fp16)); return h_fp16; @@ -299,7 +299,7 @@ inline float cpu_half2float(const float16 h) { } inline float16 cpu_float2half(const float f) { -#if defined(HAS_NATIVE_FP16_TYPE) && not defined(MISSING_GNU_F2H_IEEE) +#if defined(HAS_NATIVE_FP16_TYPE) && !defined(MISSING_GNU_F2H_IEEE) __fp16 h = f; float16 res = 0; std::memcpy(&res, &h, sizeof(__fp16)); From 063214f179af31e21adf907104ac7ad5b768233b Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 4 Nov 2025 02:44:57 -0800 Subject: [PATCH 69/92] Remove Python 3.9 support (#5081) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2088 - Remove Python 3.9 support, following PyTorch nightlies Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5081 Reviewed By: spcyppt Differential Revision: D86168579 Pulled By: q10 fbshipit-source-id: f15a5107ab9f86c7c07e704f510faab312bac858 --- .github/workflows/fbgemm_gpu_ci_cpu.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/fbgemm_gpu_ci_cpu.yml b/.github/workflows/fbgemm_gpu_ci_cpu.yml index 911944438a..5f5475acfc 100644 --- a/.github/workflows/fbgemm_gpu_ci_cpu.yml +++ b/.github/workflows/fbgemm_gpu_ci_cpu.yml @@ -75,7 +75,7 @@ jobs: { arch: arm, instance: "linux.arm64.m7g.4xlarge" }, ] build-target: [ "default" ] - python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ] + python-version: [ "3.10", "3.11", "3.12", "3.13" ] compiler: [ "gcc", "clang" ] steps: @@ -149,7 +149,7 @@ jobs: { arch: arm, instance: "linux.arm64.m7g.4xlarge", timeout: 30 }, ] build-target: [ "default" ] - python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ] + python-version: [ "3.10", "3.11", "3.12", "3.13" ] compiler: [ "gcc", "clang" ] needs: build_artifact From 0baae825dbcabf7fd52230e0a1323b2b208bc62b Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 4 Nov 2025 03:41:42 -0800 Subject: [PATCH 70/92] group_index_select_or_add_2d_kernel forward pass optimization (#5080) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5080 X-link: https://github.com/facebookresearch/FBGEMM/pull/2087 This PR introduces optimization for `group_index_select_or_add_2d_kernel` (`USE_INDEX_SELECT==true`) kernel with primary focus on `float` type and relatively small embedding dimensions. 2 things are implemented: 1) Extracted the common variables out of the loop to omit unnecessary synchronizations on memory load (compiler won't do that automatically) 2) Switch to 32 threads logical wave sizes to reduce granularity losses. Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5078 Reviewed By: spcyppt, haoyuz Differential Revision: D86135611 Pulled By: q10 fbshipit-source-id: f4fb9966f5f5180c4dde2aed92ca726c260b7743 --- .../src/sparse_ops/sparse_group_index.cu | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index c1ac40dea6..96c57cde68 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -12,10 +12,18 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { +#ifdef USE_ROCM +// The wave size is forced to be 32 on ROCm devices in favor +// of granularity losses reduction. +constexpr int EMULATED_WARP_SIZE = 32; +#else +constexpr int EMULATED_WARP_SIZE = kWarpSize; +#endif + // TODO: Update UNROLL_FACTOR constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1; constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP = - GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize; + GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE; // GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP = @@ -43,12 +51,21 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t num_work_rows, // number of rows to work on per member const int64_t group_size) { const auto total_num_warps = warp_offsets_group[group_size]; + int32_t num_cols = 0; + int32_t warps_per_row = 0; + + if constexpr (!USE_VAR_COLS) { + num_cols = num_cols_group[0]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + } + for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; warp_id < total_num_warps; warp_id += gridDim.x * blockDim.y) { - int32_t member_id, member_warp_id, num_cols, warps_per_row; - if (USE_VAR_COLS) { - __shared__ int member_ids[kMaxThreads / kWarpSize]; + int32_t member_id = 0; + int32_t member_warp_id = 0; + if constexpr (USE_VAR_COLS) { + __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE]; if (threadIdx.x == 0) { binary_search_range( &member_ids[threadIdx.y], @@ -63,8 +80,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( member_warp_id = warp_id - warp_offsets_group[member_id]; } else { // All columns are the same - num_cols = num_cols_group[0]; - warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); } @@ -82,7 +97,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( #pragma unroll for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { // Compile time conditional - if (USE_INDEX_SELECT) { + if constexpr (USE_INDEX_SELECT) { output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); } else { gpuAtomicAddNoReturn( @@ -113,13 +128,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda( at::cuda::OptionalCUDAGuard device_guard(device); // Partition work based on num_work_rows - uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize; + uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE; uint32_t max_grid_size = at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8; uint32_t grid_size = std::min( cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock), max_grid_size); - dim3 block_size(kWarpSize, num_warps_per_threadblock, 1); + dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); #define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \ FBGEMM_LAUNCH_KERNEL( \ From f1f2449fd9b3d1af5bf585b24d4b0f498c4d65a6 Mon Sep 17 00:00:00 2001 From: Zhengjun Xing Date: Tue, 4 Nov 2025 14:10:21 -0800 Subject: [PATCH 71/92] Fix OSError: [Errno 24] Too many open files in multi-copy benchmark (#5083) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5083 X-link: https://github.com/facebookresearch/FBGEMM/pull/2089 When running benchmarks with a large number of copies, the process may raise: OSError: [Errno 24] Too many open files. Example command: (fbgemm_gpu_env)$ ulimit -n 1048576 (fbgemm_gpu_env)$ python ./bench/tbe/tbe_inference_benchmark.py nbit-cpu \ --num-embeddings=40000000 --bag-size=2 --embedding-dim=96 \ --batch-size=162 --num-tables=8 --weights-precision=int4 \ --output-dtype=fp32 --copies=96 --iters=30000 PyTorch multiprocessing provides two shared-memory strategies: 1.file_descriptor (default) 2.file_system The default file_descriptor strategy uses file descriptors as shared memory handles, which can result in a large number of open FDs when many tensors are shared. If the total number of open FDs exceeds the system limit and cannot be raised, the file_system strategy should be used instead. This patch allows switching to the file_system strategy by setting: export PYTORCH_SHARE_STRATEGY='file_system' Reference: https://pytorch.org/docs/stable/multiprocessing.html#sharing-strategies Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5037 Reviewed By: spcyppt Differential Revision: D86135817 Pulled By: q10 fbshipit-source-id: 15f6fe7e1de5e9fef828f5a1496dc1cf9b41c293 --- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index ae805870bd..f0ac6f1a70 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -153,6 +153,13 @@ def benchmark_cpu_requests_mp( float: The average runtime per iteration in seconds. """ + import os + + strategy = os.environ.get("PYTORCH_SHARE_STRATEGY") + current_strategy = torch.multiprocessing.get_sharing_strategy() + if strategy is not None and current_strategy != strategy: + torch.multiprocessing.set_sharing_strategy(strategy) + cpu_bm_barrier.create_barrier(num_copies) worker_pool = torch.multiprocessing.Pool(num_copies) From b48b0b7f3ce3d955539d53880f3bd093b51bb8db Mon Sep 17 00:00:00 2001 From: Eddy Li Date: Tue, 4 Nov 2025 19:00:17 -0800 Subject: [PATCH 72/92] Support eval mode for st publish (#5085) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5085 X-link: https://github.com/facebookresearch/FBGEMM/pull/2093 As title, in silvertorch bulk eval, they will not call eval() for the module but using torch.no_grad() to run. https://www.internalfb.com/code/fbsource/[324dbccd0ab0]/fbcode/dper_lib/silvertorch/core/publish/data_processing/bulk_eval_dmp_gpu.py?lines=1057 So set a eval mode to turn the self.training to False in tbe for bulk eval. Reviewed By: emlin Differential Revision: D86220286 fbshipit-source-id: 9a48c7b4dc09767c99a545d1f25e53bf4265079f --- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 59ea7f3b70..4cdbe4a2eb 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -4762,3 +4762,9 @@ def is_first_tbe() -> bool: logging.info( f"[FREE_MEM Eviction] Evict all at batch {self.step}, {free_cpu_mem_gb} GB free CPU memory, {global_evict_trigger} ranks triggered eviction" ) + + def reset_inference_mode(self) -> None: + """ + Reset the inference mode + """ + self.eval() From 1a0eb0f4a2cb44752f0cf3a0a6bca141bd2a5d42 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 4 Nov 2025 22:12:41 -0800 Subject: [PATCH 73/92] Fix test reliability with table order (#5087) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2097 - Fix test reliability with table order Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5087 Reviewed By: spcyppt Differential Revision: D86242426 Pulled By: q10 fbshipit-source-id: 4ec307ff8fd9151bddb6bf7354bfe06f67a1fa0b --- .github/workflows/_fbgemm_gpu_cuda_test.yml | 3 +++ .../split_table_batched_embeddings_ops_training.py | 2 +- fbgemm_gpu/setup.py | 4 +--- fbgemm_gpu/test/tbe/utils/split_embeddings_test.py | 6 +++--- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/_fbgemm_gpu_cuda_test.yml b/.github/workflows/_fbgemm_gpu_cuda_test.yml index 03d619cae0..692b6ab7ac 100644 --- a/.github/workflows/_fbgemm_gpu_cuda_test.yml +++ b/.github/workflows/_fbgemm_gpu_cuda_test.yml @@ -132,6 +132,9 @@ jobs: # clang-16: error: unknown argument: '-fno-tree-loop-vectorize' run: . $PRELUDE; install_cxx_compiler $BUILD_ENV gcc + - name: Install Build Tools + run: . $PRELUDE; install_build_tools $BUILD_ENV + - name: Install CUDA run: . $PRELUDE; install_cuda $BUILD_ENV ${{ matrix.cuda-version }} diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 7d86474cc1..4f1741b3dc 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1560,7 +1560,7 @@ def get_table_name_for_logging(table_names: Optional[list[str]]) -> str: return "" # Do this because sometimes multiple shards of the same table could appear # in one TBE. - table_name_set = set(table_names) + table_name_set = sorted(set(table_names)) if len(table_name_set) == 1: return next(iter(table_name_set)) return f"<{len(table_name_set)} tables>: {table_name_set}" diff --git a/fbgemm_gpu/setup.py b/fbgemm_gpu/setup.py index 97600fc0bb..dd3246539a 100644 --- a/fbgemm_gpu/setup.py +++ b/fbgemm_gpu/setup.py @@ -4,8 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# @licenselint-loose-mode - import argparse import logging import os @@ -655,7 +653,7 @@ def main(argv: list[str]) -> None: ] + [ f"Programming Language :: Python :: {x}" - for x in ["3", "3.9", "3.10", "3.11", "3.12", "3.13"] + for x in ["3", "3.10", "3.11", "3.12", "3.13"] ], ) diff --git a/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py b/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py index b6864a3ac1..1ebdaddaa5 100644 --- a/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py @@ -178,17 +178,17 @@ def test_get_table_name_for_logging(self) -> None: SplitTableBatchedEmbeddingBagsCodegen.get_table_name_for_logging( ["t1", "t2"] ), - "<2 tables>: {'t1', 't2'}", + "<2 tables>: ['t1', 't2']", ) self.assertEqual( SplitTableBatchedEmbeddingBagsCodegen.get_table_name_for_logging( ["t1", "t2", "t1"] ), - "<2 tables>: {'t1', 't2'}", + "<2 tables>: ['t1', 't2']", ) self.assertEqual( SplitTableBatchedEmbeddingBagsCodegen.get_table_name_for_logging([]), - "<0 tables>: set()", + "<0 tables>: []", ) @unittest.skipIf(*gpu_unavailable) From 9b996a2a08688f6441dde424944e24aa9c47bbb1 Mon Sep 17 00:00:00 2001 From: Nicolas De Carli Date: Wed, 5 Nov 2025 09:58:34 -0800 Subject: [PATCH 74/92] Add NEON-based FloatOrHalfToFused8BitRowwiseQuantizedSBFloat (#5089) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2098 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5089 Adding NEON translation of FloatOrHalfToFused8BitRowwiseQuantizedSBFloat, used by Ads Performance improves by an order of magnitude: Before: bit_rate, rows, cols, elems_per_usec, GB/Sec 8, 100, 16, 378.68, 1.51 8, 100, 64, 286.91, 1.15 8, 100, 128, 262.06, 1.05 8, 100, 256, 251.34, 1.01 8, 100, 512, 244.92, 0.98 8, 100, 1024, 237.35, 0.95 8, 100, 2048, 230.83, 0.92 8, 120, 16, 378.70, 1.51 8, 120, 64, 286.72, 1.15 8, 120, 128, 263.40, 1.05 8, 120, 256, 251.58, 1.01 8, 120, 512, 245.30, 0.98 8, 120, 1024, 238.17, 0.95 8, 120, 2048, 230.69, 0.92 8, 1000, 16, 392.85, 1.57 8, 1000, 64, 294.35, 1.18 8, 1000, 128, 264.35, 1.06 8, 1000, 256, 252.13, 1.01 8, 1000, 512, 245.50, 0.98 8, 1000, 1024, 241.61, 0.97 8, 1000, 2048, 231.39, 0.93 After: bit_rate, rows, cols, elems_per_usec, GB/Sec 8, 100, 16, 1855.59, 7.42 8, 100, 64, 2615.43, 10.46 8, 100, 128, 3134.34, 12.54 8, 100, 256, 2610.72, 10.44 8, 100, 512, 3065.20, 12.26 8, 100, 1024, 3535.29, 14.14 8, 100, 2048, 3757.66, 15.03 8, 120, 16, 1991.94, 7.97 8, 120, 64, 2971.25, 11.89 8, 120, 128, 3403.37, 13.61 8, 120, 256, 2750.87, 11.00 8, 120, 512, 3272.63, 13.09 8, 120, 1024, 3618.98, 14.48 8, 120, 2048, 3848.59, 15.39 8, 1000, 16, 2329.11, 9.32 8, 1000, 64, 3068.76, 12.28 8, 1000, 128, 3678.86, 14.72 8, 1000, 256, 4440.37, 17.76 8, 1000, 512, 4558.70, 18.23 8, 1000, 1024, 4620.94, 18.48 8, 1000, 2048, 3898.84, 15.60 Reviewed By: mcfi Differential Revision: D86236406 fbshipit-source-id: 12c20cbdbbc9b0674ccca8e1aa598b7de144dea9 --- include/fbgemm/QuantUtilsNeon.h | 7 + src/QuantUtils.cc | 5 + src/QuantUtilsNeon.cc | 243 ++++++++++++++++++++++++++++---- 3 files changed, 226 insertions(+), 29 deletions(-) diff --git a/include/fbgemm/QuantUtilsNeon.h b/include/fbgemm/QuantUtilsNeon.h index 63f108b418..13169c8a05 100644 --- a/include/fbgemm/QuantUtilsNeon.h +++ b/include/fbgemm/QuantUtilsNeon.h @@ -22,6 +22,13 @@ namespace fbgemm { // Utility functions //////////////////////////////////////////////////////////////////////////////// +template +void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon( + const InputType* input, + size_t input_rows, + int input_columns, + uint8_t* output); + template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( const std::uint8_t* input, diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index 1c2e58363d..5301909193 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -714,6 +714,10 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( int input_columns, std::uint8_t* output, const InputType* rowwise_min_max) { +#if HAVE_SVE + FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon( + input, input_rows, input_columns, output); +#else if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2( @@ -723,6 +727,7 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef( input, input_rows, input_columns, output); } +#endif } template diff --git a/src/QuantUtilsNeon.cc b/src/QuantUtilsNeon.cc index dfb27fe8f8..8fef86b94f 100644 --- a/src/QuantUtilsNeon.cc +++ b/src/QuantUtilsNeon.cc @@ -11,12 +11,13 @@ #include "fbgemm/Utils.h" #define FBGEMM_EXPORTS +#include // @manual #include // @manual #if HAVE_SVE +#include // @manual #include // @manual #endif -#include // @manual #include //for std::min/std::max #include //for assert #include // for FLT_MAX @@ -32,41 +33,48 @@ namespace fbgemm { using namespace std; //////////////////////////////////////////////////////////////////////////////// // Utility functions - -void FindMinMax(const float* m, float* min, float* max, int64_t len) { - if (__builtin_expect(len <= 0, 0)) { - *min = 0.0f; - *max = 0.0f; - return; - } - +static inline void +FindMinMaxImpl_f32(const float* m, float* min, float* max, uint64_t count) { float first = *m; + float tmp_min_s = first; + float tmp_max_s = first; + float32x4_t temp_min_0 = vdupq_n_f32(first); float32x4_t temp_min_1 = vdupq_n_f32(first); float32x4_t temp_max_0 = vdupq_n_f32(first); float32x4_t temp_max_1 = vdupq_n_f32(first); - uint64_t i = 0; - uint64_t count = static_cast(len); - uint64_t loopBound = count - (count % 8); - - for (; i < loopBound; i += 8) { - float32x4_t v0 = vld1q_f32(m + i); - float32x4_t v1 = vld1q_f32(m + i + 4); - temp_min_0 = vminq_f32(temp_min_0, v0); - temp_min_1 = vminq_f32(temp_min_1, v1); - temp_max_0 = vmaxq_f32(temp_max_0, v0); - temp_max_1 = vmaxq_f32(temp_max_1, v1); + constexpr uint64_t kItemsPerIter = 8; + uint64_t loopIters = count / kItemsPerIter; + uint64_t loopRemainder = count % kItemsPerIter; + + if (__builtin_expect(loopIters > 0, 1)) { + do { + float32x4_t v0 = vld1q_f32(m); + float32x4_t v1 = vld1q_f32(m + 4); + m += kItemsPerIter; + loopIters -= 1; + temp_min_0 = vminq_f32(temp_min_0, v0); + temp_min_1 = vminq_f32(temp_min_1, v1); + temp_max_0 = vmaxq_f32(temp_max_0, v0); + temp_max_1 = vmaxq_f32(temp_max_1, v1); + } while (loopIters > 0); + + temp_min_0 = vminq_f32(temp_min_0, temp_min_1); + temp_max_0 = vmaxq_f32(temp_max_0, temp_max_1); + + tmp_min_s = vminvq_f32(temp_min_0); + tmp_max_s = vmaxvq_f32(temp_max_0); } - temp_min_0 = vminq_f32(temp_min_0, temp_min_1); - temp_max_0 = vmaxq_f32(temp_max_0, temp_max_1); - - float tmp_min_s = vminvq_f32(temp_min_0); - float tmp_max_s = vmaxvq_f32(temp_max_0); - - for (; i < count; i++) { - float tmp = *m; +#ifdef __clang__ +#pragma clang loop vectorize(disable) interleave(disable) unroll(disable) +#elif defined(__GNUC__) +#pragma GCC novector unroll 0 +#endif + while (loopRemainder > 0) { + float tmp = *m++; + loopRemainder -= 1; tmp_min_s = std::min(tmp_min_s, tmp); tmp_max_s = std::max(tmp_max_s, tmp); } @@ -75,8 +83,180 @@ void FindMinMax(const float* m, float* min, float* max, int64_t len) { *max = tmp_max_s; } +void FindMinMax(const float* m, float* min, float* max, int64_t len) { + if (__builtin_expect(len <= 0, 0)) { + *min = 0.0f; + *max = 0.0f; + return; + } + + FindMinMaxImpl_f32(m, min, max, static_cast(len)); +} + #if HAVE_SVE +static inline void +FindMinMaxImpl_f16(const float16_t* m, float* min, float* max, uint64_t count) { + float16_t first = *m; + + float16_t tmp_min_s = first; + float16_t tmp_max_s = first; + + float16x8_t temp_min_0 = vdupq_n_f16(first); + float16x8_t temp_min_1 = vdupq_n_f16(first); + float16x8_t temp_max_0 = vdupq_n_f16(first); + float16x8_t temp_max_1 = vdupq_n_f16(first); + constexpr uint64_t kItemsPerIter = 16; + uint64_t loopIters = count / kItemsPerIter; + uint64_t loopRemainder = count % kItemsPerIter; + + if (__builtin_expect(loopIters > 0, 1)) { + do { + float16x8_t v0 = vld1q_f16(m); + float16x8_t v1 = vld1q_f16(m + 8); + m += kItemsPerIter; + loopIters -= 1; + temp_min_0 = vminq_f16(temp_min_0, v0); + temp_min_1 = vminq_f16(temp_min_1, v1); + temp_max_0 = vmaxq_f16(temp_max_0, v0); + temp_max_1 = vmaxq_f16(temp_max_1, v1); + } while (loopIters > 0); + + temp_min_0 = vminq_f16(temp_min_0, temp_min_1); + temp_max_0 = vmaxq_f16(temp_max_0, temp_max_1); + + tmp_min_s = vminvq_f16(temp_min_0); + tmp_max_s = vmaxvq_f16(temp_max_0); + } + +#ifdef __clang__ +#pragma clang loop vectorize(disable) interleave(disable) unroll(disable) +#elif defined(__GNUC__) +#pragma GCC novector unroll 0 +#endif + while (loopRemainder > 0) { + float16_t tmp = *m++; + loopRemainder -= 1; + tmp_min_s = vminh_f16(tmp_min_s, tmp); + tmp_max_s = vmaxh_f16(tmp_max_s, tmp); + } + + *min = static_cast(tmp_min_s); + *max = static_cast(tmp_max_s); +} + +template +void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon( + const InputType* input, + size_t input_rows, + int input_columns, + uint8_t* output) { + constexpr float kEpsilon = 1e-8f; + + if (input_rows == 0 || input_columns <= 0) { + return; + } + + uint64_t column_count = static_cast(input_columns); + + const uint64_t output_columns = column_count + 2 * sizeof(float); + + for (size_t row = 0; __builtin_expect(row < input_rows, 1); ++row) { + const InputType* input_row = input + row * column_count; + uint8_t* output_row = output + row * output_columns; + + float* output_row_scale_bias = + reinterpret_cast(output_row + column_count); + + float minimum_element; + float maximum_element; + if constexpr (std::is_same()) { + FindMinMaxImpl_f32( + input_row, &minimum_element, &maximum_element, column_count); + } else { + FindMinMaxImpl_f16( + reinterpret_cast(input_row), + &minimum_element, + &maximum_element, + column_count); + } + float range = maximum_element - minimum_element; + + const auto inverse_scale = 255.0f / (range + kEpsilon); + + float32x4_t inverse_scale_v = vdupq_n_f32(inverse_scale); + float32x4_t min_v = vdupq_n_f32(minimum_element); + + constexpr uint64_t kItemsPerIter = 8; + uint64_t loopIters = column_count / kItemsPerIter; + uint64_t loopRemainder = column_count % kItemsPerIter; + + output_row_scale_bias[0] = range / 255.0f; + output_row_scale_bias[1] = minimum_element; + + while (__builtin_expect(loopIters > 0, 1)) { + float32x4_t v0; + float32x4_t v1; + + if constexpr (std::is_same()) { + v0 = vld1q_f32(input_row); + v1 = vld1q_f32(input_row + 4); + } else { + float16x8_t h0 = + vld1q_f16(reinterpret_cast(input_row)); + v0 = vcvt_f32_f16(vget_low_f16(h0)); + v1 = vcvt_high_f32_f16(h0); + } + + input_row += kItemsPerIter; + loopIters -= 1; + + v0 = vsubq_f32(v0, min_v); + v1 = vsubq_f32(v1, min_v); + + v0 = vmulq_f32(v0, inverse_scale_v); + v1 = vmulq_f32(v1, inverse_scale_v); + + int32x4_t i0 = vcvtnq_s32_f32(v0); + int32x4_t i1 = vcvtnq_s32_f32(v1); + + svst1b_s32( + svptrue_b8(), + reinterpret_cast(output_row), + svset_neonq_s32(svundef_s32(), i0)); + svst1b_s32( + svptrue_b8(), + reinterpret_cast(output_row + 4), + svset_neonq_s32(svundef_s32(), i1)); + + output_row += kItemsPerIter; + } + +#ifdef __clang__ +#pragma clang loop vectorize(disable) interleave(disable) unroll(disable) +#elif defined(__GNUC__) +#pragma GCC novector unroll 0 +#endif + while (loopRemainder > 0) { + float32x4_t v0; + if constexpr (std::is_same()) { + v0[0] = *input_row++; + } else { + v0[0] = + static_cast(*reinterpret_cast(input_row)); + input_row += 1; + } + loopRemainder -= 1; + v0 = vsubq_f32(v0, min_v); + v0 = vmulq_f32(v0, inverse_scale_v); + int32x4_t i0 = vcvtnq_s32_f32(v0); + *output_row = i0[0]; + output_row += 1; + } + + } // for each row +} + template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( const std::uint8_t* input, @@ -179,7 +359,12 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( const std::uint8_t* input, \ size_t input_rows, \ int input_columns, \ - type* output); + type* output); \ + template void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon( \ + const type* input, \ + size_t input_rows, \ + int input_columns, \ + uint8_t* output); // clang-format off INSTANTIATE_QuantizationNeonFunctions8Bits(float) From a842d88eb7c3c8fcbdb77ceb145d1714021493b5 Mon Sep 17 00:00:00 2001 From: Hao Yan Date: Wed, 5 Nov 2025 10:11:09 -0800 Subject: [PATCH 75/92] Inference test e2e [1/n] (#5091) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5091 X-link: https://github.com/facebookresearch/FBGEMM/pull/2099 In this test, we run following step 1. Create a DramKVInferenceEmbedding with TTL eviction for 1 min 2. Insert 1 embedding with current Unixtime - 2 mins (it is already expired) as timestamp 3. Read from it and check correctness 4. Read for multiple times 5. Evict it 6. Read it --- this time should be inconsistent Reviewed By: emlin Differential Revision: D86268606 fbshipit-source-id: edc2dc24e5327399421d20229a0b1af2ca29ea7a --- .../kv_embedding_inference_test.cpp | 217 ++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 fbgemm_gpu/test/dram_kv_embedding_cache/kv_embedding_inference_test.cpp diff --git a/fbgemm_gpu/test/dram_kv_embedding_cache/kv_embedding_inference_test.cpp b/fbgemm_gpu/test/dram_kv_embedding_cache/kv_embedding_inference_test.cpp new file mode 100644 index 0000000000..02f0960a8a --- /dev/null +++ b/fbgemm_gpu/test/dram_kv_embedding_cache/kv_embedding_inference_test.cpp @@ -0,0 +1,217 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h" + +#include +#include +#include +#include +#include +#include + +namespace kv_mem { + +class KVEmbeddingInferenceTest : public ::testing::Test { + protected: + static constexpr int EMBEDDING_DIM = 128; + static constexpr int NUM_SHARDS = 8; + + void SetUp() override { + FLAGS_logtostderr = true; + FLAGS_minloglevel = 0; + FLAGS_v = 1; + + auto feature_evict_config = c10::make_intrusive( + 3, + 4, + std::nullopt, + std::nullopt, + std::vector{1}, + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + std::vector{EMBEDDING_DIM}, + std::nullopt, + std::nullopt, + 0, + 0, + 0); + + auto hash_size_cumsum = at::tensor({0, 100000}, at::kLong); + + backend_ = std::make_unique>( + EMBEDDING_DIM, + -0.1, + 0.1, + feature_evict_config, + NUM_SHARDS, + 32, + 32, + false, + std::nullopt, + hash_size_cumsum, + false); + } + + void TearDown() override { + backend_.reset(); + } + + static std::vector generateEmbedding(int64_t embedding_id) { + std::vector embedding(EMBEDDING_DIM); + + // Use both embedding_id and current time as seed for randomness + auto now = std::chrono::system_clock::now(); + auto time_seed = std::chrono::duration_cast( + now.time_since_epoch()) + .count(); + uint32_t combined_seed = static_cast(embedding_id ^ time_seed); + + std::mt19937 rng(combined_seed); + std::uniform_real_distribution dist(-0.1f, 0.1f); + for (int i = 0; i < EMBEDDING_DIM; ++i) { + embedding[i] = dist(rng); + } + return embedding; + } + + std::unique_ptr> backend_; +}; + +TEST_F(KVEmbeddingInferenceTest, InferenceLifecycleWithMetadata) { + const int64_t embedding_id = 12345; + + auto now = std::chrono::system_clock::now(); + auto now_seconds = + std::chrono::duration_cast(now.time_since_epoch()) + .count(); + const uint32_t snapshot_timestamp = static_cast(now_seconds - 120); + + auto embedding_data = generateEmbedding(embedding_id); + + LOG(INFO) << "STEP 1: Define test embedding"; + LOG(INFO) << "Embedding ID: " << embedding_id; + LOG(INFO) << "Timestamp: " << snapshot_timestamp + << " (current time - 2 minutes)"; + LOG(INFO) << "Dimension: " << EMBEDDING_DIM; + LOG(INFO) << "First 5 elements: [" << embedding_data[0] << ", " + << embedding_data[1] << ", " << embedding_data[2] << ", " + << embedding_data[3] << ", " << embedding_data[4] << "]"; + + auto indices_tensor = at::tensor({embedding_id}, at::kLong); + auto weights_tensor = at::from_blob( + embedding_data.data(), + {1, EMBEDDING_DIM}, + at::TensorOptions().dtype(at::kFloat)); + auto count_tensor = at::tensor({1}, at::kInt); + + LOG(INFO) << "STEP 2: Insert embedding into cache"; + folly::coro::blockingWait(backend_->inference_set_kv_db_async( + indices_tensor, weights_tensor, count_tensor, snapshot_timestamp)); + LOG(INFO) << "Insertion completed"; + + auto retrieved_embedding = at::zeros({1, EMBEDDING_DIM}, at::kFloat); + + LOG(INFO) << "STEP 3: Retrieve embedding from cache"; + folly::coro::blockingWait(backend_->get_kv_db_async( + indices_tensor, retrieved_embedding, count_tensor)); + LOG(INFO) << "Retrieval completed"; + + auto retrieved_ptr = retrieved_embedding.data_ptr(); + bool all_match = true; + int mismatch_count = 0; + + LOG(INFO) << "STEP 4: Verify embedding consistency"; + for (int i = 0; i < EMBEDDING_DIM; ++i) { + if (std::abs(retrieved_ptr[i] - embedding_data[i]) > 1e-5f) { + all_match = false; + mismatch_count++; + } + } + + if (all_match) { + LOG(INFO) << "All " << EMBEDDING_DIM << " dimensions match"; + } else { + LOG(ERROR) << "Found " << mismatch_count << " mismatches out of " + << EMBEDDING_DIM << " dimensions"; + } + + ASSERT_TRUE(all_match) << "Retrieved embedding must match inserted embedding"; + + LOG(INFO) << "STEP 5: Test repeated reads"; + for (int iteration = 1; iteration <= 3; ++iteration) { + auto read_again = at::zeros({1, EMBEDDING_DIM}, at::kFloat); + folly::coro::blockingWait( + backend_->get_kv_db_async(indices_tensor, read_again, count_tensor)); + + auto read_ptr = read_again.data_ptr(); + bool matches = true; + for (int i = 0; i < EMBEDDING_DIM; ++i) { + if (std::abs(read_ptr[i] - embedding_data[i]) > 1e-5f) { + matches = false; + break; + } + } + LOG(INFO) << "Read #" << iteration << ": " + << (matches ? "Match" : "Mismatch"); + } + + LOG(INFO) << "STEP 6: Trigger eviction"; + auto eviction_time = std::chrono::system_clock::now(); + auto eviction_seconds = std::chrono::duration_cast( + eviction_time.time_since_epoch()) + .count(); + uint32_t eviction_threshold = static_cast(eviction_seconds - 60); + + LOG(INFO) << "Eviction threshold: " << eviction_threshold; + backend_->trigger_feature_evict(eviction_threshold); + backend_->wait_until_eviction_done(); + LOG(INFO) << "Eviction completed"; + + auto post_eviction_embedding = at::zeros({1, EMBEDDING_DIM}, at::kFloat); + + LOG(INFO) << "STEP 7: Read embedding after eviction"; + folly::coro::blockingWait(backend_->get_kv_db_async( + indices_tensor, post_eviction_embedding, count_tensor)); + + auto post_eviction_ptr = post_eviction_embedding.data_ptr(); + bool values_changed = false; + int differences = 0; + + for (int i = 0; i < EMBEDDING_DIM; ++i) { + if (std::abs(post_eviction_ptr[i] - embedding_data[i]) > 1e-5f) { + values_changed = true; + differences++; + } + } + + LOG(INFO) << "Differences found: " << differences << "/" << EMBEDDING_DIM; + + if (values_changed) { + LOG(INFO) << "Eviction successful - values changed"; + } else { + LOG(ERROR) << "Eviction may have failed - values unchanged"; + } + + LOG(INFO) << "Original (cached): [" << embedding_data[0] << ", " + << embedding_data[1] << ", " << embedding_data[2] << ", " + << embedding_data[3] << ", " << embedding_data[4] << "]"; + LOG(INFO) << "After eviction: [" << post_eviction_ptr[0] << ", " + << post_eviction_ptr[1] << ", " << post_eviction_ptr[2] << ", " + << post_eviction_ptr[3] << ", " << post_eviction_ptr[4] << "]"; + + ASSERT_TRUE(values_changed) << "Embedding should be different after eviction"; + + LOG(INFO) << "Test completed successfully"; +} + +} // namespace kv_mem From 2dd87764a97f64414c47a0b645c0864f19c9412a Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Wed, 5 Nov 2025 15:32:03 -0800 Subject: [PATCH 76/92] Merge VBE output [backend] reland (#5093) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5093 X-link: https://github.com/facebookresearch/FBGEMM/pull/2100 ---- # Context on the changes: Currently, Torchrec merges the outputs of individual VBE TBE ops to be ordered by ranks using [_merge_variable_batch_embeddings](https://www.internalfb.com/code/fbsource/[3bd69d7fa3534144dcb0162ca59803a6c3ff6e70]/fbcode/torchrec/distributed/embedding_lookup.py?lines=593-604). This function seems to cause ~30% QPS regression compared to baseline (HBM+UVM) for Jupiter V1 model with VBE enabled. To get rid of _merge_variable_batch_embeddings() function, we pre-allocate the `vbe_output` tensor which holds outputs from all VBE ops and calculate `vbe_output_offsets` to allow each individual VBE ops to write to the correct location in the `vbe_output` tensor. By default, `vbe_output` and `vbe_output_offsets` are `None`, which means VBE ops will return individual tensor the way it currently does. The feature is enabled when `vbe_output` and `vbe_output_offsets` are not `None`. --- **NOTE** 1. This feature is currently supported for Sparse TBE. 2. The support is limited for CUDA. 3. For backward compatibility, we append the newly introduced `vbe_output` to the existing API. Hence, we need to make the `vbe_output` tensor as `optional` with default value as `None` (there's no default value for Tensor). 4. We *cannot* annotate `vbe_output` because PyTorch registration does not support annotation of optional tensor. Adding annotation will incur the following error below. This may cause some issues to support this on MTIA, if MTIA relies on tensor annotation. ``` E0903 09:50:32.966235 2850885 ExceptionTracer.cpp:227] exception stack complete terminate called after throwing an instance of 'std::runtime_error' what(): expected ident but found '(' here: split_embedding_codegen_lookup_adagrad_function_pt2( Tensor placeholder_autograd_tensor, Tensor[](a!) weights, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype, Tensor?[](e!) aux_tensor, int[] aux_int, float[] aux_float, bool[] aux_bool, Tensor[](g!) momentum1, Tensor learning_rate_tensor, float[] optim_float, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, Tensor?(t!) vbe_output=None ) -> Tensor ~ <--- HERE ``` See https://docs.google.com/document/d/1h5YyeCjYmmN-CIFB98CrBf1uMksidPbNvM1rl8yZeds/edit?tab=t.0#heading=h.tdfkkc6ujdyl ---- This diff is a reland of D79704318 which all issues have been addressed. ## 1) pyper validation test D79704318 was reverted as it broke pyper validation test (frontend/backend package compatibility issue), which blocks pyper releases. The issue is addressed in this diff. Context: In pyper, changes in python would be included in frontend package (e.g., ads_dper3) and C++ in backend package (e.g., training_platform). If the diff contains both python and C++, there's a chance that some model will use mismatching packages. In other words, frontend package does not include the diff but backend does, and vice versa. D83881544 is only enabling backend support (i.e., no one can actually use this feature, so TBE VBE will work as usual). Due to new Unified API changes, we need to pipeline optional tensor from frontend and requires python change. Denote - #0 as no D83881544 included - #1 as D83881544 included There are 4 scenarios: (1) frontend #0 + old backend #0 - no issue (2) frontend #1 + backend #1 - no issue (3) frontend #0 + backend #1 - handled; TBE VBE will work normally. (4) frontend #1 + backend #0 - no issue; the diff added warning that backend is old There's another diff D79869613 in the stack that will enable frontend support (i.e., allow users to use this feature), which will go into __frontend package only__. Now, the 1)-4) scenarios would remain the same, but new scenarios occur. Denote - #2 as D79869613 included (5) frontend #2 + backend #1 - no issue, same as (2). (6) frontend #2 (no feature enabled) + backend #0 - same as (4). (7) frontend #2 (feature enabled) + backend #0 - **assertion error due to no backend support**, to prevent silent wrong behavior. **To use the feature, this diff stack (D83881544 and D79869613) need to be included in both frontend and backend package.** ## 2) SEV D79704318 caused SEV due to TBE v1 and v2 interfacing compatibility issue on lex_ig_o3_package. Unit tests to ensure v1 compatibility was added D83020965. D83881544 passes the v1 compatibility test. Detail on the root cause and fix: https://docs.google.com/document/d/1XcYNfyiAn4aRFvjV0QG5aLiWKuuOWtJdLOMKNszZRpI/edit?tab=t.0#heading=h.psr4a2qn0mdk ------ Reviewed By: q10, renganxu Differential Revision: D83881544 fbshipit-source-id: 5d63841bbf79a72219903e9d0f77ee3b998bc105 --- .../genscript/generate_backward_split.py | 1 + .../codegen/genscript/optimizer_args.py | 4 +- ...embedding_backward_split_host_template.cpp | 15 ++- .../embedding_forward_split_meta_template.cpp | 8 +- .../embedding_forward_split_template.cu | 29 ++++- ...dding_split_host_pt2_autograd_template.cpp | 119 ++++++++++++++---- ...ng_split_host_pt2_cpu_wrapper_template.cpp | 13 +- ...g_split_host_pt2_cuda_wrapper_template.cpp | 20 +++ .../training/pt2/pt2_arg_utils_template.h | 2 +- ..._embedding_codegen_lookup_invoker.template | 25 ++-- .../fbgemm_gpu/split_embeddings_utils.h | 5 +- .../generate_vbe_metadata.cu | 63 ++++++++-- .../split_embeddings_utils_cpu.cpp | 14 ++- .../split_embeddings_utils_meta.cpp | 3 +- 14 files changed, 255 insertions(+), 66 deletions(-) diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index 5acb6f2e7f..50506decb1 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -426,6 +426,7 @@ def generate() -> None: "lxu_cache_locations", # 3 "uvm_cache_stats", # 4 "prev_iter_dev", # 5 + "vbe_output_offsets", # 6 ], "aux_int": [ "iter", # 0 diff --git a/fbgemm_gpu/codegen/genscript/optimizer_args.py b/fbgemm_gpu/codegen/genscript/optimizer_args.py index 9d7235af84..9c8924e49f 100644 --- a/fbgemm_gpu/codegen/genscript/optimizer_args.py +++ b/fbgemm_gpu/codegen/genscript/optimizer_args.py @@ -73,9 +73,7 @@ class OptimizerArgsSetItem: "row_counter_dev": "(q!)", "row_counter_uvm": "(r!)", "optim_tensor": "(s!)", - "delta_weights_host": "(t!)", - "delta_weights_dev": "(u!)", - "delta_weights_uvm": "(v!)", + "vbe_output": "(t!)", } ###################################################################### diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 05b93d9d7e..2ea96a107e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -109,7 +109,12 @@ enum SSDTensor { gwd_lower_bound, {%- endif %} {# /* if is_gwd */ #} {%- endif %} {# /* if not nobag */ #} + {%- if vbe and not dense %} + {{ "is_experimental" if has_experimental else "false" }}, + std::nullopt /* vbe_output */ + {%- else %} {{ "is_experimental" if has_experimental else "false" }} + {%- endif %} ); if (is_annotate_trace_enabled) { @@ -474,7 +479,12 @@ Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward{{ desc_suffix }}_cud const int64_t iter, const double gwd_lower_bound, {%- endif %} + {%- if vbe and not dense %} + const bool is_experimental, + std::optional vbe_output = std::nullopt + {%- else %} const bool is_experimental + {%- endif %} ); Tensor @@ -708,7 +718,7 @@ class {{ autograd_func }} : static auto generate_vbe_metadata_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::generate_vbe_metadata", "") - .typed(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const int64_t, const bool, const c10::SymInt, const int64_t, const c10::SymInt)>(); + .typed(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const int64_t, const bool, const c10::SymInt, const int64_t, const c10::SymInt, const std::optional&)>(); auto [ vbe_row_output_offsets, @@ -729,7 +739,8 @@ class {{ autograd_func }} : {%- endif %} max_B_feature_rank, info_B_num_bits, - /*total_B=*/offsets.sym_size(0) - 1 + /*total_B=*/offsets.sym_size(0) - 1, + std::nullopt /* pre-allocated vbe_output is not supported in TBE interface V1 or Dense TBE */ ); {%- endif %} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp index 09630b57cf..e2705d16fd 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +// clang-format off {# // @lint-ignore LINTIGNORE // @lint-ignore-every CLANGFORMAT @@ -103,7 +104,12 @@ Tensor const int64_t iter, const double gwd_lower_bound, {%- endif %} + {%- if vbe and not dense %} + const bool is_experimental, + std::optional vbe_output + {%- else %} const bool is_experimental + {%- endif %} ) { // NB: omitted the device tests TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL {%- if not nobag %} @@ -210,4 +216,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- endfor %} {#-/* for is_gwd */#} {%- endif %} {#/* if (not nobag or (not weighted and not vbe)) */#} {%- endfor %} {#-/* for nobag */#} - // clang-format on + // clang-format on diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index f4de721bc9..83315d9a13 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -6,10 +6,10 @@ * LICENSE file in the root directory of this source tree. */ -{# // @lint-ignore LINTIGNORE // @lint-ignore-every CLANGFORMAT // clang-format off +{# // Note: clang-format off doesn't work with this templaterized code, // so we need to keep lint-ignore-every. // See https://fburl.com/dw9ljh4h @@ -395,7 +395,12 @@ batch_index_select_dim0_codegen_forward_cuda( const int64_t iter, const double gwd_lower_bound, {%- endif %} + {%- if vbe and not dense %} + const bool is_experimental, + std::optional vbe_output + {%- else %} const bool is_experimental + {%- endif %} {%- endif %} {#- /*if is_index_select*/ #} ) { {%- if not nobag or is_index_select %} @@ -543,11 +548,24 @@ batch_index_select_dim0_codegen_forward_cuda( o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); {%- if vbe %} - // Use a 2D tensor to make it compatible with 2D PackedTensorsAccessor of other output + {%- if dense %} output = at::empty( {1, vbe_output_size}, dev_weights.options().dtype(getScalarType(o_dtype)) - ); + ); + {%- else %} + // Use a 2D tensor to make it compatible with 2D PackedTensorsAccessor of other output + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(vbe_row_output_offsets, vbe_output); + if (vbe_output.has_value()){ + output = vbe_output.value().reshape({1, -1}); + } + else { + output = at::empty( + {1, vbe_output_size}, + dev_weights.options().dtype(getScalarType(o_dtype)) + ); + } + {%- endif %} {#-/* if dense */#} {%- else %} int64_t total_adjusted_D = total_D; if (o_dtype == SparseType::INT8) { @@ -891,7 +909,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " int iter, " " float gwd_lower_bound, " {%- endif %} + {%- if vbe and not dense %} + " bool is_experimental," + " Tensor? vbe_output" + {%- else %} " bool is_experimental" + {%- endif %} ") -> Tensor" {%- if not dense and not nobag and not vbe %} // only split_embedding_codegen_forward_[un]weighted_cuda diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 789877f69f..661f7b9b45 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -137,7 +137,12 @@ enum SSDTensor { const double /*gwd_lower_bound*/, {%- endif %} const bool /*is_experimental*/, + {%- if vbe and not dense %} + const int64_t /*output_dtype*/, + std::optional /*vbe_output*/ + {%- else %} const int64_t /*output_dtype*/ + {%- endif %} )>(); auto output = embedding_codegen_forward_op.call( @@ -186,7 +191,12 @@ enum SSDTensor { {%- endif %} {# /* if is_gwd */ #} {%- endif %} {# /* if not nobag */ #} is_experimental, + {%- if vbe and not dense %} + output_dtype, + vbe_output + {%- else %} output_dtype + {%- endif %} ); if (is_annotate_trace_enabled) { @@ -259,7 +269,7 @@ enum SSDTensor { const bool /*use_homogeneous_placements*/, {%- if ssd %} const bool /*enable_optimizer_offloading*/, - {%- endif %} + {%- endif %} {%- if is_gwd %} {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} const Tensor& /*prev_iter_dev*/, @@ -359,6 +369,14 @@ enum SSDTensor { // The number of items in the tensorlist differ between devices and is determined at runtime std::vector ret; + {%- if vbe and not dense %} + // To avoid overhead of merging multiple VBE embedding outputs, each embedding ops return + // the same output tensor i.e., vbe_output. To ensure all backward ops are triggered, the embedding + // ops are called in chain. We hence need to pass the grad_outputs to the next embedding op. + // So, if vbe_output is passed, we return the grad_outputs. + Tensor grad_vbe_output = has_vbe_output ? grad_outputs[0] : Variable(); + {%- endif %} + {%- if not dense %} ret.push_back(Variable()); // placeholder autograd tensor {%- endif %} @@ -400,18 +418,21 @@ enum SSDTensor { ret.push_back(Variable()); // max_B ret.push_back(Variable()); // max_B_feature_rank ret.push_back(Variable()); // vbe_output_size + {%- if not dense %} + ret.push_back(grad_vbe_output); // vbe_output + {%- endif %} {# /* if not dense */ #} {%- endif %} {# /* if vbe */ #} {%- if not dense %} ret.push_back(Variable()); // aux_tensor ret.push_back(Variable()); // aux_int ret.push_back(Variable()); // aux_float ret.push_back(Variable()); // aux_bool - {%- endif %} + {%- endif %} {# /* if not dense */ #} {%- if ssd %} {%- for tensor in ssd_tensors %} ret.push_back(Variable()); // {{ tensor }} {%- endfor %} - {%- endif %} + {%- endif %} {# /* if ssd */ #} {{ args_pt2.unified_pt2.split_variables | join("\n") }} return ret; {%- endmacro %} @@ -472,6 +493,9 @@ enum SSDTensor { max_B, max_B_feature_rank, vbe_output_size, + {%- if not dense %} + vbe_output, + {%- endif %} {# /* if not dense */ #} {%- endif %} {# /* if vbe */ #} {%- if not dense %} aux_tensor, @@ -504,7 +528,7 @@ enum SSDTensor { TENSORS_EMPTY_OR_ON_SAME_DEVICE({{ name }}[0], {{ name }}[2]); {{ name }}_host = {{ name }}[0]; {{ name }}_placements = {{ name }}[1]; - {{ name }}_offsets = {{ name }}[2]; + {{ name }}_offsets = {{ name }}[2]; } else if ({{ name }}.size() == {{ 5 if name == "weights" else 4 }}) { TENSOR_ON_CUDA_GPU({{ name }}[0]); @@ -514,7 +538,7 @@ enum SSDTensor { {%- if name == "weights" %} TENSORS_EMPTY_OR_ON_SAME_DEVICE({{ name }}[0], {{ name }}[4]); {%- endif %} - {{ name }}_dev = {{ name }}[0]; + {{ name }}_dev = {{ name }}[0]; {{ name }}_uvm = {{ name }}[1]; {{ name }}_placements = {{ name }}[2]; {{ name }}_offsets = {{ name }}[3]; @@ -548,7 +572,7 @@ enum SSDTensor { {%- endmacro %} //////////////////////////////////////////////////////////////////////////////// -// MACROS +// MACROS //////////////////////////////////////////////////////////////////////////////// #define GET_OPTIONAL_TENSOR_VALUE(name, empty_tensor) name.has_value() ? name.value() : empty_tensor; @@ -631,6 +655,9 @@ class {{ autograd_func }} : const c10::SymInt max_B, const c10::SymInt max_B_feature_rank, const c10::SymInt vbe_output_size, + {%- if not dense %} + std::optional vbe_output, + {%- endif %} {# /* if not dense */ #} {%- endif %} {# /* if vbe */ #} {%- if not dense %} std::vector> aux_tensor, @@ -662,6 +689,24 @@ class {{ autograd_func }} : const auto vbe_output_offsets_feature_rank_ = GET_OPTIONAL_TENSOR_VALUE(aux_tensor[IDX_VBE_OUTPUT_OFFSETS_FEATURE_RANK], Tensor()); const auto vbe_B_offsets_rank_per_feature_ = GET_OPTIONAL_TENSOR_VALUE(aux_tensor[IDX_VBE_B_OFFSETS_RANK_PER_FEATURE], Tensor()); const c10::SymInt max_B_ = max_B; + {%- if not dense %} + // The pipeline relies on frontend to supply vbe_output_offsets through aux_tensor + // However, if a model uses old frontend package (i.e., does not include frontend changes from this diff) + // with new backend package, aux_tensor will not contain vbe_output_offsets. + // This means old frontend will send aux_tensor of size 6, but the new backend (from this diff) expects 7, + // which accessing aux_tensor[IDX_VBE_OUTPUT_OFFSETS] can cause segmentation fault + const std::optional vbe_output_offsets = aux_tensor.size() == AUX_TENSOR_SIZE ? aux_tensor[IDX_VBE_OUTPUT_OFFSETS] : std::nullopt; + TORCH_CHECK( + vbe_output.has_value() == vbe_output_offsets.has_value(), + "Expected both vbe_output and vbe_output_offsets to either be None or have value. However, vbe_output ", + vbe_output.has_value() ? " has value" : " is None", + " but vbe_output_offsets ", + vbe_output_offsets.has_value() ? " has value." : " is None. ", + "Note: Frontend passes aux_tensor of size ", aux_tensor.size(), + "and backend expects aux_tensor of ", AUX_TENSOR_SIZE, + ". If the aux_tensor size mismatch, please update your frontend/backend package. Contact FBGEMM team for any assistance." + ); + {%- endif %} {%- else %} const auto max_B_ = offsets.sym_size(0) / T; {%- endif %} @@ -720,7 +765,8 @@ class {{ autograd_func }} : const bool, const c10::SymInt, const int64_t, - const c10::SymInt)>(); + const c10::SymInt, + const std::optional&)>(); auto [ vbe_row_output_offsets, vbe_b_t_map @@ -740,7 +786,8 @@ class {{ autograd_func }} : {%- endif %} max_B_feature_rank, info_B_num_bits, - /*total_B=*/offsets.sym_size(0) - 1 + /*total_B=*/offsets.sym_size(0) - 1, + vbe_output_offsets ); {%- endif %} // vbe @@ -756,9 +803,9 @@ class {{ autograd_func }} : const auto indice_weights_value = GET_OPTIONAL_TENSOR_VALUE(indice_weights, Tensor()); {%- endif %} - // Setting learning rate tensor with `.fill_()` breaks apf_dlrm bento kernel with + // Setting learning rate tensor with `.fill_()` breaks apf_dlrm bento kernel with // `RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation.` - // This is because if a tensor is saved for backward and it is mutated later, this can cause correctness problems. + // This is because if a tensor is saved for backward and it is mutated later, this can cause correctness problems. // Since the forward compute and backward compute see different data values for this tensor. // To work around, we pass the cloned tensor instead the mutated tensor {%- if "learning_rate_tensor" in args_pt2.unified_pt2.split_unpacked_arg_names %} @@ -845,8 +892,12 @@ class {{ autograd_func }} : ctx->saved_data["output_dtype"] = output_dtype; {%- endif %} {%- if vbe %} - ctx->saved_data["max_B"] = max_B_; // for reshaping vbe cpu offsets and grad_output - {%- endif %} + ctx->saved_data["max_B"] = max_B_; // for reshaping vbe cpu offsets and grad_output + // This is needed to determine whether to return grads_output + {%- if not dense %} + ctx->saved_data["has_vbe_output"] = vbe_output.has_value(); + {%- endif %} {# /* if not dense */ #} + {%- endif %} {# /* if vbe */ #} {%- if not dense %} // unpack optim args @@ -979,13 +1030,16 @@ static torch::autograd::variable_list backward( {%- if is_gwd %} const auto gwd_lower_bound = ctx->saved_data["gwd_lower_bound"].toDouble(); {%- endif %} - + {%- if not nobag %} auto output_dtype = ctx->saved_data["output_dtype"].toInt(); {%- endif %} {%- if not dense %} {%- if vbe %} auto max_B = ctx->saved_data["max_B"].toSymInt(); // for reshaping vbe cpu offsets and grad_output + {%- if not dense %} + const auto has_vbe_output = ctx->saved_data["has_vbe_output"].toBool(); // for whether to return grad_output + {%- endif %} {# /* if not dense */ #} {%- endif %} {%- for (var, _ , ivalue_cast, type) in args_pt2.unified_pt2.split_saved_data %} @@ -1120,7 +1174,7 @@ static torch::autograd::variable_list backward( feature_requires_grad {%- endif %} ); - + Tensor grad_weights_dev; // weighted if (indice_weights.defined()) @@ -1170,7 +1224,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( {%- else %} const Tensor& placeholder_autograd_tensor, const at::TensorList weights, - {%- endif %} + {%- endif %} {#-/* if dense */#} const Tensor& D_offsets, const c10::SymInt total_D, const c10::SymInt max_D, @@ -1187,20 +1241,32 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( const std::vector& aux_int, const std::vector& aux_float, c10::List aux_bool, - {%- endif %} + {%- endif %} {#-/* if not dense */#} {{ args_pt2.unified_pt2.split_function_args | join(", ") }}, const c10::SymInt max_B = -1, const c10::SymInt max_B_feature_rank = -1, - {%- if ssd %} + {%- if not dense %} const c10::SymInt vbe_output_size = -1, - const std::optional& ssd_tensors = std::nullopt + {%- if ssd %} + const std::optional& ssd_tensors = std::nullopt, + {%- endif %} {#-/* if ssd */#} + std::optional vbe_output = std::nullopt {%- else %} + {#- /* ssd and pre-allocated vbe_output is not yet supported in Dense TBE */ -#} const c10::SymInt vbe_output_size = -1 - {%- endif %} + {%- endif %} {#-/* if not dense */#} ) { {%- if has_gpu_support or has_cpu_support %} + TORCH_WARN(aux_tensor.size() <= AUX_TENSOR_SIZE, + "aux_tensor.size() should not be larger than ", + AUX_TENSOR_SIZE, + "but found to be ", + aux_tensor.size(), + ". This means frontend package does not match with backend package, so some functionalities might be missing. Please contact FBGEMM team for any assistance." + ); + {%- if not dense %} // Load the config value from JK once static auto is_tbev2_enabled = config::is_feature_enabled(config::FeatureGateName::TBE_V2); @@ -1252,7 +1318,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( "{{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function is deprecated. Please see https://github.com/pytorch/FBGEMM/discussions/1727 for more detail." ); return Tensor(); - {%- endif %} + {%- endif %} } @@ -1282,16 +1348,19 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " int[] aux_int, " " float[] aux_float, " " bool[] aux_bool, " + {%- endif %} {#-/* if not dense */#} " {{ args_pt2.unified_pt2.split_function_schemas | join(", ") }}, " " SymInt max_B=-1, " " SymInt max_B_feature_rank=-1, " - {%- if ssd %} + {%- if not dense %} " SymInt vbe_output_size=-1, " - " Tensor[]? ssd_tensors=None " - {%- else %} - " SymInt vbe_output_size=-1 " - {%- endif %} + {%- if ssd %} + " Tensor[]? ssd_tensors=None, " {%- endif %} + " Tensor? vbe_output=None " + {%- else %} + " SymInt vbe_output_size=-1 " + {%- endif %} {#-/* if not dense */#} ") -> Tensor", {PT2_COMPLIANT_TAG}); // We're playing a funny trick here: we're using the autograd diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp index 14deb1af5e..c06dd5efef 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp @@ -143,7 +143,13 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu const Tensor& B_offsets, {%- endif %} const bool /*is_experimental = false*/, - const int64_t output_dtype = static_cast(SparseType::FP32)) { + {%- if vbe %} + const int64_t output_dtype = static_cast(SparseType::FP32), + std::optional vbe_output = std::nullopt + {%- else %} + const int64_t output_dtype = static_cast(SparseType::FP32) + {%- endif %} + ){ Tensor offsets_; {%- if vbe %} const int64_t max_B_int = max_B.guard_int(__FILE__, __LINE__); @@ -406,7 +412,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor B_offsets, " {%- endif %} " bool is_experimental, " + {%- if vbe %} + " int output_dtype, " + " Tensor? vbe_output " + {%- else %} " int output_dtype " + {%- endif %} ") -> Tensor" {%- if not nobag and not vbe %} // only split_embedding_codegen_forward_[un]weighted_cuda diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp index 1a0cb0fa80..b7070deb83 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp @@ -107,7 +107,12 @@ Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt const double gwd_lower_bound, {%- endif %} const bool is_experimental, + {%- if vbe and not dense %} + const int64_t output_dtype, + std::optional vbe_output + {%- else %} const int64_t output_dtype + {%- endif %} ){ {%- set op = "{}_embedding{}_codegen_forward_{}_cuda".format( fwd_mdesc, ndesc, desc_suffix @@ -155,7 +160,12 @@ Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt const int64_t /*iter*/, const double /*gwd_lower_bound*/, {%- endif %} + {%- if vbe and not dense %} + const bool, + std::optional /*vbe_output*/ + {%- else %} const bool + {%- endif %} )>(); return op.call( @@ -201,7 +211,12 @@ Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt iter, gwd_lower_bound, {%- endif %} {# /* if is_gwd */ #} + {%- if vbe and not dense %} + is_experimental, + vbe_output + {%- else %} is_experimental + {%- endif %} ); }; {%- else %} @@ -561,7 +576,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " float gwd_lower_bound, " {%- endif %} " bool is_experimental, " + {%- if vbe and not dense %} + " int output_dtype, " + " Tensor? vbe_output" + {%- else %} " int output_dtype " + {%- endif %} ") -> Tensor" {%- if not nobag and not vbe %} // only split_embedding_codegen_forward_[un]weighted_cuda diff --git a/fbgemm_gpu/codegen/training/pt2/pt2_arg_utils_template.h b/fbgemm_gpu/codegen/training/pt2/pt2_arg_utils_template.h index ec033c89d2..675bb7df9b 100644 --- a/fbgemm_gpu/codegen/training/pt2/pt2_arg_utils_template.h +++ b/fbgemm_gpu/codegen/training/pt2/pt2_arg_utils_template.h @@ -21,7 +21,7 @@ enum ArgIndex_{{ name }} { {%- for var in aux_args[name] %} IDX_{{ var | upper }} = {{ loop.index - 1 }}, {%- endfor %} - {{ name | upper }}_SIZE = {{ name | length }} + {{ name | upper }}_SIZE = {{ aux_args[name] | length }} }; {%- endfor %} diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index 0ecf71bdb5..6fe7292db4 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -42,7 +42,7 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") {%- endif %} -{# This macro generates a code blob to pack Tensor arguments into a TensorList +{# This macro generates a code blob to pack Tensor arguments into a TensorList as number of arguments for some optimizers exceed 64 #} {%- macro pack_tensors(arg) %} {{ arg }}_list = [ @@ -58,7 +58,7 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") {%- endmacro %} {# This macro generates a code blob to pack optim optional tensor into an optional TensorList. - All optim optional tensors are packed together into `optim_tensor`. + All optim optional tensors are packed together into `optim_tensor`. This poses challenge to handle unpacking in autograd if we do per device (i.e, 3 for cpu and 4 for cuda). Hence, we pack unified args (i.e., 5 items) for readability and programmability. #} @@ -92,14 +92,14 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") "Please check the frontend and backend version. " ) {{ arg_type }}.append(dict_{{ arg_type }}["{{ var }}"]) - + {%- endfor %} {%- endmacro %} {%- if is_prototype_optimizer %} # Decorate the prototype optimizers which may be deprecated in the future with jit.ignore to avoid -# possible errors from torch.jit.script. +# possible errors from torch.jit.script. # Note that backends can be removed but the lookup invoker is still needed for backward compatibility @torch.jit.ignore {%- endif %} @@ -187,14 +187,15 @@ def invoke( "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature, "lxu_cache_locations": common_args.lxu_cache_locations, "uvm_cache_stats": common_args.uvm_cache_stats, + "vbe_output_offsets" : None, } dict_aux_int: Dict[str, int] = { - "iter": iter, - "info_B_num_bits": common_args.info_B_num_bits, + "iter": iter, + "info_B_num_bits": common_args.info_B_num_bits, "info_B_mask": common_args.info_B_mask, } - + dict_aux_float: Dict[str, float] = { "gwd_lower_bound": gwd_lower_bound, } @@ -219,7 +220,7 @@ def invoke( {%- else %} dict_aux_tensor["prev_iter_dev"] = prev_iter.dev {%- endif %} - + # optimizer_args {%- if optimizer == "none" %} @@ -302,13 +303,13 @@ def invoke( {{ pack_tensors("row_counter") }} {%- endif %} {%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %} - + if optimizer_args.use_rowwise_bias_correction and row_counter is not None: row_counter_host = None # not supported on CPU row_counter_dev = row_counter.dev row_counter_uvm = row_counter.uvm row_counter_offsets = row_counter.offsets - row_counter_placements = row_counter.placements + row_counter_placements = row_counter.placements elif optimizer_args.use_rowwise_bias_correction: assert False, "`use_rowwise_bias_correction` is set, `row_counter` cannot be None" else: @@ -316,7 +317,7 @@ def invoke( row_counter_dev = None row_counter_uvm = None row_counter_offsets = None - row_counter_placements = None + row_counter_placements = None {%- endif %} {{ pack_to_list("aux_tensor") }} @@ -358,7 +359,7 @@ def invoke( {%- for name in args_pt2.unified_pt2.split_args_dict["optim_bool"] %} optim_bool.append(dict_optim_bool["{{ name }}"]) {%- endfor %} - {%- endif %} + {%- endif %} return torch.ops.fbgemm.{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( # common_args diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h index 108b8eba5e..4334efd4b8 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h @@ -32,6 +32,7 @@ generate_vbe_metadata( const at::Tensor& D_offsets, const int64_t D, const bool nobag, - const int64_t max_B_feature_rank, + const c10::SymInt max_B_feature_rank, const int64_t info_B_num_bits, - const int64_t total_B); + const c10::SymInt total_B, + const std::optional& vbe_output_offsets); diff --git a/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu b/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu index 17b1fd0edb..ab1ecc7f1d 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu @@ -34,7 +34,9 @@ __launch_bounds__(kMaxThreads) void generate_vbe_metadata_foreach_sample_kernel( D_offsets, const int32_t D, const bool nobag, - const int32_t info_B_num_bits) { + const int32_t info_B_num_bits, + const pta::PackedTensorAccessor32 + predefined_vbe_output_offsets) { // Relative sample ID in the rank-table matrix const auto b = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; // Rank ID @@ -50,6 +52,8 @@ __launch_bounds__(kMaxThreads) void generate_vbe_metadata_foreach_sample_kernel( return; } + const bool use_predefined_offsets = predefined_vbe_output_offsets.size(0) > 0; + const auto* __restrict__ output_offsets_feature = &output_offsets_feature_rank[r * T]; @@ -57,8 +61,9 @@ __launch_bounds__(kMaxThreads) void generate_vbe_metadata_foreach_sample_kernel( const auto b_t = static_cast(B_start_t) + static_cast(B_start_r_t) + b; const auto D_ = nobag ? D : (D_offsets[t + 1] - D_offsets[t]); - row_output_offsets[b_t] = - output_offsets_feature[t] + b * static_cast(D_); + auto offset = use_predefined_offsets ? predefined_vbe_output_offsets[r][t] + : output_offsets_feature[t]; + row_output_offsets[b_t] = offset + b * static_cast(D_); // Relative sample ID in the table const auto b_ = B_start_r_t + b; @@ -114,11 +119,15 @@ generate_vbe_metadata( const Tensor& D_offsets, const int64_t D, const bool nobag, - const int64_t max_B_feature_rank, + const c10::SymInt max_B_feature_rank, const int64_t info_B_num_bits, - const int64_t total_B) { + const c10::SymInt total_B, + const std::optional& vbe_output_offsets = std::nullopt) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( - B_offsets, B_offsets_rank_per_feature, output_offsets_feature_rank); + B_offsets, + B_offsets_rank_per_feature, + output_offsets_feature_rank, + vbe_output_offsets); TENSOR_NDIM_EQUALS(B_offsets, 1); TENSOR_NDIM_EQUALS(B_offsets_rank_per_feature, 2); @@ -132,25 +141,53 @@ generate_vbe_metadata( TORCH_CHECK(D_offsets.numel() == T + 1) } + const int64_t total_B_ = total_B.guard_int(__FILE__, __LINE__); + const int64_t max_B_feature_rank_ = + max_B_feature_rank.guard_int(__FILE__, __LINE__); + const auto num_ranks = B_offsets_rank_per_feature.size(1) - 1; TORCH_CHECK( num_ranks > 0, "generate_vbe_metadata: Invalid num_ranks ", num_ranks); TORCH_CHECK(T > 0, "generate_vbe_metadata: Invalid T ", T); TORCH_CHECK( - max_B_feature_rank > 0, + max_B_feature_rank_ > 0, "generate_vbe_metadata: Invalid max_B_feature_rank ", - max_B_feature_rank); + max_B_feature_rank_); TORCH_CHECK(B_offsets_rank_per_feature.size(0) == T); TORCH_CHECK(output_offsets_feature_rank.numel() == num_ranks * T + 1); + Tensor predefined_vbe_output_offsets; + if (vbe_output_offsets.has_value()) { + predefined_vbe_output_offsets = vbe_output_offsets.value(); + TORCH_CHECK( + predefined_vbe_output_offsets.dim() == 2, + "Expected a tensor of 2 dims: [num_ranks, num_features] but got ", + predefined_vbe_output_offsets.dim()); + TORCH_CHECK( + predefined_vbe_output_offsets.size(0) == num_ranks, + "Expected predefined_vbe_output_offsets.size(0) to be", + num_ranks, + " but got ", + predefined_vbe_output_offsets.size(0)); + TORCH_CHECK( + predefined_vbe_output_offsets.size(1) == T, + "Expected predefined_vbe_output_offsets.size(1) to be", + T, + " but got ", + predefined_vbe_output_offsets.size(1)); + } else { + predefined_vbe_output_offsets = + at::empty({0, 0}, output_offsets_feature_rank.options()); + } + CUDA_DEVICE_GUARD(B_offsets); Tensor row_output_offsets = - at::empty({total_B}, output_offsets_feature_rank.options()); - Tensor b_t_map = at::empty({total_B}, B_offsets.options()); + at::empty({total_B_}, output_offsets_feature_rank.options()); + Tensor b_t_map = at::empty({total_B_}, B_offsets.options()); - const auto grid_dim_x = div_round_up(max_B_feature_rank, kMaxThreads); + const auto grid_dim_x = div_round_up(max_B_feature_rank_, kMaxThreads); const dim3 grid_size(grid_dim_x, num_ranks, T); const auto& [max_grid_x, max_grid_y, max_grid_z] = get_max_grid_size(); TORCH_CHECK( @@ -181,7 +218,9 @@ generate_vbe_metadata( PTA_B(D_offsets, int32_t, 1, 32), D, nobag, - info_B_num_bits); + info_B_num_bits, + MAKE_PTA_WITH_NAME( + func_name, predefined_vbe_output_offsets, int64_t, 2, 32)); return {row_output_offsets, b_t_map}; } diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp index 453b097774..df09f5dbd0 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp @@ -119,7 +119,8 @@ generate_vbe_metadata_cpu( const bool /*nobag*/, const c10::SymInt /*max_B_feature_rank*/, const int64_t info_B_num_bits, - const c10::SymInt total_B) { + const c10::SymInt total_B, + const std::optional& vbe_output_offsets = std::nullopt) { TENSOR_ON_CPU(B_offsets); TENSORS_ON_SAME_DEVICE(B_offsets, B_offsets_rank_per_feature); TENSORS_ON_SAME_DEVICE(B_offsets, output_offsets_feature_rank); @@ -139,6 +140,11 @@ generate_vbe_metadata_cpu( Tensor row_output_offsets = at::empty({total_B_}, output_offsets_feature_rank.options()); TORCH_CHECK(B_offsets.dtype() == at::kInt, "B_offsets should be int32"); + + if (vbe_output_offsets.has_value()) { + TORCH_CHECK(vbe_output_offsets->numel() == total_B, "size mismatch"); + } + Tensor b_t_map = at::empty({total_B_}, B_offsets.options()); auto B_offsets_acc = B_offsets.accessor(); auto D_offsets_acc = D_offsets.accessor(); @@ -166,7 +172,8 @@ generate_vbe_metadata_cpu( } } } - return {row_output_offsets, b_t_map}; + auto row_output_offsets_ = vbe_output_offsets.value_or(row_output_offsets); + return {row_output_offsets_, b_t_map}; } std::tuple @@ -204,7 +211,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " bool nobag, " " SymInt max_B_feature_rank, " " int info_B_num_bits, " - " SymInt total_B" + " SymInt total_B, " + " Tensor? vbe_output_offsets=None" ") -> (Tensor, Tensor)"); DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu); DISPATCH_TO_CPU("get_infos_metadata", get_infos_metadata_cpu); diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_meta.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_meta.cpp index 38077da03d..18b4fd7bc4 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_meta.cpp +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_meta.cpp @@ -23,7 +23,8 @@ generate_vbe_metadata_meta( const bool /*nobag*/, const c10::SymInt /*max_B_feature_rank*/, const int64_t /*info_B_num_bits*/, - const c10::SymInt total_B) { + const c10::SymInt total_B, + const std::optional& /*vbe_output_offsets*/ = std::nullopt) { Tensor row_output_offsets = at::empty_symint({total_B}, output_offsets_feature_rank.options()); Tensor b_t_map = at::empty_symint({total_B}, B_offsets.options()); From da6dffff59c71051efd44afc4b1fb7acd5189db3 Mon Sep 17 00:00:00 2001 From: jichen Date: Wed, 5 Nov 2025 15:34:44 -0800 Subject: [PATCH 77/92] embedding forward optimization for MI350 (#5064) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2095 optimization on embedding forward for MI350: 1. apply vec4 on embedding vbe forward kernel instead of vec2 2. As there are 64 threads in rocm, optimize subwarp in embedding forward v2 kernel when embedding dim is from 32 to 64. Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5064 Reviewed By: q10 Differential Revision: D85701691 Pulled By: spcyppt fbshipit-source-id: 72f491414f50e53038a4b02f3d555967d34740a7 --- ...embedding_forward_split_kernel_template.cu | 21 ++----------------- ...edding_forward_split_kernel_v2_template.cu | 7 +++++++ .../embedding_forward_split_template.cu | 14 ++++++------- 3 files changed, 16 insertions(+), 26 deletions(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index a39d33e391..aada1cdad5 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -84,11 +84,7 @@ using namespace fbgemm_gpu; {#-/* Set the weights row accessor */#} - {%- if is_rocm %} - const auto weights_row = rocm::WeightRowAccessorVec2 - {%- else %} const auto weights_row = WeightRowAccessor - {%- endif %} < {{ 'cache_t' if from_cache else 'emb_t' }}, cache_t @@ -182,11 +178,7 @@ using namespace fbgemm_gpu; {%- endif %} {#-/* Set the weights row accessor */#} - {%- if is_rocm %} - const auto weights_row = rocm::WeightRowAccessorVec2 - {%- else %} const auto weights_row = WeightRowAccessor - {%- endif %} < {{ 'cache_t' if from_cache else 'emb_t' }}, cache_t @@ -319,7 +311,7 @@ using namespace fbgemm_gpu; {%- if is_rocm %} {%- if not nobag %} - rocm::Vec2T vals[kManualUnrollLength * kMaxVecsPerThread]; + Vec4T vals[kManualUnrollLength * kMaxVecsPerThread]; {%- endif %} // Iterate over kThreadGroupSize indices for (auto outer_j = 0; outer_j < kThreadGroupSize && l_start + outer_j < L - L % kManualUnrollLength; outer_j += kManualUnrollLength) @@ -633,12 +625,7 @@ batch_index_select_dim0_codegen_forward_kernel( #endif // Elements are processed 4 at a time through fbgemm_gpu::Vec4 (CUDA float4, 16 bytes) - // for CUDA devices and 2 at a time for ROCm - {%- if is_rocm %} - constexpr int VEC_WIDTH = 2; - {%- else %} constexpr int VEC_WIDTH = 4; - {%- endif %} {%- if is_rocm %} // Unroll factor for ROCm devices constexpr int kManualUnrollLength = 4; @@ -743,12 +730,8 @@ batch_index_select_dim0_codegen_forward_kernel( const float inv_L = (mean_pooling && L != 0) ? static_cast(1.0) / L: static_cast(1.0); // Set up the accumulator buffer - {%- if is_rocm %} - rocm::Vec2T accumulators[kMaxVecsPerThread]; - {%- else %} Vec4T accumulators[kMaxVecsPerThread]; {%- endif %} - {%- endif %} {%- if dense %} {{ embedding_pool_or_store("NULL") }} @@ -930,7 +913,7 @@ batch_index_select_dim0_codegen_forward_kernel {%- endmacro %} {%- macro bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) %} - {%- set max_vecs_per_thread = 2 * kMaxVecsPerThread if is_rocm else kMaxVecsPerThread %} + {%- set max_vecs_per_thread = kMaxVecsPerThread %} {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for output_type in ['float', 'at::Half', 'at::BFloat16'] %} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu index 42f499c6dd..34ce2c6f13 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu @@ -975,6 +975,13 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel( else if (tail_warp_size <= 16) { INVOKE_PROCESS_ALL_INDICES(large_Ls, 16, 0x55) } +#if defined(USE_ROCM) + // not sure step mask value to set when group size is 32 + // while use_lxu_cache is false step mask makes no sense + else if (tail_warp_size <= 32 && !use_lxu_cache) { + INVOKE_PROCESS_ALL_INDICES(large_Ls, 32, 0xf) + } +#endif else { INVOKE_PROCESS_ALL_INDICES(large_Ls, kWarpSize, 0xf) } diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index 83315d9a13..a3edb6b965 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -734,12 +734,7 @@ batch_index_select_dim0_codegen_forward_cuda( // kFixedMaxVecsPerThread instead of kMaxVecsPerThread. But // kMaxVecsPerThread and kFixedMaxVecsPerThread are the same // forward - {%- if is_rocm %} - // Account for Vec2 load for ROCm - constexpr auto kMaxVecsPerThread = 2 * kFixedMaxVecsPerThread; - {%- else %} constexpr auto kMaxVecsPerThread = kFixedMaxVecsPerThread; - {%- endif %} const auto grid = min( div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize), @@ -813,9 +808,14 @@ batch_index_select_dim0_codegen_forward_cuda( // if (!is_experimental) } else { // Allocate num warps per table based on max_D + const int num_warps_per_table = B * div_round_up(max_D, kWarpSize * 4); - const uint32_t num_warps_per_threadblock = kForwardMaxThreads / kWarpSize; - + #ifdef USE_ROCM + const uint32_t num_warps_per_threadblock = kForwardMaxThreads / (kWarpSize * 2); + #else + const uint32_t num_warps_per_threadblock = kForwardMaxThreads / kWarpSize; + #endif + const auto kernel_func = (use_lxu_cache ? split_embedding_codegen_forward_{{ wdesc }}_v2_kernel< emb_t, cache_t, output_t, index_t, true> From 64ba2d9d126f8fc5c2ce908bd73f3baa57d64a29 Mon Sep 17 00:00:00 2001 From: Kaustubh Vartak Date: Wed, 5 Nov 2025 18:38:39 -0800 Subject: [PATCH 78/92] Support larger lookup in permute (#5086) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5086 X-link: https://github.com/facebookresearch/FBGEMM/pull/2094 For lengths per shard exceeding 2^31, we avoid overflow resulting in undefined behavior. Reviewed By: spcyppt Differential Revision: D86209662 fbshipit-source-id: 6d51290f3436629571677091c42b76b6f98e5790 --- fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu index abb6d8abd4..cf7e23c17e 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu @@ -20,7 +20,7 @@ template < typename indices_t, typename weights_t> __global__ __launch_bounds__(kMaxThreads) void permute_2D_data_kernel( - int32_t len, + int64_t len, int32_t T, int32_t B, const indices_t* __restrict__ indices, From 924082f3067e0a819a068864cb7c4bec40823e9d Mon Sep 17 00:00:00 2001 From: Daohang Shi Date: Wed, 5 Nov 2025 20:20:28 -0800 Subject: [PATCH 79/92] Deprecate tl.async_task from fbgemm (#5094) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2102 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5094 see D86119952 Reviewed By: htyu Differential Revision: D86319606 fbshipit-source-id: bdf841f0936f1be53b7a07e66b6a64e9e2aaef12 --- .../gemm/triton_gemm/grouped_gemm.py | 307 ++++++++---------- 1 file changed, 135 insertions(+), 172 deletions(-) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py index 5d373ea266..9fe94f8d46 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py @@ -509,14 +509,13 @@ def _fbgemm_grouped_gemm_ws( num_tiles = num_m_tiles * NUM_N_TILES if USE_TMA_STORE: - with tl.async_task([0]): - c_desc_ptr = tl.make_tensor_descriptor( - c_ptr + M_start_offset * N, - shape=[m_size, N], - # pyre-ignore - strides=[N, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], - ) + c_desc_ptr = tl.make_tensor_descriptor( + c_ptr + M_start_offset * N, + shape=[m_size, N], + # pyre-ignore + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) # Move across tiles next_iterated_tiles = iterated_tiles + num_tiles @@ -534,72 +533,59 @@ def _fbgemm_grouped_gemm_ws( m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) for k_offset in range(0, K, BLOCK_SIZE_K): - with tl.async_task([0]): - a = tl._experimental_descriptor_load( - a_desc_ptr, - [m_offset, k_offset], - [BLOCK_SIZE_M, BLOCK_SIZE_K], - dtype, - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, - [n_offset, k_offset], - [BLOCK_SIZE_N, BLOCK_SIZE_K], - dtype, - ) - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + dtype, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + dtype, + ) + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) if USE_TMA_STORE: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) - n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) - # pyre-ignore - c_desc_ptr.store( - [m_offset, n_offset], - accumulator.to(c_ptr.dtype.element_ty), - ) + m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + # pyre-ignore + c_desc_ptr.store( + [m_offset, n_offset], + accumulator.to(c_ptr.dtype.element_ty), + ) elif FUSE_SCATTER_ADD: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - mask = offs_am < m_size - m_offsets = tl.load( - scatter_add_indices + M_start_offset + offs_am, - mask=mask, - cache_modifier=".ca", - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - c = accumulator.to(c_ptr.dtype.element_ty) - tl.atomic_add( - c_ptr + m_offsets[:, None] * N + offs_bn[None, :], - c, - mask=mask[:, None], - sem="relaxed", - ) + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask = offs_am < m_size + m_offsets = tl.load( + scatter_add_indices + M_start_offset + offs_am, + mask=mask, + cache_modifier=".ca", + ) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.atomic_add( + c_ptr + m_offsets[:, None] * N + offs_bn[None, :], + c, + mask=mask[:, None], + sem="relaxed", + ) else: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - c = accumulator.to(c_ptr.dtype.element_ty) - tl.store( - c_ptr - + (M_start_offset + offs_am[:, None]) * N - + offs_bn[None, :], - c, - mask=offs_am[:, None] < m_size, - cache_modifier=".cs", - ) + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.store( + c_ptr + + (M_start_offset + offs_am[:, None]) * N + + offs_bn[None, :], + c, + mask=offs_am[:, None] < m_size, + cache_modifier=".cs", + ) tidx += NUM_SMS iterated_tiles += num_tiles @@ -841,14 +827,13 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws( num_tiles = num_m_tiles * NUM_N_TILES if USE_TMA_STORE: - with tl.async_task([0]): - c_desc_ptr = tl.make_tensor_descriptor( - c_ptr + M_start_offset * N, - shape=[m_size, N], - # pyre-ignore - strides=[N, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], - ) + c_desc_ptr = tl.make_tensor_descriptor( + c_ptr + M_start_offset * N, + shape=[m_size, N], + # pyre-ignore + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) # Move across tiles next_iterated_tiles = iterated_tiles + num_tiles @@ -867,107 +852,85 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws( m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) for k_offset in range(0, K, BLOCK_SIZE_K): - with tl.async_task([0]): - a = tl._experimental_descriptor_load( - a_desc_ptr, - [m_offset, k_offset], - [BLOCK_SIZE_M, BLOCK_SIZE_K], - dtype, - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, - [n_offset, k_offset], - [BLOCK_SIZE_N, BLOCK_SIZE_K], - dtype, - ) - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + dtype, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + dtype, + ) + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) if USE_TMA_LOAD_ON_SCALES: - with tl.async_task([0]): - b_scale = tl._experimental_descriptor_load( - b_scale_desc_ptr, - [n_offset], - [BLOCK_SIZE_N], - tl.float32, - ) - - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - a_scale = tl.load( - a_scale_ptr + M_start_offset + offs_am[:, None], - mask=offs_am[:, None] < m_size, - cache_modifier=".ca", - ) - c = accumulator.to(tl.float32) * a_scale * b_scale[None, :] + b_scale = tl._experimental_descriptor_load( + b_scale_desc_ptr, + [n_offset], + [BLOCK_SIZE_N], + tl.float32, + ) + + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + a_scale = tl.load( + a_scale_ptr + M_start_offset + offs_am[:, None], + mask=offs_am[:, None] < m_size, + cache_modifier=".ca", + ) + c = accumulator.to(tl.float32) * a_scale * b_scale[None, :] else: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - a_scale = tl.load( - a_scale_ptr + M_start_offset + offs_am[:, None], - mask=offs_am[:, None] < m_size, - cache_modifier=".ca", - ) - b_scale = tl.load( - b_scale_ptr + N_start_offset + offs_bn[None, :], - cache_modifier=".ca", - ) - c = accumulator.to(tl.float32) * a_scale * b_scale + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + a_scale = tl.load( + a_scale_ptr + M_start_offset + offs_am[:, None], + mask=offs_am[:, None] < m_size, + cache_modifier=".ca", + ) + b_scale = tl.load( + b_scale_ptr + N_start_offset + offs_bn[None, :], + cache_modifier=".ca", + ) + c = accumulator.to(tl.float32) * a_scale * b_scale if USE_TMA_STORE: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) - n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) - # pyre-ignore - c_desc_ptr.store( - [m_offset, n_offset], c.to(c_ptr.dtype.element_ty) - ) + m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + # pyre-ignore + c_desc_ptr.store( + [m_offset, n_offset], c.to(c_ptr.dtype.element_ty) + ) elif FUSE_SCATTER_ADD: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - mask = offs_am < m_size - m_offsets = tl.load( - scatter_add_indices + M_start_offset + offs_am, - mask=mask, - cache_modifier=".ca", - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - tl.atomic_add( - c_ptr + m_offsets[:, None] * N + offs_bn[None, :], - c, - mask=mask[:, None], - sem="relaxed", - ) + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask = offs_am < m_size + m_offsets = tl.load( + scatter_add_indices + M_start_offset + offs_am, + mask=mask, + cache_modifier=".ca", + ) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + tl.atomic_add( + c_ptr + m_offsets[:, None] * N + offs_bn[None, :], + c, + mask=mask[:, None], + sem="relaxed", + ) else: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - tl.store( - c_ptr - + (M_start_offset + offs_am[:, None]) * N - + offs_bn[None, :], - c, - mask=offs_am[:, None] < m_size, - cache_modifier=".cs", - ) + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + tl.store( + c_ptr + + (M_start_offset + offs_am[:, None]) * N + + offs_bn[None, :], + c, + mask=offs_am[:, None] < m_size, + cache_modifier=".cs", + ) tidx += NUM_SMS iterated_tiles += num_tiles From ef408b0c4e793db38e881cf3b902104e5c31e0fc Mon Sep 17 00:00:00 2001 From: Emma Lin Date: Thu, 6 Nov 2025 00:11:20 -0800 Subject: [PATCH 80/92] enable feature score auto collection in EBC (#5031) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5031 X-link: https://github.com/meta-pytorch/torchrec/pull/3475 X-link: https://github.com/facebookresearch/FBGEMM/pull/2044 Enable feature score auto collection for EBC in the similar way of EC. The configuration has no difference in embedding table config: virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( training_id_eviction_trigger_count=260_000_000, # 260M training_id_keep_count=160_000_000, # 160M enable_auto_feature_score_collection=True, feature_score_mapping={ "sparse_public_original_content_creator": 1.0, }, feature_score_default_value=0.5, ), Reviewed By: EddyLXJ Differential Revision: D85017179 fbshipit-source-id: 3d62f8adbe201d6e30c445aaed88710bbbcd6557 --- .../fbgemm_gpu/split_table_batched_embeddings_ops_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index 01832dfbc1..27c388d716 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -241,6 +241,7 @@ class KVZCHParams(NamedTuple): backend_return_whole_row: bool = False eviction_policy: EvictionPolicy = EvictionPolicy() embedding_cache_mode: bool = False + feature_score_collection_enabled: bool = False def validate(self) -> None: assert len(self.bucket_offsets) == len(self.bucket_sizes), ( From 8bf19e4fe7bc70c21b9b68c91a5f650d843777fa Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 22 Sep 2025 16:09:05 +0000 Subject: [PATCH 81/92] workgroup tuning and loop unrolled --- .../forward/embedding_forward_split_kernel_template.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index aada1cdad5..69ad8cf8ca 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -461,10 +461,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + kThreadGroupSize > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize/32) > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { + for (auto j = 0; j < (kThreadGroupSize/32) && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group @@ -628,7 +628,7 @@ batch_index_select_dim0_codegen_forward_kernel( constexpr int VEC_WIDTH = 4; {%- if is_rocm %} // Unroll factor for ROCm devices - constexpr int kManualUnrollLength = 4; + constexpr int kManualUnrollLength = 8; {%- endif %} // Determine the linearized warp ID, and exit early if needed From 90c029a71f79cd6109aff869e7894cd06ea39f21 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 13 Oct 2025 20:34:59 +0000 Subject: [PATCH 82/92] revert unroll and wg tuning --- .../forward/embedding_forward_split_kernel_template.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index 69ad8cf8ca..acbf4563f3 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -461,10 +461,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize/32) > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize) > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < (kThreadGroupSize/32) && l_start + j < L; ++j) { + for (auto j = 0; j < (kThreadGroupSize) && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group @@ -628,7 +628,7 @@ batch_index_select_dim0_codegen_forward_kernel( constexpr int VEC_WIDTH = 4; {%- if is_rocm %} // Unroll factor for ROCm devices - constexpr int kManualUnrollLength = 8; + constexpr int kManualUnrollLength = 4; {%- endif %} // Determine the linearized warp ID, and exit early if needed From ae1779133b583843a157748ca6a53ac6770f8204 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Mon, 3 Nov 2025 20:19:56 +0000 Subject: [PATCH 83/92] removed jinj is_rocm on total_L as USE_ROCM is already applied --- .../training/pt2/embedding_split_host_pt2_autograd_template.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 661f7b9b45..2b359ad06e 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1063,9 +1063,7 @@ static torch::autograd::variable_list backward( int32_t max_segment_length_per_warp = 64; // Workaround. Should not be upstreamed in any way. // Redistribute all cta_per_row work to warp_per_row. - {% if is_rocm %} int32_t total_L = indices.numel(); - {%- endif %} {%- if (not nobag) and (optimizer == "rowwise_adagrad") and (not vbe) and From bcc4116280925334322aecf643edc1fa3abee460 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 09:26:50 +0000 Subject: [PATCH 84/92] Change mixed_D default value to false --- .../training/backward/embedding_backward_dense_host_cpu.cpp | 2 +- .../backward/embedding_backward_split_host_template.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp index 626838e930..0bc3c5f254 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp @@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function( c10::SymInt /* max_B = -1 */, c10::SymInt /* max_B_feature_rank = -1 */, c10::SymInt /* vbe_output_size = -1 */, - bool /* mixed_D = true */) { + bool /* mixed_D = false */) { return SplitLookupFunction_Dense_Op::apply( host_weights, weights_offsets, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 2ea96a107e..3fe516891f 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -1116,7 +1116,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( {%- else %} const c10::SymInt vbe_output_size = -1, {%- endif %} - const bool mixed_D = true + const bool mixed_D = false ) { // TODO: refactor into macro {%- if has_gpu_support %} From f6249413f43625ea6d42b6ce477aa2b3059767a9 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 09:30:55 +0000 Subject: [PATCH 85/92] Make const work_group_size for CUDA --- .../embedding_backward_split_template.cu | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index f88b413bdb..f29e32024c 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1224,20 +1224,20 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; {% if is_rocm %} - int32_t total_L = indices.numel(); - int32_t num_cta_per_row_groups; - int32_t work_group_size; - if (total_L/total_B > 1){ - num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; - work_group_size = (kMaxThreads/4); - } - else{ - num_cta_per_row_groups = kMaxThreads / kWarpSize; - work_group_size = kMaxThreads; - } + int32_t total_L = indices.numel(); + int32_t num_cta_per_row_groups; + int32_t work_group_size; + if (total_L/total_B > 1) { + num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + work_group_size = (kMaxThreads/4); + } + else { + num_cta_per_row_groups = kMaxThreads / kWarpSize; + work_group_size = kMaxThreads; + } {%- else %} - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; - int32_t work_group_size = kMaxThreads; + int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + const int32_t work_group_size = kMaxThreads; {%- endif %} {%- if enable_optimized_hip_mixed_D_kernel %} auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); From 14cdfdbd6a347339a3330a31c4544a1fe41c022a Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 09:33:04 +0000 Subject: [PATCH 86/92] Add jinja comments to grad_indice_weights kernel --- .../embedding_backward_split_indice_weights_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index c58ba89f78..57c6804e66 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -333,9 +333,9 @@ __global__ __launch_bounds__(kForwardMaxThreads) void } } } - {%- endif %} + {%- endif %}{#-/* if not ssd and not dense and not use_vec_blocking and not vbe */#} for (; j < kWarpSize && l_start + j < L; ++j) { - {%- else %} // if is_rocm + {%- else %}{#-/* if is_rocm*/#} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { {%- endif %} // if is_rocm const auto offset_idx_j = shfl_sync(offset_idx, j); From 4973c86be828da7224ae2529915127ba90a6b833 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 09:48:15 +0000 Subject: [PATCH 87/92] Remove redundand comment --- .../training/pt2/embedding_split_host_pt2_autograd_template.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 2b359ad06e..a2304b3fb3 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1061,8 +1061,6 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; int32_t max_segment_length_per_warp = 64; - // Workaround. Should not be upstreamed in any way. - // Redistribute all cta_per_row work to warp_per_row. int32_t total_L = indices.numel(); {%- if (not nobag) and (optimizer == "rowwise_adagrad") and From 68e45ff7594492a8bb74b221a72ded537175a576 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 11:27:11 +0000 Subject: [PATCH 88/92] Unify cuda and rocm loops --- .../embedding_backward_split_indice_weights_template.cu | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 57c6804e66..9ffaea3a67 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -213,9 +213,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void 2, offset_idx + D_emb <= weights_numel, offset_idx ) {%- endif %} - {%- if is_rocm %} int32_t j = 0; - {%- if not ssd and not dense and not use_vec_blocking and not vbe %} + {%- if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe %} // Currently for split_embedding_codegen_grad_indice_weights_kernel only if (placement != PlacementType::MANAGED_CACHING) { for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { @@ -333,11 +332,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void } } } - {%- endif %}{#-/* if not ssd and not dense and not use_vec_blocking and not vbe */#} + {%- endif %}{#-/* if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe */#} for (; j < kWarpSize && l_start + j < L; ++j) { - {%- else %}{#-/* if is_rocm*/#} - for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { - {%- endif %} // if is_rocm const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); From c9aceb37e2e8995d8288b56a8bdb399828d68e69 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 22 Sep 2025 16:09:05 +0000 Subject: [PATCH 89/92] workgroup tuning and loop unrolled --- .../training/forward/embedding_forward_split_kernel_template.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index acbf4563f3..5cad567d26 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -628,7 +628,7 @@ batch_index_select_dim0_codegen_forward_kernel( constexpr int VEC_WIDTH = 4; {%- if is_rocm %} // Unroll factor for ROCm devices - constexpr int kManualUnrollLength = 4; + constexpr int kManualUnrollLength = 8; {%- endif %} // Determine the linearized warp ID, and exit early if needed From 1f82f3b2447d4147679c1646cc769c2e664872f4 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 13 Oct 2025 20:34:59 +0000 Subject: [PATCH 90/92] revert unroll and wg tuning --- .../training/forward/embedding_forward_split_kernel_template.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index 5cad567d26..acbf4563f3 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -628,7 +628,7 @@ batch_index_select_dim0_codegen_forward_kernel( constexpr int VEC_WIDTH = 4; {%- if is_rocm %} // Unroll factor for ROCm devices - constexpr int kManualUnrollLength = 8; + constexpr int kManualUnrollLength = 4; {%- endif %} // Determine the linearized warp ID, and exit early if needed From 813265cd738e7fed204345b11894e11b07ed835a Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 13 Nov 2025 07:08:58 +0000 Subject: [PATCH 91/92] eliminate warning of process_block --- .../embedding_backward_split_device_kernel_template.cuh | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index 6e25c40f10..32d61bc1c8 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -237,17 +237,13 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( // Process blocks of different sizes with loop unrolling if constexpr (sizeof(grad_t) <= 2) { - #pragma unroll kFixedMaxVecsPerThread PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) } - #pragma unroll kFixedMaxVecsPerThread PROCESS_BLOCK(4, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) - #pragma unroll kFixedMaxVecsPerThread PROCESS_BLOCK(2, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) - #pragma unroll kFixedMaxVecsPerThread PROCESS_BLOCK(1, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) } @@ -266,6 +262,7 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( } } +#undef PROCESS_BLOCK {%- endif %} // clang-format on From c20f05b91bfeb56d156615378bb5ef5df8fc3650 Mon Sep 17 00:00:00 2001 From: Wulley Date: Thu, 13 Nov 2025 14:18:19 +0000 Subject: [PATCH 92/92] add rocm for macro --- ..._backward_split_device_kernel_template.cuh | 50 ++++++++++++++++++- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index 32d61bc1c8..bb15b24f15 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -14,6 +14,7 @@ using namespace fbgemm_gpu; +{%- if is_rocm %} // Helper macro: Generate block_size grad_offset_j_i variables (i from 1 to block_size-1) #define GRAD_OFFSET(i, j) const auto grad_offset_j_##i = SHFL_SYNC(grad_offset, j + i); #define L(i, j) int32_t l_j_##i = SHFL_SYNC(l, j + i); @@ -105,6 +106,7 @@ using namespace fbgemm_gpu; {%- endif %} } \ } +{%- endif %} {%- if gen_once %} {#- /* @@ -235,6 +237,7 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( {%- endif %} int32_t j = 0; + {%- if is_rocm %} // Process blocks of different sizes with loop unrolling if constexpr (sizeof(grad_t) <= 2) { PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ @@ -246,6 +249,50 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) PROCESS_BLOCK(1, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + +#undef PROCESS_BLOCK + + {%- else %} + for (; j < kThreadGroupSize && sl + j < sl_end; ++j) { + {%- if nobag %} + int32_t l_j = SHFL_SYNC(l, j); + {%- elif vbe %} + const auto grad_offset_j = SHFL_SYNC(grad_offset, j); + {%- else %} + int32_t b_j = SHFL_SYNC(b, j); + int32_t D_start_j = SHFL_SYNC(D_start, j); + {%- endif %} + + {%- if weighted %} + at::acc_type idx_weight_j = SHFL_SYNC(idx_weight, j); + {%- endif %} + + {%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %} + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) { + const int32_t d = {{ d }}; + Vec4TAcc grad_out_vec( + {%- if nobag and is_index_select %} + // grad_output is 1d + &grad_output[grad_offset + l_j * grad_stride + d] + {%- elif nobag %} + &grad_output[l_j][d] + {%- elif vbe %} + &grad_output[0][grad_offset_j + d] + {%- else %} + &grad_output[b_j][0] + D_start_j + d + {%- endif %} // if nobag + ); + + {%- if weighted %} + grad_sum[vec].fma_(grad_out_vec, idx_weight_j); + {%- else %} + grad_sum[vec].add_(grad_out_vec); + {%- endif %} + } + } + {%- endif %} } {%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %} @@ -262,7 +309,6 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( } } -#undef PROCESS_BLOCK {%- endif %} - // clang-format on + // clang-format on \ No newline at end of file