Skip to content

Commit 728bbaa

Browse files
authored
add cache_update_mutex_ for operator test=develop (#17124)
* add cache_update_mutex_ for operator
1 parent 15453d0 commit 728bbaa

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -884,8 +884,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
884884
// result of HasAttr.
885885
if (!enable_cache_runtime_context && HasAttr(kEnableCacheRuntimeContext))
886886
enable_cache_runtime_context = true;
887-
if (!enable_cache_expected_kernel && HasAttr(kEnableCacheExpectedKernel))
888-
enable_cache_expected_kernel = true;
889887
if (!all_kernels_must_compute_runtime_shape &&
890888
HasAttr(kAllKernelsMustComputeRuntimeShape))
891889
all_kernels_must_compute_runtime_shape = true;
@@ -894,9 +892,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
894892
RunImpl(scope, place, &ctx);
895893
} else {
896894
const Scope* cur_scope = &scope;
897-
if (!runtime_ctx_ || pre_scope_ != cur_scope) {
898-
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
899-
pre_scope_ = cur_scope;
895+
if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
896+
std::lock_guard<std::mutex> lock(cache_update_mutex_);
897+
if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
898+
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
899+
pre_scope_ = cur_scope;
900+
}
900901
}
901902
RunImpl(scope, place, runtime_ctx_.get());
902903
}
@@ -908,7 +909,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
908909
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
909910
auto* dev_ctx = pool.Get(place);
910911

911-
if (!enable_cache_expected_kernel || !kernel_type_) {
912+
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
912913
ChooseKernel(*runtime_ctx, scope, place);
913914
}
914915

@@ -996,8 +997,11 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
996997
KernelTypeToString(expected_kernel_key));
997998
}
998999

999-
kernel_type_.reset(new OpKernelType(expected_kernel_key));
1000-
kernel_func_.reset(new OpKernelFunc(kernel_iter->second));
1000+
std::lock_guard<std::mutex> lock(cache_update_mutex_);
1001+
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
1002+
kernel_type_.reset(new OpKernelType(expected_kernel_key));
1003+
kernel_func_.reset(new OpKernelFunc(kernel_iter->second));
1004+
}
10011005
}
10021006

10031007
void OperatorWithKernel::TransferInplaceVarsBack(

paddle/fluid/framework/operator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <algorithm>
1818
#include <atomic>
1919
#include <memory>
20+
#include <mutex> // NOLINT
2021
#include <string>
2122
#include <tuple>
2223
#include <unordered_map>
@@ -508,8 +509,8 @@ class OperatorWithKernel : public OperatorBase {
508509
mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
509510
mutable const Scope* pre_scope_ = nullptr;
510511
mutable bool enable_cache_runtime_context = false;
511-
mutable bool enable_cache_expected_kernel = false;
512512
mutable bool all_kernels_must_compute_runtime_shape = false;
513+
mutable std::mutex cache_update_mutex_;
513514
};
514515

515516
extern bool OpSupportGPU(const std::string& op_type);

0 commit comments

Comments
 (0)