Last active
March 20, 2024 16:12
-
-
Save kprotty/c01630ae285e8b14cb7a36c454ac2d87 to your computer and use it in GitHub Desktop.
simple http server compatible with wrk for linux using epoll
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
pub fn main() !void { | |
var gpa = std.heap.GeneralPurposeAllocator(.{}){}; | |
const allocator = &gpa.allocator; | |
defer _ = gpa.deinit(); | |
var poller = try Poller.init(allocator); | |
defer poller.deinit(); | |
try Server.start(&poller, 12345); | |
while (true) { | |
poller.poll(); | |
} | |
} | |
const Poller = struct { | |
fd: std.os.fd_t, | |
allocator: *std.mem.Allocator, | |
fn init(allocator: *std.mem.Allocator) !Poller { | |
return Poller{ | |
.fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC), | |
.allocator = allocator, | |
}; | |
} | |
fn deinit(self: *Poller) void { | |
std.os.close(self.fd); | |
} | |
const Event = struct { | |
const Callback = struct { | |
onEventFn: fn(*Callback, Event) void, | |
}; | |
is_closable: bool, | |
is_readable: bool, | |
is_writable: bool, | |
}; | |
const Socket = struct { | |
poller: *Poller, | |
fd: std.os.socket_t, | |
callback: Event.Callback, | |
fn start(comptime Self: type, poller: *Poller, fd: std.os.socket_t) !*Self { | |
const Callback = struct { | |
fn onEvent(callback: *Event.Callback, event: Event) void { | |
const socket = @fieldParentPtr(Socket, "callback", callback); | |
const self = @fieldParentPtr(Self, "socket", socket); | |
handleEvent(self, event) catch { | |
self.onClose(); | |
std.os.close(self.socket.fd); | |
self.socket.poller.allocator.destroy(self); | |
}; | |
} | |
fn handleEvent(self: *Self, event: Event) !void { | |
if (event.is_closable) | |
return error.Closed; | |
if (event.is_readable) | |
try self.onRead(); | |
if (event.is_writable) | |
try self.onWrite(); | |
} | |
}; | |
const self = try poller.allocator.create(Self); | |
errdefer poller.allocator.destroy(self); | |
self.socket = .{ | |
.poller = poller, | |
.fd = fd, | |
.callback = .{ | |
.onEventFn = Callback.onEvent, | |
}, | |
}; | |
// Register with edge-triggering (EPOLLET) so that it re-arms the events when we do IO. | |
// Saves doing another epoll_ctl() syscall to rearm if using level/edge-trigerring. | |
try std.os.epoll_ctl(poller.fd, std.os.EPOLL_CTL_ADD, fd, &std.os.epoll_event{ | |
.events = std.os.EPOLLIN | std.os.EPOLLOUT | std.os.EPOLLET | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = @ptrToInt(&self.socket.callback) }, | |
}); | |
return self; | |
} | |
}; | |
fn poll(self: *Poller) void { | |
var events: [128]std.os.epoll_event = undefined; | |
const events_found = std.os.epoll_wait(self.fd, &events, -1); | |
if (events_found == 0) | |
return; | |
for (events[0..events_found]) |ev| { | |
const callback = @intToPtr(*Event.Callback, ev.data.ptr); | |
(callback.onEventFn)(callback, Event{ | |
.is_closable = ev.events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0, | |
.is_readable = ev.events & std.os.EPOLLIN != 0, | |
.is_writable = ev.events & std.os.EPOLLOUT != 0, | |
}); | |
} | |
} | |
}; | |
const Server = struct { | |
socket: Poller.Socket, | |
port: u16, | |
const SOCK_FLAGS = std.os.SOCK_CLOEXEC | std.os.SOCK_NONBLOCK; | |
fn start(poller: *Poller, comptime port: u16) !void { | |
const fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | SOCK_FLAGS, std.os.IPPROTO_TCP); | |
errdefer std.os.close(fd); | |
// Bind the socket to the port on the local address | |
const address = "127.0.0.1"; | |
var addr = comptime std.net.Address.parseIp(address, port) catch unreachable; | |
try std.os.setsockopt(fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(fd, 128); | |
const self = try Poller.Socket.start(Server, poller, fd); | |
self.port = port; | |
std.debug.warn("Listening on {}:{}", .{address, port}); | |
} | |
fn onClose(self: *Server) void { | |
std.debug.warn("server shutdown for port: {}", .{self.port}); | |
} | |
fn onWrite(self: *Server) !void { | |
std.debug.panic("server shouldn't writable", .{}); | |
} | |
fn onRead(self: *Server) !void { | |
while (true) { | |
const client_fd = std.os.accept(self.socket.fd, null, null, SOCK_FLAGS) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
Client.start(self.socket.poller, client_fd) catch |err| { | |
std.debug.warn("Failed to spawn a client: {}\n", .{err}); | |
continue; | |
}; | |
} | |
} | |
}; | |
const Client = struct { | |
socket: Poller.Socket, | |
send_bytes: usize, | |
send_partial: usize, | |
recv_bytes: usize, | |
recv_buffer: [4096]u8, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn start(poller: *Poller, fd: std.os.socket_t) !void { | |
errdefer std.os.close(fd); | |
// Enable TCP-NoDelay to send the http responses as fast as possible. | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
const self = try Poller.Socket.start(Client, poller, fd); | |
self.send_bytes = 0; | |
self.send_partial = 0; | |
self.recv_bytes = 0; | |
self.recv_buffer = undefined; | |
} | |
fn onClose(self: *Client) void { | |
// Do nothing | |
} | |
fn onRead(self: *Client) !void { | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
// Try to parse and consume a request in the request_buffer | |
// by matching everything until the end of an HTTP request with no body | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
// If found, count that as a parsed request | |
// and have the writer write the static HTTP response (eventually). | |
self.send_bytes += HTTP_RESPONSE.len; | |
continue; | |
} | |
// A complete wasn't parsed yet. | |
// Try to read more data into the buffer and try again. | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
// If the read would normally block, | |
// then we have to wait for the socket to be readable in the future to try again. | |
const bytes_read = std.os.read(self.socket.fd, readable_buffer) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
// 0 bytes read indicates that the socket can no longer read any data. | |
self.recv_bytes += bytes_read; | |
if (bytes_read == 0) | |
return error.EndOfStream; | |
} | |
} | |
fn onWrite(self: *Client) !void { | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
while (self.send_bytes > 0) { | |
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial | |
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
const iov_len = std.math.min(self.send_bytes, RESPONSE_CHUNK.len); | |
const writable_buffer = iov_base[0..iov_len]; | |
// Perform the actual write. | |
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing. | |
const bytes_written = std.os.sendto( | |
self.socket.fd, | |
writable_buffer, | |
std.os.MSG_NOSIGNAL, | |
null, | |
@as(std.os.socklen_t, 0), | |
) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.send_bytes -= bytes_written; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
} | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
pub fn main() !void { | |
var gpa = std.heap.GeneralPurposeAllocator(.{}){}; | |
defer _ = gpa.deinit(); | |
const allocator = &gpa.allocator; | |
const num_threads = 6; // std.math.max(1, std.Thread.cpuCount() catch 1); | |
const worker_fds = try allocator.alloc(std.os.fd_t, num_threads); | |
defer allocator.free(worker_fds); | |
for (worker_fds) |*worker_fd| | |
worker_fd.* = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC); | |
for (worker_fds[1..]) |worker_fd| | |
_ = try std.Thread.spawn(worker_fd, runWorker); | |
const server_fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
errdefer std.os.close(server_fd); | |
const port = 12345; | |
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable; | |
try std.os.setsockopt(server_fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(server_fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(server_fd, 128); | |
var next_fd: usize = 0; | |
const epoll_fd = worker_fds[next_fd]; | |
var events: [256]std.os.epoll_event = undefined; | |
var server_event = std.os.epoll_event{ | |
.events = std.os.EPOLLIN | std.os.EPOLLET | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = 0 }, | |
}; | |
try std.os.epoll_ctl( | |
epoll_fd, | |
std.os.EPOLL_CTL_ADD, | |
server_fd, | |
&server_event, | |
); | |
std.debug.warn("Listening on :{}", .{port}); | |
while (true) { | |
const found = std.os.epoll_wait(epoll_fd, &events, -1); | |
for (events[0..found]) |event| { | |
if (event.data.ptr != 0) { | |
Client.process(event); | |
continue; | |
} | |
if (event.events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0) | |
unreachable; | |
if (event.events & std.os.EPOLLIN == 0) | |
unreachable; | |
while (true) { | |
const client_fd = std.os.accept(server_fd, null, null, std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC) catch |err| switch (err) { | |
error.WouldBlock => break, | |
else => |e| return e, | |
}; | |
if (Client.start(allocator, worker_fds[next_fd], client_fd)) |_| { | |
next_fd += 1; | |
if (next_fd >= worker_fds.len) | |
next_fd = 0; | |
} else |_| { | |
std.os.close(client_fd); | |
std.debug.warn("Failed to start client: {}\n", .{client_fd}); | |
} | |
} | |
} | |
} | |
} | |
fn runWorker(epoll_fd: std.os.fd_t) void { | |
var events: [256]std.os.epoll_event = undefined; | |
while (true) { | |
const found = std.os.epoll_wait(epoll_fd, &events, -1); | |
for (events[0..found]) |event| | |
Client.process(event); | |
} | |
} | |
const Client = struct { | |
fd: std.os.socket_t, | |
allocator: *std.mem.Allocator, | |
send_bytes: usize = 0, | |
send_partial: usize = 0, | |
recv_bytes: usize = 0, | |
recv_buffer: [4096]u8 = undefined, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn start(allocator: *std.mem.Allocator, epoll_fd: std.os.fd_t, fd: std.os.socket_t) !void { | |
const self = try allocator.create(Client); | |
errdefer allocator.destroy(self); | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
self.* = .{ | |
.fd = fd, | |
.allocator = allocator, | |
}; | |
try std.os.epoll_ctl(epoll_fd, std.os.EPOLL_CTL_ADD, fd, &std.os.epoll_event{ | |
.events = std.os.EPOLLIN | std.os.EPOLLOUT | std.os.EPOLLET | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = @ptrToInt(self) }, | |
}); | |
} | |
fn process(event: std.os.epoll_event) void { | |
const self = @intToPtr(*Client, event.data.ptr); | |
self.processEvent(event.events) catch { | |
std.os.close(self.fd); | |
self.allocator.destroy(self); | |
}; | |
} | |
fn processEvent(self: *Client, events: u32) !void { | |
if (events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0) | |
return error.Closed; | |
if (events & std.os.EPOLLIN != 0) | |
try self.processRead(); | |
if (events & std.os.EPOLLOUT != 0) | |
try self.processWrite(); | |
} | |
fn processRead(self: *Client) !void { | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
// Try to parse and consume a request in the request_buffer | |
// by matching everything until the end of an HTTP request with no body | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
// If found, count that as a parsed request | |
// and have the writer write the static HTTP response (eventually). | |
self.send_bytes += HTTP_RESPONSE.len; | |
continue; | |
} | |
// A complete wasn't parsed yet. | |
// Try to read more data into the buffer and try again. | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
// If the read would normally block, | |
// then we have to wait for the socket to be readable in the future to try again. | |
const bytes_read = std.os.read(self.fd, readable_buffer) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
// 0 bytes read indicates that the socket can no longer read any data. | |
self.recv_bytes += bytes_read; | |
if (bytes_read == 0) | |
return error.EndOfStream; | |
} | |
} | |
fn processWrite(self: *Client) !void { | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
while (self.send_bytes > 0) { | |
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial | |
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
const iov_len = std.math.min(self.send_bytes, RESPONSE_CHUNK.len - self.send_partial); | |
const writable_buffer = iov_base[0..iov_len]; | |
// Perform the actual write. | |
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing. | |
const bytes_written = std.os.sendto( | |
self.fd, | |
writable_buffer, | |
std.os.MSG_NOSIGNAL, | |
null, | |
@as(std.os.socklen_t, 0), | |
) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.send_bytes -= bytes_written; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
} | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
var poll_fd: std.os.fd_t = undefined; | |
var server_fd: std.os.socket_t = undefined; | |
var allocator: *std.mem.Allocator = undefined; | |
pub fn main() !void { | |
var gpa = std.heap.GeneralPurposeAllocator(.{}){}; | |
allocator = &gpa.allocator; | |
defer _ = gpa.deinit(); | |
poll_fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC); | |
defer std.os.close(poll_fd); | |
server_fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
defer std.os.close(server_fd); | |
const port = 12345; | |
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable; | |
try std.os.setsockopt(server_fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(server_fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(server_fd, 128); | |
try std.os.epoll_ctl(poll_fd, std.os.EPOLL_CTL_ADD, server_fd, &std.os.epoll_event{ | |
.events = std.os.EPOLLIN | std.os.EPOLLET | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = @ptrToInt(&server_fd) }, | |
}); | |
var threads = std.math.max(1, try std.Thread.cpuCount()); | |
while (threads > 1) : (threads -= 1) | |
_ = try std.Thread.spawn({}, runWorker); | |
std.debug.warn("Listening on :{}\n", .{port}); | |
runWorker({}); | |
} | |
fn runWorker(_: void) void { | |
var events: [256]std.os.epoll_event = undefined; | |
while (true) { | |
const found = std.os.epoll_wait(poll_fd, &events, -1); | |
if (found == 0) | |
continue; | |
for (events[0..found]) |event| { | |
const ptr = event.data.ptr; | |
const flags = event.events; | |
if (ptr == @ptrToInt(&server_fd)) { | |
Client.accept(flags) catch |e| std.debug.warn("failed to accept a client: {}\n", .{e}); | |
continue; | |
} | |
const client = @intToPtr(*Client, ptr); | |
client.process(flags) catch {}; | |
} | |
} | |
} | |
const Client = struct { | |
fd: std.os.socket_t, | |
send_bytes: usize = 0, | |
send_partial: usize = 0, | |
recv_bytes: usize = 0, | |
recv_buffer: [4096]u8 = undefined, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn accept(flags: u32) !void { | |
while (true) { | |
const client_fd = std.os.accept(server_fd, null, null, std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
errdefer std.os.close(client_fd); | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(client_fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
const self = try allocator.create(Client); | |
self.* = Client{ .fd = client_fd }; | |
errdefer allocator.destroy(self); | |
try std.os.epoll_ctl(poll_fd, std.os.EPOLL_CTL_ADD, client_fd, &std.os.epoll_event{ | |
.events = std.os.EPOLLIN | std.os.EPOLLONESHOT | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = @ptrToInt(self) }, | |
}); | |
} | |
} | |
fn process(self: *Client, flags: u32) !void { | |
errdefer { | |
std.os.close(self.fd); | |
allocator.destroy(self); | |
} | |
if (flags & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0) | |
return error.Closed; | |
var written = false; | |
if ((flags & std.os.EPOLLOUT != 0) and (self.send_bytes > 0)) { | |
written = true; | |
try self.processWrite(); | |
} | |
if (flags & std.os.EPOLLIN != 0) | |
try self.processRead(); | |
if (!written and (flags & std.os.EPOLLOUT != 0) and (self.send_bytes > 0)) | |
try self.processWrite(); | |
var events: u32 = std.os.EPOLLIN | std.os.EPOLLONESHOT | std.os.EPOLLRDHUP; | |
if (self.send_bytes > 0) | |
events |= std.os.EPOLLOUT; | |
try std.os.epoll_ctl(poll_fd, std.os.EPOLL_CTL_MOD, self.fd, &std.os.epoll_event{ | |
.events = events, | |
.data = .{ .ptr = @ptrToInt(self) }, | |
}); | |
} | |
fn processRead(self: *Client) !void { | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
self.send_bytes += HTTP_RESPONSE.len; | |
continue; | |
} | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
const bytes_read = std.os.read(self.fd, readable_buffer) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.recv_bytes += bytes_read; | |
if (bytes_read == 0) | |
return error.EndOfStream; | |
} | |
} | |
fn processWrite(self: *Client) !void { | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
while (self.send_bytes > 0) { | |
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial | |
const send_bytes = self.send_bytes; | |
if (self.send_partial > RESPONSE_CHUNK.len) | |
std.debug.panic("invalid send_partial={} chunk={}\n", .{self.send_partial, RESPONSE_CHUNK.len}); | |
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
const iov_len = std.math.min(send_bytes, RESPONSE_CHUNK.len - self.send_partial); | |
const writable_buffer = iov_base[0..iov_len]; | |
// Perform the actual write. | |
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing. | |
const bytes_written = std.os.sendto( | |
self.fd, | |
writable_buffer, | |
std.os.MSG_NOSIGNAL, | |
null, | |
@as(std.os.socklen_t, 0), | |
) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
self.send_bytes -= bytes_written; | |
} | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
pub fn main() !void { | |
var loop: Loop = undefined; | |
try loop.init(); | |
defer loop.deinit(); | |
var gpa = std.heap.GeneralPurposeAllocator(.{}){}; | |
var server_frame = async Server.run(&loop, &gpa.allocator, 12345); | |
loop.run(); | |
std.debug.warn("run end\n", .{}); | |
try (nosuspend await server_frame); | |
if (gpa.deinit()) | |
return error.LeakDetected; | |
} | |
const Server = struct { | |
fn run(loop: *Loop, allocator: *std.mem.Allocator, port: u16) !void { | |
const server_fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
defer std.os.close(server_fd); | |
var addr = try std.net.Address.parseIp("127.0.0.1", port); | |
try std.os.setsockopt(server_fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(server_fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(server_fd, 128); | |
var event: Loop.IoDriver.Event = undefined; | |
try event.init(server_fd, &loop.io_driver); | |
defer event.deinit(); | |
std.debug.print("Listening on {}\n", .{port}); | |
while (true) { | |
const client_fd = std.os.accept(server_fd, null, null, std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC) catch |err| switch (err) { | |
error.WouldBlock => { | |
event.waitFor(.read); | |
continue; | |
}, | |
else => |e| return e, | |
}; | |
const client_frame = allocator.create(@Frame(Client.run)) catch |err| { | |
std.os.close(client_fd); | |
std.debug.print("Failed to spawn client: OOM\n", .{}); | |
continue; | |
}; | |
client_frame.* = async Client.run(client_fd, loop, allocator); | |
} | |
} | |
}; | |
const Client = struct { | |
fd: std.os.fd_t, | |
send_bytes: usize = 0, | |
send_partial: usize = 0, | |
writer_notify: usize = 0, | |
recv_bytes: usize = 0, | |
recv_buffer: [4096]u8 = undefined, | |
event: Loop.IoDriver.Event = undefined, | |
fn run(fd: std.os.fd_t, loop: *Loop, allocator: *std.mem.Allocator) !void { | |
var self: Client = .{ .fd = fd }; | |
self.execute(loop) catch |e| {}; // std.debug.print("Client({}) had error {}\n", .{fd, e}); | |
std.os.close(fd); | |
suspend allocator.destroy(@frame()); | |
} | |
fn execute(self: *Client, loop: *Loop) !void { | |
// Enable TCP-NoDelay to send the http responses as fast as possible. | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(self.fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
try self.event.init(self.fd, &loop.io_driver); | |
defer self.event.deinit(); | |
var read_frame = async self.reader(loop); | |
var write_frame = async self.writer(loop); | |
const read_err = await read_frame; | |
const write_err = await write_frame; | |
try read_err; | |
try write_err; | |
} | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn reader(self: *Client, loop: *Loop) !void { | |
suspend { | |
var task = Loop.Task{ .frame = @frame(), .name = "Client.read(begin)" }; | |
loop.schedule(&task, .fifo); | |
} | |
defer self.writerClose(loop); | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
// Try to parse and consume a request in the request_buffer | |
// by matching everything until the end of an HTTP request with no body | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
// If found, count that as a parsed request | |
// and have the writer write the static HTTP response (eventually). | |
self.send_bytes += HTTP_RESPONSE.len; | |
self.writerNotify(loop); | |
continue; | |
} | |
// A complete wasn't parsed yet. | |
// Try to read more data into the buffer and try again. | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
// If the read would normally block, | |
// then we have to wait for the socket to be readable in the future to try again. | |
const bytes_read = blk: { | |
while (true) { | |
break :blk std.os.read(self.fd, readable_buffer) catch |err| switch (err) { | |
error.WouldBlock => { | |
self.event.waitFor(.read); | |
continue; | |
}, | |
else => |e| return e, | |
}; | |
} | |
}; | |
// 0 bytes read indicates that the socket can no longer read any data. | |
self.recv_bytes += bytes_read; | |
if (bytes_read == 0) | |
return error.EndOfStream; | |
} | |
} | |
fn writer(self: *Client, loop: *Loop) !void { | |
suspend { | |
var task = Loop.Task{ .frame = @frame(), .name = "Client.write(begin)" }; | |
loop.schedule(&task, .lifo); | |
} | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
while (true) { | |
// wait for the writer to get some data | |
if (self.send_bytes == 0) { | |
try self.writerWait(); | |
continue; | |
} | |
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial | |
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
const iov_len = std.math.min(self.send_bytes, RESPONSE_CHUNK.len); | |
const writable_buffer = iov_base[0..iov_len]; | |
// Perform the actual write. | |
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing. | |
const bytes_written = blk: { | |
while (true) { | |
break :blk std.os.sendto( | |
self.fd, | |
writable_buffer, | |
std.os.MSG_NOSIGNAL, | |
null, | |
@as(std.os.socklen_t, 0), | |
) catch |err| switch (err) { | |
error.WouldBlock => { | |
self.event.waitFor(.write); | |
continue; | |
}, | |
else => |e| return e, | |
}; | |
} | |
}; | |
self.send_bytes -= bytes_written; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
} | |
} | |
fn writerClose(self: *Client, loop: *Loop) void { | |
const task = @intToPtr(?*Loop.Task, self.writer_notify & ~@as(usize, 0b11)); | |
self.writer_notify = 2; | |
loop.schedule(task orelse return, .lifo); | |
} | |
fn writerNotify(self: *Client, loop: *Loop) void { | |
const task = @intToPtr(?*Loop.Task, self.writer_notify & ~@as(usize, 0b11)); | |
self.writer_notify = 1; | |
loop.schedule(task orelse return, .lifo); | |
} | |
fn writerWait(self: *Client) !void { | |
while (true) { | |
if (self.writer_notify == 2) | |
return error.ReaderClosed; | |
if (self.writer_notify == 1) { | |
self.writer_notify = 0; | |
return; | |
} | |
std.debug.assert(self.writer_notify == 0); | |
var task = Loop.Task{ .frame = @frame(), .name = "Client.writer(wait-for-notify)" }; | |
suspend self.writer_notify = @ptrToInt(&task); | |
} | |
} | |
}; | |
const Loop = struct { | |
io_driver: IoDriver, | |
run_queue: struct { | |
head: ?*Task = null, | |
tail: *Task = undefined, | |
} = .{}, | |
fn init(self: *Loop) !void { | |
self.* = .{ .io_driver = undefined }; | |
try self.io_driver.init(); | |
} | |
fn deinit(self: *Loop) void { | |
self.io_driver.deinit(); | |
} | |
fn run(self: *Loop) void { | |
while (self.poll()) |task| { | |
// std.debug.warn("running: {s}\n", .{task.name}); | |
resume task.frame; | |
} | |
} | |
fn schedule(self: *Loop, task: *Task, order: enum{fifo, lifo}) void { | |
task.next = null; | |
// std.debug.warn("schedule {s} {}\n", .{task.name, order}); | |
if (self.run_queue.head) |head| { | |
switch (order) { | |
.lifo => { | |
task.next = head; | |
self.run_queue.head = task; | |
}, | |
.fifo => { | |
self.run_queue.tail.next = task; | |
self.run_queue.tail = task; | |
}, | |
} | |
} else { | |
self.run_queue.head = task; | |
self.run_queue.tail = task; | |
} | |
} | |
fn poll(self: *Loop) ?*Task { | |
while (true) { | |
if (self.pollQueue()) |task| | |
return task; | |
if (!self.io_driver.poll()) | |
return null; | |
} | |
} | |
fn pollQueue(self: *Loop) ?*Task { | |
const task = self.run_queue.head orelse return null; | |
self.run_queue.head = task.next; | |
return task; | |
} | |
const Task = struct { | |
next: ?*Task = undefined, | |
name: []const u8 = "default", | |
frame: anyframe, | |
}; | |
const IoDriver = struct { | |
epoll_fd: std.os.fd_t, | |
pending: usize, | |
fn init(self: *IoDriver) !void { | |
self.* = .{ | |
.epoll_fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC), | |
.pending = 0, | |
}; | |
} | |
fn deinit(self: *IoDriver) void { | |
std.os.close(self.epoll_fd); | |
} | |
fn poll(self: *IoDriver) bool { | |
if (self.pending == 0) | |
return false; | |
var events: [128]std.os.epoll_event = undefined; | |
const found = std.os.epoll_wait(self.epoll_fd, &events, -1); | |
const loop = @fieldParentPtr(Loop, "io_driver", self); | |
for (events[0..found]) |ev| { | |
const event = @intToPtr(?*Event, ev.data.ptr) orelse continue; | |
if (ev.events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLOUT) != 0) { | |
if (event.notify(.write)) |task| | |
loop.schedule(task, .lifo); | |
} | |
if (ev.events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLIN | std.os.EPOLLRDHUP) != 0) { | |
if (event.notify(.read)) |task| | |
loop.schedule(task, .fifo); | |
} | |
} | |
return true; | |
} | |
const Event = struct { | |
fd: std.os.fd_t, | |
reader: usize = 0, | |
writer: usize = 0, | |
driver: *IoDriver, | |
fn init(self: *Event, fd: std.os.fd_t, driver: *IoDriver) !void { | |
self.* = .{ | |
.fd = fd, | |
.driver = driver, | |
}; | |
var event = std.os.epoll_event{ | |
.events = std.os.EPOLLIN | std.os.EPOLLOUT | std.os.EPOLLET | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = @ptrToInt(self) }, | |
}; | |
try std.os.epoll_ctl( | |
driver.epoll_fd, | |
std.os.EPOLL_CTL_ADD, | |
fd, | |
&event, | |
); | |
} | |
fn deinit(self: *Event) void { | |
defer self.* = undefined; | |
std.os.epoll_ctl( | |
self.driver.epoll_fd, | |
std.os.EPOLL_CTL_DEL, | |
self.fd, | |
null, | |
) catch unreachable; | |
} | |
fn waitFor(self: *Event, kind: enum{read, write}) void { | |
const queue = switch (kind) { | |
.read => &self.reader, | |
.write => &self.writer, | |
}; | |
if (queue.* == 1) { | |
queue.* = 0; | |
} else { | |
std.debug.assert(queue.* == 0); | |
self.driver.pending += 1; | |
var task = Task{ .frame = @frame(), .name = if (kind == .read) "Event::Read" else "Event::Write" }; | |
suspend queue.* = @ptrToInt(&task); | |
} | |
} | |
fn notify(self: *Event, kind: enum{read, write}) ?*Task { | |
const queue = switch (kind) { | |
.read => &self.reader, | |
.write => &self.writer, | |
}; | |
if (@intToPtr(?*Task, queue.* & ~@as(usize, 1))) |task| { | |
self.driver.pending -= 1; | |
queue.* = 0; | |
return task; | |
} else { | |
queue.* = 1; | |
return null; | |
} | |
} | |
}; | |
}; | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const ThreadPool = @import("./ThreadPoolIO.zig"); | |
pub fn main() !void { | |
var pool = try ThreadPool.init(.{ .max_threads = 6 }); | |
defer pool.deinit(); | |
var server: Server = undefined; | |
try server.init(&pool, 12345); | |
var event = std.StaticResetEvent{}; | |
event.wait(); | |
} | |
const Poller = struct { | |
fd: std.os.fd_t, | |
gpa: std.heap.GeneralPurposeAllocator(.{}) = .{}, | |
fn init(self: *Poller) !void { | |
const fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC); | |
self.* = .{ .fd = fd }; | |
} | |
fn deinit(self: *Poller) void { | |
std.os.close(self.fd); | |
_ = self.gpa.deinit(); | |
} | |
fn getAllocator(self: *Poller) *std.mem.Allocator { | |
return &self.gpa.allocator; | |
} | |
}; | |
const Server = struct { | |
fd: std.os.socket_t, | |
pool: *ThreadPool, | |
io_runnable: ThreadPool.IoRunnable, | |
gpa: std.heap.GeneralPurposeAllocator(.{}), | |
port: u16, | |
start_runnable: ThreadPool.Runnable, | |
fn init(self: *Server, pool: *ThreadPool, comptime port: u16) !void { | |
self.fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
errdefer std.os.close(self.fd); | |
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable; | |
try std.os.setsockopt(self.fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(self.fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(self.fd, 128); | |
self.gpa = .{}; | |
self.pool = pool; | |
self.io_runnable = ThreadPool.IoRunnable{ | |
.is_readable = true, | |
.runnable = .{ .runFn = Server.run }, | |
}; | |
try pool.waitFor(self.fd, &self.io_runnable); | |
self.port = port; | |
self.start_runnable = .{ .runFn = Server.start }; | |
pool.schedule(.{}, &self.start_runnable); | |
} | |
fn start(runnable: *ThreadPool.Runnable) void { | |
const self = @fieldParentPtr(Server, "start_runnable", runnable); | |
std.debug.warn("Listening on :{}\n", .{self.port}); | |
} | |
fn run(runnable: *ThreadPool.Runnable) void { | |
const io_runnable = @fieldParentPtr(ThreadPool.IoRunnable, "runnable", runnable); | |
const self = @fieldParentPtr(Server, "io_runnable", io_runnable); | |
self.accept() catch |err| { | |
std.os.close(self.fd); | |
std.debug.warn("Server shutdown\n", .{}); | |
}; | |
} | |
fn accept(self: *Server) !void { | |
if (self.io_runnable.is_closable) | |
return error.ServerShutdown; | |
if (!self.io_runnable.is_readable) | |
unreachable; | |
while (true) { | |
const client_fd = std.os.accept(self.fd, null, null, std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC) catch |err| switch (err) { | |
error.WouldBlock => break, | |
else => |e| return e, | |
}; | |
Client.init(client_fd, self) catch |err| { | |
std.os.close(client_fd); | |
std.debug.warn("Failed to spawn client: {}\n", .{err}); | |
}; | |
} | |
try self.pool.waitFor(self.fd, &self.io_runnable); | |
} | |
}; | |
const Client = struct { | |
fd: std.os.socket_t, | |
server: *Server, | |
send_bytes: usize = 0, | |
send_partial: usize = 0, | |
recv_bytes: usize = 0, | |
recv_buffer: [4096]u8 = undefined, | |
io_runnable: ThreadPool.IoRunnable, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn init(fd: std.os.socket_t, server: *Server) !void { | |
const allocator = &server.gpa.allocator; | |
const self = try allocator.create(Client); | |
errdefer allocator.destroy(self); | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
self.* = .{ | |
.fd = fd, | |
.server = server, | |
.io_runnable = .{ | |
.is_readable = true, | |
.runnable = .{ .runFn = Client.run }, | |
}, | |
}; | |
try self.server.pool.waitFor(self.fd, &self.io_runnable); | |
} | |
fn run(runnable: *ThreadPool.Runnable) void { | |
const io_runnable = @fieldParentPtr(ThreadPool.IoRunnable, "runnable", runnable); | |
const self = @fieldParentPtr(Client, "io_runnable", io_runnable); | |
self.process() catch { | |
std.os.close(self.fd); | |
self.server.gpa.allocator.destroy(self); | |
}; | |
} | |
fn process(self: *Client) !void { | |
if (self.io_runnable.is_closable) | |
return error.Closed; | |
var written = false; | |
if (self.io_runnable.is_writable and (self.send_bytes > 0)) { | |
written = true; | |
try self.processWrite(); | |
} | |
if (self.io_runnable.is_readable) | |
try self.processRead(); | |
if (!written and self.io_runnable.is_writable and (self.send_bytes > 0)) | |
try self.processWrite(); | |
self.io_runnable.is_readable = true; | |
self.io_runnable.is_writable = self.send_bytes > 0; | |
try self.server.pool.waitFor(self.fd, &self.io_runnable); | |
} | |
fn processRead(self: *Client) !void { | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
self.send_bytes += HTTP_RESPONSE.len; | |
continue; | |
} | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
const bytes_read = std.os.read(self.fd, readable_buffer) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.recv_bytes += bytes_read; | |
if (bytes_read == 0) | |
return error.EndOfStream; | |
} | |
} | |
fn processWrite(self: *Client) !void { | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
while (self.send_bytes > 0) { | |
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial | |
const send_bytes = self.send_bytes; | |
if (self.send_partial > RESPONSE_CHUNK.len) | |
std.debug.panic("invalid send_partial={} chunk={}\n", .{self.send_partial, RESPONSE_CHUNK.len}); | |
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
const iov_len = std.math.min(send_bytes, RESPONSE_CHUNK.len - self.send_partial); | |
const writable_buffer = iov_base[0..iov_len]; | |
// Perform the actual write. | |
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing. | |
const bytes_written = std.os.sendto( | |
self.fd, | |
writable_buffer, | |
std.os.MSG_NOSIGNAL, | |
null, | |
@as(std.os.socklen_t, 0), | |
) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
self.send_bytes -= bytes_written; | |
} | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const ThreadPool = @import("./ThreadPool.zig"); | |
pub fn main() !void { | |
var poller: Poller = undefined; | |
try poller.init(); | |
defer poller.deinit(); | |
var server: Server = undefined; | |
try server.init(&poller, 12345); | |
var pool = ThreadPool.init(.{}); | |
defer pool.deinit(); | |
while (true) { | |
var events: [1024]std.os.epoll_event = undefined; | |
const found = std.os.epoll_wait(poller.fd, &events, -1); | |
var batch = ThreadPool.Batch{}; | |
defer if (!batch.isEmpty()) | |
pool.schedule(.{}, batch); | |
for (events[0..found]) |event| { | |
const socket = @intToPtr(*Socket, event.data.ptr); | |
socket.events = event.events; | |
batch.push(&socket.runnable); | |
} | |
} | |
} | |
const Poller = struct { | |
fd: std.os.fd_t, | |
gpa: std.heap.GeneralPurposeAllocator(.{}) = .{}, | |
fn init(self: *Poller) !void { | |
const fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC); | |
self.* = .{ .fd = fd }; | |
} | |
fn deinit(self: *Poller) void { | |
std.os.close(self.fd); | |
_ = self.gpa.deinit(); | |
} | |
fn getAllocator(self: *Poller) *std.mem.Allocator { | |
return &self.gpa.allocator; | |
} | |
}; | |
const Socket = struct { | |
fd: std.os.socket_t, | |
poller: *Poller, | |
events: u32, | |
runnable: ThreadPool.Runnable, | |
fn init(self: *Socket, fd: std.os.socket_t, poller: *Poller, comptime Container: type) !void { | |
const Callback = struct { | |
fn runFn(runnable: *ThreadPool.Runnable) void { | |
const socket = @fieldParentPtr(Socket, "runnable", runnable); | |
const container = @fieldParentPtr(Container, "socket", socket); | |
container.run() catch { | |
std.os.close(socket.fd); | |
container.deinit(); | |
}; | |
} | |
}; | |
self.* = .{ | |
.fd = fd, | |
.poller = poller, | |
.events = undefined, | |
.runnable = .{ .runFn = Callback.runFn }, | |
}; | |
try std.os.epoll_ctl(poller.fd, std.os.EPOLL_CTL_ADD, fd, &std.os.epoll_event{ | |
.events = std.os.EPOLLIN | std.os.EPOLLOUT | std.os.EPOLLONESHOT | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = @ptrToInt(self) }, | |
}); | |
} | |
fn register(self: *Socket, events: u32) !void { | |
try std.os.epoll_ctl(self.poller.fd, std.os.EPOLL_CTL_MOD, self.fd, &std.os.epoll_event{ | |
.events = events | std.os.EPOLLONESHOT | std.os.EPOLLRDHUP, | |
.data = .{ .ptr = @ptrToInt(self) }, | |
}); | |
} | |
}; | |
const Server = struct { | |
socket: Socket, | |
fn init(self: *Server, poller: *Poller, comptime port: u16) !void { | |
const fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
errdefer std.os.close(fd); | |
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable; | |
try std.os.setsockopt(fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(fd, 128); | |
try self.socket.init(fd, poller, Server); | |
std.debug.warn("Listening on :{}\n", .{port}); | |
} | |
fn deinit(self: *Server) void { | |
std.debug.warn("Server shutdown\n", .{}); | |
} | |
fn run(self: *Server) !void { | |
const events = self.socket.events; | |
const server_fd = self.socket.fd; | |
if (events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0) | |
return error.ServerShutdown; | |
if (events & std.os.EPOLLIN == 0) | |
unreachable; | |
while (true) { | |
const client_fd = std.os.accept(server_fd, null, null, std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC) catch |err| switch (err) { | |
error.WouldBlock => return self.socket.register(std.os.EPOLLIN), | |
else => |e| return e, | |
}; | |
Client.init(client_fd, self.socket.poller) catch |err| { | |
std.os.close(client_fd); | |
std.debug.warn("Failed to spawn client: {}\n", .{err}); | |
}; | |
} | |
} | |
}; | |
const Client = struct { | |
socket: Socket, | |
send_bytes: usize = 0, | |
send_partial: usize = 0, | |
recv_bytes: usize = 0, | |
recv_buffer: [4096]u8 = undefined, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn init(fd: std.os.socket_t, poller: *Poller) !void { | |
const allocator = poller.getAllocator(); | |
const self = try allocator.create(Client); | |
errdefer allocator.destroy(self); | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
self.* = .{ .socket = undefined }; | |
try self.socket.init(fd, poller, Client); | |
} | |
fn deinit(self: *Client) void { | |
const allocator = self.socket.poller.getAllocator(); | |
allocator.destroy(self); | |
} | |
fn run(self: *Client) !void { | |
var events = self.socket.events; | |
if (events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0) | |
return error.Closed; | |
var written = false; | |
if ((events & std.os.EPOLLOUT != 0) and (self.send_bytes > 0)) { | |
written = true; | |
try self.processWrite(); | |
} | |
if (events & std.os.EPOLLIN != 0) | |
try self.processRead(); | |
if (!written and (events & std.os.EPOLLOUT != 0) and (self.send_bytes > 0)) | |
try self.processWrite(); | |
events = std.os.EPOLLIN; | |
if (self.send_bytes > 0) | |
events |= std.os.EPOLLOUT; | |
try self.socket.register(events); | |
} | |
fn processRead(self: *Client) !void { | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
self.send_bytes += HTTP_RESPONSE.len; | |
continue; | |
} | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
const bytes_read = std.os.read(self.socket.fd, readable_buffer) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.recv_bytes += bytes_read; | |
if (bytes_read == 0) | |
return error.EndOfStream; | |
} | |
} | |
fn processWrite(self: *Client) !void { | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
while (self.send_bytes > 0) { | |
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial | |
const send_bytes = self.send_bytes; | |
if (self.send_partial > RESPONSE_CHUNK.len) | |
std.debug.panic("invalid send_partial={} chunk={}\n", .{self.send_partial, RESPONSE_CHUNK.len}); | |
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
const iov_len = std.math.min(send_bytes, RESPONSE_CHUNK.len - self.send_partial); | |
const writable_buffer = iov_base[0..iov_len]; | |
// Perform the actual write. | |
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing. | |
const bytes_written = std.os.sendto( | |
self.socket.fd, | |
writable_buffer, | |
std.os.MSG_NOSIGNAL, | |
null, | |
@as(std.os.socklen_t, 0), | |
) catch |err| switch (err) { | |
error.WouldBlock => return, | |
else => |e| return e, | |
}; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
self.send_bytes -= bytes_written; | |
} | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const system = switch (std.builtin.os.tag) { | |
.linux => std.os.linux, | |
else => std.os.system, | |
}; | |
const ThreadPool = @This(); | |
max_threads: u16, | |
counter: u32 = 0, | |
spawned_queue: ?*Worker = null, | |
run_queue: UnboundedQueue = .{}, | |
idle_semaphore: Semaphore = Semaphore.init(0), | |
shutdown_event: Event = .{}, | |
pub const InitConfig = struct { | |
max_threads: ?u16 = null, | |
}; | |
pub fn init(config: InitConfig) ThreadPool { | |
return .{ | |
.max_threads = std.math.min( | |
std.math.maxInt(u14), | |
std.math.max(1, config.max_threads orelse blk: { | |
break :blk @intCast(u16, std.Thread.cpuCount() catch 1); | |
}), | |
), | |
}; | |
} | |
pub fn deinit(self: *ThreadPool) void { | |
defer self.* = undefined; | |
self.shutdown(); | |
self.shutdown_event.wait(); | |
while (self.spawned_queue) |worker| { | |
self.spawned_queue = worker.spawned_next; | |
const thread = worker.thread; | |
worker.shutdown_event.notify(); | |
thread.wait(); | |
} | |
} | |
pub const ScheduleHints = struct { | |
priority: Priority = .Normal, | |
pub const Priority = enum { | |
High, | |
Normal, | |
Low, | |
}; | |
}; | |
pub fn schedule(self: *ThreadPool, hints: ScheduleHints, batchable: anytype) void { | |
const batch = Batch.from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (Worker.current) |worker| { | |
worker.push(hints, batch); | |
} else { | |
self.run_queue.push(batch); | |
} | |
_ = self.tryNotifyWith(false); | |
} | |
pub const SpawnConfig = struct { | |
allocator: *std.mem.Allocator, | |
hints: ScheduleHints = .{}, | |
}; | |
pub fn spawn(self: *ThreadPool, config: SpawnConfig, comptime func: anytype, args: anytype) !void { | |
const Args = @TypeOf(args); | |
const is_async = @typeInfo(@TypeOf(func)).Fn.calling_convention == .Async; | |
const Closure = struct { | |
func_args: Args, | |
allocator: *std.mem.Allocator, | |
runnable: Runnable = .{ .runFn = runFn }, | |
frame: if (is_async) @Frame(runAsyncFn) else void = undefined, | |
fn runFn(runnable: *Runnable) void { | |
const closure = @fieldParentPtr(@This(), "runnable", runnable); | |
if (is_async) { | |
closure.frame = async closure.runAsyncFn(); | |
} else { | |
const result = @call(.{}, func, closure.func_args); | |
closure.allocator.destroy(closure); | |
} | |
} | |
fn runAsyncFn(closure: *@This()) void { | |
const result = @call(.{}, func, closure.func_args); | |
suspend closure.allocator.destroy(closure); | |
} | |
}; | |
const allocator = config.allocator; | |
const closure = try allocator.create(Closure); | |
errdefer allocator.destroy(closure); | |
closure.* = .{ | |
.func_args = args, | |
.allocator = allocator, | |
}; | |
const hints = config.hints; | |
self.schedule(hints, &closure.runnable); | |
} | |
const Counter = struct { | |
state: State = .pending, | |
idle: u16 = 0, | |
spawned: u16 = 0, | |
const State = enum(u4) { | |
pending = 0, | |
notified, | |
waking, | |
waker_notified, | |
shutdown, | |
}; | |
fn pack(self: Counter) u32 { | |
return (@as(u32, @as(u4, @enumToInt(self.state))) | | |
(@as(u32, @intCast(u14, self.idle)) << 4) | | |
(@as(u32, @intCast(u14, self.spawned)) << (4 + 14))); | |
} | |
fn unpack(value: u32) Counter { | |
return Counter{ | |
.state = @intToEnum(State, @truncate(u4, value)), | |
.idle = @as(u16, @truncate(u14, value >> 4)), | |
.spawned = @as(u16, @truncate(u14, value >> (4 + 14))), | |
}; | |
} | |
}; | |
fn tryNotifyWith(self: *ThreadPool, is_caller_waking: bool) bool { | |
var spawned = false; | |
var remaining_attempts: u8 = 5; | |
var is_waking = is_caller_waking; | |
while (true) : (yieldCpu()) { | |
const counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
if (counter.state == .shutdown) { | |
if (spawned) | |
self.releaseWorker(); | |
return false; | |
} | |
const has_pending = (counter.idle > 0) or (counter.spawned < self.max_threads); | |
const can_wake = (is_waking and remaining_attempts > 0) or (!is_waking and counter.state == .pending); | |
if (has_pending and can_wake) { | |
var new_counter = counter; | |
new_counter.state = .waking; | |
if (counter.idle > 0) { | |
new_counter.idle -= 1; | |
} else if (!spawned) { | |
new_counter.spawned += 1; | |
} | |
if (@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Acquire, | |
.Monotonic, | |
)) |failed| { | |
continue; | |
} | |
is_waking = true; | |
if (counter.idle > 0) { | |
self.idle_semaphore.post(1) catch unreachable; | |
return true; | |
} | |
spawned = true; | |
if (Worker.spawn(self)) | |
return true; | |
remaining_attempts -= 1; | |
continue; | |
} | |
var new_counter = counter; | |
if (is_waking) { | |
new_counter.state = if (can_wake) .pending else .notified; | |
if (spawned) | |
new_counter.spawned -= 1; | |
} else if (counter.state == .waking) { | |
new_counter.state = .waker_notified; | |
} else if (counter.state == .pending) { | |
new_counter.state = .notified; | |
} else { | |
return false; | |
} | |
_ = @cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Monotonic, | |
.Monotonic, | |
) orelse return true; | |
} | |
} | |
const Wait = enum { | |
resumed, | |
notified, | |
shutdown, | |
}; | |
fn tryWaitWith(self: *ThreadPool, is_caller_waking: bool) Wait { | |
var is_waking = is_caller_waking; | |
var counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
while (true) { | |
if (counter.state == .shutdown) { | |
self.releaseWorker(); | |
return .shutdown; | |
} | |
const is_notified = switch (counter.state) { | |
.waker_notified => is_waking, | |
.notified => true, | |
else => false, | |
}; | |
var new_counter = counter; | |
if (is_notified) { | |
new_counter.state = if (is_waking) .waking else .pending; | |
} else { | |
new_counter.idle += 1; | |
if (is_waking) | |
new_counter.state = .pending; | |
} | |
if (@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Monotonic, | |
.Monotonic, | |
)) |updated| { | |
counter = Counter.unpack(updated); | |
continue; | |
} | |
if (is_notified and is_waking) | |
return .notified; | |
if (is_notified) | |
return .resumed; | |
self.idle_semaphore.wait(1); | |
return .notified; | |
} | |
} | |
fn releaseWorker(self: *ThreadPool) void { | |
const counter_spawned = Counter{ .spawned = 1 }; | |
const counter_value = @atomicRmw(u32, &self.counter, .Sub, counter_spawned.pack(), .AcqRel); | |
const counter = Counter.unpack(counter_value); | |
if (counter.state != .shutdown) | |
std.debug.panic("ThreadPool.releaseWorker() when not shutdown: {}", .{counter}); | |
if (counter.spawned == 1) | |
self.shutdown_event.notify(); | |
} | |
pub fn shutdown(self: *ThreadPool) void { | |
while (true) : (yieldCpu()) { | |
const counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
if (counter.state == .shutdown) | |
return; | |
var new_counter = counter; | |
new_counter.state = .shutdown; | |
new_counter.idle = 0; | |
if (@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Acquire, | |
.Monotonic, | |
)) |failed| { | |
continue; | |
} | |
self.idle_semaphore.post(self.max_threads) catch unreachable; | |
return; | |
} | |
} | |
const Worker = struct { | |
pool: *ThreadPool, | |
thread: *std.Thread, | |
spawned_next: ?*Worker = null, | |
shutdown_event: Event = .{}, | |
run_queue: BoundedQueue = .{}, | |
run_queue_next: ?*Runnable = null, | |
run_queue_lifo: ?*Runnable = null, | |
run_queue_overflow: UnboundedQueue = .{}, | |
tick: usize = undefined, | |
is_waking: bool = true, | |
next_target: ?*Worker = null, | |
threadlocal var current: ?*Worker = null; | |
fn spawn(pool: *ThreadPool) bool { | |
const Spawner = struct { | |
thread: *std.Thread = undefined, | |
thread_pool: *ThreadPool, | |
data_put_event: Event = .{}, | |
data_get_event: Event = .{}, | |
fn entry(self: *@This()) void { | |
self.data_put_event.wait(); | |
const thread = self.thread; | |
const thread_pool = self.thread_pool; | |
self.data_get_event.notify(); | |
Worker.run(thread, thread_pool); | |
} | |
}; | |
var spawner = Spawner{ .thread_pool = pool }; | |
spawner.thread = std.Thread.spawn(&spawner, Spawner.entry) catch return false; | |
spawner.data_put_event.notify(); | |
spawner.data_get_event.wait(); | |
return true; | |
} | |
fn run(thread: *std.Thread, pool: *ThreadPool) void { | |
var self = Worker{ | |
.thread = thread, | |
.pool = pool, | |
}; | |
self.tick = @ptrToInt(&self); | |
current = &self; | |
defer current = null; | |
var spawned_queue = @atomicLoad(?*Worker, &pool.spawned_queue, .Monotonic); | |
while (true) { | |
self.spawned_next = spawned_queue; | |
spawned_queue = @cmpxchgWeak( | |
?*Worker, | |
&pool.spawned_queue, | |
spawned_queue, | |
&self, | |
.Release, | |
.Monotonic, | |
) orelse break; | |
} | |
while (true) { | |
if (self.pop()) |runnable| { | |
if (self.is_waking) { | |
self.is_waking = false; | |
_ = pool.tryNotifyWith(true); | |
} | |
self.tick +%= 1; | |
runnable.run(); | |
continue; | |
} | |
self.is_waking = switch (pool.tryWaitWith(self.is_waking)) { | |
.resumed => false, | |
.notified => true, | |
.shutdown => { | |
self.shutdown_event.wait(); | |
break; | |
}, | |
}; | |
} | |
} | |
fn push(self: *Worker, hints: ScheduleHints, batchable: anytype) void { | |
var batch = Batch.from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (hints.priority == .High) { | |
const new_lifo = batch.pop(); | |
if (@atomicLoad(?*Runnable, &self.run_queue_lifo, .Monotonic) == null) { | |
@atomicStore(?*Runnable, &self.run_queue_lifo, new_lifo, .Release); | |
} else if (@atomicRmw(?*Runnable, &self.run_queue_lifo, .Xchg, new_lifo, .AcqRel)) |old_lifo| { | |
batch.pushFront(old_lifo); | |
} | |
} | |
if (hints.priority == .Low) { | |
if (self.run_queue_next) |old_next| | |
batch.pushFront(old_next); | |
self.run_queue_next = null; | |
self.run_queue_next = self.pop() orelse batch.pop(); | |
} | |
if (self.run_queue.push(batch)) |overflowed| | |
self.run_queue_overflow.push(overflowed); | |
} | |
fn pop(self: *Worker) ?*Runnable { | |
if (self.tick % 127 == 0) { | |
if (self.popAndStealFromOthers()) |runnable| | |
return runnable; | |
} | |
if (self.tick % 61 == 0) { | |
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable| | |
return runnable; | |
} | |
if (self.tick % 31 == 0) { | |
if (self.run_queue.popAndStealUnbounded(&self.run_queue_overflow)) |runnable| | |
return runnable; | |
} | |
if (self.tick % 13 == 0) { | |
if (self.popAndStealLifo(self)) |runnable| | |
return runnable; | |
} | |
if (self.run_queue.pop()) |runnable| | |
return runnable; | |
if (self.popAndStealLifo(self)) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&self.run_queue_overflow)) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable| | |
return runnable; | |
if (self.popAndStealFromOthers()) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable| | |
return runnable; | |
return null; | |
} | |
fn popAndStealLifo(self: *Worker, target: *Worker) ?*Runnable { | |
var run_queue_lifo = @atomicLoad(?*Runnable, &target.run_queue_lifo, .Monotonic); | |
while (true) { | |
if (run_queue_lifo == null) | |
return null; | |
run_queue_lifo = @cmpxchgWeak( | |
?*Runnable, | |
&target.run_queue_lifo, | |
run_queue_lifo, | |
null, | |
.Acquire, | |
.Monotonic, | |
) orelse return run_queue_lifo; | |
} | |
} | |
fn popAndStealFromOthers(self: *Worker) ?*Runnable { | |
var num_workers = blk: { | |
const counter_value = @atomicLoad(u32, &self.pool.counter, .Monotonic); | |
const counter = Counter.unpack(counter_value); | |
break :blk counter.spawned; | |
}; | |
while (num_workers > 0) : (num_workers -= 1) { | |
const target = self.next_target orelse blk: { | |
break :blk @atomicLoad(?*Worker, &self.pool.spawned_queue, .Acquire) orelse { | |
std.debug.panic("Worker observed empty spawned queue when work-stealing", .{}); | |
}; | |
}; | |
self.next_target = target.spawned_next; | |
if (target == self) | |
continue; | |
if (self.run_queue.popAndStealBounded(&target.run_queue)) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&target.run_queue_overflow)) |runnable| | |
return runnable; | |
if (self.popAndStealLifo(target)) |runnable| | |
return runnable; | |
} | |
return null; | |
} | |
}; | |
const UnboundedQueue = struct { | |
lock: Mutex = .{}, | |
batch: Batch = .{}, | |
shared_size: usize = 0, | |
fn push(self: *UnboundedQueue, batchable: anytype) void { | |
const batch = Batch.from(batchable); | |
if (batch.isEmpty()) | |
return; | |
const held = self.lock.acquire(); | |
defer held.release(); | |
self.batch.push(batch); | |
var shared_size = self.shared_size; | |
shared_size += batch.size; | |
@atomicStore(usize, &self.shared_size, shared_size, .Release); | |
} | |
fn tryAcquireConsumer(self: *UnboundedQueue) ?Consumer { | |
var shared_size = @atomicLoad(usize, &self.shared_size, .Acquire); | |
if (shared_size == 0) | |
return null; | |
const held = self.lock.acquire(); | |
shared_size = self.shared_size; | |
if (shared_size == 0) { | |
held.release(); | |
return null; | |
} | |
return Consumer{ | |
.held = held, | |
.queue = self, | |
.size = shared_size, | |
}; | |
} | |
const Consumer = struct { | |
held: Mutex.Held, | |
queue: *UnboundedQueue, | |
size: usize, | |
fn release(self: Consumer) void { | |
@atomicStore(usize, &self.queue.shared_size, self.size, .Release); | |
self.held.release(); | |
} | |
fn pop(self: *Consumer) ?*Runnable { | |
const runnable = self.queue.batch.pop() orelse return null; | |
self.size -= 1; | |
return runnable; | |
} | |
}; | |
}; | |
const BoundedQueue = struct { | |
head: usize = 0, | |
tail: usize = 0, | |
buffer: [256]*Runnable = undefined, | |
fn push(self: *BoundedQueue, batchable: anytype) ?Batch { | |
var batch = Batch.from(batchable); | |
while (true) : (yieldCpu()) { | |
if (batch.isEmpty()) | |
return null; | |
var tail = self.tail; | |
var head = @atomicLoad(usize, &self.head, .Acquire); | |
var size = tail -% head; | |
if (size < self.buffer.len) { | |
while (size < self.buffer.len) { | |
const runnable = batch.pop() orelse break; | |
@atomicStore(*Runnable, &self.buffer[tail % self.buffer.len], runnable, .Unordered); | |
tail +%= 1; | |
size += 1; | |
} | |
@atomicStore(usize, &self.tail, tail, .Release); | |
continue; | |
} | |
var migrate = self.buffer.len / 2; | |
if (@cmpxchgWeak( | |
usize, | |
&self.head, | |
head, | |
head +% migrate, | |
.AcqRel, | |
.Acquire, | |
)) |failed| { | |
continue; | |
} | |
var overflowed = Batch{}; | |
while (migrate > 0) : (migrate -= 1) { | |
const runnable = self.buffer[head % self.buffer.len]; | |
overflowed.push(runnable); | |
head +%= 1; | |
} | |
overflowed.push(batch); | |
return overflowed; | |
} | |
} | |
fn pop(self: *BoundedQueue) ?*Runnable { | |
while (true) : (yieldCpu()) { | |
const tail = self.tail; | |
const head = @atomicLoad(usize, &self.head, .Acquire); | |
const size = tail -% head; | |
if (size == 0) | |
return null; | |
if (@cmpxchgWeak( | |
usize, | |
&self.head, | |
head, | |
head +% 1, | |
.AcqRel, | |
.Acquire, | |
)) |failed| { | |
continue; | |
} | |
const runnable = self.buffer[head % self.buffer.len]; | |
return runnable; | |
} | |
} | |
fn popAndStealBounded(self: *BoundedQueue, target: *BoundedQueue) ?*Runnable { | |
if (target == self) | |
return self.pop(); | |
const tail = self.tail; | |
const head = @atomicLoad(usize, &self.head, .Acquire); | |
const size = tail -% head; | |
if (size != 0) | |
return self.pop(); | |
while (true) : (yieldThread()) { | |
const target_head = @atomicLoad(usize, &target.head, .Acquire); | |
const target_tail = @atomicLoad(usize, &target.tail, .Acquire); | |
const target_size = target_tail -% target_head; | |
var steal = target_size - (target_size / 2); | |
if (steal == 0) | |
return null; | |
if (steal > target.buffer.len / 2) | |
continue; | |
const first_runnable = @atomicLoad(*Runnable, &target.buffer[target_head % target.buffer.len], .Unordered); | |
var new_target_head = target_head +% 1; | |
var new_tail = tail; | |
steal -= 1; | |
while (steal > 0) : (steal -= 1) { | |
const runnable = @atomicLoad(*Runnable, &target.buffer[new_target_head % target.buffer.len], .Unordered); | |
new_target_head +%= 1; | |
@atomicStore(*Runnable, &self.buffer[new_tail % self.buffer.len], runnable, .Unordered); | |
new_tail +%= 1; | |
} | |
if (@cmpxchgWeak( | |
usize, | |
&target.head, | |
target_head, | |
new_target_head, | |
.AcqRel, | |
.Acquire, | |
)) |failed| { | |
continue; | |
} | |
@atomicStore(usize, &self.tail, new_tail, .Release); | |
return first_runnable; | |
} | |
} | |
fn popAndStealUnbounded(self: *BoundedQueue, target: *UnboundedQueue) ?*Runnable { | |
var consumer = target.tryAcquireConsumer() orelse return null; | |
defer consumer.release(); | |
const first_runnable = consumer.pop() orelse return null; | |
var tail = self.tail; | |
var head = @atomicLoad(usize, &self.head, .Acquire); | |
var size = tail -% head; | |
while (size < self.buffer.len) { | |
const runnable = consumer.pop() orelse break; | |
@atomicStore(*Runnable, &self.buffer[tail % self.buffer.len], runnable, .Unordered); | |
tail +%= 1; | |
size += 1; | |
} | |
@atomicStore(usize, &self.tail, tail, .Release); | |
return first_runnable; | |
} | |
}; | |
pub const Runnable = struct { | |
next: ?*Runnable = null, | |
runFn: fn (*Runnable) void, | |
pub fn run(self: *Runnable) void { | |
return (self.runFn)(self); | |
} | |
}; | |
pub const Batch = struct { | |
head: ?*Runnable = null, | |
tail: *Runnable = undefined, | |
size: usize = 0, | |
pub fn from(batchable: anytype) Batch { | |
return switch (@TypeOf(batchable)) { | |
Batch => batchable, | |
?*Runnable => from(batchable orelse return Batch{}), | |
*Runnable => { | |
batchable.next = null; | |
return Batch{ | |
.head = batchable, | |
.tail = batchable, | |
.size = 1, | |
}; | |
}, | |
else => |typ| @compileError(@typeName(typ) ++ | |
" cannot be converted into " ++ | |
@typeName(Batch)), | |
}; | |
} | |
pub fn isEmpty(self: Batch) bool { | |
return self.head == null; | |
} | |
pub const push = pushBack; | |
pub fn pushBack(self: *Batch, batchable: anytype) void { | |
const batch = from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (self.isEmpty()) { | |
self.* = batch; | |
} else { | |
self.tail.next = batch.head; | |
self.tail = batch.tail; | |
self.size += batch.size; | |
} | |
} | |
pub fn pushFront(self: *Batch, batchable: anytype) void { | |
const batch = from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (self.isEmpty()) { | |
self.* = batch; | |
} else { | |
batch.tail.next = self.head; | |
self.head = batch.head; | |
self.size += batch.size; | |
} | |
} | |
pub const pop = popFront; | |
pub fn popFront(self: *Batch) ?*Runnable { | |
const runnable = self.head orelse return null; | |
self.head = runnable.next; | |
self.size -= 1; | |
return runnable; | |
} | |
}; | |
const Semaphore = struct { | |
lock: Mutex = .{}, | |
permits: usize = 0, | |
waiters: ?*Waiter = null, | |
const Waiter = struct { | |
next: ?*Waiter = null, | |
tail: *Waiter = undefined, | |
event: Event = .{}, | |
permits: usize, | |
}; | |
fn init(permits: usize) Semaphore { | |
return .{ .permits = permits }; | |
} | |
fn wait(self: *Semaphore, permits: usize) void { | |
const held = self.lock.acquire(); | |
if (self.permits >= permits) { | |
self.permits -= permits; | |
held.release(); | |
return; | |
} | |
var waiter = Waiter{ .permits = permits }; | |
if (self.waiters) |head| { | |
head.tail.next = &waiter; | |
head.tail = &waiter; | |
} else { | |
self.waiters = &waiter; | |
waiter.tail = &waiter; | |
} | |
held.release(); | |
waiter.event.wait(); | |
} | |
fn post(self: *Semaphore, permits: usize) error{Overflow}!void { | |
var waiters: ?*Waiter = null; | |
{ | |
const held = self.lock.acquire(); | |
defer held.release(); | |
if (@addWithOverflow(usize, self.permits, permits, &self.permits)) | |
return error.Overflow; | |
while (self.waiters) |waiter| { | |
if (waiter.permits > self.permits) | |
break; | |
self.waiters = waiter.next; | |
if (self.waiters) |new_waiter| | |
new_waiter.tail = waiter.tail; | |
self.permits -= waiter.permits; | |
waiter.next = waiters; | |
waiters = waiter; | |
} | |
} | |
while (waiters) |waiter| { | |
waiters = waiter.next; | |
waiter.event.notify(); | |
} | |
} | |
}; | |
const Mutex = if (std.builtin.os.tag == .windows) | |
struct { | |
srwlock: usize = 0, | |
pub fn acquire(self: *Mutex) Held { | |
AcquireSRWLockExclusive(&self.srwlock); | |
return Mutex{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
ReleaseSRWLockExclusive(&self.mutex.srwlock); | |
} | |
}; | |
extern "kernel32" fn AcquireSRWLockExclusive( | |
srwlock: *?system.PVOID, | |
) callconv(system.WINAPI) void; | |
extern "kernel32" fn ReleaseSRWLockExclusive( | |
srwlock: *?system.PVOID, | |
) callconv(system.WINAPI) void; | |
} | |
else if (comptime std.Target.current.isDarwin()) | |
struct { | |
lock: u32 = 0, | |
pub fn acquire(self: *Mutex) Held { | |
os_unfair_lock_lock(&self.lock); | |
return Held{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
os_unfair_lock_unlock(&self.mutex.lock); | |
} | |
}; | |
extern "c" fn os_unfair_lock_lock( | |
unfair_lock: *u32, | |
) callconv(.C) void; | |
extern "c" fn os_unfair_lock_unlock( | |
unfair_lock: *u32, | |
) callconv(.C) void; | |
} | |
else if (std.builtin.os.tag == .linux) | |
struct { | |
state: i32 = UNLOCKED, | |
const UNLOCKED: i32 = 0; | |
const LOCKED: i32 = 1; | |
const WAITING: i32 = 2; | |
pub fn acquire(self: *Mutex) Held { | |
const state = @atomicRmw(i32, &self.state, .Xchg, LOCKED, .Acquire); | |
if (state != UNLOCKED) | |
self.acquireSlow(state); | |
return Held{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
switch (@atomicRmw(i32, &self.mutex.state, .Xchg, UNLOCKED, .Release)) { | |
UNLOCKED => unreachable, // unlocked an unlocked mutex | |
LOCKED => {}, | |
WAITING => self.mutex.releaseSlow(), | |
else => unreachable, | |
} | |
} | |
}; | |
fn acquireSlow(self: *Mutex, current_state: i32) void { | |
@setCold(true); | |
var wait_state = current_state; | |
while (true) { | |
var spin: u8 = 0; | |
while (spin < 5) : (spin += 1) { | |
switch (@atomicLoad(i32, &self.state, .Monotonic)) { | |
UNLOCKED => _ = @cmpxchgWeak( | |
i32, | |
&self.state, | |
UNLOCKED, | |
wait_state, | |
.Acquire, | |
.Monotonic, | |
) orelse return, | |
LOCKED => {}, | |
WAITING => break, | |
else => unreachable, | |
} | |
if (spin < 4) { | |
var pause: u8 = 30; | |
while (pause > 0) : (pause -= 1) | |
yieldCpu(); | |
} else { | |
yieldThread(); | |
} | |
} | |
const state = @atomicRmw(i32, &self.state, .Xchg, WAITING, .Acquire); | |
if (state == UNLOCKED) | |
return; | |
wait_state = WAITING; | |
switch (system.getErrno(system.futex_wait( | |
&self.state, | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAIT, | |
WAITING, | |
null, | |
))) { | |
0 => {}, | |
system.EINTR => {}, | |
system.EAGAIN => {}, | |
else => unreachable, | |
} | |
} | |
} | |
fn releaseSlow(self: *Mutex) void { | |
@setCold(true); | |
while (true) { | |
return switch (system.getErrno(system.futex_wake( | |
&self.state, | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAKE, | |
@as(i32, 1), | |
))) { | |
0 => {}, | |
system.EINTR => continue, | |
system.EFAULT => {}, | |
else => unreachable, | |
}; | |
} | |
} | |
} | |
else | |
struct { | |
locked: bool = false, | |
pub fn acquire(self: *Mutex) Held { | |
while (@atomicRmw(bool, &self.locked, .Xchg, true, .Acquire)) | |
yieldThread(); | |
return Held{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
@atomicStore(bool, &self.mutex.locked, false, .Release); | |
} | |
}; | |
}; | |
const Event = if (std.builtin.os.tag == .windows) | |
struct { | |
key: u32 = undefined, | |
pub fn wait(self: *Event) void { | |
const status = NtWaitForKeyedEvent(null, &self.key, system.FALSE, null); | |
std.debug.assert(status == .SUCCESS); | |
} | |
pub fn notify(self: *Event) void { | |
const status = NtReleaseKeyedEvent(null, &self.key, system.FALSE, null); | |
std.debug.assert(status == .SUCCESS); | |
} | |
extern "NtDll" fn NtWaitForKeyedEvent( | |
handle: ?system.HANDLE, | |
key: ?*const u32, | |
alertable: system.BOOLEAN, | |
timeout: ?*const system.LARGE_INTEGER, | |
) callconv(system.WINAPI) system.NTSTATUS; | |
extern "NtDll" fn NtReleaseKeyedEvent( | |
handle: ?system.HANDLE, | |
key: ?*const u32, | |
alertable: system.BOOLEAN, | |
timeout: ?*const system.LARGE_INTEGER, | |
) callconv(system.WINAPI) system.NTSTATUS; | |
} | |
else if (comptime std.Target.current.isDarwin()) | |
struct { | |
state: enum(u32) { | |
pending = 0, | |
notified, | |
} = .pending, | |
pub fn wait(self: *Event) void { | |
while (true) { | |
switch (@atomicLoad(@TypeOf(self.state), &self.state, .Acquire)) { | |
.pending => {}, | |
.notified => { | |
@atomicStore(@TypeOf(self.state), &self.state, .pending, .Monotonic); | |
return; | |
}, | |
} | |
const status = __ulock_wait( | |
UL_COMPARE_AND_WAIT | ULF_NO_ERRNO, | |
@ptrCast(?*const c_void, &self.state), | |
@enumToInt(@TypeOf(self.state).pending), | |
~@as(u32, 0), | |
); | |
if (status < 0) { | |
switch (-status) { | |
system.EINTR => {}, | |
else => unreachable, | |
} | |
} | |
} | |
} | |
pub fn notify(self: *Event) void { | |
@atomicStore(@TypeOf(self.state), &self.state, .notified, .Release); | |
while (true) { | |
const status = __ulock_wake( | |
UL_COMPARE_AND_WAIT | ULF_NO_ERRNO, | |
@ptrCast(?*const c_void, &self.state), | |
@as(u32, 0), | |
); | |
if (status < 0) { | |
switch (-status) { | |
system.ENOENT => {}, | |
system.EINTR => continue, | |
else => unreachable, | |
} | |
} | |
return; | |
} | |
} | |
const ULF_NO_ERRNO = 0x1000000; | |
const UL_COMPARE_AND_WAIT = 0x1; | |
extern "c" fn __ulock_wait( | |
operation: u32, | |
address: ?*const c_void, | |
value: u64, | |
timeout_us: u32, | |
) callconv(.C) c_int; | |
extern "c" fn __ulock_wake( | |
operation: u32, | |
address: ?*const c_void, | |
value: u64, | |
) callconv(.C) c_int; | |
} | |
else if (std.builtin.os.tag == .linux) | |
struct { | |
state: enum(i32) { | |
pending, | |
notified, | |
} = .pending, | |
pub fn wait(self: *Event) void { | |
while (true) { | |
switch (@atomicLoad(@TypeOf(self.state), &self.state, .Acquire)) { | |
.pending => {}, | |
.notified => { | |
@atomicStore(@TypeOf(self.state), &self.state, .pending, .Monotonic); | |
return; | |
}, | |
} | |
switch (system.getErrno(system.futex_wait( | |
@ptrCast(*const i32, &self.state), | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAIT, | |
@enumToInt(@TypeOf(self.state).pending), | |
null, | |
))) { | |
0 => {}, | |
system.EINTR => {}, | |
system.EAGAIN => {}, | |
else => unreachable, | |
} | |
} | |
} | |
pub fn notify(self: *Event) void { | |
@atomicStore(@TypeOf(self.state), &self.state, .notified, .Release); | |
while (true) { | |
return switch (system.getErrno(system.futex_wake( | |
@ptrCast(*const i32, &self.state), | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAKE, | |
@as(i32, 1), | |
))) { | |
0 => {}, | |
system.EINTR => continue, | |
system.EFAULT => {}, | |
else => unreachable, | |
}; | |
} | |
} | |
} | |
else | |
struct { | |
notified: bool = false, | |
pub fn wait(self: *Event) void { | |
while (!@atomicLoad(bool, &self.notified, .Acquire)) | |
yieldThread(); | |
@atomicStore(bool, &self.notified, false, .Monotonic); | |
} | |
pub fn notify(self: *Event) void { | |
@atomicStore(bool, &self.notified, true, .Release); | |
} | |
}; | |
const yieldThread = if (std.builtin.os.tag == .windows) | |
struct { | |
fn yield() void { | |
system.kernel32.Sleep(0); | |
} | |
}.yield | |
else if (comptime std.Target.current.isDarwin()) | |
struct { | |
fn yield() void { | |
_ = thread_switch(MACH_PORT_NULL, SWITCH_OPTION_DEPRESS, 1); | |
} | |
const MACH_PORT_NULL = 0; | |
const SWITCH_OPTION_DEPRESS = 1; | |
// https://www.gnu.org/software/hurd/gnumach-doc/Hand_002dOff-Scheduling.html | |
extern "c" fn thread_switch( | |
thread: usize, | |
options: c_int, | |
timeout_ms: c_int, | |
) callconv(.C) c_int; | |
}.yield | |
else if (std.builtin.os.tag == .linux or std.builtin.link_libc) | |
struct { | |
fn yield() void { | |
_ = system.sched_yield(); | |
} | |
}.yield | |
else | |
yieldCpu; | |
fn yieldCpu() void { | |
switch (std.builtin.arch) { | |
.i386, .x86_64 => asm volatile ("pause"), | |
.arm, .aarch64 => asm volatile ("yield"), | |
else => {}, | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const system = switch (std.builtin.os.tag) { | |
.linux => std.os.linux, | |
else => std.os.system, | |
}; | |
const ThreadPool = @This(); | |
io_driver: IoDriver, | |
max_threads: u16, | |
counter: u32 = 0, | |
spawned_queue: ?*Worker = null, | |
run_queue: UnboundedQueue = .{}, | |
shutdown_event: Event = .{}, | |
pub const InitConfig = struct { | |
max_threads: ?u16 = null, | |
}; | |
pub fn init(config: InitConfig) !ThreadPool { | |
return ThreadPool{ | |
.io_driver = try IoDriver.init(), | |
.max_threads = std.math.min( | |
std.math.maxInt(u14), | |
std.math.max(1, config.max_threads orelse blk: { | |
break :blk @intCast(u16, std.Thread.cpuCount() catch 1); | |
}), | |
), | |
}; | |
} | |
pub fn deinit(self: *ThreadPool) void { | |
defer self.* = undefined; | |
defer self.io_driver.deinit(); | |
self.shutdown(); | |
self.shutdown_event.wait(); | |
while (self.spawned_queue) |worker| { | |
self.spawned_queue = worker.spawned_next; | |
const thread = worker.thread; | |
worker.shutdown_event.notify(); | |
thread.wait(); | |
} | |
} | |
pub const ScheduleHints = struct { | |
priority: Priority = .Normal, | |
pub const Priority = enum { | |
High, | |
Normal, | |
Low, | |
}; | |
}; | |
pub fn schedule(self: *ThreadPool, hints: ScheduleHints, batchable: anytype) void { | |
const batch = Batch.from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (Worker.current) |worker| { | |
worker.push(hints, batch); | |
} else { | |
self.run_queue.push(batch); | |
} | |
_ = self.tryNotifyWith(false); | |
} | |
pub const SpawnConfig = struct { | |
allocator: *std.mem.Allocator, | |
hints: ScheduleHints = .{}, | |
}; | |
pub fn spawn(self: *ThreadPool, config: SpawnConfig, comptime func: anytype, args: anytype) !void { | |
const Args = @TypeOf(args); | |
const is_async = @typeInfo(@TypeOf(func)).Fn.calling_convention == .Async; | |
const Closure = struct { | |
func_args: Args, | |
allocator: *std.mem.Allocator, | |
runnable: Runnable = .{ .runFn = runFn }, | |
frame: if (is_async) @Frame(runAsyncFn) else void = undefined, | |
fn runFn(runnable: *Runnable) void { | |
const closure = @fieldParentPtr(@This(), "runnable", runnable); | |
if (is_async) { | |
closure.frame = async closure.runAsyncFn(); | |
} else { | |
const result = @call(.{}, func, closure.func_args); | |
closure.allocator.destroy(closure); | |
} | |
} | |
fn runAsyncFn(closure: *@This()) void { | |
const result = @call(.{}, func, closure.func_args); | |
suspend closure.allocator.destroy(closure); | |
} | |
}; | |
const allocator = config.allocator; | |
const closure = try allocator.create(Closure); | |
errdefer allocator.destroy(closure); | |
closure.* = .{ | |
.func_args = args, | |
.allocator = allocator, | |
}; | |
const hints = config.hints; | |
self.schedule(hints, &closure.runnable); | |
} | |
const Counter = struct { | |
state: State = .pending, | |
notified: bool = false, | |
idle: u16 = 0, | |
spawned: u16 = 0, | |
const State = enum(u3) { | |
pending = 0, | |
notified, | |
waking, | |
waker_notified, | |
shutdown, | |
}; | |
fn pack(self: Counter) u32 { | |
return (@as(u32, @as(u3, @enumToInt(self.state))) | | |
(@as(u32, @boolToInt(self.notified)) << 3) | | |
(@as(u32, @intCast(u14, self.idle)) << 4) | | |
(@as(u32, @intCast(u14, self.spawned)) << (4 + 14))); | |
} | |
fn unpack(value: u32) Counter { | |
return Counter{ | |
.state = @intToEnum(State, @truncate(u3, value)), | |
.notified = value & (1 << 3) != 0, | |
.idle = @as(u16, @truncate(u14, value >> 4)), | |
.spawned = @as(u16, @truncate(u14, value >> (4 + 14))), | |
}; | |
} | |
}; | |
fn tryNotifyWith(self: *ThreadPool, is_caller_waking: bool) bool { | |
var spawned = false; | |
var remaining_attempts: u8 = 5; | |
var is_waking = is_caller_waking; | |
while (true) : (yieldCpu()) { | |
const counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
if (counter.state == .shutdown) { | |
if (spawned) | |
self.releaseWorker(); | |
return false; | |
} | |
const has_pending = (counter.idle > 0) or (counter.spawned < self.max_threads); | |
const can_wake = (is_waking and remaining_attempts > 0) or (!is_waking and counter.state == .pending); | |
if (has_pending and can_wake) { | |
var new_counter = counter; | |
new_counter.state = .waking; | |
if (counter.idle > 0) { | |
new_counter.idle -= 1; | |
new_counter.notified = true; | |
} else if (!spawned) { | |
new_counter.spawned += 1; | |
} | |
if (@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Acquire, | |
.Monotonic, | |
)) |failed| { | |
continue; | |
} | |
is_waking = true; | |
if (counter.idle > 0) { | |
self.idleNotify(); | |
return true; | |
} | |
spawned = true; | |
if (Worker.spawn(self)) | |
return true; | |
remaining_attempts -= 1; | |
continue; | |
} | |
var new_counter = counter; | |
if (is_waking) { | |
new_counter.state = if (can_wake) .pending else .notified; | |
if (spawned) | |
new_counter.spawned -= 1; | |
} else if (counter.state == .waking) { | |
new_counter.state = .waker_notified; | |
} else if (counter.state == .pending) { | |
new_counter.state = .notified; | |
} else { | |
return false; | |
} | |
_ = @cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Monotonic, | |
.Monotonic, | |
) orelse return true; | |
} | |
} | |
const Wait = enum { | |
resumed, | |
notified, | |
shutdown, | |
}; | |
fn tryWaitWith(self: *ThreadPool, worker: *Worker) Wait { | |
var is_waking = worker.is_waking; | |
var counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
while (true) { | |
if (counter.state == .shutdown) { | |
self.releaseWorker(); | |
return .shutdown; | |
} | |
const is_notified = switch (counter.state) { | |
.waker_notified => is_waking, | |
.notified => true, | |
else => false, | |
}; | |
var new_counter = counter; | |
if (is_notified) { | |
new_counter.state = if (is_waking) .waking else .pending; | |
} else { | |
new_counter.idle += 1; | |
if (is_waking) | |
new_counter.state = .pending; | |
} | |
if (@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Monotonic, | |
.Monotonic, | |
)) |updated| { | |
counter = Counter.unpack(updated); | |
continue; | |
} | |
if (is_notified and is_waking) | |
return .notified; | |
if (is_notified) | |
return .resumed; | |
self.idleWait(worker); | |
return .notified; | |
} | |
} | |
fn releaseWorker(self: *ThreadPool) void { | |
const counter_spawned = Counter{ .spawned = 1 }; | |
const counter_value = @atomicRmw(u32, &self.counter, .Sub, counter_spawned.pack(), .AcqRel); | |
const counter = Counter.unpack(counter_value); | |
if (counter.state != .shutdown) | |
std.debug.panic("ThreadPool.releaseWorker() when not shutdown: {}", .{counter}); | |
if (counter.spawned == 1) | |
self.shutdown_event.notify(); | |
} | |
pub fn shutdown(self: *ThreadPool) void { | |
while (true) : (yieldCpu()) { | |
const counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
if (counter.state == .shutdown) | |
return; | |
var new_counter = counter; | |
new_counter.state = .shutdown; | |
new_counter.idle = 0; | |
if (@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Acquire, | |
.Monotonic, | |
)) |failed| { | |
continue; | |
} | |
self.idleShutdown(); | |
return; | |
} | |
} | |
fn idleWait(self: *ThreadPool, worker: *Worker) void { | |
var counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
while (true) { | |
if (counter.state == .shutdown) { | |
self.io_driver.notify(); | |
return; | |
} | |
if (counter.notified) { | |
var new_counter = counter; | |
new_counter.notified = false; | |
counter = Counter.unpack(@cmpxchgWeak( | |
u32, | |
&self.counter, | |
counter.pack(), | |
new_counter.pack(), | |
.Acquire, | |
.Monotonic, | |
) orelse return); | |
continue; | |
} | |
const batch = self.io_driver.wait(); | |
self.schedule(.{}, batch); | |
counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic)); | |
} | |
} | |
fn idleNotify(self: *ThreadPool) void { | |
self.io_driver.notify(); | |
} | |
fn idleShutdown(self: *ThreadPool) void { | |
self.io_driver.notify(); | |
} | |
pub const IoRunnable = struct { | |
fd: std.os.fd_t = -1, | |
is_closable: bool = false, | |
is_readable: bool = false, | |
is_writable: bool = false, | |
runnable: Runnable, | |
}; | |
pub fn waitFor(self: *ThreadPool, fd: std.os.fd_t, io_runnable: *IoRunnable) !void { | |
return self.io_driver.register(fd, io_runnable); | |
} | |
const IoDriver = struct { | |
poll_fd: std.os.fd_t, | |
notify_fd: std.os.fd_t, | |
fn init() !IoDriver { | |
const poll_fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC); | |
errdefer std.os.close(poll_fd); | |
const notify_fd = try std.os.eventfd(0, std.os.EFD_CLOEXEC | std.os.EFD_NONBLOCK); | |
errdefer std.os.close(notify_fd); | |
var event = std.os.epoll_event{ | |
.events = std.os.EPOLLONESHOT, | |
.data = .{ .ptr = 0 }, | |
}; | |
try std.os.epoll_ctl(poll_fd, std.os.EPOLL_CTL_ADD, notify_fd, &event); | |
return IoDriver{ | |
.poll_fd = poll_fd, | |
.notify_fd = notify_fd, | |
}; | |
} | |
fn deinit(self: *IoDriver) void { | |
std.os.close(self.poll_fd); | |
std.os.close(self.notify_fd); | |
} | |
fn notify(self: *IoDriver) void { | |
const fd = self.notify_fd; | |
var event = std.os.epoll_event{ | |
.events = std.os.EPOLLOUT | std.os.EPOLLONESHOT, | |
.data = .{ .ptr = 0 }, | |
}; | |
std.os.epoll_ctl(self.poll_fd, std.os.EPOLL_CTL_MOD, fd, &event) catch |err| switch (err) { | |
error.FileDescriptorNotRegistered => { | |
std.os.epoll_ctl(self.poll_fd, std.os.EPOLL_CTL_ADD, fd, &event) catch {}; | |
}, | |
else => {}, | |
}; | |
} | |
fn register(self: *IoDriver, fd: std.os.fd_t, io_runnable: *IoRunnable) !void { | |
var events: u32 = std.os.EPOLLONESHOT | std.os.EPOLLRDHUP | std.os.EPOLLHUP; | |
if (io_runnable.is_readable) | |
events |= std.os.EPOLLIN; | |
if (io_runnable.is_writable) | |
events |= std.os.EPOLLOUT; | |
io_runnable.fd = fd; | |
var event = std.os.epoll_event{ | |
.events = events, | |
.data = .{ .ptr = @ptrToInt(io_runnable) }, | |
}; | |
std.os.epoll_ctl(self.poll_fd, std.os.EPOLL_CTL_MOD, fd, &event) catch |err| switch (err) { | |
error.FileDescriptorNotRegistered => { | |
try std.os.epoll_ctl(self.poll_fd, std.os.EPOLL_CTL_ADD, fd, &event); | |
}, | |
else => |e| return e, | |
}; | |
} | |
fn wait(self: *IoDriver) Batch { | |
var batch = Batch{}; | |
var events: [128]std.os.epoll_event = undefined; | |
const count = std.os.epoll_wait(self.poll_fd, &events, -1); | |
for (events[0..count]) |event| { | |
const io_runnable = @intToPtr(?*IoRunnable, event.data.ptr) orelse continue; | |
io_runnable.is_closable = (event.events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0); | |
io_runnable.is_writable = (event.events & (std.os.EPOLLOUT) != 0); | |
io_runnable.is_readable = (event.events & (std.os.EPOLLIN) != 0); | |
batch.push(&io_runnable.runnable); | |
} | |
return batch; | |
} | |
}; | |
const Worker = struct { | |
pool: *ThreadPool, | |
thread: *std.Thread, | |
spawned_next: ?*Worker = null, | |
shutdown_event: Event = .{}, | |
run_queue: BoundedQueue = .{}, | |
run_queue_next: ?*Runnable = null, | |
run_queue_lifo: ?*Runnable = null, | |
run_queue_overflow: UnboundedQueue = .{}, | |
tick: usize = undefined, | |
is_waking: bool = true, | |
next_target: ?*Worker = null, | |
threadlocal var current: ?*Worker = null; | |
fn spawn(pool: *ThreadPool) bool { | |
const Spawner = struct { | |
thread: *std.Thread = undefined, | |
thread_pool: *ThreadPool, | |
data_put_event: Event = .{}, | |
data_get_event: Event = .{}, | |
fn entry(self: *@This()) void { | |
self.data_put_event.wait(); | |
const thread = self.thread; | |
const thread_pool = self.thread_pool; | |
self.data_get_event.notify(); | |
Worker.run(thread, thread_pool); | |
} | |
}; | |
var spawner = Spawner{ .thread_pool = pool }; | |
spawner.thread = std.Thread.spawn(&spawner, Spawner.entry) catch return false; | |
spawner.data_put_event.notify(); | |
spawner.data_get_event.wait(); | |
return true; | |
} | |
fn run(thread: *std.Thread, pool: *ThreadPool) void { | |
var self = Worker{ | |
.thread = thread, | |
.pool = pool, | |
}; | |
self.tick = @ptrToInt(&self); | |
current = &self; | |
defer current = null; | |
var spawned_queue = @atomicLoad(?*Worker, &pool.spawned_queue, .Monotonic); | |
while (true) { | |
self.spawned_next = spawned_queue; | |
spawned_queue = @cmpxchgWeak( | |
?*Worker, | |
&pool.spawned_queue, | |
spawned_queue, | |
&self, | |
.Release, | |
.Monotonic, | |
) orelse break; | |
} | |
while (true) { | |
if (self.pop()) |runnable| { | |
if (self.is_waking) { | |
self.is_waking = false; | |
_ = pool.tryNotifyWith(true); | |
} | |
self.tick +%= 1; | |
runnable.run(); | |
continue; | |
} | |
self.is_waking = switch (pool.tryWaitWith(&self)) { | |
.resumed => false, | |
.notified => true, | |
.shutdown => { | |
self.shutdown_event.wait(); | |
break; | |
}, | |
}; | |
} | |
} | |
fn push(self: *Worker, hints: ScheduleHints, batchable: anytype) void { | |
var batch = Batch.from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (hints.priority == .High) { | |
const new_lifo = batch.pop(); | |
if (@atomicLoad(?*Runnable, &self.run_queue_lifo, .Monotonic) == null) { | |
@atomicStore(?*Runnable, &self.run_queue_lifo, new_lifo, .Release); | |
} else if (@atomicRmw(?*Runnable, &self.run_queue_lifo, .Xchg, new_lifo, .AcqRel)) |old_lifo| { | |
batch.pushFront(old_lifo); | |
} | |
} | |
if (hints.priority == .Low) { | |
if (self.run_queue_next) |old_next| | |
batch.pushFront(old_next); | |
self.run_queue_next = null; | |
self.run_queue_next = self.pop() orelse batch.pop(); | |
} | |
if (self.run_queue.push(batch)) |overflowed| | |
self.run_queue_overflow.push(overflowed); | |
} | |
fn pop(self: *Worker) ?*Runnable { | |
if (self.tick % 127 == 0) { | |
if (self.popAndStealFromOthers()) |runnable| | |
return runnable; | |
} | |
if (self.tick % 61 == 0) { | |
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable| | |
return runnable; | |
} | |
if (self.tick % 31 == 0) { | |
if (self.run_queue.popAndStealUnbounded(&self.run_queue_overflow)) |runnable| | |
return runnable; | |
} | |
if (self.tick % 13 == 0) { | |
if (self.popAndStealLifo(self)) |runnable| | |
return runnable; | |
} | |
if (self.run_queue.pop()) |runnable| | |
return runnable; | |
if (self.popAndStealLifo(self)) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&self.run_queue_overflow)) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable| | |
return runnable; | |
if (self.popAndStealFromOthers()) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable| | |
return runnable; | |
return null; | |
} | |
fn popAndStealLifo(self: *Worker, target: *Worker) ?*Runnable { | |
var run_queue_lifo = @atomicLoad(?*Runnable, &target.run_queue_lifo, .Monotonic); | |
while (true) { | |
if (run_queue_lifo == null) | |
return null; | |
run_queue_lifo = @cmpxchgWeak( | |
?*Runnable, | |
&target.run_queue_lifo, | |
run_queue_lifo, | |
null, | |
.Acquire, | |
.Monotonic, | |
) orelse return run_queue_lifo; | |
} | |
} | |
fn popAndStealFromOthers(self: *Worker) ?*Runnable { | |
var num_workers = blk: { | |
const counter_value = @atomicLoad(u32, &self.pool.counter, .Monotonic); | |
const counter = Counter.unpack(counter_value); | |
break :blk counter.spawned; | |
}; | |
while (num_workers > 0) : (num_workers -= 1) { | |
const target = self.next_target orelse blk: { | |
break :blk @atomicLoad(?*Worker, &self.pool.spawned_queue, .Acquire) orelse { | |
std.debug.panic("Worker observed empty spawned queue when work-stealing", .{}); | |
}; | |
}; | |
self.next_target = target.spawned_next; | |
if (target == self) | |
continue; | |
if (self.run_queue.popAndStealBounded(&target.run_queue)) |runnable| | |
return runnable; | |
if (self.run_queue.popAndStealUnbounded(&target.run_queue_overflow)) |runnable| | |
return runnable; | |
if (self.popAndStealLifo(target)) |runnable| | |
return runnable; | |
} | |
return null; | |
} | |
}; | |
const UnboundedQueue = struct { | |
lock: Mutex = .{}, | |
batch: Batch = .{}, | |
shared_size: usize = 0, | |
fn push(self: *UnboundedQueue, batchable: anytype) void { | |
const batch = Batch.from(batchable); | |
if (batch.isEmpty()) | |
return; | |
const held = self.lock.acquire(); | |
defer held.release(); | |
self.batch.push(batch); | |
var shared_size = self.shared_size; | |
shared_size += batch.size; | |
@atomicStore(usize, &self.shared_size, shared_size, .Release); | |
} | |
fn tryAcquireConsumer(self: *UnboundedQueue) ?Consumer { | |
var shared_size = @atomicLoad(usize, &self.shared_size, .Acquire); | |
if (shared_size == 0) | |
return null; | |
const held = self.lock.acquire(); | |
shared_size = self.shared_size; | |
if (shared_size == 0) { | |
held.release(); | |
return null; | |
} | |
return Consumer{ | |
.held = held, | |
.queue = self, | |
.size = shared_size, | |
}; | |
} | |
const Consumer = struct { | |
held: Mutex.Held, | |
queue: *UnboundedQueue, | |
size: usize, | |
fn release(self: Consumer) void { | |
@atomicStore(usize, &self.queue.shared_size, self.size, .Release); | |
self.held.release(); | |
} | |
fn pop(self: *Consumer) ?*Runnable { | |
const runnable = self.queue.batch.pop() orelse return null; | |
self.size -= 1; | |
return runnable; | |
} | |
}; | |
}; | |
const BoundedQueue = struct { | |
head: usize = 0, | |
tail: usize = 0, | |
buffer: [256]*Runnable = undefined, | |
fn push(self: *BoundedQueue, batchable: anytype) ?Batch { | |
var batch = Batch.from(batchable); | |
while (true) : (yieldCpu()) { | |
if (batch.isEmpty()) | |
return null; | |
var tail = self.tail; | |
var head = @atomicLoad(usize, &self.head, .Acquire); | |
var size = tail -% head; | |
if (size < self.buffer.len) { | |
while (size < self.buffer.len) { | |
const runnable = batch.pop() orelse break; | |
@atomicStore(*Runnable, &self.buffer[tail % self.buffer.len], runnable, .Unordered); | |
tail +%= 1; | |
size += 1; | |
} | |
@atomicStore(usize, &self.tail, tail, .Release); | |
continue; | |
} | |
var migrate = self.buffer.len / 2; | |
if (@cmpxchgWeak( | |
usize, | |
&self.head, | |
head, | |
head +% migrate, | |
.AcqRel, | |
.Acquire, | |
)) |failed| { | |
continue; | |
} | |
var overflowed = Batch{}; | |
while (migrate > 0) : (migrate -= 1) { | |
const runnable = self.buffer[head % self.buffer.len]; | |
overflowed.push(runnable); | |
head +%= 1; | |
} | |
overflowed.push(batch); | |
return overflowed; | |
} | |
} | |
fn pop(self: *BoundedQueue) ?*Runnable { | |
while (true) : (yieldCpu()) { | |
const tail = self.tail; | |
const head = @atomicLoad(usize, &self.head, .Acquire); | |
const size = tail -% head; | |
if (size == 0) | |
return null; | |
if (@cmpxchgWeak( | |
usize, | |
&self.head, | |
head, | |
head +% 1, | |
.AcqRel, | |
.Acquire, | |
)) |failed| { | |
continue; | |
} | |
const runnable = self.buffer[head % self.buffer.len]; | |
return runnable; | |
} | |
} | |
fn popAndStealBounded(self: *BoundedQueue, target: *BoundedQueue) ?*Runnable { | |
if (target == self) | |
return self.pop(); | |
const tail = self.tail; | |
const head = @atomicLoad(usize, &self.head, .Acquire); | |
const size = tail -% head; | |
if (size != 0) | |
return self.pop(); | |
while (true) : (yieldThread()) { | |
const target_head = @atomicLoad(usize, &target.head, .Acquire); | |
const target_tail = @atomicLoad(usize, &target.tail, .Acquire); | |
const target_size = target_tail -% target_head; | |
var steal = target_size - (target_size / 2); | |
if (steal == 0) | |
return null; | |
if (steal > target.buffer.len / 2) | |
continue; | |
const first_runnable = @atomicLoad(*Runnable, &target.buffer[target_head % target.buffer.len], .Unordered); | |
var new_target_head = target_head +% 1; | |
var new_tail = tail; | |
steal -= 1; | |
while (steal > 0) : (steal -= 1) { | |
const runnable = @atomicLoad(*Runnable, &target.buffer[new_target_head % target.buffer.len], .Unordered); | |
new_target_head +%= 1; | |
@atomicStore(*Runnable, &self.buffer[new_tail % self.buffer.len], runnable, .Unordered); | |
new_tail +%= 1; | |
} | |
if (@cmpxchgWeak( | |
usize, | |
&target.head, | |
target_head, | |
new_target_head, | |
.AcqRel, | |
.Acquire, | |
)) |failed| { | |
continue; | |
} | |
@atomicStore(usize, &self.tail, new_tail, .Release); | |
return first_runnable; | |
} | |
} | |
fn popAndStealUnbounded(self: *BoundedQueue, target: *UnboundedQueue) ?*Runnable { | |
var consumer = target.tryAcquireConsumer() orelse return null; | |
defer consumer.release(); | |
const first_runnable = consumer.pop() orelse return null; | |
var tail = self.tail; | |
var head = @atomicLoad(usize, &self.head, .Acquire); | |
var size = tail -% head; | |
while (size < self.buffer.len) { | |
const runnable = consumer.pop() orelse break; | |
@atomicStore(*Runnable, &self.buffer[tail % self.buffer.len], runnable, .Unordered); | |
tail +%= 1; | |
size += 1; | |
} | |
@atomicStore(usize, &self.tail, tail, .Release); | |
return first_runnable; | |
} | |
}; | |
pub const Runnable = struct { | |
next: ?*Runnable = null, | |
runFn: fn (*Runnable) void, | |
pub fn run(self: *Runnable) void { | |
return (self.runFn)(self); | |
} | |
}; | |
pub const Batch = struct { | |
head: ?*Runnable = null, | |
tail: *Runnable = undefined, | |
size: usize = 0, | |
pub fn from(batchable: anytype) Batch { | |
return switch (@TypeOf(batchable)) { | |
Batch => batchable, | |
?*Runnable => from(batchable orelse return Batch{}), | |
*Runnable => { | |
batchable.next = null; | |
return Batch{ | |
.head = batchable, | |
.tail = batchable, | |
.size = 1, | |
}; | |
}, | |
else => |typ| @compileError(@typeName(typ) ++ | |
" cannot be converted into " ++ | |
@typeName(Batch)), | |
}; | |
} | |
pub fn isEmpty(self: Batch) bool { | |
return self.head == null; | |
} | |
pub const push = pushBack; | |
pub fn pushBack(self: *Batch, batchable: anytype) void { | |
const batch = from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (self.isEmpty()) { | |
self.* = batch; | |
} else { | |
self.tail.next = batch.head; | |
self.tail = batch.tail; | |
self.size += batch.size; | |
} | |
} | |
pub fn pushFront(self: *Batch, batchable: anytype) void { | |
const batch = from(batchable); | |
if (batch.isEmpty()) | |
return; | |
if (self.isEmpty()) { | |
self.* = batch; | |
} else { | |
batch.tail.next = self.head; | |
self.head = batch.head; | |
self.size += batch.size; | |
} | |
} | |
pub const pop = popFront; | |
pub fn popFront(self: *Batch) ?*Runnable { | |
const runnable = self.head orelse return null; | |
self.head = runnable.next; | |
self.size -= 1; | |
return runnable; | |
} | |
}; | |
const Semaphore = struct { | |
lock: Mutex = .{}, | |
permits: usize = 0, | |
waiters: ?*Waiter = null, | |
const Waiter = struct { | |
next: ?*Waiter = null, | |
tail: *Waiter = undefined, | |
event: Event = .{}, | |
permits: usize, | |
}; | |
fn init(permits: usize) Semaphore { | |
return .{ .permits = permits }; | |
} | |
fn wait(self: *Semaphore, permits: usize) void { | |
const held = self.lock.acquire(); | |
if (self.permits >= permits) { | |
self.permits -= permits; | |
held.release(); | |
return; | |
} | |
var waiter = Waiter{ .permits = permits }; | |
if (self.waiters) |head| { | |
head.tail.next = &waiter; | |
head.tail = &waiter; | |
} else { | |
self.waiters = &waiter; | |
waiter.tail = &waiter; | |
} | |
held.release(); | |
waiter.event.wait(); | |
} | |
fn post(self: *Semaphore, permits: usize) error{Overflow}!void { | |
var waiters: ?*Waiter = null; | |
{ | |
const held = self.lock.acquire(); | |
defer held.release(); | |
if (@addWithOverflow(usize, self.permits, permits, &self.permits)) | |
return error.Overflow; | |
while (self.waiters) |waiter| { | |
if (waiter.permits > self.permits) | |
break; | |
self.waiters = waiter.next; | |
if (self.waiters) |new_waiter| | |
new_waiter.tail = waiter.tail; | |
self.permits -= waiter.permits; | |
waiter.next = waiters; | |
waiters = waiter; | |
} | |
} | |
while (waiters) |waiter| { | |
waiters = waiter.next; | |
waiter.event.notify(); | |
} | |
} | |
}; | |
const Mutex = if (std.builtin.os.tag == .windows) | |
struct { | |
srwlock: usize = 0, | |
pub fn acquire(self: *Mutex) Held { | |
AcquireSRWLockExclusive(&self.srwlock); | |
return Mutex{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
ReleaseSRWLockExclusive(&self.mutex.srwlock); | |
} | |
}; | |
extern "kernel32" fn AcquireSRWLockExclusive( | |
srwlock: *?system.PVOID, | |
) callconv(system.WINAPI) void; | |
extern "kernel32" fn ReleaseSRWLockExclusive( | |
srwlock: *?system.PVOID, | |
) callconv(system.WINAPI) void; | |
} | |
else if (comptime std.Target.current.isDarwin()) | |
struct { | |
lock: u32 = 0, | |
pub fn acquire(self: *Mutex) Held { | |
os_unfair_lock_lock(&self.lock); | |
return Held{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
os_unfair_lock_unlock(&self.mutex.lock); | |
} | |
}; | |
extern "c" fn os_unfair_lock_lock( | |
unfair_lock: *u32, | |
) callconv(.C) void; | |
extern "c" fn os_unfair_lock_unlock( | |
unfair_lock: *u32, | |
) callconv(.C) void; | |
} | |
else if (std.builtin.os.tag == .linux) | |
struct { | |
state: i32 = UNLOCKED, | |
const UNLOCKED: i32 = 0; | |
const LOCKED: i32 = 1; | |
const WAITING: i32 = 2; | |
pub fn acquire(self: *Mutex) Held { | |
const state = @atomicRmw(i32, &self.state, .Xchg, LOCKED, .Acquire); | |
if (state != UNLOCKED) | |
self.acquireSlow(state); | |
return Held{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
switch (@atomicRmw(i32, &self.mutex.state, .Xchg, UNLOCKED, .Release)) { | |
UNLOCKED => unreachable, // unlocked an unlocked mutex | |
LOCKED => {}, | |
WAITING => self.mutex.releaseSlow(), | |
else => unreachable, | |
} | |
} | |
}; | |
fn acquireSlow(self: *Mutex, current_state: i32) void { | |
@setCold(true); | |
var wait_state = current_state; | |
while (true) { | |
var spin: u8 = 0; | |
while (spin < 5) : (spin += 1) { | |
switch (@atomicLoad(i32, &self.state, .Monotonic)) { | |
UNLOCKED => _ = @cmpxchgWeak( | |
i32, | |
&self.state, | |
UNLOCKED, | |
wait_state, | |
.Acquire, | |
.Monotonic, | |
) orelse return, | |
LOCKED => {}, | |
WAITING => break, | |
else => unreachable, | |
} | |
if (spin < 4) { | |
var pause: u8 = 30; | |
while (pause > 0) : (pause -= 1) | |
yieldCpu(); | |
} else { | |
yieldThread(); | |
} | |
} | |
const state = @atomicRmw(i32, &self.state, .Xchg, WAITING, .Acquire); | |
if (state == UNLOCKED) | |
return; | |
wait_state = WAITING; | |
switch (system.getErrno(system.futex_wait( | |
&self.state, | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAIT, | |
WAITING, | |
null, | |
))) { | |
0 => {}, | |
system.EINTR => {}, | |
system.EAGAIN => {}, | |
else => unreachable, | |
} | |
} | |
} | |
fn releaseSlow(self: *Mutex) void { | |
@setCold(true); | |
while (true) { | |
return switch (system.getErrno(system.futex_wake( | |
&self.state, | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAKE, | |
@as(i32, 1), | |
))) { | |
0 => {}, | |
system.EINTR => continue, | |
system.EFAULT => {}, | |
else => unreachable, | |
}; | |
} | |
} | |
} | |
else | |
struct { | |
locked: bool = false, | |
pub fn acquire(self: *Mutex) Held { | |
while (@atomicRmw(bool, &self.locked, .Xchg, true, .Acquire)) | |
yieldThread(); | |
return Held{ .mutex = self }; | |
} | |
pub const Held = struct { | |
mutex: *Mutex, | |
pub fn release(self: Held) void { | |
@atomicStore(bool, &self.mutex.locked, false, .Release); | |
} | |
}; | |
}; | |
const Event = if (std.builtin.os.tag == .windows) | |
struct { | |
key: u32 = undefined, | |
pub fn wait(self: *Event) void { | |
const status = NtWaitForKeyedEvent(null, &self.key, system.FALSE, null); | |
std.debug.assert(status == .SUCCESS); | |
} | |
pub fn notify(self: *Event) void { | |
const status = NtReleaseKeyedEvent(null, &self.key, system.FALSE, null); | |
std.debug.assert(status == .SUCCESS); | |
} | |
extern "NtDll" fn NtWaitForKeyedEvent( | |
handle: ?system.HANDLE, | |
key: ?*const u32, | |
alertable: system.BOOLEAN, | |
timeout: ?*const system.LARGE_INTEGER, | |
) callconv(system.WINAPI) system.NTSTATUS; | |
extern "NtDll" fn NtReleaseKeyedEvent( | |
handle: ?system.HANDLE, | |
key: ?*const u32, | |
alertable: system.BOOLEAN, | |
timeout: ?*const system.LARGE_INTEGER, | |
) callconv(system.WINAPI) system.NTSTATUS; | |
} | |
else if (comptime std.Target.current.isDarwin()) | |
struct { | |
state: enum(u32) { | |
pending = 0, | |
notified, | |
} = .pending, | |
pub fn wait(self: *Event) void { | |
while (true) { | |
switch (@atomicLoad(@TypeOf(self.state), &self.state, .Acquire)) { | |
.pending => {}, | |
.notified => { | |
@atomicStore(@TypeOf(self.state), &self.state, .pending, .Monotonic); | |
return; | |
}, | |
} | |
const status = __ulock_wait( | |
UL_COMPARE_AND_WAIT | ULF_NO_ERRNO, | |
@ptrCast(?*const c_void, &self.state), | |
@enumToInt(@TypeOf(self.state).pending), | |
~@as(u32, 0), | |
); | |
if (status < 0) { | |
switch (-status) { | |
system.EINTR => {}, | |
else => unreachable, | |
} | |
} | |
} | |
} | |
pub fn notify(self: *Event) void { | |
@atomicStore(@TypeOf(self.state), &self.state, .notified, .Release); | |
while (true) { | |
const status = __ulock_wake( | |
UL_COMPARE_AND_WAIT | ULF_NO_ERRNO, | |
@ptrCast(?*const c_void, &self.state), | |
@as(u32, 0), | |
); | |
if (status < 0) { | |
switch (-status) { | |
system.ENOENT => {}, | |
system.EINTR => continue, | |
else => unreachable, | |
} | |
} | |
return; | |
} | |
} | |
const ULF_NO_ERRNO = 0x1000000; | |
const UL_COMPARE_AND_WAIT = 0x1; | |
extern "c" fn __ulock_wait( | |
operation: u32, | |
address: ?*const c_void, | |
value: u64, | |
timeout_us: u32, | |
) callconv(.C) c_int; | |
extern "c" fn __ulock_wake( | |
operation: u32, | |
address: ?*const c_void, | |
value: u64, | |
) callconv(.C) c_int; | |
} | |
else if (std.builtin.os.tag == .linux) | |
struct { | |
state: enum(i32) { | |
pending, | |
notified, | |
} = .pending, | |
pub fn wait(self: *Event) void { | |
while (true) { | |
switch (@atomicLoad(@TypeOf(self.state), &self.state, .Acquire)) { | |
.pending => {}, | |
.notified => { | |
@atomicStore(@TypeOf(self.state), &self.state, .pending, .Monotonic); | |
return; | |
}, | |
} | |
switch (system.getErrno(system.futex_wait( | |
@ptrCast(*const i32, &self.state), | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAIT, | |
@enumToInt(@TypeOf(self.state).pending), | |
null, | |
))) { | |
0 => {}, | |
system.EINTR => {}, | |
system.EAGAIN => {}, | |
else => unreachable, | |
} | |
} | |
} | |
pub fn notify(self: *Event) void { | |
@atomicStore(@TypeOf(self.state), &self.state, .notified, .Release); | |
while (true) { | |
return switch (system.getErrno(system.futex_wake( | |
@ptrCast(*const i32, &self.state), | |
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAKE, | |
@as(i32, 1), | |
))) { | |
0 => {}, | |
system.EINTR => continue, | |
system.EFAULT => {}, | |
else => unreachable, | |
}; | |
} | |
} | |
} | |
else | |
struct { | |
notified: bool = false, | |
pub fn wait(self: *Event) void { | |
while (!@atomicLoad(bool, &self.notified, .Acquire)) | |
yieldThread(); | |
@atomicStore(bool, &self.notified, false, .Monotonic); | |
} | |
pub fn notify(self: *Event) void { | |
@atomicStore(bool, &self.notified, true, .Release); | |
} | |
}; | |
const yieldThread = if (std.builtin.os.tag == .windows) | |
struct { | |
fn yield() void { | |
system.kernel32.Sleep(0); | |
} | |
}.yield | |
else if (comptime std.Target.current.isDarwin()) | |
struct { | |
fn yield() void { | |
_ = thread_switch(MACH_PORT_NULL, SWITCH_OPTION_DEPRESS, 1); | |
} | |
const MACH_PORT_NULL = 0; | |
const SWITCH_OPTION_DEPRESS = 1; | |
// https://www.gnu.org/software/hurd/gnumach-doc/Hand_002dOff-Scheduling.html | |
extern "c" fn thread_switch( | |
thread: usize, | |
options: c_int, | |
timeout_ms: c_int, | |
) callconv(.C) c_int; | |
}.yield | |
else if (std.builtin.os.tag == .linux or std.builtin.link_libc) | |
struct { | |
fn yield() void { | |
_ = system.sched_yield(); | |
} | |
}.yield | |
else | |
yieldCpu; | |
fn yieldCpu() void { | |
switch (std.builtin.arch) { | |
.i386, .x86_64 => asm volatile ("pause"), | |
.arm, .aarch64 => asm volatile ("yield"), | |
else => {}, | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const linux = std.os.linux; | |
pub fn main() !void { | |
var ring: Ring = undefined; | |
try ring.init(); | |
defer ring.deinit(); | |
var server: Server = undefined; | |
try server.init(&ring, 12345); | |
defer server.deinit(); | |
while (true) { | |
try ring.poll(); | |
} | |
} | |
const Ring = struct { | |
inner: linux.IO_Uring, | |
head: ?*Completion, | |
tail: ?*Completion, | |
fn init(self: *Ring) !void { | |
self.inner = try linux.IO_Uring.init(512, 0); | |
self.head = null; | |
self.tail = null; | |
} | |
fn deinit(self: *Ring) void { | |
self.inner.deinit(); | |
} | |
const Completion = struct { | |
result: i32 = undefined, | |
next: ?*Completion = null, | |
onComplete: fn(*Ring, *Completion) void, | |
}; | |
fn flushCompletions(self: *Ring) !void { | |
var chunk: [256]linux.io_uring_cqe = undefined; | |
while (true) { | |
const found = try self.inner.copy_cqes(&chunk, 0); | |
if (found == 0) | |
break; | |
for (chunk[0..found]) |cqe| { | |
const completion = @intToPtr(*Ring.Completion, @intCast(usize, cqe.user_data)); | |
completion.result = cqe.res; | |
if (self.head == null) | |
self.head = completion; | |
if (self.tail) |tail| | |
tail.next = completion; | |
completion.next = null; | |
self.tail = completion; | |
} | |
} | |
} | |
fn getSubmission(self: *Ring) !*linux.io_uring_sqe { | |
while (true) { | |
return self.inner.get_sqe() catch { | |
try self.flushCompletions(); | |
_ = try self.inner.submit(); | |
continue; | |
}; | |
} | |
} | |
fn poll(self: *Ring) !void { | |
while (self.head) |completion| { | |
self.head = completion.next; | |
if (self.head == null) | |
self.tail = null; | |
(completion.onComplete)(self, completion); | |
} | |
_ = try self.inner.submit_and_wait(1); | |
try self.flushCompletions(); | |
} | |
fn submitPoll(self: *Ring, fd: std.os.fd_t, flags: u32, completion: *Completion) !void { | |
const sqe = try self.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .POLL_ADD; | |
sqe.fd = fd; | |
sqe.rw_flags = flags | linux.POLLERR | linux.POLLHUP; | |
sqe.user_data = @ptrToInt(completion); | |
} | |
}; | |
const Server = struct { | |
fd: std.os.socket_t, | |
completion: Ring.Completion, | |
gpa: std.heap.GeneralPurposeAllocator(.{}), | |
state: enum { poll, accept }, | |
fn init(self: *Server, ring: *Ring, comptime port: u16) !void { | |
self.fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP); | |
errdefer std.os.close(self.fd); | |
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable; | |
try std.os.setsockopt(self.fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(self.fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(self.fd, 128); | |
self.gpa = .{}; | |
self.state = .poll; | |
self.completion = .{ .onComplete = Server.onCompletion }; | |
try self.submitAccept(ring); | |
std.debug.warn("Listening on :{}\n", .{port}); | |
} | |
fn deinit(self: *Server) void { | |
std.os.close(self.fd); | |
std.debug.warn("Server closed\n", .{}); | |
} | |
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void { | |
const self = @fieldParentPtr(Server, "completion", completion); | |
self.process(ring, completion.result) catch self.deinit(); | |
} | |
fn process(self: *Server, ring: *Ring, result: i32) !void { | |
switch (self.state) { | |
.poll => { | |
if (result & (linux.POLLERR | linux.POLLHUP) != 0) | |
return error.Closed; | |
try self.submitAccept(ring); | |
}, | |
.accept => { | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => { | |
try self.submitAccept(ring); | |
return; | |
}, | |
std.os.EAGAIN => { | |
try ring.submitPoll(self.fd, linux.POLLIN, &self.completion); | |
self.state = .poll; | |
return; | |
}, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const client_fd: std.os.socket_t = result; | |
Client.init(self, ring, client_fd) catch |err| { | |
std.os.close(client_fd); | |
std.debug.warn("Failed to start client: {}\n", .{client_fd}); | |
}; | |
try self.submitAccept(ring); | |
}, | |
} | |
} | |
fn submitAccept(self: *Server, ring: *Ring) !void { | |
const sqe = try ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .ACCEPT; | |
sqe.fd = self.fd; | |
sqe.rw_flags = std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC; | |
sqe.user_data = @ptrToInt(&self.completion); | |
self.state = .accept; | |
} | |
}; | |
const Client = struct { | |
server: *Server, | |
fd: std.os.socket_t, | |
reader: Reader, | |
writer: Writer, | |
is_closed: bool, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn init(server: *Server, ring: *Ring, fd: std.os.socket_t) !void { | |
const allocator = &server.gpa.allocator; | |
const self = try allocator.create(Client); | |
errdefer allocator.destroy(self); | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
self.fd = fd; | |
self.server = server; | |
self.is_closed = false; | |
try self.reader.init(ring); | |
try self.writer.init(ring); | |
} | |
fn deinit(self: *Client) void { | |
std.os.close(self.fd); | |
const allocator = &self.server.gpa.allocator; | |
allocator.destroy(self); | |
} | |
const Reader = struct { | |
state: enum { poll, read, stop } = .stop, | |
completion: Ring.Completion, | |
recv_bytes: usize = 0, | |
recv_buffer: [4096]u8 = undefined, | |
iovec: std.os.iovec = undefined, | |
fn init(self: *Reader, ring: *Ring) !void { | |
const client = @fieldParentPtr(Client, "reader", self); | |
self.* = .{ .completion = .{ .onComplete = onCompletion } }; | |
self.iovec.iov_base = &self.recv_buffer; | |
self.iovec.iov_len = self.recv_buffer.len; | |
try ring.submitPoll(client.fd, linux.POLLIN, &self.completion); | |
self.state = .poll; | |
} | |
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void { | |
const self = @fieldParentPtr(Reader, "completion", completion); | |
const client = @fieldParentPtr(Client, "reader", self); | |
self.process(client, ring, completion.result) catch |err| { | |
if (self.state == .stop and client.writer.state == .stop) | |
client.deinit(); | |
}; | |
} | |
fn process(self: *Reader, client: *Client, ring: *Ring, result: i32) !void { | |
errdefer { | |
self.state = .stop; | |
client.is_closed = true; | |
if (client.writer.state == .idle) | |
client.writer.state = .stop; | |
} | |
if (client.is_closed) | |
return error.closed; | |
switch (self.state) { | |
.poll => { | |
if (result & (linux.POLLERR | linux.POLLHUP) != 0) | |
return error.Closed; | |
try self.submitRead(ring); | |
}, | |
.read => { | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => { | |
try self.submitRead(ring); | |
return; | |
}, | |
std.os.EAGAIN => { | |
try ring.submitPoll(client.fd, linux.POLLIN, &self.completion); | |
self.state = .poll; | |
return; | |
}, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const bytes = @intCast(usize, result); | |
self.recv_bytes += bytes; | |
if (bytes == 0) | |
return error.Eof; | |
while (true) { | |
const request_buffer = self.recv_buffer[0..self.recv_bytes]; | |
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| { | |
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len]; | |
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer); | |
self.recv_bytes = unparsed_buffer.len; | |
client.writer.send_bytes += HTTP_RESPONSE.len; | |
if (client.writer.state == .idle) { | |
client.writer.state = .write; | |
try client.writer.process(client, ring, 0); | |
} | |
continue; | |
} | |
const readable_buffer = self.recv_buffer[self.recv_bytes..]; | |
if (readable_buffer.len == 0) | |
return error.HttpRequestTooLarge; | |
self.iovec.iov_base = readable_buffer.ptr; | |
self.iovec.iov_len = readable_buffer.len; | |
try ring.submitPoll(client.fd, linux.POLLIN, &self.completion); | |
self.state = .poll; | |
return; | |
} | |
}, | |
else => {} | |
} | |
} | |
fn submitRead(self: *Reader, ring: *Ring) !void { | |
const sqe = try ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .READV; | |
sqe.fd = @fieldParentPtr(Client, "reader", self).fd; | |
sqe.addr = @ptrToInt(&self.iovec); | |
sqe.len = 1; | |
sqe.user_data = @ptrToInt(&self.completion); | |
self.state = .read; | |
} | |
}; | |
const Writer = struct { | |
state: enum { poll, write, idle, stop } = .idle, | |
completion: Ring.Completion, | |
send_bytes: usize = 0, | |
send_partial: usize = 0, | |
iovec: std.os.iovec_const = undefined, | |
fn init(self: *Writer, ring: *Ring) !void { | |
self.* = .{ .completion = .{ .onComplete = onCompletion } }; | |
} | |
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void { | |
const self = @fieldParentPtr(Writer, "completion", completion); | |
const client = @fieldParentPtr(Client, "writer", self); | |
self.process(client, ring, completion.result) catch |err| { | |
if (self.state == .stop and client.reader.state == .stop) | |
client.deinit(); | |
}; | |
} | |
fn process(self: *Writer, client: *Client, ring: *Ring, result: i32) !void { | |
errdefer { | |
self.state = .stop; | |
client.is_closed = true; | |
} | |
if (client.is_closed) | |
return error.closed; | |
switch (self.state) { | |
.poll => { | |
if (result & (linux.POLLERR | linux.POLLHUP) != 0) | |
return error.Closed; | |
try self.submitWrite(ring); | |
}, | |
.write => { | |
switch (if (result < 0) -result else @as(i32, 0)) { | |
0 => {}, | |
std.os.EINTR => { | |
try self.submitWrite(ring); | |
return; | |
}, | |
std.os.EAGAIN => { | |
try ring.submitPoll(client.fd, linux.POLLOUT, &self.completion); | |
self.state = .poll; | |
return; | |
}, | |
else => { | |
return std.os.unexpectedErrno(@intCast(usize, -result)); | |
}, | |
} | |
const bytes_written = @intCast(usize, result); | |
self.send_bytes -= bytes_written; | |
self.send_partial = bytes_written % HTTP_RESPONSE.len; | |
if (self.send_bytes == 0) { | |
self.state = .idle; | |
return; | |
} | |
const NUM_RESPONSE_CHUNKS = 128; | |
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS; | |
self.iovec.iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial; | |
self.iovec.iov_len = std.math.min(self.send_bytes, RESPONSE_CHUNK.len - self.send_partial); | |
try ring.submitPoll(client.fd, linux.POLLOUT, &self.completion); | |
self.state = .poll; | |
}, | |
else => {}, | |
} | |
} | |
fn submitWrite(self: *Writer, ring: *Ring) !void { | |
const sqe = try ring.getSubmission(); | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = .WRITEV; | |
sqe.fd = @fieldParentPtr(Client, "writer", self).fd; | |
sqe.addr = @ptrToInt(&self.iovec); | |
sqe.len = 1; | |
sqe.user_data = @ptrToInt(&self.completion); | |
self.state = .write; | |
} | |
}; | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const linux = std.os.linux; | |
const Allocator = std.mem.Allocator; | |
pub fn main() !void { | |
var ring = try Ring.init(512); | |
defer ring.deinit(); | |
var gpa = std.heap.GeneralPurposeAllocator(.{}){}; | |
defer if (gpa.deinit()) | |
unreachable; | |
_ = async Server.start(.{ | |
.ring = &ring, | |
.allocator = &gpa.allocator, | |
.port = 12345, | |
}); | |
while (true) { | |
try ring.poll(); | |
} | |
} | |
const Ring = struct { | |
io_uring: linux.IO_Uring, | |
cq_head: ?*Completion = null, | |
cq_tail: *Completion = undefined, | |
fn init(max_submissions: u12) !Ring { | |
const io_uring = try linux.IO_Uring.init(max_submissions, 0); | |
return Ring{ .io_uring = io_uring }; | |
} | |
fn deinit(self: *Ring) void { | |
self.io_uring.deinit(); | |
} | |
const Completion = struct { | |
frame: usize, | |
result: i32 = undefined, | |
next: ?*Completion = undefined, | |
}; | |
fn poll(self: *Ring) !void { | |
while (self.cq_head) |completion| { | |
self.cq_head = completion.next; | |
resume @intToPtr(anyframe, completion.frame); | |
} | |
_ = try self.io_uring.submit_and_wait(1); | |
try self.flushCompletions(); | |
} | |
fn flushCompletions(self: *Ring) !void { | |
var chunk: [256]linux.io_uring_cqe = undefined; | |
while (true) { | |
const found = try self.io_uring.copy_cqes(&chunk, 0); | |
if (found == 0) | |
break; | |
for (chunk[0..found]) |cqe| { | |
const completion = @intToPtr(*Completion, @intCast(usize, cqe.user_data)); | |
completion.result = cqe.res; | |
completion.next = null; | |
defer self.cq_tail = completion; | |
if (self.cq_head == null) { | |
self.cq_head = completion; | |
} else { | |
self.cq_tail.next = completion; | |
} | |
} | |
} | |
} | |
const Result = union(enum) { | |
Ok: u32, | |
Err: u16, | |
}; | |
fn submit(self: *Ring, op: linux.IORING_OP, fd: std.os.fd_t, addr: usize, len: u32, flags: u32) Result { | |
while (true) { | |
const sqe = blk: { | |
while (true) { | |
break :blk self.io_uring.get_sqe() catch { | |
self.flushCompletions() catch return .{ .Err = linux.ENOMEM }; | |
_ = self.io_uring.submit() catch return .{ .Err = linux.ENOMEM }; | |
continue; | |
}; | |
} | |
}; | |
sqe.* = std.mem.zeroes(@TypeOf(sqe.*)); | |
sqe.opcode = op; | |
sqe.fd = fd; | |
sqe.addr = addr; | |
sqe.len = len; | |
sqe.rw_flags = flags; | |
var completion = Completion{ .frame = @ptrToInt(@frame()) }; | |
suspend sqe.user_data = @ptrToInt(&completion); | |
const result = completion.result; | |
if (result == -linux.EAGAIN) | |
continue; | |
if (result < 0) | |
return .{ .Err = @intCast(u16, -result) }; | |
return .{ .Ok = @intCast(u32, result) }; | |
} | |
} | |
}; | |
const Server = struct { | |
fn start(args: anytype) void { | |
Server.run(args) catch |err| { | |
std.debug.warn("Server shutdown with {}\n", .{err}); | |
}; | |
} | |
fn run(args: anytype) !void { | |
const server_fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM, std.os.IPPROTO_TCP); | |
defer std.os.close(server_fd); | |
var addr = try std.net.Address.parseIp("127.0.0.1", args.port); | |
try std.os.setsockopt(server_fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
try std.os.bind(server_fd, &addr.any, addr.getOsSockLen()); | |
try std.os.listen(server_fd, 128); | |
std.debug.warn("Listening on :{}\n", .{args.port}); | |
while (true) { | |
const client_fd = switch (args.ring.submit(.ACCEPT, server_fd, 0, 0, 0)) { | |
.Ok => |fd| @intCast(std.os.socket_t, fd), | |
.Err => return error.AcceptError, | |
}; | |
Client.start(.{ | |
.fd = client_fd, | |
.ring = args.ring, | |
.allocator = args.allocator, | |
}) catch |err| { | |
std.os.close(client_fd); | |
std.debug.warn("Failed to start client: {}\n", .{err}); | |
}; | |
} | |
} | |
}; | |
const Client = struct { | |
ring: *Ring, | |
fd: std.os.socket_t, | |
send_bytes: usize = 0, | |
writer_state: usize = WRITER_EMPTY, | |
const HTTP_CLRF = "\r\n\r\n"; | |
const HTTP_RESPONSE = | |
"HTTP/1.1 200 Ok\r\n" ++ | |
"Content-Length: 10\r\n" ++ | |
"Content-Type: text/plain; charset=utf8\r\n" ++ | |
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++ | |
"Server: fasthttp\r\n" ++ | |
"\r\n" ++ | |
"HelloWorld"; | |
fn start(args: anytype) !void { | |
const SOL_TCP = 6; | |
const TCP_NODELAY = 1; | |
try std.os.setsockopt(args.fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1))); | |
const frame = try args.allocator.create(@Frame(Client.run)); | |
errdefer args.allocator.destroy(frame); | |
frame.* = async Client.run( | |
args.fd, | |
args.ring, | |
args.allocator, | |
); | |
} | |
fn run(fd: std.os.socket_t, ring: *Ring, allocator: *std.mem.Allocator) void { | |
var self = Client{ | |
.fd = fd, | |
.ring = ring, | |
}; | |
var writer = async self.runWriter(); | |
var reader = async self.runReader(); | |
await writer catch {}; | |
await reader catch {}; | |
suspend { | |
std.os.close(fd); | |
allocator.destroy(@frame()); | |
} | |
} | |
const WRITER_EMPTY = 0; | |
const WRITER_NOTIFY = 1; | |
const WRITER_CLOSED = 2; | |
const WRITER_FRAME = ~@as(usize, WRITER_NOTIFY | WRITER_CLOSED); | |
fn writerWait(self: *Client) void { | |
switch (self.writer_state) { | |
WRITER_EMPTY => { | |
suspend self.writer_state = @ptrToInt(@frame()); | |
}, | |
WRITER_NOTIFY => self.writer_state = WRITER_EMPTY, | |
WRITER_CLOSED => {}, | |
else => unreachable, | |
} | |
} | |
fn writerSet(self: *Client, state: usize) void { | |
var writer_state = state; | |
std.mem.swap(usize, &self.writer_state, &writer_state); | |
if (@intToPtr(?anyframe, writer_state & WRITER_FRAME)) |frame| | |
resume frame; | |
} | |
fn writerIsSet(self: *Client, state: usize) bool { | |
return self.writer_state == state; | |
} | |
fn runReader(self: *Client) !void { | |
defer self.writerSet(WRITER_CLOSED); | |
var buffer = std.fifo.LinearFifo(u8, .{ .Static = 4096 }).init(); | |
while (true) { | |
if (self.writerIsSet(WRITER_CLOSED)) | |
return error.Closed; | |
if (std.mem.indexOf(u8, buffer.readableSlice(0), HTTP_CLRF)) |parsed| { | |
buffer.discard(parsed + HTTP_CLRF.len); | |
buffer.realign(); | |
self.send_bytes = self.send_bytes + HTTP_RESPONSE.len; | |
self.writerSet(WRITER_NOTIFY); | |
continue; | |
} | |
const read_buf = buffer.writableSlice(0); | |
if (read_buf.len == 0) | |
return error.HttpRequestTooLarge; | |
const bytes = switch (self.ring.submit( | |
.RECV, | |
self.fd, | |
@ptrToInt(read_buf.ptr), | |
@intCast(u32, read_buf.len), | |
0, | |
)) { | |
.Ok => |bytes| @as(usize, bytes), | |
.Err => return error.ReadError, | |
}; | |
buffer.update(bytes); | |
if (bytes == 0) | |
return error.Eof; | |
} | |
} | |
fn runWriter(self: *Client) !void { | |
defer self.writerSet(WRITER_CLOSED); | |
var send_partial: usize = 0; | |
const HTTP_CHUNKS = 128; | |
const HTTP_RESPONSE_CHUNK = HTTP_RESPONSE ** HTTP_CHUNKS; | |
while (true) { | |
if (self.writerIsSet(WRITER_CLOSED)) | |
return error.Closed; | |
const send_bytes = self.send_bytes; | |
if (send_bytes == 0) { | |
self.writerWait(); | |
continue; | |
} | |
const bytes = switch (self.ring.submit( | |
.SEND, | |
self.fd, | |
@ptrToInt(&HTTP_RESPONSE_CHUNK[0]) + send_partial, | |
@intCast(u32, std.math.min(send_bytes, HTTP_RESPONSE_CHUNK.len - send_partial)), | |
std.os.MSG_NOSIGNAL, | |
)) { | |
.Ok => |bytes| @as(usize, bytes), | |
.Err => return error.WriteError, | |
}; | |
self.send_bytes = self.send_bytes - bytes; | |
send_partial = bytes % HTTP_RESPONSE.len; | |
} | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment