sslexample

sslexample.git
git clone git://git.lenczewski.org/sslexample.git
Log | Files | Refs

commit b2c083d31cbf71236cc157564a6cc2ad6d8157b4
parent 7d462fcd193af006ab4484bbb85953d026c5bf3e
Author: MikoĊ‚aj Lenczewski <mblenczewski@gmail.com>
Date:   Mon, 16 Jun 2025 01:05:03 +0100

Added example non-blocking ssl server and client

Diffstat:
Mbuild.sh | 2+-
Acommon.h | 269+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Msslclient.c | 251+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--
Msslserver.c | 284+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--
4 files changed, 794 insertions(+), 12 deletions(-)

diff --git a/build.sh b/build.sh @@ -3,7 +3,7 @@ CC="${CC:-clang}" WARNINGS="-Wall -Wextra -Wpedantic ${WERROR:+-Werror} -Wno-unused-parameter -Wno-format-pedantic" -FLAGS="-std=c99 -Og -g" +FLAGS="-std=c11 -Og -g ${RELEASE:+-DNDEBUG}" DEPS="$(pkg-config --cflags --libs liburing openssl)" diff --git a/common.h b/common.h @@ -0,0 +1,269 @@ +#define _GNU_SOURCE 1 +#define _XOPEN_SOURCE 700 + +#include <assert.h> +#include <stddef.h> +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <unistd.h> + +#include <sys/socket.h> +#include <netinet/ip.h> +#include <netdb.h> + +#include <fcntl.h> + +#include <liburing.h> + +#include <openssl/ssl.h> +#include <openssl/err.h> + +#ifndef NDEBUG +# define DBGLOG(...) fprintf(stderr, __VA_ARGS__) +#else +# define DBGLOG(...) +#endif + +#define TO_PARENT_PTR(ptr, T, member) \ + ((T *) (((uintptr_t) (ptr)) - offsetof(T, member))) + +// ioreq : a helper to make multiplexing operations on an io_uring easier +// =========================================================================== +// + +enum ioreq_type { + IOREQ_ACCEPT, IOREQ_CONNECT, IOREQ_RECV, IOREQ_SEND, IOREQ_CLOSE, +}; + +struct ioreq_accept { + int fd; + struct sockaddr *addr; + socklen_t *addrlen; + int flags; +}; + +struct ioreq_connect { + int fd; + struct sockaddr *addr; + socklen_t addrlen; +}; + +struct ioreq_recv { + int fd; + void *buf; + size_t len; + int flags; +}; + +struct ioreq_send { + int fd; + void *buf; + size_t len; + int flags; +}; + +struct ioreq_close { + int fd; +}; + +union ioreq_tag { + struct ioreq_accept accept; + struct ioreq_connect connect; + struct ioreq_recv recv; + struct ioreq_send send; + struct ioreq_close close; +}; + +struct ioreq { + enum ioreq_type type; + union ioreq_tag tag; +}; + +static int +queue_ioreqs(struct io_uring *uring, struct ioreq *reqs, size_t len) +{ + for (size_t i = 0; i < len; i++) { + struct ioreq *req = &reqs[i]; + + struct io_uring_sqe *sqe = io_uring_get_sqe(uring); + if (!sqe) { /* out of sqes, submit and try again */ + io_uring_submit(uring); + sqe = io_uring_get_sqe(uring); + } + + assert(sqe); + + io_uring_sqe_set_data(sqe, req); + + switch (req->type) { + case IOREQ_ACCEPT: + io_uring_prep_accept(sqe, + req->tag.accept.fd, + req->tag.accept.addr, + req->tag.accept.addrlen, + req->tag.accept.flags); + break; + + case IOREQ_CONNECT: + io_uring_prep_connect(sqe, + req->tag.connect.fd, + req->tag.connect.addr, + req->tag.connect.addrlen); + break; + + case IOREQ_RECV: + io_uring_prep_recv(sqe, + req->tag.recv.fd, + req->tag.recv.buf, + req->tag.recv.len, + req->tag.recv.flags); + break; + + case IOREQ_SEND: + io_uring_prep_send(sqe, + req->tag.send.fd, + req->tag.send.buf, + req->tag.send.len, + req->tag.send.flags); + break; + + + case IOREQ_CLOSE: + io_uring_prep_close(sqe, + req->tag.close.fd); + break; + } + } + + return 0; +} + +// connection : an abstraction over a socket and an ssl stream +// =========================================================================== +// + +struct connection { + int socket; + + struct ioreq ioreq; + + SSL *ssl; + BIO *ssl_bio, *net_bio; + + size_t cur, len, cap; + unsigned char *buf; +}; + +static void +conn_init(struct connection *conn, int fd, SSL *ssl, void *buf, size_t cap) +{ + conn->socket = fd; + conn->ssl = ssl; + + BIO_new_bio_pair(&conn->ssl_bio, 0, &conn->net_bio, 0); + SSL_set_bio(conn->ssl, conn->ssl_bio, conn->ssl_bio); + + conn->buf = buf; + conn->cap = cap; + conn->cur = conn->len = 0; +} + +static void +conn_free(struct connection *conn) +{ + SSL_free(conn->ssl); + BIO_free(conn->net_bio); +} + +static int +conn_prep_recv(struct connection *conn, struct io_uring *uring) +{ + char *buf; + int len = BIO_nwrite0(conn->net_bio, &buf); + + DBGLOG("[conn:%d] receiving %d bytes\n", conn->socket, len); + + conn->ioreq.type = IOREQ_RECV; + conn->ioreq.tag.recv.fd = conn->socket; + conn->ioreq.tag.recv.buf = buf; + conn->ioreq.tag.recv.len = len; + conn->ioreq.tag.recv.flags = 0; + + return queue_ioreqs(uring, &conn->ioreq, 1); +} + +static void +conn_finish_recv(struct connection *conn, int res) +{ + DBGLOG("[conn:%d] recieved %d bytes\n", conn->socket, res); + BIO_nwrite(conn->net_bio, NULL, res); +} + +static int +conn_prep_send(struct connection *conn, struct io_uring *uring) +{ + char *buf; + int len = BIO_nread0(conn->net_bio, &buf); + + DBGLOG("[conn:%d] sending %d bytes\n", conn->socket, len); + + conn->ioreq.type = IOREQ_SEND; + conn->ioreq.tag.send.fd = conn->socket; + conn->ioreq.tag.send.buf = buf; + conn->ioreq.tag.send.len = len; + conn->ioreq.tag.send.flags = 0; + + return queue_ioreqs(uring, &conn->ioreq, 1); +} + +static void +conn_finish_send(struct connection *conn, int res) +{ + DBGLOG("[conn:%d] sent %d bytes\n", conn->socket, res); + BIO_nread(conn->net_bio, NULL, res); +} + +static int +conn_prep_close(struct connection *conn, struct io_uring *uring) +{ + DBGLOG("[conn:%d] closing connection\n", conn->socket); + + conn->ioreq.type = IOREQ_CLOSE; + conn->ioreq.tag.close.fd = conn->socket; + + return queue_ioreqs(uring, &conn->ioreq, 1); +} + +extern int +do_tls_connect(struct io_uring *uring, struct connection *conn, int res); + +static int +conn_do_tls_handshake(struct connection *conn, struct io_uring *uring, + int transferred) +{ + int res = SSL_do_handshake(conn->ssl); + if (res == 1) /* completed handshake */ + return do_tls_connect(uring, conn, transferred); + + if (res == 0) /* disconnected */ + return -1; + + int err = SSL_get_error(conn->ssl, res); + if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) { + fprintf(stderr, "unexpected ssl error: %d\n", err); + ERR_print_errors_fp(stderr); + return -1; + } + + int pending = BIO_ctrl_pending(conn->net_bio); + if (pending) + return conn_prep_send(conn, uring); + + int expecting = BIO_ctrl_get_read_request(conn->net_bio); + if (expecting) + return conn_prep_recv(conn, uring); + + return -1; +} diff --git a/sslclient.c b/sslclient.c @@ -1,14 +1,255 @@ -#include <stdio.h> -#include <stdlib.h> -#include <unistd.h> +#include "common.h" + +struct ssl_client { + struct connection conn; +}; + +int +do_tls_connect(struct io_uring *uring, struct connection *conn, int res) +{ + DBGLOG("completed tls handshake!\n"); + +#ifndef NDEBUG + conn->len = snprintf((char *) conn->buf, conn->cap, "Hello, World!\n"); +#else + conn->len = conn->cap; +#endif + + printf("sending and expecting to receive %zu byte chunks\n", conn->len); + + conn->cur = SSL_write(conn->ssl, conn->buf, conn->len); + + return conn_prep_send(conn, uring); +} + +int +do_recv(struct io_uring *uring, struct connection *conn, int res) +{ + if (conn->cur < conn->len) { /* have not received entire message */ + conn->cur += SSL_read(conn->ssl, conn->buf + conn->cur, conn->len - conn->cur); + return conn_prep_recv(conn, uring); + } else { /* received entire message */ +#ifndef NDEBUG + printf("received %zu plaintext bytes\n", conn->cur); + + uint64_t nsec_per_msec = 1000000; + struct timespec timeout = { .tv_nsec = nsec_per_msec, }; + while (nanosleep(&timeout, &timeout) < 0); +#endif + + conn->cur = SSL_write(conn->ssl, conn->buf, conn->len); + + return conn_prep_send(conn, uring); + } +} + +int +do_send(struct io_uring *uring, struct connection *conn, int res) +{ + if (conn->cur < conn->len) { /* have not sent entire message */ + conn->cur += SSL_write(conn->ssl, conn->buf + conn->cur, conn->len - conn->cur); + return conn_prep_send(conn, uring); + } else { /* sent entire message */ +#ifndef NDEBUG + printf("sent %zu plaintext bytes\n", conn->cur); +#endif + + conn->cur = SSL_read(conn->ssl, conn->buf, conn->len); + + return conn_prep_recv(conn, uring); + } +} + +struct ssl_app { + char const *host, *port; + + struct io_uring io_uring; + + SSL_CTX *ssl_ctx; +}; + +int +start_client(struct ssl_app *app) +{ + struct addrinfo hints = { + .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM, + .ai_protocol = IPPROTO_TCP, + .ai_flags = AI_NUMERICSERV, + }, *addrinfo, *ptr; + + int res; + if ((res = getaddrinfo(app->host, app->port, &hints, &addrinfo))) { + fprintf(stderr, "Failed to get address info: %s\n", + gai_strerror(res)); + exit(EXIT_FAILURE); + } + + int sock; + for (ptr = addrinfo; ptr; ptr = ptr->ai_next) { + sock = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol); + if (sock < 0) + continue; + + if (connect(sock, ptr->ai_addr, ptr->ai_addrlen) < 0) { + close(sock); + continue; + } + + break; + } + + if (ptr) { + char host[NI_MAXHOST], serv[NI_MAXSERV]; + getnameinfo(ptr->ai_addr, ptr->ai_addrlen, + host, sizeof host, serv, sizeof serv, + NI_NUMERICSERV); + printf("Connected to %s:%s\n", host, serv); + } + + freeaddrinfo(addrinfo); + + if (!ptr) { + fprintf(stderr, "Failed to connect to %s:%s\n", + app->host, app->port); + exit(EXIT_FAILURE); + } + + struct ssl_client *client = malloc(sizeof *client); + assert(client); + + SSL *ssl = SSL_new(app->ssl_ctx); + assert(ssl); + + SSL_set_connect_state(ssl); + SSL_set_tlsext_host_name(ssl, app->host); + SSL_set1_host(ssl, app->host); + + size_t cap = 4096; + void *buf = malloc(cap); + assert(buf); + + conn_init(&client->conn, sock, ssl, buf, cap); + + return conn_do_tls_handshake(&client->conn, &app->io_uring, 0); +} + +void +stop_client(struct ssl_app *app, struct connection *conn) +{ + conn_free(conn); + + struct ssl_client *client = TO_PARENT_PTR(conn, struct ssl_client, conn); + + free(client->conn.buf); + free(client); +} + +int +handle(struct ssl_app *app) +{ + start_client(app); + + int quit = 0; + while (!quit) { + io_uring_submit_and_wait(&app->io_uring, 1); + + struct io_uring_cqe *cqe; + unsigned head, seen_cqes = 0; + io_uring_for_each_cqe(&app->io_uring, head, cqe) { + struct ioreq *ioreq = io_uring_cqe_get_data(cqe); + assert(ioreq); + + struct connection *conn = + TO_PARENT_PTR(ioreq, struct connection, ioreq); + + switch (ioreq->type) { + case IOREQ_RECV: { + if (cqe->res <= 0) { + conn_prep_close(conn, &app->io_uring); + goto next_cqe; + } + + conn_finish_recv(conn, cqe->res); + + if (!SSL_is_init_finished(conn->ssl)) { + conn_do_tls_handshake(conn, &app->io_uring, cqe->res); + } else { + do_recv(&app->io_uring, conn, cqe->res); + } + } break; + + case IOREQ_SEND: { + if (cqe->res <= 0) { + conn_prep_close(conn, &app->io_uring); + goto next_cqe; + } + + conn_finish_send(conn, cqe->res); + + if (!SSL_is_init_finished(conn->ssl)) { + conn_do_tls_handshake(conn, &app->io_uring, cqe->res); + } else { + do_send(&app->io_uring, conn, cqe->res); + } + } break; + + case IOREQ_CLOSE: { + stop_client(app, conn); + quit = 1; + } break; + + default: + break; + } + +next_cqe: + seen_cqes++; + } + + io_uring_cq_advance(&app->io_uring, seen_cqes); + } + + return EXIT_SUCCESS; +} int main(int argc, char **argv) { - if (argc < 3) { + if (argc < 4) { fprintf(stderr, "Usage: %s <host> <port> <cert>\n", argv[0]); exit(EXIT_FAILURE); } - exit(EXIT_SUCCESS); + char *host = argv[1], *port = argv[2], *cert = argv[3]; + + struct ssl_app app; + memset(&app, 0, sizeof app); + + app.host = host; + app.port = port; + + // create uring + // =================================================================== + // + + unsigned entries = 32, flags = 0; + io_uring_queue_init(entries, &app.io_uring, flags); + + // setup global ssl context and client ssl context + // =================================================================== + // + + app.ssl_ctx = SSL_CTX_new(TLS_method()); + assert(app.ssl_ctx); + + SSL_CTX_set_min_proto_version(app.ssl_ctx, TLS1_2_VERSION); + SSL_CTX_set_verify(app.ssl_ctx, SSL_VERIFY_PEER, NULL); + SSL_CTX_load_verify_file(app.ssl_ctx, cert); + + // start client + // =================================================================== + // + + exit(handle(&app)); } diff --git a/sslserver.c b/sslserver.c @@ -1,14 +1,286 @@ -#include <stdio.h> -#include <stdlib.h> -#include <unistd.h> +#include "common.h" + +struct ssl_remote_client { + struct connection conn; +}; + +int +do_recv(struct io_uring *uring, struct connection *conn, int res); + +int +do_tls_connect(struct io_uring *uring, struct connection *conn, int res) +{ + DBGLOG("[conn:%d] completed tls handshake\n", conn->socket); + + return do_recv(uring, conn, res); +} + +int +do_recv(struct io_uring *uring, struct connection *conn, int res) +{ + DBGLOG("[conn:%d] received %d encrypted bytes\n", conn->socket, res); + + conn->len = SSL_read(conn->ssl, conn->buf, conn->cap); + + DBGLOG("[conn:%d] received %zu plaintext bytes\n", conn->socket, conn->len); + + conn->cur = SSL_write(conn->ssl, conn->buf, conn->len);; + + DBGLOG("[conn:%d] sending %zu plaintext bytes\n", conn->socket, conn->cur); + + return conn_prep_send(conn, uring); +} + +int +do_send(struct io_uring *uring, struct connection *conn, int res) +{ + if (conn->cur < conn->len) { /* have not sent entire message */ + conn->cur += SSL_write(conn->ssl, conn->buf + conn->cur, conn->len - conn->cur); + return conn_prep_send(conn, uring); + } else { /* sent entire message */ + return conn_prep_recv(conn, uring); + } +} + +struct ssl_server { + int socket; + + struct io_uring io_uring; + + SSL_CTX *ssl_ctx; +}; + +int +handle_client(struct ssl_server *server, int fd) +{ + struct ssl_remote_client *client = malloc(sizeof *client); + assert(client); + + SSL *ssl = SSL_new(server->ssl_ctx); + assert(ssl); + + SSL_set_accept_state(ssl); + + size_t cap = 4096; + void *buf = malloc(cap); + assert(buf); + + conn_init(&client->conn, fd, ssl, buf, cap); + + return conn_do_tls_handshake(&client->conn, &server->io_uring, 0); +} + +void +free_client(struct ssl_server *server, struct connection *conn) +{ + conn_free(conn); + + struct ssl_remote_client *client = + TO_PARENT_PTR(conn, struct ssl_remote_client, conn); + + free(client->conn.buf); + free(client); +} + +int +serve(struct ssl_server *server) +{ + struct sockaddr_storage client_addr; + socklen_t client_addrlen = sizeof client_addr; + + struct ioreq accept_ioreq; + accept_ioreq.type = IOREQ_ACCEPT; + accept_ioreq.tag.accept.fd = server->socket; + accept_ioreq.tag.accept.addr = (struct sockaddr *) &client_addr; + accept_ioreq.tag.accept.addrlen = &client_addrlen; + accept_ioreq.tag.accept.flags = 0; + + queue_ioreqs(&server->io_uring, &accept_ioreq, 1); + + while (1) { + io_uring_submit_and_wait(&server->io_uring, 1); + + struct io_uring_cqe *cqe; + unsigned head, seen_cqes = 0; + io_uring_for_each_cqe(&server->io_uring, head, cqe) { + struct ioreq *ioreq = io_uring_cqe_get_data(cqe); + assert(ioreq); + + switch (ioreq->type) { + case IOREQ_ACCEPT: { + if (cqe->res < 0) { + queue_ioreqs(&server->io_uring, + &accept_ioreq, 1); + goto next_cqe; + } + + char host[NI_MAXHOST], serv[NI_MAXSERV]; + getnameinfo(ioreq->tag.accept.addr, + *ioreq->tag.accept.addrlen, + host, sizeof host, + serv, sizeof serv, + NI_NUMERICSERV); + + printf("Accepted client %d from %s:%s\n", + cqe->res, host, serv); + + client_addrlen = sizeof client_addr; + queue_ioreqs(&server->io_uring, &accept_ioreq, 1); + + handle_client(server, cqe->res); + } break; + + case IOREQ_RECV: { + struct connection *conn = + TO_PARENT_PTR(ioreq, struct connection, ioreq); + + if (cqe->res <= 0) { + conn_prep_close(conn, &server->io_uring); + goto next_cqe; + } + + conn_finish_recv(conn, cqe->res); + + if (!SSL_is_init_finished(conn->ssl)) { + conn_do_tls_handshake(conn, &server->io_uring, cqe->res); + } else { + do_recv(&server->io_uring, conn, cqe->res); + } + } break; + + case IOREQ_SEND: { + struct connection *conn = + TO_PARENT_PTR(ioreq, struct connection, ioreq); + + if (cqe->res <= 0) { + conn_prep_close(conn, &server->io_uring); + goto next_cqe; + } + + conn_finish_send(conn, cqe->res); + + if (!SSL_is_init_finished(conn->ssl)) { + conn_do_tls_handshake(conn, &server->io_uring, cqe->res); + } else { + do_send(&server->io_uring, conn, cqe->res); + } + } break; + + case IOREQ_CLOSE: { + struct connection *conn = + TO_PARENT_PTR(ioreq, struct connection, ioreq); + + printf("Closing client %d\n", conn->socket); + + free_client(server, conn); + } break; + + default: + break; + } + +next_cqe: + seen_cqes++; + } + + io_uring_cq_advance(&server->io_uring, seen_cqes); + } + + return EXIT_SUCCESS; +} int main(int argc, char **argv) { - if (argc < 3) { - fprintf(stderr, "Usage: %s <host> <port> <cert>\n", argv[0]); + if (argc < 5) { + fprintf(stderr, "Usage: %s <host> <port> <cert> <key>\n", argv[0]); + exit(EXIT_FAILURE); + } + + char *host = argv[1], *port = argv[2], *cert = argv[3], *key = argv[4]; + + struct ssl_server server; + memset(&server, 0, sizeof server); + + // create server socket + // =================================================================== + // + + struct addrinfo hints = { + .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM, + .ai_protocol = IPPROTO_TCP, + .ai_flags = AI_NUMERICSERV, + }, *addrinfo, *ptr; + + int res; + if ((res = getaddrinfo(host, port, &hints, &addrinfo))) { + fprintf(stderr, "Failed to get address info: %s\n", + gai_strerror(res)); + exit(EXIT_FAILURE); + } + + for (ptr = addrinfo; ptr; ptr = ptr->ai_next) { + server.socket = socket(ptr->ai_family, + ptr->ai_socktype, + ptr->ai_protocol); + + if (server.socket < 0) + continue; + + if (bind(server.socket, ptr->ai_addr, ptr->ai_addrlen) < 0) { + close(server.socket); + continue; + } + + listen(server.socket, 32); + + break; + } + + if (ptr) { + char host[NI_MAXHOST], serv[NI_MAXSERV]; + getnameinfo(ptr->ai_addr, ptr->ai_addrlen, + host, sizeof host, serv, sizeof serv, + NI_NUMERICSERV); + printf("Bound to %s:%s\n", host, serv); + } + + freeaddrinfo(addrinfo); + + if (!ptr) { + fprintf(stderr, "Failed to bind to %s:%s\n", host, port); + exit(EXIT_FAILURE); + } + + setsockopt(server.socket, SOL_SOCKET, SO_REUSEADDR, &(int){1}, sizeof(int)); + + // create uring + // =================================================================== + // + + unsigned entries = 32, flags = 0; + io_uring_queue_init(entries, &server.io_uring, flags); + + // setup global ssl context + // =================================================================== + // + + server.ssl_ctx = SSL_CTX_new(TLS_method()); + assert(server.ssl_ctx); + + SSL_CTX_set_min_proto_version(server.ssl_ctx, TLS1_2_VERSION); + SSL_CTX_use_certificate_file(server.ssl_ctx, cert, SSL_FILETYPE_PEM); + SSL_CTX_use_PrivateKey_file(server.ssl_ctx, key, SSL_FILETYPE_PEM); + + if (SSL_CTX_check_private_key(server.ssl_ctx) != 1) { + fprintf(stderr, "Invalid certificate and private key pair given\n"); exit(EXIT_FAILURE); } - exit(EXIT_SUCCESS); + // start server + // =================================================================== + // + + exit(serve(&server)); }