Skip to content

Commit 9acf74f

Browse files
tomlintblfacebook-github-bot
authored andcommitted
Abstract out sharable interface of Dram KV wrapper
Summary: - We need an elegant way to dynamically swap registration of KV implementation. - A solution is to abstract the current KV wrappers and make them inherit the same interface. - The original DramKVWrapper is the bridge of Python and C++ code. By delegating the actual functionality through a impl_ member, which has the type of the aforementioned interface, we can keep the registration as is but swap out detailed impl_ at runtime. V2: - Move set_dram_kv and get_dram_kv into IKVEmbeddingInference - Add lazy-initialization in DramImpl::init - Add TORCH_CHECK to throw when impl is not set before accessing - Rename IKVEmbeddingInference to KVEmbeddingInferenceInterface Differential Revision: D85162982
1 parent f849dcd commit 9acf74f

File tree

5 files changed

+335
-82
lines changed

5 files changed

+335
-82
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_impl.h"
10+
#include "deeplearning/fbgemm/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h" // @manual=//deeplearning/fbgemm/fbgemm_gpu:fbgemm_gpu
11+
12+
namespace fbgemm_gpu {
13+
14+
DramKVEmbeddingInferenceImpl::DramKVEmbeddingInferenceImpl(
15+
int64_t num_shards,
16+
double uniform_init_lower,
17+
double uniform_init_upper,
18+
bool disable_random_init)
19+
: num_shards_(num_shards),
20+
uniform_init_lower_(uniform_init_lower),
21+
uniform_init_upper_(uniform_init_upper),
22+
disable_random_init_(disable_random_init) {
23+
LOG(INFO)
24+
<< "DramKVEmbeddingInferenceImpl created with disable_random_init = "
25+
<< disable_random_init_ << ", num_shards = " << num_shards_;
26+
}
27+
28+
void DramKVEmbeddingInferenceImpl::init(
29+
const std::vector<SerializedSepcType>& specs,
30+
const int64_t row_alignment,
31+
const int64_t scale_bias_size_in_bytes,
32+
const std::optional<at::Tensor>& hash_size_cumsum) {
33+
LOG(INFO) << "DramKVEmbeddingInferenceWrapperImpl::init() starts";
34+
int64_t max_D = 0;
35+
for (const auto& spec : specs) {
36+
max_D = std::max(max_D, std::get<1>(spec));
37+
}
38+
max_row_bytes_ = nbit::padded_row_size_in_bytes(
39+
static_cast<int32_t>(max_D),
40+
static_cast<fbgemm_gpu::SparseType>(std::get<2>(specs[0])),
41+
static_cast<int32_t>(row_alignment),
42+
static_cast<int32_t>(scale_bias_size_in_bytes));
43+
LOG(INFO) << "Initialize dram_kv with max_D: " << max_D
44+
<< ", sparse_type: " << std::get<2>(specs[0])
45+
<< ", row_alignment: " << row_alignment
46+
<< ", scale_bias_size_in_bytes: " << scale_bias_size_in_bytes
47+
<< ", max_row_bytes_: " << max_row_bytes_;
48+
if (dram_kv_ != nullptr) {
49+
return;
50+
}
51+
dram_kv_ = std::make_shared<kv_mem::DramKVInferenceEmbedding<uint8_t>>(
52+
max_row_bytes_,
53+
uniform_init_lower_,
54+
uniform_init_upper_,
55+
c10::make_intrusive<kv_mem::FeatureEvictConfig>(
56+
3 /* EvictTriggerMode.MANUAL */,
57+
4 /* EvictTriggerStrategy::BY_TIMESTAMP_THRESHOLD */,
58+
0 /* trigger_step_intervals */,
59+
0 /* mem_util_threshold_in_GB */,
60+
std::nullopt /* ttls_in_mins */,
61+
std::nullopt /* counter_thresholds */,
62+
std::nullopt /* counter_decay_rates */,
63+
std::nullopt /* feature_score_counter_decay_rates */,
64+
std::nullopt /* training_id_eviction_trigger_count */,
65+
std::nullopt /* training_id_keep_count */,
66+
std::nullopt /* l2_weight_thresholds */,
67+
std::nullopt /* embedding_dims */,
68+
std::nullopt /* threshold_calculation_bucket_stride */,
69+
std::nullopt /* threshold_calculation_bucket_num */,
70+
0 /* interval for insufficient eviction s*/,
71+
0 /* interval for sufficient eviction s*/,
72+
0 /* interval_for_feature_statistics_decay_s_*/),
73+
num_shards_ /* num_shards */,
74+
num_shards_ /* num_threads */,
75+
8 /* row_storage_bitwidth */,
76+
false /* enable_async_update */,
77+
std::nullopt /* table_dims */,
78+
hash_size_cumsum,
79+
disable_random_init_);
80+
}
81+
82+
std::shared_ptr<kv_mem::DramKVInferenceEmbedding<uint8_t>>
83+
DramKVEmbeddingInferenceImpl::get_dram_kv() {
84+
return dram_kv_;
85+
}
86+
87+
void DramKVEmbeddingInferenceImpl::set_dram_kv(
88+
std::shared_ptr<kv_mem::DramKVInferenceEmbedding<uint8_t>> dram_kv) {
89+
dram_kv_ = std::move(dram_kv);
90+
}
91+
92+
void DramKVEmbeddingInferenceImpl::set_embeddings(
93+
const at::Tensor& indices,
94+
const at::Tensor& weights,
95+
std::optional<int64_t> inplace_update_ts_opt) {
96+
const auto count = at::tensor({indices.numel()}, at::ScalarType::Long);
97+
std::optional<uint32_t> inplacee_update_ts = std::nullopt;
98+
if (inplace_update_ts_opt.has_value()) {
99+
inplacee_update_ts =
100+
static_cast<std::uint32_t>(inplace_update_ts_opt.value());
101+
}
102+
folly::coro::blockingWait(dram_kv_->inference_set_kv_db_async(
103+
indices, weights, count, inplacee_update_ts));
104+
}
105+
106+
at::Tensor DramKVEmbeddingInferenceImpl::get_embeddings(
107+
const at::Tensor& indices) {
108+
const auto count = at::tensor({indices.numel()}, at::ScalarType::Long);
109+
auto weights = at::empty(
110+
{
111+
indices.numel(),
112+
max_row_bytes_,
113+
},
114+
at::kByte);
115+
folly::coro::blockingWait(dram_kv_->get_kv_db_async(indices, weights, count));
116+
return weights;
117+
}
118+
119+
void DramKVEmbeddingInferenceImpl::log_inplace_update_stats() {
120+
dram_kv_->log_inplace_update_stats();
121+
}
122+
123+
void DramKVEmbeddingInferenceImpl::trigger_evict(
124+
int64_t inplace_update_ts_64b) {
125+
uint32_t inplace_update_ts_32b =
126+
static_cast<std::uint32_t>(inplace_update_ts_64b);
127+
dram_kv_->trigger_feature_evict(inplace_update_ts_32b);
128+
dram_kv_->resume_ongoing_eviction();
129+
}
130+
131+
void DramKVEmbeddingInferenceImpl::wait_evict_completion() {
132+
dram_kv_->wait_until_eviction_done();
133+
}
134+
135+
void DramKVEmbeddingInferenceImpl::transfer_underlying_storage_from(
136+
std::shared_ptr<KVEmbeddingInferenceInterface> other) {
137+
LOG(INFO)
138+
<< "DramKVEmbeddingInferenceImpl::transfer_underlying_storage_from() starts";
139+
auto other_dram =
140+
std::dynamic_pointer_cast<DramKVEmbeddingInferenceImpl>(other);
141+
TORCH_CHECK(
142+
other_dram != nullptr,
143+
"Cannot transfer underlying storage: source is not a DramKVEmbeddingInferenceImpl");
144+
this->set_dram_kv(other_dram->get_dram_kv());
145+
}
146+
147+
} // namespace fbgemm_gpu
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h"
12+
#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/kv_embedding_inference_interface.h"
13+
14+
namespace fbgemm_gpu {
15+
16+
class DramKVEmbeddingInferenceImpl : public KVEmbeddingInferenceInterface {
17+
public:
18+
using SerializedSepcType = KVEmbeddingInferenceInterface::SerializedSepcType;
19+
20+
DramKVEmbeddingInferenceImpl(
21+
int64_t num_shards,
22+
double uniform_init_lower,
23+
double uniform_init_upper,
24+
bool disable_random_init);
25+
26+
void init(
27+
const std::vector<SerializedSepcType>& specs,
28+
const int64_t row_alignment,
29+
const int64_t scale_bias_size_in_bytes,
30+
const std::optional<at::Tensor>& hash_size_cumsum) override;
31+
32+
void set_embeddings(
33+
const at::Tensor& indices,
34+
const at::Tensor& weights,
35+
std::optional<int64_t> inplace_update_ts_opt = std::nullopt) override;
36+
37+
at::Tensor get_embeddings(const at::Tensor& indices) override;
38+
39+
void log_inplace_update_stats() override;
40+
41+
void trigger_evict(int64_t inplace_update_ts_64b) override;
42+
43+
void wait_evict_completion() override;
44+
45+
void transfer_underlying_storage_from(
46+
std::shared_ptr<KVEmbeddingInferenceInterface> other) override;
47+
48+
std::shared_ptr<kv_mem::DramKVInferenceEmbedding<uint8_t>> get_dram_kv()
49+
override;
50+
51+
void set_dram_kv(std::shared_ptr<kv_mem::DramKVInferenceEmbedding<uint8_t>>
52+
dram_kv) override;
53+
54+
private:
55+
int64_t num_shards_ = 32;
56+
double uniform_init_lower_ = 0.0;
57+
double uniform_init_upper_ = 0.0;
58+
bool disable_random_init_ = false;
59+
60+
std::shared_ptr<kv_mem::DramKVInferenceEmbedding<uint8_t>> dram_kv_;
61+
int64_t max_row_bytes_ = 0;
62+
};
63+
64+
} // namespace fbgemm_gpu

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp

Lines changed: 48 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <gflags/gflags.h>
1111
#include <torch/custom_class.h>
1212
#include "deeplearning/fbgemm/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h" // @manual=//deeplearning/fbgemm/fbgemm_gpu:fbgemm_gpu
13+
#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_impl.h"
1314

1415
DEFINE_int64(
1516
dram_kv_embedding_num_shards,
@@ -26,117 +27,88 @@ DramKVEmbeddingInferenceWrapper::DramKVEmbeddingInferenceWrapper(
2627
: num_shards_(num_shards),
2728
uniform_init_lower_(uniform_init_lower),
2829
uniform_init_upper_(uniform_init_upper),
29-
disable_random_init_(disable_random_init) {
30-
LOG(INFO)
31-
<< "DramKVEmbeddingInferenceWrapper created with disable_random_init = "
32-
<< disable_random_init_ << ", num_shards = " << num_shards_;
30+
disable_random_init_(disable_random_init) {}
31+
32+
void DramKVEmbeddingInferenceWrapper::ensure_impl_initialized() {
33+
if (impl_ == nullptr) {
34+
LOG(INFO)
35+
<< "Lazy-initializing DramKVEmbeddingInferenceImpl with num_shards = "
36+
<< num_shards_ << ", uniform_init_lower = " << uniform_init_lower_
37+
<< ", uniform_init_upper = " << uniform_init_upper_
38+
<< ", disable_random_init = " << disable_random_init_;
39+
impl_ = std::make_shared<DramKVEmbeddingInferenceImpl>(
40+
num_shards_,
41+
uniform_init_lower_,
42+
uniform_init_upper_,
43+
disable_random_init_);
44+
}
3345
}
3446

3547
void DramKVEmbeddingInferenceWrapper::init(
3648
const std::vector<SerializedSepcType>& specs,
3749
const int64_t row_alignment,
3850
const int64_t scale_bias_size_in_bytes,
3951
const std::optional<at::Tensor>& hash_size_cumsum) {
40-
LOG(INFO) << "DramKVEmbeddingInferenceWrapper::init() starts";
41-
int64_t max_D = 0;
42-
for (auto i = 0; i < specs.size(); ++i) {
43-
max_D = std::max(max_D, std::get<1>(specs[i]));
44-
}
45-
max_row_bytes_ = nbit::padded_row_size_in_bytes(
46-
static_cast<int32_t>(max_D),
47-
static_cast<fbgemm_gpu::SparseType>(std::get<2>(specs[0])),
48-
static_cast<int32_t>(row_alignment),
49-
static_cast<int32_t>(scale_bias_size_in_bytes));
50-
LOG(INFO) << "Initialize dram_kv with max_D: " << max_D
51-
<< ", sparse_type: " << std::get<2>(specs[0])
52-
<< ", row_alignment: " << row_alignment
53-
<< ", scale_bias_size_in_bytes: " << scale_bias_size_in_bytes
54-
<< ", max_row_bytes_: " << max_row_bytes_;
55-
if (dram_kv_ != nullptr) {
56-
return;
57-
}
58-
dram_kv_ = std::make_shared<kv_mem::DramKVInferenceEmbedding<uint8_t>>(
59-
max_row_bytes_,
60-
uniform_init_lower_,
61-
uniform_init_upper_,
62-
c10::make_intrusive<kv_mem::FeatureEvictConfig>(
63-
3 /* EvictTriggerMode.MANUAL */,
64-
4 /* EvictTriggerStrategy::BY_TIMESTAMP_THRESHOLD */,
65-
0 /* trigger_step_intervals */,
66-
0 /* mem_util_threshold_in_GB */,
67-
std::nullopt /* ttls_in_mins */,
68-
std::nullopt /* counter_thresholds */,
69-
std::nullopt /* counter_decay_rates */,
70-
std::nullopt /* feature_score_counter_decay_rates */,
71-
std::nullopt /* training_id_eviction_trigger_count */,
72-
std::nullopt /* training_id_keep_count */,
73-
std::nullopt /* l2_weight_thresholds */,
74-
std::nullopt /* embedding_dims */,
75-
std::nullopt /* threshold_calculation_bucket_stride */,
76-
std::nullopt /* threshold_calculation_bucket_num */,
77-
0 /* interval for insufficient eviction s*/,
78-
0 /* interval for sufficient eviction s*/,
79-
0 /* interval_for_feature_statistics_decay_s_*/),
80-
num_shards_ /* num_shards */,
81-
num_shards_ /* num_threads */,
82-
8 /* row_storage_bitwidth */,
83-
false /* enable_async_update */,
84-
std::nullopt /* table_dims */,
85-
hash_size_cumsum,
86-
disable_random_init_);
52+
ensure_impl_initialized();
53+
impl_->init(specs, row_alignment, scale_bias_size_in_bytes, hash_size_cumsum);
8754
}
8855

8956
std::shared_ptr<kv_mem::DramKVInferenceEmbedding<uint8_t>>
9057
DramKVEmbeddingInferenceWrapper::get_dram_kv() {
91-
return dram_kv_;
58+
TORCH_CHECK(impl_ != nullptr, "impl_ is not initialized. Call init first");
59+
return impl_->get_dram_kv();
9260
}
9361

9462
void DramKVEmbeddingInferenceWrapper::set_dram_kv(
9563
std::shared_ptr<kv_mem::DramKVInferenceEmbedding<uint8_t>> dram_kv) {
96-
dram_kv_ = std::move(dram_kv);
64+
TORCH_CHECK(impl_ != nullptr, "impl_ is not initialized. Call init first");
65+
impl_->set_dram_kv(std::move(dram_kv));
66+
}
67+
68+
void DramKVEmbeddingInferenceWrapper::set_impl(
69+
std::shared_ptr<KVEmbeddingInferenceInterface> impl) {
70+
impl_ = std::move(impl);
71+
}
72+
73+
std::shared_ptr<KVEmbeddingInferenceInterface>
74+
DramKVEmbeddingInferenceWrapper::get_impl() {
75+
return impl_;
76+
}
77+
78+
void DramKVEmbeddingInferenceWrapper::transfer_underlying_storage_from(
79+
const c10::intrusive_ptr<DramKVEmbeddingInferenceWrapper>& other) {
80+
TORCH_CHECK(impl_ != nullptr, "impl_ is not initialized. Call init first");
81+
impl_->transfer_underlying_storage_from(other->impl_);
9782
}
9883

9984
void DramKVEmbeddingInferenceWrapper::set_embeddings(
10085
const at::Tensor& indices,
10186
const at::Tensor& weights,
10287
std::optional<int64_t> inplace_update_ts_opt) {
103-
const auto count = at::tensor({indices.numel()}, at::ScalarType::Long);
104-
std::optional<uint32_t> inplacee_update_ts = std::nullopt;
105-
if (inplace_update_ts_opt.has_value()) {
106-
inplacee_update_ts =
107-
static_cast<std::uint32_t>(inplace_update_ts_opt.value());
108-
}
109-
folly::coro::blockingWait(dram_kv_->inference_set_kv_db_async(
110-
indices, weights, count, inplacee_update_ts));
88+
TORCH_CHECK(impl_ != nullptr, "impl_ is not initialized. Call init first");
89+
impl_->set_embeddings(indices, weights, inplace_update_ts_opt);
11190
}
11291

11392
at::Tensor DramKVEmbeddingInferenceWrapper::get_embeddings(
11493
const at::Tensor& indices) {
115-
const auto count = at::tensor({indices.numel()}, at::ScalarType::Long);
116-
auto weights = at::empty(
117-
{
118-
indices.numel(),
119-
max_row_bytes_,
120-
},
121-
at::kByte);
122-
folly::coro::blockingWait(dram_kv_->get_kv_db_async(indices, weights, count));
123-
return weights;
94+
TORCH_CHECK(impl_ != nullptr, "impl_ is not initialized. Call init first");
95+
return impl_->get_embeddings(indices);
12496
}
12597

12698
void DramKVEmbeddingInferenceWrapper::log_inplace_update_stats() {
127-
dram_kv_->log_inplace_update_stats();
99+
TORCH_CHECK(impl_ != nullptr, "impl_ is not initialized. Call init first");
100+
impl_->log_inplace_update_stats();
128101
}
129102

130103
void DramKVEmbeddingInferenceWrapper::trigger_evict(
131104
int64_t inplace_update_ts_64b) {
132-
uint32_t inplace_update_ts_32b =
133-
static_cast<std::uint32_t>(inplace_update_ts_64b);
134-
dram_kv_->trigger_feature_evict(inplace_update_ts_32b);
135-
dram_kv_->resume_ongoing_eviction();
105+
TORCH_CHECK(impl_ != nullptr, "impl_ is not initialized. Call init first");
106+
impl_->trigger_evict(inplace_update_ts_64b);
136107
}
137108

138109
void DramKVEmbeddingInferenceWrapper::wait_evict_completion() {
139-
dram_kv_->wait_until_eviction_done();
110+
TORCH_CHECK(impl_ != nullptr, "impl_ is not initialized. Call init first");
111+
impl_->wait_evict_completion();
140112
}
141113

142114
c10::List<at::Tensor> DramKVEmbeddingInferenceWrapper::serialize() const {

0 commit comments

Comments
 (0)