/* $Cambridge: hermes/src/prayer/accountd/iostream.c,v 1.1.1.1 2003/04/15 13:00:02 dpc22 Exp $ */ /************************************************ * Prayer - a Webmail Interface * ************************************************/ /* Copyright (c) University of Cambridge 2000 - 2002 */ /* See the file NOTICE for conditions of use and distribution. */ /* Some simple stream IO functions that work with sockets (separate read/write buffers) and provide timeout options on read. Also supply a (more or less) transparent interface to SSL wrapped streams*/ #include "accountd.h" /* iostream_init() ******************************************************* * * Initialiase iostream subsystem. ************************************************************************/ void iostream_init(struct config *config) { ssl_context_init(config); /* initialize global SSL context */ } /* iostream_check_rsakey() *********************************************** * * Initialiase iostream subsystem. ************************************************************************/ void iostream_check_rsakey(struct config *config) { ssl_check_rsakey(config); } /* iostream_freshen_rsakey() ********************************************* * * Initialiase iostream subsystem. ************************************************************************/ void iostream_freshen_rsakey(struct config *config) { ssl_freshen_rsakey(config); } /* iostream_init() ******************************************************* * * Clear up iostream subsystem. ************************************************************************/ void iostream_exit() { ssl_context_free(); /* free global SSL context */ } /* ====================================================================== */ /* iostream_create() ***************************************************** * * Create a new iostream and bind it to socket descriptor. * pool: Target pool for this iostream and its buffers * sockfd: Socket decriptor that we want to bind to * blocksize: Size of I/O buffers on this iostream. * * Returns: New iostream structure ************************************************************************/ struct iostream *iostream_create(struct pool *pool, int sockfd, unsigned long blocksize) { struct iostream *x = pool_alloc(pool, sizeof(struct iostream)); x->pool = pool; x->blocksize = (blocksize) ? blocksize : IOSTREAM_PREFERRED_BLOCK_SIZE; x->fd = sockfd; x->ssl = NIL; x->bio = NIL; x->debug = NIL; x->blocking = T; x->ibuffer = pool_alloc(pool, x->blocksize); x->ibufend = x->ibuffer; x->icurrent = x->ibuffer; x->itimeout = 0; /* Default: No timeout */ x->ierror = NIL; x->ieof = NIL; x->fd = sockfd; x->obuffer = pool_alloc(pool, x->blocksize); x->obufend = x->obuffer + x->blocksize; x->ocurrent = x->obuffer; x->otimeout = 0; x->oerror = NIL; /* Default: No timeout */ return (x); } /* ====================================================================== */ /* iostream_ssl_start_server() ******************************************* * * Start up server side SSL on given iostream * x: iostream. ************************************************************************/ BOOL iostream_ssl_start_server(struct iostream * x) { if (!(x->ssl = ssl_start_server(x->fd, x->itimeout))) { x->ierror = x->oerror = T; return (NIL); } return (T); } /* ====================================================================== */ /* iostream_ssl_start_client() ******************************************* * * Start up client side SSL on given iostream * x: iostream. ************************************************************************/ BOOL iostream_ssl_start_client(struct iostream * x) { if (!(x->ssl = ssl_start_client(x->fd, x->itimeout))) { x->ierror = x->oerror = T; return (NIL); } return (T); } /* iostream_ssl_enabled() ************************************************ * * Check with SSL enabled on this stream. * x: iostream. ************************************************************************/ BOOL iostream_ssl_enabled(struct iostream * x) { return ((x->ssl) ? T : NIL); } /* ====================================================================== */ /* iostream_free() ******************************************************* * * Free iostream. ************************************************************************/ void iostream_free(struct iostream *x) { if (x->ssl) ssl_free(x->ssl); if (x->pool) return; if (x->ibuffer) free(x->ibuffer); if (x->obuffer) free(x->obuffer); free(x); } /* iostream_free_buffers() *********************************************** * * Free iostream buffers: switching to raw I/O here ************************************************************************/ void iostream_free_buffers(struct iostream *x) { ioflush(x); if (x->ibuffer) free(x->ibuffer); if (x->obuffer) free(x->obuffer); x->ibuffer = x->obuffer = NIL; } /* iostream_close() ****************************************************** * * Close down and free iostream ************************************************************************/ void iostream_close(struct iostream *x) { int fd = x->fd; if (x->obuffer) iostream_flush(x); /* Flush data in write buffer */ if (x->ssl) ssl_shutdown(x->ssl); iostream_free(x); close(fd); } /* iostream_debug() ****************************************************** * * Enable or disable debugging on this iostream * x: iostream * enable: T => enable. NIL => disable. ************************************************************************/ void iostream_debug(struct iostream *x, BOOL enable) { x->debug = enable; } /* iostream_block() ****************************************************** * * Enable or disable blocking I/O on this iostream * x: iostream * enable: T => enable. NIL => disable. ************************************************************************/ void iostream_block(struct iostream *x, BOOL enable) { x->blocking = enable; } /* iostream_set_timeout() ************************************************ * * Set timeout on this iostream * x: iostream * timeout: timeout in seconds. 0 => disable timeout ************************************************************************/ void iostream_set_timeout(struct iostream *x, time_t timeout) { x->itimeout = timeout; x->otimeout = timeout; } /* ====================================================================== */ /* iostream_getchar() **************************************************** * * Get a single character from I/O stream. Normally called from via * iogetc() macro. * x: IOstream * * Returns: unsigned character. * EOF => end of file reached * EOS => temporary end of stream if non-blocking I/O enabled. ************************************************************************/ int iostream_getchar(struct iostream *x) { int len; fd_set readfds; struct timeval timeout; if (x->ierror || x->ieof) return (EOF); if (x->icurrent < x->ibufend) return (*(x->icurrent++)); /* Check for data in SSL buffer than select() won't block */ if (x->ssl && ssl_pending(x->ssl)) { while ((len = ssl_read(x->ssl, x->ibuffer, x->blocksize)) < 0) { if (errno != EINTR) { x->ieof = T; return (EOF); } } if (len == 0) { x->ieof = T; return (EOF); } if (x->debug) write(STDERR_FILENO, x->ibuffer, len); x->ibufend = x->ibuffer + len; x->icurrent = x->ibuffer; return (*(x->icurrent++)); } if (x->blocking) { if (x->itimeout > 0) { FD_ZERO(&readfds); FD_SET(x->fd, &readfds); timeout.tv_sec = x->itimeout; timeout.tv_usec = 0; while (select(x->fd + 1, &readfds, NIL, NIL, &timeout) < 0) { if (errno != EINTR) log_fatal("iostream_getchar(): select() failed"); } if (!FD_ISSET(x->fd, &readfds)) { x->ierror = T; return (EOF); } } } else { FD_ZERO(&readfds); FD_SET(x->fd, &readfds); /* Non blocking select */ timeout.tv_sec = 0; timeout.tv_usec = 0; while (select(x->fd + 1, &readfds, NIL, NIL, &timeout) < 0) { if (errno != EINTR) log_fatal("iostream_getchar(): select() failed"); } if (!FD_ISSET(x->fd, &readfds)) return (EOS); } if (x->ssl) { while ((len = ssl_read(x->ssl, x->ibuffer, x->blocksize)) < 0) { if (errno != EINTR) { x->ieof = T; return (EOF); } } if (len == 0) { x->ieof = T; return (EOF); } if (x->debug) write(STDERR_FILENO, x->ibuffer, len); x->ibufend = x->ibuffer + len; x->icurrent = x->ibuffer; return (*(x->icurrent++)); } while ((len = read(x->fd, x->ibuffer, x->blocksize)) < 0) { if (errno != EINTR) { x->ieof = T; return (EOF); } } if (len == 0) { x->ieof = T; return (EOF); } if (x->debug) write(STDERR_FILENO, x->ibuffer, len); x->ibufend = x->ibuffer + len; x->icurrent = x->ibuffer; return (*(x->icurrent++)); } /* ====================================================================== */ /* iostream_ungetchar() ************************************************** * * Unget character from stream. Normally called via ioungetc() macro. ************************************************************************/ void iostream_ungetchar(char c, struct iostream *x) { *(--x->icurrent) = c; } /* ====================================================================== */ /* Check whether more input pending on this iostream */ /* iostream_have_buffered_input() **************************************** * * Check for input sitting in iostream or SSL buffers * x: iostream * ignore_white_space: Only trigger if non-whitespace characters queued. * * Returns: T => have buffered input ************************************************************************/ BOOL iostream_have_buffered_input(struct iostream *x, BOOL ignore_whitespace) { unsigned char *s; if (x->icurrent < x->ibufend) { if (!ignore_whitespace) return (T); for (s = x->icurrent; s < x->ibufend; s++) { if ((*s != '\015') && (*s != '\012')) return (T); } } if (x->ssl && ssl_pending(x->ssl)) return (T); return (NIL); } /* iostream_have_buffered_input() **************************************** * * Check for input sitting in iostream or SSL buffers or pending on socket * x: iostream * * Returns: T => have buffered input or read() can return without blocking ************************************************************************/ BOOL iostream_have_input(struct iostream * x) { fd_set readfds; struct timeval timeout; if (x->icurrent < x->ibufend) { unsigned char *s; for (s = x->icurrent; s < x->ibufend; s++) { if ((*s != '\015') && (*s != '\012')) return (T); } } if (x->ssl && ssl_pending(x->ssl)) return (T); FD_ZERO(&readfds); FD_SET(x->fd, &readfds); /* Poll for pending input on readfds */ timeout.tv_sec = 0; timeout.tv_usec = 0; while (select(x->fd + 1, &readfds, NIL, NIL, &timeout) < 0) { if (errno != EINTR) log_fatal("iostream_getchar(): select() failed"); } if (FD_ISSET(x->fd, &readfds)) return (T); return (NIL); } /* iostream_is_eof() ***************************************************** * * Check whether iostream has reached EOF on input. ************************************************************************/ BOOL iostream_is_eof(struct iostream * x) { return ((x->ieof) ? T : NIL); } /* ====================================================================== */ /* iostream_flush() ****************************************************** * * Flush buffered data to iostream. ************************************************************************/ BOOL iostream_flush(struct iostream * x) { unsigned char *current = x->obuffer; int bytes = x->ocurrent - x->obuffer; int count; fd_set writefds; struct timeval timeout; if (x->obuffer == NIL) /* No output buffer on stream */ return (T); /* Reset ptr to start of output buffer */ x->ocurrent = x->obuffer; while ((!x->oerror) && (bytes > 0)) { if (x->otimeout > 0) { FD_ZERO(&writefds); FD_SET(x->fd, &writefds); timeout.tv_sec = x->otimeout; timeout.tv_usec = 0; while (select(x->fd + 1, NIL, &writefds, NIL, &timeout) < 0) { if (errno != EINTR) log_fatal("iostream_flush(): select() failed"); } if (!FD_ISSET(x->fd, &writefds)) { x->oerror = T; return (NIL); } } if (x->ssl) { if ((count = ssl_write(x->ssl, current, bytes)) < 0) { if (errno != EINTR) { log_debug("ssl_write() error: %d %d\n", count, ssl_get_error(x->ssl, count)); x->oerror = T; break; } } if (count == 0) { log_debug("ssl_write() error: %d %d\n", count, ssl_get_error(x->ssl, count)); x->oerror = T; } if (x->oerror) break; if (x->debug) write(STDERR_FILENO, current, count); current += count; bytes -= count; continue; } if ((count = write(x->fd, current, bytes)) < 0) { if (errno != EINTR) { x->oerror = T; break; } } if (count == 0) x->oerror = T; if (x->oerror) break; if (x->debug) write(STDERR_FILENO, current, count); current += count; bytes -= count; } return ((x->oerror) ? NIL : T); } /* ====================================================================== */ /* iostream_putchar() **************************************************** * * Push character through iostream output buffer. Normally called via * ioputc macro. * c: Character to push * x: IOstream ************************************************************************/ BOOL iostream_putchar(char c, struct iostream * x) { if ((x->ocurrent == x->obufend) && !iostream_flush(x)) return (NIL); /* Flush failed */ *(x->ocurrent++) = c; return (T); } /* iostream_puts() ******************************************************* * * Push string through iostream output buffer. Normally called via * ioputs macro. * x: IOstream * s: String to print ************************************************************************/ BOOL iostream_puts(struct iostream * x, unsigned char *s) { unsigned char c; if (!s) ioputs(x, "(nil)"); else while ((c = *s++)) ioputc(c, x); return (T); } /* Static support routine for iostream_printf */ static void iostream_print_ulong(struct iostream *x, unsigned long value) { unsigned long tmp, weight; /* All numbers contain at least one digit. * Find weight of most significant digit. */ for (weight = 1, tmp = value / 10; tmp > 0; tmp /= 10) weight *= 10; for (tmp = value; weight > 0; weight /= 10) { if (value >= weight) { /* Strictly speaking redundant... */ ioputc('0' + (value / weight), x); /* Digit other than zero */ value -= weight * (value / weight); /* Calculate remainder */ } else ioputc('0', x); } } /* iostream_printf() ***************************************************** * * Print string through iostream output buffer. Normally called via * ioprintf macro. * x: IOstream * fmt: String to print, followed by arguments. ************************************************************************/ BOOL iostream_printf(struct iostream *x, char *fmt, ...) { va_list ap; char *s; char c; va_start(ap, fmt); while ((c = *fmt++)) { if (c != '%') { ioputc(c, x); } else switch (*fmt++) { case 's': /* string */ if ((s = va_arg(ap, char *))) { while ((c = *s++)) ioputc(c, x); } else ioprintf(x, "(nil)"); break; case 'l': if (*fmt == 'u') { iostream_print_ulong(x, va_arg(ap, unsigned long)); fmt++; } else iostream_print_ulong(x, va_arg(ap, long)); break; case 'd': if (*fmt == 'u') { iostream_print_ulong(x, va_arg(ap, unsigned int)); fmt++; } else iostream_print_ulong(x, va_arg(ap, int)); break; case 'c': ioputc((char) va_arg(ap, int), x); break; case '%': ioputc('%', x); break; default: log_fatal("Bad format string to iostream_printf"); } } va_end(ap); return (T); } /* ====================================================================== */ /* iostream_getline() **************************************************** * * Get line from iostream with known upper bound. Most client routines * dealing with arbitary amounts of data won't use this, instead they * copy data to a temporary buffer and buffer_fetch() from that. * stream: iostream * s: Target buffer * length: Size of target buffer. ************************************************************************/ BOOL iostream_getline(struct iostream * stream, char *s, int length) { int c = EOF; length--; /* Leave space for trailing '\0' */ while ((length > 0) && ((c = iogetc(stream)) != EOF)) { if (c == '\015') continue; if (c == '\012') break; *s++ = c; length--; } *s = '\0'; return (((c != EOF) && (length > 0)) ? T : NIL); } /* iostream_getline_overflow() ******************************************** * * Get line from iostream with known upper bound. Most client routines * dealing with arbitary amounts of data won't use this, instead they * copy data to a temporary buffer and buffer_fetch() from that. * stream: iostream * s: Target buffer * length: Size of target buffer. * toolong: Input line was too long ************************************************************************/ BOOL iostream_getline_overflow(struct iostream * stream, char *s, int length, BOOL * overflowp) { int c = EOF; length--; /* Leave space for trailing '\0' */ while ((length > 0) && ((c = iogetc(stream)) != EOF)) { if (c == '\015') continue; if (c == '\012') break; *s++ = c; length--; } *s = '\0'; if (length == 0) { if (overflowp) *overflowp = T; return (NIL); } if (overflowp) *overflowp = NIL; return ((c != EOF) ? T : NIL); }