]> Pileus Git - ~andy/lamechat/blob - net.c
Support GnuTLS
[~andy/lamechat] / net.c
1 /*
2  * Copyright (C) 2017 Andy Spencer <andy753421@gmail.com>
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
16  */
17
18 #define _GNU_SOURCE
19
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <stdarg.h>
23 #include <string.h>
24 #include <errno.h>
25 #include <unistd.h>
26 #include <fcntl.h>
27 #include <netdb.h>
28
29 #include <sys/socket.h>
30 #include <netinet/in.h>
31 #include <netinet/tcp.h>
32
33 #ifdef USE_OPENSSL
34 #include <openssl/bio.h>
35 #include <openssl/ssl.h>
36 #endif
37 #ifdef USE_GNUTLS
38 #include <gnutls/gnutls.h>
39 #include <gnutls/x509.h>
40 #endif
41
42 #include "util.h"
43 #include "net.h"
44
45 /* Crypto types */
46 struct crypto_t {
47 #ifdef USE_OPENSSL
48         SSL_CTX *ctx;
49         SSL     *ssl;
50         BIO     *in;
51         BIO     *out;
52 #endif
53 #ifdef USE_GNUTLS
54         gnutls_session_t tls;
55 #endif
56 };
57
58 /* Local data */
59 #ifdef USE_GNUTLS
60 static gnutls_certificate_credentials_t xcred;
61 #endif
62
63 /* Local functions */
64 static int flush(net_t *net)
65 {
66 #ifdef USE_OPENSSL
67         static char buf[NET_BUFFER];
68 #endif
69         int len;
70
71         /* State machine */
72         if (net->state == NET_READY) {
73                 if (net->out_len > 0) {
74                         debug("net: flush plain");
75                         len = send(net->poll.fd, &net->out_buf[net->out_pos],
76                                        net->out_len-net->out_pos, 0);
77                         if (len > 0)
78                                 net->out_pos += len;
79                         if (net->out_pos == net->out_len) {
80                                 net->out_pos = 0;
81                                 net->out_len = 0;
82                         }
83                 }
84         }
85         if (net->state == NET_ENCRYPTED) {
86 #ifdef USE_OPENSSL
87                 if (net->out_len > 0) {
88                         debug("net: flush crypto");
89                         len = SSL_write(net->crypto->ssl,
90                                         &net->out_buf[net->out_pos],
91                                         net->out_len-net->out_pos);
92                         if (len > 0)
93                                 net->out_pos += len;
94                         if (net->out_pos == net->out_len) {
95                                 net->out_pos = 0;
96                                 net->out_len = 0;
97                         }
98                 }
99
100                 while ((len = BIO_read(net->out, buf, sizeof(buf))) > 0)
101                         send(net->poll.fd, buf, len, 0);
102 #endif
103 #ifdef USE_GNUTLS
104                 if (net->out_len > 0) {
105                         len = gnutls_record_send(net->crypto->tls,
106                                            &net->out_buf[net->out_pos],
107                                            net->out_len-net->out_pos);
108                         if (len > 0)
109                                 net->out_pos += len;
110                         if (net->out_pos == net->out_len) {
111                                 net->out_pos = 0;
112                                 net->out_len = 0;
113                         }
114                 }
115 #endif
116         }
117
118         return 1;
119 }
120
121 static void on_poll(void *_net)
122 {
123         static char buf[NET_BUFFER];
124         net_t *net = _net;
125         int len, err = 0, done = 0;
126
127         while (!done) {
128                 /* Handle Errors */
129                 socklen_t elen = sizeof(err);
130                 if (getsockopt(net->poll.fd, SOL_SOCKET, SO_ERROR, &err, &elen))
131                         error("Error getting socket opt");
132                 if (err) {
133                         debug("Socket error: %s", strerror(err));
134                         net_close(net);
135                 }
136
137                 /* State machine */
138                 if (net->state == NET_CLOSED) {
139                         done = 1;
140                 }
141                 if (net->state == NET_CONNECT) {
142                         debug("net: connect");
143                         net->state = NET_READY;
144                         flush(net);
145                 }
146                 if (net->state == NET_READY) {
147                         debug("net: ready");
148                         len = recv(net->poll.fd, buf, sizeof(buf), 0);
149                         if (len < 0)
150                                 done = 1;
151                         if (len == 0)
152                                 net_close(net);
153                         net->recv(net->data, buf, len);
154                 }
155                 if (net->state == NET_ENCRYPT) {
156                         debug("net: encrypt");
157                         net->state = NET_HANDSHAKE;
158                 }
159                 if (net->state == NET_HANDSHAKE) {
160                         debug("net: handshake");
161 #ifdef USE_OPENSSL
162                         len = recv(net->poll.fd, buf, sizeof(buf), 0);
163                         if (len < 0)
164                                 done = 1;
165                         if (len == 0)
166                                 net_close(net);
167                         if (len > 0)
168                                 BIO_write(net->in, buf, len);
169
170                         SSL_do_handshake(net->crypto->ssl);
171
172                         while ((len = BIO_read(net->out, buf, sizeof(buf))) > 0)
173                                 send(net->poll.fd, buf, len, 0);
174
175                         if (SSL_is_init_finished(net->crypto->ssl))
176                                 net->state = NET_ENCRYPTED;
177 #endif
178 #ifdef USE_GNUTLS
179                         err = gnutls_handshake(net->crypto->tls);
180                         if (err == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR) {
181                                 gnutls_datum_t out;
182                                 int type = gnutls_certificate_type_get(net->crypto->tls);
183                                 unsigned status = gnutls_session_get_verify_cert_status(net->crypto->tls);
184                                 gnutls_certificate_verification_status_print(status, type, &out, 0);
185                                 debug("Cert verify failed: %s", out.data);
186                                 gnutls_free(out.data);
187                                 net_close(net);
188                         } else if (err < 0) {
189                                 debug("Handshake failed: %s", gnutls_strerror(err));
190                                 net_close(net);
191                         } else {
192                                 char *desc = gnutls_session_get_desc(net->crypto->tls);
193                                 debug("Session info: %s", desc);
194                                 gnutls_free(desc);
195                                 net->state = NET_ENCRYPTED;
196                         }
197 #endif
198                         flush(net);
199                 }
200                 if (net->state == NET_ENCRYPTED) {
201                         debug("net: encrypted");
202 #ifdef USE_OPENSSL
203                         len = recv(net->poll.fd, buf, sizeof(buf), 0);
204                         if (len < 0)
205                                 done = 1;
206                         if (len == 0)
207                                 net_close(net);
208                         if (len > 0)
209                                 BIO_write(net->crypto->in, buf, len);
210
211                         while ((len = SSL_read(net->crypto->ssl, buf, sizeof(buf))) > 0)
212                                 net->recv(net->data, buf, len);
213 #endif
214 #ifdef USE_GNUTLS
215                         len = gnutls_record_recv(net->crypto->tls, buf, sizeof(buf));
216                         if (len < 0)
217                                 done = 1;
218                         if (len == 0)
219                                 net_close(net);
220                         if (len > 0)
221                                 net->recv(net->data, buf, len);
222 #endif
223                 }
224         }
225 }
226
227 /* Networking functions */
228 const char *get_hostname(void)
229 {
230         static char hostname[512];
231
232         if (gethostname(hostname, sizeof(hostname)))
233                 error("Error getting hostname");
234
235         return hostname;
236 }
237
238 void net_init(void)
239 {
240 #ifdef USE_OPENSSL
241         SSL_library_init();
242         SSL_load_error_strings();
243         ERR_load_BIO_strings();
244         OpenSSL_add_all_algorithms();
245 #endif
246 #ifdef USE_GNUTLS
247         gnutls_global_init();
248         gnutls_certificate_allocate_credentials(&xcred);
249         gnutls_certificate_set_x509_trust_file(xcred,
250                         "/etc/ssl/certs/ca-certificates.crt",
251                         GNUTLS_X509_FMT_PEM);
252 #endif
253 }
254
255 void net_open(net_t *net, const char *host, int port)
256 {
257         int sock, flags;
258         struct addrinfo *addrs = NULL;
259         struct addrinfo hints = {};
260         char service[16];
261         int yes = 1, idle = 120;
262
263         net->host = strcopy(host);
264         net->port = port;
265
266         snprintf(service, sizeof(service), "%d", net->port);
267         hints.ai_family   = AF_INET;
268         hints.ai_socktype = SOCK_STREAM;
269
270         /* Setup address */
271         if (getaddrinfo(net->host, service, &hints, &addrs))
272                 error("Error getting net address info");
273
274         if ((sock = socket(addrs->ai_family,
275                            addrs->ai_socktype,
276                            addrs->ai_protocol)) < 0)
277                 error("Error opening net socket");
278
279         if ((flags = fcntl(sock, F_GETFL, 0)) < 0)
280                 error("Error getting net socket flags");
281
282         if (fcntl(sock, F_SETFL, flags|O_NONBLOCK) < 0)
283                 error("Error setting net socket non-blocking");
284
285         if (setsockopt(sock, SOL_TCP, TCP_KEEPIDLE, &idle, sizeof(idle)) < 0)
286                 error("Error setting net socket keepidle");
287
288         if (setsockopt(sock, SOL_TCP, TCP_KEEPINTVL, &yes, sizeof(yes)) < 0)
289                 error("Error setting net socket keepintvl");
290
291         if (setsockopt(sock, SOL_TCP, TCP_KEEPCNT, &yes, sizeof(yes)) < 0)
292                 error("Error setting net socket keepcnt");
293
294         if (setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, &yes, sizeof(yes)) < 0)
295                 error("Error setting net socket keepalive");
296
297         if (connect(sock, addrs->ai_addr, addrs->ai_addrlen) < 0)
298                 if (errno != EINPROGRESS)
299                         error("Error connecting socket");
300
301         freeaddrinfo(addrs);
302
303         /* Setup server */
304         net->state = NET_CONNECT;
305         poll_add(&net->poll, sock, on_poll, net);
306 }
307
308 void net_encrypt(net_t *net, int flags)
309 {
310         debug("net: encrypt");
311
312 #if defined(USE_OPENSSL)
313         net->crypto = new0(crypto_t);
314         net->crypto->ctx = SSL_CTX_new(TLSv1_2_client_method());
315         net->crypto->ssl = SSL_new(net->ctx);
316
317         net->crypto->in  = BIO_new(BIO_s_mem());
318         net->crypto->out = BIO_new(BIO_s_mem());
319
320         BIO_set_mem_eof_return(net->crypto->in,  -1);
321         BIO_set_mem_eof_return(net->crypto->out, -1);
322
323         SSL_set_bio(net->crypto->ssl, net->crypto->in, net->crypto->out);
324         SSL_set_connect_state(net->crypto->ssl);
325 #elif defined(USE_GNUTLS)
326         net->crypto = new0(crypto_t);
327         gnutls_init(&net->crypto->tls, GNUTLS_CLIENT);
328         gnutls_set_default_priority(net->crypto->tls);
329         gnutls_server_name_set(net->crypto->tls, GNUTLS_NAME_DNS,
330                                net->host, strlen(net->host));
331         gnutls_credentials_set(net->crypto->tls, GNUTLS_CRD_CERTIFICATE, xcred);
332         gnutls_handshake_set_timeout(net->crypto->tls, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
333         gnutls_transport_set_int(net->crypto->tls, net->poll.fd);
334         if (!(flags & NET_NOVERIFY))
335                 gnutls_session_set_verify_cert(net->crypto->tls, net->host, 0);
336 #else
337         error("Encryption is not supported");
338 #endif
339
340         net->state = NET_ENCRYPT;
341 }
342
343 int net_send(net_t *net, const char *buf, int len)
344 {
345         if (net->out_len)
346                 return 0;
347
348         debug("net: send");
349
350         if (len <= 0)
351                 return 0;
352         if (len > NET_BUFFER)
353                 len = NET_BUFFER;
354         memcpy(net->out_buf, buf, len);
355
356         net->out_len = len;
357         net->out_pos = 0;
358
359         return flush(net);
360 }
361
362 int net_print(net_t *net, const char *fmt, ...)
363 {
364         int len;
365         va_list ap;
366         if (net->out_len)
367                 return 0;
368
369         va_start(ap, fmt);
370         len = vsnprintf(net->out_buf, NET_BUFFER, fmt, ap);
371         va_end(ap);
372         if (len <= 0)
373                 return 0;
374         if (len > NET_BUFFER)
375                 len = NET_BUFFER;
376
377         if (net->out_buf[len-1] == '\n')
378                 debug("net: print [%.*s]", len-1, net->out_buf);
379         else
380                 debug("net: print [%.*s]", len, net->out_buf);
381
382         net->out_len = len;
383         net->out_pos = 0;
384
385         return flush(net);
386 }
387
388 void net_close(net_t *net)
389 {
390         debug("net_close: %s:%d",
391                 net->host, net->port);
392         net->state = NET_CLOSED;
393         poll_del(&net->poll);
394 }