1010#include < gflags/gflags.h>
1111#include < torch/custom_class.h>
1212#include " deeplearning/fbgemm/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h" // @manual=//deeplearning/fbgemm/fbgemm_gpu:fbgemm_gpu
13+ #include " deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_impl.h"
1314
1415DEFINE_int64 (
1516 dram_kv_embedding_num_shards,
@@ -26,117 +27,88 @@ DramKVEmbeddingInferenceWrapper::DramKVEmbeddingInferenceWrapper(
2627 : num_shards_(num_shards),
2728 uniform_init_lower_ (uniform_init_lower),
2829 uniform_init_upper_(uniform_init_upper),
29- disable_random_init_(disable_random_init) {
30- LOG (INFO)
31- << " DramKVEmbeddingInferenceWrapper created with disable_random_init = "
32- << disable_random_init_ << " , num_shards = " << num_shards_;
30+ disable_random_init_(disable_random_init) {}
31+
32+ void DramKVEmbeddingInferenceWrapper::ensure_impl_initialized () {
33+ if (impl_ == nullptr ) {
34+ LOG (INFO)
35+ << " Lazy-initializing DramKVEmbeddingInferenceImpl with num_shards = "
36+ << num_shards_ << " , uniform_init_lower = " << uniform_init_lower_
37+ << " , uniform_init_upper = " << uniform_init_upper_
38+ << " , disable_random_init = " << disable_random_init_;
39+ impl_ = std::make_shared<DramKVEmbeddingInferenceImpl>(
40+ num_shards_,
41+ uniform_init_lower_,
42+ uniform_init_upper_,
43+ disable_random_init_);
44+ }
3345}
3446
3547void DramKVEmbeddingInferenceWrapper::init (
3648 const std::vector<SerializedSepcType>& specs,
3749 const int64_t row_alignment,
3850 const int64_t scale_bias_size_in_bytes,
3951 const std::optional<at::Tensor>& hash_size_cumsum) {
40- LOG (INFO) << " DramKVEmbeddingInferenceWrapper::init() starts" ;
41- int64_t max_D = 0 ;
42- for (auto i = 0 ; i < specs.size (); ++i) {
43- max_D = std::max (max_D, std::get<1 >(specs[i]));
44- }
45- max_row_bytes_ = nbit::padded_row_size_in_bytes (
46- static_cast <int32_t >(max_D),
47- static_cast <fbgemm_gpu::SparseType>(std::get<2 >(specs[0 ])),
48- static_cast <int32_t >(row_alignment),
49- static_cast <int32_t >(scale_bias_size_in_bytes));
50- LOG (INFO) << " Initialize dram_kv with max_D: " << max_D
51- << " , sparse_type: " << std::get<2 >(specs[0 ])
52- << " , row_alignment: " << row_alignment
53- << " , scale_bias_size_in_bytes: " << scale_bias_size_in_bytes
54- << " , max_row_bytes_: " << max_row_bytes_;
55- if (dram_kv_ != nullptr ) {
56- return ;
57- }
58- dram_kv_ = std::make_shared<kv_mem::DramKVInferenceEmbedding<uint8_t >>(
59- max_row_bytes_,
60- uniform_init_lower_,
61- uniform_init_upper_,
62- c10::make_intrusive<kv_mem::FeatureEvictConfig>(
63- 3 /* EvictTriggerMode.MANUAL */ ,
64- 4 /* EvictTriggerStrategy::BY_TIMESTAMP_THRESHOLD */ ,
65- 0 /* trigger_step_intervals */ ,
66- 0 /* mem_util_threshold_in_GB */ ,
67- std::nullopt /* ttls_in_mins */ ,
68- std::nullopt /* counter_thresholds */ ,
69- std::nullopt /* counter_decay_rates */ ,
70- std::nullopt /* feature_score_counter_decay_rates */ ,
71- std::nullopt /* training_id_eviction_trigger_count */ ,
72- std::nullopt /* training_id_keep_count */ ,
73- std::nullopt /* l2_weight_thresholds */ ,
74- std::nullopt /* embedding_dims */ ,
75- std::nullopt /* threshold_calculation_bucket_stride */ ,
76- std::nullopt /* threshold_calculation_bucket_num */ ,
77- 0 /* interval for insufficient eviction s*/ ,
78- 0 /* interval for sufficient eviction s*/ ,
79- 0 /* interval_for_feature_statistics_decay_s_*/ ),
80- num_shards_ /* num_shards */ ,
81- num_shards_ /* num_threads */ ,
82- 8 /* row_storage_bitwidth */ ,
83- false /* enable_async_update */ ,
84- std::nullopt /* table_dims */ ,
85- hash_size_cumsum,
86- disable_random_init_);
52+ ensure_impl_initialized ();
53+ impl_->init (specs, row_alignment, scale_bias_size_in_bytes, hash_size_cumsum);
8754}
8855
8956std::shared_ptr<kv_mem::DramKVInferenceEmbedding<uint8_t >>
9057DramKVEmbeddingInferenceWrapper::get_dram_kv () {
91- return dram_kv_;
58+ TORCH_CHECK (impl_ != nullptr , " impl_ is not initialized. Call init first" );
59+ return impl_->get_dram_kv ();
9260}
9361
9462void DramKVEmbeddingInferenceWrapper::set_dram_kv (
9563 std::shared_ptr<kv_mem::DramKVInferenceEmbedding<uint8_t >> dram_kv) {
96- dram_kv_ = std::move (dram_kv);
64+ TORCH_CHECK (impl_ != nullptr , " impl_ is not initialized. Call init first" );
65+ impl_->set_dram_kv (std::move (dram_kv));
66+ }
67+
68+ void DramKVEmbeddingInferenceWrapper::set_impl (
69+ std::shared_ptr<KVEmbeddingInferenceInterface> impl) {
70+ impl_ = std::move (impl);
71+ }
72+
73+ std::shared_ptr<KVEmbeddingInferenceInterface>
74+ DramKVEmbeddingInferenceWrapper::get_impl () {
75+ return impl_;
76+ }
77+
78+ void DramKVEmbeddingInferenceWrapper::transfer_underlying_storage_from (
79+ const c10::intrusive_ptr<DramKVEmbeddingInferenceWrapper>& other) {
80+ TORCH_CHECK (impl_ != nullptr , " impl_ is not initialized. Call init first" );
81+ impl_->transfer_underlying_storage_from (other->impl_ );
9782}
9883
9984void DramKVEmbeddingInferenceWrapper::set_embeddings (
10085 const at::Tensor& indices,
10186 const at::Tensor& weights,
10287 std::optional<int64_t > inplace_update_ts_opt) {
103- const auto count = at::tensor ({indices.numel ()}, at::ScalarType::Long);
104- std::optional<uint32_t > inplacee_update_ts = std::nullopt ;
105- if (inplace_update_ts_opt.has_value ()) {
106- inplacee_update_ts =
107- static_cast <std::uint32_t >(inplace_update_ts_opt.value ());
108- }
109- folly::coro::blockingWait (dram_kv_->inference_set_kv_db_async (
110- indices, weights, count, inplacee_update_ts));
88+ TORCH_CHECK (impl_ != nullptr , " impl_ is not initialized. Call init first" );
89+ impl_->set_embeddings (indices, weights, inplace_update_ts_opt);
11190}
11291
11392at::Tensor DramKVEmbeddingInferenceWrapper::get_embeddings (
11493 const at::Tensor& indices) {
115- const auto count = at::tensor ({indices.numel ()}, at::ScalarType::Long);
116- auto weights = at::empty (
117- {
118- indices.numel (),
119- max_row_bytes_,
120- },
121- at::kByte );
122- folly::coro::blockingWait (dram_kv_->get_kv_db_async (indices, weights, count));
123- return weights;
94+ TORCH_CHECK (impl_ != nullptr , " impl_ is not initialized. Call init first" );
95+ return impl_->get_embeddings (indices);
12496}
12597
12698void DramKVEmbeddingInferenceWrapper::log_inplace_update_stats () {
127- dram_kv_->log_inplace_update_stats ();
99+ TORCH_CHECK (impl_ != nullptr , " impl_ is not initialized. Call init first" );
100+ impl_->log_inplace_update_stats ();
128101}
129102
130103void DramKVEmbeddingInferenceWrapper::trigger_evict (
131104 int64_t inplace_update_ts_64b) {
132- uint32_t inplace_update_ts_32b =
133- static_cast <std::uint32_t >(inplace_update_ts_64b);
134- dram_kv_->trigger_feature_evict (inplace_update_ts_32b);
135- dram_kv_->resume_ongoing_eviction ();
105+ TORCH_CHECK (impl_ != nullptr , " impl_ is not initialized. Call init first" );
106+ impl_->trigger_evict (inplace_update_ts_64b);
136107}
137108
138109void DramKVEmbeddingInferenceWrapper::wait_evict_completion () {
139- dram_kv_->wait_until_eviction_done ();
110+ TORCH_CHECK (impl_ != nullptr , " impl_ is not initialized. Call init first" );
111+ impl_->wait_evict_completion ();
140112}
141113
142114c10::List<at::Tensor> DramKVEmbeddingInferenceWrapper::serialize () const {
0 commit comments