]> Pileus Git - ~andy/lamechat/blob - net.c
Split chat by channels.
[~andy/lamechat] / net.c
1 /*
2  * Copyright (C) 2017 Andy Spencer <andy753421@gmail.com>
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
16  */
17
18 #define _GNU_SOURCE
19
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <stdarg.h>
23 #include <errno.h>
24 #include <unistd.h>
25 #include <fcntl.h>
26 #include <netdb.h>
27
28 #include <openssl/bio.h>
29 #include <openssl/ssl.h>
30
31 #include "util.h"
32 #include "net.h"
33
34 /* Local functions */
35 static int flush(net_t *net)
36 {
37         static char buf[NET_BUFFER];
38         int len;
39
40         /* State machine */
41         if (net->state == NET_READY) {
42                 if (net->out_len > 0) {
43                         debug("net: flush plain");
44                         len = send(net->poll.fd, &net->out_buf[net->out_pos],
45                                        net->out_len-net->out_pos, 0);
46                         if (len > 0)
47                                 net->out_pos += len;
48                         if (net->out_pos == net->out_len) {
49                                 net->out_pos = 0;
50                                 net->out_len = 0;
51                         }
52                 }
53         }
54         if (net->state == NET_ENCRYPTED) {
55                 if (net->out_len > 0) {
56                         debug("net: flush crypto");
57                         len = SSL_write(net->ssl,
58                                         &net->out_buf[net->out_pos],
59                                         net->out_len-net->out_pos);
60                         if (len > 0)
61                                 net->out_pos += len;
62                         if (net->out_pos == net->out_len) {
63                                 net->out_pos = 0;
64                                 net->out_len = 0;
65                         }
66                 }
67
68                 while ((len = BIO_read(net->out, buf, sizeof(buf))) > 0)
69                         send(net->poll.fd, buf, len, 0);
70         }
71
72         /* Enable output poll */
73         if ((net->out_len > 0) &&
74             (net->state == NET_READY ||
75              net->state == NET_ENCRYPTED))
76                 poll_ctl(&net->poll, 1, 1, 1);
77         else
78                 poll_ctl(&net->poll, 1, 0, 1);
79
80         return 1;
81 }
82
83 static void on_poll(void *_net)
84 {
85         static char buf[NET_BUFFER];
86         net_t *net = _net;
87         int len;
88
89         /* Handle Errors */
90         int err = 0;
91         socklen_t elen = sizeof(err);
92         if (getsockopt(net->poll.fd, SOL_SOCKET, SO_ERROR, &err, &elen))
93                 error("Error getting socket opt");
94         if (err)
95                 net_close(net);
96
97         /* State machine */
98         if (net->state == NET_CONNECT) {
99                 debug("net: connect");
100                 net->state = NET_READY;
101         }
102         if (net->state == NET_READY) {
103                 debug("net: ready");
104                 len = recv(net->poll.fd, buf, sizeof(buf), 0);
105                 if (len == 0)
106                         net_close(net);
107                 net->recv(net->data, buf, len);
108         }
109         if (net->state == NET_HANDSHAKE) {
110                 debug("net: handshake");
111                 len = recv(net->poll.fd, buf, sizeof(buf), 0);
112                 if (len == 0)
113                         net_close(net);
114                 if (len > 0)
115                         BIO_write(net->in, buf, len);
116
117                 SSL_do_handshake(net->ssl);
118
119                 while ((len = BIO_read(net->out, buf, sizeof(buf))) > 0)
120                         send(net->poll.fd, buf, len, 0);
121
122                 if (SSL_is_init_finished(net->ssl))
123                         net->state = NET_ENCRYPTED;
124                 flush(net);
125         }
126         if (net->state == NET_ENCRYPTED) {
127                 debug("net: encrypted");
128                 len = recv(net->poll.fd, buf, sizeof(buf), 0);
129                 if (len == 0)
130                         net_close(net);
131                 if (len > 0)
132                         BIO_write(net->in, buf, len);
133
134                 while ((len = SSL_read(net->ssl, buf, sizeof(buf))) > 0)
135                         net->recv(net->data, buf, len);
136         }
137 }
138
139 /* Networking functions */
140 const char *get_hostname(void)
141 {
142         static char hostname[512];
143
144         if (gethostname(hostname, sizeof(hostname)))
145                 error("Error getting hostname");
146
147         return hostname;
148 }
149
150 void net_init(void)
151 {
152         SSL_library_init();
153         SSL_load_error_strings();
154         ERR_load_BIO_strings();
155         OpenSSL_add_all_algorithms();
156 }
157
158 void net_open(net_t *net, const char *host, int port)
159 {
160         int sock, flags;
161         struct addrinfo *addrs = NULL;
162         struct addrinfo hints = {};
163         char service[16];
164
165         net->host = strcopy(host);
166         net->port = port;
167
168         snprintf(service, sizeof(service), "%d", net->port);
169         hints.ai_family   = AF_INET;
170         hints.ai_socktype = SOCK_STREAM;
171
172         /* Setup address */
173         if (getaddrinfo(net->host, service, &hints, &addrs))
174                 error("Error getting net address info");
175
176         if ((sock = socket(addrs->ai_family,
177                            addrs->ai_socktype,
178                            addrs->ai_protocol)) < 0)
179                 error("Error opening net socket");
180
181         if ((flags = fcntl(sock, F_GETFL, 0)) < 0)
182                 error("Error getting net socket flags");
183
184         if (fcntl(sock, F_SETFL, flags|O_NONBLOCK) < 0)
185                 error("Error setting net socket non-blocking");
186
187         if (connect(sock, addrs->ai_addr, addrs->ai_addrlen) < 0)
188                 if (errno != EINPROGRESS)
189                         error("Error connecting socket");
190
191         freeaddrinfo(addrs);
192
193         /* Setup server */
194         net->state = NET_CONNECT;
195         poll_add(&net->poll, sock, on_poll, net);
196 }
197
198 void net_encrypt(net_t *net)
199 {
200         debug("net: encrypt");
201
202         net->ctx = SSL_CTX_new(TLSv1_2_client_method());
203         net->ssl = SSL_new(net->ctx);
204
205         net->in  = BIO_new(BIO_s_mem());
206         net->out = BIO_new(BIO_s_mem());
207
208         BIO_set_mem_eof_return(net->in,  -1);
209         BIO_set_mem_eof_return(net->out, -1);
210
211         SSL_set_bio(net->ssl, net->in, net->out);
212         SSL_set_connect_state(net->ssl);
213
214         net->state = NET_HANDSHAKE;
215 }
216
217 int net_send(net_t *net, const char *buf, int len)
218 {
219         if (net->out_len)
220                 return 0;
221
222         debug("net: send");
223
224         if (len <= 0)
225                 return 0;
226         if (len > NET_BUFFER)
227                 len = NET_BUFFER;
228         memcpy(net->out_buf, buf, len);
229
230         net->out_len = len;
231         net->out_pos = 0;
232
233         return flush(net);
234 }
235
236 int net_print(net_t *net, const char *fmt, ...)
237 {
238         int len;
239         va_list ap;
240         if (net->out_len)
241                 return 0;
242
243         va_start(ap, fmt);
244         len = vsnprintf(net->out_buf, NET_BUFFER, fmt, ap);
245         va_end(ap);
246         if (len <= 0)
247                 return 0;
248         if (len > NET_BUFFER)
249                 len = NET_BUFFER;
250
251         if (net->out_buf[len-1] == '\n')
252                 debug("net: print [%.*s]", len-1, net->out_buf);
253         else
254                 debug("net: print [%.*s]", len, net->out_buf);
255
256         net->out_len = len;
257         net->out_pos = 0;
258
259         return flush(net);
260 }
261
262 void net_close(net_t *net)
263 {
264         debug("net_close: %s:%d",
265                 net->host, net->port);
266         net->state = NET_CLOSED;
267         poll_del(&net->poll);
268 }