ws-server.c (5056B)
1 #define _DEFAULT_SOURCE 1 2 #define _GNU_SOURCE 1 3 #define _POSIX_C_SOURCE 1 4 5 #include <signal.h> 6 #include <stdio.h> 7 #include <stdlib.h> 8 #include <string.h> 9 #include <unistd.h> 10 11 #include <arpa/inet.h> 12 #include <sys/socket.h> 13 #include <netdb.h> 14 15 #include <openssl/evp.h> 16 #include <openssl/sha.h> 17 18 #include "sockaddr.h" 19 #include "http.h" 20 #include "ws.h" 21 22 static int 23 do_http_handshake(int client, struct arena *arena, char *protocols, char *extensions) 24 { 25 char buf[4096], *term = NULL; 26 size_t len = 0; 27 28 /* recv handshake request */ 29 if (http_receive_msg(client, buf, sizeof buf, &term) < 0) 30 return -1; 31 32 struct http_msg request; 33 if (http_parse_msg(arena, buf, &request) < 0) 34 return -1; 35 36 http_print_msg(&request); 37 38 /* handle handshake request */ 39 struct http_header *origin, *host, *connection, *upgrade; 40 struct http_header *ws_ver, *ws_key, *ws_proto, *ws_ext; 41 42 origin = http_header_find(request.headers, "origin"); 43 host = http_header_find(request.headers, "host"); 44 connection = http_header_find(request.headers, "connection"); 45 upgrade = http_header_find(request.headers, "upgrade"); 46 47 (void) origin; 48 49 if (!host || !connection || !upgrade) return -1; 50 if (!http_header_has_value(connection, "Upgrade")) return -1; 51 if (!http_header_has_value(upgrade, "websocket")) return -1; 52 53 ws_ver = http_header_find(request.headers, "sec-websocket-version"); 54 ws_key = http_header_find(request.headers, "sec-websocket-key"); 55 ws_proto = http_header_find(request.headers, "sec-websocket-protocol"); 56 ws_ext = http_header_find(request.headers, "sec-websocket-extensions"); 57 58 if (!ws_ver || !ws_key) return -1; 59 if (!http_header_has_value(ws_ver, "13")) return -1; 60 61 unsigned char *keyenc = (unsigned char *) ws_key->val, acceptenc[WS_KEY_LENGTH]; 62 int acceptenc_len = ws_key_digest(keyenc, strlen(ws_key->val), acceptenc); 63 64 /* build handshake response */ 65 size_t cap = snprintf(buf, sizeof buf, 66 "HTTP/1.1 101 Switching Protocols\r\n" 67 "Connection: Upgrade\r\n" 68 "Upgrade: websocket\r\n" 69 "Sec-WebSocket-Accept: %.*s\r\n", acceptenc_len, acceptenc); 70 71 if (protocols) { 72 (void) ws_proto; 73 74 // TODO: optionally enable one or more of the requested protocols 75 cap += snprintf(buf + cap, sizeof buf - cap, 76 "Sec-WebSocket-Protocol: %s\r\n", protocols); 77 } 78 79 if (extensions) { 80 (void) ws_ext; 81 82 // TODO: optionally enable one or more of the requested extensions 83 cap += snprintf(buf + cap, sizeof buf - cap, 84 "Sec-WebSocket-Extensions: %s\r\n", extensions); 85 } 86 87 cap += snprintf(buf + cap, sizeof buf - cap, "\r\n"); 88 89 /* send handshake response */ 90 len = 0; 91 do { 92 ssize_t res = send(client, buf + len, cap - len, 0); 93 if (res < 0) return -1; 94 len += res; 95 } while (len < cap); 96 97 return 0; 98 } 99 100 static int 101 handle(int client, struct sockaddr *addr, socklen_t addrlen) 102 { 103 char buf[8192]; 104 struct arena arena = { .ptr = buf, .cap = sizeof buf, .len = 0, }; 105 106 char host[NI_MAXHOST], port[NI_MAXSERV]; 107 getnameinfo(addr, addrlen, host, sizeof host, port, sizeof port, 108 NI_NUMERICHOST | NI_NUMERICSERV); 109 110 printf("client from %s:%s connected\n", host, port); 111 112 if (do_http_handshake(client, &arena, NULL, NULL) < 0) { 113 printf("[%s:%s] failed http handshake\n", host, port); 114 return -1; 115 } 116 117 printf("client from %s:%s completed http handshake\n", host, port); 118 119 while (1) { 120 struct ws_frame frame; 121 if (ws_frame_recv(client, &frame) < 0) 122 return -1; 123 124 printf("[%s:%s] received websocket frame: fin: %d, mask: %u, len: %lu\n", 125 host, port, frame.fin, frame.mask, frame.len); 126 127 if (frame.opcode == WS_CLOSE) { 128 printf("[%s:%s] closed connection\n", host, port); 129 break; 130 } 131 132 if (frame.len > sizeof buf) { 133 printf("[%s:%s] payload too large!\n", host, port); 134 ws_close(client, WS_ERROR_MESSAGE_TOO_BIG, NULL, 0, 0); 135 break; 136 } 137 138 if (ws_data_recv(client, &frame, (unsigned char *) buf, sizeof buf) < 0) 139 return -1; 140 141 printf("[%s:%s] rceived websocket payload:\n%.*s\n", 142 host, port, (int) frame.len, buf); 143 144 frame.mask = 0; 145 if (ws_msg_send(client, frame.opcode, (unsigned char *) buf, frame.len, 0, 0) < 0) 146 return -1; 147 148 printf("[%s:%s] echoed websocket message\n", host, port); 149 } 150 151 shutdown(client, SHUT_RDWR); 152 close(client); 153 154 printf("client from %s:%s disconnected\n", host, port); 155 156 return 0; 157 } 158 159 int 160 main(void) 161 { 162 int serv = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); 163 if (serv < 0) { 164 perror("socket"); 165 exit(EXIT_FAILURE); 166 } 167 168 int yes = 1; 169 setsockopt(serv, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof yes); 170 171 struct sockaddr_storage addr; 172 socklen_t addrlen; 173 get_server_addr(&addr, &addrlen); 174 175 if (bind(serv, (struct sockaddr *) &addr, addrlen) < 0) { 176 perror("bind"); 177 exit(EXIT_FAILURE); 178 } 179 180 listen(serv, 1); 181 182 printf("server at %s:%d listening\n", SERVER_ADDR, SERVER_PORT); 183 184 while (1) { 185 struct sockaddr_storage client_addr; 186 socklen_t client_addrlen = sizeof client_addr; 187 int client = accept(serv, (struct sockaddr *) &client_addr, &client_addrlen); 188 handle(client, (struct sockaddr *) &client_addr, client_addrlen); 189 } 190 191 printf("server shutdown\n"); 192 193 close(serv); 194 195 exit(EXIT_SUCCESS); 196 }