ws

ws.git
git clone git://git.lenczewski.org/ws.git
Log | Files | Refs | LICENSE

ws-server.c (5056B)


      1 #define _DEFAULT_SOURCE 1
      2 #define _GNU_SOURCE 1
      3 #define _POSIX_C_SOURCE 1
      4 
      5 #include <signal.h>
      6 #include <stdio.h>
      7 #include <stdlib.h>
      8 #include <string.h>
      9 #include <unistd.h>
     10 
     11 #include <arpa/inet.h>
     12 #include <sys/socket.h>
     13 #include <netdb.h>
     14 
     15 #include <openssl/evp.h>
     16 #include <openssl/sha.h>
     17 
     18 #include "sockaddr.h"
     19 #include "http.h"
     20 #include "ws.h"
     21 
     22 static int
     23 do_http_handshake(int client, struct arena *arena, char *protocols, char *extensions)
     24 {
     25 	char buf[4096], *term = NULL;
     26 	size_t len = 0;
     27 
     28 	/* recv handshake request */
     29 	if (http_receive_msg(client, buf, sizeof buf, &term) < 0)
     30 		return -1;
     31 
     32 	struct http_msg request;
     33 	if (http_parse_msg(arena, buf, &request) < 0)
     34 		return -1;
     35 
     36 	http_print_msg(&request);
     37 
     38 	/* handle handshake request */
     39 	struct http_header *origin, *host, *connection, *upgrade;
     40 	struct http_header *ws_ver, *ws_key, *ws_proto, *ws_ext;
     41 
     42 	origin = http_header_find(request.headers, "origin");
     43 	host = http_header_find(request.headers, "host");
     44 	connection = http_header_find(request.headers, "connection");
     45 	upgrade = http_header_find(request.headers, "upgrade");
     46 
     47 	(void) origin;
     48 
     49 	if (!host || !connection || !upgrade) return -1;
     50 	if (!http_header_has_value(connection, "Upgrade")) return -1;
     51 	if (!http_header_has_value(upgrade, "websocket")) return -1;
     52 
     53 	ws_ver = http_header_find(request.headers, "sec-websocket-version");
     54 	ws_key = http_header_find(request.headers, "sec-websocket-key");
     55 	ws_proto = http_header_find(request.headers, "sec-websocket-protocol");
     56 	ws_ext = http_header_find(request.headers, "sec-websocket-extensions");
     57 
     58 	if (!ws_ver || !ws_key) return -1;
     59 	if (!http_header_has_value(ws_ver, "13")) return -1;
     60 
     61 	unsigned char *keyenc = (unsigned char *) ws_key->val, acceptenc[WS_KEY_LENGTH];
     62 	int acceptenc_len = ws_key_digest(keyenc, strlen(ws_key->val), acceptenc);
     63 
     64 	/* build handshake response */
     65 	size_t cap = snprintf(buf, sizeof buf,
     66 			"HTTP/1.1 101 Switching Protocols\r\n"
     67 			"Connection: Upgrade\r\n"
     68 			"Upgrade: websocket\r\n"
     69 			"Sec-WebSocket-Accept: %.*s\r\n", acceptenc_len, acceptenc);
     70 
     71 	if (protocols) {
     72 		(void) ws_proto;
     73 
     74 		// TODO: optionally enable one or more of the requested protocols
     75 		cap += snprintf(buf + cap, sizeof buf - cap,
     76 				"Sec-WebSocket-Protocol: %s\r\n", protocols);
     77 	}
     78 
     79 	if (extensions) {
     80 		(void) ws_ext;
     81 
     82 		// TODO: optionally enable one or more of the requested extensions
     83 		cap += snprintf(buf + cap, sizeof buf - cap,
     84 				"Sec-WebSocket-Extensions: %s\r\n", extensions);
     85 	}
     86 
     87 	cap += snprintf(buf + cap, sizeof buf - cap, "\r\n");
     88 
     89 	/* send handshake response */
     90 	len = 0;
     91 	do {
     92 		ssize_t res = send(client, buf + len, cap - len, 0);
     93 		if (res < 0) return -1;
     94 		len += res;
     95 	} while (len < cap);
     96 
     97 	return 0;
     98 }
     99 
    100 static int
    101 handle(int client, struct sockaddr *addr, socklen_t addrlen)
    102 {
    103 	char buf[8192];
    104 	struct arena arena = { .ptr = buf, .cap = sizeof buf, .len = 0, };
    105 
    106 	char host[NI_MAXHOST], port[NI_MAXSERV];
    107 	getnameinfo(addr, addrlen, host, sizeof host, port, sizeof port,
    108 		    NI_NUMERICHOST | NI_NUMERICSERV);
    109 
    110 	printf("client from %s:%s connected\n", host, port);
    111 
    112 	if (do_http_handshake(client, &arena, NULL, NULL) < 0) {
    113 		printf("[%s:%s] failed http handshake\n", host, port);
    114 		return -1;
    115 	}
    116 
    117 	printf("client from %s:%s completed http handshake\n", host, port);
    118 
    119 	while (1) {
    120 		struct ws_frame frame;
    121 		if (ws_frame_recv(client, &frame) < 0)
    122 			return -1;
    123 
    124 		printf("[%s:%s] received websocket frame: fin: %d, mask: %u, len: %lu\n",
    125 				host, port, frame.fin, frame.mask, frame.len);
    126 
    127 		if (frame.opcode == WS_CLOSE) {
    128 			printf("[%s:%s] closed connection\n", host, port);
    129 			break;
    130 		}
    131 
    132 		if (frame.len > sizeof buf) {
    133 			printf("[%s:%s] payload too large!\n", host, port);
    134 			ws_close(client, WS_ERROR_MESSAGE_TOO_BIG, NULL, 0, 0);
    135 			break;
    136 		}
    137 
    138 		if (ws_data_recv(client, &frame, (unsigned char *) buf, sizeof buf) < 0)
    139 			return -1;
    140 
    141 		printf("[%s:%s] rceived websocket payload:\n%.*s\n",
    142 				host, port, (int) frame.len, buf);
    143 
    144 		frame.mask = 0;
    145 		if (ws_msg_send(client, frame.opcode, (unsigned char *) buf, frame.len, 0, 0) < 0)
    146 			return -1;
    147 
    148 		printf("[%s:%s] echoed websocket message\n", host, port);
    149 	}
    150 
    151 	shutdown(client, SHUT_RDWR);
    152 	close(client);
    153 
    154 	printf("client from %s:%s disconnected\n", host, port);
    155 
    156 	return 0;
    157 }
    158 
    159 int
    160 main(void)
    161 {
    162 	int serv = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
    163 	if (serv < 0) {
    164 		perror("socket");
    165 		exit(EXIT_FAILURE);
    166 	}
    167 
    168 	int yes = 1;
    169 	setsockopt(serv, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof yes);
    170 
    171 	struct sockaddr_storage addr;
    172 	socklen_t addrlen;
    173 	get_server_addr(&addr, &addrlen);
    174 
    175 	if (bind(serv, (struct sockaddr *) &addr, addrlen) < 0) {
    176 		perror("bind");
    177 		exit(EXIT_FAILURE);
    178 	}
    179 
    180 	listen(serv, 1);
    181 
    182 	printf("server at %s:%d listening\n", SERVER_ADDR, SERVER_PORT);
    183 
    184 	while (1) {
    185 		struct sockaddr_storage client_addr;
    186 		socklen_t client_addrlen = sizeof client_addr;
    187 		int client = accept(serv, (struct sockaddr *) &client_addr, &client_addrlen);
    188 		handle(client, (struct sockaddr *) &client_addr, client_addrlen);
    189 	}
    190 
    191 	printf("server shutdown\n");
    192 
    193 	close(serv);
    194 
    195 	exit(EXIT_SUCCESS);
    196 }