Skip to content

Instantly share code, notes, and snippets.

@jacky860226
Created January 2, 2026 16:48
Show Gist options
  • Select an option

  • Save jacky860226/e7bdd387d4f9040f12d119d9ee152bf4 to your computer and use it in GitHub Desktop.

Select an option

Save jacky860226/e7bdd387d4f9040f12d119d9ee152bf4 to your computer and use it in GitHub Desktop.
Concurrent Vector
#include <atomic>
#include <cassert>
#include <new>
#include <thread>
#include <utility>
#if __cplusplus >= 202002L
#include <bit>
#elif defined(_MSC_VER)
#include <intrin.h>
#endif
// --- 輔助函式 ---
constexpr size_t simple_log2(size_t n) {
#if __cplusplus >= 202002L
// C++20 標準庫
return std::bit_width(n) - 1;
#elif defined(__GNUC__) || defined(__clang__)
// GCC/Clang 內建函式:計算前導零的個數 (Count Leading Zeros)
// sizeof(size_t) * 8 是總位元數
// 例如 32 位元下,n=4 (0...0100),clz=29,32-1-29 = 2
return sizeof(size_t) * 8 - 1 - __builtin_clzl(n);
#elif defined(_MSC_VER)
// MSVC 內建函式:尋找最高位 1 的位置
unsigned long index;
#if defined(_WIN64)
_BitScanReverse64(&index, n);
#else
_BitScanReverse(&index, n);
#endif
return index;
#else
// Fallback: 通用實作
size_t res = 0;
while (n > 1) {
n >>= 1;
res++;
}
return res;
#endif
}
// ==========================================
// SegmentTable: 負責底層記憶體分段管理
// ==========================================
template <typename T> class SegmentTable {
static constexpr size_t MAX_SEGMENTS = 64;
enum class SegmentState : uint8_t { Unallocated, Allocating, Allocated };
// 儲存指向各個 Segment 的指標 (一般指標,由 segment_states 保護)
T *segments[MAX_SEGMENTS];
// 狀態旗標
std::atomic<SegmentState> segment_states[MAX_SEGMENTS];
public:
SegmentTable() {
for (size_t i = 0; i < MAX_SEGMENTS; ++i) {
segments[i] = nullptr;
segment_states[i].store(SegmentState::Unallocated,
std::memory_order_relaxed);
}
}
~SegmentTable() {
// 注意:這裡只負責釋放記憶體 (operator delete[])
// 物件的解構 (Destructor) 應該由上層容器負責
for (size_t i = 0; i < MAX_SEGMENTS; ++i) {
T *seg = segments[i];
if (seg) {
delete[] reinterpret_cast<char *>(seg);
}
}
}
// 禁止複製
SegmentTable(const SegmentTable &) = delete;
SegmentTable &operator=(const SegmentTable &) = delete;
// --- 靜態索引計算 ---
static size_t get_segment_index(size_t index) {
return simple_log2(index | 1);
}
static size_t get_segment_base(size_t seg_index) {
return (size_t(1) << seg_index) & ~size_t(1);
}
static size_t get_segment_size(size_t seg_index) {
return seg_index == 0 ? 2 : (size_t(1) << seg_index);
}
// --- 核心功能:取得指定 index 的記憶體位址 ---
// 如果該 Segment 尚未分配,此函式會負責分配 (Thread-safe)
T *get_address(size_t index) {
size_t seg_idx = get_segment_index(index);
size_t seg_base = get_segment_base(seg_idx);
size_t offset = index - seg_base;
// 1. Optimistic check (Fast path)
if (segment_states[seg_idx].load(std::memory_order_acquire) ==
SegmentState::Allocated) {
T *segment = segments[seg_idx];
return &segment[offset];
}
// 2. 嘗試取得分配權
SegmentState expected = SegmentState::Unallocated;
if (segment_states[seg_idx].compare_exchange_strong(
expected, SegmentState::Allocating, std::memory_order_acq_rel,
std::memory_order_acquire)) {
// 我是分配者
try {
size_t seg_size = get_segment_size(seg_idx);
T *segment = reinterpret_cast<T *>(new char[seg_size * sizeof(T)]);
segments[seg_idx] = segment;
segment_states[seg_idx].store(SegmentState::Allocated,
std::memory_order_release);
segment_states[seg_idx].notify_all();
return &segment[offset];
} catch (...) {
// Allocation failed. Rollback state.
segment_states[seg_idx].store(SegmentState::Unallocated,
std::memory_order_release);
segment_states[seg_idx].notify_all();
throw;
}
} else {
// 別人正在分配,等待直到完成
while (true) {
SegmentState current =
segment_states[seg_idx].load(std::memory_order_acquire);
if (current == SegmentState::Allocated) {
break;
}
if (current == SegmentState::Unallocated) {
// Previous allocator failed, retry allocation
return get_address(index);
}
segment_states[seg_idx].wait(SegmentState::Allocating,
std::memory_order_acquire);
}
}
T *segment = segments[seg_idx];
return &segment[offset];
}
// 唯讀版本:如果 Segment 不存在,回傳 nullptr (不分配)
T *get_address_if_exists(size_t index) const {
size_t seg_idx = get_segment_index(index);
size_t seg_base = get_segment_base(seg_idx);
size_t offset = index - seg_base;
if (segment_states[seg_idx].load(std::memory_order_acquire) ==
SegmentState::Allocated) {
T *segment = segments[seg_idx];
return &segment[offset];
}
return nullptr;
}
};
// ==========================================
// ConcurrentVector: 提供容器介面
// ==========================================
template <typename T> class ConcurrentVector {
SegmentTable<T> table;
std::atomic<size_t> m_size{0};
public:
ConcurrentVector() = default;
~ConcurrentVector() {
// 負責呼叫物件解構子
size_t current_size = m_size.load(std::memory_order_relaxed);
for (size_t i = 0; i < current_size; ++i) {
// 這裡我們知道記憶體一定存在,所以可以直接拿
T *ptr = table.get_address_if_exists(i);
if (ptr) {
ptr->~T();
}
}
// table 的解構子會自動釋放記憶體
}
void push_back(const T &value) { emplace_back(value); }
void push_back(T &&value) { emplace_back(std::move(value)); }
template <typename... Args> void emplace_back(Args &&...args) {
// 1. 搶佔位置
size_t idx = m_size.fetch_add(1);
// 2. 取得記憶體位址 (如果需要會自動分配 Segment)
T *ptr = table.get_address(idx);
// 3. 原地建構物件
new (ptr) T(std::forward<Args>(args)...);
}
// 原子地增加 delta 個元素空間,並進行預設初始化
// 回傳值:新增範圍的起始索引 (Start Index)
size_t grow_by(size_t delta) {
if (delta == 0)
return m_size.load(std::memory_order_relaxed);
// 1. 原子地預留範圍
size_t start_index = m_size.fetch_add(delta);
size_t end_index = start_index + delta;
// 2. 確保範圍內的 Segment 已分配
// 優化:跳躍式檢查,只檢查每個 Segment 的第一個涉及到的元素
for (size_t i = start_index; i < end_index;) {
table.get_address(i);
// 計算下一個 Segment 的起始位置,直接跳過去
size_t seg_idx = SegmentTable<T>::get_segment_index(i);
size_t next_seg_base = SegmentTable<T>::get_segment_base(seg_idx) +
SegmentTable<T>::get_segment_size(seg_idx);
i = next_seg_base;
}
// 3. 建構元素
for (size_t i = start_index; i < end_index; ++i) {
T *ptr = table.get_address(i);
new (ptr) T();
}
return start_index;
}
// 確保 vector 至少有 n 個元素
void grow_to_at_least(size_t n) {
size_t c = m_size.load(std::memory_order_relaxed);
while (c < n) {
if (m_size.compare_exchange_weak(c, n)) {
// 成功更新 size 為 n,負責初始化 [c, n)
// 1. 確保範圍內的 Segment 已分配
for (size_t i = c; i < n;) {
table.get_address(i);
size_t seg_idx = SegmentTable<T>::get_segment_index(i);
size_t next_seg_base = SegmentTable<T>::get_segment_base(seg_idx) +
SegmentTable<T>::get_segment_size(seg_idx);
i = next_seg_base;
}
// 2. 建構元素
for (size_t i = c; i < n; ++i) {
T *ptr = table.get_address(i);
new (ptr) T();
}
return;
}
// CAS 失敗,c 會被更新為最新的 m_size,迴圈繼續檢查
}
}
T &operator[](size_t index) {
// 1. 取得記憶體位址
T *ptr = table.get_address_if_exists(index);
// 2. Spin Wait
// 只有當 ptr 為 null 且 index 在合法範圍內時,才需要等待
// 如果 index >= m_size,代表這是越界存取,我們不應該等待(否則會死鎖)
if (ptr == nullptr) {
size_t current_size = m_size.load(std::memory_order_acquire);
if (index < current_size) {
while (ptr == nullptr) {
std::this_thread::yield();
ptr = table.get_address_if_exists(index);
}
}
// 如果 index >= current_size,這裡 ptr 還是 nullptr
// 接下來的 *ptr 會導致存取 nullptr (Crash),這符合 C++ operator[]
// 越界的 UB 行為
}
return *ptr;
}
const T &operator[](size_t index) const {
// 轉呼叫 non-const 版本 (邏輯一樣)
return const_cast<ConcurrentVector *>(this)->operator[](index);
}
// 帶有邊界檢查的存取 (唯讀)
const T &at(size_t index) const {
return const_cast<ConcurrentVector *>(this)->at(index);
}
// 帶有邊界檢查的存取 (讀寫)
T &at(size_t index) {
// 1. 檢查是否越界
// 注意:這裡使用 acquire memory order 確保讀到最新的 size
if (index >= m_size.load(std::memory_order_acquire)) {
throw std::out_of_range("ConcurrentVector::at: index out of range");
}
// 2. 取得記憶體位址 (邏輯同 operator[])
T *ptr = table.get_address_if_exists(index);
// Spin Wait: 雖然 index < m_size,但 Segment 可能還在分配中
while (ptr == nullptr) {
std::this_thread::yield();
ptr = table.get_address_if_exists(index);
}
return *ptr;
}
size_t size() const { return m_size.load(std::memory_order_acquire); }
};
#include <atomic>
#include <memory>
#include <type_traits>
// Forward declaration
template <typename T, size_t level = 3, size_t each_level_bit_num = 16,
typename Allocator = std::allocator<T>>
class SegmentTable;
// Specialization for level 0 (Leaf Node) - Holds the actual data
template <typename T, size_t each_level_bit_num, typename Allocator>
class SegmentTable<T, 0, each_level_bit_num, Allocator> {
static_assert(each_level_bit_num >= 1 && each_level_bit_num <= 64,
"Invalid SegmentTable configuration.");
using AllocTraits = std::allocator_traits<Allocator>;
T *elements;
Allocator alloc;
public:
SegmentTable() {
size_t size = 1ULL << each_level_bit_num;
elements = AllocTraits::allocate(alloc, size);
// Note: We allocate raw memory. Construction is handled by the caller
// (e.g., ConcurrentVector).
}
~SegmentTable() {
size_t size = 1ULL << each_level_bit_num;
AllocTraits::deallocate(alloc, elements, size);
}
T *get_address(size_t index) {
return &elements[index & ((1ULL << each_level_bit_num) - 1)];
}
T *get_address_if_exists(size_t index) const {
return &elements[index & ((1ULL << each_level_bit_num) - 1)];
}
T *concurrent_get_address(size_t index) { return get_address(index); }
T *concurrent_get_address_if_exists(size_t index) const {
return get_address_if_exists(index);
}
};
// Primary Template for level > 0 (Internal Node) - Holds pointers to child
// tables
template <typename T, size_t level, size_t each_level_bit_num,
typename Allocator>
class SegmentTable {
static_assert(level >= 1 && level <= 64 && each_level_bit_num >= 1 &&
each_level_bit_num <= 64 &&
level * each_level_bit_num <= 64,
"Invalid SegmentTable configuration.");
using ChildTable = SegmentTable<T, level - 1, each_level_bit_num, Allocator>;
using ChildAllocator = typename std::allocator_traits<
Allocator>::template rebind_alloc<ChildTable>;
using ChildAllocTraits = std::allocator_traits<ChildAllocator>;
using ChildPtrAllocator = typename std::allocator_traits<
Allocator>::template rebind_alloc<ChildTable *>;
using ChildPtrAllocTraits = std::allocator_traits<ChildPtrAllocator>;
enum class SegmentState : uint8_t { Unallocated, Allocating, Allocated };
using StateAllocator = typename std::allocator_traits<
Allocator>::template rebind_alloc<std::atomic<SegmentState>>;
using StateAllocTraits = std::allocator_traits<StateAllocator>;
ChildTable **segments;
std::atomic<SegmentState> *segment_states;
ChildAllocator child_alloc;
ChildPtrAllocator ptr_alloc;
StateAllocator state_alloc;
public:
SegmentTable() {
size_t total_segments = 1ULL << each_level_bit_num;
segments = ChildPtrAllocTraits::allocate(ptr_alloc, total_segments);
segment_states = StateAllocTraits::allocate(state_alloc, total_segments);
for (size_t i = 0; i < total_segments; ++i) {
segments[i] = nullptr;
StateAllocTraits::construct(state_alloc, &segment_states[i],
SegmentState::Unallocated);
}
}
~SegmentTable() {
size_t total_segments = 1ULL << each_level_bit_num;
for (size_t i = 0; i < total_segments; ++i) {
if (segments[i] != nullptr) {
// Destroy the child table object
ChildAllocTraits::destroy(child_alloc, segments[i]);
// Deallocate the memory for the child table object
ChildAllocTraits::deallocate(child_alloc, segments[i], 1);
}
StateAllocTraits::destroy(state_alloc, &segment_states[i]);
}
// Deallocate the array of pointers
ChildPtrAllocTraits::deallocate(ptr_alloc, segments, total_segments);
StateAllocTraits::deallocate(state_alloc, segment_states, total_segments);
}
T *get_address_if_exists(size_t index) const {
size_t shift = level * each_level_bit_num;
size_t segment_index = (index >> shift);
if (segments[segment_index] == nullptr) {
return nullptr;
}
return segments[segment_index]->get_address_if_exists(
index & ((1ULL << shift) - 1));
}
T *get_address(size_t index) {
size_t shift = level * each_level_bit_num;
size_t segment_index = (index >> shift);
if (segments[segment_index] == nullptr) {
// 1. Allocate memory for the child table
segments[segment_index] = ChildAllocTraits::allocate(child_alloc, 1);
// 2. Construct the child table (calls ChildTable constructor)
ChildAllocTraits::construct(child_alloc, segments[segment_index]);
}
return segments[segment_index]->get_address(index & ((1ULL << shift) - 1));
}
T *concurrent_get_address(size_t index) {
size_t shift = level * each_level_bit_num;
size_t segment_index = (index >> shift);
// Optimistic check
if (segment_states[segment_index].load(std::memory_order_acquire) ==
SegmentState::Allocated) {
return segments[segment_index]->concurrent_get_address(
index & ((1ULL << shift) - 1));
}
SegmentState expected = SegmentState::Unallocated;
if (segment_states[segment_index].compare_exchange_strong(
expected, SegmentState::Allocating, std::memory_order_acq_rel,
std::memory_order_acquire)) {
// We won the race. Allocate.
try {
segments[segment_index] = ChildAllocTraits::allocate(child_alloc, 1);
ChildAllocTraits::construct(child_alloc, segments[segment_index]);
segment_states[segment_index].store(SegmentState::Allocated,
std::memory_order_release);
segment_states[segment_index].notify_all();
} catch (...) {
// Allocation failed. Rollback state.
segment_states[segment_index].store(SegmentState::Unallocated,
std::memory_order_release);
segment_states[segment_index].notify_all();
throw;
}
} else {
// CAS failed. Wait until Allocated.
while (true) {
SegmentState current =
segment_states[segment_index].load(std::memory_order_acquire);
if (current == SegmentState::Allocated) {
break;
}
if (current == SegmentState::Unallocated) {
// Previous allocator failed, retry allocation
return concurrent_get_address(index);
}
segment_states[segment_index].wait(SegmentState::Allocating,
std::memory_order_acquire);
}
}
return segments[segment_index]->concurrent_get_address(
index & ((1ULL << shift) - 1));
}
T *concurrent_get_address_if_exists(size_t index) const {
size_t shift = level * each_level_bit_num;
size_t segment_index = (index >> shift);
if (segment_states[segment_index].load(std::memory_order_acquire) !=
SegmentState::Allocated) {
return nullptr;
}
return segments[segment_index]->concurrent_get_address_if_exists(
index & ((1ULL << shift) - 1));
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment