Last active
September 8, 2025 14:57
-
-
Save Boostibot/6f20587ab48d8c4d769bde163615c2e4 to your computer and use it in GitHub Desktop.
Lock free atomic string pool
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #pragma once | |
| #include <stdlib.h> | |
| #include <stdint.h> | |
| #include <string.h> | |
| #include <assert.h> | |
| #include <atomic> | |
| #if defined(_MSC_VER) | |
| #pragma warning(disable:4200) //disable warning on flexible array members | |
| #endif | |
| //Append only lock free atomic hash set of strings. | |
| // After addition the caller obtains a unique id (pointer) that represents this string. | |
| // Any further insertions are guaranteed to receive the same id. | |
| //Lock free means there is a forward progress guarantee ensuring that at least one thread will | |
| // make progress. Its impossible to get stuck for example even if some thread gets killed. | |
| //Insertion (assuming no migration) requires just a single CAS and bunch of loads. | |
| //Probably the biggest problem right now is that each string is separately allocated using malloc, | |
| // which is fast and convenient, but results in super slow deinit. This is usually not a problem as | |
| // string pools are most often used globally (ie. one per program). This can be solved by | |
| // giving the StringPool its own heap or giving each thread its own arena where it allocates. | |
| // For simplicitly neither of those is implemented here. | |
| struct StringPool { | |
| enum { | |
| COUNTERS = 16, | |
| CACHE_LINE = 64, | |
| MIGRATE_AT_ONCE = 64, | |
| INITIAL_CAPACITY = 256, | |
| }; | |
| struct String { | |
| uint32_t hash; | |
| uint32_t size; | |
| char data[]; | |
| }; | |
| struct Version { | |
| uint64_t capacity; | |
| std::atomic<Version*> next; | |
| std::atomic<uint64_t> migrated; | |
| std::atomic<String*> slots[]; | |
| }; | |
| struct Counter { | |
| std::atomic<uint64_t> val; | |
| uint8_t padding[CACHE_LINE - sizeof(uint64_t)]; | |
| }; | |
| using StringId = const char*; | |
| static inline String* MIGRATED_SLOT = (String*) 1; | |
| mutable std::atomic<Version*> version = NULL; | |
| mutable std::atomic<Version*> first_version = NULL; | |
| mutable std::atomic<Counter*> counters = NULL; | |
| StringPool() = default; | |
| ~StringPool() { free_all(); } | |
| StringPool(StringPool const&) = delete; | |
| StringPool& operator=(StringPool const&) = delete; | |
| StringPool& operator=(StringPool && other) { | |
| //not thread safe! | |
| this->version = other.version.exchange(this->version); | |
| this->first_version = other.first_version.exchange(this->first_version); | |
| this->counters = other.counters.exchange(this->counters); | |
| return *this; | |
| } | |
| inline uint32_t get_length(StringId id) const { | |
| String* slot = ((String*) (void*) id) - 1; | |
| return slot->size; | |
| } | |
| inline uint32_t get_hash(StringId id) const { | |
| String* slot = ((String*) (void*) id) - 1; | |
| return slot->hash; | |
| } | |
| inline uint64_t count() const { | |
| uint64_t count = 0; | |
| Counter* counters = this->counters.load(); | |
| if(counters) | |
| for(uint32_t i = 0; i < COUNTERS; i++) | |
| count += counters[i].val.load(); | |
| return count; | |
| } | |
| inline StringId get_or_add(const char* key) { | |
| return get_or_add(key, strlen(key)); | |
| } | |
| StringId get_or_add(const char* key, size_t len) { | |
| assert(len <= UINT32_MAX); | |
| uint32_t was_added = 0; | |
| uint32_t hash = hash32_murmur(key, len, 0); | |
| Counter* counters = get_or_init_counters(); | |
| Version* version = get_or_init_version(); | |
| String* out = get_or_add_from_version(version, counters, NULL, key, (uint32_t) len, hash, &was_added); | |
| migrate(version, counters); | |
| return out->data; | |
| } | |
| void free_all() { | |
| //Perform cleanup and quite a lot of asserts | |
| Counter* counters = this->counters.load(); | |
| Version* version = this->version.load(); | |
| Version* first_version = this->first_version.load(); | |
| if(version) { | |
| uint64_t slots_count = 0; | |
| for(uint64_t i = 0; i < version->capacity; i++) { | |
| String* slot = version->slots[i].load(); | |
| if(slot) { | |
| assert(slot != MIGRATED_SLOT); | |
| slots_count += 1; | |
| free(slot); | |
| } | |
| } | |
| for(Version* ver = first_version; ver != version; ver = ver->next.load()) | |
| assert(ver->migrated.load() == ver->capacity); | |
| assert(counters); | |
| uint64_t counter_count = 0; | |
| for(uint64_t i = 0; i < COUNTERS; i++) | |
| counter_count += counters[i].val.load(); | |
| assert(counter_count == slots_count); (void) counter_count; (void) slots_count; | |
| } | |
| free(counters); | |
| for(Version* ver = first_version; ver; ) { | |
| Version* next = ver->next.load(); | |
| free(ver); | |
| ver = next; | |
| } | |
| memset((void*) this, 0, sizeof *this); | |
| } | |
| static uint32_t hash32_murmur(const void* key, int64_t size, uint32_t seed) { | |
| assert((key != NULL || size == 0) && size >= 0); | |
| const int r = 24; | |
| const uint32_t magic = 0x5bd1e995; | |
| const uint8_t* data = (const uint8_t*)key; | |
| const uint8_t* end = data + size; | |
| uint32_t hash = seed ^ ((uint32_t) size * magic);; | |
| for(; data < end - 3; data += 4) { | |
| uint32_t read = 0; | |
| memcpy(&read, data, sizeof read); | |
| read *= magic; | |
| read ^= read >> r; | |
| read *= magic; | |
| hash *= magic; | |
| hash ^= read; | |
| } | |
| switch(size & 3) { | |
| case 3: hash ^= data[2] << 16; | |
| case 2: hash ^= data[1] << 8; | |
| case 1: hash ^= data[0]; | |
| hash *= magic; | |
| }; | |
| hash ^= hash >> 13; | |
| hash *= magic; | |
| hash ^= hash >> 15; | |
| return hash; | |
| } | |
| private: | |
| String* get_or_add_from_version( | |
| Version* version, Counter* counters, String* provided_slot, | |
| const char* key, uint32_t len, uint32_t hash, uint32_t* added_counter) const { | |
| //Go through the versions starting from head. | |
| //Try to find or add into each one. | |
| // If its full go onto a next one. | |
| // If there is no next one create one ourselves. | |
| String my_header = {hash, len}; | |
| String* my_slot = provided_slot; | |
| Version* ver = version; | |
| //This function is also called from migration | |
| // during which we are merely copying slots between different | |
| // versions. Because we are not adding any new slots we dont | |
| // increase counters nor deallocate if we already found an entry | |
| bool is_called_from_migration = provided_slot != NULL; | |
| for(;;) { | |
| //Quadratic probe the hashtable | |
| uint64_t iter_to = ver->capacity/COUNTERS; | |
| uint64_t mask = ver->capacity - 1; | |
| uint64_t i = hash & mask; | |
| for(uint64_t iter = 0; iter < iter_to; i = (i + ++iter) & mask) { | |
| repeat: | |
| String* slot = ver->slots[i].load(); | |
| //If empty slot try to add ourselves. | |
| // If we have dont have an allocation yet make one. | |
| // If someone was faster in claiming this spot then just repeat this probe iteration. | |
| if(slot == NULL) { | |
| if(my_slot == NULL) { | |
| assert(is_called_from_migration == false); | |
| my_slot = allocate_slot(key, len, hash); | |
| } | |
| if(ver->slots[i].compare_exchange_strong(slot, my_slot) == false) | |
| goto repeat; | |
| if(is_called_from_migration == false) { | |
| //Aim for 75% fullness spread amongst multiple counters | |
| // thus we expect a single counter to be at most | |
| // capacity*0.75 / COUNTERS | |
| uint64_t counter_index = hash % COUNTERS; | |
| uint64_t counter_max = ver->capacity*3/4 / COUNTERS; | |
| uint64_t counter_val = counters[counter_index].val.fetch_add(1); | |
| if(counter_val >= counter_max) { | |
| Version* next_gen = ver->next.load(); | |
| if(next_gen == NULL) { | |
| Version* new_gen = create_version(ver->capacity*2); | |
| if(ver->next.compare_exchange_strong(next_gen, new_gen) == false) | |
| free(new_gen); | |
| } | |
| } | |
| } | |
| *added_counter += 1; | |
| return my_slot; | |
| } | |
| //if someone else started migration then abandon search in this | |
| // version and go to the next one | |
| if(slot == MIGRATED_SLOT) | |
| break; | |
| if(memcmp(slot, &my_header, sizeof(uint32_t)*2) == 0) { | |
| if(memcmp(slot->data, key, len) == 0) { | |
| //backoff from my allocation if I have it | |
| // (and its not called during migration) | |
| if(my_slot && is_called_from_migration == false) | |
| free(my_slot); | |
| return slot; | |
| } | |
| } | |
| } | |
| Version* next_gen = ver->next.load(); | |
| if(next_gen == NULL) { | |
| Version* new_gen = create_version(ver->capacity*2); | |
| if(ver->next.compare_exchange_strong(next_gen, new_gen) == false) | |
| free(new_gen); | |
| next_gen = ver->next.load(); | |
| } | |
| ver = next_gen; | |
| } | |
| } | |
| void migrate(Version* head, Counter* counters) const { | |
| //if there is somewhere to migrate to, help migrate every entry out of it. | |
| //We migrate MIGRATE_AT_ONCE slots at once to not spam the shared cache lines | |
| // with FAA too much. | |
| //To make sure nobody added anything to the space we have already migrated from | |
| // (there might be empty slots) we instead fill them with MIGRATED_SLOT. | |
| // This slot value has special meaning. Anyone who will want to add something | |
| // will either add it in a space not yet filled with MIGRATED_SLOT | |
| // (thus we will migrate it later) or will encounter it and be forced to | |
| // move to a new version | |
| Version* nextver = NULL; | |
| for(Version* currver = head; (nextver = currver->next.load()); currver = nextver) { | |
| //fast way out if we are late | |
| if(currver->migrated.load() >= currver->capacity) { | |
| if(this->version.load() == currver) | |
| this->version.compare_exchange_strong(currver, nextver); | |
| continue; | |
| } | |
| //Migrate in random order to not make migrating threads clash. | |
| // To do this generate a random starting position using the stack | |
| // pointer as thread specific seed (and multiply such that two starting | |
| // positions will be MIGRATE_AT_ONCE apart) | |
| uint64_t mask = currver->capacity - 1; | |
| uint32_t seed = (uint32_t) (uintptr_t) &mask; | |
| uint32_t migrated_by_us = 0; | |
| uint64_t iterate_from = hash32_murmur(&mask, sizeof mask, seed)*MIGRATE_AT_ONCE; | |
| for(uint64_t iterated = 0;; ) { | |
| uint64_t i = (iterate_from + iterated) & mask; | |
| String* my_slot = currver->slots[i].load(); | |
| //Migrate slots not yet migrated, fill empty slots with MIGRATED_SLOT. | |
| //Count the number of migrated slots by this thread. | |
| if(my_slot != MIGRATED_SLOT) { | |
| if(my_slot) { | |
| String* added = get_or_add_from_version(nextver, counters, my_slot, | |
| my_slot->data, my_slot->size, my_slot->hash, &migrated_by_us); | |
| assert(added == my_slot); (void) added; | |
| } | |
| else { | |
| if(currver->slots[i].compare_exchange_strong(my_slot, MIGRATED_SLOT) == false) | |
| continue; | |
| migrated_by_us += 1; | |
| } | |
| } | |
| //Every once in a while update currver->migrated with the number of slots migrated. | |
| //If together all threads have migrated all values attempt to move the current active | |
| // version forward. | |
| //If a thread gets stuck/dies during migration and never adds its contribution to | |
| // currver->migrated, the other threads would get stuck in this loop. To eliviate this | |
| // problem and make this algorithm fully lock free we provide an alternative check. | |
| // If we have iterate every value then we either observed everything migrated or | |
| // migrated it ourselves. Either way we can safey move to the other version. | |
| // This check is very slow and in practice never fires | |
| iterated++; | |
| if(iterated % MIGRATE_AT_ONCE == 0) { | |
| uint64_t migrated = currver->migrated.fetch_add(migrated_by_us) + migrated_by_us; | |
| migrated_by_us = 0; | |
| if(migrated >= currver->capacity || iterated >= currver->capacity) { | |
| // Try to move the version forward. | |
| if(migrated != currver->capacity) | |
| currver->migrated.store(currver->capacity); | |
| if(this->version.load() == currver) | |
| this->version.compare_exchange_strong(currver, nextver); | |
| break; | |
| } | |
| } | |
| } | |
| } | |
| } | |
| Counter* get_or_init_counters() const { | |
| Counter* counters = this->counters.load(); | |
| if(counters == NULL) { | |
| Counter* new_counters = (Counter*) calloc(1, sizeof(Counter)*COUNTERS); | |
| if(this->counters.compare_exchange_strong(counters, new_counters)) | |
| counters = new_counters; | |
| else | |
| counters = this->counters.load(); | |
| } | |
| return counters; | |
| } | |
| Version* get_or_init_version() const { | |
| Version* version = this->version.load(); | |
| if(version == NULL) { | |
| Version* new_version = create_version(INITIAL_CAPACITY); | |
| if(this->version.compare_exchange_strong(version, new_version)) { | |
| this->first_version.store(new_version); | |
| version = new_version; | |
| } | |
| else | |
| version = this->version.load(); | |
| } | |
| return version; | |
| } | |
| inline static Version* create_version(uint64_t capacity) { | |
| Version* new_gen = (Version*) calloc(1, sizeof(Version) + capacity*sizeof(void*)); | |
| assert((capacity & (capacity - 1)) == 0 && "must be power of two"); | |
| assert(new_gen && "malloc shall not fail"); | |
| new_gen->capacity = capacity; | |
| return new_gen; | |
| } | |
| inline static String* allocate_slot(const char* key, uint32_t size, uint32_t hash) { | |
| String* header = (String*) malloc(sizeof(String) + size + 1); | |
| assert(header && "malloc shall not fail"); | |
| header->hash = hash; | |
| header->size = size; | |
| char* my_key = (char*) (header + 1); | |
| memcpy(my_key, key, size); | |
| my_key[size] = '\0'; | |
| return header; | |
| } | |
| }; | |
| #if defined(_MSC_VER) | |
| #pragma warning(default:4200) | |
| #endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| //Testing + benchmarking of the string pool | |
| #include "string_pool.h" | |
| #include <stdio.h> | |
| #include <thread> | |
| #include <vector> | |
| #include <unordered_map> | |
| #include <algorithm> | |
| #define TEST(x) (!(x) ? (printf("TEST(%s) failed!\n", #x), abort()) : (void) 0) | |
| using StringId = StringPool::StringId; | |
| void test_unit() { | |
| //We try to test at least a couple of rehashes so we | |
| // test on quite a bit of data | |
| enum {TEST_COUNT = 1 + StringPool::INITIAL_CAPACITY*4}; | |
| struct TestItem { | |
| char str[16]; | |
| StringId id; | |
| }; | |
| StringPool pool; | |
| std::vector<TestItem> items(TEST_COUNT); | |
| for(uint64_t i = 0; i < items.size(); i++) { | |
| snprintf(items[i].str, sizeof items[i].str, "%llu", (unsigned long long) i); | |
| } | |
| for(TestItem& item : items) { | |
| uint32_t len = (uint32_t) strlen(item.str); | |
| item.id = pool.get_or_add(item.str, len); | |
| TEST(pool.get_length(item.id) == len); | |
| TEST(strcmp(item.id, item.str) == 0); | |
| } | |
| TEST(pool.count() == TEST_COUNT); | |
| for(TestItem& item : items) { | |
| StringId id = pool.get_or_add(item.str); | |
| TEST(id == item.id); | |
| } | |
| TEST(pool.count() == TEST_COUNT); | |
| //TEST properties | |
| for(uint64_t k = 0; k < items.size(); k++) { | |
| TestItem const& outer = items[k]; | |
| for(uint64_t i = 0; i < items.size(); i++) { | |
| if(i != k) { | |
| TestItem const& inner = items[i]; | |
| TEST(outer.id != inner.id); | |
| } | |
| } | |
| } | |
| pool.free_all(); | |
| TEST(pool.count() == 0); | |
| } | |
| //Test adding from multiple threads with sharing - that is when sharing = 4 | |
| // then the same value will be tried to insert 4 times. | |
| void test_stress(uint32_t num_threads, uint32_t sharing, double seconds) { | |
| struct Entry { | |
| uint64_t counter; | |
| StringId id; | |
| }; | |
| StringPool pool; | |
| std::atomic<uint32_t> started = 0; | |
| std::atomic<uint32_t> run = 0; | |
| std::atomic<uint32_t> stopped = 0; | |
| std::atomic<uint64_t> shared_counter = 0; | |
| std::vector<std::vector<Entry>> entries(num_threads); | |
| for(uint32_t i = 0; i < num_threads; i++) { | |
| std::vector<Entry>* my_entries = &entries[i]; | |
| std::thread([&, my_entries]{ | |
| { | |
| started.fetch_add(1); | |
| while(run == 0); | |
| while(run == 1) { | |
| uint64_t counter = shared_counter.fetch_add(1); | |
| char buffer[16] = {0}; | |
| snprintf(buffer, sizeof buffer, "%llu", (unsigned long long) counter/sharing); | |
| Entry entry = {0}; | |
| entry.counter = counter; | |
| entry.id = pool.get_or_add(buffer); | |
| TEST(strcmp(entry.id, buffer) == 0); | |
| my_entries->push_back(entry); | |
| } | |
| } | |
| stopped.fetch_add(1); | |
| }).detach(); | |
| } | |
| while(started != num_threads); | |
| run = 1; | |
| std::this_thread::sleep_for(std::chrono::milliseconds((int64_t) (seconds * 1e3))); | |
| run = 2; | |
| while(stopped != num_threads); | |
| assert(active_thread_logs == 0); | |
| //Now test that: | |
| // 1) Each entry was added exactly once | |
| // 2) All counter values are present | |
| // 3) Handle represent what they should and point to valid c strings | |
| // 3) All new values will be found rather than added | |
| std::vector<Entry> all_entries; | |
| for(uint64_t i = 0; i < num_threads; i++) { | |
| std::vector<Entry> curr = std::move(entries[i]); | |
| for(uint64_t j = 0; j < curr.size(); j++) | |
| all_entries.push_back(curr[j]); | |
| } | |
| entries.resize(0); | |
| std::sort(all_entries.begin(), all_entries.end(), [](Entry const& a, Entry const& b){ | |
| return a.counter < b.counter; | |
| }); | |
| TEST(all_entries.size() == shared_counter); | |
| uint64_t till = shared_counter/sharing; | |
| for(uint64_t i = 0; i < till; i ++) { | |
| Entry main_entry = all_entries[i*sharing]; | |
| char buffer[16] = {0}; | |
| snprintf(buffer, sizeof buffer, "%llu", (unsigned long long) i); | |
| TEST(strcmp(main_entry.id, buffer) == 0); | |
| for(uint64_t k = 0; k < sharing; k++) { | |
| uint64_t at = k + i*sharing; | |
| if(at >= shared_counter) | |
| break; | |
| Entry entry = all_entries[at]; | |
| TEST(entry.id == main_entry.id); | |
| TEST(entry.counter == at); | |
| } | |
| } | |
| pool.free_all(); | |
| } | |
| static inline uint64_t random_splitmix(uint64_t* state) | |
| { | |
| uint64_t z = (*state += 0x9e3779b97f4a7c15); | |
| z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9; | |
| z = (z ^ (z >> 27)) * 0x94d049bb133111eb; | |
| return z ^ (z >> 31); | |
| } | |
| void benchmark(uint32_t num_threads, double seconds) { | |
| //insertions | |
| { | |
| StringPool pool; | |
| std::atomic<uint64_t> total_iters = 0; | |
| std::atomic<uint32_t> started = 0; | |
| std::atomic<uint32_t> run = 0; | |
| std::atomic<uint32_t> stopped = 0; | |
| for(uint32_t i = 0; i < num_threads; i++) { | |
| std::thread([&, i]{ | |
| uint64_t iters = 0; | |
| uint64_t splitmix_state = (uint64_t) (i+1)*74174918947891; | |
| started.fetch_add(1); | |
| while(run == 0); | |
| while(run == 1) { | |
| uint64_t key = random_splitmix(&splitmix_state); | |
| pool.get_or_add((char*) (void*) &key, sizeof key); | |
| iters += 1; | |
| } | |
| total_iters.fetch_add(iters); | |
| stopped.fetch_add(1); | |
| }).detach(); | |
| } | |
| while(started != num_threads); | |
| run = 1; | |
| std::this_thread::sleep_for(std::chrono::milliseconds((int64_t) (seconds * 1e3))); | |
| run = 2; | |
| while(stopped != num_threads); | |
| double Mops = (double) total_iters / 1e6; | |
| printf("insertion: num_threads=%u throughput=%.2lfMops/s (Ops:%.2lfM time:%.2lfs)\n", | |
| num_threads, Mops/seconds, Mops, seconds); | |
| //uint64_t cap = pool.version.load()->capacity; | |
| //uint64_t pow2 = 0; | |
| //while((1ull << pow2) < cap) pow2+=1; | |
| //printf("table cap: 2^%llu used:%lf%%\n", pow2, 100.0*total_iters/cap); | |
| } | |
| //lookups | |
| { | |
| StringPool pool; | |
| uint64_t range = 1 << 20; | |
| for(uint64_t key = 0; key < range; key++) | |
| pool.get_or_add((char*) (void*) &key, sizeof key); | |
| std::atomic<uint64_t> total_iters = 0; | |
| std::atomic<uint32_t> started = 0; | |
| std::atomic<uint32_t> run = 0; | |
| std::atomic<uint32_t> stopped = 0; | |
| for(uint32_t i = 0; i < num_threads; i++) { | |
| std::thread([&, i]{ | |
| uint64_t iters = 0; | |
| uint64_t splitmix_state = (uint64_t) (i+1)*74174918947891; | |
| started.fetch_add(1); | |
| while(run == 0); | |
| while(run == 1) { | |
| uint64_t key = random_splitmix(&splitmix_state) & (range - 1); | |
| pool.get_or_add((char*) (void*) &key, sizeof key); | |
| iters += 1; | |
| } | |
| total_iters.fetch_add(iters); | |
| stopped.fetch_add(1); | |
| }).detach(); | |
| } | |
| while(started != num_threads); | |
| run = 1; | |
| std::this_thread::sleep_for(std::chrono::milliseconds((int64_t) (seconds * 1e3))); | |
| run = 2; | |
| while(stopped != num_threads); | |
| double Mops = (double) total_iters / 1e6; | |
| printf("lookup: num_threads=%u throughput=%.2lfMops/s (Ops:%.2lfM time:%.2lfs)\n", | |
| num_threads, Mops/seconds, Mops, seconds); | |
| } | |
| } | |
| int main() { | |
| printf("unit testing ...\n"); | |
| test_unit(); | |
| printf("benchmarking ...\n"); | |
| for(uint32_t threads = 1;; threads += 2) { | |
| if(threads > std::thread::hardware_concurrency()) | |
| threads = std::thread::hardware_concurrency(); | |
| benchmark(threads, 1); | |
| if(threads == std::thread::hardware_concurrency()) | |
| break; | |
| } | |
| printf("stress testing ...\n"); | |
| for(uint32_t sharing = 1; sharing <= std::thread::hardware_concurrency() + 2; sharing += 3) | |
| for(uint32_t threads = 1; threads <= std::thread::hardware_concurrency(); threads += 1) { | |
| printf("testing with threads=%u sharing=%u ...\n", threads, sharing); | |
| test_stress(threads, sharing, 0.5); | |
| } | |
| } |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Updates: