Skip to content

Commit 237d91d

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Fix heap misusage for top-k, top-n.
The original code relied on an implementation detail and ignored the precondition that [first, last) be in heap order. The std::push_heap call was assumed to act solely as a front/back-swap and sift-down operation. This now fails in an upcoming version of clang. The standard-adhering operation is ``` std::pop_heap(heap.begin(), heap.end(), comp); heap.back() = v; std::push_heap(heap.begin(), heap.end(), comp); ``` Also adjusted container bounds to no longer include an unused element in the back, since it is no longer required and complicates the above replacement. PiperOrigin-RevId: 436894332
1 parent 3a5e0fc commit 237d91d

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

research/carls/base/top_n.h

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,16 @@ class TopN {
7171
// UNORDERED and a peek_bottom() function call is invoked.
7272
//
7373
// o HEAP_SORTED: in this state, the array is kept as a heap and
74-
// there are exactly (limit_+1) elements in the array. This
74+
// there are exactly limit_ elements in the array. This
7575
// state is reached when at least (limit_+1) elements are
7676
// pushed in.
7777
//
7878
// The state transition graph is at follows:
7979
//
80-
// peek_bottom() (limit_+1) elements
80+
// peek_bottom() (limit_+1) elements pushed
8181
// UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED
8282
// | ^
83-
// | (limit_+1) elements |
83+
// | (limit_+1) elements pushed |
8484
// +-----------------------------------------------------------+
8585

8686
enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED };
@@ -94,14 +94,18 @@ class TopN {
9494

9595
// Number of elements currently held by this TopN object. This
9696
// will be no greater than 'limit' passed to the constructor.
97-
size_t size() const { return std::min(elements_.size(), limit_); }
97+
size_t size() const { return elements_.size(); }
9898

9999
bool empty() const { return size() == 0; }
100100

101101
// If you know how many elements you will push at the time you create the
102102
// TopN object, you can call reserve to preallocate the memory that TopN
103103
// will need to process all 'n' pushes. Calling this method is optional.
104-
void reserve(size_t n) { elements_.reserve(std::min(n, limit_ + 1)); }
104+
void reserve(size_t n) {
105+
// We may need limit_+1 for the case where we transition from an unsorted
106+
// set of limit_ elements to a heap.
107+
elements_.reserve(std::min(n, limit_ + 1));
108+
}
105109

106110
// Push 'v'. If the maximum number of elements was exceeded, drop the
107111
// lowest element and return it in 'dropped' (if given). If the maximum is not
@@ -170,7 +174,7 @@ class TopN {
170174
// with no guarantees about the order of iteration. These iterators are
171175
// invalidated by mutation of the data structure.
172176
UnsortedIterator unsorted_begin() const { return elements_.begin(); }
173-
UnsortedIterator unsorted_end() const { return elements_.begin() + size(); }
177+
UnsortedIterator unsorted_end() const { return elements_.end(); }
174178

175179
// Accessor for comparator template argument.
176180
Cmp *comparator() { return &cmp_; }
@@ -184,13 +188,10 @@ class TopN {
184188
void PushInternal(U &&v, T *dropped); // NOLINT(build/c++11)
185189

186190
// elements_ can be in one of two states:
187-
// elements_.size() <= limit_: elements_ is an unsorted vector of elements
188-
// pushed so far.
189-
// elements_.size() > limit_: The last element of elements_ is unused;
190-
// the other elements of elements_ are an stl heap whose size is exactly
191-
// limit_. In this case elements_.size() is exactly one greater than
192-
// limit_, but don't use "elements_.size() == limit_ + 1" to check for
193-
// that because you'll get a false positive if limit_ == size_t(-1).
191+
// elements_.size() <= limit_ && state_ != HEAP_SORTED:
192+
// elements_ is an unsorted vector of elements pushed so far.
193+
// elements_.size() == limit_ && state_ == HEAP_SORTED:
194+
// elements_ is an stl heap.
194195
std::vector<T> elements_;
195196
size_t limit_; // Maximum number of elements to find
196197
Cmp cmp_; // Greater-than comparison function
@@ -208,6 +209,8 @@ void TopN<T, Cmp>::PushInternal(U &&v, T *dropped) { // NOLINT(build/c++11)
208209
return;
209210
}
210211
if (state_ != HEAP_SORTED) {
212+
// We may temporarily extend one beyond limit_ elements here. This is
213+
// necessary for finding and removing the smallest element.
211214
elements_.push_back(std::forward<U>(v)); // NOLINT(build/c++11)
212215
if (state_ == UNORDERED || cmp_(elements_.back(), elements_.front())) {
213216
// Easy case: we just pushed the new element back
@@ -223,25 +226,32 @@ void TopN<T, Cmp>::PushInternal(U &&v, T *dropped) { // NOLINT(build/c++11)
223226
if (elements_.size() == limit_ + 1) {
224227
// Transition from unsorted vector to a heap.
225228
std::make_heap(elements_.begin(), elements_.end(), cmp_);
226-
if (dropped) *dropped = std::move(elements_.front());
227229
std::pop_heap(elements_.begin(), elements_.end(), cmp_);
230+
if (dropped) *dropped = std::move(elements_.back());
231+
elements_.pop_back(); // Restore to size limit_.
228232
state_ = HEAP_SORTED;
233+
} else if (state_ == UNORDERED ||
234+
cmp_(elements_.back(), elements_.front())) {
235+
// Easy case: we just push the new element back
236+
} else {
237+
// To maintain the BOTTOM_KNOWN state, we need to make sure that
238+
// the element at position 0 is always the smallest. So we put
239+
// the new element at position 0 and push the original bottom
240+
// element in the back.
241+
// Warning: this code is subtle.
242+
using std::swap;
243+
swap(elements_.front(), elements_.back());
229244
}
245+
230246
} else {
231247
// Only insert the new element if it is greater than the least element.
232248
if (cmp_(v, elements_.front())) {
233-
// Store new element in the last slot of elements_. Remember from the
234-
// comments on elements_ that this last slot is unused, so we don't
235-
// overwrite anything useful.
236-
elements_.back() = std::forward<U>(v); // NOLINT(build/c++11)
237-
238-
// stp::pop_heap() swaps elements_.front() and elements_.back() and
239-
// rearranges elements from [elements_.begin(), elements_.end() - 1) such
240-
// that they are a heap according to cmp_. Net effect: remove
241-
// elements_.front() from the heap, and add the new element instead. For
242-
// more info, see https://en.cppreference.com/w/cpp/algorithm/pop_heap.
249+
// Remove the top (smallest) element of the min heap, then push the new
250+
// value in.
243251
std::pop_heap(elements_.begin(), elements_.end(), cmp_);
244252
if (dropped) *dropped = std::move(elements_.back());
253+
elements_.back() = std::forward<U>(v);
254+
std::push_heap(elements_.begin(), elements_.end(), cmp_);
245255
} else {
246256
if (dropped) *dropped = std::forward<U>(v); // NOLINT(build/c++11)
247257
}
@@ -277,7 +287,6 @@ std::unique_ptr<std::vector<T>> TopN<T, Cmp>::Extract() {
277287
if (state_ != HEAP_SORTED) {
278288
std::sort(out->begin(), out->end(), cmp_);
279289
} else {
280-
out->pop_back();
281290
std::sort_heap(out->begin(), out->end(), cmp_);
282291
}
283292
return out;
@@ -287,10 +296,6 @@ template <class T, class Cmp>
287296
std::unique_ptr<std::vector<T>> TopN<T, Cmp>::ExtractUnsorted() {
288297
std::unique_ptr<std::vector<T>> out(new std::vector<T>);
289298
out->swap(elements_);
290-
if (state_ == HEAP_SORTED) {
291-
// Remove the limit_+1'th element.
292-
out->pop_back();
293-
}
294299
return out;
295300
}
296301

@@ -308,7 +313,6 @@ void TopN<T, Cmp>::ExtractNondestructive(std::vector<T> *output) const {
308313
if (state_ != HEAP_SORTED) {
309314
std::sort(output->begin(), output->end(), cmp_);
310315
} else {
311-
output->pop_back();
312316
std::sort_heap(output->begin(), output->end(), cmp_);
313317
}
314318
}
@@ -324,10 +328,6 @@ template <class T, class Cmp>
324328
void TopN<T, Cmp>::ExtractUnsortedNondestructive(std::vector<T> *output) const {
325329
CHECK(output != nullptr);
326330
*output = elements_;
327-
if (state_ == HEAP_SORTED) {
328-
// Remove the limit_+1'th element.
329-
output->pop_back();
330-
}
331331
}
332332

333333
template <class T, class Cmp>

0 commit comments

Comments
 (0)