sslexample

sslexample.git
git clone git://git.lenczewski.org/sslexample.git
Log | Files | Refs

sslserver.c (6648B)


      1 #include "common.h"
      2 
      3 struct ssl_remote_client {
      4 	struct connection conn;
      5 };
      6 
      7 int
      8 do_recv(struct io_uring *uring, struct connection *conn, int res);
      9 
     10 int
     11 do_tls_connect(struct io_uring *uring, struct connection *conn, int res)
     12 {
     13 	DBGLOG("[conn:%d] completed tls handshake\n", conn->socket);
     14 
     15 	return do_recv(uring, conn, res);
     16 }
     17 
     18 int
     19 do_recv(struct io_uring *uring, struct connection *conn, int res)
     20 {
     21 	DBGLOG("[conn:%d] received %d encrypted bytes\n", conn->socket, res);
     22 
     23 	conn->len = SSL_read(conn->ssl, conn->buf, conn->cap);
     24 
     25 	DBGLOG("[conn:%d] received %zu plaintext bytes\n", conn->socket, conn->len);
     26 
     27 	conn->cur = SSL_write(conn->ssl, conn->buf, conn->len);;
     28 
     29 	DBGLOG("[conn:%d] sending %zu plaintext bytes\n", conn->socket, conn->cur);
     30 
     31 	return conn_prep_send(conn, uring);
     32 }
     33 
     34 int
     35 do_send(struct io_uring *uring, struct connection *conn, int res)
     36 {
     37 	if (conn->cur < conn->len) { /* have not sent entire message */
     38 		conn->cur += SSL_write(conn->ssl, conn->buf + conn->cur, conn->len - conn->cur);
     39 		return conn_prep_send(conn, uring);
     40 	} else { /* sent entire message */
     41 		return conn_prep_recv(conn, uring);
     42 	}
     43 }
     44 
     45 struct ssl_server {
     46 	int socket;
     47 
     48 	struct io_uring io_uring;
     49 
     50 	SSL_CTX *ssl_ctx;
     51 };
     52 
     53 int
     54 handle_client(struct ssl_server *server, int fd)
     55 {
     56 	struct ssl_remote_client *client = malloc(sizeof *client);
     57 	assert(client);
     58 
     59 	SSL *ssl = SSL_new(server->ssl_ctx);
     60 	assert(ssl);
     61 
     62 	SSL_set_accept_state(ssl);
     63 
     64 	size_t cap = 4096;
     65 	void *buf = malloc(cap);
     66 	assert(buf);
     67 
     68 	conn_init(&client->conn, fd, ssl, buf, cap);
     69 
     70 	return conn_do_tls_handshake(&client->conn, &server->io_uring, 0);
     71 }
     72 
     73 void
     74 free_client(struct ssl_server *server, struct connection *conn)
     75 {
     76 	conn_free(conn);
     77 
     78 	struct ssl_remote_client *client =
     79 		TO_PARENT_PTR(conn, struct ssl_remote_client, conn);
     80 
     81 	free(client->conn.buf);
     82 	free(client);
     83 }
     84 
     85 int
     86 serve(struct ssl_server *server)
     87 {
     88 	struct sockaddr_storage client_addr;
     89 	socklen_t client_addrlen = sizeof client_addr;
     90 
     91 	struct ioreq accept_ioreq;
     92 	accept_ioreq.type = IOREQ_ACCEPT;
     93 	accept_ioreq.tag.accept.fd = server->socket;
     94 	accept_ioreq.tag.accept.addr = (struct sockaddr *) &client_addr;
     95 	accept_ioreq.tag.accept.addrlen = &client_addrlen;
     96 	accept_ioreq.tag.accept.flags = 0;
     97 
     98 	queue_ioreqs(&server->io_uring, &accept_ioreq, 1);
     99 
    100 	while (1) {
    101 		io_uring_submit_and_wait(&server->io_uring, 1);
    102 
    103 		struct io_uring_cqe *cqe;
    104 		unsigned head, seen_cqes = 0;
    105 		io_uring_for_each_cqe(&server->io_uring, head, cqe) {
    106 			struct ioreq *ioreq = io_uring_cqe_get_data(cqe);
    107 			assert(ioreq);
    108 
    109 			switch (ioreq->type) {
    110 			case IOREQ_ACCEPT: {
    111 				if (cqe->res < 0) {
    112 					queue_ioreqs(&server->io_uring,
    113 						     &accept_ioreq, 1);
    114 					goto next_cqe;
    115 				}
    116 
    117 				char host[NI_MAXHOST], serv[NI_MAXSERV];
    118 				getnameinfo(ioreq->tag.accept.addr,
    119 					    *ioreq->tag.accept.addrlen,
    120 					    host, sizeof host,
    121 					    serv, sizeof serv,
    122 					    NI_NUMERICSERV);
    123 
    124 				printf("Accepted client %d from %s:%s\n",
    125 						cqe->res, host, serv);
    126 
    127 				client_addrlen = sizeof client_addr;
    128 				queue_ioreqs(&server->io_uring, &accept_ioreq, 1);
    129 
    130 				handle_client(server, cqe->res);
    131 			} break;
    132 
    133 			case IOREQ_RECV: {
    134 				struct connection *conn =
    135 					TO_PARENT_PTR(ioreq, struct connection, ioreq);
    136 
    137 				if (cqe->res <= 0) {
    138 					conn_prep_close(conn, &server->io_uring);
    139 					goto next_cqe;
    140 				}
    141 
    142 				conn_finish_recv(conn, cqe->res);
    143 
    144 				if (!SSL_is_init_finished(conn->ssl)) {
    145 					conn_do_tls_handshake(conn, &server->io_uring, cqe->res);
    146 				} else {
    147 					do_recv(&server->io_uring, conn, cqe->res);
    148 				}
    149 			} break;
    150 
    151 			case IOREQ_SEND: {
    152 				struct connection *conn =
    153 					TO_PARENT_PTR(ioreq, struct connection, ioreq);
    154 
    155 				if (cqe->res <= 0) {
    156 					conn_prep_close(conn, &server->io_uring);
    157 					goto next_cqe;
    158 				}
    159 
    160 				conn_finish_send(conn, cqe->res);
    161 
    162 				if (!SSL_is_init_finished(conn->ssl)) {
    163 					conn_do_tls_handshake(conn, &server->io_uring, cqe->res);
    164 				} else {
    165 					do_send(&server->io_uring, conn, cqe->res);
    166 				}
    167 			} break;
    168 
    169 			case IOREQ_CLOSE: {
    170 				struct connection *conn =
    171 					TO_PARENT_PTR(ioreq, struct connection, ioreq);
    172 
    173 				printf("Closing client %d\n", conn->socket);
    174 
    175 				free_client(server, conn);
    176 			} break;
    177 
    178 			default:
    179 				break;
    180 			}
    181 
    182 next_cqe:
    183 			seen_cqes++;
    184 		}
    185 
    186 		io_uring_cq_advance(&server->io_uring, seen_cqes);
    187 	}
    188 
    189 	return EXIT_SUCCESS;
    190 }
    191 
    192 int
    193 main(int argc, char **argv)
    194 {
    195 	if (argc < 5) {
    196 		fprintf(stderr, "Usage: %s <host> <port> <cert> <key>\n", argv[0]);
    197 		exit(EXIT_FAILURE);
    198 	}
    199 
    200 	char *host = argv[1], *port = argv[2], *cert = argv[3], *key = argv[4];
    201 
    202 	struct ssl_server server;
    203 	memset(&server, 0, sizeof server);
    204 
    205 	// create server socket
    206 	// ===================================================================
    207 	//
    208 
    209 	struct addrinfo hints = {
    210 		.ai_family = AF_UNSPEC,
    211 		.ai_socktype = SOCK_STREAM,
    212 		.ai_protocol = IPPROTO_TCP,
    213 		.ai_flags = AI_NUMERICSERV,
    214 	}, *addrinfo, *ptr;
    215 
    216 	int res;
    217 	if ((res = getaddrinfo(host, port, &hints, &addrinfo))) {
    218 		fprintf(stderr, "Failed to get address info: %s\n",
    219 				gai_strerror(res));
    220 		exit(EXIT_FAILURE);
    221 	}
    222 
    223 	for (ptr = addrinfo; ptr; ptr = ptr->ai_next) {
    224 		server.socket = socket(ptr->ai_family,
    225 				       ptr->ai_socktype,
    226 				       ptr->ai_protocol);
    227 
    228 		if (server.socket < 0)
    229 			continue;
    230 
    231 		if (bind(server.socket, ptr->ai_addr, ptr->ai_addrlen) < 0) {
    232 			close(server.socket);
    233 			continue;
    234 		}
    235 
    236 		listen(server.socket, 32);
    237 
    238 		break;
    239 	}
    240 
    241 	if (ptr) {
    242 		char host[NI_MAXHOST], serv[NI_MAXSERV];
    243 		getnameinfo(ptr->ai_addr, ptr->ai_addrlen,
    244 			    host, sizeof host, serv, sizeof serv,
    245 			    NI_NUMERICSERV);
    246 		printf("Bound to %s:%s\n", host, serv);
    247 	}
    248 
    249 	freeaddrinfo(addrinfo);
    250 
    251 	if (!ptr) {
    252 		fprintf(stderr, "Failed to bind to %s:%s\n", host, port);
    253 		exit(EXIT_FAILURE);
    254 	}
    255 
    256 	setsockopt(server.socket, SOL_SOCKET, SO_REUSEADDR, &(int){1}, sizeof(int));
    257 
    258 	// create uring
    259 	// ===================================================================
    260 	//
    261 
    262 	unsigned entries = 32, flags = 0;
    263 	io_uring_queue_init(entries, &server.io_uring, flags);
    264 
    265 	// setup global ssl context
    266 	// ===================================================================
    267 	//
    268 
    269 	server.ssl_ctx = SSL_CTX_new(TLS_method());
    270 	assert(server.ssl_ctx);
    271 
    272 	SSL_CTX_set_min_proto_version(server.ssl_ctx, TLS1_2_VERSION);
    273 	SSL_CTX_use_certificate_file(server.ssl_ctx, cert, SSL_FILETYPE_PEM);
    274 	SSL_CTX_use_PrivateKey_file(server.ssl_ctx, key, SSL_FILETYPE_PEM);
    275 
    276 	if (SSL_CTX_check_private_key(server.ssl_ctx) != 1) {
    277 		fprintf(stderr, "Invalid certificate and private key pair given\n");
    278 		exit(EXIT_FAILURE);
    279 	}
    280 
    281 	// start server
    282 	// ===================================================================
    283 	//
    284 
    285 	exit(serve(&server));
    286 }