Created
January 2, 2026 16:48
-
-
Save jacky860226/e7bdd387d4f9040f12d119d9ee152bf4 to your computer and use it in GitHub Desktop.
Concurrent Vector
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <atomic> | |
| #include <cassert> | |
| #include <new> | |
| #include <thread> | |
| #include <utility> | |
| #if __cplusplus >= 202002L | |
| #include <bit> | |
| #elif defined(_MSC_VER) | |
| #include <intrin.h> | |
| #endif | |
| // --- 輔助函式 --- | |
| constexpr size_t simple_log2(size_t n) { | |
| #if __cplusplus >= 202002L | |
| // C++20 標準庫 | |
| return std::bit_width(n) - 1; | |
| #elif defined(__GNUC__) || defined(__clang__) | |
| // GCC/Clang 內建函式:計算前導零的個數 (Count Leading Zeros) | |
| // sizeof(size_t) * 8 是總位元數 | |
| // 例如 32 位元下,n=4 (0...0100),clz=29,32-1-29 = 2 | |
| return sizeof(size_t) * 8 - 1 - __builtin_clzl(n); | |
| #elif defined(_MSC_VER) | |
| // MSVC 內建函式:尋找最高位 1 的位置 | |
| unsigned long index; | |
| #if defined(_WIN64) | |
| _BitScanReverse64(&index, n); | |
| #else | |
| _BitScanReverse(&index, n); | |
| #endif | |
| return index; | |
| #else | |
| // Fallback: 通用實作 | |
| size_t res = 0; | |
| while (n > 1) { | |
| n >>= 1; | |
| res++; | |
| } | |
| return res; | |
| #endif | |
| } | |
| // ========================================== | |
| // SegmentTable: 負責底層記憶體分段管理 | |
| // ========================================== | |
| template <typename T> class SegmentTable { | |
| static constexpr size_t MAX_SEGMENTS = 64; | |
| enum class SegmentState : uint8_t { Unallocated, Allocating, Allocated }; | |
| // 儲存指向各個 Segment 的指標 (一般指標,由 segment_states 保護) | |
| T *segments[MAX_SEGMENTS]; | |
| // 狀態旗標 | |
| std::atomic<SegmentState> segment_states[MAX_SEGMENTS]; | |
| public: | |
| SegmentTable() { | |
| for (size_t i = 0; i < MAX_SEGMENTS; ++i) { | |
| segments[i] = nullptr; | |
| segment_states[i].store(SegmentState::Unallocated, | |
| std::memory_order_relaxed); | |
| } | |
| } | |
| ~SegmentTable() { | |
| // 注意:這裡只負責釋放記憶體 (operator delete[]) | |
| // 物件的解構 (Destructor) 應該由上層容器負責 | |
| for (size_t i = 0; i < MAX_SEGMENTS; ++i) { | |
| T *seg = segments[i]; | |
| if (seg) { | |
| delete[] reinterpret_cast<char *>(seg); | |
| } | |
| } | |
| } | |
| // 禁止複製 | |
| SegmentTable(const SegmentTable &) = delete; | |
| SegmentTable &operator=(const SegmentTable &) = delete; | |
| // --- 靜態索引計算 --- | |
| static size_t get_segment_index(size_t index) { | |
| return simple_log2(index | 1); | |
| } | |
| static size_t get_segment_base(size_t seg_index) { | |
| return (size_t(1) << seg_index) & ~size_t(1); | |
| } | |
| static size_t get_segment_size(size_t seg_index) { | |
| return seg_index == 0 ? 2 : (size_t(1) << seg_index); | |
| } | |
| // --- 核心功能:取得指定 index 的記憶體位址 --- | |
| // 如果該 Segment 尚未分配,此函式會負責分配 (Thread-safe) | |
| T *get_address(size_t index) { | |
| size_t seg_idx = get_segment_index(index); | |
| size_t seg_base = get_segment_base(seg_idx); | |
| size_t offset = index - seg_base; | |
| // 1. Optimistic check (Fast path) | |
| if (segment_states[seg_idx].load(std::memory_order_acquire) == | |
| SegmentState::Allocated) { | |
| T *segment = segments[seg_idx]; | |
| return &segment[offset]; | |
| } | |
| // 2. 嘗試取得分配權 | |
| SegmentState expected = SegmentState::Unallocated; | |
| if (segment_states[seg_idx].compare_exchange_strong( | |
| expected, SegmentState::Allocating, std::memory_order_acq_rel, | |
| std::memory_order_acquire)) { | |
| // 我是分配者 | |
| try { | |
| size_t seg_size = get_segment_size(seg_idx); | |
| T *segment = reinterpret_cast<T *>(new char[seg_size * sizeof(T)]); | |
| segments[seg_idx] = segment; | |
| segment_states[seg_idx].store(SegmentState::Allocated, | |
| std::memory_order_release); | |
| segment_states[seg_idx].notify_all(); | |
| return &segment[offset]; | |
| } catch (...) { | |
| // Allocation failed. Rollback state. | |
| segment_states[seg_idx].store(SegmentState::Unallocated, | |
| std::memory_order_release); | |
| segment_states[seg_idx].notify_all(); | |
| throw; | |
| } | |
| } else { | |
| // 別人正在分配,等待直到完成 | |
| while (true) { | |
| SegmentState current = | |
| segment_states[seg_idx].load(std::memory_order_acquire); | |
| if (current == SegmentState::Allocated) { | |
| break; | |
| } | |
| if (current == SegmentState::Unallocated) { | |
| // Previous allocator failed, retry allocation | |
| return get_address(index); | |
| } | |
| segment_states[seg_idx].wait(SegmentState::Allocating, | |
| std::memory_order_acquire); | |
| } | |
| } | |
| T *segment = segments[seg_idx]; | |
| return &segment[offset]; | |
| } | |
| // 唯讀版本:如果 Segment 不存在,回傳 nullptr (不分配) | |
| T *get_address_if_exists(size_t index) const { | |
| size_t seg_idx = get_segment_index(index); | |
| size_t seg_base = get_segment_base(seg_idx); | |
| size_t offset = index - seg_base; | |
| if (segment_states[seg_idx].load(std::memory_order_acquire) == | |
| SegmentState::Allocated) { | |
| T *segment = segments[seg_idx]; | |
| return &segment[offset]; | |
| } | |
| return nullptr; | |
| } | |
| }; | |
| // ========================================== | |
| // ConcurrentVector: 提供容器介面 | |
| // ========================================== | |
| template <typename T> class ConcurrentVector { | |
| SegmentTable<T> table; | |
| std::atomic<size_t> m_size{0}; | |
| public: | |
| ConcurrentVector() = default; | |
| ~ConcurrentVector() { | |
| // 負責呼叫物件解構子 | |
| size_t current_size = m_size.load(std::memory_order_relaxed); | |
| for (size_t i = 0; i < current_size; ++i) { | |
| // 這裡我們知道記憶體一定存在,所以可以直接拿 | |
| T *ptr = table.get_address_if_exists(i); | |
| if (ptr) { | |
| ptr->~T(); | |
| } | |
| } | |
| // table 的解構子會自動釋放記憶體 | |
| } | |
| void push_back(const T &value) { emplace_back(value); } | |
| void push_back(T &&value) { emplace_back(std::move(value)); } | |
| template <typename... Args> void emplace_back(Args &&...args) { | |
| // 1. 搶佔位置 | |
| size_t idx = m_size.fetch_add(1); | |
| // 2. 取得記憶體位址 (如果需要會自動分配 Segment) | |
| T *ptr = table.get_address(idx); | |
| // 3. 原地建構物件 | |
| new (ptr) T(std::forward<Args>(args)...); | |
| } | |
| // 原子地增加 delta 個元素空間,並進行預設初始化 | |
| // 回傳值:新增範圍的起始索引 (Start Index) | |
| size_t grow_by(size_t delta) { | |
| if (delta == 0) | |
| return m_size.load(std::memory_order_relaxed); | |
| // 1. 原子地預留範圍 | |
| size_t start_index = m_size.fetch_add(delta); | |
| size_t end_index = start_index + delta; | |
| // 2. 確保範圍內的 Segment 已分配 | |
| // 優化:跳躍式檢查,只檢查每個 Segment 的第一個涉及到的元素 | |
| for (size_t i = start_index; i < end_index;) { | |
| table.get_address(i); | |
| // 計算下一個 Segment 的起始位置,直接跳過去 | |
| size_t seg_idx = SegmentTable<T>::get_segment_index(i); | |
| size_t next_seg_base = SegmentTable<T>::get_segment_base(seg_idx) + | |
| SegmentTable<T>::get_segment_size(seg_idx); | |
| i = next_seg_base; | |
| } | |
| // 3. 建構元素 | |
| for (size_t i = start_index; i < end_index; ++i) { | |
| T *ptr = table.get_address(i); | |
| new (ptr) T(); | |
| } | |
| return start_index; | |
| } | |
| // 確保 vector 至少有 n 個元素 | |
| void grow_to_at_least(size_t n) { | |
| size_t c = m_size.load(std::memory_order_relaxed); | |
| while (c < n) { | |
| if (m_size.compare_exchange_weak(c, n)) { | |
| // 成功更新 size 為 n,負責初始化 [c, n) | |
| // 1. 確保範圍內的 Segment 已分配 | |
| for (size_t i = c; i < n;) { | |
| table.get_address(i); | |
| size_t seg_idx = SegmentTable<T>::get_segment_index(i); | |
| size_t next_seg_base = SegmentTable<T>::get_segment_base(seg_idx) + | |
| SegmentTable<T>::get_segment_size(seg_idx); | |
| i = next_seg_base; | |
| } | |
| // 2. 建構元素 | |
| for (size_t i = c; i < n; ++i) { | |
| T *ptr = table.get_address(i); | |
| new (ptr) T(); | |
| } | |
| return; | |
| } | |
| // CAS 失敗,c 會被更新為最新的 m_size,迴圈繼續檢查 | |
| } | |
| } | |
| T &operator[](size_t index) { | |
| // 1. 取得記憶體位址 | |
| T *ptr = table.get_address_if_exists(index); | |
| // 2. Spin Wait | |
| // 只有當 ptr 為 null 且 index 在合法範圍內時,才需要等待 | |
| // 如果 index >= m_size,代表這是越界存取,我們不應該等待(否則會死鎖) | |
| if (ptr == nullptr) { | |
| size_t current_size = m_size.load(std::memory_order_acquire); | |
| if (index < current_size) { | |
| while (ptr == nullptr) { | |
| std::this_thread::yield(); | |
| ptr = table.get_address_if_exists(index); | |
| } | |
| } | |
| // 如果 index >= current_size,這裡 ptr 還是 nullptr | |
| // 接下來的 *ptr 會導致存取 nullptr (Crash),這符合 C++ operator[] | |
| // 越界的 UB 行為 | |
| } | |
| return *ptr; | |
| } | |
| const T &operator[](size_t index) const { | |
| // 轉呼叫 non-const 版本 (邏輯一樣) | |
| return const_cast<ConcurrentVector *>(this)->operator[](index); | |
| } | |
| // 帶有邊界檢查的存取 (唯讀) | |
| const T &at(size_t index) const { | |
| return const_cast<ConcurrentVector *>(this)->at(index); | |
| } | |
| // 帶有邊界檢查的存取 (讀寫) | |
| T &at(size_t index) { | |
| // 1. 檢查是否越界 | |
| // 注意:這裡使用 acquire memory order 確保讀到最新的 size | |
| if (index >= m_size.load(std::memory_order_acquire)) { | |
| throw std::out_of_range("ConcurrentVector::at: index out of range"); | |
| } | |
| // 2. 取得記憶體位址 (邏輯同 operator[]) | |
| T *ptr = table.get_address_if_exists(index); | |
| // Spin Wait: 雖然 index < m_size,但 Segment 可能還在分配中 | |
| while (ptr == nullptr) { | |
| std::this_thread::yield(); | |
| ptr = table.get_address_if_exists(index); | |
| } | |
| return *ptr; | |
| } | |
| size_t size() const { return m_size.load(std::memory_order_acquire); } | |
| }; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <atomic> | |
| #include <memory> | |
| #include <type_traits> | |
| // Forward declaration | |
| template <typename T, size_t level = 3, size_t each_level_bit_num = 16, | |
| typename Allocator = std::allocator<T>> | |
| class SegmentTable; | |
| // Specialization for level 0 (Leaf Node) - Holds the actual data | |
| template <typename T, size_t each_level_bit_num, typename Allocator> | |
| class SegmentTable<T, 0, each_level_bit_num, Allocator> { | |
| static_assert(each_level_bit_num >= 1 && each_level_bit_num <= 64, | |
| "Invalid SegmentTable configuration."); | |
| using AllocTraits = std::allocator_traits<Allocator>; | |
| T *elements; | |
| Allocator alloc; | |
| public: | |
| SegmentTable() { | |
| size_t size = 1ULL << each_level_bit_num; | |
| elements = AllocTraits::allocate(alloc, size); | |
| // Note: We allocate raw memory. Construction is handled by the caller | |
| // (e.g., ConcurrentVector). | |
| } | |
| ~SegmentTable() { | |
| size_t size = 1ULL << each_level_bit_num; | |
| AllocTraits::deallocate(alloc, elements, size); | |
| } | |
| T *get_address(size_t index) { | |
| return &elements[index & ((1ULL << each_level_bit_num) - 1)]; | |
| } | |
| T *get_address_if_exists(size_t index) const { | |
| return &elements[index & ((1ULL << each_level_bit_num) - 1)]; | |
| } | |
| T *concurrent_get_address(size_t index) { return get_address(index); } | |
| T *concurrent_get_address_if_exists(size_t index) const { | |
| return get_address_if_exists(index); | |
| } | |
| }; | |
| // Primary Template for level > 0 (Internal Node) - Holds pointers to child | |
| // tables | |
| template <typename T, size_t level, size_t each_level_bit_num, | |
| typename Allocator> | |
| class SegmentTable { | |
| static_assert(level >= 1 && level <= 64 && each_level_bit_num >= 1 && | |
| each_level_bit_num <= 64 && | |
| level * each_level_bit_num <= 64, | |
| "Invalid SegmentTable configuration."); | |
| using ChildTable = SegmentTable<T, level - 1, each_level_bit_num, Allocator>; | |
| using ChildAllocator = typename std::allocator_traits< | |
| Allocator>::template rebind_alloc<ChildTable>; | |
| using ChildAllocTraits = std::allocator_traits<ChildAllocator>; | |
| using ChildPtrAllocator = typename std::allocator_traits< | |
| Allocator>::template rebind_alloc<ChildTable *>; | |
| using ChildPtrAllocTraits = std::allocator_traits<ChildPtrAllocator>; | |
| enum class SegmentState : uint8_t { Unallocated, Allocating, Allocated }; | |
| using StateAllocator = typename std::allocator_traits< | |
| Allocator>::template rebind_alloc<std::atomic<SegmentState>>; | |
| using StateAllocTraits = std::allocator_traits<StateAllocator>; | |
| ChildTable **segments; | |
| std::atomic<SegmentState> *segment_states; | |
| ChildAllocator child_alloc; | |
| ChildPtrAllocator ptr_alloc; | |
| StateAllocator state_alloc; | |
| public: | |
| SegmentTable() { | |
| size_t total_segments = 1ULL << each_level_bit_num; | |
| segments = ChildPtrAllocTraits::allocate(ptr_alloc, total_segments); | |
| segment_states = StateAllocTraits::allocate(state_alloc, total_segments); | |
| for (size_t i = 0; i < total_segments; ++i) { | |
| segments[i] = nullptr; | |
| StateAllocTraits::construct(state_alloc, &segment_states[i], | |
| SegmentState::Unallocated); | |
| } | |
| } | |
| ~SegmentTable() { | |
| size_t total_segments = 1ULL << each_level_bit_num; | |
| for (size_t i = 0; i < total_segments; ++i) { | |
| if (segments[i] != nullptr) { | |
| // Destroy the child table object | |
| ChildAllocTraits::destroy(child_alloc, segments[i]); | |
| // Deallocate the memory for the child table object | |
| ChildAllocTraits::deallocate(child_alloc, segments[i], 1); | |
| } | |
| StateAllocTraits::destroy(state_alloc, &segment_states[i]); | |
| } | |
| // Deallocate the array of pointers | |
| ChildPtrAllocTraits::deallocate(ptr_alloc, segments, total_segments); | |
| StateAllocTraits::deallocate(state_alloc, segment_states, total_segments); | |
| } | |
| T *get_address_if_exists(size_t index) const { | |
| size_t shift = level * each_level_bit_num; | |
| size_t segment_index = (index >> shift); | |
| if (segments[segment_index] == nullptr) { | |
| return nullptr; | |
| } | |
| return segments[segment_index]->get_address_if_exists( | |
| index & ((1ULL << shift) - 1)); | |
| } | |
| T *get_address(size_t index) { | |
| size_t shift = level * each_level_bit_num; | |
| size_t segment_index = (index >> shift); | |
| if (segments[segment_index] == nullptr) { | |
| // 1. Allocate memory for the child table | |
| segments[segment_index] = ChildAllocTraits::allocate(child_alloc, 1); | |
| // 2. Construct the child table (calls ChildTable constructor) | |
| ChildAllocTraits::construct(child_alloc, segments[segment_index]); | |
| } | |
| return segments[segment_index]->get_address(index & ((1ULL << shift) - 1)); | |
| } | |
| T *concurrent_get_address(size_t index) { | |
| size_t shift = level * each_level_bit_num; | |
| size_t segment_index = (index >> shift); | |
| // Optimistic check | |
| if (segment_states[segment_index].load(std::memory_order_acquire) == | |
| SegmentState::Allocated) { | |
| return segments[segment_index]->concurrent_get_address( | |
| index & ((1ULL << shift) - 1)); | |
| } | |
| SegmentState expected = SegmentState::Unallocated; | |
| if (segment_states[segment_index].compare_exchange_strong( | |
| expected, SegmentState::Allocating, std::memory_order_acq_rel, | |
| std::memory_order_acquire)) { | |
| // We won the race. Allocate. | |
| try { | |
| segments[segment_index] = ChildAllocTraits::allocate(child_alloc, 1); | |
| ChildAllocTraits::construct(child_alloc, segments[segment_index]); | |
| segment_states[segment_index].store(SegmentState::Allocated, | |
| std::memory_order_release); | |
| segment_states[segment_index].notify_all(); | |
| } catch (...) { | |
| // Allocation failed. Rollback state. | |
| segment_states[segment_index].store(SegmentState::Unallocated, | |
| std::memory_order_release); | |
| segment_states[segment_index].notify_all(); | |
| throw; | |
| } | |
| } else { | |
| // CAS failed. Wait until Allocated. | |
| while (true) { | |
| SegmentState current = | |
| segment_states[segment_index].load(std::memory_order_acquire); | |
| if (current == SegmentState::Allocated) { | |
| break; | |
| } | |
| if (current == SegmentState::Unallocated) { | |
| // Previous allocator failed, retry allocation | |
| return concurrent_get_address(index); | |
| } | |
| segment_states[segment_index].wait(SegmentState::Allocating, | |
| std::memory_order_acquire); | |
| } | |
| } | |
| return segments[segment_index]->concurrent_get_address( | |
| index & ((1ULL << shift) - 1)); | |
| } | |
| T *concurrent_get_address_if_exists(size_t index) const { | |
| size_t shift = level * each_level_bit_num; | |
| size_t segment_index = (index >> shift); | |
| if (segment_states[segment_index].load(std::memory_order_acquire) != | |
| SegmentState::Allocated) { | |
| return nullptr; | |
| } | |
| return segments[segment_index]->concurrent_get_address_if_exists( | |
| index & ((1ULL << shift) - 1)); | |
| } | |
| }; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment