Skip to content

Commit 28bafd3

Browse files
committed
Optimize insertion in dynamic_bitset
1 parent 7865c47 commit 28bafd3

File tree

1 file changed

+252
-47
lines changed

1 file changed

+252
-47
lines changed

include/sparrow/buffer/dynamic_bitset/dynamic_bitset_base.hpp

Lines changed: 252 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
#include <algorithm>
1919
#include <bit>
20+
#include <iterator>
2021
#include <stdexcept>
2122
#include <string>
2223
#include <type_traits>
24+
#include <vector>
2325

2426
#include "sparrow/buffer/dynamic_bitset/bitset_iterator.hpp"
2527
#include "sparrow/buffer/dynamic_bitset/bitset_reference.hpp"
@@ -161,6 +163,13 @@ namespace sparrow
161163
[[nodiscard]] constexpr size_type count_extra_bits() const noexcept;
162164
constexpr void update_null_count(bool old_value, bool new_value);
163165

166+
// Efficient bit manipulation helpers for insert operations
167+
constexpr void shift_bits_right(size_type start_pos, size_type bit_count, size_type shift_amount);
168+
constexpr void fill_bits_range(size_type start_pos, size_type bit_count, value_type value);
169+
template <std::random_access_iterator InputIt>
170+
constexpr iterator
171+
insert_range_random_access(const_iterator pos, InputIt first, size_type count, size_type index);
172+
164173
storage_type m_buffer;
165174
size_type m_size;
166175
size_type m_null_count;
@@ -592,30 +601,31 @@ namespace sparrow
592601
SPARROW_ASSERT_TRUE(cbegin() <= pos);
593602
SPARROW_ASSERT_TRUE(pos <= cend());
594603
const auto index = static_cast<size_type>(std::distance(cbegin(), pos));
604+
595605
if (data() == nullptr && value)
596606
{
597607
m_size += count;
608+
return iterator(this, index);
598609
}
599-
else
600-
{
601-
const size_type old_size = size();
602-
const size_type new_size = old_size + count;
603610

604-
// TODO: The current implementation is not efficient. It can be improved.
611+
if (count == 0)
612+
{
613+
return iterator(this, index);
614+
}
605615

606-
resize(new_size);
616+
const size_type old_size = size();
617+
const size_type new_size = old_size + count;
618+
const size_type bits_to_move = old_size - index;
607619

608-
for (size_type i = old_size + count - 1; i >= index + count; --i)
609-
{
610-
set(i, test(i - count));
611-
}
620+
resize(new_size);
612621

613-
for (size_type i = 0; i < count; ++i)
614-
{
615-
set(index + i, value);
616-
}
622+
if (bits_to_move > 0)
623+
{
624+
shift_bits_right(index, bits_to_move, count);
617625
}
618626

627+
fill_bits_range(index, count, value);
628+
619629
return iterator(this, index);
620630
}
621631

@@ -626,43 +636,19 @@ namespace sparrow
626636
dynamic_bitset_base<B>::insert(const_iterator pos, InputIt first, InputIt last)
627637
{
628638
const auto index = static_cast<size_type>(std::distance(cbegin(), pos));
629-
const auto count = static_cast<size_type>(std::distance(first, last));
630-
if (data() == nullptr)
631-
{
632-
if (std::all_of(
633-
first,
634-
last,
635-
[](auto v)
636-
{
637-
return bool(v);
638-
}
639-
))
640-
{
641-
m_size += count;
642-
}
643-
return iterator(this, index);
644-
}
645-
SPARROW_ASSERT_TRUE(cbegin() <= pos);
646-
SPARROW_ASSERT_TRUE(pos <= cend());
647-
648-
const size_type old_size = size();
649-
const size_type new_size = old_size + count;
650639

651-
resize(new_size);
652-
653-
// TODO: The current implementation is not efficient. It can be improved.
654-
655-
for (size_type i = old_size + count - 1; i >= index + count; --i)
640+
if constexpr (std::random_access_iterator<InputIt>)
656641
{
657-
set(i, test(i - count));
642+
// Fast path for random access iterators
643+
const auto count = static_cast<size_type>(std::distance(first, last));
644+
return insert_range_random_access(pos, first, count, index);
658645
}
659-
660-
for (size_type i = 0; i < count; ++i)
646+
else
661647
{
662-
set(index + i, *first++);
648+
// Slower path for input iterators - collect values first
649+
std::vector<value_type> values(first, last);
650+
return insert_range_random_access(pos, values.begin(), values.size(), index);
663651
}
664-
665-
return iterator(this, index);
666652
}
667653

668654
template <typename B>
@@ -711,7 +697,7 @@ namespace sparrow
711697
// TODO: The current implementation is not efficient. It can be improved.
712698

713699
const size_type bit_to_move = size() - last_index;
714-
for (size_type i = 0; i < bit_to_move; ++i)
700+
for (size_t i = 0; i < bit_to_move; ++i)
715701
{
716702
set(first_index + i, test(last_index + i));
717703
}
@@ -738,4 +724,223 @@ namespace sparrow
738724
}
739725
resize(size() - 1);
740726
}
727+
728+
// Efficient helper functions for insert operations
729+
730+
template <typename B>
731+
requires std::ranges::random_access_range<std::remove_pointer_t<B>>
732+
constexpr void
733+
dynamic_bitset_base<B>::shift_bits_right(size_type start_pos, size_type bit_count, size_type shift_amount)
734+
{
735+
if (bit_count == 0 || shift_amount == 0 || data() == nullptr)
736+
{
737+
return;
738+
}
739+
740+
const size_type end_pos = start_pos + bit_count;
741+
742+
// Calculate block boundaries
743+
const size_type start_block = block_index(start_pos);
744+
const size_type end_block = block_index(end_pos - 1);
745+
const size_type target_start_block = block_index(start_pos + shift_amount);
746+
const size_type target_end_block = block_index(end_pos + shift_amount - 1);
747+
748+
// If the shift spans multiple blocks, use block-level operations
749+
if (shift_amount >= s_bits_per_block && start_block != end_block)
750+
{
751+
const size_type block_shift = shift_amount / s_bits_per_block;
752+
const size_type bit_shift = shift_amount % s_bits_per_block;
753+
754+
// Move whole blocks first
755+
for (size_type i = end_block; i >= start_block && i != SIZE_MAX; --i)
756+
{
757+
const size_type target_block = i + block_shift;
758+
if (target_block < buffer().size())
759+
{
760+
buffer().data()[target_block] = buffer().data()[i];
761+
}
762+
}
763+
764+
// Handle remaining bit shift within blocks
765+
if (bit_shift > 0)
766+
{
767+
for (size_type i = target_end_block; i > target_start_block && i != SIZE_MAX; --i)
768+
{
769+
const block_type current = buffer().data()[i];
770+
const block_type previous = (i > 0) ? buffer().data()[i - 1] : block_type(0);
771+
buffer().data()[i] = static_cast<block_type>(
772+
(current << bit_shift) | (previous >> (s_bits_per_block - bit_shift))
773+
);
774+
}
775+
if (target_start_block < buffer().size())
776+
{
777+
buffer().data()[target_start_block] = static_cast<block_type>(
778+
buffer().data()[target_start_block] << bit_shift
779+
);
780+
}
781+
}
782+
}
783+
else
784+
{
785+
// For smaller shifts, use bit-level operations optimized for the shift amount
786+
for (size_type i = bit_count; i > 0; --i)
787+
{
788+
const size_t src_pos = start_pos + i - 1;
789+
const size_t dst_pos = src_pos + shift_amount;
790+
set(dst_pos, test(src_pos));
791+
}
792+
}
793+
}
794+
795+
template <typename B>
796+
requires std::ranges::random_access_range<std::remove_pointer_t<B>>
797+
constexpr void
798+
dynamic_bitset_base<B>::fill_bits_range(size_type start_pos, size_type bit_count, value_type value)
799+
{
800+
if (bit_count == 0 || data() == nullptr)
801+
{
802+
return;
803+
}
804+
805+
const size_type end_pos = start_pos + bit_count;
806+
const size_type start_block = block_index(start_pos);
807+
const size_type end_block = block_index(end_pos - 1);
808+
809+
const block_type fill_value = value ? block_type(~block_type(0)) : block_type(0);
810+
811+
if (start_block == end_block)
812+
{
813+
// All bits are in the same block - use efficient bit masking
814+
const size_type start_bit = bit_index(start_pos);
815+
const size_type end_bit = bit_index(end_pos - 1);
816+
const size_type mask_width = end_bit - start_bit + 1;
817+
const block_type mask = static_cast<block_type>(((block_type(1) << mask_width) - 1) << start_bit);
818+
819+
if (value)
820+
{
821+
buffer().data()[start_block] |= mask;
822+
}
823+
else
824+
{
825+
buffer().data()[start_block] &= ~mask;
826+
}
827+
}
828+
else
829+
{
830+
// Handle first partial block
831+
const size_type start_bit = bit_index(start_pos);
832+
if (start_bit != 0)
833+
{
834+
const block_type mask = static_cast<block_type>(~block_type(0) << start_bit);
835+
if (value)
836+
{
837+
buffer().data()[start_block] |= mask;
838+
}
839+
else
840+
{
841+
buffer().data()[start_block] &= ~mask;
842+
}
843+
}
844+
else
845+
{
846+
buffer().data()[start_block] = fill_value;
847+
}
848+
849+
// Handle full blocks in between
850+
for (size_type block = start_block + 1; block < end_block; ++block)
851+
{
852+
buffer().data()[block] = fill_value;
853+
}
854+
855+
// Handle last partial block
856+
const size_type end_bit = bit_index(end_pos - 1);
857+
const block_type mask = static_cast<block_type>((block_type(1) << (end_bit + 1)) - 1);
858+
if (value)
859+
{
860+
buffer().data()[end_block] |= mask;
861+
}
862+
else
863+
{
864+
buffer().data()[end_block] &= ~mask;
865+
}
866+
}
867+
868+
m_null_count = m_size - count_non_null();
869+
}
870+
871+
template <typename B>
872+
requires std::ranges::random_access_range<std::remove_pointer_t<B>>
873+
template <std::random_access_iterator InputIt>
874+
constexpr auto dynamic_bitset_base<B>::insert_range_random_access(
875+
const_iterator /* pos */,
876+
InputIt first,
877+
size_type count,
878+
size_type index
879+
) -> iterator
880+
{
881+
if (data() == nullptr)
882+
{
883+
if (std::all_of(
884+
first,
885+
std::next(first, static_cast<std::ptrdiff_t>(count)),
886+
[](auto v)
887+
{
888+
return bool(v);
889+
}
890+
))
891+
{
892+
m_size += count;
893+
}
894+
return iterator(this, index);
895+
}
896+
897+
if (count == 0)
898+
{
899+
return iterator(this, index);
900+
}
901+
902+
const size_type old_size = size();
903+
const size_type new_size = old_size + count;
904+
const size_type bits_to_move = old_size - index;
905+
906+
resize(new_size);
907+
908+
if (bits_to_move > 0)
909+
{
910+
shift_bits_right(index, bits_to_move, count);
911+
}
912+
913+
// Set bits efficiently in batches
914+
constexpr size_type batch_size = s_bits_per_block;
915+
for (size_type i = 0; i < count; i += batch_size)
916+
{
917+
const size_type current_batch_size = std::min(batch_size, count - i);
918+
const size_type batch_start = index + i;
919+
920+
// Process bits in the current batch
921+
if (current_batch_size == s_bits_per_block && bit_index(batch_start) == 0)
922+
{
923+
// Optimized path: entire block can be set at once
924+
block_type block_value = 0;
925+
for (size_type j = 0; j < s_bits_per_block; ++j)
926+
{
927+
if (bool(*(std::next(first, static_cast<std::ptrdiff_t>(i + j)))))
928+
{
929+
block_value |= static_cast<block_type>(block_type(1) << j);
930+
}
931+
}
932+
buffer().data()[block_index(batch_start)] = block_value;
933+
}
934+
else
935+
{
936+
// Fallback to bit-by-bit setting for partial blocks
937+
for (size_type j = 0; j < current_batch_size; ++j)
938+
{
939+
set(batch_start + j, bool(*(std::next(first, static_cast<std::ptrdiff_t>(i + j)))));
940+
}
941+
}
942+
}
943+
944+
return iterator(this, index);
945+
}
741946
}

0 commit comments

Comments
 (0)