Skip to content

Commit 37379bf

Browse files
committed
Fix build error on rocm4.5/rocm5 on ubuntu18.04
* declear new rccl api in paddle/fluid * Fix build error:free(): invalid pointer by pick the fix: skarupke/flat_hash_map#26 * Filter to check codestyle for flash_hash_map.h Signed-off-by: jiajuku <jiajuku12@163.com>
1 parent e5d0861 commit 37379bf

File tree

4 files changed

+37
-12
lines changed

4 files changed

+37
-12
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ repos:
102102
exclude: |
103103
(?x)^(
104104
paddle/cinn/.+|
105-
test/cpp/cinn/.+
105+
test/cpp/cinn/.+|
106+
paddle/utils/flat_hash_map.h+
106107
)$
107108
# For CMake files
108109
- repo: local

paddle/fluid/platform/dynload/rccl.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,18 @@ RCCL_RAND_ROUTINE_EACH(DEFINE_WRAP);
2626
RCCL_RAND_ROUTINE_EACH_AFTER_2212(DEFINE_WRAP)
2727
#endif
2828

29+
#if NCCL_VERSION_CODE >= 2304
30+
RCCL_RAND_ROUTINE_EACH_AFTER_2304(DEFINE_WRAP)
31+
#endif
32+
2933
#if NCCL_VERSION_CODE >= 2703
3034
RCCL_RAND_ROUTINE_EACH_AFTER_2703(DEFINE_WRAP)
3135
#endif
3236

37+
#if NCCL_VERSION_CODE >= 21100
38+
RCCL_RAND_ROUTINE_EACH_AFTER_21100(DEFINE_WRAP)
39+
#endif
40+
3341
} // namespace dynload
3442
} // namespace platform
3543
} // namespace paddle

paddle/fluid/platform/dynload/rccl.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,25 @@ RCCL_RAND_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
5151
RCCL_RAND_ROUTINE_EACH_AFTER_2212(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
5252
#endif
5353

54+
#if NCCL_VERSION_CODE >= 2304
55+
#define RCCL_RAND_ROUTINE_EACH_AFTER_2304(__macro) __macro(ncclGetVersion);
56+
RCCL_RAND_ROUTINE_EACH_AFTER_2304(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
57+
#endif
58+
5459
#if NCCL_VERSION_CODE >= 2703
5560
#define RCCL_RAND_ROUTINE_EACH_AFTER_2703(__macro) \
5661
__macro(ncclSend); \
5762
__macro(ncclRecv);
5863
RCCL_RAND_ROUTINE_EACH_AFTER_2703(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
5964
#endif
6065

66+
#if NCCL_VERSION_CODE >= 21100
67+
#define RCCL_RAND_ROUTINE_EACH_AFTER_21100(__macro) \
68+
__macro(ncclRedOpCreatePreMulSum); \
69+
__macro(ncclRedOpDestroy);
70+
RCCL_RAND_ROUTINE_EACH_AFTER_21100(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP)
71+
#endif
72+
6173
} // namespace dynload
6274
} // namespace platform
6375
} // namespace paddle

paddle/utils/flat_hash_map.h

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,6 @@ struct sherwood_v3_entry {
126126
sherwood_v3_entry(int8_t distance_from_desired)
127127
: distance_from_desired(distance_from_desired) {}
128128
~sherwood_v3_entry() {}
129-
static sherwood_v3_entry *empty_default_table() {
130-
static sherwood_v3_entry result[min_lookups] = {
131-
{}, {}, {}, {special_end_value}};
132-
return result;
133-
}
134129

135130
bool has_value() const { return distance_from_desired >= 0; }
136131
bool is_empty() const { return distance_from_desired < 0; }
@@ -664,13 +659,24 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal {
664659
bool empty() const { return num_elements == 0; }
665660

666661
private:
667-
EntryPointer entries = Entry::empty_default_table();
662+
EntryPointer entries = empty_default_table();
668663
size_t num_slots_minus_one = 0;
669664
typename HashPolicySelector<ArgumentHash>::type hash_policy;
670665
int8_t max_lookups = detailv3::min_lookups - 1;
671666
float _max_load_factor = 0.5f;
672667
size_t num_elements = 0;
673668

669+
EntryPointer empty_default_table() {
670+
EntryPointer result =
671+
AllocatorTraits::allocate(*this, detailv3::min_lookups);
672+
EntryPointer special_end_item =
673+
result + static_cast<ptrdiff_t>(detailv3::min_lookups - 1);
674+
for (EntryPointer it = result; it != special_end_item; ++it)
675+
it->distance_from_desired = -1;
676+
special_end_item->distance_from_desired = Entry::special_end_value;
677+
return result;
678+
}
679+
674680
static int8_t compute_max_lookups(size_t num_buckets) {
675681
int8_t desired = detailv3::log2(num_buckets);
676682
return (std::max)(detailv3::min_lookups, desired);
@@ -743,15 +749,13 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal {
743749
void deallocate_data(EntryPointer begin,
744750
size_t num_slots_minus_one,
745751
int8_t max_lookups) {
746-
if (begin != Entry::empty_default_table()) {
747-
AllocatorTraits::deallocate(
748-
*this, begin, num_slots_minus_one + max_lookups + 1);
749-
}
752+
AllocatorTraits::deallocate(
753+
*this, begin, num_slots_minus_one + max_lookups + 1);
750754
}
751755

752756
void reset_to_empty_state() {
753757
deallocate_data(entries, num_slots_minus_one, max_lookups);
754-
entries = Entry::empty_default_table();
758+
entries = empty_default_table();
755759
num_slots_minus_one = 0;
756760
hash_policy.reset();
757761
max_lookups = detailv3::min_lookups - 1;

0 commit comments

Comments
 (0)