]> gitweb.ps.run Git - zigws/blob - ws_epoll.zig
697309319e27b090b65c20f5eb5133a880d2b5c4
[zigws] / ws_epoll.zig
1 const std = @import("std");
2 const net = std.net;
3 const posix = std.posix;
4 const linux = std.os.linux;
5 const echo = std.debug.print;
6
7 const html =
8     \\<script>
9     \\let s = new WebSocket("ws://localhost:8080/");
10     \\s.addEventListener("open", (event) => {
11     \\  console.log("connected");
12     \\  s.send("Hello Server!");
13     \\});
14     \\s.addEventListener("message", (event) => {
15     \\  console.log("Message from server ", event.data);
16     \\});
17     \\</script>
18     \\<button onclick="s.send('hallo')">ws</button>
19 ;
20
21 var g_buffer: [1024]u8 = undefined;
22
23 fn respondHttp(stream: std.net.Stream, msg: []const u8) !void {
24     _ = msg;
25
26     try std.fmt.format(
27         stream.writer(),
28         "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{s}",
29         .{ html.len, html },
30     );
31 }
32
33 fn respondWs(stream: std.net.Stream, msg: []const u8) !void {
34     // TODO: handle msg length (< 128)
35     var response_buffer = try std.BoundedArray(u8, 128).init(0);
36     var writer = response_buffer.writer();
37     _ = try writer.writeByte(0b10000001);
38     _ = try writer.writeByte(@truncate(msg.len));
39     _ = try writer.write(msg);
40     _ = try stream.write(response_buffer.constSlice());
41 }
42
43 fn handleHttp(stream: std.net.Stream, msg: []const u8) !bool {
44     echo("received: {s}\n", .{msg});
45     try respondHttp(stream, msg);
46     return true;
47 }
48
49 fn handleWsUpgrade(stream: std.net.Stream, msg: []const u8) !bool {
50     if (std.mem.indexOf(u8, msg, "Sec-WebSocket-Key")) |idx| {
51         const end = std.mem.indexOfScalarPos(u8, msg, idx, '\r') orelse return error.InvalidUpgradeRequest;
52
53         const key = msg[idx + 19 .. end];
54         echo("key: {s}\n", .{key});
55
56         var concat = try std.BoundedArray(u8, 128).init(0);
57         _ = try concat.writer().write(key);
58         _ = try concat.writer().write("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
59
60         var hashed: [20]u8 = undefined;
61         std.crypto.hash.Sha1.hash(concat.constSlice(), &hashed, .{});
62
63         var encoded: [28]u8 = undefined;
64         _ = std.base64.standard.Encoder.encode(&encoded, &hashed);
65
66         try std.fmt.format(
67             stream.writer(),
68             "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {s}\r\n\r\n",
69             .{encoded},
70         );
71
72         return true;
73     }
74     return false;
75 }
76
77 fn handleWs(stream: std.net.Stream, frame: []const u8) !bool {
78     // TODO: handle ping/pong
79     // TODO: handle closing
80
81     // zig fmt: off
82     const FIN    = (frame[0] & 0b10000000) >> 7;
83     const RSV1   = (frame[0] & 0b01000000) >> 6;
84     const RSV2   = (frame[0] & 0b00100000) >> 5;
85     const RSV3   = (frame[0] & 0b00010000) >> 4;
86     const OPCODE = (frame[0] & 0b00001111) >> 0;
87     const MASK   = (frame[1] & 0b10000000) >> 7;
88
89     var msgIndex: usize = 2;
90
91     var payloadLength: usize = frame[1] & 0b01111111 >> 0;
92     if (payloadLength == 126) {
93         payloadLength = std.mem.readVarInt(u16, frame[2..4],.big);
94         msgIndex = 4;
95     }
96     else if (payloadLength == 127) {
97         payloadLength = std.mem.readVarInt(u16, frame[2..6],.big);
98         msgIndex = 6;
99         // const msb = payloadLength >> 63;
100     }
101     // zig fmt: on
102
103     const mask = frame[msgIndex .. msgIndex + 4];
104
105     const encoded = frame[msgIndex + 4 .. frame.len];
106     var decoded = try std.BoundedArray(u8, 1024).init(0);
107
108     for (0..encoded.len) |i| {
109         try decoded.append(encoded[i] ^ mask[i % 4]);
110     }
111
112     echo("tcp: [{}] {s}\n", .{ .{ FIN, RSV1, RSV2, RSV3, OPCODE, MASK, payloadLength, mask }, decoded.constSlice() });
113
114     try respondWs(stream, "yooooooo");
115     // _ = try stream.write(frame);
116     // std.debug.print("frame: {b}\n", .{frame});
117     return true;
118 }
119
120 fn handle(client: *Client) !bool {
121     var stream = std.net.Stream{ .handle = client.fd };
122
123     const bytesReceived = try stream.read(&g_buffer);
124     if (bytesReceived == 0) {
125         return false;
126     }
127     const msg = g_buffer[0..bytesReceived];
128
129     if (client.ws) {
130         return try handleWs(stream, msg);
131     } else if (std.mem.indexOf(u8, msg, "Sec-WebSocket-Key")) |_| {
132         const success = try handleWsUpgrade(stream, msg);
133         if (success) {
134             client.ws = true;
135         }
136         return success;
137     } else {
138         return try handleHttp(stream, msg);
139     }
140 }
141
142 const Client = struct {
143     fd: i32,
144     ws: bool = false,
145 };
146
147 pub fn main() !void {
148     const address = try std.net.Address.parseIp("127.0.0.1", 8080);
149
150     const tpe: u32 = posix.SOCK.STREAM | posix.SOCK.NONBLOCK;
151     const protocol = posix.IPPROTO.TCP;
152     const listener = try posix.socket(address.any.family, tpe, protocol);
153     defer posix.close(listener);
154
155     try posix.setsockopt(listener, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
156     try posix.bind(listener, &address.any, address.getOsSockLen());
157     try posix.listen(listener, 128);
158
159     // epoll_create1 takes flags. We aren't using any in these examples
160     const efd = try posix.epoll_create1(0);
161     defer posix.close(efd);
162
163     const listener_client = Client{ .fd = listener };
164     {
165         // monitor our listening socket
166         var event = linux.epoll_event{ .events = linux.EPOLL.IN, .data = .{ .ptr = @intFromPtr(&listener_client) } };
167         try posix.epoll_ctl(efd, linux.EPOLL.CTL_ADD, listener, &event);
168     }
169
170     var ready_list: [128]linux.epoll_event = undefined;
171     var client_list = try std.BoundedArray(Client, 128).init(0);
172
173     while (true) {
174         const ready_count = posix.epoll_wait(efd, &ready_list, -1);
175         for (ready_list[0..ready_count]) |ready| {
176             const client: *Client = @ptrFromInt(ready.data.ptr);
177             const ready_socket = client.fd;
178             if (ready_socket == listener) {
179                 const client_socket = try posix.accept(listener, null, null, posix.SOCK.NONBLOCK);
180                 errdefer posix.close(client_socket);
181                 var new_client = try client_list.addOne();
182                 new_client.fd = client_socket;
183                 var event = linux.epoll_event{ .events = linux.EPOLL.IN, .data = .{ .ptr = @intFromPtr(new_client) } };
184                 try posix.epoll_ctl(efd, linux.EPOLL.CTL_ADD, client_socket, &event);
185             } else {
186                 const closed = !(handle(client) catch false);
187
188                 if (closed or ready.events & linux.EPOLL.RDHUP == linux.EPOLL.RDHUP) {
189                     posix.close(ready_socket);
190                 }
191             }
192         }
193     }
194 }