/* $Cambridge: hermes/src/prayer/accountd/iostream.c,v 1.1.1.1 2003/04/15 13:00:02 dpc22 Exp $ */
/************************************************
* Prayer - a Webmail Interface *
************************************************/
/* Copyright (c) University of Cambridge 2000 - 2002 */
/* See the file NOTICE for conditions of use and distribution. */
/* Some simple stream IO functions that work with sockets (separate
read/write buffers) and provide timeout options on read.
Also supply a (more or less) transparent interface to SSL wrapped streams*/
#include "accountd.h"
/* iostream_init() *******************************************************
*
* Initialiase iostream subsystem.
************************************************************************/
void iostream_init(struct config *config)
{
ssl_context_init(config); /* initialize global SSL context */
}
/* iostream_check_rsakey() ***********************************************
*
* Initialiase iostream subsystem.
************************************************************************/
void iostream_check_rsakey(struct config *config)
{
ssl_check_rsakey(config);
}
/* iostream_freshen_rsakey() *********************************************
*
* Initialiase iostream subsystem.
************************************************************************/
void iostream_freshen_rsakey(struct config *config)
{
ssl_freshen_rsakey(config);
}
/* iostream_init() *******************************************************
*
* Clear up iostream subsystem.
************************************************************************/
void iostream_exit()
{
ssl_context_free(); /* free global SSL context */
}
/* ====================================================================== */
/* iostream_create() *****************************************************
*
* Create a new iostream and bind it to socket descriptor.
* pool: Target pool for this iostream and its buffers
* sockfd: Socket decriptor that we want to bind to
* blocksize: Size of I/O buffers on this iostream.
*
* Returns: New iostream structure
************************************************************************/
struct iostream *iostream_create(struct pool *pool, int sockfd,
unsigned long blocksize)
{
struct iostream *x = pool_alloc(pool, sizeof(struct iostream));
x->pool = pool;
x->blocksize = (blocksize) ? blocksize : IOSTREAM_PREFERRED_BLOCK_SIZE;
x->fd = sockfd;
x->ssl = NIL;
x->bio = NIL;
x->debug = NIL;
x->blocking = T;
x->ibuffer = pool_alloc(pool, x->blocksize);
x->ibufend = x->ibuffer;
x->icurrent = x->ibuffer;
x->itimeout = 0; /* Default: No timeout */
x->ierror = NIL;
x->ieof = NIL;
x->fd = sockfd;
x->obuffer = pool_alloc(pool, x->blocksize);
x->obufend = x->obuffer + x->blocksize;
x->ocurrent = x->obuffer;
x->otimeout = 0;
x->oerror = NIL; /* Default: No timeout */
return (x);
}
/* ====================================================================== */
/* iostream_ssl_start_server() *******************************************
*
* Start up server side SSL on given iostream
* x: iostream.
************************************************************************/
BOOL iostream_ssl_start_server(struct iostream * x)
{
if (!(x->ssl = ssl_start_server(x->fd, x->itimeout))) {
x->ierror = x->oerror = T;
return (NIL);
}
return (T);
}
/* ====================================================================== */
/* iostream_ssl_start_client() *******************************************
*
* Start up client side SSL on given iostream
* x: iostream.
************************************************************************/
BOOL iostream_ssl_start_client(struct iostream * x)
{
if (!(x->ssl = ssl_start_client(x->fd, x->itimeout))) {
x->ierror = x->oerror = T;
return (NIL);
}
return (T);
}
/* iostream_ssl_enabled() ************************************************
*
* Check with SSL enabled on this stream.
* x: iostream.
************************************************************************/
BOOL iostream_ssl_enabled(struct iostream * x)
{
return ((x->ssl) ? T : NIL);
}
/* ====================================================================== */
/* iostream_free() *******************************************************
*
* Free iostream.
************************************************************************/
void iostream_free(struct iostream *x)
{
if (x->ssl)
ssl_free(x->ssl);
if (x->pool)
return;
if (x->ibuffer)
free(x->ibuffer);
if (x->obuffer)
free(x->obuffer);
free(x);
}
/* iostream_free_buffers() ***********************************************
*
* Free iostream buffers: switching to raw I/O here
************************************************************************/
void iostream_free_buffers(struct iostream *x)
{
ioflush(x);
if (x->ibuffer)
free(x->ibuffer);
if (x->obuffer)
free(x->obuffer);
x->ibuffer = x->obuffer = NIL;
}
/* iostream_close() ******************************************************
*
* Close down and free iostream
************************************************************************/
void iostream_close(struct iostream *x)
{
int fd = x->fd;
if (x->obuffer)
iostream_flush(x); /* Flush data in write buffer */
if (x->ssl)
ssl_shutdown(x->ssl);
iostream_free(x);
close(fd);
}
/* iostream_debug() ******************************************************
*
* Enable or disable debugging on this iostream
* x: iostream
* enable: T => enable. NIL => disable.
************************************************************************/
void iostream_debug(struct iostream *x, BOOL enable)
{
x->debug = enable;
}
/* iostream_block() ******************************************************
*
* Enable or disable blocking I/O on this iostream
* x: iostream
* enable: T => enable. NIL => disable.
************************************************************************/
void iostream_block(struct iostream *x, BOOL enable)
{
x->blocking = enable;
}
/* iostream_set_timeout() ************************************************
*
* Set timeout on this iostream
* x: iostream
* timeout: timeout in seconds. 0 => disable timeout
************************************************************************/
void iostream_set_timeout(struct iostream *x, time_t timeout)
{
x->itimeout = timeout;
x->otimeout = timeout;
}
/* ====================================================================== */
/* iostream_getchar() ****************************************************
*
* Get a single character from I/O stream. Normally called from via
* iogetc() macro.
* x: IOstream
*
* Returns: unsigned character.
* EOF => end of file reached
* EOS => temporary end of stream if non-blocking I/O enabled.
************************************************************************/
int iostream_getchar(struct iostream *x)
{
int len;
fd_set readfds;
struct timeval timeout;
if (x->ierror || x->ieof)
return (EOF);
if (x->icurrent < x->ibufend)
return (*(x->icurrent++));
/* Check for data in SSL buffer than select() won't block */
if (x->ssl && ssl_pending(x->ssl)) {
while ((len = ssl_read(x->ssl, x->ibuffer, x->blocksize)) < 0) {
if (errno != EINTR) {
x->ieof = T;
return (EOF);
}
}
if (len == 0) {
x->ieof = T;
return (EOF);
}
if (x->debug)
write(STDERR_FILENO, x->ibuffer, len);
x->ibufend = x->ibuffer + len;
x->icurrent = x->ibuffer;
return (*(x->icurrent++));
}
if (x->blocking) {
if (x->itimeout > 0) {
FD_ZERO(&readfds);
FD_SET(x->fd, &readfds);
timeout.tv_sec = x->itimeout;
timeout.tv_usec = 0;
while (select(x->fd + 1, &readfds, NIL, NIL, &timeout) < 0) {
if (errno != EINTR)
log_fatal("iostream_getchar(): select() failed");
}
if (!FD_ISSET(x->fd, &readfds)) {
x->ierror = T;
return (EOF);
}
}
} else {
FD_ZERO(&readfds);
FD_SET(x->fd, &readfds);
/* Non blocking select */
timeout.tv_sec = 0;
timeout.tv_usec = 0;
while (select(x->fd + 1, &readfds, NIL, NIL, &timeout) < 0) {
if (errno != EINTR)
log_fatal("iostream_getchar(): select() failed");
}
if (!FD_ISSET(x->fd, &readfds))
return (EOS);
}
if (x->ssl) {
while ((len = ssl_read(x->ssl, x->ibuffer, x->blocksize)) < 0) {
if (errno != EINTR) {
x->ieof = T;
return (EOF);
}
}
if (len == 0) {
x->ieof = T;
return (EOF);
}
if (x->debug)
write(STDERR_FILENO, x->ibuffer, len);
x->ibufend = x->ibuffer + len;
x->icurrent = x->ibuffer;
return (*(x->icurrent++));
}
while ((len = read(x->fd, x->ibuffer, x->blocksize)) < 0) {
if (errno != EINTR) {
x->ieof = T;
return (EOF);
}
}
if (len == 0) {
x->ieof = T;
return (EOF);
}
if (x->debug)
write(STDERR_FILENO, x->ibuffer, len);
x->ibufend = x->ibuffer + len;
x->icurrent = x->ibuffer;
return (*(x->icurrent++));
}
/* ====================================================================== */
/* iostream_ungetchar() **************************************************
*
* Unget character from stream. Normally called via ioungetc() macro.
************************************************************************/
void iostream_ungetchar(char c, struct iostream *x)
{
*(--x->icurrent) = c;
}
/* ====================================================================== */
/* Check whether more input pending on this iostream */
/* iostream_have_buffered_input() ****************************************
*
* Check for input sitting in iostream or SSL buffers
* x: iostream
* ignore_white_space: Only trigger if non-whitespace characters queued.
*
* Returns: T => have buffered input
************************************************************************/
BOOL
iostream_have_buffered_input(struct iostream *x, BOOL ignore_whitespace)
{
unsigned char *s;
if (x->icurrent < x->ibufend) {
if (!ignore_whitespace)
return (T);
for (s = x->icurrent; s < x->ibufend; s++) {
if ((*s != '\015') && (*s != '\012'))
return (T);
}
}
if (x->ssl && ssl_pending(x->ssl))
return (T);
return (NIL);
}
/* iostream_have_buffered_input() ****************************************
*
* Check for input sitting in iostream or SSL buffers or pending on socket
* x: iostream
*
* Returns: T => have buffered input or read() can return without blocking
************************************************************************/
BOOL iostream_have_input(struct iostream * x)
{
fd_set readfds;
struct timeval timeout;
if (x->icurrent < x->ibufend) {
unsigned char *s;
for (s = x->icurrent; s < x->ibufend; s++) {
if ((*s != '\015') && (*s != '\012'))
return (T);
}
}
if (x->ssl && ssl_pending(x->ssl))
return (T);
FD_ZERO(&readfds);
FD_SET(x->fd, &readfds);
/* Poll for pending input on readfds */
timeout.tv_sec = 0;
timeout.tv_usec = 0;
while (select(x->fd + 1, &readfds, NIL, NIL, &timeout) < 0) {
if (errno != EINTR)
log_fatal("iostream_getchar(): select() failed");
}
if (FD_ISSET(x->fd, &readfds))
return (T);
return (NIL);
}
/* iostream_is_eof() *****************************************************
*
* Check whether iostream has reached EOF on input.
************************************************************************/
BOOL iostream_is_eof(struct iostream * x)
{
return ((x->ieof) ? T : NIL);
}
/* ====================================================================== */
/* iostream_flush() ******************************************************
*
* Flush buffered data to iostream.
************************************************************************/
BOOL iostream_flush(struct iostream * x)
{
unsigned char *current = x->obuffer;
int bytes = x->ocurrent - x->obuffer;
int count;
fd_set writefds;
struct timeval timeout;
if (x->obuffer == NIL) /* No output buffer on stream */
return (T);
/* Reset ptr to start of output buffer */
x->ocurrent = x->obuffer;
while ((!x->oerror) && (bytes > 0)) {
if (x->otimeout > 0) {
FD_ZERO(&writefds);
FD_SET(x->fd, &writefds);
timeout.tv_sec = x->otimeout;
timeout.tv_usec = 0;
while (select(x->fd + 1, NIL, &writefds, NIL, &timeout) < 0) {
if (errno != EINTR)
log_fatal("iostream_flush(): select() failed");
}
if (!FD_ISSET(x->fd, &writefds)) {
x->oerror = T;
return (NIL);
}
}
if (x->ssl) {
if ((count = ssl_write(x->ssl, current, bytes)) < 0) {
if (errno != EINTR) {
log_debug("ssl_write() error: %d %d\n",
count, ssl_get_error(x->ssl, count));
x->oerror = T;
break;
}
}
if (count == 0) {
log_debug("ssl_write() error: %d %d\n",
count, ssl_get_error(x->ssl, count));
x->oerror = T;
}
if (x->oerror)
break;
if (x->debug)
write(STDERR_FILENO, current, count);
current += count;
bytes -= count;
continue;
}
if ((count = write(x->fd, current, bytes)) < 0) {
if (errno != EINTR) {
x->oerror = T;
break;
}
}
if (count == 0)
x->oerror = T;
if (x->oerror)
break;
if (x->debug)
write(STDERR_FILENO, current, count);
current += count;
bytes -= count;
}
return ((x->oerror) ? NIL : T);
}
/* ====================================================================== */
/* iostream_putchar() ****************************************************
*
* Push character through iostream output buffer. Normally called via
* ioputc macro.
* c: Character to push
* x: IOstream
************************************************************************/
BOOL iostream_putchar(char c, struct iostream * x)
{
if ((x->ocurrent == x->obufend) && !iostream_flush(x))
return (NIL); /* Flush failed */
*(x->ocurrent++) = c;
return (T);
}
/* iostream_puts() *******************************************************
*
* Push string through iostream output buffer. Normally called via
* ioputs macro.
* x: IOstream
* s: String to print
************************************************************************/
BOOL iostream_puts(struct iostream * x, unsigned char *s)
{
unsigned char c;
if (!s)
ioputs(x, "(nil)");
else
while ((c = *s++))
ioputc(c, x);
return (T);
}
/* Static support routine for iostream_printf */
static void iostream_print_ulong(struct iostream *x, unsigned long value)
{
unsigned long tmp, weight;
/* All numbers contain at least one digit.
* Find weight of most significant digit. */
for (weight = 1, tmp = value / 10; tmp > 0; tmp /= 10)
weight *= 10;
for (tmp = value; weight > 0; weight /= 10) {
if (value >= weight) { /* Strictly speaking redundant... */
ioputc('0' + (value / weight), x); /* Digit other than zero */
value -= weight * (value / weight); /* Calculate remainder */
} else
ioputc('0', x);
}
}
/* iostream_printf() *****************************************************
*
* Print string through iostream output buffer. Normally called via
* ioprintf macro.
* x: IOstream
* fmt: String to print, followed by arguments.
************************************************************************/
BOOL iostream_printf(struct iostream *x, char *fmt, ...)
{
va_list ap;
char *s;
char c;
va_start(ap, fmt);
while ((c = *fmt++)) {
if (c != '%') {
ioputc(c, x);
} else
switch (*fmt++) {
case 's': /* string */
if ((s = va_arg(ap, char *))) {
while ((c = *s++))
ioputc(c, x);
} else
ioprintf(x, "(nil)");
break;
case 'l':
if (*fmt == 'u') {
iostream_print_ulong(x, va_arg(ap, unsigned long));
fmt++;
} else
iostream_print_ulong(x, va_arg(ap, long));
break;
case 'd':
if (*fmt == 'u') {
iostream_print_ulong(x, va_arg(ap, unsigned int));
fmt++;
} else
iostream_print_ulong(x, va_arg(ap, int));
break;
case 'c':
ioputc((char) va_arg(ap, int), x);
break;
case '%':
ioputc('%', x);
break;
default:
log_fatal("Bad format string to iostream_printf");
}
}
va_end(ap);
return (T);
}
/* ====================================================================== */
/* iostream_getline() ****************************************************
*
* Get line from iostream with known upper bound. Most client routines
* dealing with arbitary amounts of data won't use this, instead they
* copy data to a temporary buffer and buffer_fetch() from that.
* stream: iostream
* s: Target buffer
* length: Size of target buffer.
************************************************************************/
BOOL iostream_getline(struct iostream * stream, char *s, int length)
{
int c = EOF;
length--; /* Leave space for trailing '\0' */
while ((length > 0) && ((c = iogetc(stream)) != EOF)) {
if (c == '\015')
continue;
if (c == '\012')
break;
*s++ = c;
length--;
}
*s = '\0';
return (((c != EOF) && (length > 0)) ? T : NIL);
}
/* iostream_getline_overflow() ********************************************
*
* Get line from iostream with known upper bound. Most client routines
* dealing with arbitary amounts of data won't use this, instead they
* copy data to a temporary buffer and buffer_fetch() from that.
* stream: iostream
* s: Target buffer
* length: Size of target buffer.
* toolong: Input line was too long
************************************************************************/
BOOL
iostream_getline_overflow(struct iostream * stream,
char *s, int length, BOOL * overflowp)
{
int c = EOF;
length--; /* Leave space for trailing '\0' */
while ((length > 0) && ((c = iogetc(stream)) != EOF)) {
if (c == '\015')
continue;
if (c == '\012')
break;
*s++ = c;
length--;
}
*s = '\0';
if (length == 0) {
if (overflowp)
*overflowp = T;
return (NIL);
}
if (overflowp)
*overflowp = NIL;
return ((c != EOF) ? T : NIL);
}
syntax highlighted by Code2HTML, v. 0.9.1