/* Socket routines.
*
* IRC Services is copyright (c) 1996-2007 Andrew Church.
* E-mail: <achurch@achurch.org>
* Parts written by Andrew Kempe and others.
* This program is free but copyrighted software; see the file COPYING for
* details.
*/
#include "services.h"
#include <fcntl.h>
#include <sys/socket.h>
#include <netdb.h>
#include <netinet/in.h>
#include <arpa/inet.h>
/*************************************************************************/
/* Socket data structure */
struct socket_ {
Socket *next, *prev;
int fd; /* Socket's file descriptor */
int flags; /* Status flags (SF_*) */
struct sockaddr_in remote; /* Remote address */
/* Usage of pointers:
* - Xbuf is the buffer base pointer
* - Xptr is the address of the next character to store
* - Xend is the address of the next character to retrieve
* - Xtop is the address of the last byte of the buffer + 1
* Xend-Xptr (mod Xbufsize) gives the number of bytes in the buffer.
*/
char *rbuf, *rptr, *rend, *rtop; /* Read buffer and pointers */
char *wbuf, *wptr, *wend, *wtop; /* Write buffer and pointers */
uint32 rbufsize, wbufsize; /* Buffer sizes */
SocketCallback cb_connect; /* Connect callback */
SocketCallback cb_disconn; /* Disconnect callback */
SocketCallback cb_accept; /* Accept callback */
SocketCallback cb_read; /* Data-available callback */
SocketCallback cb_readline; /* Line-available callback */
uint32 total_read; /* Total number of kilobytes read */
int total_read_B; /* Fractional number of kilobytes read */
uint32 total_written; /* Total number of kilobytes written */
int total_written_B; /* Fractional number of kilobytes written */
};
#define SF_SELFCREATED 0x0001 /* We created this socket ourselves */
#define SF_CONNECTING 0x0002 /* Socket is busy connecting */
#define SF_CONNECTED 0x0004 /* Socket has connected */
#define SF_LISTENER 0x0008 /* Socket is a listener socket */
#define SF_CALLBACK 0x0010 /* Currently calling callbacks */
#define SF_WARNED 0x0020 /* Warned about hitting buffer limit */
#define SF_DISCONNECT 0x0040 /* Disconnect when writebuf empty */
#define SF_DELETEME 0x0080 /* Delete socket when convenient */
#define SF_DISCONN_CB 0x0100 /* Disconnect callback has been called (used
* to prevent duplicate calls) */
#define SF_BLOCKING 0x0200 /* Writes are blocking */
/*************************************************************************/
/* List of all sockets (even unopened ones) */
static Socket *allsockets = NULL;
/* Array of all opened sockets, dynamically allocated */
static Socket **sockets = NULL;
/* Highest FD number in use plus 1; also length of sockets[] array */
static int max_fd;
/* Set of all connected socket FDs */
static fd_set sock_fds;
/* Set of all FDs that need data written (or are connecting) */
static fd_set write_fds;
/* Total memory used by socket buffers */
static uint32 total_bufsize;
/*************************************************************************/
/* Internal routine declarations (definitions at bottom of file) */
static void do_accept(Socket *s);
static int fill_read_buffer(Socket *s);
static int flush_write_buffer(Socket *s);
static void resize_rbuf(Socket *s, uint32 size);
static void resize_wbuf(Socket *s, uint32 size);
static void resize_buf(char **p_buf, char **p_ptr, char **p_end, char **p_top,
uint32 newsize);
static int reclaim_buffer_space_one(Socket *s);
static int reclaim_buffer_space(void);
static int buffered_write(Socket *s, const char *buf, int len);
static int do_disconn(Socket *s, void *code);
static void sock_closefd(Socket *s);
/*************************************************************************/
/*************************** Global routines *****************************/
/*************************************************************************/
/* Create and return a new socket. Returns NULL if unsuccessful (i.e. no
* more space for buffers).
*/
Socket *sock_new(void)
{
Socket *s;
if (TotalNetBufferSize) {
while (total_bufsize + NET_MIN_BUFSIZE*2 > TotalNetBufferSize) {
if (!reclaim_buffer_space()) {
log("sockets: sock_new(): out of buffer space!");
return NULL;
}
}
}
s = smalloc(sizeof(*s));
s->fd = -1;
s->flags = 0;
memset(&s->remote, 0, sizeof(s->remote));
s->rbuf = s->rptr = s->rend = smalloc(NET_MIN_BUFSIZE);
s->rtop = s->rbuf + NET_MIN_BUFSIZE;
s->wbuf = s->wptr = s->wend = smalloc(NET_MIN_BUFSIZE);
s->wtop = s->wbuf + NET_MIN_BUFSIZE;
s->rbufsize = s->wbufsize = NET_MIN_BUFSIZE;
s->cb_connect = NULL;
s->cb_disconn = NULL;
s->cb_accept = NULL;
s->cb_read = NULL;
s->cb_readline = NULL;
s->total_read = s->total_written = 0;
s->total_read_B = s->total_written_B = 0;
LIST_INSERT(s, allsockets);
total_bufsize += s->rbufsize + s->wbufsize;
return s;
}
/*************************************************************************/
/* Free a socket, first disconnecting/closing it if necessary. */
void sock_free(Socket *s)
{
if (!s) {
log("sockets: sock_free() with NULL socket!");
errno = EINVAL;
return;
}
if (s->flags & (SF_CONNECTING | SF_CONNECTED)) {
s->flags |= SF_DELETEME;
do_disconn(s, DISCONN_LOCAL);
/* do_disconn() will call us again at the appropriate time */
return;
} else if (s->flags & SF_LISTENER) {
close_listener(s);
}
LIST_REMOVE(s, allsockets);
total_bufsize -= s->rbufsize + s->wbufsize;
free(s->rbuf);
free(s->wbuf);
free(s);
}
/*************************************************************************/
/* Set a callback on a socket. */
void sock_setcb(Socket *s, int which, SocketCallback func)
{
if (!s) {
log("sockets: sock_setcb() with NULL socket!");
errno = EINVAL;
return;
}
switch (which) {
case SCB_CONNECT: s->cb_connect = func; break;
case SCB_DISCONNECT: s->cb_disconn = func; break;
case SCB_ACCEPT: s->cb_accept = func; break;
case SCB_READ: s->cb_read = func; break;
case SCB_READLINE: s->cb_readline = func; break;
default:
log("sockets: sock_setcb(): invalid callback ID %d", which);
break;
}
}
/*************************************************************************/
/* Return whether the given socket is currently connected. */
int sock_isconn(const Socket *s)
{
if (!s) {
log("sockets: sock_isconn() with NULL socket!");
errno = EINVAL;
return 0;
}
return s->flags & SF_CONNECTED ? 1 : 0;
}
/*************************************************************************/
/* Retrieve address of remote end of socket. Functions the same way as
* getpeername() (initialize *lenptr to sizeof(sa), address returned in sa,
* non-truncated length of address returned in *lenptr). Returns -1 with
* errno == EINVAL if a NULL pointer is passed or the given socket is not
* connected.
*/
int sock_remote(const Socket *s, struct sockaddr *sa, int *lenptr)
{
if (!s || !sa || !lenptr || !(s->flags & SF_CONNECTED)) {
if (!s || !sa || !lenptr) {
log("sockets: sock_remote() with NULL %s!",
!s ? "socket" : !sa ? "sockaddr" : "lenptr");
}
errno = EINVAL;
return -1;
}
if (sizeof(s->remote) <= *lenptr)
memcpy(sa, &s->remote, sizeof(s->remote));
else
memcpy(sa, &s->remote, *lenptr);
*lenptr = sizeof(s->remote);
return 0;
}
/*************************************************************************/
/* Set whether socket writes should block (blocking != 0) or not
* (blocking == 0).
*/
void sock_set_blocking(Socket *s, int blocking)
{
if (!s) {
log("sockets: sock_set_blocking() with NULL socket!");
errno = EINVAL;
return;
}
if (blocking)
s->flags |= SF_BLOCKING;
else
s->flags &= ~SF_BLOCKING;
}
/*************************************************************************/
/* Return whether socket writes are blocking (return value != 0) or not
* (return value == 0).
*/
int sock_get_blocking(Socket *s)
{
if (!s) {
log("sockets: socket_get_blocking() with NULL socket!");
errno = EINVAL;
return -1;
}
return s->flags & SF_BLOCKING;
}
/*************************************************************************/
/* Return amount of data in read buffer. Assumes socket is valid. */
inline uint32 read_buffer_len(const Socket *s)
{
if (s->rend >= s->rptr)
return s->rend - s->rptr;
else
return (s->rend + s->rbufsize) - s->rptr;
}
/*************************************************************************/
/* Return amount of data in write buffer. Assumes socket is valid. */
inline uint32 write_buffer_len(const Socket *s)
{
if (s->wend >= s->wptr)
return s->wend - s->wptr;
else
return (s->wend + s->wbufsize) - s->wptr;
}
/*************************************************************************/
/* Return total number of kilobytes received and sent on this socket in
* *readkb_ret and *writekb_ret respectively. Sent data count does not
* include buffered but unsent data.
*/
void sock_rwstat(const Socket *s, uint32 *readkb_ret, uint32 *writekb_ret)
{
if (!s) {
log("sockets: sock_rwstat() with NULL socket!");
errno = EINVAL;
return;
}
if (readkb_ret)
*readkb_ret = s->total_read;
if (writekb_ret)
*writekb_ret = s->total_written;
}
/*************************************************************************/
/* Return the larger of (1) the ratio of the given socket's total buffer
* size (read and write buffers combined) to NetBufferSize and (2) the
* ratio of the amount of memory used by all sockets' buffers to
* TotalNetBufferSize, as a percentage rounded up to the next integer.
* Ratios (1) and (2) are zero if NetBufferSize or TotalNetBufferSize,
* respectively, are not set.
*
* If any of `socksize_ret', `totalsize_ret', `ratio1_ret', and
* `ratio2_ret' are non-NULL, they are set respectively to the given
* socket's total buffer size, the amount of memory used by all sockets'
* buffers, ratio (1) as a percentage, and ratio (2) as a percentage.
*
* If `s' is NULL, ratio (1) is set to zero, and *socksize_ret is not
* modified.
*/
int sock_bufstat(const Socket *s, uint32 *socksize_ret,
uint32 *totalsize_ret, int *ratio1_ret, int *ratio2_ret)
{
int ratio1 = 0, ratio2 = 0;
if (NetBufferSize && s) {
uint32 size = s->rbufsize + s->wbufsize;
if (NetBufferSize <= 0x7FFFFFFF/100)
ratio1 = (size*100 + NetBufferSize-1) / NetBufferSize;
else
ratio1 = (size + NetBufferSize/100-1) / (NetBufferSize/100);
}
if (TotalNetBufferSize) {
if (NetBufferSize <= 0x7FFFFFFF/100)
ratio2 = (total_bufsize*100 + TotalNetBufferSize-1)
/ TotalNetBufferSize;
else
ratio2 = (total_bufsize + TotalNetBufferSize/100-1)
/ (TotalNetBufferSize/100);
}
if (socksize_ret && s)
*socksize_ret = s->rbufsize + s->wbufsize;
if (totalsize_ret)
*totalsize_ret = total_bufsize;
if (ratio1_ret)
*ratio1_ret = ratio1;
if (ratio2_ret)
*ratio2_ret = ratio2;
if (ratio1 > ratio2)
return ratio1;
else
return ratio2;
}
/*************************************************************************/
/*************************************************************************/
/* Check all sockets for activity, and call callbacks as necessary.
* Returns after activity has been detected on at least one socket or
* ReadTimeout milliseconds have elapsed.
*/
void check_sockets(void)
{
fd_set rfds, wfds;
struct timeval tv;
int i;
int32 res;
Socket *s, *s2;
rfds = sock_fds;
wfds = write_fds;
tv.tv_sec = ReadTimeout/1000;
tv.tv_usec = (ReadTimeout%1000) * 1000;
enable_signals();
do {
/* Note: On systems which don't return the time remaining in `tv',
* this loop could continue forever if Services keeps getting
* signals, but as that shouldn't ordinarily happen we don't worry
* about it. */
res = select(max_fd, &rfds, &wfds, NULL, &tv);
} while (res < 0 && errno == EINTR);
disable_signals();
if (debug >= 3)
log("debug: sockets: select returned %d", res);
if (res <= 0) {
if (res < 0)
log_perror("sockets: select()");
return;
}
s = NULL;
for (i = 0; i < max_fd; i++) {
if (s)
s->flags &= ~SF_CALLBACK;
s = sockets[i];
if (s && s->fd != i) {
log("sockets: BUG: sockets[%d]->fd = %d (should be equal),"
" clearing socket from table", i, s->fd);
sockets[i] = NULL;
}
if (s)
s->flags |= SF_CALLBACK;
if (FD_ISSET(i, &wfds)) {
if (debug >= 3)
log("debug: sockets: write ready on fd %d", i);
if (!s) {
log("sockets: BUG: got write-ready on fd %d but no socket"
" for it!", i);
continue;
} else if (s->flags & SF_CONNECTING) {
/* Connection established (or failed) */
int val;
socklen_t vallen;
vallen = sizeof(val);
if (debug >= 2)
log("debug: sockets: connect on fd %d returned", i);
if (getsockopt(i, SOL_SOCKET, SO_ERROR, &val, &vallen) < 0) {
log_perror("sockets: getsockopt(SO_ERROR) for connect (%d"
" -> %s:%u)", i, inet_ntoa(s->remote.sin_addr),
htons(s->remote.sin_port));
do_disconn(s, DISCONN_CONNFAIL);
continue;
}
if (val != 0) {
errno = val;
log_perror("sockets: connect(%d -> %s:%u)", i,
inet_ntoa(s->remote.sin_addr),
htons(s->remote.sin_port));
do_disconn(s, DISCONN_CONNFAIL);
continue;
} else {
if (s->cb_connect)
s->cb_connect(s, 0);
}
if (s->fd >= 0) { /* the socket might have been closed */
s->flags &= ~SF_CONNECTING;
s->flags |= SF_CONNECTED;
FD_SET(i, &sock_fds);
}
FD_CLR(i, &write_fds);
} else if (!(s->flags & SF_CONNECTED)) {
log("sockets: BUG: got write-ready on fd %d but socket not"
" connected!", i);
} else {
flush_write_buffer(s);
}
} /* set in write fds */
if (FD_ISSET(i, &rfds)) {
if (debug >= 3)
log("debug: sockets: read ready on fd %d", i);
if (!s) {
log("sockets: BUG: got data on fd %d but no socket for it!",i);
FD_CLR(i, &sock_fds);
continue;
}
if ((s->flags & SF_LISTENER) && s->cb_accept) {
/* Connection arrived */
do_accept(s);
continue;
} else if (!(s->flags & SF_CONNECTED)) {
log("sockets: BUG: got data on fd %d but not connected!", i);
FD_CLR(i, &sock_fds);
continue;
}
/* Normal read */
if (read_buffer_len(s) >= s->rbufsize-1) {
/* Buffer is full, try to expand it */
int newsize = 0;
if (s->rbufsize < NET_MIN_BUFSIZE)
newsize = NET_MIN_BUFSIZE;
else if (s->rbufsize+NET_MIN_BUFSIZE < NetBufferSize)
newsize = s->rbufsize + NET_MIN_BUFSIZE;
if (newsize > 0)
resize_rbuf(s, newsize);
}
res = fill_read_buffer(s);
if (res < 0) {
/* Connection was closed (or some other error occurred) */
if (debug && res < 0)
log_perror("debug: sockets: read(%d)", i);
do_disconn(s, DISCONN_REMOTE);
continue;
} else if (res == 0) {
log_perror("sockets: BUG: fill_read_buffer() returned 0!");
} else {
uint32 left = read_buffer_len(s), newleft;
if (left == 0) {
log("sockets: BUG: 0 bytes avail after successful read!");
continue;
}
/* Call read callback(s) in a loop until no more data is
* left or neither callback takes any data, or the socket is
* disconnected */
do {
newleft = left;
if (s->cb_read) {
s->cb_read(s, (void *)(long)newleft);
if ((s->flags & SF_DISCONNECT) || s->fd < 0)
break;
newleft = read_buffer_len(s);
}
if (s->cb_readline) {
char *newline;
if (s->rend > s->rptr) {
newline = memchr(s->rptr, '\n', newleft);
} else {
newline = memchr(s->rptr, '\n', s->rtop - s->rptr);
if (!newline)
newline = memchr(s->rbuf, '\n',
s->rend - s->rbuf);
}
if (newline) {
s->cb_readline(s, (void *)(long)newleft);
if ((s->flags & SF_DISCONNECT) || s->fd < 0)
break;
newleft = read_buffer_len(s);
}
}
} while (newleft != left && (left = newleft) != 0);
reclaim_buffer_space_one(s);
}
} /* socket ready for reading */
if (s && (s->flags & SF_DELETEME))
sock_free(s);
} /* for all sockets */
if (s) /* clear SF_CALLBACK from last socket */
s->flags &= ~SF_CALLBACK;
/* Clear out any ready-to-be-deleted sockets */
/* FIXME: this should never happen but it seems to anyway */
LIST_FOREACH_SAFE (s, allsockets, s2) {
if (s->flags & SF_DELETEME)
sock_free(s);
}
} /* check_sockets() */
/*************************************************************************/
/*************************************************************************/
/* Initiate a connection to the given host and port. If an error occurs,
* returns -1, else returns 0. The connection is not necessarily complete
* even if 0 is returned, and may later fail; use the SCB_CONNECT and
* SCB_DISCONNECT callbacks. If this function fails due to inability to
* resolve a hostname, errno will be set to the negative of h_errno; pass
* the negative of this value to hstrerror() to get an appropriate error
* message.
*
* lhost/lport specify the local side of the connection. If they are not
* given (lhost==NULL, lport==0), then they are left to be set by the OS.
*
* If either host or lhost is not a valid IP address and the gethostbyname()
* function is available, this function may block while the hostname is
* being resolved.
*
* This function may be called from a socket's disconnect callback to
* establish a new connection using the same socket. It may not be called,
* however, if the socket is being freed with sock_free().
*/
int conn(Socket *s, const char *host, int port, const char *lhost, int lport)
{
#if HAVE_GETHOSTBYNAME
struct hostent *hp;
#endif
uint8 *addr;
struct sockaddr_in sa, lsa;
int fd, i;
if (!s || !host || port <= 0 || port > 65535) {
if (port <= 0 || port > 65535)
log("sockets: conn() with bad port number (%d)!", port);
else
log("sockets: conn() with NULL %s!", !s ? "socket" : "hostname");
errno = EINVAL;
return -1;
}
if (s->flags & SF_DELETEME) {
log("sockets: conn() called on a freeing socket (%p)", s);
errno = EPERM;
return -1;
}
memset(&lsa, 0, sizeof(lsa));
lsa.sin_family = AF_INET;
if (lhost) {
if ((addr = pack_ip(lhost)) != 0)
memcpy((char *)&lsa.sin_addr, addr, 4);
#if HAVE_GETHOSTBYNAME
else if ((hp = gethostbyname(lhost)) != NULL)
memcpy((char *)&lsa.sin_addr, hp->h_addr, hp->h_length);
#endif
else
lhost = NULL;
}
if (lport)
lsa.sin_port = htons(lport);
memset(&sa, 0, sizeof(sa));
if ((addr = pack_ip(host)) != 0) {
memcpy((char *)&sa.sin_addr, addr, 4);
sa.sin_family = AF_INET;
}
#if HAVE_GETHOSTBYNAME
else if ((hp = gethostbyname(host)) != NULL) {
memcpy((char *)&sa.sin_addr, hp->h_addr, hp->h_length);
sa.sin_family = hp->h_addrtype;
} else {
errno = -h_errno;
return -1;
}
#else
else {
log("sockets: conn(): `%s' is not a valid IP address", host);
errno = EINVAL;
return -1;
}
#endif
sa.sin_port = htons((uint16)port);
if ((fd = socket(sa.sin_family, SOCK_STREAM, 0)) < 0)
return -1;
if (fcntl(fd, F_SETFL, O_NDELAY) < 0) {
int errno_save = errno;
close(fd);
errno = errno_save;
return -1;
}
if ((lhost || lport) && bind(fd,(struct sockaddr *)&lsa,sizeof(sa)) < 0) {
int errno_save = errno;
close(fd);
errno = errno_save;
return -1;
}
if ((i = connect(fd, (struct sockaddr *)&sa, sizeof(sa))) < 0
&& errno != EINPROGRESS
) {
int errno_save = errno;
close(fd);
errno = errno_save;
return -1;
}
if (max_fd < fd+1) {
int j;
sockets = srealloc(sockets, (fd+1) * sizeof(*sockets));
for (j = max_fd; j < fd; j++)
sockets[j] = NULL;
max_fd = fd+1;
}
sockets[fd] = s;
s->remote = sa;
s->fd = fd;
if (i == 0) {
s->flags |= SF_CONNECTED;
FD_SET(fd, &sock_fds);
if (s->cb_connect)
s->cb_connect(s, 0);
} else {
s->flags |= SF_CONNECTING;
FD_SET(fd, &write_fds);
}
return 0;
}
/*************************************************************************/
/* Disconnect a socket. Returns 0 on success, -1 on error (s == NULL or
* listener socket). Calling this routine on an already-disconnected
* socket returns success without doing anything. Note that the socket may
* not be disconnected immediately; callers who intend to reuse the socket
* MUST wait until the disconnect callback is called before doing so.
*/
int disconn(Socket *s)
{
if (!s) {
log("sockets: disconn() with NULL socket!");
errno = EINVAL;
return -1;
}
return do_disconn(s, DISCONN_LOCAL);
}
/*************************************************************************/
/* Open a listener socket on the given host and port; returns 0 on success
* (the socket is set up and listening), -1 on error. If `host' has
* multiple addresses, only the first one is used; if `host' is NULL, all
* addresses are bound to. As with conn(), a negative errno value
* indicates a failure to resolve the hostname `host'. `backlog' is the
* backlog limit for incoming connections, and is passed directly to the
* listen() system call.
*
* Note that even after the socket is successfully created, it will not
* accept any connections until/unless the SCB_ACCEPT callback is set.
*
* If host is not a valid IP address and the gethostbyname() function is
* available, this function may block while the hostname is being resolved.
*/
int open_listener(Socket *s, const char *host, int port, int backlog)
{
#if HAVE_GETHOSTBYNAME
struct hostent *hp;
#endif
uint8 *addr;
struct sockaddr_in sa;
int fd, i;
if (!s || port <= 0 || port > 65535 || backlog < 1) {
if (port <= 0 || port > 65535)
log("sockets: open_listener() with bad port number (%d)!", port);
else if (backlog < 1)
log("sockets: open_listener() with bad backlog (%d)!", backlog);
else
log("sockets: open_listener() with NULL socket!");
errno = EINVAL;
return -1;
}
memset(&sa, 0, sizeof(sa));
if (host) {
if ((addr = pack_ip(host)) != 0) {
memcpy((char *)&sa.sin_addr, addr, 4);
sa.sin_family = AF_INET;
}
#if HAVE_GETHOSTBYNAME
else if ((hp = gethostbyname(host)) != NULL) {
memcpy((char *)&sa.sin_addr, hp->h_addr, hp->h_length);
sa.sin_family = hp->h_addrtype;
} else {
errno = -h_errno;
return -1;
}
#else
else {
log("sockets: open_listener(): `%s' is not a valid IP address",
host);
errno = EINVAL;
return -1;
}
#endif
} else { /* !host */
sa.sin_family = AF_INET;
sa.sin_addr.s_addr = INADDR_ANY;
}
sa.sin_port = htons((uint16)port);
if ((fd = socket(sa.sin_family, SOCK_STREAM, 0)) < 0)
return -1;
i = 1;
if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &i, sizeof(i)) < 0) {
log_perror("sockets: open_listener(): setsockopt(%d, SO_REUSEADDR,"
" 1) failed", fd);
}
if (fcntl(fd, F_SETFL, O_NDELAY) < 0) {
int errno_save = errno;
close(fd);
errno = errno_save;
return -1;
}
if (bind(fd, (struct sockaddr *)&sa, sizeof(sa)) < 0) {
int errno_save = errno;
close(fd);
errno = errno_save;
return -1;
}
if (listen(fd, backlog) < 0) {
int errno_save = errno;
close(fd);
errno = errno_save;
return -1;
}
/* Listener sockets don't need read/write buffers */
free(s->rbuf);
free(s->wbuf);
s->rbuf = s->rptr = s->rend = s->rtop = NULL;
s->wbuf = s->wptr = s->wend = s->wtop = NULL;
s->rbufsize = 0;
s->wbufsize = 0;
if (max_fd < fd+1) {
sockets = srealloc(sockets, (fd+1) * sizeof(*sockets));
for (i = max_fd; i < fd; i++)
sockets[i] = NULL;
max_fd = fd+1;
}
sockets[fd] = s;
FD_SET(fd, &sock_fds);
s->fd = fd;
s->flags |= SF_LISTENER;
return 0;
}
/*************************************************************************/
/* Close a listener socket. */
int close_listener(Socket *s)
{
if (s == NULL || !(s->flags & SF_LISTENER)) {
if (s)
log("sockets: close_listener() with non-listener socket (%d)!",
s->fd);
else
log("sockets: close_listener() with NULL socket!");
errno = EINVAL;
return -1;
}
sock_closefd(s);
s->flags &= ~SF_LISTENER;
return 0;
}
/*************************************************************************/
/* Read raw data from a socket, like read(). Returns number of bytes read,
* or -1 on error. Only reads from the buffer (does not attempt to fetch
* more data from the connection).
*/
int32 sread(Socket *s, char *buf, int32 len)
{
int32 nread = 0;
if (!s || !buf || len <= 0) {
log("sockets: sread() with %s!",
!s ? "NULL socket" : !buf ? "NULL buffer" : "len <= 0");
errno = EINVAL;
return -1;
}
if (s->rend < s->rptr) {
/* Buffer data wraps around */
if (s->rptr + len <= s->rtop) {
/* Only need to read from end of buffer */
memcpy(buf, s->rptr, len);
s->rptr += len;
if (s->rptr >= s->rtop)
s->rptr -= s->rbufsize;
return len;
} else {
/* Need to read from both end and beginning */
nread = s->rtop - s->rptr;
memcpy(buf, s->rptr, nread);
s->rptr = s->rbuf;
len -= nread;
/* Continue below */
}
}
/* Read data from beginning of buffer */
if (s->rptr < s->rend) {
if (len > s->rend - s->rptr)
len = s->rend - s->rptr;
memcpy(buf+nread, s->rptr, len);
s->rptr += len;
nread += len;
}
/* Return number of bytes read */
return nread;
}
/*************************************************************************/
/* Write raw data to a socket, like write(). Returns number of bytes
* written, or -1 on error.
*/
int32 swrite(Socket *s, const char *buf, int32 len)
{
if (!s || !buf || len < 0) {
log("sockets: swrite() with %s!",
!s ? "NULL socket" : !buf ? "NULL buffer" : "len <= 0");
errno = EINVAL;
return -1;
}
return buffered_write(s, buf, len);
}
/*************************************************************************/
/*************************************************************************/
/* Read a character from a socket, like fgetc(). Returns EOF if no data
* is available in the socket buffer. Assumes the socket is valid.
*/
int sgetc(Socket *s)
{
int c;
/* No paranoia check here, to save time */
if (s->rptr == s->rend)
return EOF;
c = *s->rptr++;
if (s->rptr >= s->rtop)
s->rptr -= s->rbufsize;
return c;
}
/*************************************************************************/
/* Read a line from a socket, like fgets(). If not enough buffered data
* is available to fill a complete line, or another error occurs, returns
* NULL.
*/
char *sgets(char *buf, int32 len, Socket *s)
{
char *ptr = s->rptr, *eol;
int32 to_top = s->rtop - ptr; /* used for efficiency */
if (!s || !buf || len <= 0) {
log("sockets: sgets[2]() with %s!",
!s ? "NULL socket" : !buf ? "NULL buffer" : "len <= 0");
return NULL;
}
/* Find end of line */
if (s->rend > s->rptr) {
eol = memchr(s->rptr, '\n', s->rend - s->rptr);
} else {
eol = memchr(s->rptr, '\n', to_top);
if (!eol)
eol = memchr(s->rbuf, '\n', s->rend - s->rbuf);
}
if (!eol)
return NULL;
eol++; /* Point 1 byte after \n */
/* Set rptr now; old value is in ptr variable */
s->rptr = eol;
if (s->rptr >= s->rtop) /* >rtop is impossible, but just in case */
s->rptr = s->rbuf;
/* Note: The greatest possible value for eol is s->rend, so as long as
* we ensure that rend doesn't wrap around and reach rptr (i.e. always
* leave at least 1 byte in the buffer unused), we can never have
* eol == ptr here. */
/* Trim eol to <len bytes */
if (eol > ptr) {
if (eol-ptr >= len)
eol = ptr + len-1;
} else {
if (to_top >= len-1) /* we don't mind eol == rtop */
eol = ptr + len-1;
else if (to_top + (eol - s->rbuf) >= len)
eol = s->rbuf + (len-1 - to_top);
}
/* Actually copy to buffer and return */
if (eol > ptr) {
memcpy(buf, ptr, eol-ptr);
buf[eol-ptr] = 0;
} else {
memcpy(buf, ptr, to_top);
memcpy(buf+to_top, s->rbuf, eol - s->rbuf);
buf[to_top + (eol - s->rbuf)] = 0;
}
return buf;
}
/*************************************************************************/
/* Reads a line of text from a socket, and strips newline and carriage
* return characters from the end of the line.
*/
char *sgets2(char *buf, int32 len, Socket *s)
{
char *str = sgets(buf, len, s);
if (!str)
return str;
str = buf + strlen(buf)-1;
if (*str == '\n')
*str-- = 0;
if (*str == '\r')
*str = 0;
return buf;
}
/*************************************************************************/
/* Write a string to a socket, like fputs(). Returns the number of bytes
* written.
*/
int sputs(char *str, Socket *s)
{
if (!str || !s) {
log("sockets: sputs() with %s!",
!s ? "NULL socket" : "NULL string");
}
return buffered_write(s, str, strlen(str));
}
/*************************************************************************/
/* Write to a socket a la [v]printf(). Returns the number of bytes written;
* in no case will more than 65535 bytes be written (if the output would be
* be longer than this, it will be truncated).
*/
int sockprintf(Socket *s, const char *fmt, ...)
{
va_list args;
int ret;
va_start(args, fmt);
ret = vsockprintf(s, fmt, args);
va_end(args);
return ret;
}
int vsockprintf(Socket *s, const char *fmt, va_list args)
{
char buf[65536];
if (!s || !fmt) {
log("sockets: [v]sockprintf() with %s!",
!s ? "NULL socket" : "NULL format string");
errno = EINVAL;
return -1;
}
return buffered_write(s, buf, vsnprintf(buf, sizeof(buf), fmt, args));
}
/*************************************************************************/
/************************** Internal routines ****************************/
/*************************************************************************/
/* Accept a connection on the given socket. Called from check_sockets(). */
static void do_accept(Socket *s)
{
int i;
struct sockaddr_in sin;
socklen_t sin_len = sizeof(sin);
int newfd;
newfd = accept(s->fd, (struct sockaddr *)&sin, &sin_len);
if (newfd < 0) {
if (errno != ECONNRESET)
log_perror("sockets: accept(%d)", s->fd);
} else if (fcntl(newfd, F_SETFL, O_NDELAY) < 0) {
log_perror("sockets: fcntl(NDELAY) on accept(%d)", s->fd);
close(newfd);
} else {
Socket *news = sock_new();
if (!news) {
log("sockets: accept(%d): Unable to create socket structure"
" (out of buffer space?)", s->fd);
} else {
news->fd = newfd;
news->flags |= SF_SELFCREATED | SF_CONNECTED;
memcpy(&news->remote, &sin, sin_len);
FD_SET(newfd, &sock_fds);
for (i = newfd; i < max_fd; i++);
if (max_fd < newfd+1) {
sockets = srealloc(sockets, (newfd+1) * sizeof(*sockets));
for (i = max_fd; i < newfd; i++)
sockets[i] = NULL;
max_fd = newfd+1;
}
sockets[newfd] = news;
s->cb_accept(s, news);
}
}
}
/*************************************************************************/
/* Fill up the read buffer of a socket with any data that may have arrived.
* Returns the number of bytes read (nonzero), or -1 on error; errno is set
* by read() calls but is otherwise preserved.
*/
static int fill_read_buffer(Socket *s)
{
int nread = 0;
int errno_save = errno;
if (s->fd < 0) {
errno = EBADF;
return -1;
}
while (read_buffer_len(s) < s->rbufsize-1) {
int maxread, res;
if (s->rend < s->rptr) /* wrapped around? */
maxread = (s->rptr-1) - s->rend;
else if (s->rptr == s->rbuf)
maxread = s->rtop - s->rend - 1;
else
maxread = s->rtop - s->rend;
do {
errno = 0;
res = read(s->fd, s->rend, maxread);
if (res <= 0 && errno == 0)
errno = ECONNRESET; /* make a guess */
} while (res <= 0 && errno == EINTR);
errno_save = errno;
if (debug >= 3)
log("debug: sockets: fill_read_buffer wanted %d, got %d",
maxread, nread);
if (res <= 0) {
if (nread == 0)
nread = -1;
break;
}
nread += res;
s->total_read += res / 1024;
s->total_read_B += res % 1024;
if (s->total_read_B >= 1024) {
s->total_read += s->total_read_B / 1024;
s->total_read_B %= 1024;
}
s->rend += res;
if (s->rend == s->rtop)
s->rend = s->rbuf;
}
if (nread == 0) {
nread = -1;
errno = ENOBUFS;
} else {
errno = errno_save;
}
return nread;
}
/*************************************************************************/
/* Try and write up to one chunk of data from the buffer to the socket.
* Return how much was written.
*/
static int flush_write_buffer(Socket *s)
{
int maxwrite, nwritten;
if (s->fd < 0) {
errno = EBADF;
return -1;
}
if (!sock_isconn(s)) /* not yet connected */
return 0;
if (s->wend != s->wptr) {
if (s->wptr > s->wend) /* wrapped around? */
maxwrite = s->wtop - s->wptr;
else
maxwrite = s->wend - s->wptr;
nwritten = write(s->fd, s->wptr, maxwrite);
if (debug >= 3)
log("debug: sockets: flush_write_buffer wanted %d, got %d",
maxwrite, nwritten);
if (nwritten < 0 && errno != EAGAIN && errno != EINTR) {
int errno_save = errno;
if (errno != ECONNRESET && errno != EPIPE)
log_perror("sockets: flush_write_buffer(%d)", s->fd);
do_disconn(s, DISCONN_REMOTE);
errno = errno_save;
return -1;
}
if (nwritten > 0) {
s->flags &= ~SF_WARNED;
s->wptr += nwritten;
if (s->wptr >= s->wtop)
s->wptr = s->wbuf;
s->total_written += nwritten / 1024;
s->total_written_B += nwritten % 1024;
if (s->total_written_B >= 1024) {
s->total_written += s->total_written_B / 1024;
s->total_written_B %= 1024;
}
return nwritten;
}
}
if (s->wptr == s->wend) {
FD_CLR(s->fd, &write_fds);
if (s->flags & SF_DISCONNECT) {
s->flags &= ~SF_DISCONNECT;
do_disconn(s, DISCONN_LOCAL);
} else {
reclaim_buffer_space_one(s);
}
}
return 0;
}
/*************************************************************************/
/* Resize a socket's read or write buffer. */
static void resize_rbuf(Socket *s, uint32 size)
{
if (size <= read_buffer_len(s)) {
/* log("sockets: BUG: resize_rbuf(%d): size (%d) <= rlen (%d)",
s->fd, size, read_buffer_len(s));
*/ return;
}
resize_buf(&s->rbuf, &s->rptr, &s->rend, &s->rtop, size);
s->rbufsize = size;
}
static void resize_wbuf(Socket *s, uint32 size)
{
if (size <= write_buffer_len(s)) {
/* log("sockets: BUG: resize_wbuf(%d): size (%d) <= wlen (%d)",
s->fd, size, write_buffer_len(s));
*/ return;
}
resize_buf(&s->wbuf, &s->wptr, &s->wend, &s->wtop, size);
s->wbufsize = size;
}
/* Routine that does the actual resizing. Assumes that newsize >= current
* size. */
static void resize_buf(char **p_buf, char **p_ptr, char **p_end, char **p_top,
uint32 newsize)
{
uint32 size = *p_top - *p_buf;
char *newbuf;
uint32 len = 0;
if (newsize <= size)
return;
newbuf = smalloc(newsize);
/* Copy old data to new buffer, if any */
if (*p_end < *p_ptr) {
len = *p_top - *p_ptr;
memcpy(newbuf, *p_ptr, len);
*p_ptr = *p_buf;
}
if (*p_end > *p_ptr) {
memcpy(newbuf+len, *p_ptr, *p_end - *p_ptr);
len += *p_end - *p_ptr;
}
free(*p_buf);
*p_buf = newbuf;
*p_ptr = newbuf;
*p_end = newbuf + len;
*p_top = newbuf + newsize;
}
/*************************************************************************/
/* Try to reclaim unused buffer space. Return 1 if some buffer space was
* freed, 0 if not.
*/
static int reclaim_buffer_space_one(Socket *s)
{
uint32 rlen = read_buffer_len(s), wlen = write_buffer_len(s);
int retval = 0;
if (s->rbufsize > NET_MIN_BUFSIZE
&& rlen < s->rbufsize - NET_MIN_BUFSIZE
) {
if (rlen < NET_MIN_BUFSIZE) {
rlen = NET_MIN_BUFSIZE;
} else {
/* Round up to the next multiple of NET_MIN_BUFSIZE, leaving
* at least one byte available */
rlen += NET_MIN_BUFSIZE;
rlen /= NET_MIN_BUFSIZE;
rlen *= NET_MIN_BUFSIZE;
}
resize_rbuf(s, rlen);
retval = 1;
}
if (s->wbufsize > NET_MIN_BUFSIZE
&& wlen < s->wbufsize - NET_MIN_BUFSIZE
) {
if (wlen < NET_MIN_BUFSIZE) {
wlen = NET_MIN_BUFSIZE;
} else {
wlen += NET_MIN_BUFSIZE;
wlen /= NET_MIN_BUFSIZE;
wlen *= NET_MIN_BUFSIZE;
}
resize_wbuf(s, wlen);
retval = 1;
}
return retval;
}
static int reclaim_buffer_space(void)
{
Socket *s;
int retval = 0;
LIST_FOREACH (s, allsockets) {
retval |= reclaim_buffer_space_one(s);
}
return retval;
}
/*************************************************************************/
/* Write data to a socket with buffering. */
static int buffered_write(Socket *s, const char *buf, int len)
{
int nwritten, left = len;
int errno_save = errno;
if (s->fd < 0) {
errno = EBADF;
return -1;
}
while (left > 0) {
/* Fill up to the current buffer size. */
if (write_buffer_len(s) < s->wbufsize-1) {
int maxwrite;
/* If buffer is empty, reset pointers to beginning for efficiency*/
if (write_buffer_len(s) == 0)
s->wptr = s->wend = s->wbuf;
if (s->wptr == s->wbuf) {
/* Buffer not wrapped */
maxwrite = s->wtop - s->wend - 1;
} else {
/* Buffer is wrapped. If this write would reach to or past
* the end of the buffer, write it first and reset the end
* pointer to the beginning of the buffer. */
if (s->wend+left >= s->wtop && s->wptr <= s->wend) {
nwritten = s->wtop - s->wend;
memcpy(s->wend, buf, nwritten);
buf += nwritten;
left -= nwritten;
s->wend = s->wbuf;
}
/* Now we can copy a single chunk to wend. */
if (s->wptr > s->wend)
maxwrite = s->wptr - s->wend - 1;
else
maxwrite = left; /* guaranteed to fit from above code */
}
if (left > maxwrite)
nwritten = maxwrite;
else
nwritten = left;
if (nwritten) {
memcpy(s->wend, buf, nwritten);
buf += nwritten;
left -= nwritten;
s->wend += nwritten;
}
}
/* Now write to the socket as much as we can. */
flush_write_buffer(s);
errno_save = errno;
if (write_buffer_len(s) >= s->wbufsize-1) {
/* Write failed on full buffer; try to expand the buffer. */
int over = 0, over_total = 0;
if (NetBufferSize && s->rbufsize + s->wbufsize >= NetBufferSize)
over = 1;
if (TotalNetBufferSize && total_bufsize >= TotalNetBufferSize)
over_total = 1;
if (over || over_total) {
if (s->flags & SF_BLOCKING) {
fd_set fds;
FD_ZERO(&fds);
FD_SET(s->fd, &fds);
if (select(s->fd+1, NULL, &fds, NULL, NULL) < 0) {
log("sockets: waiting on blocking socket %d: %s",
s->fd, strerror(errno));
break;
}
continue; /* don't expand the buffer, since it's at max */
} else {
if (!(s->flags & SF_WARNED)) {
log("sockets: socket %d exceeded %s buffer size"
" limit (%d)", s->fd,
over ? "per-connection" : "total",
over ? NetBufferSize : TotalNetBufferSize);
s->flags |= SF_WARNED;
}
errno_save = EAGAIN;
break;
}
}
resize_wbuf(s, s->wbufsize + NET_MIN_BUFSIZE);
}
} /* while (left > 0) */
/* If the socket wasn't closed for an error and data is left in the
* write buffer, tell check_sockets() to try and flush it */
if (s->fd >= 0 && s->wptr != s->wend)
FD_SET(s->fd, &write_fds);
errno = errno_save;
return len - left;
}
/*************************************************************************/
/* Internal version of disconn(), used to pass a specific code to the
* disconnect callback. If code == DISCONN_LOCAL, attempt to first write
* out any data left in the write buffer, and delay disconnection if we
* can't.
*/
static int do_disconn(Socket *s, void *code)
{
int errno_save = errno; /* for passing to the callback */
if (s == NULL || (s->flags & SF_LISTENER)) {
if (s)
log("sockets: do_disconn(%d) with listener socket (%d)!",
(int)(long)code, s->fd);
else
log("sockets: do_disconn(%d) with NULL socket!", (int)(long)code);
errno = EINVAL;
return -1;
}
if (s->flags & SF_DISCONN_CB)
return 0;
if (!(s->flags & (SF_CONNECTING | SF_CONNECTED)))
return 0;
if (code == DISCONN_LOCAL && s->wptr != s->wend) {
/* Write out any buffered data */
flush_write_buffer(s);
if (s->wptr != s->wend) {
/* Some data is still buffered; request disconnect after it
* goes out */
s->flags |= SF_DISCONNECT;
/* It's not technically disconnected yet, but it will (should)
* succeed eventually */
return 0;
}
}
shutdown(s->fd, 2);
sock_closefd(s);
s->flags |= SF_DISCONN_CB;
if (s->cb_disconn) {
errno = errno_save;
s->cb_disconn(s, code);
}
if (s->fd >= 0) {
/* The socket was reconnected */
s->flags &= ~SF_DISCONN_CB;
return 0;
}
s->flags &= ~(SF_CONNECTING | SF_CONNECTED | SF_DISCONN_CB);
if (s->flags & (SF_SELFCREATED | SF_DELETEME)) {
if (s->flags & SF_CALLBACK)
s->flags |= SF_DELETEME;
else
sock_free(s);
} else {
reclaim_buffer_space_one(s);
}
return 0;
}
/*************************************************************************/
/* Close a socket's file descriptor, and clear it from all associated
* structures (s->fd, sockets[], sock_fds, write_fds).
*/
static void sock_closefd(Socket *s)
{
int i;
close(s->fd);
FD_CLR(s->fd, &sock_fds);
FD_CLR(s->fd, &write_fds);
sockets[s->fd] = NULL;
s->fd = -1;
i = max_fd;
while (i > 0 && !sockets[i-1])
i--;
if (i < max_fd) {
sockets = srealloc(sockets, sizeof(*sockets) * i);
max_fd = i;
}
}
/*************************************************************************/
syntax highlighted by Code2HTML, v. 0.9.1