@@ -884,8 +884,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
884
884
// result of HasAttr.
885
885
if (!enable_cache_runtime_context && HasAttr (kEnableCacheRuntimeContext ))
886
886
enable_cache_runtime_context = true ;
887
- if (!enable_cache_expected_kernel && HasAttr (kEnableCacheExpectedKernel ))
888
- enable_cache_expected_kernel = true ;
889
887
if (!all_kernels_must_compute_runtime_shape &&
890
888
HasAttr (kAllKernelsMustComputeRuntimeShape ))
891
889
all_kernels_must_compute_runtime_shape = true ;
@@ -894,9 +892,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
894
892
RunImpl (scope, place, &ctx);
895
893
} else {
896
894
const Scope* cur_scope = &scope;
897
- if (!runtime_ctx_ || pre_scope_ != cur_scope) {
898
- runtime_ctx_.reset (new RuntimeContext (Inputs (), Outputs (), scope));
899
- pre_scope_ = cur_scope;
895
+ if (runtime_ctx_.get () == nullptr || pre_scope_ != cur_scope) {
896
+ std::lock_guard<std::mutex> lock (cache_update_mutex_);
897
+ if (runtime_ctx_.get () == nullptr || pre_scope_ != cur_scope) {
898
+ runtime_ctx_.reset (new RuntimeContext (Inputs (), Outputs (), scope));
899
+ pre_scope_ = cur_scope;
900
+ }
900
901
}
901
902
RunImpl (scope, place, runtime_ctx_.get ());
902
903
}
@@ -908,7 +909,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
908
909
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance ();
909
910
auto * dev_ctx = pool.Get (place);
910
911
911
- if (!enable_cache_expected_kernel || !kernel_type_ ) {
912
+ if (kernel_type_. get () == nullptr || kernel_func_. get () == nullptr ) {
912
913
ChooseKernel (*runtime_ctx, scope, place);
913
914
}
914
915
@@ -996,8 +997,11 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
996
997
KernelTypeToString (expected_kernel_key));
997
998
}
998
999
999
- kernel_type_.reset (new OpKernelType (expected_kernel_key));
1000
- kernel_func_.reset (new OpKernelFunc (kernel_iter->second ));
1000
+ std::lock_guard<std::mutex> lock (cache_update_mutex_);
1001
+ if (kernel_type_.get () == nullptr || kernel_func_.get () == nullptr ) {
1002
+ kernel_type_.reset (new OpKernelType (expected_kernel_key));
1003
+ kernel_func_.reset (new OpKernelFunc (kernel_iter->second ));
1004
+ }
1001
1005
}
1002
1006
1003
1007
void OperatorWithKernel::TransferInplaceVarsBack (
0 commit comments