/* 
   elmo - ELectronic Mail Operator

   Copyright (C) 2002, 2003, 2004 rzyjontko

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; version 2.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software Foundation,
   Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.  

   ----------------------------------------------------------------------

   These are all primitives I need to use when connecting as client.
   
*/
/****************************************************************************
 *    IMPLEMENTATION HEADERS
 ****************************************************************************/

#ifdef HAVE_CONFIG_H
# include <config.h>
#endif

#include <stdio.h>
#include <string.h>
#include <stdarg.h>
#include <time.h>
#include <errno.h>
#include <unistd.h>
#include <netdb.h>
#include <arpa/inet.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <setjmp.h>
#include <signal.h>
#include <regex.h>

#ifdef OPENSSL_SUPPORT
# include <openssl/ssl.h>
#endif

#include "xmalloc.h"
#include "error.h"
#include "networking.h"
#include "gettext.h"
#include "rstring.h"
#include "cmd.h"
#include "debug.h"
#include "clock.h"
#include "str.h"

/****************************************************************************
 *    IMPLEMENTATION PRIVATE DEFINITIONS / ENUMERATIONS / SIMPLE TYPEDEFS
 ****************************************************************************/

#define ARRAY_SIZE 5
#define BUFFER_SIZE (17 * 1024)
#define SEND_CHUNK_SIZE 1000

enum status {
        NET_DOWN,
        NET_RESOLVING,
        NET_CONNECTING,
        NET_READY,
        NET_READING_DATA,
        NET_WRITING_DATA,
        NET_TIMED_OUT,
};

#define MIN(a,b) (((a)<(b))?(a):(b))

/****************************************************************************
 *    IMPLEMENTATION PRIVATE CLASS PROTOTYPES / EXTERNAL CLASS REFERENCES
 ****************************************************************************/
/****************************************************************************
 *    IMPLEMENTATION PRIVATE STRUCTURES / UTILITY CLASSES
 ****************************************************************************/

struct net {
        int             index;       /* index in descriptors array */
        int             used;        /* if this descriptor is free or not */
        
        enum status     status;      /* status of the connection */
  
        char           *server_name; /* server address */
  
        int             sock;        /* socket */
        struct in_addr  local_addr;  /* local address */
        unsigned short  local_port;  /* local port */
        struct in_addr  server_addr; /* remote address */
        unsigned short  server_port; /* remote port */

#ifdef OPENSSL_SUPPORT
        SSL            *ssl;         /* ssl session */
#else
        void           *ssl;
#endif

        int             progress;
        str_t          *progress_desc;  /* progress bar description */
        int             bytes_expected; /* maximal value for progress bar */
        int             bytes_received; /* current value for progress bar */
        
        int             total_sent;  /* bytes sent from */
        int             total_recv;  /* bytes received */
        time_t          time_start;  /* time of the descriptor creation */

        int             is_compiled; /* if terminator has been compiled */
        regex_t         terminator;
        int             read_size;   /* size of read_buffer */
        int             read_fill;   /* how many bytes have been read */
        char           *read_buffer;
        void          (*recv_fun)(char *, int); /* function to be called after successfull receive */

        int             send_size;   /* size of send_buffer */
        int             send_sent;   /* how many bytes have been sent */
        void           *send_buffer; /* data to be sent */
        void          (*send_fun)(int); /* function to be called after successful send */

        void          (*cleanup)(int); /* function to be called if anything goes wrong */
};

/****************************************************************************
 *    IMPLEMENTATION REQUIRED EXTERNAL REFERENCES (AVOID)
 ****************************************************************************/
/****************************************************************************
 *    IMPLEMENTATION PRIVATE DATA
 ****************************************************************************/

/* There is one problem with non-multi-threaded programs.  You cannot
   discard the connection while it's in progress because gethostbyname,
   and connect block until the end of execution.

   To work this around I setup a special SIGINT signal handler, and make
   a nonlocal transfer of control with longjmp function from it.

   The code should look like:

        if (setjmp (discard_connection))
                return 1;
        
        prev_handler = signal (SIGINT, connection_break);
        ret          = do_the_blocking_action ();
        signal (SIGINT, prev_handler);

   The signal handler must restore previous handler in the same manner. */

/* The prev_handler is used to restore a previous handler (which terminates
   the program execution). */
static void (*prev_handler)(int) = NULL;



/* The discard_connection is a longjmp target established before
   gethostbyname, and connect. */
static sigjmp_buf discard_connection;



/* I decided to have private array of descriptors rather than returning
   the pointer to the user because I need to iterate through all of the
   sockets when there is some data available, to see which socket received
   it. */
static struct net descriptors[ARRAY_SIZE];


#ifdef OPENSSL_SUPPORT
/* OpenSSL context. */
static SSL_CTX *ssl_ctx = NULL;
#endif

/****************************************************************************
 *    INTERFACE DATA
 ****************************************************************************/
/****************************************************************************
 *    IMPLEMENTATION PRIVATE FUNCTION PROTOTYPES
 ****************************************************************************/
/****************************************************************************
 *    IMPLEMENTATION PRIVATE FUNCTIONS
 ****************************************************************************/

static void
connection_break (int signum)
{
        signal (SIGINT, prev_handler);
        siglongjmp (discard_connection, 1);
}



static int
resolve_host (struct net *enet)
{
        struct hostent *ht;
        int             ret;

        if (enet == NULL || enet->status != NET_DOWN)
                return 1;
  
        enet->status = NET_RESOLVING;

        ht = gethostbyname (enet->server_name);

        if (! ht){
                error_ (0, "%s: %s", enet->server_name, hstrerror (h_errno));
                ret = 1;
        }
        else {
                enet->server_addr = * (struct in_addr *) ht->h_addr_list[0];
                ret = 0;
        }

        if (ret)
                enet->status = NET_DOWN;
  
        return ret;
}



static int
make_connection (struct net *enet, int secure)
{
        struct sockaddr_in  st;
        socklen_t           st_len = sizeof (struct sockaddr_in);
        int                 ret;

        if (enet->status != NET_RESOLVING)
                return 1;
  
        enet->status = NET_CONNECTING;
  
        enet->sock = socket (PF_INET, SOCK_STREAM, 0);
        if (enet->sock == -1){
                error_ (errno, "%s", inet_ntoa (enet->server_addr));
                return 1;
        }

        st.sin_family = AF_INET;
        st.sin_addr   = enet->server_addr;
        st.sin_port   = enet->server_port;

        ret = connect (enet->sock, (struct sockaddr *) & st, sizeof (st));
        
        if (ret == -1){
                error_ (errno, "%s", inet_ntoa (enet->server_addr));
                return 1;
        }

        if (getsockname (enet->sock, (struct sockaddr *) & st, & st_len)
            == 0){
                enet->local_addr = st.sin_addr;
                enet->local_port = st.sin_port;
        }

        if (secure){
#ifdef OPENSSL_SUPPORT
                enet->ssl = SSL_new (ssl_ctx);
                SSL_set_fd (enet->ssl, enet->sock);
                ret = SSL_connect (enet->ssl);
#else
                error_ (0, "%s", _("ssl requested, but compiled without "
                                   "support"));
                return 1;
#endif
        }

        enet->status = NET_READY;
        return 0;
}



static int
find_free (void)
{
        int i;

        for (i = 0; i < ARRAY_SIZE; i++){
                if (! descriptors[i].used)
                        return i;
        }
        return -1;
}



static void
destroy_enet (struct net *enet)
{
        if (enet->cleanup)
                enet->cleanup (enet->index);
        
        if (enet->server_name)
                xfree (enet->server_name);

        if (enet->is_compiled)
                regfree (& enet->terminator);

        if (enet->read_buffer)
                xfree (enet->read_buffer);

        if (enet->progress != -1)
                progress_close (enet->progress);
        if (enet->progress_desc)
                str_destroy (enet->progress_desc);

#ifdef OPENSSL_SUPPORT
        if (enet->ssl)
                SSL_free (enet->ssl);
#endif
        
        enet->cleanup        = NULL;
        enet->server_name    = NULL;
        enet->used           = 0;
        enet->sock           = -1;
        enet->ssl            = NULL;
        enet->status         = NET_DOWN;
        enet->local_port     = 0;
        enet->is_compiled    = 0;
        enet->read_buffer    = NULL;
        enet->read_size      = 0;
        enet->read_fill      = 0;
        enet->recv_fun       = NULL;
        enet->send_buffer    = NULL;
        enet->send_size      = 0;
        enet->send_sent      = 0;
        enet->send_fun       = NULL;
        enet->total_sent     = 0;
        enet->total_recv     = 0;
        enet->progress       = -1;
        enet->progress_desc  = NULL;
        enet->bytes_expected = 0;
        enet->bytes_received = 0;

        memset (&enet->local_addr, '\0', sizeof (enet->server_addr));
}



static void
shutdown_connection (struct net *enet)
{
#ifdef OPENSSL_SUPPORT
        if (enet->ssl)
                SSL_shutdown (enet->ssl);
        else
#endif
                shutdown (enet->sock, SHUT_RDWR);
        destroy_enet (enet);
}


static void
init_enet (struct net *enet)
{
        destroy_enet (enet);

        enet->used        = 1;
        enet->time_start  = time (NULL);
}



static int
check_read_write_ret (struct net *enet, int ret, const char *s)
{
        if (ret < 0){
                error_ (errno, "%s", s);
                shutdown_connection (enet);
                return 1;
        }

        if (ret == 0){
                error_ (errno, "%s", _("server closed the connection"));
                shutdown_connection (enet);
                return 1;
        }

        return 0;
}



static void
data_reader (struct net *enet)
{
        int        ret;
        regmatch_t matches[1];
        
        if (enet->read_buffer == NULL){
                enet->read_size   = BUFFER_SIZE;
                enet->read_fill   = 0;
                enet->read_buffer = xmalloc (BUFFER_SIZE);
        }

        if (enet->read_fill + 1 >= enet->read_size){
                enet->read_size   *= 2;
                enet->read_buffer  = xrealloc (enet->read_buffer,
                                               enet->read_size);
        }

#ifdef OPENSSL_SUPPORT
        if (enet->ssl){
                ret = SSL_read (enet->ssl,
                                enet->read_buffer + enet->read_fill,
                                enet->read_size - enet->read_fill - 1);
        }
        else
#endif
                ret = recv (enet->sock, enet->read_buffer + enet->read_fill,
                            enet->read_size - enet->read_fill - 1, 0);


        if (check_read_write_ret (enet, ret, "recieving"))
                return;
        
        if (enet->progress != -1)
                progress_advance (enet->progress, ret);
        
        enet->total_recv += ret;
        enet->read_fill  += ret;
        enet->read_buffer[enet->read_fill] = '\0';

        ret = regexec (& enet->terminator, enet->read_buffer, 1, matches, 0);
        if (ret && ret != REG_NOMATCH){
                error_regex (ret, & enet->terminator, NULL);
                shutdown_connection (enet);
                return;
        }

        if (ret == 0){
                int len           = enet->read_fill;
                enet->read_fill   = 0;
                enet->status      = NET_READY;
                enet->is_compiled = 0;
                regfree (& enet->terminator);

                if (enet->progress != -1)
                        progress_close (enet->progress);
                enet->progress = -1;
                
                cmd_del_readfd_handler (enet->sock);

                /* It is absolutely crucial, that this call is the last
                   thing to do in this function.  It is because recv_fun
                   may (implicitly) modify some important data structures.
                   It may remove some readfd handlers (including this one),
                   or add a new one. It may even destroy the enet object. */
                enet->recv_fun (enet->read_buffer, len);
        }
}



static void
data_sender (struct net *enet)
{
        int ret;
        int size;

        size = MIN (SEND_CHUNK_SIZE, enet->send_size - enet->send_sent);

#ifdef OPENSSL_SUPPORT
        if (enet->ssl){
                ret = SSL_write (enet->ssl,
                                 enet->send_buffer + enet->send_sent, size);
        }
        else
#endif
                ret = send (enet->sock, enet->send_buffer + enet->send_sent,
                            size, 0);


        if (check_read_write_ret (enet, ret, "sending"))
                return;

        if (enet->progress_desc && enet->progress == -1){
                enet->progress = progress_setup (enet->send_size, "%s",
                                                 enet->progress_desc->str);
        }

        if (enet->progress != -1)
                progress_advance (enet->progress, ret);
        
        enet->total_sent += ret;
        enet->send_sent  += ret;

        if (enet->send_sent == enet->send_size){
                enet->send_buffer = NULL;
                enet->send_size   = 0;
                enet->send_sent   = 0;
                enet->status      = NET_READY;

                cmd_del_writefd_handler (enet->sock);

                /* It is absolutely crucial, that this call is the last
                   thing to do in this function.  It is because send_fun
                   may (implicitly) modify some important data structures.
                   It may remove some readfd handlers (including this one),
                   or add a new one. It may even destroy the enet object. */
                enet->send_fun (enet->index);
        }
}



static void
read_handler (int fd)
{
        int         i;
        struct net *enet;

        for (i = 0; i < ARRAY_SIZE; i++){
                enet = descriptors + i;

                if (enet->used && enet->sock == fd
                    && enet->status == NET_READING_DATA){

                        data_reader (enet);
                        return;
                }
        }

        error_ (0, "%s", _("no descriptor found: please submit a bug"));
        cmd_del_readfd_handler (fd);
}



static void
write_handler (int fd)
{
        int         i;
        struct net *enet;

        for (i = 0; i < ARRAY_SIZE; i++){
                enet = descriptors + i;

                if (enet->used && enet->sock == fd
                    && enet->status == NET_WRITING_DATA){

                        data_sender (enet);
                        return;
                }
        }

        error_ (0, "%s", _("no descriptor found: please submit a bug"));
        cmd_del_readfd_handler (fd);
}



static void
combo_fun (int nd)
{
        struct net *enet;

        enet         = descriptors + nd;
        enet->status = NET_READING_DATA;

        if (enet->progress != -1)
                progress_close (enet->progress);

        if (enet->bytes_expected != 0){
                enet->progress = progress_setup (enet->bytes_expected, "%s",
                                                 enet->progress_desc->str);
        }

        cmd_add_readfd_handler (enet->sock, read_handler);
}

/****************************************************************************
 *    INTERFACE FUNCTIONS
 ****************************************************************************/



void
net_init (void)
{
        int i;
        
        memset (descriptors, 0, sizeof (descriptors));

        for (i = 0; i < ARRAY_SIZE; i++)
                descriptors[i].index = i;

        signal (SIGPIPE, SIG_IGN);
        
#ifdef OPENSSL_SUPPORT
        SSL_load_error_strings ();
        SSL_library_init ();
        ssl_ctx = SSL_CTX_new (SSLv23_method ());
#endif
}




void
net_free_resources (void)
{
        int i;

        for (i = 0; i < ARRAY_SIZE; i++){
                destroy_enet (descriptors + i);
        }

#ifdef OPENSSL_SUPPORT
        if (ssl_ctx)
                SSL_CTX_free (ssl_ctx);
#endif
}



int
net_open (const char *hostname, unsigned short port, int secure,
          void (*cleanup)(int))
{
        int         p    = -1;
        int         i    = find_free ();
        struct net *enet = descriptors + i;

        if (i == -1){
                error_ (0, "%s", _("too many open connections"));
                return -1;
        }

        init_enet (enet);
        
        enet->cleanup     = cleanup;
        enet->server_name = xstrdup (hostname);
        enet->server_port = htons (port);

        if (sigsetjmp (discard_connection, 1)){
                if (p != -1)
                        progress_close (p);
                error_ (0, "%s", _("connection terminated"));
                destroy_enet (enet);
                return -1;
        }

        p     = progress_setup (1, _("resolving host %s..."), hostname);
        errno = 0;
        prev_handler = signal (SIGINT, connection_break);
        if (resolve_host (enet)){
                if (p != -1)
                        progress_close (p);
                signal (SIGINT, prev_handler);
                destroy_enet (enet);
                return -1;
        }
        
        progress_change_desc (p, _("connecting to %s..."),
                              inet_ntoa (enet->server_addr));
        if (make_connection (enet, secure)){
                if (p != -1)
                        progress_close (p);
                signal (SIGINT, prev_handler);
                destroy_enet (enet);
                return -1;
        }
        signal (SIGINT, prev_handler);
        progress_close (p);

        return i;
}


  
void
net_close (int nd)
{
        struct net *enet;

        if (nd >= ARRAY_SIZE || nd < 0)
                return;
        
        enet = descriptors + nd;
        
        if (enet->sock != -1){
                cmd_del_readfd_handler (enet->sock);
                cmd_del_writefd_handler (enet->sock);
        }
        
        if (enet->status != NET_DOWN){
                shutdown_connection (enet);
        }
        else {
                destroy_enet (enet);
        }
}



void
net_recv_data (int nd, const char *re, void (*fun)(char *buf, int size))
{
        int         ret;
        struct net *enet;

        if (nd >= ARRAY_SIZE || nd < 0)
                return;

        enet = descriptors + nd;
        if (! enet->used || enet->status != NET_READY)
                return;

        ret = regcomp (& enet->terminator, re,
                       REG_ICASE | REG_EXTENDED);
        if (ret){
                error_regex (ret, & enet->terminator, re);
                regfree (& enet->terminator);
                enet->is_compiled = 0;
                return;
        }

        enet->is_compiled = 1;
        enet->status      = NET_READING_DATA;
        enet->recv_fun    = fun;

        cmd_add_readfd_handler (enet->sock, read_handler);
}



void
net_send_data (int nd, void *buf, int size, void (*fun)(int))
{
        struct net *enet;
        
        if (nd >= ARRAY_SIZE || nd < 0)
                return;

        enet = descriptors + nd;
        if (! enet->used || enet->status != NET_READY)
                return;

        enet->send_size   = size;
        enet->send_buffer = buf;
        enet->status      = NET_WRITING_DATA;
        enet->send_fun    = fun;

        cmd_add_writefd_handler (enet->sock, write_handler);
}



void
net_expect (int nd, int size, char *desc, ...)
{
        va_list     ap;
        struct net *enet;
        
        if (nd >= ARRAY_SIZE || nd < 0)
                return;

        enet = descriptors + nd;
        if (! enet->used || enet->status != NET_READY)
                return;

        if (enet->progress_desc)
                str_clear (enet->progress_desc);
        else
                enet->progress_desc = str_create ();
        enet->bytes_expected = size;
        enet->bytes_received = 0;

        va_start (ap, desc);
        str_vsprintf (enet->progress_desc, desc, ap);
        va_end (ap);
}



void
net_send_progress (int nd, char *desc, ...)
{
        va_list     ap;
        struct net *enet;

        if (nd >= ARRAY_SIZE || nd < 0)
                return;

        enet = descriptors + nd;

        if (! enet->used || enet->status != NET_READY)
                return;

        if (enet->progress_desc)
                str_clear (enet->progress_desc);
        else
                enet->progress_desc = str_create ();

        va_start (ap, desc);
        str_vsprintf (enet->progress_desc, desc, ap);
        va_end (ap);
}



void
net_combo (int nd, void *buf, int size, const char *re,
           void (*fun)(char *, int))
{
        int         ret;
        struct net *enet;
        
        if (nd >= ARRAY_SIZE || nd < 0)
                return;

        enet = descriptors + nd;
        if (! enet->used || enet->status != NET_READY)
                return;

        if (enet->is_compiled){
                regfree (& enet->terminator);
                enet->is_compiled = 0;
        }
        
        ret = regcomp (& enet->terminator, re,
                       REG_ICASE | REG_EXTENDED | REG_NEWLINE);
        if (ret){
                error_regex (ret, & enet->terminator, re);
                regfree (& enet->terminator);
                enet->is_compiled = 0;
                return;
        }

        enet->is_compiled = 1;
        enet->recv_fun    = fun;

        net_send_data (nd, buf, size, combo_fun);
}



char *
net_server_address (int nd)
{
        struct net *enet;

        if (nd >= ARRAY_SIZE || nd < 0)
                return NULL;

        enet = descriptors + nd;

        if (! enet->used)
                return NULL;

        return enet->server_name;
}

/****************************************************************************
 *    INTERFACE CLASS BODIES
 ****************************************************************************/
/****************************************************************************
 *
 *    END MODULE networking.c
 *
 ****************************************************************************/


syntax highlighted by Code2HTML, v. 0.9.1