Skip to content

Commit c62902e

Browse files
[ONEDNN] fix accuracy issue of fc when the input shapes are dynamic
1 parent 3bcc91e commit c62902e

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,13 @@ class FCMKLDNNHandler
284284

285285
std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder(
286286
const phi::DenseTensor* weights, const std::vector<float>& scale_data) {
287-
const std::string weights_key = this->memory_key_ + "@weights";
287+
const std::string weights_base_key = this->memory_key_ + "@weights";
288+
std::string weights_key;
289+
weights_key.reserve(128);
290+
weights_key = phi::funcs::ExtendKeyWithThreadInfoIfNeeded(
291+
dev_ctx_,
292+
phi::funcs::CreateKey(
293+
dev_ctx_, weights_base_key, this->fwd_pd_->weights_desc()));
288294
auto memory_p = std::static_pointer_cast<dnnl::memory>(
289295
this->dev_ctx_.GetBlob(weights_key));
290296

@@ -410,7 +416,8 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
410416
phi::funcs::CreateKey(dev_ctx,
411417
ctx.InputName("Input"),
412418
ctx.InputName("W"),
413-
phi::vectorize(x->dims())));
419+
phi::vectorize(x->dims()),
420+
phi::vectorize(weights->dims())));
414421

415422
auto inner_product_cache =
416423
std::static_pointer_cast<InnerProductCache>(dev_ctx.GetBlob(cache_key));

paddle/phi/backends/onednn/onednn_helper.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,12 @@ inline void AppendKey(std::string* key, const T& num) {
154154
key->append(std::to_string(num));
155155
}
156156

157+
template <>
158+
inline void AppendKey(std::string* key,
159+
const dnnl::memory::format_kind& format) {
160+
key->append(std::to_string(static_cast<int>(format)));
161+
}
162+
157163
template <>
158164
inline void AppendKey(std::string* key,
159165
const dnnl::memory::format_tag& format) {
@@ -171,6 +177,25 @@ inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
171177
key->append(std::to_string(static_cast<int>(algorithm)));
172178
}
173179

180+
template <>
181+
inline void AppendKey(std::string* key, const dnnl::memory::dims& dims) {
182+
for (size_t i = 0; i < dims.size(); i++) {
183+
AppendKey(key, static_cast<int64_t>(dims[i]));
184+
}
185+
}
186+
187+
template <>
188+
inline void AppendKey(std::string* key, const dnnl::memory::desc& md) {
189+
AppendKey(key, md.get_dims());
190+
AppendKey(key, md.get_data_type());
191+
AppendKey(key, md.get_format_kind());
192+
AppendKey(key, md.get_inner_blks());
193+
AppendKey(key, md.get_inner_idxs());
194+
AppendKey(key, md.get_inner_nblks());
195+
AppendKey(key, md.get_padded_dims());
196+
AppendKey(key, md.get_strides());
197+
}
198+
174199
template <>
175200
inline void AppendKey(std::string* key,
176201
const dnnl::normalization_flags& flags) {

0 commit comments

Comments
 (0)