Created
August 24, 2021 19:51
-
-
Save jstefanelli/afe4913219a4d4fab8d73e1b0eee9933 to your computer and use it in GitHub Desktop.
Proof-of-concept lock-free(?) Queue in C++
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <thread> | |
| #include <vector> | |
| #include <atomic> | |
| #include <memory> | |
| #include <iostream> | |
| #include <sstream> | |
| template<typename T> | |
| struct Concurrent_Queue_t { | |
| protected: | |
| struct Queue_Section_t { | |
| protected: | |
| T* data; | |
| size_t size; | |
| std::atomic_size_t start; | |
| std::atomic_size_t end; | |
| public: | |
| std::atomic<std::shared_ptr<Queue_Section_t>> next; | |
| Queue_Section_t(size_t size) : size(size), start(0), end(0), next(std::shared_ptr<Queue_Section_t>(nullptr)) { | |
| data = reinterpret_cast<T*>(std::malloc(sizeof(T) * size)); | |
| } | |
| ~Queue_Section_t() { | |
| std::free(data); | |
| } | |
| void Push(const T& value) { | |
| size_t e = end.load(); | |
| size_t n; | |
| do { | |
| if (e >= size) { | |
| std::shared_ptr<Queue_Section_t> next_section; | |
| do { | |
| next_section = next.load(); | |
| } while(next_section == nullptr); | |
| next_section->Push(value); | |
| return; | |
| } | |
| n = e + 1; | |
| } while(!end.compare_exchange_weak(e, n, std::memory_order_release, std::memory_order_relaxed)); | |
| data[e] = value; | |
| if (e == size / 2) { | |
| next.store(std::make_shared<Queue_Section_t>(size)); | |
| } | |
| } | |
| std::optional<T> Pull() { | |
| size_t s = start.load(); | |
| size_t e; | |
| do { | |
| e = end.load(); | |
| if (s >= e) { | |
| return std::optional<T>(); | |
| } | |
| } while(!start.compare_exchange_weak(s, s + 1, std::memory_order_release, std::memory_order_relaxed)); | |
| return std::optional<T>(data[s]); | |
| } | |
| bool Full() const { | |
| return end.load() >= size; | |
| } | |
| bool Completed() const { | |
| return Full() && start.load() >= size; | |
| } | |
| }; | |
| std::atomic<std::shared_ptr<Queue_Section_t>> first_section; | |
| std::atomic<std::shared_ptr<Queue_Section_t>> last_section; | |
| size_t chunk_size; | |
| public: | |
| Concurrent_Queue_t(size_t chunk_size = 64U) : chunk_size(chunk_size), first_section(std::make_shared<Queue_Section_t>(chunk_size)) { | |
| last_section.store(first_section.load()); | |
| } | |
| void Push(const T& item) { | |
| auto l = last_section.load(); | |
| l->Push(item); | |
| auto next = l->next.load(); | |
| if (l->Full()) { | |
| do { | |
| next = l->next.load(); | |
| if (!l->Full() || next == nullptr) | |
| return; | |
| } while(!last_section.compare_exchange_weak(l, next, std::memory_order_release, std::memory_order_relaxed)); | |
| } | |
| } | |
| std::optional<T> Pull() { | |
| auto f = first_section.load(); | |
| std::optional<T> val = f->Pull(); | |
| std::shared_ptr<Queue_Section_t> next; | |
| do { | |
| next = f->next.load(); | |
| if ((!f->Completed()) || next == nullptr) | |
| break; | |
| } while (!first_section.compare_exchange_weak(f, next, std::memory_order_release, std::memory_order_relaxed)); | |
| return val; | |
| } | |
| }; | |
| #define THREAD_COUNT 8 | |
| #define NUMBER_COUNT 10 | |
| int main() | |
| { | |
| std::vector<std::thread> threads; | |
| Concurrent_Queue_t<int> queue(32); | |
| for (auto i = 0; i < THREAD_COUNT; i++) { | |
| threads.push_back(std::thread([&queue, i]() { | |
| std::stringstream sstream; | |
| sstream << "Thread " << i << " online" << std::endl; | |
| std::cout << sstream.str(); | |
| for (int j = 0; j < NUMBER_COUNT; j++) { | |
| queue.Push((i * NUMBER_COUNT) + j); | |
| } | |
| sstream.str(""); | |
| sstream << "Thread " << i << " done" << std::endl; | |
| std::cout << sstream.str(); | |
| })); | |
| } | |
| for (auto& t : threads) { | |
| t.join(); | |
| } | |
| std::optional<int> res; | |
| do { | |
| res = queue.Pull(); | |
| if (res.has_value()) { | |
| std::cout << "Val: " << res.value() << std::endl; | |
| } | |
| else { | |
| std::cout << "Invalid value" << std::endl; | |
| } | |
| } while(res.has_value()); | |
| return 0; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment