|
/// SPDX-License-Identifier: GPL-2.0-or-later |
|
|
|
// How to compile: |
|
// g++ -O3 -std=c++20 -mwaitpkg -Wall -Wextra umonitor-umwait-semaphore.cc -o umonitor-umwait-semaphore |
|
// OR |
|
// g++ -O3 -std=c++20 -mwaitpkg -Wall -Wextra -fsanitize=undefined,thread umonitor-umwait-semaphore.cc -o umonitor-umwait-semaphore |
|
|
|
#include <algorithm> |
|
#include <atomic> |
|
#include <charconv> |
|
#include <cstdint> |
|
#include <new> |
|
#include <sys/types.h> |
|
#include <thread> |
|
#include <chrono> |
|
#include <iostream> |
|
#include <stdexcept> |
|
#include <vector> |
|
#include <array> |
|
|
|
#ifdef _MSC_VER |
|
#include <intrin.h> // For MSVC intrinsics |
|
#else |
|
#include <immintrin.h> // For GCC/Clang intrinsics |
|
#endif |
|
|
|
constexpr uint64_t UMWAIT_SLEEP_DURATION_TSC_MAX = 3000ULL; // tune per HW |
|
|
|
// for debugging |
|
using TimeStamp = std::uint64_t; |
|
using ThreadTimeStamps = struct alignas(std::hardware_destructive_interference_size) { |
|
std::array<TimeStamp, 3> timestamps; |
|
size_t sleep_iter; |
|
}; |
|
|
|
class UMWaitSemaphore { |
|
using CounterType = std::atomic<int>; |
|
static_assert(CounterType::is_always_lock_free); |
|
|
|
public: |
|
explicit UMWaitSemaphore(int initial_count = 0) |
|
: count_{ initial_count } |
|
{ |
|
if (initial_count < 0) |
|
throw std::runtime_error("Semaphore initial count cannot be negative."); |
|
} |
|
|
|
size_t acquire() |
|
{ |
|
size_t iter = 0; |
|
while (true) { |
|
auto expected_count = count_.load(std::memory_order_acquire); |
|
ACQUIRE_IF_POSSIBLE: |
|
if (expected_count > 0) { // Is it possible to acquire the semaphore? |
|
auto desired_count = expected_count - 1; |
|
if (count_.compare_exchange_weak(expected_count, desired_count, |
|
std::memory_order_release, |
|
std::memory_order_relaxed)) // Try to acquire |
|
return iter; // Permit acquired, exit |
|
continue; // The counts didn't match because some other thread |
|
// acquired a permit in the time we read count_ and |
|
// tried to decrement it. So, retry! |
|
} |
|
|
|
// Sleep until count_ becomes non-zero |
|
do { |
|
_umonitor(&count_); |
|
const uint64_t now = __rdtsc(); |
|
const uint64_t wakeup_at = now + UMWAIT_SLEEP_DURATION_TSC_MAX; |
|
_umwait(1, wakeup_at); |
|
iter++; |
|
} while ((expected_count = count_.load(std::memory_order_acquire)) == 0); |
|
// There's a high chance that count_ is still non-zero, so |
|
// do the check again, this time skipping a load |
|
goto ACQUIRE_IF_POSSIBLE; |
|
} |
|
} |
|
|
|
void release() |
|
{ |
|
count_.fetch_add(1, std::memory_order_acq_rel); // Increment count & wakeup umwait |
|
// NOTE: Increment is a read/modify/write operation! |
|
} |
|
|
|
private: |
|
alignas(std::hardware_constructive_interference_size) CounterType count_; |
|
}; |
|
|
|
void worker_thread(UMWaitSemaphore &semaphore, int tid, ThreadTimeStamps ×tamps) |
|
{ |
|
// Instead of cout, record events using tsc and output them later |
|
|
|
auto a = __rdtsc(); |
|
auto sleep_iter = semaphore.acquire(); |
|
auto b = __rdtsc(); |
|
std::this_thread::sleep_for(std::chrono::microseconds(150 + tid * 22)); // Simulate work |
|
auto c = __rdtsc(); |
|
semaphore.release(); |
|
|
|
timestamps = { .timestamps = { a, b, c }, .sleep_iter = sleep_iter }; |
|
} |
|
|
|
static void parse_args(int argc, char const *const *argv, int &numThreads, int &semaphoreCount); |
|
static void print_all_timestamps_sorted(const std::vector<ThreadTimeStamps> &); |
|
|
|
std::uint64_t g_tsc_start; |
|
int main(int argc, char *argv[]) |
|
{ |
|
int numThreads; |
|
int semaphoreCount; |
|
parse_args(argc, argv, numThreads, semaphoreCount); |
|
|
|
// Create a semaphore with an initial count of 2 |
|
UMWaitSemaphore semaphore(semaphoreCount); |
|
std::clog << "Semaphore initialized with count = " << semaphoreCount << ".\n"; |
|
|
|
std::vector<std::thread> threads; |
|
std::vector<ThreadTimeStamps> timestamps; |
|
|
|
threads.reserve(numThreads); |
|
timestamps.resize(numThreads); |
|
|
|
std::clog << "Spawning " << numThreads << " worker threads...\n"; |
|
g_tsc_start = __rdtsc(); |
|
for (int i = 0; i < numThreads; ++i) { |
|
threads.emplace_back(worker_thread, std::ref(semaphore), i + 1, std::ref(timestamps[i])); |
|
} |
|
|
|
// Let threads run for a bit |
|
std::this_thread::sleep_for(std::chrono::seconds(1)); |
|
|
|
std::clog << "\nMain thread: Releasing an additional permit after some time.\n"; |
|
|
|
for (int i = 0; i < semaphoreCount; i++) { // Manually release permits |
|
semaphore.release(); |
|
std::this_thread::sleep_for(std::chrono::milliseconds(100)); |
|
} |
|
semaphore.release(); // Manually release another permit |
|
|
|
for (auto &t : threads) { |
|
t.join(); |
|
} |
|
|
|
std::clog << "\nAll workers finished.\n"; |
|
|
|
std::clog << "\n\nExecution summary:\n"; |
|
print_all_timestamps_sorted(timestamps); |
|
} |
|
|
|
// --------------------------------------------------------------------------------- |
|
void parse_args(int argc, char const *const *argv, int &numThreads, int &semaphoreCount) |
|
{ |
|
if (argc != 3) { |
|
throw std::runtime_error("bad args"); |
|
} |
|
auto parse_intcstr = [](const char *cstr, int &v) { |
|
std::string_view sv{ cstr }; |
|
auto [ptr, ec] = std::from_chars(sv.data(), sv.data() + sv.size(), v); |
|
if (ec == std::errc()) |
|
return; |
|
else |
|
throw std::runtime_error("Error parsing argv"); |
|
}; |
|
|
|
parse_intcstr(argv[1], numThreads); |
|
parse_intcstr(argv[2], semaphoreCount); |
|
} |
|
|
|
void print_all_timestamps_sorted(const std::vector<ThreadTimeStamps> &allTimeStamps) |
|
{ |
|
struct RevStamp { |
|
int tid; |
|
enum RevStampType : char { kAsk, kAcquired, kReleasing } type; |
|
std::uint64_t timestamp; // set only on acquire (type == 1) |
|
}; |
|
std::vector<RevStamp> stamps; |
|
stamps.reserve(allTimeStamps.size() * 3); |
|
for (size_t i = 0; i < allTimeStamps.size(); i++) { |
|
for (size_t j = 0; j < 3; j++) { |
|
stamps.push_back({ .tid = (int)i, |
|
.type = (RevStamp::RevStampType)j, |
|
.timestamp = allTimeStamps[i].timestamps[j] }); |
|
} |
|
} |
|
|
|
std::sort(stamps.begin(), stamps.end(), [](auto &a, auto &b) { return a.timestamp < b.timestamp; }); |
|
|
|
int w[] = { 12, 8, 16, 8, 8 }; |
|
std::cout << std::left << std::setw(w[0]) << "TSC" << std::setw(w[1]) << "Tid" << std::setw(w[2]) |
|
<< "Operation" << std::setw(w[3]) << "Retry #" << std::setw(w[4]) << "Acquired #" |
|
<< std::endl; |
|
ssize_t acq_count = 0; |
|
for (auto &stamp : stamps) { |
|
std::cout << std::setw(w[0]) << stamp.timestamp - g_tsc_start << std::setw(w[1]) << stamp.tid; |
|
switch (stamp.type) { |
|
case RevStamp::kAsk: { |
|
std::cout << std::setw(w[2]) << "Ask"; // thread arrived. |
|
break; |
|
} |
|
case RevStamp::kAcquired: { |
|
std::cout << std::setw(w[2]) << "Acquired" << std::setw(w[3]) << ++acq_count |
|
<< std::setw(w[4]) << std::right << allTimeStamps[stamp.tid].sleep_iter |
|
<< std::left; |
|
break; |
|
} |
|
case RevStamp::kReleasing: { |
|
std::cout << std::setw(w[2]) << "Releasing!" << std::setw(w[3]) << std::setw(w[4]) |
|
<< --acq_count; |
|
break; |
|
} |
|
} |
|
std::cout << '\n'; |
|
} |
|
} |