From ff2de7176a6fcaafcf5d2c4501f8224a3269cb42 Mon Sep 17 00:00:00 2001 From: Muki Kiboigo Date: Fri, 18 Oct 2024 03:38:49 -0700 Subject: [PATCH] http/router: asynchronous fs serving --- examples/http/basic/main.zig | 3 +- examples/http/fs/main.zig | 28 +- src/core/lib.zig | 1 - src/core/server.zig | 916 ------------------------ src/core/zprovision.zig | 41 -- src/http/context.zig | 2 +- src/http/lib.zig | 1 - src/http/protocol.zig | 78 -- src/http/provision.zig | 66 ++ src/http/router.zig | 151 +++- src/http/server.zig | 1296 ++++++++++++++++++++++++++++------ 11 files changed, 1306 insertions(+), 1277 deletions(-) delete mode 100644 src/core/server.zig delete mode 100644 src/core/zprovision.zig delete mode 100644 src/http/protocol.zig create mode 100644 src/http/provision.zig diff --git a/examples/http/basic/main.zig b/examples/http/basic/main.zig index 088b134..9181d97 100644 --- a/examples/http/basic/main.zig +++ b/examples/http/basic/main.zig @@ -34,11 +34,12 @@ pub fn main() !void { }.handler_fn)); var server = http.Server(.plain, .auto).init(.{ + .router = &router, .allocator = allocator, .threading = .single, }); defer server.deinit(); try server.bind(host, port); - try server.listen(.{ .router = &router }); + try server.listen(); } diff --git a/examples/http/fs/main.zig b/examples/http/fs/main.zig index 5380ff4..13bf4f3 100644 --- a/examples/http/fs/main.zig +++ b/examples/http/fs/main.zig @@ -7,13 +7,15 @@ pub fn main() !void { const host: []const u8 = "0.0.0.0"; const port: u16 = 9862; - const allocator = std.heap.page_allocator; + var gpa = std.heap.GeneralPurposeAllocator(.{ .thread_safe = true }){ .backing_allocator = std.heap.c_allocator }; + const allocator = gpa.allocator(); + defer _ = gpa.deinit(); var router = http.Router.init(allocator); defer router.deinit(); try router.serve_route("/", http.Route.init().get(struct { - pub fn handler_fn(_: http.Request, response: *http.Response, _: http.Context) void { + pub fn handler_fn(ctx: *http.Context) void { const body = \\ \\ @@ -23,7 +25,7 @@ pub fn main() !void { \\ ; - response.set(.{ + ctx.respond(.{ .status = .OK, .mime = http.Mime.HTML, .body = body[0..], @@ -31,9 +33,27 @@ pub fn main() !void { } }.handler_fn)); + try router.serve_route("/kill", http.Route.init().get(struct { + pub fn handler_fn(ctx: *http.Context) void { + ctx.runtime.stop(); + + ctx.respond(.{ + .status = .OK, + .mime = http.Mime.HTML, + .body = "", + }); + } + }.handler_fn)); + try router.serve_fs_dir("/static", "./examples/http/fs/static"); - var server = http.Server(.plain, .auto).init(.{ .allocator = allocator }); + var server = http.Server(.plain, .auto).init(.{ + .allocator = allocator, + .threading = .auto, + .size_connections_max = 256, + }); + defer server.deinit(); + try server.bind(host, port); try server.listen(.{ .router = &router }); } diff --git a/src/core/lib.zig b/src/core/lib.zig index e091ee8..f21f77a 100644 --- a/src/core/lib.zig +++ b/src/core/lib.zig @@ -1,3 +1,2 @@ pub const Job = @import("job.zig").Job; pub const Pseudoslice = @import("pseudoslice.zig").Pseudoslice; -pub const Server = @import("server.zig").Server; diff --git a/src/core/server.zig b/src/core/server.zig deleted file mode 100644 index 84d7f65..0000000 --- a/src/core/server.zig +++ /dev/null @@ -1,916 +0,0 @@ -const std = @import("std"); -const builtin = @import("builtin"); -const assert = std.debug.assert; -const log = std.log.scoped(.@"zzz/server"); - -const Pseudoslice = @import("pseudoslice.zig").Pseudoslice; -const ZProvision = @import("zprovision.zig").ZProvision; - -const TLSFileOptions = @import("../tls/lib.zig").TLSFileOptions; -const TLSContext = @import("../tls/lib.zig").TLSContext; -const TLS = @import("../tls/lib.zig").TLS; - -const Pool = @import("tardy").Pool; -pub const Threading = @import("tardy").TardyThreading; -pub const Runtime = @import("tardy").Runtime; -pub const Task = @import("tardy").Task; -const TaskFn = @import("tardy").TaskFn; -pub const AsyncIOType = @import("tardy").AsyncIOType; -const TardyCreator = @import("tardy").Tardy; -const Cross = @import("tardy").Cross; - -pub const RecvStatus = union(enum) { - kill, - recv, - send: Pseudoslice, - spawned, -}; - -/// Security Model to use.chinp acas -/// -/// Default: .plain (plaintext) -pub const Security = union(enum) { - plain, - tls: struct { - cert: TLSFileOptions, - key: TLSFileOptions, - cert_name: []const u8 = "CERTIFICATE", - key_name: []const u8 = "PRIVATE KEY", - }, -}; - -/// These are various general configuration -/// options that are important for the actual framework. -/// -/// This includes various different options and limits -/// for interacting with the underlying network. -pub const zzzConfig = struct { - /// The allocator that server will use. - allocator: std.mem.Allocator, - /// Threading Model to use. - /// - /// Default: .auto - threading: Threading = .auto, - /// Kernel Backlog Value. - size_backlog: u31 = 512, - /// Number of Maximum Concurrent Connections. - /// - /// This is applied PER thread if using multi-threading. - /// zzz will drop/close any connections greater - /// than this. - /// - /// You want to tune this to your expected number - /// of maximum connections. - /// - /// Default: 1024 - size_connections_max: u16 = 1024, - /// Maximum number of completions we can reap - /// with a single call of reap(). - /// - /// Default: 256 - size_completions_reap_max: u16 = 256, - /// Amount of allocated memory retained - /// after an arena is cleared. - /// - /// A higher value will increase memory usage but - /// should make allocators faster.Tardy - /// - /// A lower value will reduce memory usage but - /// will make allocators slower. - /// - /// Default: 1KB - size_connection_arena_retain: u32 = 1024, - /// Size of the buffer (in bytes) used for - /// interacting with the socket. - /// - /// Default: 4 KB. - size_socket_buffer: u32 = 1024 * 4, - /// Maximum size (in bytes) of the Recv buffer. - /// This is mainly a concern when you are reading in - /// large requests before responding. - /// - /// Default: 2MB. - size_recv_buffer_max: u32 = 1024 * 1024 * 2, -}; - -fn RecvFn(comptime ProtocolData: type, comptime ProtocolConfig: type) type { - return *const fn ( - rt: *Runtime, - trigger_task: TaskFn, - provision: *ZProvision(ProtocolData), - p_config: *const ProtocolConfig, - z_config: *const zzzConfig, - recv_buffer: []const u8, - ) RecvStatus; -} - -pub fn Server( - comptime security: Security, - comptime async_type: AsyncIOType, - comptime ProtocolData: type, - comptime ProtocolConfig: type, - comptime recv_fn: RecvFn(ProtocolData, ProtocolConfig), -) type { - const TLSContextType = comptime if (security == .tls) TLSContext else void; - const TLSType = comptime if (security == .tls) ?TLS else void; - const Provision = ZProvision(ProtocolData); - const Tardy = TardyCreator(async_type); - - return struct { - const Self = @This(); - allocator: std.mem.Allocator, - tardy: Tardy, - config: zzzConfig, - addr: std.net.Address, - tls_ctx: TLSContextType, - - pub fn init(config: zzzConfig) Self { - const tls_ctx = switch (comptime security) { - .tls => |inner| TLSContext.init(.{ - .allocator = config.allocator, - .cert = inner.cert, - .cert_name = inner.cert_name, - .key = inner.key, - .key_name = inner.key_name, - .size_tls_buffer_max = config.size_socket_buffer * 2, - }) catch unreachable, - .plain => void{}, - }; - - return Self{ - .allocator = config.allocator, - .tardy = Tardy.init(.{ - .allocator = config.allocator, - .threading = config.threading, - .size_tasks_max = config.size_connections_max, - .size_aio_jobs_max = config.size_connections_max, - .size_aio_reap_max = config.size_completions_reap_max, - }) catch unreachable, - .config = config, - .addr = undefined, - .tls_ctx = tls_ctx, - }; - } - - pub fn deinit(self: *Self) void { - if (comptime security == .tls) { - self.tls_ctx.deinit(); - } - - self.tardy.deinit(); - } - - fn create_socket(self: *const Self) !std.posix.socket_t { - const socket: std.posix.socket_t = blk: { - const socket_flags = std.posix.SOCK.STREAM | std.posix.SOCK.CLOEXEC | std.posix.SOCK.NONBLOCK; - break :blk try std.posix.socket( - self.addr.any.family, - socket_flags, - std.posix.IPPROTO.TCP, - ); - }; - - log.debug("socket | t: {s} v: {any}", .{ @typeName(std.posix.socket_t), socket }); - - if (@hasDecl(std.posix.SO, "REUSEPORT_LB")) { - try std.posix.setsockopt( - socket, - std.posix.SOL.SOCKET, - std.posix.SO.REUSEPORT_LB, - &std.mem.toBytes(@as(c_int, 1)), - ); - } else if (@hasDecl(std.posix.SO, "REUSEPORT")) { - try std.posix.setsockopt( - socket, - std.posix.SOL.SOCKET, - std.posix.SO.REUSEPORT, - &std.mem.toBytes(@as(c_int, 1)), - ); - } else { - try std.posix.setsockopt( - socket, - std.posix.SOL.SOCKET, - std.posix.SO.REUSEADDR, - &std.mem.toBytes(@as(c_int, 1)), - ); - } - - try std.posix.bind(socket, &self.addr.any, self.addr.getOsSockLen()); - return socket; - } - - /// If you are using a custom implementation that does NOT rely - /// on TCP/IP, you can SKIP calling this method and just set the - /// socket value yourself. - /// - /// This is only allowed on certain targets that do not have TCP/IP - /// support. - pub fn bind(self: *Self, host: []const u8, port: u16) !void { - assert(host.len > 0); - assert(port > 0); - - self.addr = blk: { - switch (comptime builtin.os.tag) { - .windows => break :blk try std.net.Address.parseIp(host, port), - else => break :blk try std.net.Address.resolveIp(host, port), - } - }; - } - - fn close_task(rt: *Runtime, _: *const Task, ctx: ?*anyopaque) !void { - const provision: *Provision = @ptrCast(@alignCast(ctx.?)); - assert(provision.job == .close); - const server_socket = rt.storage.get("server_socket", std.posix.socket_t); - const pool = rt.storage.get_ptr("provision_pool", Pool(Provision)); - const z_config = rt.storage.get_const_ptr("z_config", zzzConfig); - - log.info("{d} - closing connection", .{provision.index}); - - if (comptime security == .tls) { - const tls_slice = rt.storage.get("tls_slice", []TLSType); - - const tls_ptr: *?TLS = &tls_slice[provision.index]; - assert(tls_ptr.* != null); - tls_ptr.*.?.deinit(); - tls_ptr.* = null; - } - - provision.socket = Cross.socket.INVALID_SOCKET; - provision.job = .empty; - _ = provision.arena.reset(.{ .retain_with_limit = z_config.size_connection_arena_retain }); - provision.data.clean(); - provision.recv_buffer.clearRetainingCapacity(); - pool.release(provision.index); - - const accept_queued = rt.storage.get_ptr("accept_queued", bool); - if (!accept_queued.*) { - accept_queued.* = true; - try rt.net.accept(.{ - .socket = server_socket, - .func = accept_task, - }); - } - } - - fn accept_task(rt: *Runtime, t: *const Task, _: ?*anyopaque) !void { - const child_socket = t.result.?.socket; - - const pool = rt.storage.get_ptr("provision_pool", Pool(Provision)); - const accept_queued = rt.storage.get_ptr("accept_queued", bool); - accept_queued.* = false; - - if (rt.scheduler.tasks.clean() >= 2) { - accept_queued.* = true; - const server_socket = rt.storage.get("server_socket", std.posix.socket_t); - try rt.net.accept(.{ - .socket = server_socket, - .func = accept_task, - }); - } - - if (!Cross.socket.is_valid(child_socket)) { - log.err("socket accept failed", .{}); - return error.AcceptFailed; - } - - // This should never fail. It means that we have a dangling item. - assert(pool.clean() > 0); - const borrowed = pool.borrow_hint(t.index) catch unreachable; - - log.info("{d} - accepting connection", .{borrowed.index}); - log.debug( - "empty provision slots: {d}", - .{pool.items.len - pool.dirty.count()}, - ); - assert(borrowed.item.job == .empty); - - try Cross.socket.disable_nagle(child_socket); - try Cross.socket.to_nonblock(child_socket); - - const provision = borrowed.item; - - // Store the index of this item. - provision.index = @intCast(borrowed.index); - provision.socket = child_socket; - - switch (comptime security) { - .tls => |_| { - const tls_ctx = rt.storage.get_const_ptr("tls_ctx", TLSContextType); - const tls_slice = rt.storage.get("tls_slice", []TLSType); - - const tls_ptr: *?TLS = &tls_slice[provision.index]; - assert(tls_ptr.* == null); - - tls_ptr.* = tls_ctx.create(child_socket) catch |e| { - log.err("{d} - tls creation failed={any}", .{ provision.index, e }); - provision.job = .close; - try rt.net.close(.{ - .socket = provision.socket, - .func = close_task, - .ctx = provision, - }); - return error.TLSCreationFailed; - }; - - const recv_buf = tls_ptr.*.?.start_handshake() catch |e| { - log.err("{d} - tls start handshake failed={any}", .{ provision.index, e }); - provision.job = .close; - try rt.net.close(.{ - .socket = provision.socket, - .func = close_task, - .ctx = provision, - }); - return error.TLSStartHandshakeFailed; - }; - - provision.job = .{ .handshake = .{ .state = .recv, .count = 0 } }; - try rt.net.recv(.{ - .socket = child_socket, - .buffer = recv_buf, - .func = handshake_task, - .ctx = borrowed.item, - }); - }, - .plain => { - provision.job = .{ .recv = .{ .count = 0 } }; - try rt.net.recv(.{ - .socket = child_socket, - .buffer = provision.buffer, - .func = recv_task, - .ctx = borrowed.item, - }); - }, - } - } - - /// This is the task you MUST trigger if the `recv_fn` returns `.spawned`. - fn trigger_task(rt: *Runtime, _: *const Task, ctx: ?*anyopaque) !void { - const provision: *Provision = @ptrCast(@alignCast(ctx.?)); - - switch (provision.job) { - else => unreachable, - .recv => { - try rt.net.recv(.{ - .socket = provision.socket, - .buffer = provision.buffer, - .func = recv_task, - .ctx = provision, - }); - }, - .send => |*send_job| { - const z_config = rt.storage.get_const_ptr("z_config", zzzConfig); - const plain_buffer = send_job.slice.get(0, z_config.size_socket_buffer); - - switch (comptime security) { - .tls => |_| { - const tls_slice = rt.storage.get("tls_slice", []TLSType); - - const tls_ptr: *?TLS = &tls_slice[provision.index]; - assert(tls_ptr.* != null); - - const encrypted_buffer = tls_ptr.*.?.encrypt(plain_buffer) catch |e| { - log.err("{d} - encrypt failed: {any}", .{ provision.index, e }); - provision.job = .close; - try rt.net.close(.{ - .socket = provision.socket, - .func = close_task, - .ctx = provision, - }); - return error.TLSEncryptFailed; - }; - - send_job.count = plain_buffer.len; - send_job.security = .{ - .tls = .{ - .encrypted = encrypted_buffer, - .encrypted_count = 0, - }, - }; - - try rt.net.send(.{ - .socket = provision.socket, - .buffer = encrypted_buffer, - .func = send_task, - .ctx = provision, - }); - }, - .plain => { - send_job.security = .plain; - - try rt.net.send(.{ - .socket = provision.socket, - .buffer = plain_buffer, - .func = send_task, - .ctx = provision, - }); - }, - } - }, - } - } - - fn recv_task(rt: *Runtime, t: *const Task, ctx: ?*anyopaque) !void { - const provision: *Provision = @ptrCast(@alignCast(ctx.?)); - assert(provision.job == .recv); - const length: i32 = t.result.?.value; - - const p_config = rt.storage.get_const_ptr("p_config", ProtocolConfig); - const z_config = rt.storage.get_const_ptr("z_config", zzzConfig); - - const recv_job = &provision.job.recv; - - // If the socket is closed. - if (length <= 0) { - provision.job = .close; - try rt.net.close(.{ - .socket = provision.socket, - .func = close_task, - .ctx = provision, - }); - return; - } - - log.debug("{d} - recv triggered", .{provision.index}); - - const recv_count: usize = @intCast(length); - recv_job.count += recv_count; - const pre_recv_buffer = provision.buffer[0..recv_count]; - - const recv_buffer = blk: { - switch (comptime security) { - .tls => |_| { - const tls_slice = rt.storage.get("tls_slice", []TLSType); - - const tls_ptr: *?TLS = &tls_slice[provision.index]; - assert(tls_ptr.* != null); - - break :blk tls_ptr.*.?.decrypt(pre_recv_buffer) catch |e| { - log.err("{d} - decrypt failed: {any}", .{ provision.index, e }); - provision.job = .close; - try rt.net.close(.{ - .socket = provision.socket, - .func = close_task, - .ctx = provision, - }); - return error.TLSDecryptFailed; - }; - }, - .plain => break :blk pre_recv_buffer, - } - }; - - var status: RecvStatus = @call(.auto, recv_fn, .{ rt, trigger_task, provision, p_config, z_config, recv_buffer }); - - switch (status) { - .spawned => return, - .kill => { - rt.stop(); - return error.Killed; - }, - .recv => { - try rt.net.recv(.{ - .socket = provision.socket, - .buffer = provision.buffer, - .func = recv_task, - .ctx = provision, - }); - }, - .send => |*pslice| { - const plain_buffer = pslice.get(0, z_config.size_socket_buffer); - - switch (comptime security) { - .tls => |_| { - const tls_slice = rt.storage.get("tls_slice", []TLSType); - - const tls_ptr: *?TLS = &tls_slice[provision.index]; - assert(tls_ptr.* != null); - - const encrypted_buffer = tls_ptr.*.?.encrypt(plain_buffer) catch |e| { - log.err("{d} - encrypt failed: {any}", .{ provision.index, e }); - provision.job = .close; - try rt.net.close(.{ - .socket = provision.socket, - .func = close_task, - .ctx = provision, - }); - return error.TLSEncryptFailed; - }; - - provision.job = .{ - .send = .{ - .slice = pslice.*, - .count = @intCast(plain_buffer.len), - .security = .{ - .tls = .{ - .encrypted = encrypted_buffer, - .encrypted_count = 0, - }, - }, - }, - }; - - try rt.net.send(.{ - .socket = provision.socket, - .buffer = encrypted_buffer, - .func = send_task, - .ctx = provision, - }); - }, - .plain => { - provision.job = .{ - .send = .{ - .slice = pslice.*, - .count = 0, - .security = .plain, - }, - }; - - try rt.net.send(.{ - .socket = provision.socket, - .buffer = plain_buffer, - .func = send_task, - .ctx = provision, - }); - }, - } - }, - } - } - - fn handshake_task(rt: *Runtime, t: *const Task, ctx: ?*anyopaque) !void { - log.debug("Handshake Task", .{}); - assert(security == .tls); - const provision: *Provision = @ptrCast(@alignCast(ctx.?)); - const length: i32 = t.result.?.value; - - if (comptime security == .tls) { - const tls_slice = rt.storage.get("tls_slice", []TLSType); - - assert(provision.job == .handshake); - const handshake_job = &provision.job.handshake; - - const tls_ptr: *?TLS = &tls_slice[provision.index]; - assert(tls_ptr.* != null); - log.debug("processing handshake", .{}); - handshake_job.count += 1; - - if (length <= 0) { - log.debug("handshake connection closed", .{}); - provision.job = .close; - try rt.net.close(.{ - .socket = provision.socket, - .func = close_task, - .ctx = provision, - }); - return error.TLSHandshakeClosed; - } - - if (handshake_job.count >= 50) { - log.debug("handshake taken too many cycles", .{}); - provision.job = .close; - try rt.net.close(.{ - .socket = provision.socket, - .func = close_task, - .ctx = provision, - }); - return error.TLSHandshakeTooManyCycles; - } - - const hs_length: usize = @intCast(length); - - switch (handshake_job.state) { - .recv => { - // on recv, we want to read from socket and feed into tls engien - const hstate = tls_ptr.*.?.continue_handshake( - .{ .recv = @intCast(hs_length) }, - ) catch |e| { - log.err("{d} - tls handshake on recv failed={any}", .{ provision.index, e }); - provision.job = .close; - try rt.net.close(.{ - .socket = provision.socket, - .func = close_task, - .ctx = provision, - }); - return error.TLSHandshakeRecvFailed; - }; - - switch (hstate) { - .recv => |buf| { - log.debug("requeing recv in handshake", .{}); - try rt.net.recv(.{ - .socket = provision.socket, - .buffer = buf, - .func = handshake_task, - .ctx = provision, - }); - }, - .send => |buf| { - log.debug("queueing send in handshake", .{}); - handshake_job.state = .send; - try rt.net.send(.{ - .socket = provision.socket, - .buffer = buf, - .func = handshake_task, - .ctx = provision, - }); - }, - .complete => { - log.debug("handshake complete", .{}); - provision.job = .{ .recv = .{ .count = 0 } }; - try rt.net.recv(.{ - .socket = provision.socket, - .buffer = provision.buffer, - .func = recv_task, - .ctx = provision, - }); - }, - } - }, - .send => { - // on recv, we want to read from socket and feed into tls engien - const hstate = tls_ptr.*.?.continue_handshake( - .{ .send = @intCast(hs_length) }, - ) catch |e| { - log.err("{d} - tls handshake on send failed={any}", .{ provision.index, e }); - provision.job = .close; - try rt.net.close(.{ - .socket = provision.socket, - .func = close_task, - .ctx = provision, - }); - return error.TLSHandshakeSendFailed; - }; - - switch (hstate) { - .recv => |buf| { - handshake_job.state = .recv; - log.debug("queuing recv in handshake", .{}); - try rt.net.recv(.{ - .socket = provision.socket, - .buffer = buf, - .func = handshake_task, - .ctx = provision, - }); - }, - .send => |buf| { - log.debug("requeing send in handshake", .{}); - try rt.net.send(.{ - .socket = provision.socket, - .buffer = buf, - .func = handshake_task, - .ctx = provision, - }); - }, - .complete => { - log.debug("handshake complete", .{}); - provision.job = .{ .recv = .{ .count = 0 } }; - try rt.net.recv(.{ - .socket = provision.socket, - .buffer = provision.buffer, - .func = recv_task, - .ctx = provision, - }); - }, - } - }, - } - } else unreachable; - } - - fn send_task(rt: *Runtime, t: *const Task, ctx: ?*anyopaque) !void { - const provision: *Provision = @ptrCast(@alignCast(ctx.?)); - assert(provision.job == .send); - const length: i32 = t.result.?.value; - - const z_config = rt.storage.get_const_ptr("z_config", zzzConfig); - - // If the socket is closed. - if (length <= 0) { - provision.job = .close; - try rt.net.close(.{ - .socket = provision.socket, - .func = close_task, - .ctx = provision, - }); - return; - } - - const send_job = &provision.job.send; - - log.debug("{d} - send triggered", .{provision.index}); - const send_count: usize = @intCast(length); - log.debug("{d} - send length: {d}", .{ provision.index, send_count }); - - switch (comptime security) { - .tls => { - assert(send_job.security == .tls); - - const tls_slice = rt.storage.get("tls_slice", []TLSType); - - const job_tls = &send_job.security.tls; - job_tls.encrypted_count += send_count; - - if (job_tls.encrypted_count >= job_tls.encrypted.len) { - if (send_job.count >= send_job.slice.len) { - // All done sending. - log.debug("{d} - queueing a new recv", .{provision.index}); - _ = provision.arena.reset(.{ - .retain_with_limit = z_config.size_connection_arena_retain, - }); - provision.recv_buffer.clearRetainingCapacity(); - provision.job = .{ .recv = .{ .count = 0 } }; - - try rt.net.recv(.{ - .socket = provision.socket, - .buffer = provision.buffer, - .func = recv_task, - .ctx = provision, - }); - } else { - // Queue a new chunk up for sending. - log.debug( - "{d} - sending next chunk starting at index {d}", - .{ provision.index, send_job.count }, - ); - - const inner_slice = send_job.slice.get( - send_job.count, - send_job.count + z_config.size_socket_buffer, - ); - - send_job.count += @intCast(inner_slice.len); - - const tls_ptr: *?TLS = &tls_slice[provision.index]; - assert(tls_ptr.* != null); - - const encrypted = tls_ptr.*.?.encrypt(inner_slice) catch |e| { - log.err("{d} - encrypt failed: {any}", .{ provision.index, e }); - provision.job = .close; - try rt.net.close(.{ - .socket = provision.socket, - .func = close_task, - .ctx = provision, - }); - return error.TLSEncryptFailed; - }; - - job_tls.encrypted = encrypted; - job_tls.encrypted_count = 0; - - try rt.net.send(.{ - .socket = provision.socket, - .buffer = job_tls.encrypted, - .func = send_task, - .ctx = provision, - }); - } - } else { - log.debug( - "{d} - sending next encrypted chunk starting at index {d}", - .{ provision.index, job_tls.encrypted_count }, - ); - - const remainder = job_tls.encrypted[job_tls.encrypted_count..]; - try rt.net.send(.{ - .socket = provision.socket, - .buffer = remainder, - .func = send_task, - .ctx = provision, - }); - } - }, - .plain => { - assert(send_job.security == .plain); - send_job.count += send_count; - - if (send_job.count >= send_job.slice.len) { - log.debug("{d} - queueing a new recv", .{provision.index}); - _ = provision.arena.reset(.{ - .retain_with_limit = z_config.size_connection_arena_retain, - }); - provision.recv_buffer.clearRetainingCapacity(); - provision.job = .{ .recv = .{ .count = 0 } }; - - try rt.net.recv(.{ - .socket = provision.socket, - .buffer = provision.buffer, - .func = recv_task, - .ctx = provision, - }); - } else { - log.debug( - "{d} - sending next chunk starting at index {d}", - .{ provision.index, send_job.count }, - ); - - const plain_buffer = send_job.slice.get( - send_job.count, - send_job.count + z_config.size_socket_buffer, - ); - - log.debug("{d} - chunk ends at: {d}", .{ - provision.index, - plain_buffer.len + send_job.count, - }); - - try rt.net.send(.{ - .socket = provision.socket, - .buffer = plain_buffer, - .func = send_task, - .ctx = provision, - }); - } - }, - } - } - - pub fn listen(self: *Self, protocol_config: ProtocolConfig) !void { - log.info("server listening...", .{}); - log.info("security mode: {s}", .{@tagName(security)}); - - const EntryParams = struct { - zzz: *Self, - p_config: *ProtocolConfig, - }; - - try self.tardy.entry( - struct { - fn rt_start(rt: *Runtime, alloc: std.mem.Allocator, params: EntryParams) !void { - const socket = try params.zzz.create_socket(); - try std.posix.listen(socket, params.zzz.config.size_backlog); - - // use the arena here. - var pool_params = params.zzz.config; - pool_params.allocator = alloc; - - const provision_pool = try alloc.create(Pool(Provision)); - provision_pool.* = try Pool(Provision).init( - alloc, - params.zzz.config.size_connections_max, - Provision.init_hook, - pool_params, - ); - - for (provision_pool.items) |*provision| { - provision.data = ProtocolData.init(alloc, params.p_config); - } - - try rt.storage.store_ptr("provision_pool", provision_pool); - try rt.storage.store_ptr("z_config", ¶ms.zzz.config); - try rt.storage.store_ptr("p_config", params.p_config); - - if (comptime security == .tls) { - const tls_slice = try alloc.alloc( - TLSType, - params.zzz.config.size_connections_max, - ); - if (comptime security == .tls) { - for (tls_slice) |*tls| { - tls.* = null; - } - } - - // since slices are fat pointers... - try rt.storage.store_alloc("tls_slice", tls_slice); - try rt.storage.store_ptr("tls_ctx", ¶ms.zzz.tls_ctx); - } - - try rt.storage.store_alloc("server_socket", socket); - try rt.storage.store_alloc("accept_queued", true); - - try rt.net.accept(.{ - .socket = socket, - .func = accept_task, - }); - } - }.rt_start, - EntryParams{ - .zzz = self, - .p_config = @constCast(&protocol_config), - }, - struct { - fn rt_end(rt: *Runtime, alloc: std.mem.Allocator, _: anytype) void { - // clean up socket. - const server_socket = rt.storage.get("server_socket", std.posix.socket_t); - std.posix.close(server_socket); - - // clean up provision pool. - const provision_pool = rt.storage.get_ptr("provision_pool", Pool(Provision)); - for (provision_pool.items) |*provision| { - provision.data.deinit(alloc); - } - provision_pool.deinit(Provision.deinit_hook, alloc); - alloc.destroy(provision_pool); - - // clean up TLS. - if (comptime security == .tls) { - const tls_slice = rt.storage.get("tls_slice", []TLSType); - alloc.free(tls_slice); - } - } - }.rt_end, - void, - ); - } - }; -} diff --git a/src/core/zprovision.zig b/src/core/zprovision.zig deleted file mode 100644 index 3e2f63e..0000000 --- a/src/core/zprovision.zig +++ /dev/null @@ -1,41 +0,0 @@ -const std = @import("std"); -const panic = std.debug.panic; -const Job = @import("../core/lib.zig").Job; -const TLS = @import("../tls/lib.zig").TLS; - -pub fn ZProvision(comptime ProtocolData: type) type { - return struct { - const Self = @This(); - index: usize, - job: Job, - socket: std.posix.socket_t, - buffer: []u8, - recv_buffer: std.ArrayList(u8), - arena: std.heap.ArenaAllocator, - data: ProtocolData, - - pub fn init_hook(provisions: []Self, ctx: anytype) void { - for (provisions) |*provision| { - provision.job = .empty; - provision.socket = undefined; - provision.data = undefined; - // Create Buffer - provision.buffer = ctx.allocator.alloc(u8, ctx.size_socket_buffer) catch { - panic("attempting to statically allocate more memory than available. (Socket Buffer)", .{}); - }; - // Create Recv Buffer - provision.recv_buffer = std.ArrayList(u8).init(ctx.allocator); - // Create the Context Arena - provision.arena = std.heap.ArenaAllocator.init(ctx.allocator); - } - } - - pub fn deinit_hook(provisions: []Self, allocator: anytype) void { - for (provisions) |*provision| { - allocator.free(provision.buffer); - provision.recv_buffer.deinit(); - provision.arena.deinit(); - } - } - }; -} diff --git a/src/http/context.zig b/src/http/context.zig index d4e5fe0..d3356f0 100644 --- a/src/http/context.zig +++ b/src/http/context.zig @@ -5,7 +5,7 @@ const log = std.log.scoped(.@"zzz/http/context"); const Capture = @import("routing_trie.zig").Capture; const QueryMap = @import("routing_trie.zig").QueryMap; -const Provision = @import("../core/zprovision.zig").ZProvision(@import("protocol.zig").ProtocolData); +const Provision = @import("provision.zig").Provision; const Request = @import("request.zig").Request; const Response = @import("response.zig").Response; diff --git a/src/http/lib.zig b/src/http/lib.zig index 97e43ac..468cdd1 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -9,7 +9,6 @@ pub const Router = @import("router.zig").Router; pub const RouteHandlerFn = @import("route.zig").RouteHandlerFn; pub const Context = @import("context.zig").Context; pub const Headers = @import("headers.zig").Headers; -pub const Protocol = @import("protocol.zig"); pub const Server = @import("server.zig").Server; diff --git a/src/http/protocol.zig b/src/http/protocol.zig deleted file mode 100644 index 23fab20..0000000 --- a/src/http/protocol.zig +++ /dev/null @@ -1,78 +0,0 @@ -const std = @import("std"); -const Job = @import("../core/lib.zig").Job; -const Capture = @import("routing_trie.zig").Capture; -const Query = @import("routing_trie.zig").Query; -const QueryMap = @import("routing_trie.zig").QueryMap; -const Request = @import("request.zig").Request; -const Response = @import("response.zig").Response; -const Stage = @import("stage.zig").Stage; -const Router = @import("router.zig").Router; - -pub const ProtocolConfig = struct { - router: *Router, - num_header_max: u32 = 32, - /// Maximum number of Captures in a Route - /// - /// Default: 8 - num_captures_max: u32 = 8, - /// Maximum number of Queries in a URL - /// - /// Default: 8 - num_queries_max: u32 = 8, - /// Maximum size (in bytes) of the Request. - /// - /// Default: 2MB. - size_request_max: u32 = 1024 * 1024 * 2, - /// Maximum size (in bytes) of the Request URI. - /// - /// Default: 2KB. - size_request_uri_max: u32 = 1024 * 2, -}; - -pub const ProtocolData = struct { - captures: []Capture, - queries: QueryMap, - request: Request, - response: Response, - stage: Stage, - - pub fn init(allocator: std.mem.Allocator, config: *const ProtocolConfig) ProtocolData { - var queries = QueryMap.init(allocator); - queries.ensureTotalCapacity(config.num_queries_max) catch unreachable; - - return ProtocolData{ - .stage = .header, - .captures = allocator.alloc(Capture, config.num_captures_max) catch unreachable, - .queries = queries, - .request = Request.init(allocator, .{ - .num_header_max = config.num_header_max, - .size_request_max = config.size_request_max, - .size_request_uri_max = config.size_request_uri_max, - }) catch unreachable, - .response = Response.init(allocator, .{ - .num_headers_max = config.num_header_max, - }) catch unreachable, - }; - } - - pub fn deinit(self: *ProtocolData, allocator: std.mem.Allocator) void { - self.request.deinit(); - self.response.deinit(); - self.queries.deinit(); - allocator.free(self.captures); - } - - pub fn clean(self: *ProtocolData) void { - self.response.clear(); - } -}; - -const testing = std.testing; - -test "ProtocolData deinit" { - const config: ProtocolConfig = .{ .router = undefined }; - var x = ProtocolData.init(testing.allocator, &config); - defer x.deinit(testing.allocator); - - try testing.expectEqual(x.stage, .header); -} diff --git a/src/http/provision.zig b/src/http/provision.zig new file mode 100644 index 0000000..7802468 --- /dev/null +++ b/src/http/provision.zig @@ -0,0 +1,66 @@ +const std = @import("std"); + +const Job = @import("../core/job.zig").Job; +const Capture = @import("routing_trie.zig").Capture; +const QueryMap = @import("routing_trie.zig").QueryMap; +const Request = @import("request.zig").Request; +const Response = @import("response.zig").Response; +const Stage = @import("stage.zig").Stage; +const ServerConfig = @import("server.zig").ServerConfig; + +pub const Provision = struct { + index: usize, + job: Job, + socket: std.posix.socket_t, + buffer: []u8, + recv_buffer: std.ArrayList(u8), + arena: std.heap.ArenaAllocator, + captures: []Capture, + queries: QueryMap, + request: Request, + response: Response, + stage: Stage, + + pub fn init_hook(provisions: []Provision, config: anytype) void { + for (provisions) |*provision| { + provision.job = .empty; + provision.socket = undefined; + // Create Buffer + provision.buffer = config.allocator.alloc(u8, config.size_socket_buffer) catch { + @panic("attempting to statically allocate more memory than available. (Socket Buffer)"); + }; + // Create Recv Buffer + provision.recv_buffer = std.ArrayList(u8).init(config.allocator); + // Create the Context Arena + provision.arena = std.heap.ArenaAllocator.init(config.allocator); + + provision.stage = .header; + provision.captures = config.allocator.alloc(Capture, config.num_captures_max) catch unreachable; + + var queries = QueryMap.init(config.allocator); + queries.ensureTotalCapacity(config.num_queries_max) catch unreachable; + provision.queries = queries; + + provision.request = Request.init(config.allocator, .{ + .num_header_max = config.num_header_max, + .size_request_max = config.size_request_max, + .size_request_uri_max = config.size_request_uri_max, + }) catch unreachable; + provision.response = Response.init(config.allocator, .{ + .num_headers_max = config.num_header_max, + }) catch unreachable; + } + } + + pub fn deinit_hook(provisions: []Provision, allocator: anytype) void { + for (provisions) |*provision| { + allocator.free(provision.buffer); + provision.recv_buffer.deinit(); + provision.arena.deinit(); + provision.request.deinit(); + provision.response.deinit(); + provision.queries.deinit(); + allocator.free(provision.captures); + } + } +}; diff --git a/src/http/router.zig b/src/http/router.zig index 89f5ca6..4824781 100644 --- a/src/http/router.zig +++ b/src/http/router.zig @@ -13,6 +13,9 @@ const Context = @import("context.zig").Context; const RoutingTrie = @import("routing_trie.zig").RoutingTrie; const QueryMap = @import("routing_trie.zig").QueryMap; +const Runtime = @import("tardy").Runtime; +const Task = @import("tardy").Task; + pub const Router = struct { allocator: std.mem.Allocator, routes: RoutingTrie, @@ -29,16 +32,111 @@ pub const Router = struct { self.routes.deinit(); } + const FileProvision = struct { + mime: Mime, + context: *Context, + fd: std.posix.fd_t, + offset: usize, + list: std.ArrayList(u8), + buffer: []u8, + }; + + fn open_file_task(rt: *Runtime, t: *const Task, ctx: ?*anyopaque) !void { + const provision: *FileProvision = @ptrCast(@alignCast(ctx.?)); + errdefer { + provision.context.respond(.{ + .status = .@"Internal Server Error", + .mime = Mime.HTML, + .body = "", + }); + } + + const fd = t.result.?.fd; + if (fd <= -1) { + provision.context.respond(.{ + .status = .@"Not Found", + .mime = Mime.HTML, + .body = "File Not Found", + }); + return; + } + provision.fd = fd; + + try rt.fs.read(.{ + .fd = fd, + .buffer = provision.buffer, + .offset = 0, + .func = read_file_task, + .ctx = provision, + }); + } + + fn read_file_task(rt: *Runtime, t: *const Task, ctx: ?*anyopaque) !void { + const provision: *FileProvision = @ptrCast(@alignCast(ctx.?)); + errdefer { + provision.context.respond(.{ + .status = .@"Internal Server Error", + .mime = Mime.HTML, + .body = "", + }); + } + + const result: i32 = t.result.?.value; + if (result <= 0) { + // If we are done reading... + try rt.fs.close(.{ + .fd = provision.fd, + .func = close_file_task, + .ctx = provision, + }); + return; + } + + const length: usize = @intCast(result); + + try provision.list.appendSlice(provision.buffer[0..length]); + + // TODO: This needs to be a setting you pass in to the router. + // + //if (provision.list.items.len > 1024 * 1024 * 4) { + // provision.context.respond(.{ + // .status = .@"Content Too Large", + // .mime = Mime.HTML, + // .body = "File Too Large", + // }); + // return; + //} + + provision.offset += length; + + try rt.fs.read(.{ + .fd = provision.fd, + .buffer = provision.buffer, + .offset = provision.offset, + .func = read_file_task, + .ctx = provision, + }); + } + + fn close_file_task(_: *Runtime, _: *const Task, ctx: ?*anyopaque) !void { + const provision: *FileProvision = @ptrCast(@alignCast(ctx.?)); + + provision.context.respond(.{ + .status = .OK, + .mime = provision.mime, + .body = provision.list.items[0..], + }); + } + pub fn serve_fs_dir(self: *Router, comptime url_path: []const u8, comptime dir_path: []const u8) !void { assert(!self.locked); const route = Route.init().get(struct { - pub fn handler_fn(request: Request, response: *Response, context: Context) void { - _ = request; + pub fn handler_fn(ctx: *Context) void { + const search_path = ctx.captures[0].remaining; - const search_path = context.captures[0].remaining; - const file_path = std.fmt.allocPrint(context.allocator, "{s}/{s}", .{ dir_path, search_path }) catch { - response.set(.{ + const file_path = std.fmt.allocPrintZ(ctx.allocator, "{s}/{s}", .{ dir_path, search_path }) catch { + ctx.respond(.{ .status = .@"Internal Server Error", .mime = Mime.HTML, .body = "", @@ -46,39 +144,50 @@ pub const Router = struct { return; }; + // TODO: Ensure that paths cannot go out of scope and reference data that they shouldn't be allowed to. + // Very important. + const extension_start = std.mem.lastIndexOfScalar(u8, search_path, '.'); const mime: Mime = blk: { if (extension_start) |start| { break :blk Mime.from_extension(search_path[start..]); } else { - break :blk Mime.HTML; + break :blk Mime.BIN; } }; - const file: std.fs.File = std.fs.cwd().openFile(file_path, .{}) catch { - response.set(.{ - .status = .@"Not Found", + const provision = ctx.allocator.create(FileProvision) catch { + ctx.respond(.{ + .status = .@"Internal Server Error", .mime = Mime.HTML, - .body = "File Not Found", + .body = "", }); return; }; - defer file.close(); - const file_bytes = file.readToEndAlloc(context.allocator, 1024 * 1024 * 4) catch { - response.set(.{ - .status = .@"Content Too Large", + provision.* = .{ + .mime = mime, + .context = ctx, + .fd = -1, + .offset = 0, + .list = std.ArrayList(u8).init(ctx.allocator), + .buffer = ctx.provision.buffer, + }; + + // We also need to support chunked encoding. + // It makes a lot more sense for files atleast. + ctx.runtime.fs.open(.{ + .path = file_path, + .func = open_file_task, + .ctx = provision, + }) catch { + ctx.respond(.{ + .status = .@"Internal Server Error", .mime = Mime.HTML, - .body = "File Too Large", + .body = "", }); return; }; - - response.set(.{ - .status = .OK, - .mime = mime, - .body = file_bytes, - }); } }.handler_fn); diff --git a/src/http/server.zig b/src/http/server.zig index f8dd9dd..78b3088 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -1,35 +1,124 @@ const std = @import("std"); - const builtin = @import("builtin"); const assert = std.debug.assert; -const panic = std.debug.panic; const log = std.log.scoped(.@"zzz/http/server"); -const Runtime = @import("tardy").Runtime; -const AsyncIOType = @import("tardy").AsyncIOType; -const Pool = @import("tardy").Pool; +const Pseudoslice = @import("../core/pseudoslice.zig").Pseudoslice; -const Job = @import("../core/lib.zig").Job; -const Pseudoslice = @import("../core/lib.zig").Pseudoslice; +const TLSFileOptions = @import("../tls/lib.zig").TLSFileOptions; +const TLSContext = @import("../tls/lib.zig").TLSContext; +const TLS = @import("../tls/lib.zig").TLS; +const Provision = @import("provision.zig").Provision; +const Mime = @import("mime.zig").Mime; +const Router = @import("router.zig").Router; +const Context = @import("context.zig").Context; const HTTPError = @import("lib.zig").HTTPError; -const Request = @import("lib.zig").Request; -const Response = @import("lib.zig").Response; -const Mime = @import("lib.zig").Mime; -const Context = @import("lib.zig").Context; -const Router = @import("lib.zig").Router; - -const Capture = @import("routing_trie.zig").Capture; -const ProtocolData = @import("protocol.zig").ProtocolData; -const ProtocolConfig = @import("protocol.zig").ProtocolConfig; -const Security = @import("../core/server.zig").Security; -const zzzConfig = @import("../core/server.zig").zzzConfig; -const Provision = @import("../core/zprovision.zig").ZProvision(ProtocolData); - -const RecvStatus = @import("../core/server.zig").RecvStatus; -const zzzServer = @import("../core/server.zig").Server; +const Pool = @import("tardy").Pool; +pub const Threading = @import("tardy").TardyThreading; +pub const Runtime = @import("tardy").Runtime; +pub const Task = @import("tardy").Task; const TaskFn = @import("tardy").TaskFn; +pub const AsyncIOType = @import("tardy").AsyncIOType; +const TardyCreator = @import("tardy").Tardy; +const Cross = @import("tardy").Cross; + +pub const RecvStatus = union(enum) { + kill, + recv, + send: Pseudoslice, + spawned, +}; + +/// Security Model to use.chinp acas +/// +/// Default: .plain (plaintext) +pub const Security = union(enum) { + plain, + tls: struct { + cert: TLSFileOptions, + key: TLSFileOptions, + cert_name: []const u8 = "CERTIFICATE", + key_name: []const u8 = "PRIVATE KEY", + }, +}; + +/// These are various general configuration +/// options that are important for the actual framework. +/// +/// This includes various different options and limits +/// for interacting with the underlying network. +pub const ServerConfig = struct { + /// The allocator that server will use. + allocator: std.mem.Allocator, + /// HTTP Request Router. + router: *Router, + /// Threading Model to use. + /// + /// Default: .auto + threading: Threading = .auto, + /// Kernel Backlog Value. + size_backlog: u31 = 512, + /// Number of Maximum Concurrent Connections. + /// + /// This is applied PER thread if using multi-threading. + /// zzz will drop/close any connections greater + /// than this. + /// + /// You want to tune this to your expected number + /// of maximum connections. + /// + /// Default: 1024 + size_connections_max: u16 = 1024, + /// Maximum number of completions we can reap + /// with a single call of reap(). + /// + /// Default: 256 + size_completions_reap_max: u16 = 256, + /// Amount of allocated memory retained + /// after an arena is cleared. + /// + /// A higher value will increase memory usage but + /// should make allocators faster.Tardy + /// + /// A lower value will reduce memory usage but + /// will make allocators slower. + /// + /// Default: 1KB + size_connection_arena_retain: u32 = 1024, + /// Size of the buffer (in bytes) used for + /// interacting with the socket. + /// + /// Default: 4 KB. + size_socket_buffer: u32 = 1024 * 4, + /// Maximum size (in bytes) of the Recv buffer. + /// This is mainly a concern when you are reading in + /// large requests before responding. + /// + /// Default: 2MB. + size_recv_buffer_max: u32 = 1024 * 1024 * 2, + /// Maximum number of Headers in a Request/Response + /// + /// Default: 32 + num_header_max: u32 = 32, + /// Maximum number of Captures in a Route + /// + /// Default: 8 + num_captures_max: u32 = 8, + /// Maximum number of Queries in a URL + /// + /// Default: 8 + num_queries_max: u32 = 8, + /// Maximum size (in bytes) of the Request. + /// + /// Default: 2MB. + size_request_max: u32 = 1024 * 1024 * 2, + /// Maximum size (in bytes) of the Request URI. + /// + /// Default: 2KB. + size_request_uri_max: u32 = 1024 * 2, +}; /// Uses the current p.response to generate and queue up the sending /// of a response. This is used when we already know what we want to send. @@ -37,23 +126,23 @@ const TaskFn = @import("tardy").TaskFn; /// See: `route_and_respond` pub inline fn raw_respond(p: *Provision) !RecvStatus { { - const status_code: u16 = if (p.data.response.status) |status| @intFromEnum(status) else 0; - const status_name = if (p.data.response.status) |status| @tagName(status) else "No Status"; + const status_code: u16 = if (p.response.status) |status| @intFromEnum(status) else 0; + const status_name = if (p.response.status) |status| @tagName(status) else "No Status"; log.info("{d} - {d} {s}", .{ p.index, status_code, status_name }); } - const body = p.data.response.body orelse ""; - const header_buffer = try p.data.response.headers_into_buffer(p.buffer, @intCast(body.len)); - p.data.response.headers.clear(); + const body = p.response.body orelse ""; + const header_buffer = try p.response.headers_into_buffer(p.buffer, @intCast(body.len)); + p.response.headers.clear(); const pseudo = Pseudoslice.init(header_buffer, body, p.buffer); return .{ .send = pseudo }; } fn route_and_respond(runtime: *Runtime, trigger: TaskFn, p: *Provision, router: *const Router) !RecvStatus { route: { - const found = router.get_route_from_host(p.data.request.uri, p.data.captures, &p.data.queries); + const found = router.get_route_from_host(p.request.uri, p.captures, &p.queries); if (found) |f| { - const handler = f.route.get_handler(p.data.request.method); + const handler = f.route.get_handler(p.request.method); if (handler) |func| { const context: *Context = try p.arena.allocator().create(Context); @@ -62,9 +151,9 @@ fn route_and_respond(runtime: *Runtime, trigger: TaskFn, p: *Provision, router: trigger, runtime, p, - &p.data.request, - &p.data.response, - p.data.request.uri, + &p.request, + &p.response, + p.request.uri, f.captures, f.queries, ); @@ -73,7 +162,7 @@ fn route_and_respond(runtime: *Runtime, trigger: TaskFn, p: *Provision, router: return .spawned; } else { // If we match the route but not the method. - p.data.response.set(.{ + p.response.set(.{ .status = .@"Method Not Allowed", .mime = Mime.HTML, .body = "405 Method Not Allowed", @@ -82,7 +171,7 @@ fn route_and_respond(runtime: *Runtime, trigger: TaskFn, p: *Provision, router: // We also need to add to Allow header. // This uses the connection's arena to allocate 64 bytes. const allowed = f.route.get_allowed(p.arena.allocator()) catch { - p.data.response.set(.{ + p.response.set(.{ .status = .@"Internal Server Error", .mime = Mime.HTML, .body = "", @@ -91,8 +180,8 @@ fn route_and_respond(runtime: *Runtime, trigger: TaskFn, p: *Provision, router: break :route; }; - p.data.response.headers.add("Allow", allowed) catch { - p.data.response.set(.{ + p.response.headers.add("Allow", allowed) catch { + p.response.set(.{ .status = .@"Internal Server Error", .mime = Mime.HTML, .body = "", @@ -106,7 +195,7 @@ fn route_and_respond(runtime: *Runtime, trigger: TaskFn, p: *Provision, router: } // Didn't match any route. - p.data.response.set(.{ + p.response.set(.{ .status = .@"Not Found", .mime = Mime.HTML, .body = "404 Not Found", @@ -114,221 +203,1002 @@ fn route_and_respond(runtime: *Runtime, trigger: TaskFn, p: *Provision, router: break :route; } - if (p.data.response.status == .Kill) { + if (p.response.status == .Kill) { return .kill; } return try raw_respond(p); } -pub fn recv_fn( - runtime: *Runtime, - trigger: TaskFn, - provision: *Provision, - p_config: *const ProtocolConfig, - z_config: *const zzzConfig, - recv_buffer: []const u8, -) RecvStatus { - _ = z_config; - - var stage = provision.data.stage; - const job = provision.job.recv; - - if (job.count >= p_config.size_request_max) { - provision.data.response.set(.{ - .status = .@"Content Too Large", - .mime = Mime.HTML, - .body = "Request was too large", - }); +pub fn Server( + comptime security: Security, + comptime async_type: AsyncIOType, +) type { + const TLSContextType = comptime if (security == .tls) TLSContext else void; + const TLSType = comptime if (security == .tls) ?TLS else void; + const Tardy = TardyCreator(async_type); - return raw_respond(provision) catch unreachable; - } + return struct { + const Self = @This(); + allocator: std.mem.Allocator, + tardy: Tardy, + config: ServerConfig, + addr: std.net.Address, + tls_ctx: TLSContextType, - switch (stage) { - .header => { - const start = provision.recv_buffer.items.len -| 4; - provision.recv_buffer.appendSlice(recv_buffer) catch unreachable; - const header_ends = std.mem.lastIndexOf(u8, provision.recv_buffer.items[start..], "\r\n\r\n"); + pub fn init(config: ServerConfig) Self { + const tls_ctx = switch (comptime security) { + .tls => |inner| TLSContext.init(.{ + .allocator = config.allocator, + .cert = inner.cert, + .cert_name = inner.cert_name, + .key = inner.key, + .key_name = inner.key_name, + .size_tls_buffer_max = config.size_socket_buffer * 2, + }) catch unreachable, + .plain => void{}, + }; - // Basically, this means we haven't finished processing the header. - if (header_ends == null) { - log.debug("{d} - header doesn't end in this chunk, continue", .{provision.index}); - return .recv; + return Self{ + .allocator = config.allocator, + .tardy = Tardy.init(.{ + .allocator = config.allocator, + .threading = config.threading, + .size_tasks_max = config.size_connections_max, + .size_aio_jobs_max = config.size_connections_max, + .size_aio_reap_max = config.size_completions_reap_max, + }) catch unreachable, + .config = config, + .addr = undefined, + .tls_ctx = tls_ctx, + }; + } + + pub fn deinit(self: *Self) void { + if (comptime security == .tls) { + self.tls_ctx.deinit(); } - log.debug("{d} - parsing header", .{provision.index}); - // The +4 is to account for the slice we match. - const header_end: u32 = @intCast(header_ends.? + 4); - provision.data.request.parse_headers(provision.recv_buffer.items[0..header_end]) catch |e| { - switch (e) { - HTTPError.ContentTooLarge => { - provision.data.response.set(.{ - .status = .@"Content Too Large", - .mime = Mime.HTML, - .body = "Request was too large", - }); - }, - HTTPError.TooManyHeaders => { - provision.data.response.set(.{ - .status = .@"Request Header Fields Too Large", - .mime = Mime.HTML, - .body = "Too Many Headers", - }); - }, - HTTPError.MalformedRequest => { - provision.data.response.set(.{ - .status = .@"Bad Request", - .mime = Mime.HTML, - .body = "Malformed Request", + self.tardy.deinit(); + } + + fn create_socket(self: *const Self) !std.posix.socket_t { + const socket: std.posix.socket_t = blk: { + const socket_flags = std.posix.SOCK.STREAM | std.posix.SOCK.CLOEXEC | std.posix.SOCK.NONBLOCK; + break :blk try std.posix.socket( + self.addr.any.family, + socket_flags, + std.posix.IPPROTO.TCP, + ); + }; + + log.debug("socket | t: {s} v: {any}", .{ @typeName(std.posix.socket_t), socket }); + + if (@hasDecl(std.posix.SO, "REUSEPORT_LB")) { + try std.posix.setsockopt( + socket, + std.posix.SOL.SOCKET, + std.posix.SO.REUSEPORT_LB, + &std.mem.toBytes(@as(c_int, 1)), + ); + } else if (@hasDecl(std.posix.SO, "REUSEPORT")) { + try std.posix.setsockopt( + socket, + std.posix.SOL.SOCKET, + std.posix.SO.REUSEPORT, + &std.mem.toBytes(@as(c_int, 1)), + ); + } else { + try std.posix.setsockopt( + socket, + std.posix.SOL.SOCKET, + std.posix.SO.REUSEADDR, + &std.mem.toBytes(@as(c_int, 1)), + ); + } + + try std.posix.bind(socket, &self.addr.any, self.addr.getOsSockLen()); + return socket; + } + + /// If you are using a custom implementation that does NOT rely + /// on TCP/IP, you can SKIP calling this method and just set the + /// socket value yourself. + /// + /// This is only allowed on certain targets that do not have TCP/IP + /// support. + pub fn bind(self: *Self, host: []const u8, port: u16) !void { + assert(host.len > 0); + assert(port > 0); + + self.addr = blk: { + switch (comptime builtin.os.tag) { + .windows => break :blk try std.net.Address.parseIp(host, port), + else => break :blk try std.net.Address.resolveIp(host, port), + } + }; + } + + fn close_task(rt: *Runtime, _: *const Task, ctx: ?*anyopaque) !void { + const provision: *Provision = @ptrCast(@alignCast(ctx.?)); + assert(provision.job == .close); + const server_socket = rt.storage.get("server_socket", std.posix.socket_t); + const pool = rt.storage.get_ptr("provision_pool", Pool(Provision)); + const config = rt.storage.get_const_ptr("config", ServerConfig); + + log.info("{d} - closing connection", .{provision.index}); + + if (comptime security == .tls) { + const tls_slice = rt.storage.get("tls_slice", []TLSType); + + const tls_ptr: *TLSType = &tls_slice[provision.index]; + assert(tls_ptr.* != null); + tls_ptr.*.?.deinit(); + tls_ptr.* = null; + } + + provision.socket = Cross.socket.INVALID_SOCKET; + provision.job = .empty; + _ = provision.arena.reset(.{ .retain_with_limit = config.size_connection_arena_retain }); + provision.response.clear(); + + // TODO: new config setting here! + if (provision.recv_buffer.items.len > 1024) { + provision.recv_buffer.shrinkRetainingCapacity(1024); + } else { + provision.recv_buffer.clearRetainingCapacity(); + } + + pool.release(provision.index); + + const accept_queued = rt.storage.get_ptr("accept_queued", bool); + if (!accept_queued.*) { + accept_queued.* = true; + try rt.net.accept(.{ + .socket = server_socket, + .func = accept_task, + }); + } + } + + fn accept_task(rt: *Runtime, t: *const Task, _: ?*anyopaque) !void { + const child_socket = t.result.?.socket; + + const pool = rt.storage.get_ptr("provision_pool", Pool(Provision)); + const accept_queued = rt.storage.get_ptr("accept_queued", bool); + accept_queued.* = false; + + if (rt.scheduler.tasks.clean() >= 2) { + accept_queued.* = true; + const server_socket = rt.storage.get("server_socket", std.posix.socket_t); + try rt.net.accept(.{ + .socket = server_socket, + .func = accept_task, + }); + } + + if (!Cross.socket.is_valid(child_socket)) { + log.err("socket accept failed", .{}); + return error.AcceptFailed; + } + + // This should never fail. It means that we have a dangling item. + assert(pool.clean() > 0); + const borrowed = pool.borrow_hint(t.index) catch unreachable; + + log.info("{d} - accepting connection", .{borrowed.index}); + log.debug( + "empty provision slots: {d}", + .{pool.items.len - pool.dirty.count()}, + ); + assert(borrowed.item.job == .empty); + + try Cross.socket.disable_nagle(child_socket); + try Cross.socket.to_nonblock(child_socket); + + const provision = borrowed.item; + + // Store the index of this item. + provision.index = @intCast(borrowed.index); + provision.socket = child_socket; + + switch (comptime security) { + .tls => |_| { + const tls_ctx = rt.storage.get_const_ptr("tls_ctx", TLSContextType); + const tls_slice = rt.storage.get("tls_slice", []TLSType); + + const tls_ptr: *TLSType = &tls_slice[provision.index]; + assert(tls_ptr.* == null); + + tls_ptr.* = tls_ctx.create(child_socket) catch |e| { + log.err("{d} - tls creation failed={any}", .{ provision.index, e }); + provision.job = .close; + try rt.net.close(.{ + .socket = provision.socket, + .func = close_task, + .ctx = provision, }); - }, - HTTPError.URITooLong => { - provision.data.response.set(.{ - .status = .@"URI Too Long", - .mime = Mime.HTML, - .body = "URI Too Long", + return error.TLSCreationFailed; + }; + + const recv_buf = tls_ptr.*.?.start_handshake() catch |e| { + log.err("{d} - tls start handshake failed={any}", .{ provision.index, e }); + provision.job = .close; + try rt.net.close(.{ + .socket = provision.socket, + .func = close_task, + .ctx = provision, }); + return error.TLSStartHandshakeFailed; + }; + + provision.job = .{ .handshake = .{ .state = .recv, .count = 0 } }; + try rt.net.recv(.{ + .socket = child_socket, + .buffer = recv_buf, + .func = handshake_task, + .ctx = borrowed.item, + }); + }, + .plain => { + provision.job = .{ .recv = .{ .count = 0 } }; + try rt.net.recv(.{ + .socket = child_socket, + .buffer = provision.buffer, + .func = recv_task, + .ctx = borrowed.item, + }); + }, + } + } + + /// This is the task you MUST trigger if the `recv_fn` returns `.spawned`. + fn trigger_task(rt: *Runtime, _: *const Task, ctx: ?*anyopaque) !void { + const provision: *Provision = @ptrCast(@alignCast(ctx.?)); + + switch (provision.job) { + else => unreachable, + .recv => { + try rt.net.recv(.{ + .socket = provision.socket, + .buffer = provision.buffer, + .func = recv_task, + .ctx = provision, + }); + }, + .send => |*send_job| { + const config = rt.storage.get_const_ptr("config", ServerConfig); + const plain_buffer = send_job.slice.get(0, config.size_socket_buffer); + + switch (comptime security) { + .tls => |_| { + const tls_slice: []TLSType = @as( + [*]TLSType, + @ptrCast(@alignCast(rt.storage.get("tls_slice").?)), + )[0..config.size_connections_max]; + + const tls_ptr: *TLSType = &tls_slice[provision.index]; + assert(tls_ptr.* != null); + + const encrypted_buffer = tls_ptr.*.?.encrypt(plain_buffer) catch |e| { + log.err("{d} - encrypt failed: {any}", .{ provision.index, e }); + provision.job = .close; + try rt.net.close(.{ + .fd = provision.socket, + .func = close_task, + .ctx = provision, + }); + return error.TLSEncryptFailed; + }; + + send_job.count = plain_buffer.len; + send_job.security = .{ + .tls = .{ + .encrypted = encrypted_buffer, + .encrypted_count = 0, + }, + }; + + try rt.net.send(.{ + .socket = provision.socket, + .buffer = encrypted_buffer, + .func = send_task, + .ctx = provision, + }); + }, + .plain => { + send_job.security = .plain; + + try rt.net.send(.{ + .socket = provision.socket, + .buffer = plain_buffer, + .func = send_task, + .ctx = provision, + }); + }, + } + }, + } + } + + fn recv_task(rt: *Runtime, t: *const Task, ctx: ?*anyopaque) !void { + const provision: *Provision = @ptrCast(@alignCast(ctx.?)); + assert(provision.job == .recv); + const length: i32 = t.result.?.value; + + const config = rt.storage.get_const_ptr("config", ServerConfig); + + const recv_job = &provision.job.recv; + + // If the socket is closed. + if (length <= 0) { + provision.job = .close; + try rt.net.close(.{ + .socket = provision.socket, + .func = close_task, + .ctx = provision, + }); + return; + } + + log.debug("{d} - recv triggered", .{provision.index}); + + const recv_count: usize = @intCast(length); + recv_job.count += recv_count; + const pre_recv_buffer = provision.buffer[0..recv_count]; + + const recv_buffer = blk: { + switch (comptime security) { + .tls => |_| { + const tls_slice = rt.storage.get("tls_slice", []TLSType); + + const tls_ptr: *TLSType = &tls_slice[provision.index]; + assert(tls_ptr.* != null); + + break :blk tls_ptr.*.?.decrypt(pre_recv_buffer) catch |e| { + log.err("{d} - decrypt failed: {any}", .{ provision.index, e }); + provision.job = .close; + try rt.net.close(.{ + .socket = provision.socket, + .func = close_task, + .ctx = provision, + }); + return error.TLSDecryptFailed; + }; }, - HTTPError.InvalidMethod => { - provision.data.response.set(.{ - .status = .@"Not Implemented", - .mime = Mime.HTML, - .body = "Not Implemented", + .plain => break :blk pre_recv_buffer, + } + }; + + var status: RecvStatus = status: { + var stage = provision.stage; + const job = provision.job.recv; + + if (job.count >= config.size_request_max) { + provision.response.set(.{ + .status = .@"Content Too Large", + .mime = Mime.HTML, + .body = "Request was too large", + }); + + break :status raw_respond(provision) catch unreachable; + } + + switch (stage) { + .header => { + const start = provision.recv_buffer.items.len -| 4; + provision.recv_buffer.appendSlice(recv_buffer) catch unreachable; + const header_ends = std.mem.lastIndexOf(u8, provision.recv_buffer.items[start..], "\r\n\r\n"); + + // Basically, this means we haven't finished processing the header. + if (header_ends == null) { + log.debug("{d} - header doesn't end in this chunk, continue", .{provision.index}); + break :status .recv; + } + + log.debug("{d} - parsing header", .{provision.index}); + // The +4 is to account for the slice we match. + const header_end: u32 = @intCast(header_ends.? + 4); + provision.request.parse_headers(provision.recv_buffer.items[0..header_end]) catch |e| { + switch (e) { + HTTPError.ContentTooLarge => { + provision.response.set(.{ + .status = .@"Content Too Large", + .mime = Mime.HTML, + .body = "Request was too large", + }); + }, + HTTPError.TooManyHeaders => { + provision.response.set(.{ + .status = .@"Request Header Fields Too Large", + .mime = Mime.HTML, + .body = "Too Many Headers", + }); + }, + HTTPError.MalformedRequest => { + provision.response.set(.{ + .status = .@"Bad Request", + .mime = Mime.HTML, + .body = "Malformed Request", + }); + }, + HTTPError.URITooLong => { + provision.response.set(.{ + .status = .@"URI Too Long", + .mime = Mime.HTML, + .body = "URI Too Long", + }); + }, + HTTPError.InvalidMethod => { + provision.response.set(.{ + .status = .@"Not Implemented", + .mime = Mime.HTML, + .body = "Not Implemented", + }); + }, + HTTPError.HTTPVersionNotSupported => { + provision.response.set(.{ + .status = .@"HTTP Version Not Supported", + .mime = Mime.HTML, + .body = "HTTP Version Not Supported", + }); + }, + } + + break :status raw_respond(provision) catch unreachable; + }; + + // Logging information about Request. + log.info("{d} - \"{s} {s}\" {s}", .{ + provision.index, + @tagName(provision.request.method), + provision.request.uri, + provision.request.headers.get("User-Agent") orelse "N/A", }); + + // HTTP/1.1 REQUIRES a Host header to be present. + const is_http_1_1 = provision.request.version == .@"HTTP/1.1"; + const is_host_present = provision.request.headers.get("Host") != null; + if (is_http_1_1 and !is_host_present) { + provision.response.set(.{ + .status = .@"Bad Request", + .mime = Mime.HTML, + .body = "Missing \"Host\" Header", + }); + + break :status raw_respond(provision) catch unreachable; + } + + if (!provision.request.expect_body()) { + break :status route_and_respond(rt, trigger_task, provision, config.router) catch unreachable; + } + + // Everything after here is a Request that is expecting a body. + const content_length = blk: { + const length_string = provision.request.headers.get("Content-Length") orelse { + break :blk 0; + }; + + break :blk std.fmt.parseInt(u32, length_string, 10) catch { + provision.response.set(.{ + .status = .@"Bad Request", + .mime = Mime.HTML, + .body = "", + }); + + break :status raw_respond(provision) catch unreachable; + }; + }; + + if (header_end < provision.recv_buffer.items.len) { + const difference = provision.recv_buffer.items.len - header_end; + if (difference == content_length) { + // Whole Body + log.debug("{d} - got whole body with header", .{provision.index}); + const body_end = header_end + difference; + provision.request.set_body(provision.recv_buffer.items[header_end..body_end]); + break :status route_and_respond(rt, trigger_task, provision, config.router) catch unreachable; + } else { + // Partial Body + log.debug("{d} - got partial body with header", .{provision.index}); + stage = .{ .body = header_end }; + break :status .recv; + } + } else if (header_end == provision.recv_buffer.items.len) { + // Body of length 0 probably or only got header. + if (content_length == 0) { + log.debug("{d} - got body of length 0", .{provision.index}); + // Body of Length 0. + provision.request.set_body(""); + break :status route_and_respond(rt, trigger_task, provision, config.router) catch unreachable; + } else { + // Got only header. + log.debug("{d} - got all header aka no body", .{provision.index}); + stage = .{ .body = header_end }; + break :status .recv; + } + } else unreachable; }, - HTTPError.HTTPVersionNotSupported => { - provision.data.response.set(.{ - .status = .@"HTTP Version Not Supported", - .mime = Mime.HTML, - .body = "HTTP Version Not Supported", - }); + + .body => |header_end| { + // We should ONLY be here if we expect there to be a body. + assert(provision.request.expect_body()); + log.debug("{d} - body matching trigger_tasked", .{provision.index}); + + const content_length = blk: { + const length_string = provision.request.headers.get("Content-Length") orelse { + provision.response.set(.{ + .status = .@"Length Required", + .mime = Mime.HTML, + .body = "", + }); + + break :status raw_respond(provision) catch unreachable; + }; + + break :blk std.fmt.parseInt(u32, length_string, 10) catch { + provision.response.set(.{ + .status = .@"Bad Request", + .mime = Mime.HTML, + .body = "", + }); + + break :status raw_respond(provision) catch unreachable; + }; + }; + + const request_length = header_end + content_length; + + // If this body will be too long, abort early. + if (request_length > config.size_request_max) { + provision.response.set(.{ + .status = .@"Content Too Large", + .mime = Mime.HTML, + .body = "", + }); + break :status raw_respond(provision) catch unreachable; + } + + if (job.count >= request_length) { + provision.request.set_body(provision.recv_buffer.items[header_end..request_length]); + break :status route_and_respond(rt, trigger_task, provision, config.router) catch unreachable; + } else { + break :status .recv; + } }, } - - return raw_respond(provision) catch unreachable; }; - // Logging information about Request. - log.info("{d} - \"{s} {s}\" {s}", .{ - provision.index, - @tagName(provision.data.request.method), - provision.data.request.uri, - provision.data.request.headers.get("User-Agent") orelse "N/A", - }); - - // HTTP/1.1 REQUIRES a Host header to be present. - const is_http_1_1 = provision.data.request.version == .@"HTTP/1.1"; - const is_host_present = provision.data.request.headers.get("Host") != null; - if (is_http_1_1 and !is_host_present) { - provision.data.response.set(.{ - .status = .@"Bad Request", - .mime = Mime.HTML, - .body = "Missing \"Host\" Header", - }); + switch (status) { + .spawned => return, + .kill => { + rt.stop(); + return error.Killed; + }, + .recv => { + try rt.net.recv(.{ + .socket = provision.socket, + .buffer = provision.buffer, + .func = recv_task, + .ctx = provision, + }); + }, + .send => |*pslice| { + const plain_buffer = pslice.get(0, config.size_socket_buffer); - return raw_respond(provision) catch unreachable; - } + switch (comptime security) { + .tls => |_| { + const tls_slice = rt.storage.get("tls_slice", []TLSType); + + const tls_ptr: *TLSType = &tls_slice[provision.index]; + assert(tls_ptr.* != null); + + const encrypted_buffer = tls_ptr.*.?.encrypt(plain_buffer) catch |e| { + log.err("{d} - encrypt failed: {any}", .{ provision.index, e }); + provision.job = .close; + try rt.net.close(.{ + .socket = provision.socket, + .func = close_task, + .ctx = provision, + }); + return error.TLSEncryptFailed; + }; - if (!provision.data.request.expect_body()) { - return route_and_respond(runtime, trigger, provision, p_config.router) catch unreachable; + provision.job = .{ + .send = .{ + .slice = pslice.*, + .count = @intCast(plain_buffer.len), + .security = .{ + .tls = .{ + .encrypted = encrypted_buffer, + .encrypted_count = 0, + }, + }, + }, + }; + + try rt.net.send(.{ + .socket = provision.socket, + .buffer = encrypted_buffer, + .func = send_task, + .ctx = provision, + }); + }, + .plain => { + provision.job = .{ + .send = .{ + .slice = pslice.*, + .count = 0, + .security = .plain, + }, + }; + + try rt.net.send(.{ + .socket = provision.socket, + .buffer = plain_buffer, + .func = send_task, + .ctx = provision, + }); + }, + } + }, } + } - // Everything after here is a Request that is expecting a body. - const content_length = blk: { - const length_string = provision.data.request.headers.get("Content-Length") orelse { - break :blk 0; - }; + fn handshake_task(rt: *Runtime, t: *const Task, ctx: ?*anyopaque) !void { + log.debug("Handshake Task", .{}); + assert(security == .tls); + const provision: *Provision = @ptrCast(@alignCast(ctx.?)); + const length: i32 = t.result.?.value; - break :blk std.fmt.parseInt(u32, length_string, 10) catch { - provision.data.response.set(.{ - .status = .@"Bad Request", - .mime = Mime.HTML, - .body = "", - }); + if (comptime security == .tls) { + const tls_slice = rt.storage.get("tls_slice", []TLSType); - return raw_respond(provision) catch unreachable; - }; - }; + assert(provision.job == .handshake); + const handshake_job = &provision.job.handshake; - if (header_end < provision.recv_buffer.items.len) { - const difference = provision.recv_buffer.items.len - header_end; - if (difference == content_length) { - // Whole Body - log.debug("{d} - got whole body with header", .{provision.index}); - const body_end = header_end + difference; - provision.data.request.set_body(provision.recv_buffer.items[header_end..body_end]); - return route_and_respond(runtime, trigger, provision, p_config.router) catch unreachable; - } else { - // Partial Body - log.debug("{d} - got partial body with header", .{provision.index}); - stage = .{ .body = header_end }; - return .recv; + const tls_ptr: *TLSType = &tls_slice[provision.index]; + assert(tls_ptr.* != null); + log.debug("processing handshake", .{}); + handshake_job.count += 1; + + if (length <= 0) { + log.debug("handshake connection closed", .{}); + provision.job = .close; + try rt.net.close(.{ + .socket = provision.socket, + .func = close_task, + .ctx = provision, + }); + return error.TLSHandshakeClosed; } - } else if (header_end == provision.recv_buffer.items.len) { - // Body of length 0 probably or only got header. - if (content_length == 0) { - log.debug("{d} - got body of length 0", .{provision.index}); - // Body of Length 0. - provision.data.request.set_body(""); - return route_and_respond(runtime, trigger, provision, p_config.router) catch unreachable; - } else { - // Got only header. - log.debug("{d} - got all header aka no body", .{provision.index}); - stage = .{ .body = header_end }; - return .recv; + + if (handshake_job.count >= 50) { + log.debug("handshake taken too many cycles", .{}); + provision.job = .close; + try rt.net.close(.{ + .socket = provision.socket, + .func = close_task, + .ctx = provision, + }); + return error.TLSHandshakeTooManyCycles; } - } else unreachable; - }, - .body => |header_end| { - // We should ONLY be here if we expect there to be a body. - assert(provision.data.request.expect_body()); - log.debug("{d} - body matching triggered", .{provision.index}); + const hs_length: usize = @intCast(length); - const content_length = blk: { - const length_string = provision.data.request.headers.get("Content-Length") orelse { - provision.data.response.set(.{ - .status = .@"Length Required", - .mime = Mime.HTML, - .body = "", - }); + switch (handshake_job.state) { + .recv => { + // on recv, we want to read from socket and feed into tls engien + const hstate = tls_ptr.*.?.continue_handshake( + .{ .recv = @intCast(hs_length) }, + ) catch |e| { + log.err("{d} - tls handshake on recv failed={any}", .{ provision.index, e }); + provision.job = .close; + try rt.net.close(.{ + .socket = provision.socket, + .func = close_task, + .ctx = provision, + }); + return error.TLSHandshakeRecvFailed; + }; - return raw_respond(provision) catch unreachable; - }; + switch (hstate) { + .recv => |buf| { + log.debug("requeing recv in handshake", .{}); + try rt.net.recv(.{ + .socket = provision.socket, + .buffer = buf, + .func = handshake_task, + .ctx = provision, + }); + }, + .send => |buf| { + log.debug("queueing send in handshake", .{}); + handshake_job.state = .send; + try rt.net.send(.{ + .socket = provision.socket, + .buffer = buf, + .func = handshake_task, + .ctx = provision, + }); + }, + .complete => { + log.debug("handshake complete", .{}); + provision.job = .{ .recv = .{ .count = 0 } }; + try rt.net.recv(.{ + .socket = provision.socket, + .buffer = provision.buffer, + .func = recv_task, + .ctx = provision, + }); + }, + } + }, + .send => { + // on recv, we want to read from socket and feed into tls engien + const hstate = tls_ptr.*.?.continue_handshake( + .{ .send = @intCast(hs_length) }, + ) catch |e| { + log.err("{d} - tls handshake on send failed={any}", .{ provision.index, e }); + provision.job = .close; + try rt.net.close(.{ + .socket = provision.socket, + .func = close_task, + .ctx = provision, + }); + return error.TLSHandshakeSendFailed; + }; - break :blk std.fmt.parseInt(u32, length_string, 10) catch { - provision.data.response.set(.{ - .status = .@"Bad Request", - .mime = Mime.HTML, - .body = "", - }); + switch (hstate) { + .recv => |buf| { + handshake_job.state = .recv; + log.debug("queuing recv in handshake", .{}); + try rt.net.recv(.{ + .socket = provision.socket, + .buffer = buf, + .func = handshake_task, + .ctx = provision, + }); + }, + .send => |buf| { + log.debug("requeing send in handshake", .{}); + try rt.net.send(.{ + .socket = provision.socket, + .buffer = buf, + .func = handshake_task, + .ctx = provision, + }); + }, + .complete => { + log.debug("handshake complete", .{}); + provision.job = .{ .recv = .{ .count = 0 } }; + try rt.net.recv(.{ + .socket = provision.socket, + .buffer = provision.buffer, + .func = recv_task, + .ctx = provision, + }); + }, + } + }, + } + } else unreachable; + } - return raw_respond(provision) catch unreachable; - }; - }; + fn send_task(rt: *Runtime, t: *const Task, ctx: ?*anyopaque) !void { + const provision: *Provision = @ptrCast(@alignCast(ctx.?)); + assert(provision.job == .send); + const length: i32 = t.result.?.value; - const request_length = header_end + content_length; + const config = rt.storage.get_const_ptr("config", ServerConfig); - // If this body will be too long, abort early. - if (request_length > p_config.size_request_max) { - provision.data.response.set(.{ - .status = .@"Content Too Large", - .mime = Mime.HTML, - .body = "", + // If the socket is closed. + if (length <= 0) { + provision.job = .close; + try rt.net.close(.{ + .socket = provision.socket, + .func = close_task, + .ctx = provision, }); - return raw_respond(provision) catch unreachable; + return; } - if (job.count >= request_length) { - provision.data.request.set_body(provision.recv_buffer.items[header_end..request_length]); - return route_and_respond(runtime, trigger, provision, p_config.router) catch unreachable; - } else { - return .recv; + const send_job = &provision.job.send; + + log.debug("{d} - send triggered", .{provision.index}); + const send_count: usize = @intCast(length); + log.debug("{d} - send length: {d}", .{ provision.index, send_count }); + + switch (comptime security) { + .tls => { + assert(send_job.security == .tls); + + const tls_slice = rt.storage.get("tls_slice", []TLSType); + + const job_tls = &send_job.security.tls; + job_tls.encrypted_count += send_count; + + if (job_tls.encrypted_count >= job_tls.encrypted.len) { + if (send_job.count >= send_job.slice.len) { + // All done sending. + log.debug("{d} - queueing a new recv", .{provision.index}); + _ = provision.arena.reset(.{ + .retain_with_limit = config.size_connection_arena_retain, + }); + provision.recv_buffer.clearRetainingCapacity(); + provision.job = .{ .recv = .{ .count = 0 } }; + + try rt.net.recv(.{ + .socket = provision.socket, + .buffer = provision.buffer, + .func = recv_task, + .ctx = provision, + }); + } else { + // Queue a new chunk up for sending. + log.debug( + "{d} - sending next chunk starting at index {d}", + .{ provision.index, send_job.count }, + ); + + const inner_slice = send_job.slice.get( + send_job.count, + send_job.count + config.size_socket_buffer, + ); + + send_job.count += @intCast(inner_slice.len); + + const tls_ptr: *TLSType = &tls_slice[provision.index]; + assert(tls_ptr.* != null); + + const encrypted = tls_ptr.*.?.encrypt(inner_slice) catch |e| { + log.err("{d} - encrypt failed: {any}", .{ provision.index, e }); + provision.job = .close; + try rt.net.close(.{ + .socket = provision.socket, + .func = close_task, + .ctx = provision, + }); + return error.TLSEncryptFailed; + }; + + job_tls.encrypted = encrypted; + job_tls.encrypted_count = 0; + + try rt.net.send(.{ + .socket = provision.socket, + .buffer = job_tls.encrypted, + .func = send_task, + .ctx = provision, + }); + } + } else { + log.debug( + "{d} - sending next encrypted chunk starting at index {d}", + .{ provision.index, job_tls.encrypted_count }, + ); + + const remainder = job_tls.encrypted[job_tls.encrypted_count..]; + try rt.net.send(.{ + .socket = provision.socket, + .buffer = remainder, + .func = send_task, + .ctx = provision, + }); + } + }, + .plain => { + assert(send_job.security == .plain); + send_job.count += send_count; + + if (send_job.count >= send_job.slice.len) { + log.debug("{d} - queueing a new recv", .{provision.index}); + _ = provision.arena.reset(.{ + .retain_with_limit = config.size_connection_arena_retain, + }); + provision.recv_buffer.clearRetainingCapacity(); + provision.job = .{ .recv = .{ .count = 0 } }; + + try rt.net.recv(.{ + .socket = provision.socket, + .buffer = provision.buffer, + .func = recv_task, + .ctx = provision, + }); + } else { + log.debug( + "{d} - sending next chunk starting at index {d}", + .{ provision.index, send_job.count }, + ); + + const plain_buffer = send_job.slice.get( + send_job.count, + send_job.count + config.size_socket_buffer, + ); + + log.debug("{d} - chunk ends at: {d}", .{ + provision.index, + plain_buffer.len + send_job.count, + }); + + try rt.net.send(.{ + .socket = provision.socket, + .buffer = plain_buffer, + .func = send_task, + .ctx = provision, + }); + } + }, } - }, - } -} + } + + pub fn listen(self: *Self) !void { + log.info("server listening...", .{}); + log.info("security mode: {s}", .{@tagName(security)}); + + try self.tardy.entry( + struct { + fn rt_start(rt: *Runtime, alloc: std.mem.Allocator, zzz: *Self) !void { + const socket = try zzz.create_socket(); + try std.posix.listen(socket, zzz.config.size_backlog); + + const provision_pool = try alloc.create(Pool(Provision)); + provision_pool.* = try Pool(Provision).init( + alloc, + zzz.config.size_connections_max, + Provision.init_hook, + zzz.config, + ); + + try rt.storage.store_ptr("provision_pool", provision_pool); + try rt.storage.store_ptr("config", &zzz.config); + + if (comptime security == .tls) { + const tls_slice = try alloc.alloc( + TLSType, + zzz.config.size_connections_max, + ); + if (comptime security == .tls) { + for (tls_slice) |*tls| { + tls.* = null; + } + } -pub fn Server(comptime security: Security, comptime async_type: AsyncIOType) type { - return zzzServer(security, async_type, ProtocolData, ProtocolConfig, recv_fn); + // since slices are fat pointers... + try rt.storage.store_alloc("tls_slice", tls_slice); + try rt.storage.store_ptr("tls_ctx", zzz.tls_ctx); + } + + try rt.storage.store_alloc("server_socket", socket); + try rt.storage.store_alloc("accept_queued", true); + + try rt.net.accept(.{ + .socket = socket, + .func = accept_task, + }); + } + }.rt_start, + self, + struct { + fn rt_end(rt: *Runtime, alloc: std.mem.Allocator, _: anytype) void { + // clean up socket. + const server_socket = rt.storage.get("server_socket", std.posix.socket_t); + std.posix.close(server_socket); + + // clean up provision pool. + const provision_pool = rt.storage.get_ptr("provision_pool", Pool(Provision)); + provision_pool.deinit(Provision.deinit_hook, alloc); + alloc.destroy(provision_pool); + + // clean up TLS. + if (comptime security == .tls) { + const tls_slice = rt.storage.get("tls_slice", []TLSType); + alloc.free(tls_slice); + } + } + }.rt_end, + void, + ); + } + }; }