From 87246f3a4c37090317c442396256d4eb278a8a3b Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Wed, 8 Feb 2023 08:31:33 +0000 Subject: [PATCH 1/2] fix NLP-Bert model performance loss --- paddle/fluid/framework/op_kernel_type.h | 31 ++++++++++++++++++--- paddle/fluid/imperative/prepared_operator.h | 4 +-- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/op_kernel_type.h b/paddle/fluid/framework/op_kernel_type.h index eb969a94d8256..1b285a9719e96 100644 --- a/paddle/fluid/framework/op_kernel_type.h +++ b/paddle/fluid/framework/op_kernel_type.h @@ -16,10 +16,12 @@ limitations under the License. */ #include +#include "glog/logging.h" #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/library_type.h" #include "paddle/fluid/platform/place.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/kernel_factory.h" @@ -131,10 +133,31 @@ inline bool backends_are_same_class(const phi::Backend& l, return phi::TransToPhiPlace(l) == phi::TransToPhiPlace(r); } -inline bool NeedTransform(const phi::KernelKey& l, const phi::KernelKey& r) { - return !backends_are_same_class(l.backend(), r.backend()) || - NeedTransformDataType(l, r) || - NeedTransformLayout(l.layout(), r.layout()); +inline bool NeedTransformBackend(const phi::Backend& type_for_var_backend, + const phi::Backend& expected_backend, + const phi::DenseTensor& tensor) { + // NOTE(jiahongyu): KernelKey does not hold place information, so we need to + // explicitly transform CUDAPinnedPlace->CUDAPlace + if (type_for_var_backend != phi::Backend::ALL_BACKEND && + paddle::platform::is_cuda_pinned_place(tensor.place()) && + expected_backend != phi::Backend::CPU) { + VLOG(3) << "Transform Variable " << tensor.name() << " from " + << tensor.place() << " to " + << phi::TransToPhiPlace(expected_backend); + return true; + } + return !backends_are_same_class(type_for_var_backend, expected_backend); +} + +inline bool NeedTransform(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_key, + const phi::DenseTensor& tensor) { + return NeedTransformBackend(kernel_type_for_var.backend(), + expected_kernel_key.backend(), + tensor) || + NeedTransformDataType(kernel_type_for_var, expected_kernel_key) || + NeedTransformLayout(kernel_type_for_var.layout(), + expected_kernel_key.layout()); } } // namespace framework diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index fb36a03e01890..00e059572d204 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -87,8 +87,8 @@ std::shared_ptr> PrepareData( if (tensor && tensor->IsInitialized() && (tensor->memory_size() != 0)) { auto kernel_type_for_var = op.GetKernelTypeForVar( name_pair.first, *tensor, expected_kernel_key); - if (!framework::NeedTransform(kernel_type_for_var, - expected_kernel_key)) { + if (!framework::NeedTransform( + kernel_type_for_var, expected_kernel_key, *tensor)) { continue; } else { VLOG(3) << "Transform Variable " << GetNameFromVar(template_var) From 26e3755febdfad495cab1ad350f9bb03ee5bb4d5 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 9 Feb 2023 02:39:23 +0000 Subject: [PATCH 2/2] fix windows compile error --- paddle/fluid/framework/op_kernel_type.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/framework/op_kernel_type.h b/paddle/fluid/framework/op_kernel_type.h index 1b285a9719e96..43b383aecb047 100644 --- a/paddle/fluid/framework/op_kernel_type.h +++ b/paddle/fluid/framework/op_kernel_type.h @@ -16,13 +16,13 @@ limitations under the License. */ #include -#include "glog/logging.h" #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/library_type.h" #include "paddle/fluid/platform/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_factory.h" namespace paddle {