/* * Copyright (C) 2017 Andy Spencer * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #define _GNU_SOURCE #include #include #include #include #include #include #include #include #include #include #include #ifdef USE_OPENSSL #include #include #endif #ifdef USE_GNUTLS #include #include #endif #include "util.h" #include "net.h" /* Crypto types */ struct crypto_t { #ifdef USE_OPENSSL SSL_CTX *ctx; SSL *ssl; BIO *in; BIO *out; #endif #ifdef USE_GNUTLS gnutls_session_t tls; #endif }; /* Local data */ #ifdef USE_GNUTLS static gnutls_certificate_credentials_t xcred; #endif /* Local functions */ static int flush(net_t *net) { #ifdef USE_OPENSSL static char buf[NET_BUFFER]; #endif int len; /* State machine */ if (net->state == NET_READY) { if (net->out_len > 0) { debug("net: flush plain"); len = send(net->poll.fd, &net->out_buf[net->out_pos], net->out_len-net->out_pos, 0); if (len > 0) net->out_pos += len; if (net->out_pos == net->out_len) { net->out_pos = 0; net->out_len = 0; } } } if (net->state == NET_ENCRYPTED) { #ifdef USE_OPENSSL if (net->out_len > 0) { debug("net: flush crypto"); len = SSL_write(net->crypto->ssl, &net->out_buf[net->out_pos], net->out_len-net->out_pos); if (len > 0) net->out_pos += len; if (net->out_pos == net->out_len) { net->out_pos = 0; net->out_len = 0; } } while ((len = BIO_read(net->out, buf, sizeof(buf))) > 0) send(net->poll.fd, buf, len, 0); #endif #ifdef USE_GNUTLS if (net->out_len > 0) { len = gnutls_record_send(net->crypto->tls, &net->out_buf[net->out_pos], net->out_len-net->out_pos); if (len > 0) net->out_pos += len; if (net->out_pos == net->out_len) { net->out_pos = 0; net->out_len = 0; } } #endif } return 1; } static void on_poll(void *_net) { static char buf[NET_BUFFER]; net_t *net = _net; int len, err = 0, done = 0; while (!done) { /* Handle Errors */ socklen_t elen = sizeof(err); if (getsockopt(net->poll.fd, SOL_SOCKET, SO_ERROR, &err, &elen)) error("Error getting socket opt"); if (err) { debug("Socket error: %s", strerror(err)); net_close(net); } /* State machine */ if (net->state == NET_CLOSED) { done = 1; } if (net->state == NET_CONNECT) { debug("net: connect"); net->state = NET_READY; flush(net); } if (net->state == NET_READY) { debug("net: ready"); len = recv(net->poll.fd, buf, sizeof(buf), 0); if (len < 0) done = 1; if (len == 0) net_close(net); net->recv(net->data, buf, len); } if (net->state == NET_ENCRYPT) { debug("net: encrypt"); net->state = NET_HANDSHAKE; } if (net->state == NET_HANDSHAKE) { debug("net: handshake"); #ifdef USE_OPENSSL len = recv(net->poll.fd, buf, sizeof(buf), 0); if (len < 0) done = 1; if (len == 0) net_close(net); if (len > 0) BIO_write(net->in, buf, len); SSL_do_handshake(net->crypto->ssl); while ((len = BIO_read(net->out, buf, sizeof(buf))) > 0) send(net->poll.fd, buf, len, 0); if (SSL_is_init_finished(net->crypto->ssl)) net->state = NET_ENCRYPTED; #endif #ifdef USE_GNUTLS err = gnutls_handshake(net->crypto->tls); if (err == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR) { gnutls_datum_t out; int type = gnutls_certificate_type_get(net->crypto->tls); unsigned status = gnutls_session_get_verify_cert_status(net->crypto->tls); gnutls_certificate_verification_status_print(status, type, &out, 0); debug("Cert verify failed: %s", out.data); gnutls_free(out.data); net_close(net); } else if (err < 0) { debug("Handshake failed: %s", gnutls_strerror(err)); net_close(net); } else { char *desc = gnutls_session_get_desc(net->crypto->tls); debug("Session info: %s", desc); gnutls_free(desc); net->state = NET_ENCRYPTED; } #endif flush(net); } if (net->state == NET_ENCRYPTED) { debug("net: encrypted"); #ifdef USE_OPENSSL len = recv(net->poll.fd, buf, sizeof(buf), 0); if (len < 0) done = 1; if (len == 0) net_close(net); if (len > 0) BIO_write(net->crypto->in, buf, len); while ((len = SSL_read(net->crypto->ssl, buf, sizeof(buf))) > 0) net->recv(net->data, buf, len); #endif #ifdef USE_GNUTLS len = gnutls_record_recv(net->crypto->tls, buf, sizeof(buf)); if (len < 0) done = 1; if (len == 0) net_close(net); if (len > 0) net->recv(net->data, buf, len); #endif } } } /* Networking functions */ const char *get_hostname(void) { static char hostname[512]; if (gethostname(hostname, sizeof(hostname))) error("Error getting hostname"); return hostname; } void net_init(void) { #ifdef USE_OPENSSL SSL_library_init(); SSL_load_error_strings(); ERR_load_BIO_strings(); OpenSSL_add_all_algorithms(); #endif #ifdef USE_GNUTLS gnutls_global_init(); gnutls_certificate_allocate_credentials(&xcred); gnutls_certificate_set_x509_trust_file(xcred, "/etc/ssl/certs/ca-certificates.crt", GNUTLS_X509_FMT_PEM); #endif } void net_open(net_t *net, const char *host, int port) { int sock, flags; struct addrinfo *addrs = NULL; struct addrinfo hints = {}; char service[16]; int yes = 1, idle = 120; net->host = strcopy(host); net->port = port; snprintf(service, sizeof(service), "%d", net->port); hints.ai_family = AF_INET; hints.ai_socktype = SOCK_STREAM; /* Setup address */ if (getaddrinfo(net->host, service, &hints, &addrs)) error("Error getting net address info"); if ((sock = socket(addrs->ai_family, addrs->ai_socktype, addrs->ai_protocol)) < 0) error("Error opening net socket"); if ((flags = fcntl(sock, F_GETFL, 0)) < 0) error("Error getting net socket flags"); if (fcntl(sock, F_SETFL, flags|O_NONBLOCK) < 0) error("Error setting net socket non-blocking"); if (setsockopt(sock, SOL_TCP, TCP_KEEPIDLE, &idle, sizeof(idle)) < 0) error("Error setting net socket keepidle"); if (setsockopt(sock, SOL_TCP, TCP_KEEPINTVL, &yes, sizeof(yes)) < 0) error("Error setting net socket keepintvl"); if (setsockopt(sock, SOL_TCP, TCP_KEEPCNT, &yes, sizeof(yes)) < 0) error("Error setting net socket keepcnt"); if (setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, &yes, sizeof(yes)) < 0) error("Error setting net socket keepalive"); if (connect(sock, addrs->ai_addr, addrs->ai_addrlen) < 0) if (errno != EINPROGRESS) error("Error connecting socket"); freeaddrinfo(addrs); /* Setup server */ net->state = NET_CONNECT; poll_add(&net->poll, sock, on_poll, net); } void net_encrypt(net_t *net, int flags) { debug("net: encrypt"); #if defined(USE_OPENSSL) net->crypto = new0(crypto_t); net->crypto->ctx = SSL_CTX_new(TLSv1_2_client_method()); net->crypto->ssl = SSL_new(net->ctx); net->crypto->in = BIO_new(BIO_s_mem()); net->crypto->out = BIO_new(BIO_s_mem()); BIO_set_mem_eof_return(net->crypto->in, -1); BIO_set_mem_eof_return(net->crypto->out, -1); SSL_set_bio(net->crypto->ssl, net->crypto->in, net->crypto->out); SSL_set_connect_state(net->crypto->ssl); #elif defined(USE_GNUTLS) net->crypto = new0(crypto_t); gnutls_init(&net->crypto->tls, GNUTLS_CLIENT); gnutls_set_default_priority(net->crypto->tls); gnutls_server_name_set(net->crypto->tls, GNUTLS_NAME_DNS, net->host, strlen(net->host)); gnutls_credentials_set(net->crypto->tls, GNUTLS_CRD_CERTIFICATE, xcred); gnutls_handshake_set_timeout(net->crypto->tls, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT); gnutls_transport_set_int(net->crypto->tls, net->poll.fd); if (!(flags & NET_NOVERIFY)) gnutls_session_set_verify_cert(net->crypto->tls, net->host, 0); #else error("Encryption is not supported"); #endif net->state = NET_ENCRYPT; } int net_send(net_t *net, const char *buf, int len) { if (net->out_len) return 0; debug("net: send"); if (len <= 0) return 0; if (len > NET_BUFFER) len = NET_BUFFER; memcpy(net->out_buf, buf, len); net->out_len = len; net->out_pos = 0; return flush(net); } int net_print(net_t *net, const char *fmt, ...) { int len; va_list ap; if (net->out_len) return 0; va_start(ap, fmt); len = vsnprintf(net->out_buf, NET_BUFFER, fmt, ap); va_end(ap); if (len <= 0) return 0; if (len > NET_BUFFER) len = NET_BUFFER; if (net->out_buf[len-1] == '\n') debug("net: print [%.*s]", len-1, net->out_buf); else debug("net: print [%.*s]", len, net->out_buf); net->out_len = len; net->out_pos = 0; return flush(net); } void net_close(net_t *net) { debug("net_close: %s:%d", net->host, net->port); net->state = NET_CLOSED; poll_del(&net->poll); }