Skip to content

Commit 3020ad0

Browse files
authored
[IR] Fix ir StorageManager bug (#54503)
* fix ir storage bug * refine code * refine code * fix bug * refine code * refine code
1 parent 8dfcf03 commit 3020ad0

File tree

3 files changed

+19
-11
lines changed

3 files changed

+19
-11
lines changed

paddle/ir/core/enforce.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
// there is no equivalent intrinsics in msvc.
2626
#define UNLIKELY(condition) (condition)
2727
#endif
28-
29-
inline bool is_error(bool stat) { return !stat; }
28+
template <typename T>
29+
inline bool is_error(const T& stat) {
30+
return !stat;
31+
}
3032

3133
namespace ir {
3234
class IrNotMetException : public std::exception {
@@ -55,7 +57,7 @@ class IrNotMetException : public std::exception {
5557

5658
#define IR_ENFORCE(COND, ...) \
5759
do { \
58-
auto __cond__ = (COND); \
60+
bool __cond__(COND); \
5961
if (UNLIKELY(is_error(__cond__))) { \
6062
try { \
6163
throw ir::IrNotMetException( \

paddle/ir/core/storage_manager.cc

+9-6
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,20 @@ namespace ir {
2525
struct ParametricStorageManager {
2626
using StorageBase = StorageManager::StorageBase;
2727

28-
ParametricStorageManager() {}
28+
explicit ParametricStorageManager(std::function<void(StorageBase *)> destroy)
29+
: destroy_(destroy) {}
2930

3031
~ParametricStorageManager() {
3132
for (const auto &instance : parametric_instances_) {
32-
delete instance.second;
33+
destroy_(instance.second);
3334
}
3435
parametric_instances_.clear();
3536
}
3637

3738
// Get the storage of parametric type, if not in the cache, create and
3839
// insert the cache.
3940
StorageBase *GetOrCreate(std::size_t hash_value,
40-
std::function<bool(const StorageBase *)> equal_func,
41+
std::function<bool(StorageBase *)> equal_func,
4142
std::function<StorageBase *()> constructor) {
4243
if (parametric_instances_.count(hash_value) != 0) {
4344
auto pr = parametric_instances_.equal_range(hash_value);
@@ -62,6 +63,7 @@ struct ParametricStorageManager {
6263
// In order to prevent hash conflicts, the unordered_multimap data structure
6364
// is used for storage.
6465
std::unordered_multimap<size_t, StorageBase *> parametric_instances_;
66+
std::function<void(StorageBase *)> destroy_;
6567
};
6668

6769
StorageManager::StorageManager() {}
@@ -95,12 +97,13 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl(
9597
return parameterless_instance;
9698
}
9799

98-
void StorageManager::RegisterParametricStorageImpl(TypeId type_id) {
100+
void StorageManager::RegisterParametricStorageImpl(
101+
TypeId type_id, std::function<void(StorageBase *)> destroy) {
99102
std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_);
100103
VLOG(4) << "Register a parametric storage of: [TypeId_hash="
101104
<< std::hash<ir::TypeId>()(type_id) << "].";
102-
parametric_instance_.emplace(type_id,
103-
std::make_unique<ParametricStorageManager>());
105+
parametric_instance_.emplace(
106+
type_id, std::make_unique<ParametricStorageManager>(destroy));
104107
}
105108

106109
void StorageManager::RegisterParameterlessStorageImpl(

paddle/ir/core/storage_manager.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ class StorageManager {
100100
///
101101
template <typename Storage>
102102
void RegisterParametricStorage(TypeId type_id) {
103-
return RegisterParametricStorageImpl(type_id);
103+
return RegisterParametricStorageImpl(type_id, [](StorageBase *storage) {
104+
delete static_cast<Storage *>(storage);
105+
});
104106
}
105107

106108
///
@@ -129,7 +131,8 @@ class StorageManager {
129131

130132
StorageBase *GetParameterlessStorageImpl(TypeId type_id);
131133

132-
void RegisterParametricStorageImpl(TypeId type_id);
134+
void RegisterParametricStorageImpl(
135+
TypeId type_id, std::function<void(StorageBase *)> destroy);
133136

134137
void RegisterParameterlessStorageImpl(
135138
TypeId type_id, std::function<StorageBase *()> constructor);

0 commit comments

Comments
 (0)