Skip to content

Instantly share code, notes, and snippets.

@Boostibot
Last active September 8, 2025 14:57
Show Gist options
  • Select an option

  • Save Boostibot/6f20587ab48d8c4d769bde163615c2e4 to your computer and use it in GitHub Desktop.

Select an option

Save Boostibot/6f20587ab48d8c4d769bde163615c2e4 to your computer and use it in GitHub Desktop.
Lock free atomic string pool
#pragma once
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <assert.h>
#include <atomic>
#if defined(_MSC_VER)
#pragma warning(disable:4200) //disable warning on flexible array members
#endif
//Append only lock free atomic hash set of strings.
// After addition the caller obtains a unique id (pointer) that represents this string.
// Any further insertions are guaranteed to receive the same id.
//Lock free means there is a forward progress guarantee ensuring that at least one thread will
// make progress. Its impossible to get stuck for example even if some thread gets killed.
//Insertion (assuming no migration) requires just a single CAS and bunch of loads.
//Probably the biggest problem right now is that each string is separately allocated using malloc,
// which is fast and convenient, but results in super slow deinit. This is usually not a problem as
// string pools are most often used globally (ie. one per program). This can be solved by
// giving the StringPool its own heap or giving each thread its own arena where it allocates.
// For simplicitly neither of those is implemented here.
struct StringPool {
enum {
COUNTERS = 16,
CACHE_LINE = 64,
MIGRATE_AT_ONCE = 64,
INITIAL_CAPACITY = 256,
};
struct String {
uint32_t hash;
uint32_t size;
char data[];
};
struct Version {
uint64_t capacity;
std::atomic<Version*> next;
std::atomic<uint64_t> migrated;
std::atomic<String*> slots[];
};
struct Counter {
std::atomic<uint64_t> val;
uint8_t padding[CACHE_LINE - sizeof(uint64_t)];
};
using StringId = const char*;
static inline String* MIGRATED_SLOT = (String*) 1;
mutable std::atomic<Version*> version = NULL;
mutable std::atomic<Version*> first_version = NULL;
mutable std::atomic<Counter*> counters = NULL;
StringPool() = default;
~StringPool() { free_all(); }
StringPool(StringPool const&) = delete;
StringPool& operator=(StringPool const&) = delete;
StringPool& operator=(StringPool && other) {
//not thread safe!
this->version = other.version.exchange(this->version);
this->first_version = other.first_version.exchange(this->first_version);
this->counters = other.counters.exchange(this->counters);
return *this;
}
inline uint32_t get_length(StringId id) const {
String* slot = ((String*) (void*) id) - 1;
return slot->size;
}
inline uint32_t get_hash(StringId id) const {
String* slot = ((String*) (void*) id) - 1;
return slot->hash;
}
inline uint64_t count() const {
uint64_t count = 0;
Counter* counters = this->counters.load();
if(counters)
for(uint32_t i = 0; i < COUNTERS; i++)
count += counters[i].val.load();
return count;
}
inline StringId get_or_add(const char* key) {
return get_or_add(key, strlen(key));
}
StringId get_or_add(const char* key, size_t len) {
assert(len <= UINT32_MAX);
uint32_t was_added = 0;
uint32_t hash = hash32_murmur(key, len, 0);
Counter* counters = get_or_init_counters();
Version* version = get_or_init_version();
String* out = get_or_add_from_version(version, counters, NULL, key, (uint32_t) len, hash, &was_added);
migrate(version, counters);
return out->data;
}
void free_all() {
//Perform cleanup and quite a lot of asserts
Counter* counters = this->counters.load();
Version* version = this->version.load();
Version* first_version = this->first_version.load();
if(version) {
uint64_t slots_count = 0;
for(uint64_t i = 0; i < version->capacity; i++) {
String* slot = version->slots[i].load();
if(slot) {
assert(slot != MIGRATED_SLOT);
slots_count += 1;
free(slot);
}
}
for(Version* ver = first_version; ver != version; ver = ver->next.load())
assert(ver->migrated.load() == ver->capacity);
assert(counters);
uint64_t counter_count = 0;
for(uint64_t i = 0; i < COUNTERS; i++)
counter_count += counters[i].val.load();
assert(counter_count == slots_count); (void) counter_count; (void) slots_count;
}
free(counters);
for(Version* ver = first_version; ver; ) {
Version* next = ver->next.load();
free(ver);
ver = next;
}
memset((void*) this, 0, sizeof *this);
}
static uint32_t hash32_murmur(const void* key, int64_t size, uint32_t seed) {
assert((key != NULL || size == 0) && size >= 0);
const int r = 24;
const uint32_t magic = 0x5bd1e995;
const uint8_t* data = (const uint8_t*)key;
const uint8_t* end = data + size;
uint32_t hash = seed ^ ((uint32_t) size * magic);;
for(; data < end - 3; data += 4) {
uint32_t read = 0;
memcpy(&read, data, sizeof read);
read *= magic;
read ^= read >> r;
read *= magic;
hash *= magic;
hash ^= read;
}
switch(size & 3) {
case 3: hash ^= data[2] << 16;
case 2: hash ^= data[1] << 8;
case 1: hash ^= data[0];
hash *= magic;
};
hash ^= hash >> 13;
hash *= magic;
hash ^= hash >> 15;
return hash;
}
private:
String* get_or_add_from_version(
Version* version, Counter* counters, String* provided_slot,
const char* key, uint32_t len, uint32_t hash, uint32_t* added_counter) const {
//Go through the versions starting from head.
//Try to find or add into each one.
// If its full go onto a next one.
// If there is no next one create one ourselves.
String my_header = {hash, len};
String* my_slot = provided_slot;
Version* ver = version;
//This function is also called from migration
// during which we are merely copying slots between different
// versions. Because we are not adding any new slots we dont
// increase counters nor deallocate if we already found an entry
bool is_called_from_migration = provided_slot != NULL;
for(;;) {
//Quadratic probe the hashtable
uint64_t iter_to = ver->capacity/COUNTERS;
uint64_t mask = ver->capacity - 1;
uint64_t i = hash & mask;
for(uint64_t iter = 0; iter < iter_to; i = (i + ++iter) & mask) {
repeat:
String* slot = ver->slots[i].load();
//If empty slot try to add ourselves.
// If we have dont have an allocation yet make one.
// If someone was faster in claiming this spot then just repeat this probe iteration.
if(slot == NULL) {
if(my_slot == NULL) {
assert(is_called_from_migration == false);
my_slot = allocate_slot(key, len, hash);
}
if(ver->slots[i].compare_exchange_strong(slot, my_slot) == false)
goto repeat;
if(is_called_from_migration == false) {
//Aim for 75% fullness spread amongst multiple counters
// thus we expect a single counter to be at most
// capacity*0.75 / COUNTERS
uint64_t counter_index = hash % COUNTERS;
uint64_t counter_max = ver->capacity*3/4 / COUNTERS;
uint64_t counter_val = counters[counter_index].val.fetch_add(1);
if(counter_val >= counter_max) {
Version* next_gen = ver->next.load();
if(next_gen == NULL) {
Version* new_gen = create_version(ver->capacity*2);
if(ver->next.compare_exchange_strong(next_gen, new_gen) == false)
free(new_gen);
}
}
}
*added_counter += 1;
return my_slot;
}
//if someone else started migration then abandon search in this
// version and go to the next one
if(slot == MIGRATED_SLOT)
break;
if(memcmp(slot, &my_header, sizeof(uint32_t)*2) == 0) {
if(memcmp(slot->data, key, len) == 0) {
//backoff from my allocation if I have it
// (and its not called during migration)
if(my_slot && is_called_from_migration == false)
free(my_slot);
return slot;
}
}
}
Version* next_gen = ver->next.load();
if(next_gen == NULL) {
Version* new_gen = create_version(ver->capacity*2);
if(ver->next.compare_exchange_strong(next_gen, new_gen) == false)
free(new_gen);
next_gen = ver->next.load();
}
ver = next_gen;
}
}
void migrate(Version* head, Counter* counters) const {
//if there is somewhere to migrate to, help migrate every entry out of it.
//We migrate MIGRATE_AT_ONCE slots at once to not spam the shared cache lines
// with FAA too much.
//To make sure nobody added anything to the space we have already migrated from
// (there might be empty slots) we instead fill them with MIGRATED_SLOT.
// This slot value has special meaning. Anyone who will want to add something
// will either add it in a space not yet filled with MIGRATED_SLOT
// (thus we will migrate it later) or will encounter it and be forced to
// move to a new version
Version* nextver = NULL;
for(Version* currver = head; (nextver = currver->next.load()); currver = nextver) {
//fast way out if we are late
if(currver->migrated.load() >= currver->capacity) {
if(this->version.load() == currver)
this->version.compare_exchange_strong(currver, nextver);
continue;
}
//Migrate in random order to not make migrating threads clash.
// To do this generate a random starting position using the stack
// pointer as thread specific seed (and multiply such that two starting
// positions will be MIGRATE_AT_ONCE apart)
uint64_t mask = currver->capacity - 1;
uint32_t seed = (uint32_t) (uintptr_t) &mask;
uint32_t migrated_by_us = 0;
uint64_t iterate_from = hash32_murmur(&mask, sizeof mask, seed)*MIGRATE_AT_ONCE;
for(uint64_t iterated = 0;; ) {
uint64_t i = (iterate_from + iterated) & mask;
String* my_slot = currver->slots[i].load();
//Migrate slots not yet migrated, fill empty slots with MIGRATED_SLOT.
//Count the number of migrated slots by this thread.
if(my_slot != MIGRATED_SLOT) {
if(my_slot) {
String* added = get_or_add_from_version(nextver, counters, my_slot,
my_slot->data, my_slot->size, my_slot->hash, &migrated_by_us);
assert(added == my_slot); (void) added;
}
else {
if(currver->slots[i].compare_exchange_strong(my_slot, MIGRATED_SLOT) == false)
continue;
migrated_by_us += 1;
}
}
//Every once in a while update currver->migrated with the number of slots migrated.
//If together all threads have migrated all values attempt to move the current active
// version forward.
//If a thread gets stuck/dies during migration and never adds its contribution to
// currver->migrated, the other threads would get stuck in this loop. To eliviate this
// problem and make this algorithm fully lock free we provide an alternative check.
// If we have iterate every value then we either observed everything migrated or
// migrated it ourselves. Either way we can safey move to the other version.
// This check is very slow and in practice never fires
iterated++;
if(iterated % MIGRATE_AT_ONCE == 0) {
uint64_t migrated = currver->migrated.fetch_add(migrated_by_us) + migrated_by_us;
migrated_by_us = 0;
if(migrated >= currver->capacity || iterated >= currver->capacity) {
// Try to move the version forward.
if(migrated != currver->capacity)
currver->migrated.store(currver->capacity);
if(this->version.load() == currver)
this->version.compare_exchange_strong(currver, nextver);
break;
}
}
}
}
}
Counter* get_or_init_counters() const {
Counter* counters = this->counters.load();
if(counters == NULL) {
Counter* new_counters = (Counter*) calloc(1, sizeof(Counter)*COUNTERS);
if(this->counters.compare_exchange_strong(counters, new_counters))
counters = new_counters;
else
counters = this->counters.load();
}
return counters;
}
Version* get_or_init_version() const {
Version* version = this->version.load();
if(version == NULL) {
Version* new_version = create_version(INITIAL_CAPACITY);
if(this->version.compare_exchange_strong(version, new_version)) {
this->first_version.store(new_version);
version = new_version;
}
else
version = this->version.load();
}
return version;
}
inline static Version* create_version(uint64_t capacity) {
Version* new_gen = (Version*) calloc(1, sizeof(Version) + capacity*sizeof(void*));
assert((capacity & (capacity - 1)) == 0 && "must be power of two");
assert(new_gen && "malloc shall not fail");
new_gen->capacity = capacity;
return new_gen;
}
inline static String* allocate_slot(const char* key, uint32_t size, uint32_t hash) {
String* header = (String*) malloc(sizeof(String) + size + 1);
assert(header && "malloc shall not fail");
header->hash = hash;
header->size = size;
char* my_key = (char*) (header + 1);
memcpy(my_key, key, size);
my_key[size] = '\0';
return header;
}
};
#if defined(_MSC_VER)
#pragma warning(default:4200)
#endif
//Testing + benchmarking of the string pool
#include "string_pool.h"
#include <stdio.h>
#include <thread>
#include <vector>
#include <unordered_map>
#include <algorithm>
#define TEST(x) (!(x) ? (printf("TEST(%s) failed!\n", #x), abort()) : (void) 0)
using StringId = StringPool::StringId;
void test_unit() {
//We try to test at least a couple of rehashes so we
// test on quite a bit of data
enum {TEST_COUNT = 1 + StringPool::INITIAL_CAPACITY*4};
struct TestItem {
char str[16];
StringId id;
};
StringPool pool;
std::vector<TestItem> items(TEST_COUNT);
for(uint64_t i = 0; i < items.size(); i++) {
snprintf(items[i].str, sizeof items[i].str, "%llu", (unsigned long long) i);
}
for(TestItem& item : items) {
uint32_t len = (uint32_t) strlen(item.str);
item.id = pool.get_or_add(item.str, len);
TEST(pool.get_length(item.id) == len);
TEST(strcmp(item.id, item.str) == 0);
}
TEST(pool.count() == TEST_COUNT);
for(TestItem& item : items) {
StringId id = pool.get_or_add(item.str);
TEST(id == item.id);
}
TEST(pool.count() == TEST_COUNT);
//TEST properties
for(uint64_t k = 0; k < items.size(); k++) {
TestItem const& outer = items[k];
for(uint64_t i = 0; i < items.size(); i++) {
if(i != k) {
TestItem const& inner = items[i];
TEST(outer.id != inner.id);
}
}
}
pool.free_all();
TEST(pool.count() == 0);
}
//Test adding from multiple threads with sharing - that is when sharing = 4
// then the same value will be tried to insert 4 times.
void test_stress(uint32_t num_threads, uint32_t sharing, double seconds) {
struct Entry {
uint64_t counter;
StringId id;
};
StringPool pool;
std::atomic<uint32_t> started = 0;
std::atomic<uint32_t> run = 0;
std::atomic<uint32_t> stopped = 0;
std::atomic<uint64_t> shared_counter = 0;
std::vector<std::vector<Entry>> entries(num_threads);
for(uint32_t i = 0; i < num_threads; i++) {
std::vector<Entry>* my_entries = &entries[i];
std::thread([&, my_entries]{
{
started.fetch_add(1);
while(run == 0);
while(run == 1) {
uint64_t counter = shared_counter.fetch_add(1);
char buffer[16] = {0};
snprintf(buffer, sizeof buffer, "%llu", (unsigned long long) counter/sharing);
Entry entry = {0};
entry.counter = counter;
entry.id = pool.get_or_add(buffer);
TEST(strcmp(entry.id, buffer) == 0);
my_entries->push_back(entry);
}
}
stopped.fetch_add(1);
}).detach();
}
while(started != num_threads);
run = 1;
std::this_thread::sleep_for(std::chrono::milliseconds((int64_t) (seconds * 1e3)));
run = 2;
while(stopped != num_threads);
assert(active_thread_logs == 0);
//Now test that:
// 1) Each entry was added exactly once
// 2) All counter values are present
// 3) Handle represent what they should and point to valid c strings
// 3) All new values will be found rather than added
std::vector<Entry> all_entries;
for(uint64_t i = 0; i < num_threads; i++) {
std::vector<Entry> curr = std::move(entries[i]);
for(uint64_t j = 0; j < curr.size(); j++)
all_entries.push_back(curr[j]);
}
entries.resize(0);
std::sort(all_entries.begin(), all_entries.end(), [](Entry const& a, Entry const& b){
return a.counter < b.counter;
});
TEST(all_entries.size() == shared_counter);
uint64_t till = shared_counter/sharing;
for(uint64_t i = 0; i < till; i ++) {
Entry main_entry = all_entries[i*sharing];
char buffer[16] = {0};
snprintf(buffer, sizeof buffer, "%llu", (unsigned long long) i);
TEST(strcmp(main_entry.id, buffer) == 0);
for(uint64_t k = 0; k < sharing; k++) {
uint64_t at = k + i*sharing;
if(at >= shared_counter)
break;
Entry entry = all_entries[at];
TEST(entry.id == main_entry.id);
TEST(entry.counter == at);
}
}
pool.free_all();
}
static inline uint64_t random_splitmix(uint64_t* state)
{
uint64_t z = (*state += 0x9e3779b97f4a7c15);
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9;
z = (z ^ (z >> 27)) * 0x94d049bb133111eb;
return z ^ (z >> 31);
}
void benchmark(uint32_t num_threads, double seconds) {
//insertions
{
StringPool pool;
std::atomic<uint64_t> total_iters = 0;
std::atomic<uint32_t> started = 0;
std::atomic<uint32_t> run = 0;
std::atomic<uint32_t> stopped = 0;
for(uint32_t i = 0; i < num_threads; i++) {
std::thread([&, i]{
uint64_t iters = 0;
uint64_t splitmix_state = (uint64_t) (i+1)*74174918947891;
started.fetch_add(1);
while(run == 0);
while(run == 1) {
uint64_t key = random_splitmix(&splitmix_state);
pool.get_or_add((char*) (void*) &key, sizeof key);
iters += 1;
}
total_iters.fetch_add(iters);
stopped.fetch_add(1);
}).detach();
}
while(started != num_threads);
run = 1;
std::this_thread::sleep_for(std::chrono::milliseconds((int64_t) (seconds * 1e3)));
run = 2;
while(stopped != num_threads);
double Mops = (double) total_iters / 1e6;
printf("insertion: num_threads=%u throughput=%.2lfMops/s (Ops:%.2lfM time:%.2lfs)\n",
num_threads, Mops/seconds, Mops, seconds);
//uint64_t cap = pool.version.load()->capacity;
//uint64_t pow2 = 0;
//while((1ull << pow2) < cap) pow2+=1;
//printf("table cap: 2^%llu used:%lf%%\n", pow2, 100.0*total_iters/cap);
}
//lookups
{
StringPool pool;
uint64_t range = 1 << 20;
for(uint64_t key = 0; key < range; key++)
pool.get_or_add((char*) (void*) &key, sizeof key);
std::atomic<uint64_t> total_iters = 0;
std::atomic<uint32_t> started = 0;
std::atomic<uint32_t> run = 0;
std::atomic<uint32_t> stopped = 0;
for(uint32_t i = 0; i < num_threads; i++) {
std::thread([&, i]{
uint64_t iters = 0;
uint64_t splitmix_state = (uint64_t) (i+1)*74174918947891;
started.fetch_add(1);
while(run == 0);
while(run == 1) {
uint64_t key = random_splitmix(&splitmix_state) & (range - 1);
pool.get_or_add((char*) (void*) &key, sizeof key);
iters += 1;
}
total_iters.fetch_add(iters);
stopped.fetch_add(1);
}).detach();
}
while(started != num_threads);
run = 1;
std::this_thread::sleep_for(std::chrono::milliseconds((int64_t) (seconds * 1e3)));
run = 2;
while(stopped != num_threads);
double Mops = (double) total_iters / 1e6;
printf("lookup: num_threads=%u throughput=%.2lfMops/s (Ops:%.2lfM time:%.2lfs)\n",
num_threads, Mops/seconds, Mops, seconds);
}
}
int main() {
printf("unit testing ...\n");
test_unit();
printf("benchmarking ...\n");
for(uint32_t threads = 1;; threads += 2) {
if(threads > std::thread::hardware_concurrency())
threads = std::thread::hardware_concurrency();
benchmark(threads, 1);
if(threads == std::thread::hardware_concurrency())
break;
}
printf("stress testing ...\n");
for(uint32_t sharing = 1; sharing <= std::thread::hardware_concurrency() + 2; sharing += 3)
for(uint32_t threads = 1; threads <= std::thread::hardware_concurrency(); threads += 1) {
printf("testing with threads=%u sharing=%u ...\n", threads, sharing);
test_stress(threads, sharing, 0.5);
}
}
@Boostibot
Copy link
Author

Updates:

  • Migration is fully lock free. Previously there was a possible case where if a certain migrating thread would die, the migration would never complete. Now we still do the same mechanism but also include a slow fallback that will always eventually complete.
  • More resilient to hash collisions. Before we would rehash if the number of probes was too high. This worked well in practice but could result in very easy (accidental) denial of service attack. Now we properly count filled slots using distributed counters. Another upside of this is that we can easily get the (approximate) count of strings in the pool.
  • Removed the fragment from slot pointer. It only marginally improved performance and made the code harder to read and less portable.
  • Smaller size. We now allocate bigger version of the table on demand and add to a linked list. Previously we had a big array of pointers on the stack.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment