Skip to content

Instantly share code, notes, and snippets.

@daurnimator
Forked from kprotty/ThreadPool.zig
Last active January 3, 2021 18:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save daurnimator/699320cda828303671a21d15bb4a3753 to your computer and use it in GitHub Desktop.
Save daurnimator/699320cda828303671a21d15bb4a3753 to your computer and use it in GitHub Desktop.
simple http server compatible with wrk for linux using epoll
const std = @import("std");
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
const allocator = &gpa.allocator;
defer _ = gpa.deinit();
var poller = try Poller.init(allocator);
defer poller.deinit();
try Server.start(&poller, 12345);
while (true) {
poller.poll();
}
}
const Poller = struct {
fd: std.os.fd_t,
allocator: *std.mem.Allocator,
fn init(allocator: *std.mem.Allocator) !Poller {
return Poller{
.fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC),
.allocator = allocator,
};
}
fn deinit(self: *Poller) void {
std.os.close(self.fd);
}
const Event = struct {
const Callback = struct {
onEventFn: fn(*Callback, Event) void,
};
is_closable: bool,
is_readable: bool,
is_writable: bool,
};
const Socket = struct {
poller: *Poller,
fd: std.os.socket_t,
callback: Event.Callback,
fn start(comptime Self: type, poller: *Poller, fd: std.os.socket_t) !*Self {
const Callback = struct {
fn onEvent(callback: *Event.Callback, event: Event) void {
const socket = @fieldParentPtr(Socket, "callback", callback);
const self = @fieldParentPtr(Self, "socket", socket);
handleEvent(self, event) catch {
self.onClose();
std.os.close(self.socket.fd);
self.socket.poller.allocator.destroy(self);
};
}
fn handleEvent(self: *Self, event: Event) !void {
if (event.is_closable)
return error.Closed;
if (event.is_readable)
try self.onRead();
if (event.is_writable)
try self.onWrite();
}
};
const self = try poller.allocator.create(Self);
errdefer poller.allocator.destroy(self);
self.socket = .{
.poller = poller,
.fd = fd,
.callback = .{
.onEventFn = Callback.onEvent,
},
};
// Register with edge-triggering (EPOLLET) so that it re-arms the events when we do IO.
// Saves doing another epoll_ctl() syscall to rearm if using level/edge-trigerring.
try std.os.epoll_ctl(poller.fd, std.os.EPOLL_CTL_ADD, fd, &std.os.epoll_event{
.events = std.os.EPOLLIN | std.os.EPOLLOUT | std.os.EPOLLET | std.os.EPOLLRDHUP,
.data = .{ .ptr = @ptrToInt(&self.socket.callback) },
});
return self;
}
};
fn poll(self: *Poller) void {
var events: [128]std.os.epoll_event = undefined;
const events_found = std.os.epoll_wait(self.fd, &events, -1);
if (events_found == 0)
return;
for (events[0..events_found]) |ev| {
const callback = @intToPtr(*Event.Callback, ev.data.ptr);
(callback.onEventFn)(callback, Event{
.is_closable = ev.events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0,
.is_readable = ev.events & std.os.EPOLLIN != 0,
.is_writable = ev.events & std.os.EPOLLOUT != 0,
});
}
}
};
const Server = struct {
socket: Poller.Socket,
port: u16,
const SOCK_FLAGS = std.os.SOCK_CLOEXEC | std.os.SOCK_NONBLOCK;
fn start(poller: *Poller, comptime port: u16) !void {
const fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | SOCK_FLAGS, std.os.IPPROTO_TCP);
errdefer std.os.close(fd);
// Bind the socket to the port on the local address
const address = "127.0.0.1";
var addr = comptime std.net.Address.parseIp(address, port) catch unreachable;
try std.os.setsockopt(fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
try std.os.bind(fd, &addr.any, addr.getOsSockLen());
try std.os.listen(fd, 128);
const self = try Poller.Socket.start(Server, poller, fd);
self.port = port;
std.debug.warn("Listening on {}:{}", .{address, port});
}
fn onClose(self: *Server) void {
std.debug.warn("server shutdown for port: {}", .{self.port});
}
fn onWrite(self: *Server) !void {
std.debug.panic("server shouldn't writable", .{});
}
fn onRead(self: *Server) !void {
while (true) {
const client_fd = std.os.accept(self.socket.fd, null, null, SOCK_FLAGS) catch |err| switch (err) {
error.WouldBlock => return,
else => |e| return e,
};
Client.start(self.socket.poller, client_fd) catch |err| {
std.debug.warn("Failed to spawn a client: {}\n", .{err});
continue;
};
}
}
};
const Client = struct {
socket: Poller.Socket,
send_bytes: usize,
send_partial: usize,
recv_bytes: usize,
recv_buffer: [4096]u8,
const HTTP_CLRF = "\r\n\r\n";
const HTTP_RESPONSE =
"HTTP/1.1 200 Ok\r\n" ++
"Content-Length: 10\r\n" ++
"Content-Type: text/plain; charset=utf8\r\n" ++
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++
"Server: fasthttp\r\n" ++
"\r\n" ++
"HelloWorld";
fn start(poller: *Poller, fd: std.os.socket_t) !void {
errdefer std.os.close(fd);
// Enable TCP-NoDelay to send the http responses as fast as possible.
const SOL_TCP = 6;
const TCP_NODELAY = 1;
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1)));
const self = try Poller.Socket.start(Client, poller, fd);
self.send_bytes = 0;
self.send_partial = 0;
self.recv_bytes = 0;
self.recv_buffer = undefined;
}
fn onClose(self: *Client) void {
// Do nothing
}
fn onRead(self: *Client) !void {
while (true) {
const request_buffer = self.recv_buffer[0..self.recv_bytes];
// Try to parse and consume a request in the request_buffer
// by matching everything until the end of an HTTP request with no body
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| {
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len];
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer);
self.recv_bytes = unparsed_buffer.len;
// If found, count that as a parsed request
// and have the writer write the static HTTP response (eventually).
self.send_bytes += HTTP_RESPONSE.len;
continue;
}
// A complete wasn't parsed yet.
// Try to read more data into the buffer and try again.
const readable_buffer = self.recv_buffer[self.recv_bytes..];
if (readable_buffer.len == 0)
return error.HttpRequestTooLarge;
// If the read would normally block,
// then we have to wait for the socket to be readable in the future to try again.
const bytes_read = std.os.read(self.socket.fd, readable_buffer) catch |err| switch (err) {
error.WouldBlock => return,
else => |e| return e,
};
// 0 bytes read indicates that the socket can no longer read any data.
self.recv_bytes += bytes_read;
if (bytes_read == 0)
return error.EndOfStream;
}
}
fn onWrite(self: *Client) !void {
const NUM_RESPONSE_CHUNKS = 128;
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS;
while (self.send_bytes > 0) {
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial;
const iov_len = std.math.min(self.send_bytes, RESPONSE_CHUNK.len);
const writable_buffer = iov_base[0..iov_len];
// Perform the actual write.
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing.
const bytes_written = std.os.sendto(
self.socket.fd,
writable_buffer,
std.os.MSG_NOSIGNAL,
null,
@as(std.os.socklen_t, 0),
) catch |err| switch (err) {
error.WouldBlock => return,
else => |e| return e,
};
self.send_bytes -= bytes_written;
self.send_partial = bytes_written % HTTP_RESPONSE.len;
}
}
};
const std = @import("std");
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = &gpa.allocator;
const num_threads = 6; // std.math.max(1, std.Thread.cpuCount() catch 1);
const worker_fds = try allocator.alloc(std.os.fd_t, num_threads);
defer allocator.free(worker_fds);
for (worker_fds) |*worker_fd|
worker_fd.* = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC);
for (worker_fds[1..]) |worker_fd|
_ = try std.Thread.spawn(worker_fd, runWorker);
const server_fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP);
errdefer std.os.close(server_fd);
const port = 12345;
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable;
try std.os.setsockopt(server_fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
try std.os.bind(server_fd, &addr.any, addr.getOsSockLen());
try std.os.listen(server_fd, 128);
var next_fd: usize = 0;
const epoll_fd = worker_fds[next_fd];
var events: [256]std.os.epoll_event = undefined;
var server_event = std.os.epoll_event{
.events = std.os.EPOLLIN | std.os.EPOLLET | std.os.EPOLLRDHUP,
.data = .{ .ptr = 0 },
};
try std.os.epoll_ctl(
epoll_fd,
std.os.EPOLL_CTL_ADD,
server_fd,
&server_event,
);
std.debug.warn("Listening on :{}", .{port});
while (true) {
const found = std.os.epoll_wait(epoll_fd, &events, -1);
for (events[0..found]) |event| {
if (event.data.ptr != 0) {
Client.process(event);
continue;
}
if (event.events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0)
unreachable;
if (event.events & std.os.EPOLLIN == 0)
unreachable;
while (true) {
const client_fd = std.os.accept(server_fd, null, null, std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC) catch |err| switch (err) {
error.WouldBlock => break,
else => |e| return e,
};
if (Client.start(allocator, worker_fds[next_fd], client_fd)) |_| {
next_fd += 1;
if (next_fd >= worker_fds.len)
next_fd = 0;
} else |_| {
std.os.close(client_fd);
std.debug.warn("Failed to start client: {}\n", .{client_fd});
}
}
}
}
}
fn runWorker(epoll_fd: std.os.fd_t) void {
var events: [256]std.os.epoll_event = undefined;
while (true) {
const found = std.os.epoll_wait(epoll_fd, &events, -1);
for (events[0..found]) |event|
Client.process(event);
}
}
const Client = struct {
fd: std.os.socket_t,
allocator: *std.mem.Allocator,
send_bytes: usize = 0,
send_partial: usize = 0,
recv_bytes: usize = 0,
recv_buffer: [4096]u8 = undefined,
const HTTP_CLRF = "\r\n\r\n";
const HTTP_RESPONSE =
"HTTP/1.1 200 Ok\r\n" ++
"Content-Length: 10\r\n" ++
"Content-Type: text/plain; charset=utf8\r\n" ++
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++
"Server: fasthttp\r\n" ++
"\r\n" ++
"HelloWorld";
fn start(allocator: *std.mem.Allocator, epoll_fd: std.os.fd_t, fd: std.os.socket_t) !void {
const self = try allocator.create(Client);
errdefer allocator.destroy(self);
const SOL_TCP = 6;
const TCP_NODELAY = 1;
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1)));
self.* = .{
.fd = fd,
.allocator = allocator,
};
try std.os.epoll_ctl(epoll_fd, std.os.EPOLL_CTL_ADD, fd, &std.os.epoll_event{
.events = std.os.EPOLLIN | std.os.EPOLLOUT | std.os.EPOLLET | std.os.EPOLLRDHUP,
.data = .{ .ptr = @ptrToInt(self) },
});
}
fn process(event: std.os.epoll_event) void {
const self = @intToPtr(*Client, event.data.ptr);
self.processEvent(event.events) catch {
std.os.close(self.fd);
self.allocator.destroy(self);
};
}
fn processEvent(self: *Client, events: u32) !void {
if (events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0)
return error.Closed;
if (events & std.os.EPOLLIN != 0)
try self.processRead();
if (events & std.os.EPOLLOUT != 0)
try self.processWrite();
}
fn processRead(self: *Client) !void {
while (true) {
const request_buffer = self.recv_buffer[0..self.recv_bytes];
// Try to parse and consume a request in the request_buffer
// by matching everything until the end of an HTTP request with no body
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| {
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len];
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer);
self.recv_bytes = unparsed_buffer.len;
// If found, count that as a parsed request
// and have the writer write the static HTTP response (eventually).
self.send_bytes += HTTP_RESPONSE.len;
continue;
}
// A complete wasn't parsed yet.
// Try to read more data into the buffer and try again.
const readable_buffer = self.recv_buffer[self.recv_bytes..];
if (readable_buffer.len == 0)
return error.HttpRequestTooLarge;
// If the read would normally block,
// then we have to wait for the socket to be readable in the future to try again.
const bytes_read = std.os.read(self.fd, readable_buffer) catch |err| switch (err) {
error.WouldBlock => return,
else => |e| return e,
};
// 0 bytes read indicates that the socket can no longer read any data.
self.recv_bytes += bytes_read;
if (bytes_read == 0)
return error.EndOfStream;
}
}
fn processWrite(self: *Client) !void {
const NUM_RESPONSE_CHUNKS = 128;
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS;
while (self.send_bytes > 0) {
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial;
const iov_len = std.math.min(self.send_bytes, RESPONSE_CHUNK.len - self.send_partial);
const writable_buffer = iov_base[0..iov_len];
// Perform the actual write.
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing.
const bytes_written = std.os.sendto(
self.fd,
writable_buffer,
std.os.MSG_NOSIGNAL,
null,
@as(std.os.socklen_t, 0),
) catch |err| switch (err) {
error.WouldBlock => return,
else => |e| return e,
};
self.send_bytes -= bytes_written;
self.send_partial = bytes_written % HTTP_RESPONSE.len;
}
}
};
const std = @import("std");
var poll_fd: std.os.fd_t = undefined;
var server_fd: std.os.socket_t = undefined;
var allocator: *std.mem.Allocator = undefined;
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
allocator = &gpa.allocator;
defer _ = gpa.deinit();
poll_fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC);
defer std.os.close(poll_fd);
server_fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP);
defer std.os.close(server_fd);
const port = 12345;
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable;
try std.os.setsockopt(server_fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
try std.os.bind(server_fd, &addr.any, addr.getOsSockLen());
try std.os.listen(server_fd, 128);
try std.os.epoll_ctl(poll_fd, std.os.EPOLL_CTL_ADD, server_fd, &std.os.epoll_event{
.events = std.os.EPOLLIN | std.os.EPOLLET | std.os.EPOLLRDHUP,
.data = .{ .ptr = @ptrToInt(&server_fd) },
});
var threads = std.math.max(1, try std.Thread.cpuCount());
while (threads > 1) : (threads -= 1)
_ = try std.Thread.spawn({}, runWorker);
std.debug.warn("Listening on :{}\n", .{port});
runWorker({});
}
fn runWorker(_: void) void {
var events: [256]std.os.epoll_event = undefined;
while (true) {
const found = std.os.epoll_wait(poll_fd, &events, -1);
if (found == 0)
continue;
for (events[0..found]) |event| {
const ptr = event.data.ptr;
const flags = event.events;
if (ptr == @ptrToInt(&server_fd)) {
Client.accept(flags) catch |e| std.debug.warn("failed to accept a client: {}\n", .{e});
continue;
}
const client = @intToPtr(*Client, ptr);
client.process(flags) catch {};
}
}
}
const Client = struct {
fd: std.os.socket_t,
send_bytes: usize = 0,
send_partial: usize = 0,
recv_bytes: usize = 0,
recv_buffer: [4096]u8 = undefined,
const HTTP_CLRF = "\r\n\r\n";
const HTTP_RESPONSE =
"HTTP/1.1 200 Ok\r\n" ++
"Content-Length: 10\r\n" ++
"Content-Type: text/plain; charset=utf8\r\n" ++
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++
"Server: fasthttp\r\n" ++
"\r\n" ++
"HelloWorld";
fn accept(flags: u32) !void {
while (true) {
const client_fd = std.os.accept(server_fd, null, null, std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC) catch |err| switch (err) {
error.WouldBlock => return,
else => |e| return e,
};
errdefer std.os.close(client_fd);
const SOL_TCP = 6;
const TCP_NODELAY = 1;
try std.os.setsockopt(client_fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1)));
const self = try allocator.create(Client);
self.* = Client{ .fd = client_fd };
errdefer allocator.destroy(self);
try std.os.epoll_ctl(poll_fd, std.os.EPOLL_CTL_ADD, client_fd, &std.os.epoll_event{
.events = std.os.EPOLLIN | std.os.EPOLLONESHOT | std.os.EPOLLRDHUP,
.data = .{ .ptr = @ptrToInt(self) },
});
}
}
fn process(self: *Client, flags: u32) !void {
errdefer {
std.os.close(self.fd);
allocator.destroy(self);
}
if (flags & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0)
return error.Closed;
var written = false;
if ((flags & std.os.EPOLLOUT != 0) and (self.send_bytes > 0)) {
written = true;
try self.processWrite();
}
if (flags & std.os.EPOLLIN != 0)
try self.processRead();
if (!written and (flags & std.os.EPOLLOUT != 0) and (self.send_bytes > 0))
try self.processWrite();
var events: u32 = std.os.EPOLLIN | std.os.EPOLLONESHOT | std.os.EPOLLRDHUP;
if (self.send_bytes > 0)
events |= std.os.EPOLLOUT;
try std.os.epoll_ctl(poll_fd, std.os.EPOLL_CTL_MOD, self.fd, &std.os.epoll_event{
.events = events,
.data = .{ .ptr = @ptrToInt(self) },
});
}
fn processRead(self: *Client) !void {
while (true) {
const request_buffer = self.recv_buffer[0..self.recv_bytes];
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| {
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len];
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer);
self.recv_bytes = unparsed_buffer.len;
self.send_bytes += HTTP_RESPONSE.len;
continue;
}
const readable_buffer = self.recv_buffer[self.recv_bytes..];
if (readable_buffer.len == 0)
return error.HttpRequestTooLarge;
const bytes_read = std.os.read(self.fd, readable_buffer) catch |err| switch (err) {
error.WouldBlock => return,
else => |e| return e,
};
self.recv_bytes += bytes_read;
if (bytes_read == 0)
return error.EndOfStream;
}
}
fn processWrite(self: *Client) !void {
const NUM_RESPONSE_CHUNKS = 128;
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS;
while (self.send_bytes > 0) {
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial
const send_bytes = self.send_bytes;
if (self.send_partial > RESPONSE_CHUNK.len)
std.debug.panic("invalid send_partial={} chunk={}\n", .{self.send_partial, RESPONSE_CHUNK.len});
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial;
const iov_len = std.math.min(send_bytes, RESPONSE_CHUNK.len - self.send_partial);
const writable_buffer = iov_base[0..iov_len];
// Perform the actual write.
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing.
const bytes_written = std.os.sendto(
self.fd,
writable_buffer,
std.os.MSG_NOSIGNAL,
null,
@as(std.os.socklen_t, 0),
) catch |err| switch (err) {
error.WouldBlock => return,
else => |e| return e,
};
self.send_partial = bytes_written % HTTP_RESPONSE.len;
self.send_bytes -= bytes_written;
}
}
};
const std = @import("std");
const ThreadPool = @import("./ThreadPoolIO.zig");
pub fn main() !void {
var pool = try ThreadPool.init(.{ .max_threads = 6 });
defer pool.deinit();
var server: Server = undefined;
try server.init(&pool, 12345);
var event = std.StaticResetEvent{};
event.wait();
}
const Poller = struct {
fd: std.os.fd_t,
gpa: std.heap.GeneralPurposeAllocator(.{}) = .{},
fn init(self: *Poller) !void {
const fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC);
self.* = .{ .fd = fd };
}
fn deinit(self: *Poller) void {
std.os.close(self.fd);
_ = self.gpa.deinit();
}
fn getAllocator(self: *Poller) *std.mem.Allocator {
return &self.gpa.allocator;
}
};
const Server = struct {
fd: std.os.socket_t,
pool: *ThreadPool,
io_runnable: ThreadPool.IoRunnable,
gpa: std.heap.GeneralPurposeAllocator(.{}),
port: u16,
start_runnable: ThreadPool.Runnable,
fn init(self: *Server, pool: *ThreadPool, comptime port: u16) !void {
self.fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP);
errdefer std.os.close(self.fd);
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable;
try std.os.setsockopt(self.fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
try std.os.bind(self.fd, &addr.any, addr.getOsSockLen());
try std.os.listen(self.fd, 128);
self.gpa = .{};
self.pool = pool;
self.io_runnable = ThreadPool.IoRunnable{
.is_readable = true,
.runnable = .{ .runFn = Server.run },
};
try pool.waitFor(self.fd, &self.io_runnable);
self.port = port;
self.start_runnable = .{ .runFn = Server.start };
pool.schedule(.{}, &self.start_runnable);
}
fn start(runnable: *ThreadPool.Runnable) void {
const self = @fieldParentPtr(Server, "start_runnable", runnable);
std.debug.warn("Listening on :{}\n", .{self.port});
}
fn run(runnable: *ThreadPool.Runnable) void {
const io_runnable = @fieldParentPtr(ThreadPool.IoRunnable, "runnable", runnable);
const self = @fieldParentPtr(Server, "io_runnable", io_runnable);
self.accept() catch |err| {
std.os.close(self.fd);
std.debug.warn("Server shutdown\n", .{});
};
}
fn accept(self: *Server) !void {
if (self.io_runnable.is_closable)
return error.ServerShutdown;
if (!self.io_runnable.is_readable)
unreachable;
while (true) {
const client_fd = std.os.accept(self.fd, null, null, std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC) catch |err| switch (err) {
error.WouldBlock => break,
else => |e| return e,
};
Client.init(client_fd, self) catch |err| {
std.os.close(client_fd);
std.debug.warn("Failed to spawn client: {}\n", .{err});
};
}
try self.pool.waitFor(self.fd, &self.io_runnable);
}
};
const Client = struct {
fd: std.os.socket_t,
server: *Server,
send_bytes: usize = 0,
send_partial: usize = 0,
recv_bytes: usize = 0,
recv_buffer: [4096]u8 = undefined,
io_runnable: ThreadPool.IoRunnable,
const HTTP_CLRF = "\r\n\r\n";
const HTTP_RESPONSE =
"HTTP/1.1 200 Ok\r\n" ++
"Content-Length: 10\r\n" ++
"Content-Type: text/plain; charset=utf8\r\n" ++
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++
"Server: fasthttp\r\n" ++
"\r\n" ++
"HelloWorld";
fn init(fd: std.os.socket_t, server: *Server) !void {
const allocator = &server.gpa.allocator;
const self = try allocator.create(Client);
errdefer allocator.destroy(self);
const SOL_TCP = 6;
const TCP_NODELAY = 1;
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1)));
self.* = .{
.fd = fd,
.server = server,
.io_runnable = .{
.is_readable = true,
.runnable = .{ .runFn = Client.run },
},
};
try self.server.pool.waitFor(self.fd, &self.io_runnable);
}
fn run(runnable: *ThreadPool.Runnable) void {
const io_runnable = @fieldParentPtr(ThreadPool.IoRunnable, "runnable", runnable);
const self = @fieldParentPtr(Client, "io_runnable", io_runnable);
self.process() catch {
std.os.close(self.fd);
self.server.gpa.allocator.destroy(self);
};
}
fn process(self: *Client) !void {
if (self.io_runnable.is_closable)
return error.Closed;
var written = false;
if (self.io_runnable.is_writable and (self.send_bytes > 0)) {
written = true;
try self.processWrite();
}
if (self.io_runnable.is_readable)
try self.processRead();
if (!written and self.io_runnable.is_writable and (self.send_bytes > 0))
try self.processWrite();
self.io_runnable.is_readable = true;
self.io_runnable.is_writable = self.send_bytes > 0;
try self.server.pool.waitFor(self.fd, &self.io_runnable);
}
fn processRead(self: *Client) !void {
while (true) {
const request_buffer = self.recv_buffer[0..self.recv_bytes];
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| {
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len];
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer);
self.recv_bytes = unparsed_buffer.len;
self.send_bytes += HTTP_RESPONSE.len;
continue;
}
const readable_buffer = self.recv_buffer[self.recv_bytes..];
if (readable_buffer.len == 0)
return error.HttpRequestTooLarge;
const bytes_read = std.os.read(self.fd, readable_buffer) catch |err| switch (err) {
error.WouldBlock => return,
else => |e| return e,
};
self.recv_bytes += bytes_read;
if (bytes_read == 0)
return error.EndOfStream;
}
}
fn processWrite(self: *Client) !void {
const NUM_RESPONSE_CHUNKS = 128;
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS;
while (self.send_bytes > 0) {
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial
const send_bytes = self.send_bytes;
if (self.send_partial > RESPONSE_CHUNK.len)
std.debug.panic("invalid send_partial={} chunk={}\n", .{self.send_partial, RESPONSE_CHUNK.len});
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial;
const iov_len = std.math.min(send_bytes, RESPONSE_CHUNK.len - self.send_partial);
const writable_buffer = iov_base[0..iov_len];
// Perform the actual write.
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing.
const bytes_written = std.os.sendto(
self.fd,
writable_buffer,
std.os.MSG_NOSIGNAL,
null,
@as(std.os.socklen_t, 0),
) catch |err| switch (err) {
error.WouldBlock => return,
else => |e| return e,
};
self.send_partial = bytes_written % HTTP_RESPONSE.len;
self.send_bytes -= bytes_written;
}
}
};
const std = @import("std");
const ThreadPool = @import("./ThreadPool.zig");
pub fn main() !void {
var poller: Poller = undefined;
try poller.init();
defer poller.deinit();
var server: Server = undefined;
try server.init(&poller, 12345);
var pool = ThreadPool.init(.{});
defer pool.deinit();
while (true) {
var events: [1024]std.os.epoll_event = undefined;
const found = std.os.epoll_wait(poller.fd, &events, -1);
var batch = ThreadPool.Batch{};
defer if (!batch.isEmpty())
pool.schedule(.{}, batch);
for (events[0..found]) |event| {
const socket = @intToPtr(*Socket, event.data.ptr);
socket.events = event.events;
batch.push(&socket.runnable);
}
}
}
const Poller = struct {
fd: std.os.fd_t,
gpa: std.heap.GeneralPurposeAllocator(.{}) = .{},
fn init(self: *Poller) !void {
const fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC);
self.* = .{ .fd = fd };
}
fn deinit(self: *Poller) void {
std.os.close(self.fd);
_ = self.gpa.deinit();
}
fn getAllocator(self: *Poller) *std.mem.Allocator {
return &self.gpa.allocator;
}
};
const Socket = struct {
fd: std.os.socket_t,
poller: *Poller,
events: u32,
runnable: ThreadPool.Runnable,
fn init(self: *Socket, fd: std.os.socket_t, poller: *Poller, comptime Container: type) !void {
const Callback = struct {
fn runFn(runnable: *ThreadPool.Runnable) void {
const socket = @fieldParentPtr(Socket, "runnable", runnable);
const container = @fieldParentPtr(Container, "socket", socket);
container.run() catch {
std.os.close(socket.fd);
container.deinit();
};
}
};
self.* = .{
.fd = fd,
.poller = poller,
.events = undefined,
.runnable = .{ .runFn = Callback.runFn },
};
try std.os.epoll_ctl(poller.fd, std.os.EPOLL_CTL_ADD, fd, &std.os.epoll_event{
.events = std.os.EPOLLIN | std.os.EPOLLOUT | std.os.EPOLLONESHOT | std.os.EPOLLRDHUP,
.data = .{ .ptr = @ptrToInt(self) },
});
}
fn register(self: *Socket, events: u32) !void {
try std.os.epoll_ctl(self.poller.fd, std.os.EPOLL_CTL_MOD, self.fd, &std.os.epoll_event{
.events = events | std.os.EPOLLONESHOT | std.os.EPOLLRDHUP,
.data = .{ .ptr = @ptrToInt(self) },
});
}
};
const Server = struct {
socket: Socket,
fn init(self: *Server, poller: *Poller, comptime port: u16) !void {
const fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP);
errdefer std.os.close(fd);
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable;
try std.os.setsockopt(fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
try std.os.bind(fd, &addr.any, addr.getOsSockLen());
try std.os.listen(fd, 128);
try self.socket.init(fd, poller, Server);
std.debug.warn("Listening on :{}\n", .{port});
}
fn deinit(self: *Server) void {
std.debug.warn("Server shutdown\n", .{});
}
fn run(self: *Server) !void {
const events = self.socket.events;
const server_fd = self.socket.fd;
if (events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0)
return error.ServerShutdown;
if (events & std.os.EPOLLIN == 0)
unreachable;
while (true) {
const client_fd = std.os.accept(server_fd, null, null, std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC) catch |err| switch (err) {
error.WouldBlock => return self.socket.register(std.os.EPOLLIN),
else => |e| return e,
};
Client.init(client_fd, self.socket.poller) catch |err| {
std.os.close(client_fd);
std.debug.warn("Failed to spawn client: {}\n", .{err});
};
}
}
};
const Client = struct {
socket: Socket,
send_bytes: usize = 0,
send_partial: usize = 0,
recv_bytes: usize = 0,
recv_buffer: [4096]u8 = undefined,
const HTTP_CLRF = "\r\n\r\n";
const HTTP_RESPONSE =
"HTTP/1.1 200 Ok\r\n" ++
"Content-Length: 10\r\n" ++
"Content-Type: text/plain; charset=utf8\r\n" ++
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++
"Server: fasthttp\r\n" ++
"\r\n" ++
"HelloWorld";
fn init(fd: std.os.socket_t, poller: *Poller) !void {
const allocator = poller.getAllocator();
const self = try allocator.create(Client);
errdefer allocator.destroy(self);
const SOL_TCP = 6;
const TCP_NODELAY = 1;
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1)));
self.* = .{ .socket = undefined };
try self.socket.init(fd, poller, Client);
}
fn deinit(self: *Client) void {
const allocator = self.socket.poller.getAllocator();
allocator.destroy(self);
}
fn run(self: *Client) !void {
var events = self.socket.events;
if (events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0)
return error.Closed;
var written = false;
if ((events & std.os.EPOLLOUT != 0) and (self.send_bytes > 0)) {
written = true;
try self.processWrite();
}
if (events & std.os.EPOLLIN != 0)
try self.processRead();
if (!written and (events & std.os.EPOLLOUT != 0) and (self.send_bytes > 0))
try self.processWrite();
events = std.os.EPOLLIN;
if (self.send_bytes > 0)
events |= std.os.EPOLLOUT;
try self.socket.register(events);
}
fn processRead(self: *Client) !void {
while (true) {
const request_buffer = self.recv_buffer[0..self.recv_bytes];
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| {
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len];
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer);
self.recv_bytes = unparsed_buffer.len;
self.send_bytes += HTTP_RESPONSE.len;
continue;
}
const readable_buffer = self.recv_buffer[self.recv_bytes..];
if (readable_buffer.len == 0)
return error.HttpRequestTooLarge;
const bytes_read = std.os.read(self.socket.fd, readable_buffer) catch |err| switch (err) {
error.WouldBlock => return,
else => |e| return e,
};
self.recv_bytes += bytes_read;
if (bytes_read == 0)
return error.EndOfStream;
}
}
fn processWrite(self: *Client) !void {
const NUM_RESPONSE_CHUNKS = 128;
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS;
while (self.send_bytes > 0) {
// Compute the chunk of responses that we need to send bytes on send_bytes + send_partial
const send_bytes = self.send_bytes;
if (self.send_partial > RESPONSE_CHUNK.len)
std.debug.panic("invalid send_partial={} chunk={}\n", .{self.send_partial, RESPONSE_CHUNK.len});
const iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial;
const iov_len = std.math.min(send_bytes, RESPONSE_CHUNK.len - self.send_partial);
const writable_buffer = iov_base[0..iov_len];
// Perform the actual write.
// Use MSG_NOSIGNAL to get error.BrokenPipe instead of a signal on write-end closing.
const bytes_written = std.os.sendto(
self.socket.fd,
writable_buffer,
std.os.MSG_NOSIGNAL,
null,
@as(std.os.socklen_t, 0),
) catch |err| switch (err) {
error.WouldBlock => return,
else => |e| return e,
};
self.send_partial = bytes_written % HTTP_RESPONSE.len;
self.send_bytes -= bytes_written;
}
}
};
const std = @import("std");
const system = switch (std.builtin.os.tag) {
.linux => std.os.linux,
else => std.os.system,
};
const ThreadPool = @This();
max_threads: u16,
counter: u32 = 0,
spawned_queue: ?*Worker = null,
run_queue: UnboundedQueue = .{},
idle_semaphore: Semaphore = Semaphore.init(0),
shutdown_event: Event = .{},
pub const InitConfig = struct {
max_threads: ?u16 = null,
};
pub fn init(config: InitConfig) ThreadPool {
return .{
.max_threads = std.math.min(
std.math.maxInt(u14),
std.math.max(1, config.max_threads orelse blk: {
break :blk @intCast(u16, std.Thread.cpuCount() catch 1);
}),
),
};
}
pub fn deinit(self: *ThreadPool) void {
defer self.* = undefined;
self.shutdown();
self.shutdown_event.wait();
while (self.spawned_queue) |worker| {
self.spawned_queue = worker.spawned_next;
const thread = worker.thread;
worker.shutdown_event.notify();
thread.wait();
}
}
pub const ScheduleHints = struct {
priority: Priority = .Normal,
pub const Priority = enum {
High,
Normal,
Low,
};
};
pub fn schedule(self: *ThreadPool, hints: ScheduleHints, batchable: anytype) void {
const batch = Batch.from(batchable);
if (batch.isEmpty())
return;
if (Worker.current) |worker| {
worker.push(hints, batch);
} else {
self.run_queue.push(batch);
}
_ = self.tryNotifyWith(false);
}
pub const SpawnConfig = struct {
allocator: *std.mem.Allocator,
hints: ScheduleHints = .{},
};
pub fn spawn(self: *ThreadPool, config: SpawnConfig, comptime func: anytype, args: anytype) !void {
const Args = @TypeOf(args);
const is_async = @typeInfo(@TypeOf(func)).Fn.calling_convention == .Async;
const Closure = struct {
func_args: Args,
allocator: *std.mem.Allocator,
runnable: Runnable = .{ .runFn = runFn },
frame: if (is_async) @Frame(runAsyncFn) else void = undefined,
fn runFn(runnable: *Runnable) void {
const closure = @fieldParentPtr(@This(), "runnable", runnable);
if (is_async) {
closure.frame = async closure.runAsyncFn();
} else {
const result = @call(.{}, func, closure.func_args);
closure.allocator.destroy(closure);
}
}
fn runAsyncFn(closure: *@This()) void {
const result = @call(.{}, func, closure.func_args);
suspend closure.allocator.destroy(closure);
}
};
const allocator = config.allocator;
const closure = try allocator.create(Closure);
errdefer allocator.destroy(closure);
closure.* = .{
.func_args = args,
.allocator = allocator,
};
const hints = config.hints;
self.schedule(hints, &closure.runnable);
}
const Counter = struct {
state: State = .pending,
idle: u16 = 0,
spawned: u16 = 0,
const State = enum(u4) {
pending = 0,
notified,
waking,
waker_notified,
shutdown,
};
fn pack(self: Counter) u32 {
return (@as(u32, @as(u4, @enumToInt(self.state))) |
(@as(u32, @intCast(u14, self.idle)) << 4) |
(@as(u32, @intCast(u14, self.spawned)) << (4 + 14)));
}
fn unpack(value: u32) Counter {
return Counter{
.state = @intToEnum(State, @truncate(u4, value)),
.idle = @as(u16, @truncate(u14, value >> 4)),
.spawned = @as(u16, @truncate(u14, value >> (4 + 14))),
};
}
};
fn tryNotifyWith(self: *ThreadPool, is_caller_waking: bool) bool {
var spawned = false;
var remaining_attempts: u8 = 5;
var is_waking = is_caller_waking;
while (true) : (yieldCpu()) {
const counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic));
if (counter.state == .shutdown) {
if (spawned)
self.releaseWorker();
return false;
}
const has_pending = (counter.idle > 0) or (counter.spawned < self.max_threads);
const can_wake = (is_waking and remaining_attempts > 0) or (!is_waking and counter.state == .pending);
if (has_pending and can_wake) {
var new_counter = counter;
new_counter.state = .waking;
if (counter.idle > 0) {
new_counter.idle -= 1;
} else if (!spawned) {
new_counter.spawned += 1;
}
if (@cmpxchgWeak(
u32,
&self.counter,
counter.pack(),
new_counter.pack(),
.Acquire,
.Monotonic,
)) |failed| {
continue;
}
is_waking = true;
if (counter.idle > 0) {
self.idle_semaphore.post(1) catch unreachable;
return true;
}
spawned = true;
if (Worker.spawn(self))
return true;
remaining_attempts -= 1;
continue;
}
var new_counter = counter;
if (is_waking) {
new_counter.state = if (can_wake) .pending else .notified;
if (spawned)
new_counter.spawned -= 1;
} else if (counter.state == .waking) {
new_counter.state = .waker_notified;
} else if (counter.state == .pending) {
new_counter.state = .notified;
} else {
return false;
}
_ = @cmpxchgWeak(
u32,
&self.counter,
counter.pack(),
new_counter.pack(),
.Monotonic,
.Monotonic,
) orelse return true;
}
}
const Wait = enum {
resumed,
notified,
shutdown,
};
fn tryWaitWith(self: *ThreadPool, is_caller_waking: bool) Wait {
var is_waking = is_caller_waking;
var counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic));
while (true) {
if (counter.state == .shutdown) {
self.releaseWorker();
return .shutdown;
}
const is_notified = switch (counter.state) {
.waker_notified => is_waking,
.notified => true,
else => false,
};
var new_counter = counter;
if (is_notified) {
new_counter.state = if (is_waking) .waking else .pending;
} else {
new_counter.idle += 1;
if (is_waking)
new_counter.state = .pending;
}
if (@cmpxchgWeak(
u32,
&self.counter,
counter.pack(),
new_counter.pack(),
.Monotonic,
.Monotonic,
)) |updated| {
counter = Counter.unpack(updated);
continue;
}
if (is_notified and is_waking)
return .notified;
if (is_notified)
return .resumed;
self.idle_semaphore.wait(1);
return .notified;
}
}
fn releaseWorker(self: *ThreadPool) void {
const counter_spawned = Counter{ .spawned = 1 };
const counter_value = @atomicRmw(u32, &self.counter, .Sub, counter_spawned.pack(), .AcqRel);
const counter = Counter.unpack(counter_value);
if (counter.state != .shutdown)
std.debug.panic("ThreadPool.releaseWorker() when not shutdown: {}", .{counter});
if (counter.spawned == 1)
self.shutdown_event.notify();
}
pub fn shutdown(self: *ThreadPool) void {
while (true) : (yieldCpu()) {
const counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic));
if (counter.state == .shutdown)
return;
var new_counter = counter;
new_counter.state = .shutdown;
new_counter.idle = 0;
if (@cmpxchgWeak(
u32,
&self.counter,
counter.pack(),
new_counter.pack(),
.Acquire,
.Monotonic,
)) |failed| {
continue;
}
self.idle_semaphore.post(self.max_threads) catch unreachable;
return;
}
}
const Worker = struct {
pool: *ThreadPool,
thread: *std.Thread,
spawned_next: ?*Worker = null,
shutdown_event: Event = .{},
run_queue: BoundedQueue = .{},
run_queue_next: ?*Runnable = null,
run_queue_lifo: ?*Runnable = null,
run_queue_overflow: UnboundedQueue = .{},
tick: usize = undefined,
is_waking: bool = true,
next_target: ?*Worker = null,
threadlocal var current: ?*Worker = null;
fn spawn(pool: *ThreadPool) bool {
const Spawner = struct {
thread: *std.Thread = undefined,
thread_pool: *ThreadPool,
data_put_event: Event = .{},
data_get_event: Event = .{},
fn entry(self: *@This()) void {
self.data_put_event.wait();
const thread = self.thread;
const thread_pool = self.thread_pool;
self.data_get_event.notify();
Worker.run(thread, thread_pool);
}
};
var spawner = Spawner{ .thread_pool = pool };
spawner.thread = std.Thread.spawn(&spawner, Spawner.entry) catch return false;
spawner.data_put_event.notify();
spawner.data_get_event.wait();
return true;
}
fn run(thread: *std.Thread, pool: *ThreadPool) void {
var self = Worker{
.thread = thread,
.pool = pool,
};
self.tick = @ptrToInt(&self);
current = &self;
defer current = null;
var spawned_queue = @atomicLoad(?*Worker, &pool.spawned_queue, .Monotonic);
while (true) {
self.spawned_next = spawned_queue;
spawned_queue = @cmpxchgWeak(
?*Worker,
&pool.spawned_queue,
spawned_queue,
&self,
.Release,
.Monotonic,
) orelse break;
}
while (true) {
if (self.pop()) |runnable| {
if (self.is_waking) {
self.is_waking = false;
_ = pool.tryNotifyWith(true);
}
self.tick +%= 1;
runnable.run();
continue;
}
self.is_waking = switch (pool.tryWaitWith(self.is_waking)) {
.resumed => false,
.notified => true,
.shutdown => {
self.shutdown_event.wait();
break;
},
};
}
}
fn push(self: *Worker, hints: ScheduleHints, batchable: anytype) void {
var batch = Batch.from(batchable);
if (batch.isEmpty())
return;
if (hints.priority == .High) {
const new_lifo = batch.pop();
if (@atomicLoad(?*Runnable, &self.run_queue_lifo, .Monotonic) == null) {
@atomicStore(?*Runnable, &self.run_queue_lifo, new_lifo, .Release);
} else if (@atomicRmw(?*Runnable, &self.run_queue_lifo, .Xchg, new_lifo, .AcqRel)) |old_lifo| {
batch.pushFront(old_lifo);
}
}
if (hints.priority == .Low) {
if (self.run_queue_next) |old_next|
batch.pushFront(old_next);
self.run_queue_next = null;
self.run_queue_next = self.pop() orelse batch.pop();
}
if (self.run_queue.push(batch)) |overflowed|
self.run_queue_overflow.push(overflowed);
}
fn pop(self: *Worker) ?*Runnable {
if (self.tick % 127 == 0) {
if (self.popAndStealFromOthers()) |runnable|
return runnable;
}
if (self.tick % 61 == 0) {
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable|
return runnable;
}
if (self.tick % 31 == 0) {
if (self.run_queue.popAndStealUnbounded(&self.run_queue_overflow)) |runnable|
return runnable;
}
if (self.tick % 13 == 0) {
if (self.popAndStealLifo(self)) |runnable|
return runnable;
}
if (self.run_queue.pop()) |runnable|
return runnable;
if (self.popAndStealLifo(self)) |runnable|
return runnable;
if (self.run_queue.popAndStealUnbounded(&self.run_queue_overflow)) |runnable|
return runnable;
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable|
return runnable;
if (self.popAndStealFromOthers()) |runnable|
return runnable;
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable|
return runnable;
return null;
}
fn popAndStealLifo(self: *Worker, target: *Worker) ?*Runnable {
var run_queue_lifo = @atomicLoad(?*Runnable, &target.run_queue_lifo, .Monotonic);
while (true) {
if (run_queue_lifo == null)
return null;
run_queue_lifo = @cmpxchgWeak(
?*Runnable,
&target.run_queue_lifo,
run_queue_lifo,
null,
.Acquire,
.Monotonic,
) orelse return run_queue_lifo;
}
}
fn popAndStealFromOthers(self: *Worker) ?*Runnable {
var num_workers = blk: {
const counter_value = @atomicLoad(u32, &self.pool.counter, .Monotonic);
const counter = Counter.unpack(counter_value);
break :blk counter.spawned;
};
while (num_workers > 0) : (num_workers -= 1) {
const target = self.next_target orelse blk: {
break :blk @atomicLoad(?*Worker, &self.pool.spawned_queue, .Acquire) orelse {
std.debug.panic("Worker observed empty spawned queue when work-stealing", .{});
};
};
self.next_target = target.spawned_next;
if (target == self)
continue;
if (self.run_queue.popAndStealBounded(&target.run_queue)) |runnable|
return runnable;
if (self.run_queue.popAndStealUnbounded(&target.run_queue_overflow)) |runnable|
return runnable;
if (self.popAndStealLifo(target)) |runnable|
return runnable;
}
return null;
}
};
const UnboundedQueue = struct {
lock: Mutex = .{},
batch: Batch = .{},
shared_size: usize = 0,
fn push(self: *UnboundedQueue, batchable: anytype) void {
const batch = Batch.from(batchable);
if (batch.isEmpty())
return;
const held = self.lock.acquire();
defer held.release();
self.batch.push(batch);
var shared_size = self.shared_size;
shared_size += batch.size;
@atomicStore(usize, &self.shared_size, shared_size, .Release);
}
fn tryAcquireConsumer(self: *UnboundedQueue) ?Consumer {
var shared_size = @atomicLoad(usize, &self.shared_size, .Acquire);
if (shared_size == 0)
return null;
const held = self.lock.acquire();
shared_size = self.shared_size;
if (shared_size == 0) {
held.release();
return null;
}
return Consumer{
.held = held,
.queue = self,
.size = shared_size,
};
}
const Consumer = struct {
held: Mutex.Held,
queue: *UnboundedQueue,
size: usize,
fn release(self: Consumer) void {
@atomicStore(usize, &self.queue.shared_size, self.size, .Release);
self.held.release();
}
fn pop(self: *Consumer) ?*Runnable {
const runnable = self.queue.batch.pop() orelse return null;
self.size -= 1;
return runnable;
}
};
};
const BoundedQueue = struct {
head: usize = 0,
tail: usize = 0,
buffer: [256]*Runnable = undefined,
fn push(self: *BoundedQueue, batchable: anytype) ?Batch {
var batch = Batch.from(batchable);
while (true) : (yieldCpu()) {
if (batch.isEmpty())
return null;
var tail = self.tail;
var head = @atomicLoad(usize, &self.head, .Acquire);
var size = tail -% head;
if (size < self.buffer.len) {
while (size < self.buffer.len) {
const runnable = batch.pop() orelse break;
@atomicStore(*Runnable, &self.buffer[tail % self.buffer.len], runnable, .Unordered);
tail +%= 1;
size += 1;
}
@atomicStore(usize, &self.tail, tail, .Release);
continue;
}
var migrate = self.buffer.len / 2;
if (@cmpxchgWeak(
usize,
&self.head,
head,
head +% migrate,
.AcqRel,
.Acquire,
)) |failed| {
continue;
}
var overflowed = Batch{};
while (migrate > 0) : (migrate -= 1) {
const runnable = self.buffer[head % self.buffer.len];
overflowed.push(runnable);
head +%= 1;
}
overflowed.push(batch);
return overflowed;
}
}
fn pop(self: *BoundedQueue) ?*Runnable {
while (true) : (yieldCpu()) {
const tail = self.tail;
const head = @atomicLoad(usize, &self.head, .Acquire);
const size = tail -% head;
if (size == 0)
return null;
if (@cmpxchgWeak(
usize,
&self.head,
head,
head +% 1,
.AcqRel,
.Acquire,
)) |failed| {
continue;
}
const runnable = self.buffer[head % self.buffer.len];
return runnable;
}
}
fn popAndStealBounded(self: *BoundedQueue, target: *BoundedQueue) ?*Runnable {
if (target == self)
return self.pop();
const tail = self.tail;
const head = @atomicLoad(usize, &self.head, .Acquire);
const size = tail -% head;
if (size != 0)
return self.pop();
while (true) : (yieldThread()) {
const target_head = @atomicLoad(usize, &target.head, .Acquire);
const target_tail = @atomicLoad(usize, &target.tail, .Acquire);
const target_size = target_tail -% target_head;
var steal = target_size - (target_size / 2);
if (steal == 0)
return null;
if (steal > target.buffer.len / 2)
continue;
const first_runnable = @atomicLoad(*Runnable, &target.buffer[target_head % target.buffer.len], .Unordered);
var new_target_head = target_head +% 1;
var new_tail = tail;
steal -= 1;
while (steal > 0) : (steal -= 1) {
const runnable = @atomicLoad(*Runnable, &target.buffer[new_target_head % target.buffer.len], .Unordered);
new_target_head +%= 1;
@atomicStore(*Runnable, &self.buffer[new_tail % self.buffer.len], runnable, .Unordered);
new_tail +%= 1;
}
if (@cmpxchgWeak(
usize,
&target.head,
target_head,
new_target_head,
.AcqRel,
.Acquire,
)) |failed| {
continue;
}
@atomicStore(usize, &self.tail, new_tail, .Release);
return first_runnable;
}
}
fn popAndStealUnbounded(self: *BoundedQueue, target: *UnboundedQueue) ?*Runnable {
var consumer = target.tryAcquireConsumer() orelse return null;
defer consumer.release();
const first_runnable = consumer.pop() orelse return null;
var tail = self.tail;
var head = @atomicLoad(usize, &self.head, .Acquire);
var size = tail -% head;
while (size < self.buffer.len) {
const runnable = consumer.pop() orelse break;
@atomicStore(*Runnable, &self.buffer[tail % self.buffer.len], runnable, .Unordered);
tail +%= 1;
size += 1;
}
@atomicStore(usize, &self.tail, tail, .Release);
return first_runnable;
}
};
pub const Runnable = struct {
next: ?*Runnable = null,
runFn: fn (*Runnable) void,
pub fn run(self: *Runnable) void {
return (self.runFn)(self);
}
};
pub const Batch = struct {
head: ?*Runnable = null,
tail: *Runnable = undefined,
size: usize = 0,
pub fn from(batchable: anytype) Batch {
return switch (@TypeOf(batchable)) {
Batch => batchable,
?*Runnable => from(batchable orelse return Batch{}),
*Runnable => {
batchable.next = null;
return Batch{
.head = batchable,
.tail = batchable,
.size = 1,
};
},
else => |typ| @compileError(@typeName(typ) ++
" cannot be converted into " ++
@typeName(Batch)),
};
}
pub fn isEmpty(self: Batch) bool {
return self.head == null;
}
pub const push = pushBack;
pub fn pushBack(self: *Batch, batchable: anytype) void {
const batch = from(batchable);
if (batch.isEmpty())
return;
if (self.isEmpty()) {
self.* = batch;
} else {
self.tail.next = batch.head;
self.tail = batch.tail;
self.size += batch.size;
}
}
pub fn pushFront(self: *Batch, batchable: anytype) void {
const batch = from(batchable);
if (batch.isEmpty())
return;
if (self.isEmpty()) {
self.* = batch;
} else {
batch.tail.next = self.head;
self.head = batch.head;
self.size += batch.size;
}
}
pub const pop = popFront;
pub fn popFront(self: *Batch) ?*Runnable {
const runnable = self.head orelse return null;
self.head = runnable.next;
self.size -= 1;
return runnable;
}
};
const Semaphore = struct {
lock: Mutex = .{},
permits: usize = 0,
waiters: ?*Waiter = null,
const Waiter = struct {
next: ?*Waiter = null,
tail: *Waiter = undefined,
event: Event = .{},
permits: usize,
};
fn init(permits: usize) Semaphore {
return .{ .permits = permits };
}
fn wait(self: *Semaphore, permits: usize) void {
const held = self.lock.acquire();
if (self.permits >= permits) {
self.permits -= permits;
held.release();
return;
}
var waiter = Waiter{ .permits = permits };
if (self.waiters) |head| {
head.tail.next = &waiter;
head.tail = &waiter;
} else {
self.waiters = &waiter;
waiter.tail = &waiter;
}
held.release();
waiter.event.wait();
}
fn post(self: *Semaphore, permits: usize) error{Overflow}!void {
var waiters: ?*Waiter = null;
{
const held = self.lock.acquire();
defer held.release();
if (@addWithOverflow(usize, self.permits, permits, &self.permits))
return error.Overflow;
while (self.waiters) |waiter| {
if (waiter.permits > self.permits)
break;
self.waiters = waiter.next;
if (self.waiters) |new_waiter|
new_waiter.tail = waiter.tail;
self.permits -= waiter.permits;
waiter.next = waiters;
waiters = waiter;
}
}
while (waiters) |waiter| {
waiters = waiter.next;
waiter.event.notify();
}
}
};
const Mutex = if (std.builtin.os.tag == .windows)
struct {
srwlock: usize = 0,
pub fn acquire(self: *Mutex) Held {
AcquireSRWLockExclusive(&self.srwlock);
return Mutex{ .mutex = self };
}
pub const Held = struct {
mutex: *Mutex,
pub fn release(self: Held) void {
ReleaseSRWLockExclusive(&self.mutex.srwlock);
}
};
extern "kernel32" fn AcquireSRWLockExclusive(
srwlock: *?system.PVOID,
) callconv(system.WINAPI) void;
extern "kernel32" fn ReleaseSRWLockExclusive(
srwlock: *?system.PVOID,
) callconv(system.WINAPI) void;
}
else if (comptime std.Target.current.isDarwin())
struct {
lock: u32 = 0,
pub fn acquire(self: *Mutex) Held {
os_unfair_lock_lock(&self.lock);
return Held{ .mutex = self };
}
pub const Held = struct {
mutex: *Mutex,
pub fn release(self: Held) void {
os_unfair_lock_unlock(&self.mutex.lock);
}
};
extern "c" fn os_unfair_lock_lock(
unfair_lock: *u32,
) callconv(.C) void;
extern "c" fn os_unfair_lock_unlock(
unfair_lock: *u32,
) callconv(.C) void;
}
else if (std.builtin.os.tag == .linux)
struct {
state: i32 = UNLOCKED,
const UNLOCKED: i32 = 0;
const LOCKED: i32 = 1;
const WAITING: i32 = 2;
pub fn acquire(self: *Mutex) Held {
const state = @atomicRmw(i32, &self.state, .Xchg, LOCKED, .Acquire);
if (state != UNLOCKED)
self.acquireSlow(state);
return Held{ .mutex = self };
}
pub const Held = struct {
mutex: *Mutex,
pub fn release(self: Held) void {
switch (@atomicRmw(i32, &self.mutex.state, .Xchg, UNLOCKED, .Release)) {
UNLOCKED => unreachable, // unlocked an unlocked mutex
LOCKED => {},
WAITING => self.mutex.releaseSlow(),
else => unreachable,
}
}
};
fn acquireSlow(self: *Mutex, current_state: i32) void {
@setCold(true);
var wait_state = current_state;
while (true) {
var spin: u8 = 0;
while (spin < 5) : (spin += 1) {
switch (@atomicLoad(i32, &self.state, .Monotonic)) {
UNLOCKED => _ = @cmpxchgWeak(
i32,
&self.state,
UNLOCKED,
wait_state,
.Acquire,
.Monotonic,
) orelse return,
LOCKED => {},
WAITING => break,
else => unreachable,
}
if (spin < 4) {
var pause: u8 = 30;
while (pause > 0) : (pause -= 1)
yieldCpu();
} else {
yieldThread();
}
}
const state = @atomicRmw(i32, &self.state, .Xchg, WAITING, .Acquire);
if (state == UNLOCKED)
return;
wait_state = WAITING;
switch (system.getErrno(system.futex_wait(
&self.state,
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAIT,
WAITING,
null,
))) {
0 => {},
system.EINTR => {},
system.EAGAIN => {},
else => unreachable,
}
}
}
fn releaseSlow(self: *Mutex) void {
@setCold(true);
while (true) {
return switch (system.getErrno(system.futex_wake(
&self.state,
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAKE,
@as(i32, 1),
))) {
0 => {},
system.EINTR => continue,
system.EFAULT => {},
else => unreachable,
};
}
}
}
else
struct {
locked: bool = false,
pub fn acquire(self: *Mutex) Held {
while (@atomicRmw(bool, &self.locked, .Xchg, true, .Acquire))
yieldThread();
return Held{ .mutex = self };
}
pub const Held = struct {
mutex: *Mutex,
pub fn release(self: Held) void {
@atomicStore(bool, &self.mutex.locked, false, .Release);
}
};
};
const Event = if (std.builtin.os.tag == .windows)
struct {
key: u32 = undefined,
pub fn wait(self: *Event) void {
const status = NtWaitForKeyedEvent(null, &self.key, system.FALSE, null);
std.debug.assert(status == .SUCCESS);
}
pub fn notify(self: *Event) void {
const status = NtReleaseKeyedEvent(null, &self.key, system.FALSE, null);
std.debug.assert(status == .SUCCESS);
}
extern "NtDll" fn NtWaitForKeyedEvent(
handle: ?system.HANDLE,
key: ?*const u32,
alertable: system.BOOLEAN,
timeout: ?*const system.LARGE_INTEGER,
) callconv(system.WINAPI) system.NTSTATUS;
extern "NtDll" fn NtReleaseKeyedEvent(
handle: ?system.HANDLE,
key: ?*const u32,
alertable: system.BOOLEAN,
timeout: ?*const system.LARGE_INTEGER,
) callconv(system.WINAPI) system.NTSTATUS;
}
else if (comptime std.Target.current.isDarwin())
struct {
state: enum(u32) {
pending = 0,
notified,
} = .pending,
pub fn wait(self: *Event) void {
while (true) {
switch (@atomicLoad(@TypeOf(self.state), &self.state, .Acquire)) {
.pending => {},
.notified => {
@atomicStore(@TypeOf(self.state), &self.state, .pending, .Monotonic);
return;
},
}
const status = __ulock_wait(
UL_COMPARE_AND_WAIT | ULF_NO_ERRNO,
@ptrCast(?*const c_void, &self.state),
@enumToInt(@TypeOf(self.state).pending),
~@as(u32, 0),
);
if (status < 0) {
switch (-status) {
system.EINTR => {},
else => unreachable,
}
}
}
}
pub fn notify(self: *Event) void {
@atomicStore(@TypeOf(self.state), &self.state, .notified, .Release);
while (true) {
const status = __ulock_wake(
UL_COMPARE_AND_WAIT | ULF_NO_ERRNO,
@ptrCast(?*const c_void, &self.state),
@as(u32, 0),
);
if (status < 0) {
switch (-status) {
system.ENOENT => {},
system.EINTR => continue,
else => unreachable,
}
}
return;
}
}
const ULF_NO_ERRNO = 0x1000000;
const UL_COMPARE_AND_WAIT = 0x1;
extern "c" fn __ulock_wait(
operation: u32,
address: ?*const c_void,
value: u64,
timeout_us: u32,
) callconv(.C) c_int;
extern "c" fn __ulock_wake(
operation: u32,
address: ?*const c_void,
value: u64,
) callconv(.C) c_int;
}
else if (std.builtin.os.tag == .linux)
struct {
state: enum(i32) {
pending,
notified,
} = .pending,
pub fn wait(self: *Event) void {
while (true) {
switch (@atomicLoad(@TypeOf(self.state), &self.state, .Acquire)) {
.pending => {},
.notified => {
@atomicStore(@TypeOf(self.state), &self.state, .pending, .Monotonic);
return;
},
}
switch (system.getErrno(system.futex_wait(
@ptrCast(*const i32, &self.state),
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAIT,
@enumToInt(@TypeOf(self.state).pending),
null,
))) {
0 => {},
system.EINTR => {},
system.EAGAIN => {},
else => unreachable,
}
}
}
pub fn notify(self: *Event) void {
@atomicStore(@TypeOf(self.state), &self.state, .notified, .Release);
while (true) {
return switch (system.getErrno(system.futex_wake(
@ptrCast(*const i32, &self.state),
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAKE,
@as(i32, 1),
))) {
0 => {},
system.EINTR => continue,
system.EFAULT => {},
else => unreachable,
};
}
}
}
else
struct {
notified: bool = false,
pub fn wait(self: *Event) void {
while (!@atomicLoad(bool, &self.notified, .Acquire))
yieldThread();
@atomicStore(bool, &self.notified, false, .Monotonic);
}
pub fn notify(self: *Event) void {
@atomicStore(bool, &self.notified, true, .Release);
}
};
const yieldThread = if (std.builtin.os.tag == .windows)
struct {
fn yield() void {
system.kernel32.Sleep(0);
}
}.yield
else if (comptime std.Target.current.isDarwin())
struct {
fn yield() void {
_ = thread_switch(MACH_PORT_NULL, SWITCH_OPTION_DEPRESS, 1);
}
const MACH_PORT_NULL = 0;
const SWITCH_OPTION_DEPRESS = 1;
// https://www.gnu.org/software/hurd/gnumach-doc/Hand_002dOff-Scheduling.html
extern "c" fn thread_switch(
thread: usize,
options: c_int,
timeout_ms: c_int,
) callconv(.C) c_int;
}.yield
else if (std.builtin.os.tag == .linux or std.builtin.link_libc)
struct {
fn yield() void {
_ = system.sched_yield();
}
}.yield
else
yieldCpu;
fn yieldCpu() void {
switch (std.builtin.arch) {
.i386, .x86_64 => asm volatile ("pause"),
.arm, .aarch64 => asm volatile ("yield"),
else => {},
}
}
const std = @import("std");
const system = switch (std.builtin.os.tag) {
.linux => std.os.linux,
else => std.os.system,
};
const ThreadPool = @This();
io_driver: IoDriver,
max_threads: u16,
counter: u32 = 0,
spawned_queue: ?*Worker = null,
run_queue: UnboundedQueue = .{},
shutdown_event: Event = .{},
pub const InitConfig = struct {
max_threads: ?u16 = null,
};
pub fn init(config: InitConfig) !ThreadPool {
return ThreadPool{
.io_driver = try IoDriver.init(),
.max_threads = std.math.min(
std.math.maxInt(u14),
std.math.max(1, config.max_threads orelse blk: {
break :blk @intCast(u16, std.Thread.cpuCount() catch 1);
}),
),
};
}
pub fn deinit(self: *ThreadPool) void {
defer self.* = undefined;
defer self.io_driver.deinit();
self.shutdown();
self.shutdown_event.wait();
while (self.spawned_queue) |worker| {
self.spawned_queue = worker.spawned_next;
const thread = worker.thread;
worker.shutdown_event.notify();
thread.wait();
}
}
pub const ScheduleHints = struct {
priority: Priority = .Normal,
pub const Priority = enum {
High,
Normal,
Low,
};
};
pub fn schedule(self: *ThreadPool, hints: ScheduleHints, batchable: anytype) void {
const batch = Batch.from(batchable);
if (batch.isEmpty())
return;
if (Worker.current) |worker| {
worker.push(hints, batch);
} else {
self.run_queue.push(batch);
}
_ = self.tryNotifyWith(false);
}
pub const SpawnConfig = struct {
allocator: *std.mem.Allocator,
hints: ScheduleHints = .{},
};
pub fn spawn(self: *ThreadPool, config: SpawnConfig, comptime func: anytype, args: anytype) !void {
const Args = @TypeOf(args);
const is_async = @typeInfo(@TypeOf(func)).Fn.calling_convention == .Async;
const Closure = struct {
func_args: Args,
allocator: *std.mem.Allocator,
runnable: Runnable = .{ .runFn = runFn },
frame: if (is_async) @Frame(runAsyncFn) else void = undefined,
fn runFn(runnable: *Runnable) void {
const closure = @fieldParentPtr(@This(), "runnable", runnable);
if (is_async) {
closure.frame = async closure.runAsyncFn();
} else {
const result = @call(.{}, func, closure.func_args);
closure.allocator.destroy(closure);
}
}
fn runAsyncFn(closure: *@This()) void {
const result = @call(.{}, func, closure.func_args);
suspend closure.allocator.destroy(closure);
}
};
const allocator = config.allocator;
const closure = try allocator.create(Closure);
errdefer allocator.destroy(closure);
closure.* = .{
.func_args = args,
.allocator = allocator,
};
const hints = config.hints;
self.schedule(hints, &closure.runnable);
}
const Counter = struct {
state: State = .pending,
notified: bool = false,
idle: u16 = 0,
spawned: u16 = 0,
const State = enum(u3) {
pending = 0,
notified,
waking,
waker_notified,
shutdown,
};
fn pack(self: Counter) u32 {
return (@as(u32, @as(u3, @enumToInt(self.state))) |
(@as(u32, @boolToInt(self.notified)) << 3) |
(@as(u32, @intCast(u14, self.idle)) << 4) |
(@as(u32, @intCast(u14, self.spawned)) << (4 + 14)));
}
fn unpack(value: u32) Counter {
return Counter{
.state = @intToEnum(State, @truncate(u3, value)),
.notified = value & (1 << 3) != 0,
.idle = @as(u16, @truncate(u14, value >> 4)),
.spawned = @as(u16, @truncate(u14, value >> (4 + 14))),
};
}
};
fn tryNotifyWith(self: *ThreadPool, is_caller_waking: bool) bool {
var spawned = false;
var remaining_attempts: u8 = 5;
var is_waking = is_caller_waking;
while (true) : (yieldCpu()) {
const counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic));
if (counter.state == .shutdown) {
if (spawned)
self.releaseWorker();
return false;
}
const has_pending = (counter.idle > 0) or (counter.spawned < self.max_threads);
const can_wake = (is_waking and remaining_attempts > 0) or (!is_waking and counter.state == .pending);
if (has_pending and can_wake) {
var new_counter = counter;
new_counter.state = .waking;
if (counter.idle > 0) {
new_counter.idle -= 1;
new_counter.notified = true;
} else if (!spawned) {
new_counter.spawned += 1;
}
if (@cmpxchgWeak(
u32,
&self.counter,
counter.pack(),
new_counter.pack(),
.Acquire,
.Monotonic,
)) |failed| {
continue;
}
is_waking = true;
if (counter.idle > 0) {
self.idleNotify();
return true;
}
spawned = true;
if (Worker.spawn(self))
return true;
remaining_attempts -= 1;
continue;
}
var new_counter = counter;
if (is_waking) {
new_counter.state = if (can_wake) .pending else .notified;
if (spawned)
new_counter.spawned -= 1;
} else if (counter.state == .waking) {
new_counter.state = .waker_notified;
} else if (counter.state == .pending) {
new_counter.state = .notified;
} else {
return false;
}
_ = @cmpxchgWeak(
u32,
&self.counter,
counter.pack(),
new_counter.pack(),
.Monotonic,
.Monotonic,
) orelse return true;
}
}
const Wait = enum {
resumed,
notified,
shutdown,
};
fn tryWaitWith(self: *ThreadPool, worker: *Worker) Wait {
var is_waking = worker.is_waking;
var counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic));
while (true) {
if (counter.state == .shutdown) {
self.releaseWorker();
return .shutdown;
}
const is_notified = switch (counter.state) {
.waker_notified => is_waking,
.notified => true,
else => false,
};
var new_counter = counter;
if (is_notified) {
new_counter.state = if (is_waking) .waking else .pending;
} else {
new_counter.idle += 1;
if (is_waking)
new_counter.state = .pending;
}
if (@cmpxchgWeak(
u32,
&self.counter,
counter.pack(),
new_counter.pack(),
.Monotonic,
.Monotonic,
)) |updated| {
counter = Counter.unpack(updated);
continue;
}
if (is_notified and is_waking)
return .notified;
if (is_notified)
return .resumed;
self.idleWait(worker);
return .notified;
}
}
fn releaseWorker(self: *ThreadPool) void {
const counter_spawned = Counter{ .spawned = 1 };
const counter_value = @atomicRmw(u32, &self.counter, .Sub, counter_spawned.pack(), .AcqRel);
const counter = Counter.unpack(counter_value);
if (counter.state != .shutdown)
std.debug.panic("ThreadPool.releaseWorker() when not shutdown: {}", .{counter});
if (counter.spawned == 1)
self.shutdown_event.notify();
}
pub fn shutdown(self: *ThreadPool) void {
while (true) : (yieldCpu()) {
const counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic));
if (counter.state == .shutdown)
return;
var new_counter = counter;
new_counter.state = .shutdown;
new_counter.idle = 0;
if (@cmpxchgWeak(
u32,
&self.counter,
counter.pack(),
new_counter.pack(),
.Acquire,
.Monotonic,
)) |failed| {
continue;
}
self.idleShutdown();
return;
}
}
fn idleWait(self: *ThreadPool, worker: *Worker) void {
var counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic));
while (true) {
if (counter.state == .shutdown) {
self.io_driver.notify();
return;
}
if (counter.notified) {
var new_counter = counter;
new_counter.notified = false;
counter = Counter.unpack(@cmpxchgWeak(
u32,
&self.counter,
counter.pack(),
new_counter.pack(),
.Acquire,
.Monotonic,
) orelse return);
continue;
}
const batch = self.io_driver.wait();
self.schedule(.{}, batch);
counter = Counter.unpack(@atomicLoad(u32, &self.counter, .Monotonic));
}
}
fn idleNotify(self: *ThreadPool) void {
self.io_driver.notify();
}
fn idleShutdown(self: *ThreadPool) void {
self.io_driver.notify();
}
pub const IoRunnable = struct {
fd: std.os.fd_t = -1,
is_closable: bool = false,
is_readable: bool = false,
is_writable: bool = false,
runnable: Runnable,
};
pub fn waitFor(self: *ThreadPool, fd: std.os.fd_t, io_runnable: *IoRunnable) !void {
return self.io_driver.register(fd, io_runnable);
}
const IoDriver = struct {
poll_fd: std.os.fd_t,
notify_fd: std.os.fd_t,
fn init() !IoDriver {
const poll_fd = try std.os.epoll_create1(std.os.EPOLL_CLOEXEC);
errdefer std.os.close(poll_fd);
const notify_fd = try std.os.eventfd(0, std.os.EFD_CLOEXEC | std.os.EFD_NONBLOCK);
errdefer std.os.close(notify_fd);
var event = std.os.epoll_event{
.events = std.os.EPOLLONESHOT,
.data = .{ .ptr = 0 },
};
try std.os.epoll_ctl(poll_fd, std.os.EPOLL_CTL_ADD, notify_fd, &event);
return IoDriver{
.poll_fd = poll_fd,
.notify_fd = notify_fd,
};
}
fn deinit(self: *IoDriver) void {
std.os.close(self.poll_fd);
std.os.close(self.notify_fd);
}
fn notify(self: *IoDriver) void {
const fd = self.notify_fd;
var event = std.os.epoll_event{
.events = std.os.EPOLLOUT | std.os.EPOLLONESHOT,
.data = .{ .ptr = 0 },
};
std.os.epoll_ctl(self.poll_fd, std.os.EPOLL_CTL_MOD, fd, &event) catch |err| switch (err) {
error.FileDescriptorNotRegistered => {
std.os.epoll_ctl(self.poll_fd, std.os.EPOLL_CTL_ADD, fd, &event) catch {};
},
else => {},
};
}
fn register(self: *IoDriver, fd: std.os.fd_t, io_runnable: *IoRunnable) !void {
var events: u32 = std.os.EPOLLONESHOT | std.os.EPOLLRDHUP | std.os.EPOLLHUP;
if (io_runnable.is_readable)
events |= std.os.EPOLLIN;
if (io_runnable.is_writable)
events |= std.os.EPOLLOUT;
io_runnable.fd = fd;
var event = std.os.epoll_event{
.events = events,
.data = .{ .ptr = @ptrToInt(io_runnable) },
};
std.os.epoll_ctl(self.poll_fd, std.os.EPOLL_CTL_MOD, fd, &event) catch |err| switch (err) {
error.FileDescriptorNotRegistered => {
try std.os.epoll_ctl(self.poll_fd, std.os.EPOLL_CTL_ADD, fd, &event);
},
else => |e| return e,
};
}
fn wait(self: *IoDriver) Batch {
var batch = Batch{};
var events: [128]std.os.epoll_event = undefined;
const count = std.os.epoll_wait(self.poll_fd, &events, -1);
for (events[0..count]) |event| {
const io_runnable = @intToPtr(?*IoRunnable, event.data.ptr) orelse continue;
io_runnable.is_closable = (event.events & (std.os.EPOLLERR | std.os.EPOLLHUP | std.os.EPOLLRDHUP) != 0);
io_runnable.is_writable = (event.events & (std.os.EPOLLOUT) != 0);
io_runnable.is_readable = (event.events & (std.os.EPOLLIN) != 0);
batch.push(&io_runnable.runnable);
}
return batch;
}
};
const Worker = struct {
pool: *ThreadPool,
thread: *std.Thread,
spawned_next: ?*Worker = null,
shutdown_event: Event = .{},
run_queue: BoundedQueue = .{},
run_queue_next: ?*Runnable = null,
run_queue_lifo: ?*Runnable = null,
run_queue_overflow: UnboundedQueue = .{},
tick: usize = undefined,
is_waking: bool = true,
next_target: ?*Worker = null,
threadlocal var current: ?*Worker = null;
fn spawn(pool: *ThreadPool) bool {
const Spawner = struct {
thread: *std.Thread = undefined,
thread_pool: *ThreadPool,
data_put_event: Event = .{},
data_get_event: Event = .{},
fn entry(self: *@This()) void {
self.data_put_event.wait();
const thread = self.thread;
const thread_pool = self.thread_pool;
self.data_get_event.notify();
Worker.run(thread, thread_pool);
}
};
var spawner = Spawner{ .thread_pool = pool };
spawner.thread = std.Thread.spawn(&spawner, Spawner.entry) catch return false;
spawner.data_put_event.notify();
spawner.data_get_event.wait();
return true;
}
fn run(thread: *std.Thread, pool: *ThreadPool) void {
var self = Worker{
.thread = thread,
.pool = pool,
};
self.tick = @ptrToInt(&self);
current = &self;
defer current = null;
var spawned_queue = @atomicLoad(?*Worker, &pool.spawned_queue, .Monotonic);
while (true) {
self.spawned_next = spawned_queue;
spawned_queue = @cmpxchgWeak(
?*Worker,
&pool.spawned_queue,
spawned_queue,
&self,
.Release,
.Monotonic,
) orelse break;
}
while (true) {
if (self.pop()) |runnable| {
if (self.is_waking) {
self.is_waking = false;
_ = pool.tryNotifyWith(true);
}
self.tick +%= 1;
runnable.run();
continue;
}
self.is_waking = switch (pool.tryWaitWith(&self)) {
.resumed => false,
.notified => true,
.shutdown => {
self.shutdown_event.wait();
break;
},
};
}
}
fn push(self: *Worker, hints: ScheduleHints, batchable: anytype) void {
var batch = Batch.from(batchable);
if (batch.isEmpty())
return;
if (hints.priority == .High) {
const new_lifo = batch.pop();
if (@atomicLoad(?*Runnable, &self.run_queue_lifo, .Monotonic) == null) {
@atomicStore(?*Runnable, &self.run_queue_lifo, new_lifo, .Release);
} else if (@atomicRmw(?*Runnable, &self.run_queue_lifo, .Xchg, new_lifo, .AcqRel)) |old_lifo| {
batch.pushFront(old_lifo);
}
}
if (hints.priority == .Low) {
if (self.run_queue_next) |old_next|
batch.pushFront(old_next);
self.run_queue_next = null;
self.run_queue_next = self.pop() orelse batch.pop();
}
if (self.run_queue.push(batch)) |overflowed|
self.run_queue_overflow.push(overflowed);
}
fn pop(self: *Worker) ?*Runnable {
if (self.tick % 127 == 0) {
if (self.popAndStealFromOthers()) |runnable|
return runnable;
}
if (self.tick % 61 == 0) {
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable|
return runnable;
}
if (self.tick % 31 == 0) {
if (self.run_queue.popAndStealUnbounded(&self.run_queue_overflow)) |runnable|
return runnable;
}
if (self.tick % 13 == 0) {
if (self.popAndStealLifo(self)) |runnable|
return runnable;
}
if (self.run_queue.pop()) |runnable|
return runnable;
if (self.popAndStealLifo(self)) |runnable|
return runnable;
if (self.run_queue.popAndStealUnbounded(&self.run_queue_overflow)) |runnable|
return runnable;
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable|
return runnable;
if (self.popAndStealFromOthers()) |runnable|
return runnable;
if (self.run_queue.popAndStealUnbounded(&self.pool.run_queue)) |runnable|
return runnable;
return null;
}
fn popAndStealLifo(self: *Worker, target: *Worker) ?*Runnable {
var run_queue_lifo = @atomicLoad(?*Runnable, &target.run_queue_lifo, .Monotonic);
while (true) {
if (run_queue_lifo == null)
return null;
run_queue_lifo = @cmpxchgWeak(
?*Runnable,
&target.run_queue_lifo,
run_queue_lifo,
null,
.Acquire,
.Monotonic,
) orelse return run_queue_lifo;
}
}
fn popAndStealFromOthers(self: *Worker) ?*Runnable {
var num_workers = blk: {
const counter_value = @atomicLoad(u32, &self.pool.counter, .Monotonic);
const counter = Counter.unpack(counter_value);
break :blk counter.spawned;
};
while (num_workers > 0) : (num_workers -= 1) {
const target = self.next_target orelse blk: {
break :blk @atomicLoad(?*Worker, &self.pool.spawned_queue, .Acquire) orelse {
std.debug.panic("Worker observed empty spawned queue when work-stealing", .{});
};
};
self.next_target = target.spawned_next;
if (target == self)
continue;
if (self.run_queue.popAndStealBounded(&target.run_queue)) |runnable|
return runnable;
if (self.run_queue.popAndStealUnbounded(&target.run_queue_overflow)) |runnable|
return runnable;
if (self.popAndStealLifo(target)) |runnable|
return runnable;
}
return null;
}
};
const UnboundedQueue = struct {
lock: Mutex = .{},
batch: Batch = .{},
shared_size: usize = 0,
fn push(self: *UnboundedQueue, batchable: anytype) void {
const batch = Batch.from(batchable);
if (batch.isEmpty())
return;
const held = self.lock.acquire();
defer held.release();
self.batch.push(batch);
var shared_size = self.shared_size;
shared_size += batch.size;
@atomicStore(usize, &self.shared_size, shared_size, .Release);
}
fn tryAcquireConsumer(self: *UnboundedQueue) ?Consumer {
var shared_size = @atomicLoad(usize, &self.shared_size, .Acquire);
if (shared_size == 0)
return null;
const held = self.lock.acquire();
shared_size = self.shared_size;
if (shared_size == 0) {
held.release();
return null;
}
return Consumer{
.held = held,
.queue = self,
.size = shared_size,
};
}
const Consumer = struct {
held: Mutex.Held,
queue: *UnboundedQueue,
size: usize,
fn release(self: Consumer) void {
@atomicStore(usize, &self.queue.shared_size, self.size, .Release);
self.held.release();
}
fn pop(self: *Consumer) ?*Runnable {
const runnable = self.queue.batch.pop() orelse return null;
self.size -= 1;
return runnable;
}
};
};
const BoundedQueue = struct {
head: usize = 0,
tail: usize = 0,
buffer: [256]*Runnable = undefined,
fn push(self: *BoundedQueue, batchable: anytype) ?Batch {
var batch = Batch.from(batchable);
while (true) : (yieldCpu()) {
if (batch.isEmpty())
return null;
var tail = self.tail;
var head = @atomicLoad(usize, &self.head, .Acquire);
var size = tail -% head;
if (size < self.buffer.len) {
while (size < self.buffer.len) {
const runnable = batch.pop() orelse break;
@atomicStore(*Runnable, &self.buffer[tail % self.buffer.len], runnable, .Unordered);
tail +%= 1;
size += 1;
}
@atomicStore(usize, &self.tail, tail, .Release);
continue;
}
var migrate = self.buffer.len / 2;
if (@cmpxchgWeak(
usize,
&self.head,
head,
head +% migrate,
.AcqRel,
.Acquire,
)) |failed| {
continue;
}
var overflowed = Batch{};
while (migrate > 0) : (migrate -= 1) {
const runnable = self.buffer[head % self.buffer.len];
overflowed.push(runnable);
head +%= 1;
}
overflowed.push(batch);
return overflowed;
}
}
fn pop(self: *BoundedQueue) ?*Runnable {
while (true) : (yieldCpu()) {
const tail = self.tail;
const head = @atomicLoad(usize, &self.head, .Acquire);
const size = tail -% head;
if (size == 0)
return null;
if (@cmpxchgWeak(
usize,
&self.head,
head,
head +% 1,
.AcqRel,
.Acquire,
)) |failed| {
continue;
}
const runnable = self.buffer[head % self.buffer.len];
return runnable;
}
}
fn popAndStealBounded(self: *BoundedQueue, target: *BoundedQueue) ?*Runnable {
if (target == self)
return self.pop();
const tail = self.tail;
const head = @atomicLoad(usize, &self.head, .Acquire);
const size = tail -% head;
if (size != 0)
return self.pop();
while (true) : (yieldThread()) {
const target_head = @atomicLoad(usize, &target.head, .Acquire);
const target_tail = @atomicLoad(usize, &target.tail, .Acquire);
const target_size = target_tail -% target_head;
var steal = target_size - (target_size / 2);
if (steal == 0)
return null;
if (steal > target.buffer.len / 2)
continue;
const first_runnable = @atomicLoad(*Runnable, &target.buffer[target_head % target.buffer.len], .Unordered);
var new_target_head = target_head +% 1;
var new_tail = tail;
steal -= 1;
while (steal > 0) : (steal -= 1) {
const runnable = @atomicLoad(*Runnable, &target.buffer[new_target_head % target.buffer.len], .Unordered);
new_target_head +%= 1;
@atomicStore(*Runnable, &self.buffer[new_tail % self.buffer.len], runnable, .Unordered);
new_tail +%= 1;
}
if (@cmpxchgWeak(
usize,
&target.head,
target_head,
new_target_head,
.AcqRel,
.Acquire,
)) |failed| {
continue;
}
@atomicStore(usize, &self.tail, new_tail, .Release);
return first_runnable;
}
}
fn popAndStealUnbounded(self: *BoundedQueue, target: *UnboundedQueue) ?*Runnable {
var consumer = target.tryAcquireConsumer() orelse return null;
defer consumer.release();
const first_runnable = consumer.pop() orelse return null;
var tail = self.tail;
var head = @atomicLoad(usize, &self.head, .Acquire);
var size = tail -% head;
while (size < self.buffer.len) {
const runnable = consumer.pop() orelse break;
@atomicStore(*Runnable, &self.buffer[tail % self.buffer.len], runnable, .Unordered);
tail +%= 1;
size += 1;
}
@atomicStore(usize, &self.tail, tail, .Release);
return first_runnable;
}
};
pub const Runnable = struct {
next: ?*Runnable = null,
runFn: fn (*Runnable) void,
pub fn run(self: *Runnable) void {
return (self.runFn)(self);
}
};
pub const Batch = struct {
head: ?*Runnable = null,
tail: *Runnable = undefined,
size: usize = 0,
pub fn from(batchable: anytype) Batch {
return switch (@TypeOf(batchable)) {
Batch => batchable,
?*Runnable => from(batchable orelse return Batch{}),
*Runnable => {
batchable.next = null;
return Batch{
.head = batchable,
.tail = batchable,
.size = 1,
};
},
else => |typ| @compileError(@typeName(typ) ++
" cannot be converted into " ++
@typeName(Batch)),
};
}
pub fn isEmpty(self: Batch) bool {
return self.head == null;
}
pub const push = pushBack;
pub fn pushBack(self: *Batch, batchable: anytype) void {
const batch = from(batchable);
if (batch.isEmpty())
return;
if (self.isEmpty()) {
self.* = batch;
} else {
self.tail.next = batch.head;
self.tail = batch.tail;
self.size += batch.size;
}
}
pub fn pushFront(self: *Batch, batchable: anytype) void {
const batch = from(batchable);
if (batch.isEmpty())
return;
if (self.isEmpty()) {
self.* = batch;
} else {
batch.tail.next = self.head;
self.head = batch.head;
self.size += batch.size;
}
}
pub const pop = popFront;
pub fn popFront(self: *Batch) ?*Runnable {
const runnable = self.head orelse return null;
self.head = runnable.next;
self.size -= 1;
return runnable;
}
};
const Semaphore = struct {
lock: Mutex = .{},
permits: usize = 0,
waiters: ?*Waiter = null,
const Waiter = struct {
next: ?*Waiter = null,
tail: *Waiter = undefined,
event: Event = .{},
permits: usize,
};
fn init(permits: usize) Semaphore {
return .{ .permits = permits };
}
fn wait(self: *Semaphore, permits: usize) void {
const held = self.lock.acquire();
if (self.permits >= permits) {
self.permits -= permits;
held.release();
return;
}
var waiter = Waiter{ .permits = permits };
if (self.waiters) |head| {
head.tail.next = &waiter;
head.tail = &waiter;
} else {
self.waiters = &waiter;
waiter.tail = &waiter;
}
held.release();
waiter.event.wait();
}
fn post(self: *Semaphore, permits: usize) error{Overflow}!void {
var waiters: ?*Waiter = null;
{
const held = self.lock.acquire();
defer held.release();
if (@addWithOverflow(usize, self.permits, permits, &self.permits))
return error.Overflow;
while (self.waiters) |waiter| {
if (waiter.permits > self.permits)
break;
self.waiters = waiter.next;
if (self.waiters) |new_waiter|
new_waiter.tail = waiter.tail;
self.permits -= waiter.permits;
waiter.next = waiters;
waiters = waiter;
}
}
while (waiters) |waiter| {
waiters = waiter.next;
waiter.event.notify();
}
}
};
const Mutex = if (std.builtin.os.tag == .windows)
struct {
srwlock: usize = 0,
pub fn acquire(self: *Mutex) Held {
AcquireSRWLockExclusive(&self.srwlock);
return Mutex{ .mutex = self };
}
pub const Held = struct {
mutex: *Mutex,
pub fn release(self: Held) void {
ReleaseSRWLockExclusive(&self.mutex.srwlock);
}
};
extern "kernel32" fn AcquireSRWLockExclusive(
srwlock: *?system.PVOID,
) callconv(system.WINAPI) void;
extern "kernel32" fn ReleaseSRWLockExclusive(
srwlock: *?system.PVOID,
) callconv(system.WINAPI) void;
}
else if (comptime std.Target.current.isDarwin())
struct {
lock: u32 = 0,
pub fn acquire(self: *Mutex) Held {
os_unfair_lock_lock(&self.lock);
return Held{ .mutex = self };
}
pub const Held = struct {
mutex: *Mutex,
pub fn release(self: Held) void {
os_unfair_lock_unlock(&self.mutex.lock);
}
};
extern "c" fn os_unfair_lock_lock(
unfair_lock: *u32,
) callconv(.C) void;
extern "c" fn os_unfair_lock_unlock(
unfair_lock: *u32,
) callconv(.C) void;
}
else if (std.builtin.os.tag == .linux)
struct {
state: i32 = UNLOCKED,
const UNLOCKED: i32 = 0;
const LOCKED: i32 = 1;
const WAITING: i32 = 2;
pub fn acquire(self: *Mutex) Held {
const state = @atomicRmw(i32, &self.state, .Xchg, LOCKED, .Acquire);
if (state != UNLOCKED)
self.acquireSlow(state);
return Held{ .mutex = self };
}
pub const Held = struct {
mutex: *Mutex,
pub fn release(self: Held) void {
switch (@atomicRmw(i32, &self.mutex.state, .Xchg, UNLOCKED, .Release)) {
UNLOCKED => unreachable, // unlocked an unlocked mutex
LOCKED => {},
WAITING => self.mutex.releaseSlow(),
else => unreachable,
}
}
};
fn acquireSlow(self: *Mutex, current_state: i32) void {
@setCold(true);
var wait_state = current_state;
while (true) {
var spin: u8 = 0;
while (spin < 5) : (spin += 1) {
switch (@atomicLoad(i32, &self.state, .Monotonic)) {
UNLOCKED => _ = @cmpxchgWeak(
i32,
&self.state,
UNLOCKED,
wait_state,
.Acquire,
.Monotonic,
) orelse return,
LOCKED => {},
WAITING => break,
else => unreachable,
}
if (spin < 4) {
var pause: u8 = 30;
while (pause > 0) : (pause -= 1)
yieldCpu();
} else {
yieldThread();
}
}
const state = @atomicRmw(i32, &self.state, .Xchg, WAITING, .Acquire);
if (state == UNLOCKED)
return;
wait_state = WAITING;
switch (system.getErrno(system.futex_wait(
&self.state,
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAIT,
WAITING,
null,
))) {
0 => {},
system.EINTR => {},
system.EAGAIN => {},
else => unreachable,
}
}
}
fn releaseSlow(self: *Mutex) void {
@setCold(true);
while (true) {
return switch (system.getErrno(system.futex_wake(
&self.state,
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAKE,
@as(i32, 1),
))) {
0 => {},
system.EINTR => continue,
system.EFAULT => {},
else => unreachable,
};
}
}
}
else
struct {
locked: bool = false,
pub fn acquire(self: *Mutex) Held {
while (@atomicRmw(bool, &self.locked, .Xchg, true, .Acquire))
yieldThread();
return Held{ .mutex = self };
}
pub const Held = struct {
mutex: *Mutex,
pub fn release(self: Held) void {
@atomicStore(bool, &self.mutex.locked, false, .Release);
}
};
};
const Event = if (std.builtin.os.tag == .windows)
struct {
key: u32 = undefined,
pub fn wait(self: *Event) void {
const status = NtWaitForKeyedEvent(null, &self.key, system.FALSE, null);
std.debug.assert(status == .SUCCESS);
}
pub fn notify(self: *Event) void {
const status = NtReleaseKeyedEvent(null, &self.key, system.FALSE, null);
std.debug.assert(status == .SUCCESS);
}
extern "NtDll" fn NtWaitForKeyedEvent(
handle: ?system.HANDLE,
key: ?*const u32,
alertable: system.BOOLEAN,
timeout: ?*const system.LARGE_INTEGER,
) callconv(system.WINAPI) system.NTSTATUS;
extern "NtDll" fn NtReleaseKeyedEvent(
handle: ?system.HANDLE,
key: ?*const u32,
alertable: system.BOOLEAN,
timeout: ?*const system.LARGE_INTEGER,
) callconv(system.WINAPI) system.NTSTATUS;
}
else if (comptime std.Target.current.isDarwin())
struct {
state: enum(u32) {
pending = 0,
notified,
} = .pending,
pub fn wait(self: *Event) void {
while (true) {
switch (@atomicLoad(@TypeOf(self.state), &self.state, .Acquire)) {
.pending => {},
.notified => {
@atomicStore(@TypeOf(self.state), &self.state, .pending, .Monotonic);
return;
},
}
const status = __ulock_wait(
UL_COMPARE_AND_WAIT | ULF_NO_ERRNO,
@ptrCast(?*const c_void, &self.state),
@enumToInt(@TypeOf(self.state).pending),
~@as(u32, 0),
);
if (status < 0) {
switch (-status) {
system.EINTR => {},
else => unreachable,
}
}
}
}
pub fn notify(self: *Event) void {
@atomicStore(@TypeOf(self.state), &self.state, .notified, .Release);
while (true) {
const status = __ulock_wake(
UL_COMPARE_AND_WAIT | ULF_NO_ERRNO,
@ptrCast(?*const c_void, &self.state),
@as(u32, 0),
);
if (status < 0) {
switch (-status) {
system.ENOENT => {},
system.EINTR => continue,
else => unreachable,
}
}
return;
}
}
const ULF_NO_ERRNO = 0x1000000;
const UL_COMPARE_AND_WAIT = 0x1;
extern "c" fn __ulock_wait(
operation: u32,
address: ?*const c_void,
value: u64,
timeout_us: u32,
) callconv(.C) c_int;
extern "c" fn __ulock_wake(
operation: u32,
address: ?*const c_void,
value: u64,
) callconv(.C) c_int;
}
else if (std.builtin.os.tag == .linux)
struct {
state: enum(i32) {
pending,
notified,
} = .pending,
pub fn wait(self: *Event) void {
while (true) {
switch (@atomicLoad(@TypeOf(self.state), &self.state, .Acquire)) {
.pending => {},
.notified => {
@atomicStore(@TypeOf(self.state), &self.state, .pending, .Monotonic);
return;
},
}
switch (system.getErrno(system.futex_wait(
@ptrCast(*const i32, &self.state),
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAIT,
@enumToInt(@TypeOf(self.state).pending),
null,
))) {
0 => {},
system.EINTR => {},
system.EAGAIN => {},
else => unreachable,
}
}
}
pub fn notify(self: *Event) void {
@atomicStore(@TypeOf(self.state), &self.state, .notified, .Release);
while (true) {
return switch (system.getErrno(system.futex_wake(
@ptrCast(*const i32, &self.state),
system.FUTEX_PRIVATE_FLAG | system.FUTEX_WAKE,
@as(i32, 1),
))) {
0 => {},
system.EINTR => continue,
system.EFAULT => {},
else => unreachable,
};
}
}
}
else
struct {
notified: bool = false,
pub fn wait(self: *Event) void {
while (!@atomicLoad(bool, &self.notified, .Acquire))
yieldThread();
@atomicStore(bool, &self.notified, false, .Monotonic);
}
pub fn notify(self: *Event) void {
@atomicStore(bool, &self.notified, true, .Release);
}
};
const yieldThread = if (std.builtin.os.tag == .windows)
struct {
fn yield() void {
system.kernel32.Sleep(0);
}
}.yield
else if (comptime std.Target.current.isDarwin())
struct {
fn yield() void {
_ = thread_switch(MACH_PORT_NULL, SWITCH_OPTION_DEPRESS, 1);
}
const MACH_PORT_NULL = 0;
const SWITCH_OPTION_DEPRESS = 1;
// https://www.gnu.org/software/hurd/gnumach-doc/Hand_002dOff-Scheduling.html
extern "c" fn thread_switch(
thread: usize,
options: c_int,
timeout_ms: c_int,
) callconv(.C) c_int;
}.yield
else if (std.builtin.os.tag == .linux or std.builtin.link_libc)
struct {
fn yield() void {
_ = system.sched_yield();
}
}.yield
else
yieldCpu;
fn yieldCpu() void {
switch (std.builtin.arch) {
.i386, .x86_64 => asm volatile ("pause"),
.arm, .aarch64 => asm volatile ("yield"),
else => {},
}
}
const std = @import("std");
const linux = std.os.linux;
pub fn main() !void {
var ring: Ring = undefined;
try ring.init();
defer ring.deinit();
var server: Server = undefined;
try server.init(12345);
defer server.deinit();
var frame = async server.run(&ring);
while (true) {
try ring.poll();
}
}
const Ring = struct {
inner: linux.IO_Uring,
queue: std.TailQueue(void),
fn init(self: *Ring) !void {
self.inner = try linux.IO_Uring.init(512, 0);
self.queue = .{};
}
fn deinit(self: *Ring) void {
self.inner.deinit();
}
const Completion = struct {
onComplete: anyframe,
ring_queue: std.TailQueue(void).Node = .{ .data = {} },
result: i32 = undefined,
};
fn flushCompletions(self: *Ring) void {
var chunk: [256]linux.io_uring_cqe = undefined;
while (true) {
const found = self.inner.copy_cqes(&chunk, 0) catch unreachable;
if (found == 0)
break;
for (chunk[0..found]) |cqe| {
const completion = @intToPtr(*Ring.Completion, @intCast(usize, cqe.user_data));
completion.result = cqe.res;
self.queue.append(&completion.ring_queue);
}
}
}
fn getSubmission(self: *Ring) *linux.io_uring_sqe {
while (true) {
return self.inner.get_sqe() catch {
var completion = Ring.Completion{
.onComplete = @frame(),
};
self.queue.append(&completion.ring_queue);
suspend;
continue;
};
}
}
fn poll(self: *Ring) !void {
while (self.queue.popFirst()) |node| {
const completion = @fieldParentPtr(Completion, "ring_queue", node);
resume completion.onComplete;
}
_ = try self.inner.submit_and_wait(1);
self.flushCompletions();
}
};
const Server = struct {
fd: std.os.socket_t,
gpa: std.heap.GeneralPurposeAllocator(.{}),
clients: std.TailQueue(void),
pub fn init(self: *Server, comptime port: u16) !void {
self.fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP);
errdefer std.os.close(self.fd);
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable;
try std.os.setsockopt(self.fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
try std.os.bind(self.fd, &addr.any, addr.getOsSockLen());
try std.os.listen(self.fd, 128);
self.gpa = .{};
self.clients = .{};
std.debug.warn("Listening on :{}\n", .{port});
}
pub fn deinit(self: *Server) void {
// TODO: cancel outstanding accept call?
// TODO: kill current clients?
std.os.close(self.fd);
std.debug.warn("Server closed\n", .{});
}
pub fn run(self: *Server, ring: *Ring) !void {
errdefer self.deinit();
while (true) {
const result = self.accept(ring);
switch (if (result < 0) -result else @as(i32, 0)) {
0 => {},
std.os.EINTR => continue,
else => {
return std.os.unexpectedErrno(@intCast(usize, -result));
},
}
const client_fd: std.os.socket_t = result;
const client = Client.init(self, client_fd) catch |err| {
std.os.close(client_fd);
std.debug.warn("Failed to start client: {}\n", .{client_fd});
continue;
};
client.frame = async client.run(ring);
self.clients.append(&client.server_clients);
}
}
fn accept(self: *Server, ring: *Ring) i32 {
const sqe = ring.getSubmission();
sqe.* = std.mem.zeroes(@TypeOf(sqe.*));
sqe.opcode = .ACCEPT;
sqe.fd = self.fd;
sqe.rw_flags = std.os.SOCK_CLOEXEC;
var completion = Ring.Completion{
.onComplete = @frame(),
};
sqe.user_data = @ptrToInt(&completion);
suspend;
return completion.result;
}
};
const Client = struct {
server: *Server,
server_clients: std.TailQueue(void).Node,
fd: std.os.socket_t,
reader: Reader,
writer: Writer,
is_closed: bool,
frame: @Frame(run),
const HTTP_CLRF = "\r\n\r\n";
const HTTP_RESPONSE =
"HTTP/1.1 200 Ok\r\n" ++
"Content-Length: 10\r\n" ++
"Content-Type: text/plain; charset=utf8\r\n" ++
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++
"Server: fasthttp\r\n" ++
"\r\n" ++
"HelloWorld";
pub fn init(server: *Server, fd: std.os.socket_t) !*Client {
const allocator = &server.gpa.allocator;
const self = try allocator.create(Client);
errdefer allocator.destroy(self);
const SOL_TCP = 6;
const TCP_NODELAY = 1;
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1)));
self.* = .{
.server = server,
.server_clients = .{ .data = {} },
.fd = fd,
.reader = .{ .fd = fd },
.writer = .{ .fd = fd },
.is_closed = false,
.frame = undefined,
};
return self;
}
pub fn deinit(self: *Client) void {
std.os.close(self.fd);
self.server.clients.remove(&self.server_clients);
const allocator = &self.server.gpa.allocator;
suspend {
allocator.destroy(self);
}
}
pub fn run(self: *Client, ring: *Ring) Reader.RunError!void {
Reader.run(self, ring) catch |err| switch (err) {
error.ConnectionResetByPeer => {},
else => return err,
};
}
const Reader = struct {
state: enum { read, stop } = .stop,
buffer: Buffer = Buffer.init(),
fd: i32,
const Buffer = std.fifo.LinearFifo(u8, .{ .Static = 4096 });
const ReaderError = std.os.RecvFromError || error{
closed,
Eof,
HttpRequestTooLarge,
};
const RunError = ReaderError;
fn run(client: *Client, ring: *Ring) RunError!void {
const self = &client.reader;
// const client = @fieldParentPtr(Client, "reader", self);
errdefer {
self.state = .stop;
client.is_closed = true;
if (client.writer.state == .idle)
client.writer.state = .stop;
if (client.writer.state == .stop)
client.deinit();
}
while (true) {
const result = self.read(ring);
if (client.is_closed)
return error.closed;
switch (if (result < 0) @intCast(u12, -result) else @as(u12, 0)) {
0 => {},
std.os.EINTR => continue,
std.os.ENOMEM => return error.SystemResources,
std.os.ECONNRESET => return error.ConnectionResetByPeer,
else => |err| return std.os.unexpectedErrno(err),
}
const bytes = @intCast(usize, result);
if (bytes == 0)
return error.Eof;
self.buffer.update(bytes);
while (true) {
self.buffer.realign();
if (std.mem.indexOf(u8, self.buffer.readableSlice(0), HTTP_CLRF)) |parsed| {
self.buffer.discard(bytes);
try Writer.write(client, ring, HTTP_RESPONSE);
continue;
}
if (self.buffer.writableLength() == 0)
return error.HttpRequestTooLarge;
break;
}
}
}
fn read(self: *Reader, ring: *Ring) i32 {
const sqe = ring.getSubmission();
sqe.* = std.mem.zeroes(@TypeOf(sqe.*));
sqe.opcode = .RECV;
sqe.fd = self.fd;
// sqe.fd = @fieldParentPtr(Client, "reader", self).fd;
const slice = self.buffer.writableSlice(0);
sqe.addr = @ptrToInt(slice.ptr);
sqe.len = @truncate(u32, slice.len);
self.state = .read;
var completion = Ring.Completion{
.onComplete = @frame(),
};
sqe.user_data = @ptrToInt(&completion);
suspend;
return completion.result;
}
};
const Writer = struct {
state: enum { write, idle, stop } = .idle,
buffer: Buffer = Buffer.init(),
fd: i32,
const Buffer = std.fifo.LinearFifo(u8, .{ .Static = 4096 });
fn flush(self: *Writer, ring: *Ring) !void {
switch (self.state) {
.idle => {
self.state = .write;
while (true) {
const result = self.rawWrite(ring);
switch (if (result < 0) -result else @as(i32, 0)) {
0 => {},
std.os.EINTR => continue,
else => {
return std.os.unexpectedErrno(@intCast(usize, -result));
},
}
const bytes_written = @intCast(usize, result);
self.buffer.discard(bytes_written);
break;
}
self.state = .idle;
},
.write => {
@panic("TODO: wait for existing write to finish");
},
.stop => return error.closed,
}
}
fn write(client: *Client, ring: *Ring, src: []const u8) !void {
const self = &client.writer;
// const client = @fieldParentPtr(Client, "writer", self);
errdefer {
self.state = .stop;
client.is_closed = true;
if (client.reader.state == .stop)
client.deinit();
}
var src_left = src;
while (src_left.len > 0) {
const writable_slice = self.buffer.writableSlice(0);
if (writable_slice.len == 0) {
try self.flush(ring);
continue;
}
const n = std.math.min(writable_slice.len, src_left.len);
std.mem.copy(u8, writable_slice, src_left[0..n]);
self.buffer.update(n);
src_left = src_left[n..];
}
while (self.buffer.readableLength() > 0) {
try self.flush(ring);
}
}
fn rawWrite(self: *Writer, ring: *Ring) i32 {
const sqe = ring.getSubmission();
sqe.* = std.mem.zeroes(@TypeOf(sqe.*));
sqe.opcode = .SEND;
sqe.fd = self.fd;
// sqe.fd = @fieldParentPtr(Client, "writer", self).fd;
const slice = self.buffer.readableSlice(0);
sqe.addr = @ptrToInt(slice.ptr);
sqe.len = @truncate(u32, slice.len);
var completion = Ring.Completion{
.onComplete = @frame(),
};
sqe.user_data = @ptrToInt(&completion);
suspend;
return completion.result;
}
};
};
const std = @import("std");
const linux = std.os.linux;
pub fn main() !void {
var ring: Ring = undefined;
try ring.init();
defer ring.deinit();
var server: Server = undefined;
try server.init(&ring, 12345);
defer server.deinit();
while (true) {
try ring.poll();
}
}
const Ring = struct {
inner: linux.IO_Uring,
head: ?*Completion,
tail: ?*Completion,
fn init(self: *Ring) !void {
self.inner = try linux.IO_Uring.init(512, 0);
self.head = null;
self.tail = null;
}
fn deinit(self: *Ring) void {
self.inner.deinit();
}
const Completion = struct {
result: i32 = undefined,
next: ?*Completion = null,
onComplete: fn (*Ring, *Completion) void,
};
fn flushCompletions(self: *Ring) !void {
var chunk: [256]linux.io_uring_cqe = undefined;
while (true) {
const found = try self.inner.copy_cqes(&chunk, 0);
if (found == 0)
break;
for (chunk[0..found]) |cqe| {
const completion = @intToPtr(*Ring.Completion, @intCast(usize, cqe.user_data));
completion.result = cqe.res;
if (self.head == null)
self.head = completion;
if (self.tail) |tail|
tail.next = completion;
completion.next = null;
self.tail = completion;
}
}
}
fn getSubmission(self: *Ring) !*linux.io_uring_sqe {
while (true) {
return self.inner.get_sqe() catch {
try self.flushCompletions();
_ = try self.inner.submit();
continue;
};
}
}
fn poll(self: *Ring) !void {
while (self.head) |completion| {
self.head = completion.next;
if (self.head == null)
self.tail = null;
(completion.onComplete)(self, completion);
}
_ = try self.inner.submit_and_wait(1);
try self.flushCompletions();
}
};
const Server = struct {
fd: std.os.socket_t,
completion: Ring.Completion,
gpa: std.heap.GeneralPurposeAllocator(.{}),
fn init(self: *Server, ring: *Ring, comptime port: u16) !void {
self.fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP);
errdefer std.os.close(self.fd);
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable;
try std.os.setsockopt(self.fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
try std.os.bind(self.fd, &addr.any, addr.getOsSockLen());
try std.os.listen(self.fd, 128);
self.gpa = .{};
self.completion = .{ .onComplete = Server.onCompletion };
try self.submitAccept(ring);
std.debug.warn("Listening on :{}\n", .{port});
}
fn deinit(self: *Server) void {
std.os.close(self.fd);
std.debug.warn("Server closed\n", .{});
}
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void {
const self = @fieldParentPtr(Server, "completion", completion);
self.process(ring, completion.result) catch self.deinit();
}
fn process(self: *Server, ring: *Ring, result: i32) !void {
switch (if (result < 0) -result else @as(i32, 0)) {
0 => {},
std.os.EINTR => {
try self.submitAccept(ring);
return;
},
else => {
return std.os.unexpectedErrno(@intCast(usize, -result));
},
}
const client_fd: std.os.socket_t = result;
Client.init(self, ring, client_fd) catch |err| {
std.os.close(client_fd);
std.debug.warn("Failed to start client: {}\n", .{client_fd});
};
try self.submitAccept(ring);
}
fn submitAccept(self: *Server, ring: *Ring) !void {
const sqe = try ring.getSubmission();
sqe.* = std.mem.zeroes(@TypeOf(sqe.*));
sqe.opcode = .ACCEPT;
sqe.fd = self.fd;
sqe.rw_flags = std.os.SOCK_CLOEXEC;
sqe.user_data = @ptrToInt(&self.completion);
}
};
const Client = struct {
server: *Server,
fd: std.os.socket_t,
reader: Reader,
writer: Writer,
is_closed: bool,
const HTTP_CLRF = "\r\n\r\n";
const HTTP_RESPONSE =
"HTTP/1.1 200 Ok\r\n" ++
"Content-Length: 10\r\n" ++
"Content-Type: text/plain; charset=utf8\r\n" ++
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++
"Server: fasthttp\r\n" ++
"\r\n" ++
"HelloWorld";
fn init(server: *Server, ring: *Ring, fd: std.os.socket_t) !void {
const allocator = &server.gpa.allocator;
const self = try allocator.create(Client);
errdefer allocator.destroy(self);
const SOL_TCP = 6;
const TCP_NODELAY = 1;
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1)));
self.fd = fd;
self.server = server;
self.is_closed = false;
try self.reader.init(ring);
try self.writer.init(ring);
}
fn deinit(self: *Client) void {
std.os.close(self.fd);
const allocator = &self.server.gpa.allocator;
allocator.destroy(self);
}
const Reader = struct {
state: enum { read, stop } = .stop,
completion: Ring.Completion,
buffer: Buffer = Buffer.init(),
const Buffer = std.fifo.LinearFifo(u8, .{ .Static = 4096 });
fn init(self: *Reader, ring: *Ring) !void {
const client = @fieldParentPtr(Client, "reader", self);
self.* = .{ .completion = .{ .onComplete = onCompletion } };
try self.submitRead(ring);
}
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void {
const self = @fieldParentPtr(Reader, "completion", completion);
const client = @fieldParentPtr(Client, "reader", self);
self.process(client, ring, completion.result) catch |err| {
if (self.state == .stop and client.writer.state == .stop)
client.deinit();
};
}
fn process(self: *Reader, client: *Client, ring: *Ring, result: i32) !void {
errdefer {
self.state = .stop;
client.is_closed = true;
if (client.writer.state == .idle)
client.writer.state = .stop;
}
if (client.is_closed)
return error.closed;
switch (if (result < 0) -result else @as(i32, 0)) {
0 => {},
std.os.EINTR => {
try self.submitRead(ring);
return;
},
else => {
return std.os.unexpectedErrno(@intCast(usize, -result));
},
}
const bytes = @intCast(usize, result);
if (bytes == 0)
return error.Eof;
self.buffer.update(bytes);
while (true) {
self.buffer.realign();
if (std.mem.indexOf(u8, self.buffer.readableSlice(0), HTTP_CLRF)) |parsed| {
self.buffer.discard(bytes);
client.writer.buffer.writeAssumeCapacity(HTTP_RESPONSE);
if (client.writer.state == .idle) {
try client.writer.submitWrite(ring);
}
continue;
}
if (self.buffer.writableLength() == 0)
return error.HttpRequestTooLarge;
try self.submitRead(ring);
return;
}
}
fn submitRead(self: *Reader, ring: *Ring) !void {
const sqe = try ring.getSubmission();
sqe.* = std.mem.zeroes(@TypeOf(sqe.*));
sqe.opcode = .RECV;
sqe.fd = @fieldParentPtr(Client, "reader", self).fd;
const slice = self.buffer.writableSlice(0);
sqe.addr = @ptrToInt(slice.ptr);
sqe.len = @truncate(u32, slice.len);
sqe.user_data = @ptrToInt(&self.completion);
self.state = .read;
}
};
const Writer = struct {
state: enum { write, idle, stop } = .idle,
completion: Ring.Completion,
buffer: Buffer = Buffer.init(),
const Buffer = std.fifo.LinearFifo(u8, .{ .Static = NUM_RESPONSE_CHUNKS * HTTP_RESPONSE.len });
const NUM_RESPONSE_CHUNKS = 128;
fn init(self: *Writer, ring: *Ring) !void {
self.* = .{ .completion = .{ .onComplete = onCompletion } };
}
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void {
const self = @fieldParentPtr(Writer, "completion", completion);
const client = @fieldParentPtr(Client, "writer", self);
self.process(client, ring, completion.result) catch |err| {
if (self.state == .stop and client.reader.state == .stop)
client.deinit();
};
}
fn process(self: *Writer, client: *Client, ring: *Ring, result: i32) !void {
errdefer {
self.state = .stop;
client.is_closed = true;
}
if (client.is_closed)
return error.closed;
if (self.state == .idle) unreachable;
switch (if (result < 0) -result else @as(i32, 0)) {
0 => {},
std.os.EINTR => {
try self.submitWrite(ring);
return;
},
else => {
return std.os.unexpectedErrno(@intCast(usize, -result));
},
}
const bytes_written = @intCast(usize, result);
self.buffer.discard(bytes_written);
if (self.buffer.readableLength() == 0) {
self.state = .idle;
return;
}
try self.submitWrite(ring);
}
fn submitWrite(self: *Writer, ring: *Ring) !void {
const sqe = try ring.getSubmission();
sqe.* = std.mem.zeroes(@TypeOf(sqe.*));
sqe.opcode = .SEND;
sqe.fd = @fieldParentPtr(Client, "writer", self).fd;
const slice = self.buffer.readableSlice(0);
sqe.addr = @ptrToInt(slice.ptr);
sqe.len = @truncate(u32, slice.len);
sqe.user_data = @ptrToInt(&self.completion);
self.state = .write;
}
};
};
const std = @import("std");
const linux = std.os.linux;
pub fn main() !void {
var ring: Ring = undefined;
try ring.init();
defer ring.deinit();
var server: Server = undefined;
try server.init(&ring, 12345);
defer server.deinit();
while (true) {
try ring.poll();
}
}
const Ring = struct {
inner: linux.IO_Uring,
head: ?*Completion,
tail: ?*Completion,
fn init(self: *Ring) !void {
self.inner = try linux.IO_Uring.init(512, 0);
self.head = null;
self.tail = null;
}
fn deinit(self: *Ring) void {
self.inner.deinit();
}
const Completion = struct {
result: i32 = undefined,
next: ?*Completion = null,
onComplete: fn (*Ring, *Completion) void,
};
fn flushCompletions(self: *Ring) !void {
var chunk: [256]linux.io_uring_cqe = undefined;
while (true) {
const found = try self.inner.copy_cqes(&chunk, 0);
if (found == 0)
break;
for (chunk[0..found]) |cqe| {
const completion = @intToPtr(*Ring.Completion, @intCast(usize, cqe.user_data));
completion.result = cqe.res;
if (self.head == null)
self.head = completion;
if (self.tail) |tail|
tail.next = completion;
completion.next = null;
self.tail = completion;
}
}
}
fn getSubmission(self: *Ring) !*linux.io_uring_sqe {
while (true) {
return self.inner.get_sqe() catch {
try self.flushCompletions();
_ = try self.inner.submit();
continue;
};
}
}
fn poll(self: *Ring) !void {
while (self.head) |completion| {
self.head = completion.next;
if (self.head == null)
self.tail = null;
(completion.onComplete)(self, completion);
}
_ = try self.inner.submit_and_wait(1);
try self.flushCompletions();
}
};
const Server = struct {
fd: std.os.socket_t,
completion: Ring.Completion,
gpa: std.heap.GeneralPurposeAllocator(.{}),
fn init(self: *Server, ring: *Ring, comptime port: u16) !void {
self.fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP);
errdefer std.os.close(self.fd);
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable;
try std.os.setsockopt(self.fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
try std.os.bind(self.fd, &addr.any, addr.getOsSockLen());
try std.os.listen(self.fd, 128);
self.gpa = .{};
self.completion = .{ .onComplete = Server.onCompletion };
try self.submitAccept(ring);
std.debug.warn("Listening on :{}\n", .{port});
}
fn deinit(self: *Server) void {
std.os.close(self.fd);
std.debug.warn("Server closed\n", .{});
}
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void {
const self = @fieldParentPtr(Server, "completion", completion);
self.process(ring, completion.result) catch self.deinit();
}
fn process(self: *Server, ring: *Ring, result: i32) !void {
switch (if (result < 0) -result else @as(i32, 0)) {
0 => {},
std.os.EINTR => {
try self.submitAccept(ring);
return;
},
else => {
return std.os.unexpectedErrno(@intCast(usize, -result));
},
}
const client_fd: std.os.socket_t = result;
Client.init(self, ring, client_fd) catch |err| {
std.os.close(client_fd);
std.debug.warn("Failed to start client: {}\n", .{client_fd});
};
try self.submitAccept(ring);
}
fn submitAccept(self: *Server, ring: *Ring) !void {
const sqe = try ring.getSubmission();
sqe.* = std.mem.zeroes(@TypeOf(sqe.*));
sqe.opcode = .ACCEPT;
sqe.fd = self.fd;
sqe.rw_flags = std.os.SOCK_CLOEXEC;
sqe.user_data = @ptrToInt(&self.completion);
}
};
const Client = struct {
server: *Server,
fd: std.os.socket_t,
reader: Reader,
writer: Writer,
is_closed: bool,
const HTTP_CLRF = "\r\n\r\n";
const HTTP_RESPONSE =
"HTTP/1.1 200 Ok\r\n" ++
"Content-Length: 10\r\n" ++
"Content-Type: text/plain; charset=utf8\r\n" ++
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++
"Server: fasthttp\r\n" ++
"\r\n" ++
"HelloWorld";
fn init(server: *Server, ring: *Ring, fd: std.os.socket_t) !void {
const allocator = &server.gpa.allocator;
const self = try allocator.create(Client);
errdefer allocator.destroy(self);
const SOL_TCP = 6;
const TCP_NODELAY = 1;
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1)));
self.fd = fd;
self.server = server;
self.is_closed = false;
try self.reader.init(ring);
try self.writer.init(ring);
}
fn deinit(self: *Client) void {
std.os.close(self.fd);
const allocator = &self.server.gpa.allocator;
allocator.destroy(self);
}
const Reader = struct {
state: enum { read, stop } = .stop,
completion: Ring.Completion,
recv_bytes: usize = 0,
recv_buffer: [4096]u8 = undefined,
iovec: std.os.iovec = undefined,
fn init(self: *Reader, ring: *Ring) !void {
const client = @fieldParentPtr(Client, "reader", self);
self.* = .{ .completion = .{ .onComplete = onCompletion } };
self.iovec.iov_base = &self.recv_buffer;
self.iovec.iov_len = self.recv_buffer.len;
try self.submitRead(ring);
}
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void {
const self = @fieldParentPtr(Reader, "completion", completion);
const client = @fieldParentPtr(Client, "reader", self);
self.process(client, ring, completion.result) catch |err| {
if (self.state == .stop and client.writer.state == .stop)
client.deinit();
};
}
fn process(self: *Reader, client: *Client, ring: *Ring, result: i32) !void {
errdefer {
self.state = .stop;
client.is_closed = true;
if (client.writer.state == .idle)
client.writer.state = .stop;
}
if (client.is_closed)
return error.closed;
switch (if (result < 0) -result else @as(i32, 0)) {
0 => {},
std.os.EINTR => {
try self.submitRead(ring);
return;
},
else => {
return std.os.unexpectedErrno(@intCast(usize, -result));
},
}
const bytes = @intCast(usize, result);
self.recv_bytes += bytes;
if (bytes == 0)
return error.Eof;
while (true) {
const request_buffer = self.recv_buffer[0..self.recv_bytes];
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| {
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len)..request_buffer.len];
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer);
self.recv_bytes = unparsed_buffer.len;
client.writer.send_bytes += HTTP_RESPONSE.len;
if (client.writer.state == .idle) {
client.writer.state = .write;
try client.writer.process(client, ring, 0);
}
continue;
}
const readable_buffer = self.recv_buffer[self.recv_bytes..];
if (readable_buffer.len == 0)
return error.HttpRequestTooLarge;
self.iovec.iov_base = readable_buffer.ptr;
self.iovec.iov_len = readable_buffer.len;
try self.submitRead(ring);
return;
}
}
fn submitRead(self: *Reader, ring: *Ring) !void {
const sqe = try ring.getSubmission();
sqe.* = std.mem.zeroes(@TypeOf(sqe.*));
// sqe.opcode = .READV;
sqe.opcode = .RECV;
sqe.fd = @fieldParentPtr(Client, "reader", self).fd;
// sqe.addr = @ptrToInt(&self.iovec);
// sqe.len = 1;
sqe.addr = @ptrToInt(self.iovec.iov_base);
sqe.len = @truncate(u32, self.iovec.iov_len);
sqe.user_data = @ptrToInt(&self.completion);
self.state = .read;
}
};
const Writer = struct {
state: enum { write, idle, stop } = .idle,
completion: Ring.Completion,
send_bytes: usize = 0,
send_partial: usize = 0,
iovec: std.os.iovec_const = undefined,
fn init(self: *Writer, ring: *Ring) !void {
self.* = .{ .completion = .{ .onComplete = onCompletion } };
}
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void {
const self = @fieldParentPtr(Writer, "completion", completion);
const client = @fieldParentPtr(Client, "writer", self);
self.process(client, ring, completion.result) catch |err| {
if (self.state == .stop and client.reader.state == .stop)
client.deinit();
};
}
fn process(self: *Writer, client: *Client, ring: *Ring, result: i32) !void {
errdefer {
self.state = .stop;
client.is_closed = true;
}
if (client.is_closed)
return error.closed;
switch (if (result < 0) -result else @as(i32, 0)) {
0 => {},
std.os.EINTR => {
try self.submitWrite(ring);
return;
},
else => {
return std.os.unexpectedErrno(@intCast(usize, -result));
},
}
const bytes_written = @intCast(usize, result);
self.send_bytes -= bytes_written;
self.send_partial = bytes_written % HTTP_RESPONSE.len;
if (self.send_bytes == 0) {
self.state = .idle;
return;
}
const NUM_RESPONSE_CHUNKS = 128;
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS;
self.iovec.iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial;
self.iovec.iov_len = std.math.min(self.send_bytes, RESPONSE_CHUNK.len - self.send_partial);
try self.submitWrite(ring);
}
fn submitWrite(self: *Writer, ring: *Ring) !void {
const sqe = try ring.getSubmission();
sqe.* = std.mem.zeroes(@TypeOf(sqe.*));
// sqe.opcode = .WRITEV;
sqe.opcode = .SEND;
sqe.fd = @fieldParentPtr(Client, "writer", self).fd;
// sqe.addr = @ptrToInt(&self.iovec);
// sqe.len = 1;
sqe.addr = @ptrToInt(self.iovec.iov_base);
sqe.len = @truncate(u32, self.iovec.iov_len);
sqe.user_data = @ptrToInt(&self.completion);
self.state = .write;
}
};
};
const std = @import("std");
const linux = std.os.linux;
pub fn main() !void {
var ring: Ring = undefined;
try ring.init();
defer ring.deinit();
var server: Server = undefined;
try server.init(&ring, 12345);
defer server.deinit();
while (true) {
try ring.poll();
}
}
const Ring = struct {
inner: linux.IO_Uring,
head: ?*Completion,
tail: ?*Completion,
fn init(self: *Ring) !void {
self.inner = try linux.IO_Uring.init(512, 0);
self.head = null;
self.tail = null;
}
fn deinit(self: *Ring) void {
self.inner.deinit();
}
const Completion = struct {
result: i32 = undefined,
next: ?*Completion = null,
onComplete: fn(*Ring, *Completion) void,
};
fn flushCompletions(self: *Ring) !void {
var chunk: [256]linux.io_uring_cqe = undefined;
while (true) {
const found = try self.inner.copy_cqes(&chunk, 0);
if (found == 0)
break;
for (chunk[0..found]) |cqe| {
const completion = @intToPtr(*Ring.Completion, @intCast(usize, cqe.user_data));
completion.result = cqe.res;
if (self.head == null)
self.head = completion;
if (self.tail) |tail|
tail.next = completion;
completion.next = null;
self.tail = completion;
}
}
}
fn getSubmission(self: *Ring) !*linux.io_uring_sqe {
while (true) {
return self.inner.get_sqe() catch {
try self.flushCompletions();
_ = try self.inner.submit();
continue;
};
}
}
fn poll(self: *Ring) !void {
while (self.head) |completion| {
self.head = completion.next;
if (self.head == null)
self.tail = null;
(completion.onComplete)(self, completion);
}
_ = try self.inner.submit_and_wait(1);
try self.flushCompletions();
}
fn submitPoll(self: *Ring, fd: std.os.fd_t, flags: u32, completion: *Completion) !void {
const sqe = try self.getSubmission();
sqe.* = std.mem.zeroes(@TypeOf(sqe.*));
sqe.opcode = .POLL_ADD;
sqe.fd = fd;
sqe.rw_flags = flags | linux.POLLERR | linux.POLLHUP;
sqe.user_data = @ptrToInt(completion);
}
};
const Server = struct {
fd: std.os.socket_t,
completion: Ring.Completion,
gpa: std.heap.GeneralPurposeAllocator(.{}),
state: enum { poll, accept },
fn init(self: *Server, ring: *Ring, comptime port: u16) !void {
self.fd = try std.os.socket(std.os.AF_INET, std.os.SOCK_STREAM | std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC, std.os.IPPROTO_TCP);
errdefer std.os.close(self.fd);
var addr = comptime std.net.Address.parseIp("127.0.0.1", port) catch unreachable;
try std.os.setsockopt(self.fd, std.os.SOL_SOCKET, std.os.SO_REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
try std.os.bind(self.fd, &addr.any, addr.getOsSockLen());
try std.os.listen(self.fd, 128);
self.gpa = .{};
self.state = .poll;
self.completion = .{ .onComplete = Server.onCompletion };
try self.submitAccept(ring);
std.debug.warn("Listening on :{}\n", .{port});
}
fn deinit(self: *Server) void {
std.os.close(self.fd);
std.debug.warn("Server closed\n", .{});
}
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void {
const self = @fieldParentPtr(Server, "completion", completion);
self.process(ring, completion.result) catch self.deinit();
}
fn process(self: *Server, ring: *Ring, result: i32) !void {
switch (self.state) {
.poll => {
if (result & (linux.POLLERR | linux.POLLHUP) != 0)
return error.Closed;
try self.submitAccept(ring);
},
.accept => {
switch (if (result < 0) -result else @as(i32, 0)) {
0 => {},
std.os.EINTR => {
try self.submitAccept(ring);
return;
},
std.os.EAGAIN => {
try ring.submitPoll(self.fd, linux.POLLIN, &self.completion);
self.state = .poll;
return;
},
else => {
return std.os.unexpectedErrno(@intCast(usize, -result));
},
}
const client_fd: std.os.socket_t = result;
Client.init(self, ring, client_fd) catch |err| {
std.os.close(client_fd);
std.debug.warn("Failed to start client: {}\n", .{client_fd});
};
try self.submitAccept(ring);
},
}
}
fn submitAccept(self: *Server, ring: *Ring) !void {
const sqe = try ring.getSubmission();
sqe.* = std.mem.zeroes(@TypeOf(sqe.*));
sqe.opcode = .ACCEPT;
sqe.fd = self.fd;
sqe.rw_flags = std.os.SOCK_NONBLOCK | std.os.SOCK_CLOEXEC;
sqe.user_data = @ptrToInt(&self.completion);
self.state = .accept;
}
};
const Client = struct {
server: *Server,
fd: std.os.socket_t,
reader: Reader,
writer: Writer,
is_closed: bool,
const HTTP_CLRF = "\r\n\r\n";
const HTTP_RESPONSE =
"HTTP/1.1 200 Ok\r\n" ++
"Content-Length: 10\r\n" ++
"Content-Type: text/plain; charset=utf8\r\n" ++
"Date: Thu, 19 Nov 2020 14:26:34 GMT\r\n" ++
"Server: fasthttp\r\n" ++
"\r\n" ++
"HelloWorld";
fn init(server: *Server, ring: *Ring, fd: std.os.socket_t) !void {
const allocator = &server.gpa.allocator;
const self = try allocator.create(Client);
errdefer allocator.destroy(self);
const SOL_TCP = 6;
const TCP_NODELAY = 1;
try std.os.setsockopt(fd, SOL_TCP, TCP_NODELAY, &std.mem.toBytes(@as(c_int, 1)));
self.fd = fd;
self.server = server;
self.is_closed = false;
try self.reader.init(ring);
try self.writer.init(ring);
}
fn deinit(self: *Client) void {
std.os.close(self.fd);
const allocator = &self.server.gpa.allocator;
allocator.destroy(self);
}
const Reader = struct {
state: enum { poll, read, stop } = .stop,
completion: Ring.Completion,
recv_bytes: usize = 0,
recv_buffer: [4096]u8 = undefined,
iovec: std.os.iovec = undefined,
fn init(self: *Reader, ring: *Ring) !void {
const client = @fieldParentPtr(Client, "reader", self);
self.* = .{ .completion = .{ .onComplete = onCompletion } };
self.iovec.iov_base = &self.recv_buffer;
self.iovec.iov_len = self.recv_buffer.len;
try ring.submitPoll(client.fd, linux.POLLIN, &self.completion);
self.state = .poll;
}
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void {
const self = @fieldParentPtr(Reader, "completion", completion);
const client = @fieldParentPtr(Client, "reader", self);
self.process(client, ring, completion.result) catch |err| {
if (self.state == .stop and client.writer.state == .stop)
client.deinit();
};
}
fn process(self: *Reader, client: *Client, ring: *Ring, result: i32) !void {
errdefer {
self.state = .stop;
client.is_closed = true;
if (client.writer.state == .idle)
client.writer.state = .stop;
}
if (client.is_closed)
return error.closed;
switch (self.state) {
.poll => {
if (result & (linux.POLLERR | linux.POLLHUP) != 0)
return error.Closed;
try self.submitRead(ring);
},
.read => {
switch (if (result < 0) -result else @as(i32, 0)) {
0 => {},
std.os.EINTR => {
try self.submitRead(ring);
return;
},
std.os.EAGAIN => {
try ring.submitPoll(client.fd, linux.POLLIN, &self.completion);
self.state = .poll;
return;
},
else => {
return std.os.unexpectedErrno(@intCast(usize, -result));
},
}
const bytes = @intCast(usize, result);
self.recv_bytes += bytes;
if (bytes == 0)
return error.Eof;
while (true) {
const request_buffer = self.recv_buffer[0..self.recv_bytes];
if (std.mem.indexOf(u8, request_buffer, HTTP_CLRF)) |parsed| {
const unparsed_buffer = self.recv_buffer[(parsed + HTTP_CLRF.len) .. request_buffer.len];
std.mem.copy(u8, &self.recv_buffer, unparsed_buffer);
self.recv_bytes = unparsed_buffer.len;
client.writer.send_bytes += HTTP_RESPONSE.len;
if (client.writer.state == .idle) {
client.writer.state = .write;
try client.writer.process(client, ring, 0);
}
continue;
}
const readable_buffer = self.recv_buffer[self.recv_bytes..];
if (readable_buffer.len == 0)
return error.HttpRequestTooLarge;
self.iovec.iov_base = readable_buffer.ptr;
self.iovec.iov_len = readable_buffer.len;
try ring.submitPoll(client.fd, linux.POLLIN, &self.completion);
self.state = .poll;
return;
}
},
else => {}
}
}
fn submitRead(self: *Reader, ring: *Ring) !void {
const sqe = try ring.getSubmission();
sqe.* = std.mem.zeroes(@TypeOf(sqe.*));
sqe.opcode = .READV;
sqe.fd = @fieldParentPtr(Client, "reader", self).fd;
sqe.addr = @ptrToInt(&self.iovec);
sqe.len = 1;
sqe.user_data = @ptrToInt(&self.completion);
self.state = .read;
}
};
const Writer = struct {
state: enum { poll, write, idle, stop } = .idle,
completion: Ring.Completion,
send_bytes: usize = 0,
send_partial: usize = 0,
iovec: std.os.iovec_const = undefined,
fn init(self: *Writer, ring: *Ring) !void {
self.* = .{ .completion = .{ .onComplete = onCompletion } };
}
fn onCompletion(ring: *Ring, completion: *Ring.Completion) void {
const self = @fieldParentPtr(Writer, "completion", completion);
const client = @fieldParentPtr(Client, "writer", self);
self.process(client, ring, completion.result) catch |err| {
if (self.state == .stop and client.reader.state == .stop)
client.deinit();
};
}
fn process(self: *Writer, client: *Client, ring: *Ring, result: i32) !void {
errdefer {
self.state = .stop;
client.is_closed = true;
}
if (client.is_closed)
return error.closed;
switch (self.state) {
.poll => {
if (result & (linux.POLLERR | linux.POLLHUP) != 0)
return error.Closed;
try self.submitWrite(ring);
},
.write => {
switch (if (result < 0) -result else @as(i32, 0)) {
0 => {},
std.os.EINTR => {
try self.submitWrite(ring);
return;
},
std.os.EAGAIN => {
try ring.submitPoll(client.fd, linux.POLLOUT, &self.completion);
self.state = .poll;
return;
},
else => {
return std.os.unexpectedErrno(@intCast(usize, -result));
},
}
const bytes_written = @intCast(usize, result);
self.send_bytes -= bytes_written;
self.send_partial = bytes_written % HTTP_RESPONSE.len;
if (self.send_bytes == 0) {
self.state = .idle;
return;
}
const NUM_RESPONSE_CHUNKS = 128;
const RESPONSE_CHUNK = HTTP_RESPONSE ** NUM_RESPONSE_CHUNKS;
self.iovec.iov_base = @ptrCast([*]const u8, &RESPONSE_CHUNK[0]) + self.send_partial;
self.iovec.iov_len = std.math.min(self.send_bytes, RESPONSE_CHUNK.len - self.send_partial);
try ring.submitPoll(client.fd, linux.POLLOUT, &self.completion);
self.state = .poll;
},
else => {},
}
}
fn submitWrite(self: *Writer, ring: *Ring) !void {
const sqe = try ring.getSubmission();
sqe.* = std.mem.zeroes(@TypeOf(sqe.*));
sqe.opcode = .WRITEV;
sqe.fd = @fieldParentPtr(Client, "writer", self).fd;
sqe.addr = @ptrToInt(&self.iovec);
sqe.len = 1;
sqe.user_data = @ptrToInt(&self.completion);
self.state = .write;
}
};
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment