/* Socket routines.
 *
 * IRC Services is copyright (c) 1996-2007 Andrew Church.
 *     E-mail: <achurch@achurch.org>
 * Parts written by Andrew Kempe and others.
 * This program is free but copyrighted software; see the file COPYING for
 * details.
 */

#include "services.h"
#include <fcntl.h>
#include <sys/socket.h>
#include <netdb.h>
#include <netinet/in.h>
#include <arpa/inet.h>

/*************************************************************************/

/* Socket data structure */

struct socket_ {
    Socket *next, *prev;
    int fd;			/* Socket's file descriptor */
    int flags;			/* Status flags (SF_*) */
    struct sockaddr_in remote;	/* Remote address */

    /* Usage of pointers:
     *    - Xbuf is the buffer base pointer
     *    - Xptr is the address of the next character to store
     *    - Xend is the address of the next character to retrieve
     *    - Xtop is the address of the last byte of the buffer + 1
     * Xend-Xptr (mod Xbufsize) gives the number of bytes in the buffer.
     */
    char *rbuf, *rptr, *rend, *rtop;  /* Read buffer and pointers */
    char *wbuf, *wptr, *wend, *wtop;  /* Write buffer and pointers */
    uint32 rbufsize, wbufsize;	/* Buffer sizes */

    SocketCallback cb_connect;	/* Connect callback */
    SocketCallback cb_disconn;	/* Disconnect callback */
    SocketCallback cb_accept;	/* Accept callback */
    SocketCallback cb_read;	/* Data-available callback */
    SocketCallback cb_readline;	/* Line-available callback */

    uint32 total_read;		/* Total number of kilobytes read */
    int total_read_B;		/* Fractional number of kilobytes read */
    uint32 total_written;	/* Total number of kilobytes written */
    int total_written_B;	/* Fractional number of kilobytes written */
};

#define SF_SELFCREATED	0x0001	/* We created this socket ourselves */
#define SF_CONNECTING	0x0002	/* Socket is busy connecting */
#define SF_CONNECTED	0x0004	/* Socket has connected */
#define SF_LISTENER	0x0008	/* Socket is a listener socket */
#define SF_CALLBACK	0x0010	/* Currently calling callbacks */
#define SF_WARNED	0x0020	/* Warned about hitting buffer limit */
#define SF_DISCONNECT	0x0040	/* Disconnect when writebuf empty */
#define SF_DELETEME	0x0080	/* Delete socket when convenient */
#define SF_DISCONN_CB	0x0100	/* Disconnect callback has been called (used
				 *    to prevent duplicate calls) */
#define SF_BLOCKING	0x0200	/* Writes are blocking */

/*************************************************************************/

/* List of all sockets (even unopened ones) */
static Socket *allsockets = NULL;

/* Array of all opened sockets, dynamically allocated */
static Socket **sockets = NULL;

/* Highest FD number in use plus 1; also length of sockets[] array */
static int max_fd;

/* Set of all connected socket FDs */
static fd_set sock_fds;

/* Set of all FDs that need data written (or are connecting) */
static fd_set write_fds;

/* Total memory used by socket buffers */
static uint32 total_bufsize;

/*************************************************************************/

/* Internal routine declarations (definitions at bottom of file) */

static void do_accept(Socket *s);
static int fill_read_buffer(Socket *s);
static int flush_write_buffer(Socket *s);
static void resize_rbuf(Socket *s, uint32 size);
static void resize_wbuf(Socket *s, uint32 size);
static void resize_buf(char **p_buf, char **p_ptr, char **p_end, char **p_top,
		       uint32 newsize);
static int reclaim_buffer_space_one(Socket *s);
static int reclaim_buffer_space(void);
static int buffered_write(Socket *s, const char *buf, int len);
static int do_disconn(Socket *s, void *code);
static void sock_closefd(Socket *s);

/*************************************************************************/
/*************************** Global routines *****************************/
/*************************************************************************/

/* Create and return a new socket.  Returns NULL if unsuccessful (i.e. no
 * more space for buffers).
 */

Socket *sock_new(void)
{
    Socket *s;

    if (TotalNetBufferSize) {
	while (total_bufsize + NET_MIN_BUFSIZE*2 > TotalNetBufferSize) {
	    if (!reclaim_buffer_space()) {
		log("sockets: sock_new(): out of buffer space!");
		return NULL;
	    }
	}
    }

    s = smalloc(sizeof(*s));
    s->fd = -1;
    s->flags = 0;
    memset(&s->remote, 0, sizeof(s->remote));
    s->rbuf = s->rptr = s->rend = smalloc(NET_MIN_BUFSIZE);
    s->rtop = s->rbuf + NET_MIN_BUFSIZE;
    s->wbuf = s->wptr = s->wend = smalloc(NET_MIN_BUFSIZE);
    s->wtop = s->wbuf + NET_MIN_BUFSIZE;
    s->rbufsize = s->wbufsize = NET_MIN_BUFSIZE;
    s->cb_connect  = NULL;
    s->cb_disconn  = NULL;
    s->cb_accept   = NULL;
    s->cb_read     = NULL;
    s->cb_readline = NULL;
    s->total_read = s->total_written = 0;
    s->total_read_B = s->total_written_B = 0;
    LIST_INSERT(s, allsockets);
    total_bufsize += s->rbufsize + s->wbufsize;
    return s;
}

/*************************************************************************/

/* Free a socket, first disconnecting/closing it if necessary. */

void sock_free(Socket *s)
{
    if (!s) {
	log("sockets: sock_free() with NULL socket!");
	errno = EINVAL;
	return;
    }
    if (s->flags & (SF_CONNECTING | SF_CONNECTED)) {
	s->flags |= SF_DELETEME;
	do_disconn(s, DISCONN_LOCAL);
	/* do_disconn() will call us again at the appropriate time */
	return;
    } else if (s->flags & SF_LISTENER) {
	close_listener(s);
    }
    LIST_REMOVE(s, allsockets);
    total_bufsize -= s->rbufsize + s->wbufsize;
    free(s->rbuf);
    free(s->wbuf);
    free(s);
}

/*************************************************************************/

/* Set a callback on a socket. */

void sock_setcb(Socket *s, int which, SocketCallback func)
{
    if (!s) {
	log("sockets: sock_setcb() with NULL socket!");
	errno = EINVAL;
	return;
    }
    switch (which) {
      case SCB_CONNECT:		s->cb_connect  = func; break;
      case SCB_DISCONNECT:	s->cb_disconn  = func; break;
      case SCB_ACCEPT:		s->cb_accept   = func; break;
      case SCB_READ:		s->cb_read     = func; break;
      case SCB_READLINE:	s->cb_readline = func; break;
      default:
	log("sockets: sock_setcb(): invalid callback ID %d", which);
	break;
    }
}

/*************************************************************************/

/* Return whether the given socket is currently connected. */

int sock_isconn(const Socket *s)
{
    if (!s) {
	log("sockets: sock_isconn() with NULL socket!");
	errno = EINVAL;
	return 0;
    }
    return s->flags & SF_CONNECTED ? 1 : 0;
}

/*************************************************************************/

/* Retrieve address of remote end of socket.  Functions the same way as
 * getpeername() (initialize *lenptr to sizeof(sa), address returned in sa,
 * non-truncated length of address returned in *lenptr).  Returns -1 with
 * errno == EINVAL if a NULL pointer is passed or the given socket is not
 * connected.
 */

int sock_remote(const Socket *s, struct sockaddr *sa, int *lenptr)
{
    if (!s || !sa || !lenptr || !(s->flags & SF_CONNECTED)) {
	if (!s || !sa || !lenptr) {
	    log("sockets: sock_remote() with NULL %s!",
		!s ? "socket" : !sa ? "sockaddr" : "lenptr");
	}
	errno = EINVAL;
	return -1;
    }
    if (sizeof(s->remote) <= *lenptr)
	memcpy(sa, &s->remote, sizeof(s->remote));
    else
	memcpy(sa, &s->remote, *lenptr);
    *lenptr = sizeof(s->remote);
    return 0;
}

/*************************************************************************/

/* Set whether socket writes should block (blocking != 0) or not
 * (blocking == 0).
 */

void sock_set_blocking(Socket *s, int blocking)
{
    if (!s) {
	log("sockets: sock_set_blocking() with NULL socket!");
	errno = EINVAL;
	return;
    }
    if (blocking)
	s->flags |= SF_BLOCKING;
    else
	s->flags &= ~SF_BLOCKING;
}

/*************************************************************************/

/* Return whether socket writes are blocking (return value != 0) or not
 * (return value == 0).
 */

int sock_get_blocking(Socket *s)
{
    if (!s) {
	log("sockets: socket_get_blocking() with NULL socket!");
	errno = EINVAL;
	return -1;
    }
    return s->flags & SF_BLOCKING;
}

/*************************************************************************/

/* Return amount of data in read buffer.  Assumes socket is valid. */

inline uint32 read_buffer_len(const Socket *s)
{
    if (s->rend >= s->rptr)
	return s->rend - s->rptr;
    else
	return (s->rend + s->rbufsize) - s->rptr;
}


/*************************************************************************/

/* Return amount of data in write buffer.  Assumes socket is valid. */

inline uint32 write_buffer_len(const Socket *s)
{
    if (s->wend >= s->wptr)
	return s->wend - s->wptr;
    else
	return (s->wend + s->wbufsize) - s->wptr;
}

/*************************************************************************/

/* Return total number of kilobytes received and sent on this socket in
 * *readkb_ret and *writekb_ret respectively.  Sent data count does not
 * include buffered but unsent data.
 */

void sock_rwstat(const Socket *s, uint32 *readkb_ret, uint32 *writekb_ret)
{
    if (!s) {
	log("sockets: sock_rwstat() with NULL socket!");
	errno = EINVAL;
	return;
    }
    if (readkb_ret)
	*readkb_ret = s->total_read;
    if (writekb_ret)
	*writekb_ret = s->total_written;
}

/*************************************************************************/

/* Return the larger of (1) the ratio of the given socket's total buffer
 * size (read and write buffers combined) to NetBufferSize and (2) the
 * ratio of the amount of memory used by all sockets' buffers to
 * TotalNetBufferSize, as a percentage rounded up to the next integer.
 * Ratios (1) and (2) are zero if NetBufferSize or TotalNetBufferSize,
 * respectively, are not set.
 *
 * If any of `socksize_ret', `totalsize_ret', `ratio1_ret', and
 * `ratio2_ret' are non-NULL, they are set respectively to the given
 * socket's total buffer size, the amount of memory used by all sockets'
 * buffers, ratio (1) as a percentage, and ratio (2) as a percentage.
 *
 * If `s' is NULL, ratio (1) is set to zero, and *socksize_ret is not
 * modified.
 */

int sock_bufstat(const Socket *s, uint32 *socksize_ret,
		 uint32 *totalsize_ret, int *ratio1_ret, int *ratio2_ret)
{
    int ratio1 = 0, ratio2 = 0;

    if (NetBufferSize && s) {
	uint32 size = s->rbufsize + s->wbufsize;
	if (NetBufferSize <= 0x7FFFFFFF/100)
	    ratio1 = (size*100 + NetBufferSize-1) / NetBufferSize;
	else
	    ratio1 = (size + NetBufferSize/100-1) / (NetBufferSize/100);
    }
    if (TotalNetBufferSize) {
	if (NetBufferSize <= 0x7FFFFFFF/100)
	    ratio2 = (total_bufsize*100 + TotalNetBufferSize-1)
		   / TotalNetBufferSize;
	else
	    ratio2 = (total_bufsize + TotalNetBufferSize/100-1)
		   / (TotalNetBufferSize/100);
    }
    if (socksize_ret && s)
	*socksize_ret = s->rbufsize + s->wbufsize;
    if (totalsize_ret)
	*totalsize_ret = total_bufsize;
    if (ratio1_ret)
	*ratio1_ret = ratio1;
    if (ratio2_ret)
	*ratio2_ret = ratio2;
    if (ratio1 > ratio2)
	return ratio1;
    else
	return ratio2;
}

/*************************************************************************/
/*************************************************************************/

/* Check all sockets for activity, and call callbacks as necessary.
 * Returns after activity has been detected on at least one socket or
 * ReadTimeout milliseconds have elapsed.
 */

void check_sockets(void)
{
    fd_set rfds, wfds;
    struct timeval tv;
    int i;
    int32 res;
    Socket *s, *s2;

    rfds = sock_fds;
    wfds = write_fds;
    tv.tv_sec = ReadTimeout/1000;
    tv.tv_usec = (ReadTimeout%1000) * 1000;
    enable_signals();
    do {
	/* Note: On systems which don't return the time remaining in `tv',
	 * this loop could continue forever if Services keeps getting
	 * signals, but as that shouldn't ordinarily happen we don't worry
	 * about it. */
	res = select(max_fd, &rfds, &wfds, NULL, &tv);
    } while (res < 0 && errno == EINTR);
    disable_signals();
    if (debug >= 3)
	log("debug: sockets: select returned %d", res);
    if (res <= 0) {
	if (res < 0)
	    log_perror("sockets: select()");
	return;
    }

    s = NULL;
    for (i = 0; i < max_fd; i++) {
	if (s)
	    s->flags &= ~SF_CALLBACK;
	s = sockets[i];
	if (s && s->fd != i) {
	    log("sockets: BUG: sockets[%d]->fd = %d (should be equal),"
		" clearing socket from table", i, s->fd);
	    sockets[i] = NULL;
	}
	if (s)
	    s->flags |= SF_CALLBACK;

	if (FD_ISSET(i, &wfds)) {
	    if (debug >= 3)
		log("debug: sockets: write ready on fd %d", i);
	    if (!s) {
		log("sockets: BUG: got write-ready on fd %d but no socket"
		    " for it!", i);
		continue;
	    } else if (s->flags & SF_CONNECTING) {
		/* Connection established (or failed) */
		int val;
		socklen_t vallen;
		vallen = sizeof(val);
		if (debug >= 2)
		    log("debug: sockets: connect on fd %d returned", i);
		if (getsockopt(i, SOL_SOCKET, SO_ERROR, &val, &vallen) < 0) {
		    log_perror("sockets: getsockopt(SO_ERROR) for connect (%d"
			       " -> %s:%u)", i, inet_ntoa(s->remote.sin_addr),
			       htons(s->remote.sin_port));
		    do_disconn(s, DISCONN_CONNFAIL);
		    continue;
		}
		if (val != 0) {
		    errno = val;
		    log_perror("sockets: connect(%d -> %s:%u)", i,
			       inet_ntoa(s->remote.sin_addr),
			       htons(s->remote.sin_port));
		    do_disconn(s, DISCONN_CONNFAIL);
		    continue;
		} else {
		    if (s->cb_connect)
			s->cb_connect(s, 0);
		}
		if (s->fd >= 0) {  /* the socket might have been closed */
		    s->flags &= ~SF_CONNECTING;
		    s->flags |= SF_CONNECTED;
		    FD_SET(i, &sock_fds);
		}
		FD_CLR(i, &write_fds);
	    } else if (!(s->flags & SF_CONNECTED)) {
		log("sockets: BUG: got write-ready on fd %d but socket not"
		    " connected!", i);
	    } else {
		flush_write_buffer(s);
	    }
	} /* set in write fds */

	if (FD_ISSET(i, &rfds)) {
	    if (debug >= 3)
		log("debug: sockets: read ready on fd %d", i);
	    if (!s) {
		log("sockets: BUG: got data on fd %d but no socket for it!",i);
		FD_CLR(i, &sock_fds);
		continue;
	    }
	    if ((s->flags & SF_LISTENER) && s->cb_accept) {
		/* Connection arrived */
		do_accept(s);
		continue;
	    } else if (!(s->flags & SF_CONNECTED)) {
		log("sockets: BUG: got data on fd %d but not connected!", i);
		FD_CLR(i, &sock_fds);
		continue;
	    }
	    /* Normal read */
	    if (read_buffer_len(s) >= s->rbufsize-1) {
		/* Buffer is full, try to expand it */
		int newsize = 0;
		if (s->rbufsize < NET_MIN_BUFSIZE)
		    newsize = NET_MIN_BUFSIZE;
		else if (s->rbufsize+NET_MIN_BUFSIZE < NetBufferSize)
		    newsize = s->rbufsize + NET_MIN_BUFSIZE;
		if (newsize > 0)
		    resize_rbuf(s, newsize);
	    }
	    res = fill_read_buffer(s);
	    if (res < 0) {
		/* Connection was closed (or some other error occurred) */
		if (debug && res < 0)
		    log_perror("debug: sockets: read(%d)", i);
		do_disconn(s, DISCONN_REMOTE);
		continue;
	    } else if (res == 0) {
		log_perror("sockets: BUG: fill_read_buffer() returned 0!");
	    } else {
		uint32 left = read_buffer_len(s), newleft;
		if (left == 0) {
		    log("sockets: BUG: 0 bytes avail after successful read!");
		    continue;
		}
		/* Call read callback(s) in a loop until no more data is
		 * left or neither callback takes any data, or the socket is
		 * disconnected */
		do {
		    newleft = left;
		    if (s->cb_read) {
			s->cb_read(s, (void *)(long)newleft);
			if ((s->flags & SF_DISCONNECT) || s->fd < 0)
			    break;
			newleft = read_buffer_len(s);
		    }
		    if (s->cb_readline) {
			char *newline;
			if (s->rend > s->rptr) {
			    newline = memchr(s->rptr, '\n', newleft);
			} else {
			    newline = memchr(s->rptr, '\n', s->rtop - s->rptr);
			    if (!newline)
				newline = memchr(s->rbuf, '\n',
						 s->rend - s->rbuf);
			}
			if (newline) {
			    s->cb_readline(s, (void *)(long)newleft);
			    if ((s->flags & SF_DISCONNECT) || s->fd < 0)
				break;
			    newleft = read_buffer_len(s);
			}
		    }
		} while (newleft != left && (left = newleft) != 0);
		reclaim_buffer_space_one(s);
	    }
	} /* socket ready for reading */

	if (s && (s->flags & SF_DELETEME))
	    sock_free(s);

    } /* for all sockets */

    if (s)  /* clear SF_CALLBACK from last socket */
	s->flags &= ~SF_CALLBACK;

    /* Clear out any ready-to-be-deleted sockets */
    /* FIXME: this should never happen but it seems to anyway */
    LIST_FOREACH_SAFE (s, allsockets, s2) {
	if (s->flags & SF_DELETEME)
	    sock_free(s);
    }

} /* check_sockets() */

/*************************************************************************/
/*************************************************************************/

/* Initiate a connection to the given host and port.  If an error occurs,
 * returns -1, else returns 0.  The connection is not necessarily complete
 * even if 0 is returned, and may later fail; use the SCB_CONNECT and
 * SCB_DISCONNECT callbacks.  If this function fails due to inability to
 * resolve a hostname, errno will be set to the negative of h_errno; pass
 * the negative of this value to hstrerror() to get an appropriate error
 * message.
 *
 * lhost/lport specify the local side of the connection.  If they are not
 * given (lhost==NULL, lport==0), then they are left to be set by the OS.
 *
 * If either host or lhost is not a valid IP address and the gethostbyname()
 * function is available, this function may block while the hostname is
 * being resolved.
 *
 * This function may be called from a socket's disconnect callback to
 * establish a new connection using the same socket.  It may not be called,
 * however, if the socket is being freed with sock_free().
 */

int conn(Socket *s, const char *host, int port, const char *lhost, int lport)
{
#if HAVE_GETHOSTBYNAME
    struct hostent *hp;
#endif
    uint8 *addr;
    struct sockaddr_in sa, lsa;
    int fd, i;

    if (!s || !host || port <= 0 || port > 65535) {
	if (port <= 0 || port > 65535)
	    log("sockets: conn() with bad port number (%d)!", port);
	else
	    log("sockets: conn() with NULL %s!", !s ? "socket" : "hostname");
	errno = EINVAL;
	return -1;
    }
    if (s->flags & SF_DELETEME) {
	log("sockets: conn() called on a freeing socket (%p)", s);
	errno = EPERM;
	return -1;
    }
    memset(&lsa, 0, sizeof(lsa));
    lsa.sin_family = AF_INET;
    if (lhost) {
	if ((addr = pack_ip(lhost)) != 0)
	    memcpy((char *)&lsa.sin_addr, addr, 4);
#if HAVE_GETHOSTBYNAME
	else if ((hp = gethostbyname(lhost)) != NULL)
	    memcpy((char *)&lsa.sin_addr, hp->h_addr, hp->h_length);
#endif
	else
	    lhost = NULL;
    }
    if (lport)
	lsa.sin_port = htons(lport);

    memset(&sa, 0, sizeof(sa));
    if ((addr = pack_ip(host)) != 0) {
	memcpy((char *)&sa.sin_addr, addr, 4);
	sa.sin_family = AF_INET;
    }
#if HAVE_GETHOSTBYNAME
    else if ((hp = gethostbyname(host)) != NULL) {
	memcpy((char *)&sa.sin_addr, hp->h_addr, hp->h_length);
	sa.sin_family = hp->h_addrtype;
    } else {
	errno = -h_errno;
	return -1;
    }
#else
    else {
	log("sockets: conn(): `%s' is not a valid IP address", host);
	errno = EINVAL;
	return -1;
    }
#endif
    sa.sin_port = htons((uint16)port);

    if ((fd = socket(sa.sin_family, SOCK_STREAM, 0)) < 0)
	return -1;

    if (fcntl(fd, F_SETFL, O_NDELAY) < 0) {
	int errno_save = errno;
	close(fd);
	errno = errno_save;
	return -1;
    }

    if ((lhost || lport) && bind(fd,(struct sockaddr *)&lsa,sizeof(sa)) < 0) {
	int errno_save = errno;
	close(fd);
	errno = errno_save;
	return -1;
    }

    if ((i = connect(fd, (struct sockaddr *)&sa, sizeof(sa))) < 0
	&& errno != EINPROGRESS
    ) {
	int errno_save = errno;
	close(fd);
	errno = errno_save;
	return -1;
    }

    if (max_fd < fd+1) {
	int j;
	sockets = srealloc(sockets, (fd+1) * sizeof(*sockets));
	for (j = max_fd; j < fd; j++)
	    sockets[j] = NULL;
	max_fd = fd+1;
    }
    sockets[fd] = s;
    s->remote = sa;
    s->fd = fd;
    if (i == 0) {
	s->flags |= SF_CONNECTED;
	FD_SET(fd, &sock_fds);
	if (s->cb_connect)
	    s->cb_connect(s, 0);
    } else {
	s->flags |= SF_CONNECTING;
	FD_SET(fd, &write_fds);
    }
    return 0;
}

/*************************************************************************/

/* Disconnect a socket.  Returns 0 on success, -1 on error (s == NULL or
 * listener socket).  Calling this routine on an already-disconnected
 * socket returns success without doing anything.  Note that the socket may
 * not be disconnected immediately; callers who intend to reuse the socket
 * MUST wait until the disconnect callback is called before doing so.
 */

int disconn(Socket *s)
{
    if (!s) {
	log("sockets: disconn() with NULL socket!");
	errno = EINVAL;
	return -1;
    }
    return do_disconn(s, DISCONN_LOCAL);
}

/*************************************************************************/

/* Open a listener socket on the given host and port; returns 0 on success
 * (the socket is set up and listening), -1 on error.  If `host' has
 * multiple addresses, only the first one is used; if `host' is NULL, all
 * addresses are bound to.  As with conn(), a negative errno value
 * indicates a failure to resolve the hostname `host'.  `backlog' is the
 * backlog limit for incoming connections, and is passed directly to the
 * listen() system call.
 *
 * Note that even after the socket is successfully created, it will not
 * accept any connections until/unless the SCB_ACCEPT callback is set.
 *
 * If host is not a valid IP address and the gethostbyname() function is
 * available, this function may block while the hostname is being resolved.
 */

int open_listener(Socket *s, const char *host, int port, int backlog)
{
#if HAVE_GETHOSTBYNAME
    struct hostent *hp;
#endif
    uint8 *addr;
    struct sockaddr_in sa;
    int fd, i;

    if (!s || port <= 0 || port > 65535 || backlog < 1) {
	if (port <= 0 || port > 65535)
	    log("sockets: open_listener() with bad port number (%d)!", port);
	else if (backlog < 1)
	    log("sockets: open_listener() with bad backlog (%d)!", backlog);
	else
	    log("sockets: open_listener() with NULL socket!");
	errno = EINVAL;
	return -1;
    }
    memset(&sa, 0, sizeof(sa));
    if (host) {
	if ((addr = pack_ip(host)) != 0) {
	    memcpy((char *)&sa.sin_addr, addr, 4);
	    sa.sin_family = AF_INET;
	}
#if HAVE_GETHOSTBYNAME
	else if ((hp = gethostbyname(host)) != NULL) {
	    memcpy((char *)&sa.sin_addr, hp->h_addr, hp->h_length);
	    sa.sin_family = hp->h_addrtype;
	} else {
	    errno = -h_errno;
	    return -1;
	}
#else
	else {
	    log("sockets: open_listener(): `%s' is not a valid IP address",
		host);
	    errno = EINVAL;
	    return -1;
	}
#endif
    } else {  /* !host */
	sa.sin_family = AF_INET;
	sa.sin_addr.s_addr = INADDR_ANY;
    }
    sa.sin_port = htons((uint16)port);

    if ((fd = socket(sa.sin_family, SOCK_STREAM, 0)) < 0)
	return -1;

    i = 1;
    if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &i, sizeof(i)) < 0) {
	log_perror("sockets: open_listener(): setsockopt(%d, SO_REUSEADDR,"
		   " 1) failed", fd);
    }

    if (fcntl(fd, F_SETFL, O_NDELAY) < 0) {
	int errno_save = errno;
	close(fd);
	errno = errno_save;
	return -1;
    }

    if (bind(fd, (struct sockaddr *)&sa, sizeof(sa)) < 0) {
	int errno_save = errno;
	close(fd);
	errno = errno_save;
	return -1;
    }

    if (listen(fd, backlog) < 0) {
	int errno_save = errno;
	close(fd);
	errno = errno_save;
	return -1;
    }

    /* Listener sockets don't need read/write buffers */
    free(s->rbuf);
    free(s->wbuf);
    s->rbuf = s->rptr = s->rend = s->rtop = NULL;
    s->wbuf = s->wptr = s->wend = s->wtop = NULL;
    s->rbufsize = 0;
    s->wbufsize = 0;

    if (max_fd < fd+1) {
	sockets = srealloc(sockets, (fd+1) * sizeof(*sockets));
	for (i = max_fd; i < fd; i++)
	    sockets[i] = NULL;
	max_fd = fd+1;
    }
    sockets[fd] = s;
    FD_SET(fd, &sock_fds);
    s->fd = fd;
    s->flags |= SF_LISTENER;
    return 0;
}

/*************************************************************************/

/* Close a listener socket. */

int close_listener(Socket *s)
{
    if (s == NULL || !(s->flags & SF_LISTENER)) {
	if (s)
	    log("sockets: close_listener() with non-listener socket (%d)!",
		s->fd);
	else
	    log("sockets: close_listener() with NULL socket!");
	errno = EINVAL;
	return -1;
    }
    sock_closefd(s);
    s->flags &= ~SF_LISTENER;
    return 0;
}

/*************************************************************************/

/* Read raw data from a socket, like read().  Returns number of bytes read,
 * or -1 on error.  Only reads from the buffer (does not attempt to fetch
 * more data from the connection).
 */

int32 sread(Socket *s, char *buf, int32 len)
{
    int32 nread = 0;

    if (!s || !buf || len <= 0) {
	log("sockets: sread() with %s!",
	    !s ? "NULL socket" : !buf ? "NULL buffer" : "len <= 0");
	errno = EINVAL;
	return -1;
    }
    if (s->rend < s->rptr) {
	/* Buffer data wraps around */
	if (s->rptr + len <= s->rtop) {
	    /* Only need to read from end of buffer */
	    memcpy(buf, s->rptr, len);
	    s->rptr += len;
	    if (s->rptr >= s->rtop)
		s->rptr -= s->rbufsize;
	    return len;
	} else {
	    /* Need to read from both end and beginning */
	    nread = s->rtop - s->rptr;
	    memcpy(buf, s->rptr, nread);
	    s->rptr = s->rbuf;
	    len -= nread;
	    /* Continue below */
	}
    }
    /* Read data from beginning of buffer */
    if (s->rptr < s->rend) {
	if (len > s->rend - s->rptr)
	    len = s->rend - s->rptr;
	memcpy(buf+nread, s->rptr, len);
	s->rptr += len;
	nread += len;
    }
    /* Return number of bytes read */
    return nread;
}

/*************************************************************************/

/* Write raw data to a socket, like write().  Returns number of bytes
 * written, or -1 on error.
 */

int32 swrite(Socket *s, const char *buf, int32 len)
{
    if (!s || !buf || len < 0) {
	log("sockets: swrite() with %s!",
	    !s ? "NULL socket" : !buf ? "NULL buffer" : "len <= 0");
	errno = EINVAL;
	return -1;
    }
    return buffered_write(s, buf, len);
}

/*************************************************************************/
/*************************************************************************/

/* Read a character from a socket, like fgetc().  Returns EOF if no data
 * is available in the socket buffer.  Assumes the socket is valid.
 */

int sgetc(Socket *s)
{
    int c;

    /* No paranoia check here, to save time */
    if (s->rptr == s->rend)
	return EOF;
    c = *s->rptr++;
    if (s->rptr >= s->rtop)
	s->rptr -= s->rbufsize;
    return c;
}

/*************************************************************************/

/* Read a line from a socket, like fgets().  If not enough buffered data
 * is available to fill a complete line, or another error occurs, returns
 * NULL.
 */

char *sgets(char *buf, int32 len, Socket *s)
{
    char *ptr = s->rptr, *eol;
    int32 to_top = s->rtop - ptr;  /* used for efficiency */

    if (!s || !buf || len <= 0) {
	log("sockets: sgets[2]() with %s!",
	    !s ? "NULL socket" : !buf ? "NULL buffer" : "len <= 0");
	return NULL;
    }

    /* Find end of line */
    if (s->rend > s->rptr) {
	eol = memchr(s->rptr, '\n', s->rend - s->rptr);
    } else {
	eol = memchr(s->rptr, '\n', to_top);
	if (!eol)
	    eol = memchr(s->rbuf, '\n', s->rend - s->rbuf);
    }
    if (!eol)
	return NULL;
    eol++;			/* Point 1 byte after \n */

    /* Set rptr now; old value is in ptr variable */
    s->rptr = eol;
    if (s->rptr >= s->rtop)	/* >rtop is impossible, but just in case */
	s->rptr = s->rbuf;

    /* Note: The greatest possible value for eol is s->rend, so as long as
     * we ensure that rend doesn't wrap around and reach rptr (i.e. always
     * leave at least 1 byte in the buffer unused), we can never have
     * eol == ptr here. */

    /* Trim eol to <len bytes */
    if (eol > ptr) {
	if (eol-ptr >= len)
	    eol = ptr + len-1;
    } else {
	if (to_top >= len-1)  /* we don't mind eol == rtop */
	    eol = ptr + len-1;
	else if (to_top + (eol - s->rbuf) >= len)
	    eol = s->rbuf + (len-1 - to_top);
    }

    /* Actually copy to buffer and return */
    if (eol > ptr) {
	memcpy(buf, ptr, eol-ptr);
	buf[eol-ptr] = 0;
    } else {
	memcpy(buf, ptr, to_top);
	memcpy(buf+to_top, s->rbuf, eol - s->rbuf);
	buf[to_top + (eol - s->rbuf)] = 0;
    }
    return buf;
}

/*************************************************************************/

/* Reads a line of text from a socket, and strips newline and carriage
 * return characters from the end of the line.
 */

char *sgets2(char *buf, int32 len, Socket *s)
{
    char *str = sgets(buf, len, s);
    if (!str)
	return str;
    str = buf + strlen(buf)-1;
    if (*str == '\n')
	*str-- = 0;
    if (*str == '\r')
	*str = 0;
    return buf;
}

/*************************************************************************/

/* Write a string to a socket, like fputs().  Returns the number of bytes
 * written.
 */

int sputs(char *str, Socket *s)
{
    if (!str || !s) {
	log("sockets: sputs() with %s!",
	    !s ? "NULL socket" : "NULL string");
    }
    return buffered_write(s, str, strlen(str));
}

/*************************************************************************/

/* Write to a socket a la [v]printf().  Returns the number of bytes written;
 * in no case will more than 65535 bytes be written (if the output would be
 * be longer than this, it will be truncated).
 */

int sockprintf(Socket *s, const char *fmt, ...)
{
    va_list args;
    int ret;

    va_start(args, fmt);
    ret = vsockprintf(s, fmt, args);
    va_end(args);
    return ret;
}

int vsockprintf(Socket *s, const char *fmt, va_list args)
{
    char buf[65536];
    if (!s || !fmt) {
	log("sockets: [v]sockprintf() with %s!",
	    !s ? "NULL socket" : "NULL format string");
	errno = EINVAL;
	return -1;
    }
    return buffered_write(s, buf, vsnprintf(buf, sizeof(buf), fmt, args));
}

/*************************************************************************/
/************************** Internal routines ****************************/
/*************************************************************************/

/* Accept a connection on the given socket.  Called from check_sockets(). */

static void do_accept(Socket *s)
{
    int i;
    struct sockaddr_in sin;
    socklen_t sin_len = sizeof(sin);
    int newfd;

    newfd = accept(s->fd, (struct sockaddr *)&sin, &sin_len);
    if (newfd < 0) {
	if (errno != ECONNRESET)
	    log_perror("sockets: accept(%d)", s->fd);
    } else if (fcntl(newfd, F_SETFL, O_NDELAY) < 0) {
	log_perror("sockets: fcntl(NDELAY) on accept(%d)", s->fd);
	close(newfd);
    } else {
	Socket *news = sock_new();
	if (!news) {
	    log("sockets: accept(%d): Unable to create socket structure"
		" (out of buffer space?)", s->fd);
	} else {
	    news->fd = newfd;
	    news->flags |= SF_SELFCREATED | SF_CONNECTED;
	    memcpy(&news->remote, &sin, sin_len);
	    FD_SET(newfd, &sock_fds);
	    for (i = newfd; i < max_fd; i++);
	    if (max_fd < newfd+1) {
		sockets = srealloc(sockets, (newfd+1) * sizeof(*sockets));
		for (i = max_fd; i < newfd; i++)
		    sockets[i] = NULL;
		max_fd = newfd+1;
	    }
	    sockets[newfd] = news;
	    s->cb_accept(s, news);
	}
    }
}

/*************************************************************************/

/* Fill up the read buffer of a socket with any data that may have arrived.
 * Returns the number of bytes read (nonzero), or -1 on error; errno is set
 * by read() calls but is otherwise preserved.
 */

static int fill_read_buffer(Socket *s)
{
    int nread = 0;
    int errno_save = errno;

    if (s->fd < 0) {
	errno = EBADF;
	return -1;
    }
    while (read_buffer_len(s) < s->rbufsize-1) {
	int maxread, res;
	if (s->rend < s->rptr)	/* wrapped around? */
	    maxread = (s->rptr-1) - s->rend;
	else if (s->rptr == s->rbuf)
	    maxread = s->rtop - s->rend - 1;
	else
	    maxread = s->rtop - s->rend;
	do {
	    errno = 0;
	    res = read(s->fd, s->rend, maxread);
	    if (res <= 0 && errno == 0)
		errno = ECONNRESET;  /* make a guess */
	} while (res <= 0 && errno == EINTR);
	errno_save = errno;
	if (debug >= 3)
	    log("debug: sockets: fill_read_buffer wanted %d, got %d",
		maxread, nread);
	if (res <= 0) {
	    if (nread == 0)
		nread = -1;
	    break;
	}
	nread += res;
	s->total_read += res / 1024;
	s->total_read_B += res % 1024;
	if (s->total_read_B >= 1024) {
	    s->total_read += s->total_read_B / 1024;
	    s->total_read_B %= 1024;
	}
	s->rend += res;
	if (s->rend == s->rtop)
	    s->rend = s->rbuf;
    }
    if (nread == 0) {
	nread = -1;
	errno = ENOBUFS;
    } else {
	errno = errno_save;
    }
    return nread;
}

/*************************************************************************/

/* Try and write up to one chunk of data from the buffer to the socket.
 * Return how much was written.
 */

static int flush_write_buffer(Socket *s)
{
    int maxwrite, nwritten;

    if (s->fd < 0) {
	errno = EBADF;
	return -1;
    }
    if (!sock_isconn(s))  /* not yet connected */
	return 0;
    if (s->wend != s->wptr) {
	if (s->wptr > s->wend)	/* wrapped around? */
	    maxwrite = s->wtop - s->wptr;
	else
	    maxwrite = s->wend - s->wptr;
	nwritten = write(s->fd, s->wptr, maxwrite);
	if (debug >= 3)
	    log("debug: sockets: flush_write_buffer wanted %d, got %d",
		maxwrite, nwritten);
	if (nwritten < 0 && errno != EAGAIN && errno != EINTR) {
	    int errno_save = errno;
	    if (errno != ECONNRESET && errno != EPIPE)
		log_perror("sockets: flush_write_buffer(%d)", s->fd);
	    do_disconn(s, DISCONN_REMOTE);
	    errno = errno_save;
	    return -1;
	}
	if (nwritten > 0) {
	    s->flags &= ~SF_WARNED;
	    s->wptr += nwritten;
	    if (s->wptr >= s->wtop)
		s->wptr = s->wbuf;
	    s->total_written += nwritten / 1024;
	    s->total_written_B += nwritten % 1024;
	    if (s->total_written_B >= 1024) {
		s->total_written += s->total_written_B / 1024;
		s->total_written_B %= 1024;
	    }
	    return nwritten;
	}
    }
    if (s->wptr == s->wend) {
	FD_CLR(s->fd, &write_fds);
	if (s->flags & SF_DISCONNECT) {
	    s->flags &= ~SF_DISCONNECT;
	    do_disconn(s, DISCONN_LOCAL);
	} else {
	    reclaim_buffer_space_one(s);
	}
    }
    return 0;
}

/*************************************************************************/

/* Resize a socket's read or write buffer. */

static void resize_rbuf(Socket *s, uint32 size)
{
    if (size <= read_buffer_len(s)) {
/*	log("sockets: BUG: resize_rbuf(%d): size (%d) <= rlen (%d)",
	    s->fd, size, read_buffer_len(s));
*/	return;
    }
    resize_buf(&s->rbuf, &s->rptr, &s->rend, &s->rtop, size);
    s->rbufsize = size;
}


static void resize_wbuf(Socket *s, uint32 size)
{
    if (size <= write_buffer_len(s)) {
/*	log("sockets: BUG: resize_wbuf(%d): size (%d) <= wlen (%d)",
	    s->fd, size, write_buffer_len(s));
*/	return;
    }
    resize_buf(&s->wbuf, &s->wptr, &s->wend, &s->wtop, size);
    s->wbufsize = size;
}

/* Routine that does the actual resizing.  Assumes that newsize >= current
 * size. */
static void resize_buf(char **p_buf, char **p_ptr, char **p_end, char **p_top,
		       uint32 newsize)
{
    uint32 size = *p_top - *p_buf;
    char *newbuf;
    uint32 len = 0;

    if (newsize <= size)
	return;
    newbuf = smalloc(newsize);
    /* Copy old data to new buffer, if any */
    if (*p_end < *p_ptr) {
	len = *p_top - *p_ptr;
	memcpy(newbuf, *p_ptr, len);
	*p_ptr = *p_buf;
    }
    if (*p_end > *p_ptr) {
	memcpy(newbuf+len, *p_ptr, *p_end - *p_ptr);
	len += *p_end - *p_ptr;
    }
    free(*p_buf);
    *p_buf = newbuf;
    *p_ptr = newbuf;
    *p_end = newbuf + len;
    *p_top = newbuf + newsize;
}

/*************************************************************************/

/* Try to reclaim unused buffer space.  Return 1 if some buffer space was
 * freed, 0 if not.
 */

static int reclaim_buffer_space_one(Socket *s)
{
    uint32 rlen = read_buffer_len(s), wlen = write_buffer_len(s);
    int retval = 0;

    if (s->rbufsize > NET_MIN_BUFSIZE
     && rlen < s->rbufsize - NET_MIN_BUFSIZE
    ) {
	if (rlen < NET_MIN_BUFSIZE) {
	    rlen = NET_MIN_BUFSIZE;
	} else {
	    /* Round up to the next multiple of NET_MIN_BUFSIZE, leaving
	     * at least one byte available */
	    rlen += NET_MIN_BUFSIZE;
	    rlen /= NET_MIN_BUFSIZE;
	    rlen *= NET_MIN_BUFSIZE;
	}
	resize_rbuf(s, rlen);
	retval = 1;
    }
    if (s->wbufsize > NET_MIN_BUFSIZE
     && wlen < s->wbufsize - NET_MIN_BUFSIZE
    ) {
	if (wlen < NET_MIN_BUFSIZE) {
	    wlen = NET_MIN_BUFSIZE;
	} else {
	    wlen += NET_MIN_BUFSIZE;
	    wlen /= NET_MIN_BUFSIZE;
	    wlen *= NET_MIN_BUFSIZE;
	}
	resize_wbuf(s, wlen);
	retval = 1;
    }
    return retval;
}


static int reclaim_buffer_space(void)
{
    Socket *s;
    int retval = 0;

    LIST_FOREACH (s, allsockets) {
	retval |= reclaim_buffer_space_one(s);
    }
    return retval;
}

/*************************************************************************/

/* Write data to a socket with buffering. */

static int buffered_write(Socket *s, const char *buf, int len)
{
    int nwritten, left = len;
    int errno_save = errno;

    if (s->fd < 0) {
	errno = EBADF;
	return -1;
    }

    while (left > 0) {

	/* Fill up to the current buffer size. */
	if (write_buffer_len(s) < s->wbufsize-1) {
	    int maxwrite;
	    /* If buffer is empty, reset pointers to beginning for efficiency*/
	    if (write_buffer_len(s) == 0)
		s->wptr = s->wend = s->wbuf;
	    if (s->wptr == s->wbuf) {
		/* Buffer not wrapped */
		maxwrite = s->wtop - s->wend - 1;
	    } else {
		/* Buffer is wrapped.  If this write would reach to or past
		 * the end of the buffer, write it first and reset the end
		 * pointer to the beginning of the buffer. */
		if (s->wend+left >= s->wtop && s->wptr <= s->wend) {
		    nwritten = s->wtop - s->wend;
		    memcpy(s->wend, buf, nwritten);
		    buf += nwritten;
		    left -= nwritten;
		    s->wend = s->wbuf;
		}
		/* Now we can copy a single chunk to wend. */
		if (s->wptr > s->wend)
		    maxwrite = s->wptr - s->wend - 1;
		else
		    maxwrite = left;  /* guaranteed to fit from above code */
	    }
	    if (left > maxwrite)
		nwritten = maxwrite;
	    else
		nwritten = left;
	    if (nwritten) {
		memcpy(s->wend, buf, nwritten);
		buf += nwritten;
		left -= nwritten;
		s->wend += nwritten;
	    }
	}

	/* Now write to the socket as much as we can. */
	flush_write_buffer(s);
	errno_save = errno;
	if (write_buffer_len(s) >= s->wbufsize-1) {
	    /* Write failed on full buffer; try to expand the buffer. */
	    int over = 0, over_total = 0;
	    if (NetBufferSize && s->rbufsize + s->wbufsize >= NetBufferSize)
		over = 1;
	    if (TotalNetBufferSize && total_bufsize >= TotalNetBufferSize)
		over_total = 1;
	    if (over || over_total) {
		if (s->flags & SF_BLOCKING) {
		    fd_set fds;
		    FD_ZERO(&fds);
		    FD_SET(s->fd, &fds);
		    if (select(s->fd+1, NULL, &fds, NULL, NULL) < 0) {
			log("sockets: waiting on blocking socket %d: %s",
			    s->fd, strerror(errno));
			break;
		    }
		    continue;  /* don't expand the buffer, since it's at max */
		} else {
		    if (!(s->flags & SF_WARNED)) {
			log("sockets: socket %d exceeded %s buffer size"
			    " limit (%d)", s->fd,
			    over ? "per-connection" : "total",
			    over ? NetBufferSize : TotalNetBufferSize);
			s->flags |= SF_WARNED;
		    }
		    errno_save = EAGAIN;
		    break;
		}
	    }
	    resize_wbuf(s, s->wbufsize + NET_MIN_BUFSIZE);
	}

    } /* while (left > 0) */

    /* If the socket wasn't closed for an error and data is left in the
     * write buffer, tell check_sockets() to try and flush it */
    if (s->fd >= 0 && s->wptr != s->wend)
	FD_SET(s->fd, &write_fds);

    errno = errno_save;
    return len - left;
}

/*************************************************************************/

/* Internal version of disconn(), used to pass a specific code to the
 * disconnect callback.  If code == DISCONN_LOCAL, attempt to first write
 * out any data left in the write buffer, and delay disconnection if we
 * can't.
 */

static int do_disconn(Socket *s, void *code)
{
    int errno_save = errno;  /* for passing to the callback */

    if (s == NULL || (s->flags & SF_LISTENER)) {
	if (s)
	    log("sockets: do_disconn(%d) with listener socket (%d)!",
		(int)(long)code, s->fd);
	else
	    log("sockets: do_disconn(%d) with NULL socket!", (int)(long)code);
	errno = EINVAL;
	return -1;
    }
    if (s->flags & SF_DISCONN_CB)
	return 0;
    if (!(s->flags & (SF_CONNECTING | SF_CONNECTED)))
	return 0;
    if (code == DISCONN_LOCAL && s->wptr != s->wend) {
	/* Write out any buffered data */
	flush_write_buffer(s);
	if (s->wptr != s->wend) {
	    /* Some data is still buffered; request disconnect after it
	     * goes out */
	    s->flags |= SF_DISCONNECT;
	    /* It's not technically disconnected yet, but it will (should)
	     * succeed eventually */
	    return 0;
	}
    }
    shutdown(s->fd, 2);
    sock_closefd(s);
    s->flags |= SF_DISCONN_CB;
    if (s->cb_disconn) {
	errno = errno_save;
	s->cb_disconn(s, code);
    }
    if (s->fd >= 0) {
	/* The socket was reconnected */
	s->flags &= ~SF_DISCONN_CB;
	return 0;
    }
    s->flags &= ~(SF_CONNECTING | SF_CONNECTED | SF_DISCONN_CB);
    if (s->flags & (SF_SELFCREATED | SF_DELETEME)) {
	if (s->flags & SF_CALLBACK)
	    s->flags |= SF_DELETEME;
	else
	    sock_free(s);
    } else {
	reclaim_buffer_space_one(s);
    }
    return 0;
}

/*************************************************************************/

/* Close a socket's file descriptor, and clear it from all associated
 * structures (s->fd, sockets[], sock_fds, write_fds).
 */

static void sock_closefd(Socket *s)
{
    int i;

    close(s->fd);
    FD_CLR(s->fd, &sock_fds);
    FD_CLR(s->fd, &write_fds);
    sockets[s->fd] = NULL;
    s->fd = -1;
    i = max_fd;
    while (i > 0 && !sockets[i-1])
	i--;
    if (i < max_fd) {
	sockets = srealloc(sockets, sizeof(*sockets) * i);
	max_fd = i;
    }
}

/*************************************************************************/


syntax highlighted by Code2HTML, v. 0.9.1