/* $Cambridge: hermes/src/prayer/accountd/iostream.c,v 1.1.1.1 2003/04/15 13:00:02 dpc22 Exp $ */
/************************************************
 *    Prayer - a Webmail Interface              *
 ************************************************/

/* Copyright (c) University of Cambridge 2000 - 2002 */
/* See the file NOTICE for conditions of use and distribution. */

/* Some simple stream IO functions that work with sockets (separate
   read/write buffers) and provide timeout options on read.
   Also supply a (more or less) transparent interface to SSL wrapped streams*/

#include "accountd.h"

/* iostream_init() *******************************************************
 *
 * Initialiase iostream subsystem.
 ************************************************************************/

void iostream_init(struct config *config)
{
    ssl_context_init(config);   /* initialize global SSL context */
}

/* iostream_check_rsakey() ***********************************************
 *
 * Initialiase iostream subsystem.
 ************************************************************************/

void iostream_check_rsakey(struct config *config)
{
    ssl_check_rsakey(config);
}

/* iostream_freshen_rsakey() *********************************************
 *
 * Initialiase iostream subsystem.
 ************************************************************************/

void iostream_freshen_rsakey(struct config *config)
{
    ssl_freshen_rsakey(config);
}

/* iostream_init() *******************************************************
 *
 * Clear up iostream subsystem.
 ************************************************************************/

void iostream_exit()
{
    ssl_context_free();         /* free global SSL context */
}

/* ====================================================================== */

/* iostream_create() *****************************************************
 *
 * Create a new iostream and bind it to socket descriptor.
 *       pool: Target pool for this iostream and its buffers
 *     sockfd: Socket decriptor that we want to bind to
 *  blocksize: Size of I/O buffers on this iostream.
 *
 * Returns: New iostream structure
 ************************************************************************/

struct iostream *iostream_create(struct pool *pool, int sockfd,
                                 unsigned long blocksize)
{
    struct iostream *x = pool_alloc(pool, sizeof(struct iostream));

    x->pool = pool;
    x->blocksize = (blocksize) ? blocksize : IOSTREAM_PREFERRED_BLOCK_SIZE;
    x->fd = sockfd;
    x->ssl = NIL;
    x->bio = NIL;
    x->debug = NIL;
    x->blocking = T;

    x->ibuffer = pool_alloc(pool, x->blocksize);
    x->ibufend = x->ibuffer;
    x->icurrent = x->ibuffer;
    x->itimeout = 0;            /* Default: No timeout */
    x->ierror = NIL;
    x->ieof = NIL;

    x->fd = sockfd;
    x->obuffer = pool_alloc(pool, x->blocksize);
    x->obufend = x->obuffer + x->blocksize;
    x->ocurrent = x->obuffer;
    x->otimeout = 0;
    x->oerror = NIL;            /* Default: No timeout */
    return (x);
}

/* ====================================================================== */

/* iostream_ssl_start_server() *******************************************
 *
 * Start up server side SSL on given iostream
 *    x: iostream.
 ************************************************************************/

BOOL iostream_ssl_start_server(struct iostream * x)
{
    if (!(x->ssl = ssl_start_server(x->fd, x->itimeout))) {
        x->ierror = x->oerror = T;
        return (NIL);
    }
    return (T);
}

/* ====================================================================== */

/* iostream_ssl_start_client() *******************************************
 *
 * Start up client side SSL on given iostream
 *    x: iostream.
 ************************************************************************/

BOOL iostream_ssl_start_client(struct iostream * x)
{
    if (!(x->ssl = ssl_start_client(x->fd, x->itimeout))) {
        x->ierror = x->oerror = T;
        return (NIL);
    }
    return (T);
}

/* iostream_ssl_enabled() ************************************************
 *
 * Check with SSL enabled on this stream.
 *    x: iostream.
 ************************************************************************/

BOOL iostream_ssl_enabled(struct iostream * x)
{
    return ((x->ssl) ? T : NIL);
}

/* ====================================================================== */

/* iostream_free() *******************************************************
 *
 * Free iostream.
 ************************************************************************/

void iostream_free(struct iostream *x)
{
    if (x->ssl)
        ssl_free(x->ssl);

    if (x->pool)
        return;

    if (x->ibuffer)
        free(x->ibuffer);

    if (x->obuffer)
        free(x->obuffer);

    free(x);
}

/* iostream_free_buffers() ***********************************************
 *
 * Free iostream buffers: switching to raw I/O here
 ************************************************************************/

void iostream_free_buffers(struct iostream *x)
{
    ioflush(x);

    if (x->ibuffer)
        free(x->ibuffer);

    if (x->obuffer)
        free(x->obuffer);

    x->ibuffer = x->obuffer = NIL;
}

/* iostream_close() ******************************************************
 *
 * Close down and free iostream
 ************************************************************************/

void iostream_close(struct iostream *x)
{
    int fd = x->fd;

    if (x->obuffer)
        iostream_flush(x);      /* Flush data in write buffer */

    if (x->ssl)
        ssl_shutdown(x->ssl);

    iostream_free(x);
    close(fd);
}

/* iostream_debug() ******************************************************
 *
 * Enable or disable debugging on this iostream
 *       x: iostream
 *  enable: T => enable. NIL => disable.
 ************************************************************************/

void iostream_debug(struct iostream *x, BOOL enable)
{
    x->debug = enable;
}

/* iostream_block() ******************************************************
 *
 * Enable or disable blocking I/O on this iostream
 *       x: iostream
 *  enable: T => enable. NIL => disable.
 ************************************************************************/

void iostream_block(struct iostream *x, BOOL enable)
{
    x->blocking = enable;
}

/* iostream_set_timeout() ************************************************
 *
 * Set timeout on this iostream
 *         x: iostream
 *   timeout: timeout in seconds. 0 => disable timeout
 ************************************************************************/

void iostream_set_timeout(struct iostream *x, time_t timeout)
{
    x->itimeout = timeout;
    x->otimeout = timeout;
}

/* ====================================================================== */

/* iostream_getchar() ****************************************************
 *
 * Get a single character from I/O stream. Normally called from via
 * iogetc() macro.
 *     x: IOstream
 *
 * Returns: unsigned character.
 *          EOF  => end of file reached
 *          EOS  => temporary end of stream if non-blocking I/O enabled.
 ************************************************************************/

int iostream_getchar(struct iostream *x)
{
    int len;
    fd_set readfds;
    struct timeval timeout;

    if (x->ierror || x->ieof)
        return (EOF);

    if (x->icurrent < x->ibufend)
        return (*(x->icurrent++));

    /* Check for data in SSL buffer than select() won't block */
    if (x->ssl && ssl_pending(x->ssl)) {
        while ((len = ssl_read(x->ssl, x->ibuffer, x->blocksize)) < 0) {
            if (errno != EINTR) {
                x->ieof = T;
                return (EOF);
            }
        }
        if (len == 0) {
            x->ieof = T;
            return (EOF);
        }

        if (x->debug)
            write(STDERR_FILENO, x->ibuffer, len);

        x->ibufend = x->ibuffer + len;
        x->icurrent = x->ibuffer;

        return (*(x->icurrent++));
    }

    if (x->blocking) {
        if (x->itimeout > 0) {
            FD_ZERO(&readfds);
            FD_SET(x->fd, &readfds);

            timeout.tv_sec = x->itimeout;
            timeout.tv_usec = 0;

            while (select(x->fd + 1, &readfds, NIL, NIL, &timeout) < 0) {
                if (errno != EINTR)
                    log_fatal("iostream_getchar(): select() failed");
            }
            if (!FD_ISSET(x->fd, &readfds)) {
                x->ierror = T;
                return (EOF);
            }
        }
    } else {
        FD_ZERO(&readfds);
        FD_SET(x->fd, &readfds);

        /* Non blocking select */
        timeout.tv_sec = 0;
        timeout.tv_usec = 0;

        while (select(x->fd + 1, &readfds, NIL, NIL, &timeout) < 0) {
            if (errno != EINTR)
                log_fatal("iostream_getchar(): select() failed");
        }
        if (!FD_ISSET(x->fd, &readfds))
            return (EOS);
    }

    if (x->ssl) {
        while ((len = ssl_read(x->ssl, x->ibuffer, x->blocksize)) < 0) {
            if (errno != EINTR) {
                x->ieof = T;
                return (EOF);
            }
        }
        if (len == 0) {
            x->ieof = T;
            return (EOF);
        }

        if (x->debug)
            write(STDERR_FILENO, x->ibuffer, len);

        x->ibufend = x->ibuffer + len;
        x->icurrent = x->ibuffer;

        return (*(x->icurrent++));
    }

    while ((len = read(x->fd, x->ibuffer, x->blocksize)) < 0) {
        if (errno != EINTR) {
            x->ieof = T;
            return (EOF);
        }
    }
    if (len == 0) {
        x->ieof = T;
        return (EOF);
    }

    if (x->debug)
        write(STDERR_FILENO, x->ibuffer, len);

    x->ibufend = x->ibuffer + len;
    x->icurrent = x->ibuffer;
    return (*(x->icurrent++));
}

/* ====================================================================== */

/* iostream_ungetchar() **************************************************
 *
 * Unget character from stream. Normally called via ioungetc() macro.
 ************************************************************************/

void iostream_ungetchar(char c, struct iostream *x)
{
    *(--x->icurrent) = c;
}

/* ====================================================================== */

/* Check whether more input pending on this iostream */

/* iostream_have_buffered_input() ****************************************
 *
 * Check for input sitting in iostream or SSL buffers
 *                  x: iostream
 * ignore_white_space: Only trigger if non-whitespace characters queued.
 *
 * Returns: T => have buffered input
 ************************************************************************/

BOOL
iostream_have_buffered_input(struct iostream *x, BOOL ignore_whitespace)
{
    unsigned char *s;

    if (x->icurrent < x->ibufend) {
        if (!ignore_whitespace)
            return (T);

        for (s = x->icurrent; s < x->ibufend; s++) {
            if ((*s != '\015') && (*s != '\012'))
                return (T);
        }
    }

    if (x->ssl && ssl_pending(x->ssl))
        return (T);

    return (NIL);
}

/* iostream_have_buffered_input() ****************************************
 *
 * Check for input sitting in iostream or SSL buffers or pending on socket
 *    x: iostream
 *
 * Returns: T => have buffered input or read() can return without blocking
 ************************************************************************/

BOOL iostream_have_input(struct iostream * x)
{
    fd_set readfds;
    struct timeval timeout;

    if (x->icurrent < x->ibufend) {
        unsigned char *s;

        for (s = x->icurrent; s < x->ibufend; s++) {
            if ((*s != '\015') && (*s != '\012'))
                return (T);
        }
    }

    if (x->ssl && ssl_pending(x->ssl))
        return (T);

    FD_ZERO(&readfds);
    FD_SET(x->fd, &readfds);

    /* Poll for pending input on readfds */
    timeout.tv_sec = 0;
    timeout.tv_usec = 0;

    while (select(x->fd + 1, &readfds, NIL, NIL, &timeout) < 0) {
        if (errno != EINTR)
            log_fatal("iostream_getchar(): select() failed");
    }

    if (FD_ISSET(x->fd, &readfds))
        return (T);

    return (NIL);
}

/* iostream_is_eof() *****************************************************
 *
 * Check whether iostream has reached EOF on input.
 ************************************************************************/

BOOL iostream_is_eof(struct iostream * x)
{
    return ((x->ieof) ? T : NIL);
}

/* ====================================================================== */

/* iostream_flush() ******************************************************
 *
 * Flush buffered data to iostream.
 ************************************************************************/

BOOL iostream_flush(struct iostream * x)
{
    unsigned char *current = x->obuffer;
    int bytes = x->ocurrent - x->obuffer;
    int count;
    fd_set writefds;
    struct timeval timeout;

    if (x->obuffer == NIL)      /* No output buffer on stream */
        return (T);

    /* Reset ptr to start of output buffer */
    x->ocurrent = x->obuffer;

    while ((!x->oerror) && (bytes > 0)) {
        if (x->otimeout > 0) {
            FD_ZERO(&writefds);
            FD_SET(x->fd, &writefds);

            timeout.tv_sec = x->otimeout;
            timeout.tv_usec = 0;

            while (select(x->fd + 1, NIL, &writefds, NIL, &timeout) < 0) {
                if (errno != EINTR)
                    log_fatal("iostream_flush(): select() failed");
            }

            if (!FD_ISSET(x->fd, &writefds)) {
                x->oerror = T;
                return (NIL);
            }
        }

        if (x->ssl) {
            if ((count = ssl_write(x->ssl, current, bytes)) < 0) {
                if (errno != EINTR) {
                    log_debug("ssl_write() error: %d %d\n",
                              count, ssl_get_error(x->ssl, count));
                    x->oerror = T;
                    break;
                }
            }
            if (count == 0) {
                log_debug("ssl_write() error: %d %d\n",
                          count, ssl_get_error(x->ssl, count));
                x->oerror = T;
            }

            if (x->oerror)
                break;

            if (x->debug)
                write(STDERR_FILENO, current, count);

            current += count;
            bytes -= count;

            continue;
        }

        if ((count = write(x->fd, current, bytes)) < 0) {
            if (errno != EINTR) {
                x->oerror = T;
                break;
            }
        }
        if (count == 0)
            x->oerror = T;

        if (x->oerror)
            break;

        if (x->debug)
            write(STDERR_FILENO, current, count);

        current += count;
        bytes -= count;
    }

    return ((x->oerror) ? NIL : T);
}

/* ====================================================================== */

/* iostream_putchar() ****************************************************
 *
 * Push character through iostream output buffer. Normally called via
 * ioputc macro.
 *      c: Character to push
 *      x: IOstream
 ************************************************************************/

BOOL iostream_putchar(char c, struct iostream * x)
{

    if ((x->ocurrent == x->obufend) && !iostream_flush(x))
        return (NIL);           /* Flush failed */

    *(x->ocurrent++) = c;
    return (T);
}

/* iostream_puts() *******************************************************
 *
 * Push string through iostream output buffer. Normally called via
 * ioputs macro.
 *      x: IOstream
 *      s: String to print
 ************************************************************************/

BOOL iostream_puts(struct iostream * x, unsigned char *s)
{
    unsigned char c;

    if (!s)
        ioputs(x, "(nil)");
    else
        while ((c = *s++))
            ioputc(c, x);

    return (T);
}

/* Static support routine for iostream_printf */

static void iostream_print_ulong(struct iostream *x, unsigned long value)
{
    unsigned long tmp, weight;

    /* All numbers contain at least one digit.
     * Find weight of most significant digit. */
    for (weight = 1, tmp = value / 10; tmp > 0; tmp /= 10)
        weight *= 10;

    for (tmp = value; weight > 0; weight /= 10) {
        if (value >= weight) {  /* Strictly speaking redundant... */
            ioputc('0' + (value / weight), x);  /* Digit other than zero */
            value -= weight * (value / weight); /* Calculate remainder */
        } else
            ioputc('0', x);
    }
}

/* iostream_printf() *****************************************************
 *
 * Print string through iostream output buffer. Normally called via
 * ioprintf macro.
 *      x: IOstream
 *    fmt: String to print, followed by arguments.
 ************************************************************************/

BOOL iostream_printf(struct iostream *x, char *fmt, ...)
{
    va_list ap;
    char *s;
    char c;

    va_start(ap, fmt);

    while ((c = *fmt++)) {
        if (c != '%') {
            ioputc(c, x);
        } else
            switch (*fmt++) {
            case 's':          /* string */
                if ((s = va_arg(ap, char *))) {
                    while ((c = *s++))
                        ioputc(c, x);
                } else
                    ioprintf(x, "(nil)");
                break;
            case 'l':
                if (*fmt == 'u') {
                    iostream_print_ulong(x, va_arg(ap, unsigned long));
                    fmt++;
                } else
                    iostream_print_ulong(x, va_arg(ap, long));
                break;
            case 'd':
                if (*fmt == 'u') {
                    iostream_print_ulong(x, va_arg(ap, unsigned int));
                    fmt++;
                } else
                    iostream_print_ulong(x, va_arg(ap, int));
                break;
            case 'c':
                ioputc((char) va_arg(ap, int), x);
                break;
            case '%':
                ioputc('%', x);
                break;
            default:
                log_fatal("Bad format string to iostream_printf");
            }
    }
    va_end(ap);

    return (T);
}

/* ====================================================================== */

/* iostream_getline() ****************************************************
 *
 * Get line from iostream with known upper bound. Most client routines
 * dealing with arbitary amounts of data won't use this, instead they
 * copy data to a temporary buffer and buffer_fetch() from that.
 *    stream: iostream
 *         s: Target buffer
 *    length: Size of target buffer.
 ************************************************************************/

BOOL iostream_getline(struct iostream * stream, char *s, int length)
{
    int c = EOF;

    length--;                   /* Leave space for trailing '\0' */

    while ((length > 0) && ((c = iogetc(stream)) != EOF)) {
        if (c == '\015')
            continue;

        if (c == '\012')
            break;

        *s++ = c;
        length--;
    }

    *s = '\0';

    return (((c != EOF) && (length > 0)) ? T : NIL);
}

/* iostream_getline_overflow() ********************************************
 *
 * Get line from iostream with known upper bound. Most client routines
 * dealing with arbitary amounts of data won't use this, instead they
 * copy data to a temporary buffer and buffer_fetch() from that.
 *    stream: iostream
 *         s: Target buffer
 *    length: Size of target buffer.
 *   toolong: Input line was too long
 ************************************************************************/

BOOL
iostream_getline_overflow(struct iostream * stream,
                          char *s, int length, BOOL * overflowp)
{
    int c = EOF;

    length--;                   /* Leave space for trailing '\0' */

    while ((length > 0) && ((c = iogetc(stream)) != EOF)) {
        if (c == '\015')
            continue;

        if (c == '\012')
            break;

        *s++ = c;
        length--;
    }

    *s = '\0';

    if (length == 0) {
        if (overflowp)
            *overflowp = T;
        return (NIL);
    }

    if (overflowp)
        *overflowp = NIL;
    return ((c != EOF) ? T : NIL);
}


syntax highlighted by Code2HTML, v. 0.9.1