ws

ws.git
git clone git://git.lenczewski.org/ws.git
Log | Files | Refs | LICENSE

commit 0a02c0ea0f5059371d408578b47c15f5e12d3f2a
Author: Mikołaj Lenczewski <mblenczewski@gmail.com>
Date:   Sat,  7 Jun 2025 16:04:11 +0000

Initial commit

Diffstat:
A.editorconfig | 14++++++++++++++
A.gitignore | 3+++
ALICENSE | 18++++++++++++++++++
Aarena.h | 48++++++++++++++++++++++++++++++++++++++++++++++++
Abuild.sh | 15+++++++++++++++
Aclean.sh | 5+++++
Ahttp.h | 138+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Asockaddr.h | 24++++++++++++++++++++++++
Aws-client.c | 163+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aws-client.html | 48++++++++++++++++++++++++++++++++++++++++++++++++
Aws-server.c | 196+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aws.h | 319+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
12 files changed, 991 insertions(+), 0 deletions(-)

diff --git a/.editorconfig b/.editorconfig @@ -0,0 +1,14 @@ +root = true + +[*] +end_of_line = lf +insert_final_newline = true +charset = utf-8 + +[*.{c,h}] +indent_style = tab +indent_size = 8 + +[*.{html,css,js}] +indent_style = space +indent_size = 2 diff --git a/.gitignore b/.gitignore @@ -0,0 +1,3 @@ +bin + +**/.*.swp diff --git a/LICENSE b/LICENSE @@ -0,0 +1,18 @@ +The MIT-Zero License + +Copyright (c) 2025 Mikołaj Lenczewski <mikolaj@lenczewski.org> + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/arena.h b/arena.h @@ -0,0 +1,48 @@ +#ifndef ARENA_H +#define ARENA_H + +#include <assert.h> +#include <stdalign.h> +#include <stddef.h> +#include <stdint.h> + +#define ISPOW2(v) (((v) & ((v) - 1)) == 0) +#define IS_ALIGNED(v, align) ((v) & ((align) - 1)) +#define ALIGN_PREV(v, align) ((v) & ~((align) - 1)) +#define ALIGN_NEXT(v, align) ALIGN_PREV(((v) + ((align) - 1)), (align)) + +struct arena { + void *ptr; + size_t cap, len; +}; + +static inline void +arena_reset(struct arena *arena) +{ + arena->len = 0; +} + +static inline void * +arena_alloc(struct arena *arena, size_t size, size_t align) +{ + assert(size); + assert(align); + assert(ISPOW2(align)); + + uintptr_t ptr = (uintptr_t) arena->ptr; + uintptr_t aligned_ptr = ALIGN_NEXT(ptr + arena->len, align); + if (ptr + arena->cap < aligned_ptr + size) + return NULL; + + arena->len = (aligned_ptr - ptr) + size; + + return (void *) aligned_ptr; +} + +#define ARENA_ALLOC_ARRAY(arena, T, n) \ + arena_alloc((arena), sizeof(T) * (n), alignof(T)) + +#define ARENA_ALLOC_SIZED(arena, T) \ + ARENA_ALLOC_ARRAY((arena), T, 1) + +#endif /* ARENA_H */ diff --git a/build.sh b/build.sh @@ -0,0 +1,15 @@ +#!/bin/sh + +CC="${CC:-clang}" + +FLAGS="-Wall -Wextra -Wpedantic ${WERROR:+-Werror} -std=c11 -O0 -g" +DEPS="$(pkg-config --cflags --libs openssl)" + +set -ex + +mkdir -p bin + +$CC -o bin/ws-client ws-client.c $FLAGS $DEPS +$CC -o bin/ws-server ws-server.c $FLAGS $DEPS +#$CC -o bin/sse-client sse-client.c $FLAGS $DEPS +#$CC -o bin/sse-server sse-server.c $FLAGS $DEPS diff --git a/clean.sh b/clean.sh @@ -0,0 +1,5 @@ +#!/bin/sh + +set -ex + +rm -rf bin diff --git a/http.h b/http.h @@ -0,0 +1,138 @@ +#ifndef HTTP_H +#define HTTP_H + +#ifndef _POSIX_C_SOURCE +# define _POSIX_C_SOURCE 1 +#endif + +#include <assert.h> +#include <ctype.h> +#include <string.h> +#include <strings.h> + +#include "arena.h" + +#define HTTP_TERMINATOR "\r\n\r\n" + +static inline int +http_receive_msg(int sock, char *buf, size_t cap, char **terminator) +{ + char *term = NULL; + size_t len = 0, cur = 0; + + do { + ssize_t res = recv(sock, buf + len, cap - len, 0); + if (res < 0) return -1; + len += res; + + size_t lower_bound = (cur < 4) ? 0 : (cur - 4); + term = memmem(buf + lower_bound, len - lower_bound, + HTTP_TERMINATOR, 4); + cur += res; + } while (len < cap && !term); + + *term = 0; + + *terminator = term; + + return 0; +} + +struct http_header { + char *key, *val; + struct http_header *next; +}; + +static inline struct http_header * +http_header_find(struct http_header *headers, char const *key) +{ + while (headers) { + if (strcasecmp(headers->key, key) == 0) + return headers; + + headers = headers->next; + } + + return NULL; +} + +static inline int +http_header_has_value(struct http_header const *header, char const *value) +{ + char *cur = header->val, *tok; + do { + // header values are comma-delimited + tok = strchrnul(cur, ','); + + while (isspace(*cur)) cur++; // trim leading whitespace + + char *end = tok, save; + if (isspace(*(end - 1))) end--; // if character preceeding ',' is whitespace + while (isspace(*end)) end--; // trim trailing whitespace + + save = *end; + *end = '\0'; + + if (strstr(cur, value)) + return 1; + + *end = save; + cur = ++tok; + } while (*tok); + + return 0; +} + +struct http_msg { + char *leader; + struct http_header *headers; +}; + +static inline int +http_parse_msg(struct arena *arena, char *buf, struct http_msg *out) +{ + char *saveptr = NULL; + out->leader = strtok_r(buf, "\r\n", &saveptr); + + out->headers = NULL; + + char *hdr = NULL; + while ((hdr = strtok_r(NULL, "\r\n", &saveptr))) { + struct http_header *header = ARENA_ALLOC_SIZED(arena, struct http_header); + if (!header) return -1; + + char *val; + header->key = strtok_r(hdr, ":", &val); + + while (isspace(*val)) val++; // trim leading whitespace + + header->val = val; + + header->next = out->headers; + out->headers = header; + } + + return 0; +} + +static inline void +http_print_msg(struct http_msg const *msg) +{ + printf("http msg:\n"); + + printf("=====\n"); + + printf("LEADER: %s\n", msg->leader); + + printf("=====\n"); + + printf("HEADERS:\n"); + + struct http_header *headers = msg->headers; + while (headers) { + printf("\t%s = %s\n", headers->key, headers->val); + headers = headers->next; + } +} + +#endif /* HTTP_H */ diff --git a/sockaddr.h b/sockaddr.h @@ -0,0 +1,24 @@ +#ifndef SERVERADDR_H +#define SERVERADDR_H + +#define SERVER_ADDR "127.0.0.1" +#define SERVER_PORT 8080 + +static inline int +get_server_addr(struct sockaddr_storage *storage, socklen_t *storagelen) +{ + struct sockaddr_in *addr = (struct sockaddr_in *) storage; + addr->sin_family = AF_INET; + addr->sin_port = htons(SERVER_PORT); + + if (inet_pton(addr->sin_family, SERVER_ADDR, &addr->sin_addr) < 1) { + perror("inet_pton"); + return -1; + } + + *storagelen = sizeof *addr; + + return 0; +} + +#endif /* SERVERADDR_H */ diff --git a/ws-client.c b/ws-client.c @@ -0,0 +1,163 @@ +#define _GNU_SOURCE 1 +#define _POSIX_C_SOURCE 1 + +#include <signal.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <unistd.h> + +#include <arpa/inet.h> +#include <sys/socket.h> + +#include <openssl/evp.h> +#include <openssl/sha.h> + +#include "sockaddr.h" +#include "http.h" +#include "ws.h" + +static int +do_http_handshake(int client, struct arena *arena, char *protocols, char *extensions) +{ + char buf[4096], *term = NULL; + size_t len = 0; + + /* build request */ + unsigned char key[16] = "0123456789abcdef", keyenc[WS_KEY_LENGTH]; + int keyenc_len = ws_key_digest(key, sizeof key, keyenc); + + size_t cap = snprintf(buf, sizeof buf, + "GET /home HTTP/1.1\r\n" + "Origin: localhost\r\n" + "Host: localhost\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Sec-WebSocket-Key: %.*s\r\n", keyenc_len, keyenc); + + if (protocols) + cap += snprintf(buf + cap, sizeof buf - cap, + "Sec-WebSocket-Protocol: %s\r\n", protocols); + + if (extensions) + cap += snprintf(buf + cap, sizeof buf - cap, + "Sec-WebSocket-Extensions: %s\r\n", extensions); + + cap += snprintf(buf + cap, sizeof buf - cap, "\r\n"); + + /* send request */ + do { + ssize_t res = send(client, buf + len, cap - len, 0); + if (res < 0) return -1; + len += cap; + } while (len < cap); + + /* recv response */ + if (http_receive_msg(client, buf, sizeof buf, &term) < 0) + return -1; + + /* parse response */ + struct http_msg response; + if (http_parse_msg(arena, buf, &response) < 0) + return -1; + + http_print_msg(&response); + + /* handle response */ + struct http_header *connection, *upgrade; + struct http_header *ws_accept, *ws_proto, *ws_ext; + + connection = http_header_find(response.headers, "connection"); + upgrade = http_header_find(response.headers, "upgrade"); + + if (!connection || !upgrade) return -1; + if (!http_header_has_value(connection, "Upgrade")) return -1; + if (!http_header_has_value(upgrade, "websocket")) return -1; + + ws_accept = http_header_find(response.headers, "sec-websocket-accept"); + ws_proto = http_header_find(response.headers, "sec-websocket-protocol"); + ws_ext = http_header_find(response.headers, "sec-websocket-extensions"); + + if (!ws_accept) return -1; + + unsigned char acceptenc[WS_KEY_LENGTH]; + ws_key_digest(keyenc, keyenc_len, acceptenc); + + if (!http_header_has_value(ws_accept, (char *) acceptenc)) return -1; + if (ws_proto && !http_header_has_value(ws_proto, protocols)) return -1; + if (ws_ext && !http_header_has_value(ws_ext, extensions)) return -1; + + return 0; +} + +static int +handle(int client) +{ + char buf[8192]; + struct arena arena = { .ptr = buf, .cap = sizeof buf, .len = 0, }; + + if (do_http_handshake(client, &arena, NULL, NULL) < 0) + return -1; + + printf("client completed http handshake\n"); + + char msg[] = "Hello, World!\n"; + uint32_t mask = rand(); + int fragment = 0; + + if (ws_msg_send(client, WS_DATA_UTF8, (unsigned char *) msg, sizeof msg, mask, fragment) < 0) + return -1; + + printf("client sent websocket message\n"); + + struct ws_frame frame; + if (ws_frame_recv(client, &frame) < 0) + return -1; + + printf("client received websocket frame: fin: %d, mask: %u, len: %lu\n", + frame.fin, frame.mask, frame.len); + + if (ws_data_recv(client, &frame, (unsigned char *) buf, sizeof buf) < 0) + return -1; + + printf("client reeived websocket payload:\n%.*s\n", + (int) frame.len, buf); + + if (ws_close(client, WS_ERROR_OK, NULL, 0, rand()) < 0) + return -1; + + printf("client closed websocket\n"); + + return 0; +} + +int +main(void) +{ + int client = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (client < 0) { + perror("socket"); + exit(EXIT_FAILURE); + } + + struct sockaddr_storage addr; + socklen_t addrlen; + get_server_addr(&addr, &addrlen); + + if (connect(client, (struct sockaddr *) &addr, sizeof addr) < 0) { + perror("connect"); + exit(EXIT_FAILURE); + } + + printf("client connected to %s:%d\n", SERVER_ADDR, SERVER_PORT); + + handle(client); + + shutdown(client, SHUT_RDWR); + close(client); + + printf("client disconnected\n"); + + exit(EXIT_SUCCESS); +} diff --git a/ws-client.html b/ws-client.html @@ -0,0 +1,48 @@ +<!DOCTYPE html> +<html lang="en"> + <body> + <h1>Controls</h1> + <label for="msg">Message:</label> + <input id="msg" type="text" /> + <input id="sendbtn" type="button" value="Send" /> + <input id="closebtn" type="button" value="Close" /> + + <h1>History</h1> + <div id="log"> + </div> + </body> + <script> + let input = document.getElementById("msg"); + let sendbtn = document.getElementById("sendbtn"); + let closebtn = document.getElementById("closebtn"); + let log = document.getElementById("log"); + + let ws = new WebSocket("ws://localhost:8080"); + + sendbtn.onclick = (e) => { + console.log("sendbtn.click"); + console.log(e); + ws.send(input.value); + log.innerHTML += `<p><span>Sent: ${input.value}</span></p>\n`; + }; + + closebtn.onclick = (e) => { + console.log("closebtn.click"); + console.log(e); + ws.close(); + log.innerHTML += `<p><span>Closed websocket</span></p>\n`; + }; + + ws.onopen = (e) => { + console.log("ws.onopen:"); + console.log(e); + log.innerHTML += `<p><span>Websocket connected</span></p>\n`; + }; + + ws.onmessage = (e) => { + console.log("ws.onmessage:"); + console.log(e); + log.innerHTML += `<p><span>Received: ${e.data}</span></p>\n`; + }; + </script> +</html> diff --git a/ws-server.c b/ws-server.c @@ -0,0 +1,196 @@ +#define _DEFAULT_SOURCE 1 +#define _GNU_SOURCE 1 +#define _POSIX_C_SOURCE 1 + +#include <signal.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <unistd.h> + +#include <arpa/inet.h> +#include <sys/socket.h> +#include <netdb.h> + +#include <openssl/evp.h> +#include <openssl/sha.h> + +#include "sockaddr.h" +#include "http.h" +#include "ws.h" + +static int +do_http_handshake(int client, struct arena *arena, char *protocols, char *extensions) +{ + char buf[4096], *term = NULL; + size_t len = 0; + + /* recv handshake request */ + if (http_receive_msg(client, buf, sizeof buf, &term) < 0) + return -1; + + struct http_msg request; + if (http_parse_msg(arena, buf, &request) < 0) + return -1; + + http_print_msg(&request); + + /* handle handshake request */ + struct http_header *origin, *host, *connection, *upgrade; + struct http_header *ws_ver, *ws_key, *ws_proto, *ws_ext; + + origin = http_header_find(request.headers, "origin"); + host = http_header_find(request.headers, "host"); + connection = http_header_find(request.headers, "connection"); + upgrade = http_header_find(request.headers, "upgrade"); + + (void) origin; + + if (!host || !connection || !upgrade) return -1; + if (!http_header_has_value(connection, "Upgrade")) return -1; + if (!http_header_has_value(upgrade, "websocket")) return -1; + + ws_ver = http_header_find(request.headers, "sec-websocket-version"); + ws_key = http_header_find(request.headers, "sec-websocket-key"); + ws_proto = http_header_find(request.headers, "sec-websocket-protocol"); + ws_ext = http_header_find(request.headers, "sec-websocket-extensions"); + + if (!ws_ver || !ws_key) return -1; + if (!http_header_has_value(ws_ver, "13")) return -1; + + unsigned char *keyenc = (unsigned char *) ws_key->val, acceptenc[WS_KEY_LENGTH]; + int acceptenc_len = ws_key_digest(keyenc, strlen(ws_key->val), acceptenc); + + /* build handshake response */ + size_t cap = snprintf(buf, sizeof buf, + "HTTP/1.1 101 Switching Protocols\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Accept: %.*s\r\n", acceptenc_len, acceptenc); + + if (protocols) { + (void) ws_proto; + + // TODO: optionally enable one or more of the requested protocols + cap += snprintf(buf + cap, sizeof buf - cap, + "Sec-WebSocket-Protocol: %s\r\n", protocols); + } + + if (extensions) { + (void) ws_ext; + + // TODO: optionally enable one or more of the requested extensions + cap += snprintf(buf + cap, sizeof buf - cap, + "Sec-WebSocket-Extensions: %s\r\n", extensions); + } + + cap += snprintf(buf + cap, sizeof buf - cap, "\r\n"); + + /* send handshake response */ + len = 0; + do { + ssize_t res = send(client, buf + len, cap - len, 0); + if (res < 0) return -1; + len += res; + } while (len < cap); + + return 0; +} + +static int +handle(int client, struct sockaddr *addr, socklen_t addrlen) +{ + char buf[8192]; + struct arena arena = { .ptr = buf, .cap = sizeof buf, .len = 0, }; + + char host[NI_MAXHOST], port[NI_MAXSERV]; + getnameinfo(addr, addrlen, host, sizeof host, port, sizeof port, + NI_NUMERICHOST | NI_NUMERICSERV); + + printf("client from %s:%s connected\n", host, port); + + if (do_http_handshake(client, &arena, NULL, NULL) < 0) { + printf("[%s:%s] failed http handshake\n", host, port); + return -1; + } + + printf("client from %s:%s completed http handshake\n", host, port); + + while (1) { + struct ws_frame frame; + if (ws_frame_recv(client, &frame) < 0) + return -1; + + printf("[%s:%s] received websocket frame: fin: %d, mask: %u, len: %lu\n", + host, port, frame.fin, frame.mask, frame.len); + + if (frame.opcode == WS_CLOSE) { + printf("[%s:%s] closed connection\n", host, port); + break; + } + + if (frame.len > sizeof buf) { + printf("[%s:%s] payload too large!\n", host, port); + ws_close(client, WS_ERROR_MESSAGE_TOO_BIG, NULL, 0, 0); + break; + } + + if (ws_data_recv(client, &frame, (unsigned char *) buf, sizeof buf) < 0) + return -1; + + printf("[%s:%s] rceived websocket payload:\n%.*s\n", + host, port, (int) frame.len, buf); + + frame.mask = 0; + if (ws_msg_send(client, frame.opcode, (unsigned char *) buf, frame.len, 0, 0) < 0) + return -1; + + printf("[%s:%s] echoed websocket message\n", host, port); + } + + shutdown(client, SHUT_RDWR); + close(client); + + printf("client from %s:%s disconnected\n", host, port); + + return 0; +} + +int +main(void) +{ + int serv = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (serv < 0) { + perror("socket"); + exit(EXIT_FAILURE); + } + + int yes = 1; + setsockopt(serv, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof yes); + + struct sockaddr_storage addr; + socklen_t addrlen; + get_server_addr(&addr, &addrlen); + + if (bind(serv, (struct sockaddr *) &addr, addrlen) < 0) { + perror("bind"); + exit(EXIT_FAILURE); + } + + listen(serv, 1); + + printf("server at %s:%d listening\n", SERVER_ADDR, SERVER_PORT); + + while (1) { + struct sockaddr_storage client_addr; + socklen_t client_addrlen = sizeof client_addr; + int client = accept(serv, (struct sockaddr *) &client_addr, &client_addrlen); + handle(client, (struct sockaddr *) &client_addr, client_addrlen); + } + + printf("server shutdown\n"); + + close(serv); + + exit(EXIT_SUCCESS); +} diff --git a/ws.h b/ws.h @@ -0,0 +1,319 @@ +#ifndef WS_H +#define WS_H + +#include <openssl/evp.h> +#include <openssl/sha.h> + +#define WS_GUID "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + +#define WS_KEY_LENGTH ((SHA_DIGEST_LENGTH * 4) / 3) + +static inline int +ws_key_digest(unsigned char *buf, size_t len, unsigned char out[static WS_KEY_LENGTH]) +{ + unsigned char concat[128], sha1[SHA_DIGEST_LENGTH]; + size_t concat_len = snprintf((char *) concat, sizeof concat, "%.*s%s", + (int) len, buf, WS_GUID); + SHA1(concat, concat_len, sha1); + return EVP_EncodeBlock(out, sha1, sizeof sha1); +} + +static inline uint16_t +ws_hton16(uint16_t v) +{ + return htons(v); +} + +static inline uint32_t +ws_hton32(uint32_t v) +{ + return htonl(v); +} + +static inline uint64_t +ws_hton64(uint64_t v) +{ + if (1 == htons(1)) return v; + + uint32_t lower = (v & 0xffffffff), upper = (v >> 32); + uint64_t res = (((uint64_t) htonl(lower)) << 32) | htonl(upper); + return res; +} + +static inline uint16_t +ws_ntoh16(uint16_t v) +{ + return ntohs(v); +} + +static inline uint32_t +ws_ntoh32(uint32_t v) +{ + return ntohl(v); +} + +static inline uint64_t +ws_ntoh64(uint64_t v) +{ + if (1 == htons(1)) return v; + + uint32_t lower = (v & 0xffffffff), upper = (v >> 32); + uint64_t res = (((uint64_t) ntohl(lower)) << 32) | htonl(upper); + return res; +} + +enum ws_error { + /* 0 - 999 unused */ + + WS_ERROR_OK = 1000, + WS_ERROR_GOING_AWAY = 1001, + WS_ERROR_PROTOCOL_ERROR = 1002, + WS_ERROR_UNSUPPORTED_DATA = 1003, + + /* 1004 reserved */ + + WS_ERROR_NO_CODE_RECIEVED = 1005, + WS_ERROR_CLOSED_ABNORMALLY = 1006, + + WS_ERROR_INVALID_PAYLOAD_DATA = 1007, + WS_ERROR_POLICY_VIOLATED = 1008, + WS_ERROR_MESSAGE_TOO_BIG = 1009, + WS_ERROR_UNSUPPORTED_EXTENSION = 1010, + WS_ERROR_INTERNAL_SERVER_ERROR = 1011, + + /* 1012 - 1014 reserved */ + + WS_ERROR_TLS_HANDSHAKE_FAILURE = 1015, + + /* 3000 - 3999 reserved */ + + /* 4000 - 4999 reserved for application */ +}; + +enum ws_opcode { + WS_CONT = 0x0, + WS_DATA_UTF8 = 0x1, + WS_DATA_BINARY = 0x2, + + /* 0x3 - 0x7 reserved */ + + WS_CLOSE = 0x8, + WS_PING = 0x9, + WS_PONG = 0xa, + + /* 0xb - 0xf reserved */ +}; + +struct ws_frame { + enum ws_opcode opcode; + uint8_t fin; + uint8_t res; + uint64_t len; + uint32_t mask; +}; + +static inline int +ws_frame_recv(int sock, struct ws_frame *out) +{ + int res; + + uint8_t hdr[8]; + + if ((res = recv(sock, hdr, 2, 0)) < 0) + return -1; + + out->fin = (hdr[0] >> 7) & 0x1; + out->res = (hdr[0] >> 4) & 0x7; + out->opcode = hdr[0] & 0xf; + + int masked = (hdr[1] >> 7), len = hdr[1] & 0x7f; + + if (len == 127) { + if ((res = recv(sock, hdr, 8, 0)) < 0) + return -1; + + uint64_t raw = *((uint64_t *) hdr); + out->len = ws_ntoh64(raw); + } else if (len == 126) { + if ((res = recv(sock, hdr, 2, 0)) < 0) + return -1; + + uint16_t raw = *((uint16_t *) hdr); + out->len = ws_ntoh16(raw); + } else /* len < 126 */ { + out->len = len; + } + + if (masked) { + if ((res = recv(sock, hdr, 4, 0)) < 0) + return -1; + + uint32_t raw = *((uint32_t *) hdr); + out->mask = ws_ntoh32(raw); + } else { + out->mask = 0; + } + + return 0; +} + +static inline int +ws_frame_send(int sock, struct ws_frame const *frame) +{ + int res; + + /* serialise frame header */ + uint8_t len; + if (frame->len < 126) { + len = frame->len; + } else if (frame->len <= UINT16_MAX) { + len = 126; + } else { + len = 127; + } + + uint8_t ptr = 0; + uint8_t hdr[14]; + hdr[ptr++] = ((!!(frame->fin)) << 7) | ((frame->res & 0x7) << 4) | (frame->opcode & 0xf); + hdr[ptr++] = ((!!(frame->mask)) << 7) | len; + + if (len == 126) { + uint16_t raw = ws_hton16(frame->len); + memcpy(hdr + ptr, &raw, 2); + ptr += 2; + } + + if (len == 127) { + uint64_t raw = ws_hton64(frame->len); + memcpy(hdr + ptr, &raw, 8); + ptr += 8; + } + + if (frame->mask) { + uint32_t raw = ws_hton32(frame->mask); + memcpy(hdr + ptr, &raw, 4); + ptr += 4; + } + + /* send frame header */ + if ((res = send(sock, hdr, ptr, 0)) < 0) + return -1; + + return 0; +} + +static inline int +ws_data_recv(int sock, struct ws_frame const *frame, unsigned char *buf, size_t cap) +{ + assert(frame->len <= cap); + + int res; + + uint64_t remaining = frame->len; + + unsigned char payload[1024]; + unsigned char key[4] = { + (frame->mask & 0xff000000) >> 24, + (frame->mask & 0x00ff0000) >> 16, + (frame->mask & 0x0000ff00) >> 8, + (frame->mask & 0x000000ff), + }; + + do { + size_t masked = (sizeof payload < remaining) ? sizeof payload : remaining; + + /* receive masked payload */ + size_t masked_recv = 0; + do { + if ((res = recv(sock, payload + masked_recv, masked - masked_recv, 0)) < 0) + return -1; + masked_recv += res; + } while (masked_recv < masked); + + /* unmask payload */ + for (size_t i = 0; i < masked; i++) + *buf++ = payload[i] ^ key[i % 4]; + + remaining -= masked; + } while (remaining); + + return 0; +} + +static inline int +ws_data_send(int sock, struct ws_frame const *frame, unsigned char const *buf) +{ + int res; + + /* send frame payload */ + uint64_t remaining = frame->len; + + unsigned char payload[1024]; + unsigned char key[4] = { + (frame->mask & 0xff000000) >> 24, + (frame->mask & 0x00ff0000) >> 16, + (frame->mask & 0x0000ff00) >> 8, + (frame->mask & 0x000000ff), + }; + + do { + size_t masked = (sizeof payload < remaining) ? sizeof payload : remaining; + + /* mask payload */ + for (size_t i = 0; i < masked; i++) + payload[i] = *buf++ ^ key[i % 4]; + + /* send masked payload */ + size_t masked_send = 0; + do { + if ((res = send(sock, payload + masked_send, masked - masked_send, 0)) < 0) return -1; + masked_send += res; + } while (masked_send < masked); + + remaining -= masked; + } while (remaining); + + return 0; +} + +static inline int +ws_msg_send(int sock, enum ws_opcode opcode, unsigned char *buf, size_t len, uint32_t mask, int fragment) +{ + struct ws_frame frame; + frame.opcode = opcode; + frame.fin = !fragment; + frame.res = 0; + frame.len = len; + frame.mask = mask; + + if (ws_frame_send(sock, &frame) < 0) + return -1; + + return ws_data_send(sock, &frame, buf); +} + +static inline int +ws_close(int sock, enum ws_error err, unsigned char *msg, size_t len, uint32_t mask) +{ + assert(len <= 123); + + struct ws_frame frame; + frame.opcode = WS_CLOSE; + frame.fin = 1; + frame.res = 0; + frame.len = len; + frame.mask = mask; + + if (ws_frame_send(sock, &frame) < 0) + return -1; + + /* send optional error code */ + int res; + uint16_t raw = ws_hton16(err); + if ((res = send(sock, &raw, sizeof raw, 0)) < 0) + return -1; + + return ws_data_send(sock, &frame, msg); +} + +#endif /* WS_H */