]> Pileus Git - ~andy/lamechat/blob - net.c
XMPP whitespace idle
[~andy/lamechat] / net.c
1 /*
2  * Copyright (C) 2017 Andy Spencer <andy753421@gmail.com>
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
16  */
17
18 #define _GNU_SOURCE
19
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <stdarg.h>
23 #include <errno.h>
24 #include <unistd.h>
25 #include <fcntl.h>
26 #include <netdb.h>
27
28 #include <sys/socket.h>
29 #include <netinet/in.h>
30 #include <netinet/tcp.h>
31
32 #include <openssl/bio.h>
33 #include <openssl/ssl.h>
34
35 #include "util.h"
36 #include "net.h"
37
38 /* Local functions */
39 static int flush(net_t *net)
40 {
41         static char buf[NET_BUFFER];
42         int len;
43
44         /* State machine */
45         if (net->state == NET_READY) {
46                 if (net->out_len > 0) {
47                         debug("net: flush plain");
48                         len = send(net->poll.fd, &net->out_buf[net->out_pos],
49                                        net->out_len-net->out_pos, 0);
50                         if (len > 0)
51                                 net->out_pos += len;
52                         if (net->out_pos == net->out_len) {
53                                 net->out_pos = 0;
54                                 net->out_len = 0;
55                         }
56                 }
57         }
58         if (net->state == NET_ENCRYPTED) {
59                 if (net->out_len > 0) {
60                         debug("net: flush crypto");
61                         len = SSL_write(net->ssl,
62                                         &net->out_buf[net->out_pos],
63                                         net->out_len-net->out_pos);
64                         if (len > 0)
65                                 net->out_pos += len;
66                         if (net->out_pos == net->out_len) {
67                                 net->out_pos = 0;
68                                 net->out_len = 0;
69                         }
70                 }
71
72                 while ((len = BIO_read(net->out, buf, sizeof(buf))) > 0)
73                         send(net->poll.fd, buf, len, 0);
74         }
75
76         return 1;
77 }
78
79 static void on_poll(void *_net)
80 {
81         static char buf[NET_BUFFER];
82         net_t *net = _net;
83         int len, err = 0, done = 0;
84
85         while (!done) {
86                 /* Handle Errors */
87                 socklen_t elen = sizeof(err);
88                 if (getsockopt(net->poll.fd, SOL_SOCKET, SO_ERROR, &err, &elen))
89                         error("Error getting socket opt");
90                 if (err) {
91                         debug("Socket error: %s", strerror(err));
92                         net_close(net);
93                 }
94
95                 /* State machine */
96                 if (net->state == NET_CONNECT) {
97                         debug("net: connect");
98                         net->state = NET_READY;
99                         flush(net);
100                 }
101                 if (net->state == NET_READY) {
102                         debug("net: ready");
103                         len = recv(net->poll.fd, buf, sizeof(buf), 0);
104                         if (len < 0)
105                                 done = 1;
106                         if (len == 0)
107                                 net_close(net);
108                         net->recv(net->data, buf, len);
109                 }
110                 if (net->state == NET_ENCRYPT) {
111                         debug("net: encrypt");
112                         net->state = NET_HANDSHAKE;
113                 }
114                 if (net->state == NET_HANDSHAKE) {
115                         debug("net: handshake");
116                         len = recv(net->poll.fd, buf, sizeof(buf), 0);
117                         if (len < 0)
118                                 done = 1;
119                         if (len == 0)
120                                 net_close(net);
121                         if (len > 0)
122                                 BIO_write(net->in, buf, len);
123
124                         SSL_do_handshake(net->ssl);
125
126                         while ((len = BIO_read(net->out, buf, sizeof(buf))) > 0)
127                                 send(net->poll.fd, buf, len, 0);
128
129                         if (SSL_is_init_finished(net->ssl))
130                                 net->state = NET_ENCRYPTED;
131                         flush(net);
132                 }
133                 if (net->state == NET_ENCRYPTED) {
134                         debug("net: encrypted");
135                         len = recv(net->poll.fd, buf, sizeof(buf), 0);
136                         if (len < 0)
137                                 done = 1;
138                         if (len == 0)
139                                 net_close(net);
140                         if (len > 0)
141                                 BIO_write(net->in, buf, len);
142
143                         while ((len = SSL_read(net->ssl, buf, sizeof(buf))) > 0)
144                                 net->recv(net->data, buf, len);
145                 }
146         }
147 }
148
149 /* Networking functions */
150 const char *get_hostname(void)
151 {
152         static char hostname[512];
153
154         if (gethostname(hostname, sizeof(hostname)))
155                 error("Error getting hostname");
156
157         return hostname;
158 }
159
160 void net_init(void)
161 {
162         SSL_library_init();
163         SSL_load_error_strings();
164         ERR_load_BIO_strings();
165         OpenSSL_add_all_algorithms();
166 }
167
168 void net_open(net_t *net, const char *host, int port)
169 {
170         int sock, flags;
171         struct addrinfo *addrs = NULL;
172         struct addrinfo hints = {};
173         char service[16];
174         int yes = 1, idle = 120;
175
176         net->host = strcopy(host);
177         net->port = port;
178
179         snprintf(service, sizeof(service), "%d", net->port);
180         hints.ai_family   = AF_INET;
181         hints.ai_socktype = SOCK_STREAM;
182
183         /* Setup address */
184         if (getaddrinfo(net->host, service, &hints, &addrs))
185                 error("Error getting net address info");
186
187         if ((sock = socket(addrs->ai_family,
188                            addrs->ai_socktype,
189                            addrs->ai_protocol)) < 0)
190                 error("Error opening net socket");
191
192         if ((flags = fcntl(sock, F_GETFL, 0)) < 0)
193                 error("Error getting net socket flags");
194
195         if (fcntl(sock, F_SETFL, flags|O_NONBLOCK) < 0)
196                 error("Error setting net socket non-blocking");
197
198         if (setsockopt(sock, SOL_TCP, TCP_KEEPIDLE, &idle, sizeof(idle)) < 0)
199                 error("Error setting net socket keepidle");
200
201         if (setsockopt(sock, SOL_TCP, TCP_KEEPINTVL, &yes, sizeof(yes)) < 0)
202                 error("Error setting net socket keepintvl");
203
204         if (setsockopt(sock, SOL_TCP, TCP_KEEPCNT, &yes, sizeof(yes)) < 0)
205                 error("Error setting net socket keepcnt");
206
207         if (setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, &yes, sizeof(yes)) < 0)
208                 error("Error setting net socket keepalive");
209
210         if (connect(sock, addrs->ai_addr, addrs->ai_addrlen) < 0)
211                 if (errno != EINPROGRESS)
212                         error("Error connecting socket");
213
214         freeaddrinfo(addrs);
215
216         /* Setup server */
217         net->state = NET_CONNECT;
218         poll_add(&net->poll, sock, on_poll, net);
219 }
220
221 void net_encrypt(net_t *net)
222 {
223         debug("net: encrypt");
224
225         net->ctx = SSL_CTX_new(TLSv1_2_client_method());
226         net->ssl = SSL_new(net->ctx);
227
228         net->in  = BIO_new(BIO_s_mem());
229         net->out = BIO_new(BIO_s_mem());
230
231         BIO_set_mem_eof_return(net->in,  -1);
232         BIO_set_mem_eof_return(net->out, -1);
233
234         SSL_set_bio(net->ssl, net->in, net->out);
235         SSL_set_connect_state(net->ssl);
236
237         net->state = NET_ENCRYPT;
238 }
239
240 int net_send(net_t *net, const char *buf, int len)
241 {
242         if (net->out_len)
243                 return 0;
244
245         debug("net: send");
246
247         if (len <= 0)
248                 return 0;
249         if (len > NET_BUFFER)
250                 len = NET_BUFFER;
251         memcpy(net->out_buf, buf, len);
252
253         net->out_len = len;
254         net->out_pos = 0;
255
256         return flush(net);
257 }
258
259 int net_print(net_t *net, const char *fmt, ...)
260 {
261         int len;
262         va_list ap;
263         if (net->out_len)
264                 return 0;
265
266         va_start(ap, fmt);
267         len = vsnprintf(net->out_buf, NET_BUFFER, fmt, ap);
268         va_end(ap);
269         if (len <= 0)
270                 return 0;
271         if (len > NET_BUFFER)
272                 len = NET_BUFFER;
273
274         if (net->out_buf[len-1] == '\n')
275                 debug("net: print [%.*s]", len-1, net->out_buf);
276         else
277                 debug("net: print [%.*s]", len, net->out_buf);
278
279         net->out_len = len;
280         net->out_pos = 0;
281
282         return flush(net);
283 }
284
285 void net_close(net_t *net)
286 {
287         debug("net_close: %s:%d",
288                 net->host, net->port);
289         net->state = NET_CLOSED;
290         poll_del(&net->poll);
291 }