From 799b0ed1f30b7f65e16f6093a4f529cbb49a5a22 Mon Sep 17 00:00:00 2001 From: patrick-scho Date: Mon, 30 Jun 2025 14:28:04 +0200 Subject: [PATCH] Initial commit --- ws_epoll.zig | 194 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 ws_epoll.zig diff --git a/ws_epoll.zig b/ws_epoll.zig new file mode 100644 index 0000000..6973093 --- /dev/null +++ b/ws_epoll.zig @@ -0,0 +1,194 @@ +const std = @import("std"); +const net = std.net; +const posix = std.posix; +const linux = std.os.linux; +const echo = std.debug.print; + +const html = + \\ + \\ +; + +var g_buffer: [1024]u8 = undefined; + +fn respondHttp(stream: std.net.Stream, msg: []const u8) !void { + _ = msg; + + try std.fmt.format( + stream.writer(), + "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{s}", + .{ html.len, html }, + ); +} + +fn respondWs(stream: std.net.Stream, msg: []const u8) !void { + // TODO: handle msg length (< 128) + var response_buffer = try std.BoundedArray(u8, 128).init(0); + var writer = response_buffer.writer(); + _ = try writer.writeByte(0b10000001); + _ = try writer.writeByte(@truncate(msg.len)); + _ = try writer.write(msg); + _ = try stream.write(response_buffer.constSlice()); +} + +fn handleHttp(stream: std.net.Stream, msg: []const u8) !bool { + echo("received: {s}\n", .{msg}); + try respondHttp(stream, msg); + return true; +} + +fn handleWsUpgrade(stream: std.net.Stream, msg: []const u8) !bool { + if (std.mem.indexOf(u8, msg, "Sec-WebSocket-Key")) |idx| { + const end = std.mem.indexOfScalarPos(u8, msg, idx, '\r') orelse return error.InvalidUpgradeRequest; + + const key = msg[idx + 19 .. end]; + echo("key: {s}\n", .{key}); + + var concat = try std.BoundedArray(u8, 128).init(0); + _ = try concat.writer().write(key); + _ = try concat.writer().write("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + + var hashed: [20]u8 = undefined; + std.crypto.hash.Sha1.hash(concat.constSlice(), &hashed, .{}); + + var encoded: [28]u8 = undefined; + _ = std.base64.standard.Encoder.encode(&encoded, &hashed); + + try std.fmt.format( + stream.writer(), + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {s}\r\n\r\n", + .{encoded}, + ); + + return true; + } + return false; +} + +fn handleWs(stream: std.net.Stream, frame: []const u8) !bool { + // TODO: handle ping/pong + // TODO: handle closing + + // zig fmt: off + const FIN = (frame[0] & 0b10000000) >> 7; + const RSV1 = (frame[0] & 0b01000000) >> 6; + const RSV2 = (frame[0] & 0b00100000) >> 5; + const RSV3 = (frame[0] & 0b00010000) >> 4; + const OPCODE = (frame[0] & 0b00001111) >> 0; + const MASK = (frame[1] & 0b10000000) >> 7; + + var msgIndex: usize = 2; + + var payloadLength: usize = frame[1] & 0b01111111 >> 0; + if (payloadLength == 126) { + payloadLength = std.mem.readVarInt(u16, frame[2..4],.big); + msgIndex = 4; + } + else if (payloadLength == 127) { + payloadLength = std.mem.readVarInt(u16, frame[2..6],.big); + msgIndex = 6; + // const msb = payloadLength >> 63; + } + // zig fmt: on + + const mask = frame[msgIndex .. msgIndex + 4]; + + const encoded = frame[msgIndex + 4 .. frame.len]; + var decoded = try std.BoundedArray(u8, 1024).init(0); + + for (0..encoded.len) |i| { + try decoded.append(encoded[i] ^ mask[i % 4]); + } + + echo("tcp: [{}] {s}\n", .{ .{ FIN, RSV1, RSV2, RSV3, OPCODE, MASK, payloadLength, mask }, decoded.constSlice() }); + + try respondWs(stream, "yooooooo"); + // _ = try stream.write(frame); + // std.debug.print("frame: {b}\n", .{frame}); + return true; +} + +fn handle(client: *Client) !bool { + var stream = std.net.Stream{ .handle = client.fd }; + + const bytesReceived = try stream.read(&g_buffer); + if (bytesReceived == 0) { + return false; + } + const msg = g_buffer[0..bytesReceived]; + + if (client.ws) { + return try handleWs(stream, msg); + } else if (std.mem.indexOf(u8, msg, "Sec-WebSocket-Key")) |_| { + const success = try handleWsUpgrade(stream, msg); + if (success) { + client.ws = true; + } + return success; + } else { + return try handleHttp(stream, msg); + } +} + +const Client = struct { + fd: i32, + ws: bool = false, +}; + +pub fn main() !void { + const address = try std.net.Address.parseIp("127.0.0.1", 8080); + + const tpe: u32 = posix.SOCK.STREAM | posix.SOCK.NONBLOCK; + const protocol = posix.IPPROTO.TCP; + const listener = try posix.socket(address.any.family, tpe, protocol); + defer posix.close(listener); + + try posix.setsockopt(listener, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); + try posix.bind(listener, &address.any, address.getOsSockLen()); + try posix.listen(listener, 128); + + // epoll_create1 takes flags. We aren't using any in these examples + const efd = try posix.epoll_create1(0); + defer posix.close(efd); + + const listener_client = Client{ .fd = listener }; + { + // monitor our listening socket + var event = linux.epoll_event{ .events = linux.EPOLL.IN, .data = .{ .ptr = @intFromPtr(&listener_client) } }; + try posix.epoll_ctl(efd, linux.EPOLL.CTL_ADD, listener, &event); + } + + var ready_list: [128]linux.epoll_event = undefined; + var client_list = try std.BoundedArray(Client, 128).init(0); + + while (true) { + const ready_count = posix.epoll_wait(efd, &ready_list, -1); + for (ready_list[0..ready_count]) |ready| { + const client: *Client = @ptrFromInt(ready.data.ptr); + const ready_socket = client.fd; + if (ready_socket == listener) { + const client_socket = try posix.accept(listener, null, null, posix.SOCK.NONBLOCK); + errdefer posix.close(client_socket); + var new_client = try client_list.addOne(); + new_client.fd = client_socket; + var event = linux.epoll_event{ .events = linux.EPOLL.IN, .data = .{ .ptr = @intFromPtr(new_client) } }; + try posix.epoll_ctl(efd, linux.EPOLL.CTL_ADD, client_socket, &event); + } else { + const closed = !(handle(client) catch false); + + if (closed or ready.events & linux.EPOLL.RDHUP == linux.EPOLL.RDHUP) { + posix.close(ready_socket); + } + } + } + } +} -- 2.50.1