#define TCP_BODY

#include "tcp.h"

#include <sys/socket.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <errno.h>
#include "postal.h"
#include "userlist.h"
#include "address.h"
#include "logit.h"

tcp::tcp(int *exitCount, const char *addr, unsigned short default_port
       , Logit *log
#ifdef USE_SSL
       , int ssl
#endif
       , const char *sourceAddr, Logit *debug)
 : m_md5()
 , m_destAffinity(0)
 , m_log(log)
#ifdef USE_SSL
 , m_canTLS(false)
 , m_useTLS(ssl)
#endif
 , m_exitCount(exitCount)
 , m_fd(-1)
 , m_start(0)
 , m_end(0)
 , m_open(false)
 , m_addr(new address(addr, default_port))
 , m_sourceAddr(NULL)
 , m_debug(debug)
#ifdef USE_SSL
#ifdef USE_OPENSSL
 , m_sslMeth(NULL)
 , m_sslCtx(NULL)
 , m_ssl(NULL)
#else
 , m_gnutls_session(NULL)
 , m_gnutls_anoncred(NULL)
#endif
 , m_isTLS(false)
#endif
{
  if(sourceAddr)
    m_sourceAddr = new address(sourceAddr);
  m_destAffinity = getThreadNum() % m_addr->addressCount();
#ifdef USE_SSL
  if(m_useTLS)
  {
#ifdef USE_OPENSSL
//don't seem to need this    SSL_library_init();
    SSLeay_add_ssl_algorithms();
    SSL_load_error_strings();
#endif
  }
#endif
}

tcp::tcp(int threadNum, const tcp *parent)
 : Thread(threadNum, parent)
 , m_md5()
 , m_destAffinity(parent->m_destAffinity)
 , m_log(parent->m_log)
#ifdef USE_SSL
 , m_canTLS(false)
 , m_useTLS(parent->m_useTLS)
#endif
 , m_exitCount(parent->m_exitCount)
 , m_fd(-1)
 , m_start(0)
 , m_end(0)
 , m_open(false)
 , m_addr(parent->m_addr)
 , m_sourceAddr(parent->m_sourceAddr)
 , m_debug(parent->m_debug ? new Logit(*(parent->m_debug), threadNum) : NULL)
#ifdef USE_SSL
#ifdef USE_OPENSSL
 , m_sslMeth(NULL)
 , m_sslCtx(NULL)
 , m_ssl(NULL)
#else
 , m_gnutls_session(NULL)
 , m_gnutls_anoncred(NULL)
#endif
 , m_isTLS(false)
#endif
{
}

tcp::~tcp()
{
  disconnect();
  if(getThreadNum() < 1)
  {
    delete m_addr;
    delete m_sourceAddr;
  }
  if(m_debug)
    delete m_debug;
}

int tcp::Connect(short port)
{
  if(*m_exitCount)
    return 1;
#ifdef USE_SSL
  m_canTLS = false;
#endif
  m_start = 0;
  m_end = 0;
#ifdef USE_SSL
  m_isTLS = false;
#endif
  sockaddr *sa;
  sa = m_addr->get_addr(m_destAffinity);
  if(!sa)
    return 1;
  m_fd = socket(PF_INET, SOCK_STREAM, 0);
  if(m_fd < 0)
  {
    fprintf(stderr, "Can't open socket.\n");
    error();
    return 2;
  }
  int rc;
  if(m_sourceAddr)
  {
    sockaddr *source;
    source = (sockaddr *)m_sourceAddr->get_rand_addr();
    rc = bind(m_fd, source, sizeof(struct sockaddr_in));
    if(rc)
    {
      fprintf(stderr, "Can't bind to port.\n");
      error();
      close(m_fd);
      return 2;
    }
  }
  m_poll.fd = m_fd;
  if(port)
  {
    struct sockaddr_in newAddr;
    memcpy(&newAddr, sa, sizeof(newAddr));
    newAddr.sin_port = htons(port);
    rc = connect(m_fd, (sockaddr *)&newAddr, sizeof(struct sockaddr_in));
  }
  else
  {
    rc = connect(m_fd, sa, sizeof(struct sockaddr_in));
  }
  if(rc)
  {
    fprintf(stderr, "Can't connect to %s port %d.\n"
                  , inet_ntoa(((sockaddr_in *)sa)->sin_addr)
                  , int(ntohs(((sockaddr_in *)sa)->sin_port)) );
    error();
    close(m_fd);
    return 1;
  }
  socklen_t namelen = sizeof(m_connectionLocalAddr);
  rc = getsockname(m_fd, (struct sockaddr *)&m_connectionLocalAddr, &namelen);
  if(rc)
    fprintf(stderr, "Can't getsockname!\n");
  if(m_debug)
    m_debug->reopen();
  m_open = true;
  return 0;
}

#ifdef USE_SSL
int tcp::ConnectTLS()
{
#ifdef USE_OPENSSL
  m_sslCtx = NULL;
  m_ssl = NULL;
  m_sslMeth = SSLv2_client_method();
  if(m_sslMeth == NULL)
  {
    fprintf(stderr, "Can't get SSLv2_client_method.\n");
    error();
    return 2;
  }
  m_sslCtx = SSL_CTX_new(m_sslMeth);
  if(m_sslCtx == NULL)
  {
    fprintf(stderr, "Can't SSL_CTX_new\n");
    error();
    return 2;
  }
  if((m_ssl = SSL_new(m_sslCtx)) == NULL)
  {
    fprintf(stderr, "Can't SSL_new\n");
    SSL_CTX_free(m_sslCtx);
    error();
    return 2;
  }
  SSL_set_fd(m_ssl, m_fd);
  if(-1 == SSL_connect(m_ssl))
  {
    fprintf(stderr, "Can't SSL_CONNECT\n");
    SSL_free(m_ssl);
    SSL_CTX_free(m_sslCtx);
    error();
    return 1;
  }
#else
  gnutls_anon_allocate_client_credentials(&m_gnutls_anoncred);
  m_gnutls_session = new gnutls_session_t;
  // Initialize TLS session
  gnutls_init (m_gnutls_session, GNUTLS_CLIENT);
  // Use default priorities
  gnutls_set_default_priority(*m_gnutls_session); // bug in gnutls interface?
  // Need to enable anonymous KX specifically
  const int kx_prio[] = { GNUTLS_KX_ANON_DH, 0 };
  gnutls_kx_set_priority(*m_gnutls_session, kx_prio);
  // put the anonymous credentials to the current session
  gnutls_credentials_set(*m_gnutls_session, GNUTLS_CRD_ANON, m_gnutls_anoncred);
  gnutls_transport_set_ptr(*m_gnutls_session, (gnutls_transport_ptr_t)m_fd);

  // Perform the TLS handshake
  if(gnutls_handshake(*m_gnutls_session) < 0)
  {
    fprintf(stderr, "Can't gnutls_handshake\n");
    gnutls_deinit(*m_gnutls_session);
    gnutls_anon_free_client_credentials(m_gnutls_anoncred);
    error();
    return 1;
  }
#endif
  if(*m_exitCount > 1)
    return 3;
  m_isTLS = true;

#if 0
// openssl debugging code that may be useful to have around
// in a commented-out state.
  /* Following two steps are optional and not required for
     data exchange to be successful. */
 
  /* Get the cipher - opt */
 
  printf ("SSL connection using %s\n", SSL_get_cipher(m_ssl));
 
  /* Get server's certificate (note: beware of dynamic allocation) - opt */
 
  X509 *server_cert;
  server_cert = SSL_get_peer_certificate(m_ssl);
  if(!server_cert)
  {
    fprintf(stderr, "Can't SSL_get_peer_certificate\n");
    return 2;
  }
  printf ("Server certificate:\n");
 
  char *str = X509_NAME_oneline(X509_get_subject_name(server_cert),0,0);
  if(!str)
  {
    fprintf(stderr, "Can't X509_NAME_oneline\n");
    return 2;
  }
  printf ("\t subject: %s\n", str);
  Free (str);
  str = X509_NAME_oneline (X509_get_issuer_name(server_cert),0,0);
  if(!str)
  {
    fprintf(stderr, "Can't X509_get_issuer_name\n");
    return 2;
  }
  printf ("\t issuer: %s\n", str);
  Free (str);
 
  /* We could do all sorts of certificate verification stuff here before
     deallocating the certificate. */
 
  X509_free(server_cert);
#endif 
  return 0;
}
#endif

int tcp::disconnect()
{
  if(m_open)
  {
#ifdef USE_SSL
    if(m_isTLS)
    {
#ifdef USE_OPENSSL
      SSL_shutdown(m_ssl);
      close(m_fd);
      SSL_free(m_ssl);
      SSL_CTX_free(m_sslCtx);
#else
      gnutls_bye(*m_gnutls_session, GNUTLS_SHUT_RDWR);
      close(m_fd);
      gnutls_deinit(*m_gnutls_session);
      gnutls_anon_free_client_credentials(m_gnutls_anoncred);
#endif
      m_isTLS = false;
    }
    else
#endif
    {
      close(m_fd);
    }
  }
  m_open = false;
  return 0;
}

ERROR_TYPE tcp::sendData(CPCCHAR buf, int size)
{
  if(!m_open)
    return eCorrupt;
  int sent = 0;
  m_poll.events = POLLOUT | POLLERR | POLLHUP | POLLNVAL;
  int rc;
  while(sent != size)
  {
    if(*m_exitCount > 1)
      return eCtrl_C;
    rc = poll(&m_poll, 1, 60000);
    if(rc == 0)
    {
      fprintf(stderr, "Server timed out on write.\n");
      error();
      return eTimeout;
    }
    if(rc < 0)
    {
      fprintf(stderr, "Poll error.\n");
      error();
      return eSocket;
    }
#ifdef USE_SSL
    if(m_isTLS)
    {
#ifdef USE_OPENSSL
      rc = SSL_write(m_ssl, &buf[sent], size - sent);
#else
      rc = gnutls_record_send(*m_gnutls_session, &buf[sent], size - sent);
#endif
    }
    else
#endif
    {
      rc = write(m_fd, &buf[sent], size - sent);
    }
    if(rc < 1)
    {
      fprintf(stderr, "Can't write to socket.\n");
      error();
      return eSocket;
    }
    if(m_debug)
      m_debug->Write(buf, rc);
    sent += rc;
  }
  sentData(size);
  return eNoError;
}

// fgets() doesn't
int tcp::readLine(char *buf, int bufSize)
{
  if(!m_open)
    return eCorrupt;
  int ind = 0;
  if(m_start < m_end)
  {
    do
    {
      buf[ind] = m_buf[m_start];
      ind++;
      m_start++;
    }
    while(m_start < m_end && m_buf[m_start - 1] != '\n' && ind < bufSize);
  }
  if(ind == bufSize || (ind > 0 && buf[ind - 1] == '\n') )
  {
    if(ind < bufSize)
      buf[ind] = '\0';
    receivedData(ind);
    if(m_debug)
      m_debug->Write(buf, ind);
    return ind;
  }
  // buffer is empty
  m_start = 0;
  m_end = 0;

  time_t now = time(NULL);
  m_poll.events = POLLIN | POLLERR | POLLHUP | POLLNVAL;
  while(1)
  {
    if(*m_exitCount > 1)
      return 3;
    int timeout = 60 - (time(NULL) - now);
    int rc;
    if(timeout < 0 || (rc = poll(&m_poll, 1, timeout * 1000)) == 0)
    {
      fprintf(stderr, "Server timed out on read.\n");
      error();
      return eSocket;
    }
    if(rc < 0)
    {
      if(errno == EINTR)
        continue;
      fprintf(stderr, "Poll error.\n");
      error();
      return eCorrupt;
    }
#ifdef USE_SSL
    if(m_isTLS)
    {
#ifdef USE_OPENSSL
      rc = SSL_read(m_ssl, m_buf, sizeof(m_buf));
#else
      rc = gnutls_record_recv(*m_gnutls_session, m_buf, sizeof(m_buf));
#endif
    }
    else
#endif
    {
      rc = read(m_fd, m_buf, sizeof(m_buf));
    }
    if(rc < 0)
    {
      error();
      return eSocket;
    }
    m_end = rc;
    do
    {
      buf[ind] = m_buf[m_start];
      ind++;
      m_start++;
    } while(m_start < m_end && m_buf[m_start - 1] != '\n' && ind < bufSize);

    if(ind == bufSize || (ind > 0 && buf[ind - 1] == '\n') )
    {
      if(ind < bufSize)
        buf[ind] = '\0';
      receivedData(ind);
      if(m_debug)
        m_debug->Write(buf, ind);
      return ind;
    }
    if(m_start == m_end)
    {
      m_start = 0;
      m_end = 0;
    }
  }
  return 0;
}

ERROR_TYPE tcp::sendCommandString(const string &str, bool important)
{
  return sendCommandData(str.c_str(), str.size(), important);
}

ERROR_TYPE tcp::sendCommandData(const char *buf, int size, bool important)
{
  ERROR_TYPE rc = sendData(buf, size);
  if(rc)
    return rc;
  return readCommandResp(important);
}

void tcp::endIt()
{
  if(m_open)
    close(m_fd);
  m_open = false;
}

int tcp::doAllWork(int rate)
{
  double workCount = 0.0;
  char data[2048];
  time_t lastTime = time(NULL) - RESULTS_LAG;
  int toSend;
  for(unsigned int i = 0; i < sizeof(data); i++)
    data[i] = 1;
 
  while(1)
  {
    int rc;
    if(rate)
    {
      time_t newTime = time(NULL);
      workCount += double(newTime - lastTime) / 60.0 * double(rate);
      toSend = int(workCount);
      if(toSend > int(sizeof(data)) )
      {
        toSend = sizeof(data);
        workCount = 0.0;
      }
      else
      {
        workCount -= double(toSend);
      }
      lastTime = newTime;
    }
    else
      toSend = sizeof(data);
    // NB if data can't be written then it is discarded.
    // Buffer size will be at least 1024 bytes, if worker threads aren't
    // keeping up then we don't really want more than that in the queue.
    if(toSend)
    {
      rc = WriteWork(data, toSend, 5);
      if(rc < 0)
        return -rc;
    }
 
    if(*m_exitCount > 1)
      return -1;
    rc = pollRead();
    if(rc)
      return rc;
  }
  return 0;
}



syntax highlighted by Code2HTML, v. 0.9.1