Created
December 30, 2023 20:07
-
-
Save likern/9df0d97b3551236b456716d95f282f83 to your computer and use it in GitHub Desktop.
PetersonMutex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const log = std.log.info; | |
const PetersonMutex = struct { | |
flag_: [2]u64 = [_]u64{ 0, 0 }, | |
victim_idx_: u64 = undefined, // 0 or 1 | |
main_thread_id_: std.Thread.Id, | |
pub fn init(main_thread: std.Thread.Id) PetersonMutex { | |
return PetersonMutex{ | |
.main_thread_id_ = main_thread, | |
}; | |
} | |
pub fn lock(self: *PetersonMutex) void { | |
const i = self.id_to_index(); | |
const j = 1 - i; | |
const ptr_flag_i: *volatile u64 = @ptrCast(&self.flag_[i]); | |
const ptr_flag_j: *volatile u64 = @ptrCast(&self.flag_[j]); | |
const ptr_victim: *volatile u64 = @ptrCast(&self.victim_idx_); | |
ptr_flag_i.* = 1; | |
ptr_victim.* = i; | |
// self.flag_[i] = true; | |
// self.victim_idx_ = i; | |
while (ptr_flag_j.* == 1 and ptr_victim.* == i) {} | |
} | |
pub fn unlock(self: *PetersonMutex) void { | |
const i = self.id_to_index(); | |
const ptr_flag_i: *volatile u64 = @ptrCast(&self.flag_[i]); | |
ptr_flag_i.* = 0; | |
} | |
fn id_to_index(self: *const PetersonMutex) u64 { | |
const id = std.Thread.getCurrentId(); | |
if (id == self.main_thread_id_) { | |
return 0; | |
} | |
return 1; | |
} | |
}; | |
var global_counter: u64 = 0; | |
var mutex: PetersonMutex = undefined; | |
pub fn add_one() void { | |
mutex.lock(); | |
defer mutex.unlock(); | |
const val: *volatile u64 = @ptrCast(&global_counter); | |
val.* = val.* + 1; | |
} | |
pub fn run() void { | |
for (0..10_000) |_| { | |
add_one(); | |
} | |
} | |
pub fn thread_function() void { | |
run(); | |
} | |
pub fn main() !void { | |
mutex = PetersonMutex.init(std.Thread.getCurrentId()); | |
const second = try std.Thread.spawn(.{}, thread_function, .{}); | |
run(); | |
std.Thread.join(second); | |
log("Expected: 20_000, got: {}", .{global_counter}); | |
return; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment