Skip to content

Commit 1aaee78

Browse files
spcypptfacebook-github-bot
authored andcommitted
Add meta functions for cache ops (#4118)
Summary: Pull Request resolved: #4118 X-link: facebookresearch/FBGEMM#1200 Add meta function for cache ops. This is to fix opchecktest failures for D73644969. Reviewed By: sryap Differential Revision: D74520715 fbshipit-source-id: 77f7f9e9df218a56e2d3fcdd11189dbc21af321b
1 parent faddb3a commit 1aaee78

File tree

9 files changed

+47
-516
lines changed

9 files changed

+47
-516
lines changed

fbgemm_gpu/src/split_embeddings_cache/common.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ Tensor linearize_cache_indices_cpu(
3333
const int64_t max_B,
3434
const int64_t indices_base_offset);
3535

36+
Tensor linearize_cache_indices_meta(
37+
const Tensor& cache_hash_size_cumsum,
38+
const Tensor& indices,
39+
const Tensor& offsets,
40+
const std::optional<Tensor>& B_offsets,
41+
const int64_t max_B,
42+
const int64_t indices_base_offset);
43+
3644
Tensor linearize_cache_indices_from_row_idx_cpu(
3745
Tensor cache_hash_size_cumsum,
3846
Tensor update_table_indices,
@@ -96,6 +104,15 @@ Tensor lxu_cache_lookup_cpu(
96104
std::optional<Tensor> num_uniq_cache_indices,
97105
std::optional<Tensor> lxu_cache_locations_output);
98106

107+
Tensor lxu_cache_lookup_meta(
108+
Tensor linear_cache_indices,
109+
Tensor lxu_cache_state,
110+
int64_t invalid_index,
111+
bool gather_cache_stats,
112+
std::optional<Tensor> uvm_cache_stats,
113+
std::optional<Tensor> num_uniq_cache_indices,
114+
std::optional<Tensor> lxu_cache_locations_output);
115+
99116
Tensor direct_mapped_lxu_cache_lookup_cpu(
100117
Tensor linear_cache_indices,
101118
Tensor lxu_cache_state,

fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,14 @@ DLL_PUBLIC Tensor linearize_cache_indices_from_row_idx_cpu(
2929
return at::empty_like(update_row_indices);
3030
}
3131

32+
DLL_PUBLIC Tensor linearize_cache_indices_meta(
33+
const Tensor& /*cache_hash_size_cumsum*/,
34+
const Tensor& indices,
35+
const Tensor& /*offsets*/,
36+
const std::optional<Tensor>& /*B_offsets*/,
37+
const int64_t /*max_B*/,
38+
const int64_t /*indices_base_offset*/) {
39+
return at::empty_like(indices, indices.options().dtype(at::kLong));
40+
}
41+
3242
} // namespace fbgemm_gpu

fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,16 @@ DLL_PUBLIC Tensor direct_mapped_lxu_cache_lookup_cpu(
3434
linear_cache_indices, linear_cache_indices.options().dtype(at::kInt));
3535
}
3636

37+
DLL_PUBLIC Tensor lxu_cache_lookup_meta(
38+
Tensor linear_cache_indices,
39+
Tensor /* lxu_cache_state */,
40+
int64_t /* invalid_index */,
41+
bool /* gather_cache_stats */,
42+
std::optional<Tensor> /* uvm_cache_stats */,
43+
std::optional<Tensor> /* num_uniq_cache_indices */,
44+
std::optional<Tensor> lxu_cache_locations_output) {
45+
return lxu_cache_locations_output.value_or(empty_like(
46+
linear_cache_indices, linear_cache_indices.options().dtype(at::kInt)));
47+
}
48+
3749
} // namespace fbgemm_gpu

fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
6868
DISPATCH_TO_CPU("lxu_cache_lookup", lxu_cache_lookup_cpu);
6969
DISPATCH_TO_CPU(
7070
"direct_mapped_lxu_cache_lookup", direct_mapped_lxu_cache_lookup_cpu);
71+
72+
DISPATCH_TO_META("linearize_cache_indices", linearize_cache_indices_meta);
73+
DISPATCH_TO_META("lxu_cache_lookup", lxu_cache_lookup_meta);
7174
}
7275

7376
} // namespace

fbgemm_gpu/test/tbe/cache/failures_dict_fast.json

Lines changed: 1 addition & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -72,68 +72,7 @@
7272
"status": "xfail"
7373
}
7474
},
75-
"fbgemm::linearize_cache_indices": {
76-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": {
77-
"comment": "",
78-
"status": "skip"
79-
},
80-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": {
81-
"comment": "",
82-
"status": "skip"
83-
},
84-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": {
85-
"comment": "",
86-
"status": "skip"
87-
},
88-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": {
89-
"comment": "",
90-
"status": "skip"
91-
},
92-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": {
93-
"comment": "",
94-
"status": "skip"
95-
},
96-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": {
97-
"comment": "",
98-
"status": "skip"
99-
},
100-
"BackwardSGDTest.test_faketensor__test_backward_sgd": {
101-
"comment": "",
102-
"status": "xfail"
103-
},
104-
"BackwardSGDTest.test_faketensor__test_backward_sgd_really_long_segments": {
105-
"comment": "",
106-
"status": "xfail"
107-
},
108-
"CacheTest.test_faketensor__test_cache_miss_counter": {
109-
"comment": "",
110-
"status": "xfail"
111-
},
112-
"LXUCacheTest.test_faketensor__test_unique_lxu_cache_lookup": {
113-
"comment": "",
114-
"status": "xfail"
115-
},
116-
"LinearizeCacheIndicesTest.test_faketensor__test_linearize_cache_indices": {
117-
"comment": "",
118-
"status": "xfail"
119-
},
120-
"NBitCacheTest.test_faketensor__test_nbit_cache_miss_counter": {
121-
"comment": "",
122-
"status": "xfail"
123-
},
124-
"NBitForwardTest.test_faketensor__test_nbit_forward_uvm_cache": {
125-
"comment": "",
126-
"status": "xfail"
127-
},
128-
"NBitSplitEmbeddingsTest.test_faketensor__test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function": {
129-
"comment": "",
130-
"status": "xfail"
131-
},
132-
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_stb_uvm_cache_stats": {
133-
"comment": "",
134-
"status": "xfail"
135-
}
136-
},
75+
"fbgemm::linearize_cache_indices": {},
13776
"fbgemm::linearize_cache_indices_from_row_idx": {
13877
"LinearizeCacheIndicesTest.test_faketensor__test_linearize_cache_indices_from_row_idx": {
13978
"comment": "",
@@ -186,69 +125,9 @@
186125
"fbgemm::lxu_cache_flush": {},
187126
"fbgemm::lxu_cache_locking_counter_decrement": {},
188127
"fbgemm::lxu_cache_lookup": {
189-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": {
190-
"comment": "",
191-
"status": "xfail"
192-
},
193-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": {
194-
"comment": "",
195-
"status": "xfail"
196-
},
197-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": {
198-
"comment": "",
199-
"status": "xfail"
200-
},
201-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": {
202-
"comment": "",
203-
"status": "xfail"
204-
},
205-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": {
206-
"comment": "",
207-
"status": "xfail"
208-
},
209-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": {
210-
"comment": "",
211-
"status": "xfail"
212-
},
213-
"BackwardSGDTest.test_faketensor__test_backward_sgd": {
214-
"comment": "",
215-
"status": "xfail"
216-
},
217-
"BackwardSGDTest.test_faketensor__test_backward_sgd_really_long_segments": {
218-
"comment": "",
219-
"status": "xfail"
220-
},
221-
"CacheTest.test_faketensor__test_cache_miss_counter": {
222-
"comment": "",
223-
"status": "xfail"
224-
},
225128
"CacheTest.test_schema__test_cache_miss_counter": {
226129
"comment": "",
227130
"status": "xfail"
228-
},
229-
"LXUCacheTest.test_faketensor__test_lxu_cache_lookup": {
230-
"comment": "",
231-
"status": "xfail"
232-
},
233-
"LXUCacheTest.test_faketensor__test_unique_lxu_cache_lookup": {
234-
"comment": "",
235-
"status": "xfail"
236-
},
237-
"NBitCacheTest.test_faketensor__test_nbit_cache_miss_counter": {
238-
"comment": "",
239-
"status": "xfail"
240-
},
241-
"NBitForwardTest.test_faketensor__test_nbit_forward_uvm_cache": {
242-
"comment": "",
243-
"status": "xfail"
244-
},
245-
"NBitSplitEmbeddingsTest.test_faketensor__test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function": {
246-
"comment": "",
247-
"status": "xfail"
248-
},
249-
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_stb_uvm_cache_stats": {
250-
"comment": "",
251-
"status": "xfail"
252131
}
253132
},
254133
"fbgemm::new_managed_tensor": {},

fbgemm_gpu/test/tbe/inference/failures_dict_fast.json

Lines changed: 2 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -87,68 +87,7 @@
8787
"status": "xfail"
8888
}
8989
},
90-
"fbgemm::linearize_cache_indices": {
91-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": {
92-
"comment": "",
93-
"status": "skip"
94-
},
95-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": {
96-
"comment": "",
97-
"status": "skip"
98-
},
99-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": {
100-
"comment": "",
101-
"status": "skip"
102-
},
103-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": {
104-
"comment": "",
105-
"status": "skip"
106-
},
107-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": {
108-
"comment": "",
109-
"status": "skip"
110-
},
111-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": {
112-
"comment": "",
113-
"status": "skip"
114-
},
115-
"BackwardSGDTest.test_faketensor__test_backward_sgd": {
116-
"comment": "",
117-
"status": "xfail"
118-
},
119-
"BackwardSGDTest.test_faketensor__test_backward_sgd_really_long_segments": {
120-
"comment": "",
121-
"status": "xfail"
122-
},
123-
"CacheTest.test_faketensor__test_cache_miss_counter": {
124-
"comment": "",
125-
"status": "xfail"
126-
},
127-
"LXUCacheTest.test_faketensor__test_unique_lxu_cache_lookup": {
128-
"comment": "",
129-
"status": "xfail"
130-
},
131-
"LinearizeCacheIndicesTest.test_faketensor__test_linearize_cache_indices": {
132-
"comment": "",
133-
"status": "xfail"
134-
},
135-
"NBitCacheTest.test_faketensor__test_nbit_cache_miss_counter": {
136-
"comment": "",
137-
"status": "xfail"
138-
},
139-
"NBitForwardTest.test_faketensor__test_nbit_forward_uvm_cache": {
140-
"comment": "",
141-
"status": "xfail"
142-
},
143-
"NBitSplitEmbeddingsTest.test_faketensor__test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function": {
144-
"comment": "",
145-
"status": "xfail"
146-
},
147-
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_stb_uvm_cache_stats": {
148-
"comment": "",
149-
"status": "xfail"
150-
}
151-
},
90+
"fbgemm::linearize_cache_indices": {},
15291
"fbgemm::linearize_cache_indices_from_row_idx": {
15392
"LinearizeCacheIndicesTest.test_faketensor__test_linearize_cache_indices_from_row_idx": {
15493
"comment": "",
@@ -200,68 +139,7 @@
200139
"fbgemm::lru_cache_populate_byte": {},
201140
"fbgemm::lxu_cache_flush": {},
202141
"fbgemm::lxu_cache_locking_counter_decrement": {},
203-
"fbgemm::lxu_cache_lookup": {
204-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": {
205-
"comment": "",
206-
"status": "xfail"
207-
},
208-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": {
209-
"comment": "",
210-
"status": "xfail"
211-
},
212-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": {
213-
"comment": "",
214-
"status": "xfail"
215-
},
216-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": {
217-
"comment": "",
218-
"status": "xfail"
219-
},
220-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": {
221-
"comment": "",
222-
"status": "xfail"
223-
},
224-
"BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": {
225-
"comment": "",
226-
"status": "xfail"
227-
},
228-
"BackwardSGDTest.test_faketensor__test_backward_sgd": {
229-
"comment": "",
230-
"status": "xfail"
231-
},
232-
"BackwardSGDTest.test_faketensor__test_backward_sgd_really_long_segments": {
233-
"comment": "",
234-
"status": "xfail"
235-
},
236-
"CacheTest.test_faketensor__test_cache_miss_counter": {
237-
"comment": "",
238-
"status": "xfail"
239-
},
240-
"LXUCacheTest.test_faketensor__test_lxu_cache_lookup": {
241-
"comment": "",
242-
"status": "xfail"
243-
},
244-
"LXUCacheTest.test_faketensor__test_unique_lxu_cache_lookup": {
245-
"comment": "",
246-
"status": "xfail"
247-
},
248-
"NBitCacheTest.test_faketensor__test_nbit_cache_miss_counter": {
249-
"comment": "",
250-
"status": "xfail"
251-
},
252-
"NBitForwardTest.test_faketensor__test_nbit_forward_uvm_cache": {
253-
"comment": "",
254-
"status": "xfail"
255-
},
256-
"NBitSplitEmbeddingsTest.test_faketensor__test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function": {
257-
"comment": "",
258-
"status": "xfail"
259-
},
260-
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_stb_uvm_cache_stats": {
261-
"comment": "",
262-
"status": "xfail"
263-
}
264-
},
142+
"fbgemm::lxu_cache_lookup": {},
265143
"fbgemm::new_managed_tensor": {},
266144
"fbgemm::new_unified_tensor": {
267145
"NBitForwardTest.test_faketensor__test_nbit_forward_gpu_no_cache": {

0 commit comments

Comments
 (0)