#define TCP_BODY
#include "basictcp.h"
#include <sys/socket.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <errno.h>
#include <stdarg.h>
#include "postal.h"
#include "userlist.h"
#include "address.h"
#include "logit.h"
#include "results.h"
#ifndef USE_OPENSSL
int base_tcp::m_init_dh_params = 0;
gnutls_dh_params_t base_tcp::m_dh_params;
#endif
base_tcp::base_tcp(int fd, Logit *log, Logit *debug, results *res
#ifdef USE_SSL
, int ssl
#endif
) :
#ifdef USE_SSL
m_canTLS(false),
m_useTLS(ssl),
#endif
m_sock(fd)
, m_start(0)
, m_end(0)
, m_open(true)
, m_log(log)
, m_debug(debug)
, m_res(res)
#ifdef USE_SSL
#ifdef USE_OPENSSL
, m_sslMeth(NULL)
, m_sslCtx(NULL)
, m_ssl(NULL)
#else
, m_gnutls_session(NULL)
#endif
, m_isTLS(false)
#endif
{
m_poll.fd = m_sock;
#ifdef USE_SSL
if(m_useTLS)
{
#ifdef USE_OPENSSL
//don't seem to need this SSL_library_init();
SSLeay_add_ssl_algorithms();
SSL_load_error_strings();
#endif
}
#endif
}
base_tcp::~base_tcp()
{
}
#ifdef USE_SSL
#ifdef USE_GNUTLS
#define DH_BITS 1024
void base_tcp::m_initialize_tls_session()
{
const int kx_prio[] = { GNUTLS_KX_ANON_DH, 0 };
gnutls_init(&m_gnutls_session, GNUTLS_SERVER);
/* avoid calling all the priority functions, since the defaults
* are adequate.
*/
gnutls_set_default_priority(m_gnutls_session);
gnutls_kx_set_priority(m_gnutls_session, kx_prio);
gnutls_credentials_set(m_gnutls_session, GNUTLS_CRD_ANON, m_anoncred);
gnutls_dh_set_prime_bits(m_gnutls_session, DH_BITS);
}
void base_tcp::m_generate_dh_params()
{
/* Generate Diffie Hellman parameters - for use with DHE
* kx algorithms. These should be discarded and regenerated
* once a day, once a week or once a month. Depending on the
* security requirements.
*/
gnutls_dh_params_init(&m_dh_params);
gnutls_dh_params_generate2(m_dh_params, DH_BITS);
}
#endif // USE_GNUTLS
int base_tcp::ConnectTLS()
{
#ifdef USE_OPENSSL
m_sslMeth = NULL;
m_sslCtx = NULL;
m_ssl = NULL;
m_sslMeth = SSLv2_client_method();
if(m_sslMeth == NULL)
{
fprintf(stderr, "Can't get SSLv2_client_method.\n");
return 2;
}
m_sslCtx = SSL_CTX_new(m_sslMeth);
if(m_sslCtx == NULL)
{
fprintf(stderr, "Can't SSL_CTX_new\n");
return 2;
}
if((m_ssl = SSL_new(m_sslCtx)) == NULL)
{
fprintf(stderr, "Can't SSL_new\n");
SSL_CTX_free(m_sslCtx);
return 2;
}
SSL_set_fd(m_ssl, m_sock);
if(-1 == SSL_connect(m_ssl))
{
fprintf(stderr, "Can't SSL_CONNECT\n");
SSL_free(m_ssl);
SSL_CTX_free(m_sslCtx);
return 1;
}
m_isTLS = true;
// debugging code that may be useful to have around in a commented-out state.
#if 0
/* Following two steps are optional and not required for
data exchange to be successful. */
/* Get the cipher - opt */
printf ("SSL connection using %s\n", SSL_get_cipher(m_ssl));
/* Get server's certificate (note: beware of dynamic allocation) - opt */
X509 *server_cert;
server_cert = SSL_get_peer_certificate(m_ssl);
if(!server_cert)
{
fprintf(stderr, "Can't SSL_get_peer_certificate\n");
return 2;
}
printf ("Server certificate:\n");
char *str = X509_NAME_oneline(X509_get_subject_name(server_cert),0,0);
if(!str)
{
fprintf(stderr, "Can't X509_NAME_oneline\n");
return 2;
}
printf ("\t subject: %s\n", str);
Free (str);
str = X509_NAME_oneline (X509_get_issuer_name(server_cert),0,0);
if(!str)
{
fprintf(stderr, "Can't X509_get_issuer_name\n");
return 2;
}
printf ("\t issuer: %s\n", str);
Free (str);
/* We could do all sorts of certificate verification stuff here before
deallocating the certificate. */
X509_free(server_cert);
#endif // 0
#else
gnutls_anon_allocate_server_credentials(&m_anoncred);
m_initialize_tls_session();
if(!m_init_dh_params)
{
m_init_dh_params = 1;
m_generate_dh_params();
}
gnutls_anon_set_server_dh_params(m_anoncred, m_dh_params);
gnutls_transport_set_ptr(m_gnutls_session, (gnutls_transport_ptr_t)m_sock);
int rc = gnutls_handshake(m_gnutls_session);
if(rc < 0)
{
gnutls_deinit(m_gnutls_session);
return 2;
}
m_isTLS = 1;
/* request client certificate if any.
*/
gnutls_certificate_server_set_request(m_gnutls_session, GNUTLS_CERT_REQUEST);
#endif // USE_OPENSSL
return 0;
}
#endif // USE_SSL
int base_tcp::disconnect()
{
if(m_open)
{
#ifdef USE_SSL
if(m_isTLS)
{
#ifdef USE_OPENSSL
SSL_shutdown(m_ssl);
close(m_sock);
SSL_free(m_ssl);
SSL_CTX_free(m_sslCtx);
m_isTLS = false;
#else
#endif
}
else
#endif
{
close(m_sock);
}
}
m_open = false;
return 0;
}
ERROR_TYPE base_tcp::printf(CPCCHAR fmt, ...)
{
va_list argp;
va_start(argp, fmt);
char buf[1024];
int len = vsnprintf(buf, sizeof(buf), fmt, argp);
if(len > (int)sizeof(buf))
len = sizeof(buf);
return sendData(buf, len);
}
ERROR_TYPE base_tcp::sendData(CPCCHAR buf, int size)
{
if(!m_open)
return eCorrupt;
int sent = 0;
m_poll.events = POLLOUT | POLLERR | POLLHUP;
int rc;
while(sent != size)
{
rc = poll(&m_poll, 1, 60000);
if(rc == 0)
{
fprintf(stderr, "Server timed out on write.\n");
return eTimeout;
}
if(rc < 0)
{
fprintf(stderr, "Poll error.\n");
return eSocket;
}
#ifdef USE_SSL
if(m_isTLS)
{
#ifdef USE_OPENSSL
rc = SSL_write(m_ssl, &buf[sent], size - sent);
#else
rc = gnutls_record_send(m_gnutls_session, &buf[sent], size - sent);
#endif
}
else
#endif
{
rc = write(m_sock, &buf[sent], size - sent);
}
if(rc < 1)
{
// fprintf(stderr, "Can't write to socket.\n");
return eSocket;
}
if(m_debug)
m_debug->Write(buf, rc);
sent += rc;
}
sentData(size);
return eNoError;
}
int base_tcp::readLine(char *buf, int bufSize, bool stripCR, int timeout)
{
if(!m_open)
return eCorrupt;
int ind = 0;
if(m_start < m_end)
{
do
{
buf[ind] = m_buf[m_start];
ind++;
m_start++;
}
while(m_start < m_end && m_buf[m_start - 1] != '\n' && ind < bufSize);
}
if(ind == bufSize || (ind > 0 && buf[ind - 1] == '\n') )
{
receivedData(ind);
if(m_debug)
m_debug->Write(buf, ind);
if(ind < bufSize)
{
ind--;
buf[ind] = '\0';
if(stripCR && buf[ind - 1] == '\r')
{
ind--;
buf[ind] = '\0';
}
}
return ind;
}
// buffer is empty
m_start = 0;
m_end = 0;
time_t now = time(NULL);
m_poll.events = POLLIN | POLLERR | POLLHUP;
while(1)
{
int tmo = timeout - (time(NULL) - now);
int rc;
if(tmo < 0 || (rc = poll(&m_poll, 1, tmo * 1000)) == 0)
{
return eTimeout;
}
if(rc < 0)
{
fprintf(stderr, "Poll error.\n");
return eCorrupt;
}
#ifdef USE_SSL
if(m_isTLS)
{
#ifdef USE_OPENSSL
rc = SSL_read(m_ssl, m_buf, sizeof(m_buf));
#else
rc = gnutls_record_recv(m_gnutls_session, m_buf, sizeof(m_buf));
#endif
}
else
#endif
{
rc = read(m_sock, m_buf, sizeof(m_buf));
}
if(rc < 0)
return eSocket;
m_end = rc;
do
{
buf[ind] = m_buf[m_start];
ind++;
m_start++;
} while(m_start < m_end && m_buf[m_start - 1] != '\n' && ind < bufSize);
if(ind == bufSize || (ind > 0 && buf[ind - 1] == '\n') )
{
receivedData(ind);
if(m_debug)
m_debug->Write(buf, ind);
if(ind < bufSize)
{
ind--;
buf[ind] = '\0';
if(stripCR && buf[ind - 1] == '\r')
{
ind--;
buf[ind] = '\0';
}
}
return ind;
}
if(m_start == m_end)
{
m_start = 0;
m_end = 0;
}
}
return 0; // never reached
}
void base_tcp::sentData(int)
{
}
void base_tcp::receivedData(int bytes)
{
m_res->dataBytes(bytes);
}
syntax highlighted by Code2HTML, v. 0.9.1