sslserver.c (6648B)
1 #include "common.h" 2 3 struct ssl_remote_client { 4 struct connection conn; 5 }; 6 7 int 8 do_recv(struct io_uring *uring, struct connection *conn, int res); 9 10 int 11 do_tls_connect(struct io_uring *uring, struct connection *conn, int res) 12 { 13 DBGLOG("[conn:%d] completed tls handshake\n", conn->socket); 14 15 return do_recv(uring, conn, res); 16 } 17 18 int 19 do_recv(struct io_uring *uring, struct connection *conn, int res) 20 { 21 DBGLOG("[conn:%d] received %d encrypted bytes\n", conn->socket, res); 22 23 conn->len = SSL_read(conn->ssl, conn->buf, conn->cap); 24 25 DBGLOG("[conn:%d] received %zu plaintext bytes\n", conn->socket, conn->len); 26 27 conn->cur = SSL_write(conn->ssl, conn->buf, conn->len);; 28 29 DBGLOG("[conn:%d] sending %zu plaintext bytes\n", conn->socket, conn->cur); 30 31 return conn_prep_send(conn, uring); 32 } 33 34 int 35 do_send(struct io_uring *uring, struct connection *conn, int res) 36 { 37 if (conn->cur < conn->len) { /* have not sent entire message */ 38 conn->cur += SSL_write(conn->ssl, conn->buf + conn->cur, conn->len - conn->cur); 39 return conn_prep_send(conn, uring); 40 } else { /* sent entire message */ 41 return conn_prep_recv(conn, uring); 42 } 43 } 44 45 struct ssl_server { 46 int socket; 47 48 struct io_uring io_uring; 49 50 SSL_CTX *ssl_ctx; 51 }; 52 53 int 54 handle_client(struct ssl_server *server, int fd) 55 { 56 struct ssl_remote_client *client = malloc(sizeof *client); 57 assert(client); 58 59 SSL *ssl = SSL_new(server->ssl_ctx); 60 assert(ssl); 61 62 SSL_set_accept_state(ssl); 63 64 size_t cap = 4096; 65 void *buf = malloc(cap); 66 assert(buf); 67 68 conn_init(&client->conn, fd, ssl, buf, cap); 69 70 return conn_do_tls_handshake(&client->conn, &server->io_uring, 0); 71 } 72 73 void 74 free_client(struct ssl_server *server, struct connection *conn) 75 { 76 conn_free(conn); 77 78 struct ssl_remote_client *client = 79 TO_PARENT_PTR(conn, struct ssl_remote_client, conn); 80 81 free(client->conn.buf); 82 free(client); 83 } 84 85 int 86 serve(struct ssl_server *server) 87 { 88 struct sockaddr_storage client_addr; 89 socklen_t client_addrlen = sizeof client_addr; 90 91 struct ioreq accept_ioreq; 92 accept_ioreq.type = IOREQ_ACCEPT; 93 accept_ioreq.tag.accept.fd = server->socket; 94 accept_ioreq.tag.accept.addr = (struct sockaddr *) &client_addr; 95 accept_ioreq.tag.accept.addrlen = &client_addrlen; 96 accept_ioreq.tag.accept.flags = 0; 97 98 queue_ioreqs(&server->io_uring, &accept_ioreq, 1); 99 100 while (1) { 101 io_uring_submit_and_wait(&server->io_uring, 1); 102 103 struct io_uring_cqe *cqe; 104 unsigned head, seen_cqes = 0; 105 io_uring_for_each_cqe(&server->io_uring, head, cqe) { 106 struct ioreq *ioreq = io_uring_cqe_get_data(cqe); 107 assert(ioreq); 108 109 switch (ioreq->type) { 110 case IOREQ_ACCEPT: { 111 if (cqe->res < 0) { 112 queue_ioreqs(&server->io_uring, 113 &accept_ioreq, 1); 114 goto next_cqe; 115 } 116 117 char host[NI_MAXHOST], serv[NI_MAXSERV]; 118 getnameinfo(ioreq->tag.accept.addr, 119 *ioreq->tag.accept.addrlen, 120 host, sizeof host, 121 serv, sizeof serv, 122 NI_NUMERICSERV); 123 124 printf("Accepted client %d from %s:%s\n", 125 cqe->res, host, serv); 126 127 client_addrlen = sizeof client_addr; 128 queue_ioreqs(&server->io_uring, &accept_ioreq, 1); 129 130 handle_client(server, cqe->res); 131 } break; 132 133 case IOREQ_RECV: { 134 struct connection *conn = 135 TO_PARENT_PTR(ioreq, struct connection, ioreq); 136 137 if (cqe->res <= 0) { 138 conn_prep_close(conn, &server->io_uring); 139 goto next_cqe; 140 } 141 142 conn_finish_recv(conn, cqe->res); 143 144 if (!SSL_is_init_finished(conn->ssl)) { 145 conn_do_tls_handshake(conn, &server->io_uring, cqe->res); 146 } else { 147 do_recv(&server->io_uring, conn, cqe->res); 148 } 149 } break; 150 151 case IOREQ_SEND: { 152 struct connection *conn = 153 TO_PARENT_PTR(ioreq, struct connection, ioreq); 154 155 if (cqe->res <= 0) { 156 conn_prep_close(conn, &server->io_uring); 157 goto next_cqe; 158 } 159 160 conn_finish_send(conn, cqe->res); 161 162 if (!SSL_is_init_finished(conn->ssl)) { 163 conn_do_tls_handshake(conn, &server->io_uring, cqe->res); 164 } else { 165 do_send(&server->io_uring, conn, cqe->res); 166 } 167 } break; 168 169 case IOREQ_CLOSE: { 170 struct connection *conn = 171 TO_PARENT_PTR(ioreq, struct connection, ioreq); 172 173 printf("Closing client %d\n", conn->socket); 174 175 free_client(server, conn); 176 } break; 177 178 default: 179 break; 180 } 181 182 next_cqe: 183 seen_cqes++; 184 } 185 186 io_uring_cq_advance(&server->io_uring, seen_cqes); 187 } 188 189 return EXIT_SUCCESS; 190 } 191 192 int 193 main(int argc, char **argv) 194 { 195 if (argc < 5) { 196 fprintf(stderr, "Usage: %s <host> <port> <cert> <key>\n", argv[0]); 197 exit(EXIT_FAILURE); 198 } 199 200 char *host = argv[1], *port = argv[2], *cert = argv[3], *key = argv[4]; 201 202 struct ssl_server server; 203 memset(&server, 0, sizeof server); 204 205 // create server socket 206 // =================================================================== 207 // 208 209 struct addrinfo hints = { 210 .ai_family = AF_UNSPEC, 211 .ai_socktype = SOCK_STREAM, 212 .ai_protocol = IPPROTO_TCP, 213 .ai_flags = AI_NUMERICSERV, 214 }, *addrinfo, *ptr; 215 216 int res; 217 if ((res = getaddrinfo(host, port, &hints, &addrinfo))) { 218 fprintf(stderr, "Failed to get address info: %s\n", 219 gai_strerror(res)); 220 exit(EXIT_FAILURE); 221 } 222 223 for (ptr = addrinfo; ptr; ptr = ptr->ai_next) { 224 server.socket = socket(ptr->ai_family, 225 ptr->ai_socktype, 226 ptr->ai_protocol); 227 228 if (server.socket < 0) 229 continue; 230 231 if (bind(server.socket, ptr->ai_addr, ptr->ai_addrlen) < 0) { 232 close(server.socket); 233 continue; 234 } 235 236 listen(server.socket, 32); 237 238 break; 239 } 240 241 if (ptr) { 242 char host[NI_MAXHOST], serv[NI_MAXSERV]; 243 getnameinfo(ptr->ai_addr, ptr->ai_addrlen, 244 host, sizeof host, serv, sizeof serv, 245 NI_NUMERICSERV); 246 printf("Bound to %s:%s\n", host, serv); 247 } 248 249 freeaddrinfo(addrinfo); 250 251 if (!ptr) { 252 fprintf(stderr, "Failed to bind to %s:%s\n", host, port); 253 exit(EXIT_FAILURE); 254 } 255 256 setsockopt(server.socket, SOL_SOCKET, SO_REUSEADDR, &(int){1}, sizeof(int)); 257 258 // create uring 259 // =================================================================== 260 // 261 262 unsigned entries = 32, flags = 0; 263 io_uring_queue_init(entries, &server.io_uring, flags); 264 265 // setup global ssl context 266 // =================================================================== 267 // 268 269 server.ssl_ctx = SSL_CTX_new(TLS_method()); 270 assert(server.ssl_ctx); 271 272 SSL_CTX_set_min_proto_version(server.ssl_ctx, TLS1_2_VERSION); 273 SSL_CTX_use_certificate_file(server.ssl_ctx, cert, SSL_FILETYPE_PEM); 274 SSL_CTX_use_PrivateKey_file(server.ssl_ctx, key, SSL_FILETYPE_PEM); 275 276 if (SSL_CTX_check_private_key(server.ssl_ctx) != 1) { 277 fprintf(stderr, "Invalid certificate and private key pair given\n"); 278 exit(EXIT_FAILURE); 279 } 280 281 // start server 282 // =================================================================== 283 // 284 285 exit(serve(&server)); 286 }