/* Socket routines. * * IRC Services is copyright (c) 1996-2007 Andrew Church. * E-mail: * Parts written by Andrew Kempe and others. * This program is free but copyrighted software; see the file COPYING for * details. */ #include "services.h" #include #include #include #include #include /*************************************************************************/ /* Socket data structure */ struct socket_ { Socket *next, *prev; int fd; /* Socket's file descriptor */ int flags; /* Status flags (SF_*) */ struct sockaddr_in remote; /* Remote address */ /* Usage of pointers: * - Xbuf is the buffer base pointer * - Xptr is the address of the next character to store * - Xend is the address of the next character to retrieve * - Xtop is the address of the last byte of the buffer + 1 * Xend-Xptr (mod Xbufsize) gives the number of bytes in the buffer. */ char *rbuf, *rptr, *rend, *rtop; /* Read buffer and pointers */ char *wbuf, *wptr, *wend, *wtop; /* Write buffer and pointers */ uint32 rbufsize, wbufsize; /* Buffer sizes */ SocketCallback cb_connect; /* Connect callback */ SocketCallback cb_disconn; /* Disconnect callback */ SocketCallback cb_accept; /* Accept callback */ SocketCallback cb_read; /* Data-available callback */ SocketCallback cb_readline; /* Line-available callback */ uint32 total_read; /* Total number of kilobytes read */ int total_read_B; /* Fractional number of kilobytes read */ uint32 total_written; /* Total number of kilobytes written */ int total_written_B; /* Fractional number of kilobytes written */ }; #define SF_SELFCREATED 0x0001 /* We created this socket ourselves */ #define SF_CONNECTING 0x0002 /* Socket is busy connecting */ #define SF_CONNECTED 0x0004 /* Socket has connected */ #define SF_LISTENER 0x0008 /* Socket is a listener socket */ #define SF_CALLBACK 0x0010 /* Currently calling callbacks */ #define SF_WARNED 0x0020 /* Warned about hitting buffer limit */ #define SF_DISCONNECT 0x0040 /* Disconnect when writebuf empty */ #define SF_DELETEME 0x0080 /* Delete socket when convenient */ #define SF_DISCONN_CB 0x0100 /* Disconnect callback has been called (used * to prevent duplicate calls) */ #define SF_BLOCKING 0x0200 /* Writes are blocking */ /*************************************************************************/ /* List of all sockets (even unopened ones) */ static Socket *allsockets = NULL; /* Array of all opened sockets, dynamically allocated */ static Socket **sockets = NULL; /* Highest FD number in use plus 1; also length of sockets[] array */ static int max_fd; /* Set of all connected socket FDs */ static fd_set sock_fds; /* Set of all FDs that need data written (or are connecting) */ static fd_set write_fds; /* Total memory used by socket buffers */ static uint32 total_bufsize; /*************************************************************************/ /* Internal routine declarations (definitions at bottom of file) */ static void do_accept(Socket *s); static int fill_read_buffer(Socket *s); static int flush_write_buffer(Socket *s); static void resize_rbuf(Socket *s, uint32 size); static void resize_wbuf(Socket *s, uint32 size); static void resize_buf(char **p_buf, char **p_ptr, char **p_end, char **p_top, uint32 newsize); static int reclaim_buffer_space_one(Socket *s); static int reclaim_buffer_space(void); static int buffered_write(Socket *s, const char *buf, int len); static int do_disconn(Socket *s, void *code); static void sock_closefd(Socket *s); /*************************************************************************/ /*************************** Global routines *****************************/ /*************************************************************************/ /* Create and return a new socket. Returns NULL if unsuccessful (i.e. no * more space for buffers). */ Socket *sock_new(void) { Socket *s; if (TotalNetBufferSize) { while (total_bufsize + NET_MIN_BUFSIZE*2 > TotalNetBufferSize) { if (!reclaim_buffer_space()) { log("sockets: sock_new(): out of buffer space!"); return NULL; } } } s = smalloc(sizeof(*s)); s->fd = -1; s->flags = 0; memset(&s->remote, 0, sizeof(s->remote)); s->rbuf = s->rptr = s->rend = smalloc(NET_MIN_BUFSIZE); s->rtop = s->rbuf + NET_MIN_BUFSIZE; s->wbuf = s->wptr = s->wend = smalloc(NET_MIN_BUFSIZE); s->wtop = s->wbuf + NET_MIN_BUFSIZE; s->rbufsize = s->wbufsize = NET_MIN_BUFSIZE; s->cb_connect = NULL; s->cb_disconn = NULL; s->cb_accept = NULL; s->cb_read = NULL; s->cb_readline = NULL; s->total_read = s->total_written = 0; s->total_read_B = s->total_written_B = 0; LIST_INSERT(s, allsockets); total_bufsize += s->rbufsize + s->wbufsize; return s; } /*************************************************************************/ /* Free a socket, first disconnecting/closing it if necessary. */ void sock_free(Socket *s) { if (!s) { log("sockets: sock_free() with NULL socket!"); errno = EINVAL; return; } if (s->flags & (SF_CONNECTING | SF_CONNECTED)) { s->flags |= SF_DELETEME; do_disconn(s, DISCONN_LOCAL); /* do_disconn() will call us again at the appropriate time */ return; } else if (s->flags & SF_LISTENER) { close_listener(s); } LIST_REMOVE(s, allsockets); total_bufsize -= s->rbufsize + s->wbufsize; free(s->rbuf); free(s->wbuf); free(s); } /*************************************************************************/ /* Set a callback on a socket. */ void sock_setcb(Socket *s, int which, SocketCallback func) { if (!s) { log("sockets: sock_setcb() with NULL socket!"); errno = EINVAL; return; } switch (which) { case SCB_CONNECT: s->cb_connect = func; break; case SCB_DISCONNECT: s->cb_disconn = func; break; case SCB_ACCEPT: s->cb_accept = func; break; case SCB_READ: s->cb_read = func; break; case SCB_READLINE: s->cb_readline = func; break; default: log("sockets: sock_setcb(): invalid callback ID %d", which); break; } } /*************************************************************************/ /* Return whether the given socket is currently connected. */ int sock_isconn(const Socket *s) { if (!s) { log("sockets: sock_isconn() with NULL socket!"); errno = EINVAL; return 0; } return s->flags & SF_CONNECTED ? 1 : 0; } /*************************************************************************/ /* Retrieve address of remote end of socket. Functions the same way as * getpeername() (initialize *lenptr to sizeof(sa), address returned in sa, * non-truncated length of address returned in *lenptr). Returns -1 with * errno == EINVAL if a NULL pointer is passed or the given socket is not * connected. */ int sock_remote(const Socket *s, struct sockaddr *sa, int *lenptr) { if (!s || !sa || !lenptr || !(s->flags & SF_CONNECTED)) { if (!s || !sa || !lenptr) { log("sockets: sock_remote() with NULL %s!", !s ? "socket" : !sa ? "sockaddr" : "lenptr"); } errno = EINVAL; return -1; } if (sizeof(s->remote) <= *lenptr) memcpy(sa, &s->remote, sizeof(s->remote)); else memcpy(sa, &s->remote, *lenptr); *lenptr = sizeof(s->remote); return 0; } /*************************************************************************/ /* Set whether socket writes should block (blocking != 0) or not * (blocking == 0). */ void sock_set_blocking(Socket *s, int blocking) { if (!s) { log("sockets: sock_set_blocking() with NULL socket!"); errno = EINVAL; return; } if (blocking) s->flags |= SF_BLOCKING; else s->flags &= ~SF_BLOCKING; } /*************************************************************************/ /* Return whether socket writes are blocking (return value != 0) or not * (return value == 0). */ int sock_get_blocking(Socket *s) { if (!s) { log("sockets: socket_get_blocking() with NULL socket!"); errno = EINVAL; return -1; } return s->flags & SF_BLOCKING; } /*************************************************************************/ /* Return amount of data in read buffer. Assumes socket is valid. */ inline uint32 read_buffer_len(const Socket *s) { if (s->rend >= s->rptr) return s->rend - s->rptr; else return (s->rend + s->rbufsize) - s->rptr; } /*************************************************************************/ /* Return amount of data in write buffer. Assumes socket is valid. */ inline uint32 write_buffer_len(const Socket *s) { if (s->wend >= s->wptr) return s->wend - s->wptr; else return (s->wend + s->wbufsize) - s->wptr; } /*************************************************************************/ /* Return total number of kilobytes received and sent on this socket in * *readkb_ret and *writekb_ret respectively. Sent data count does not * include buffered but unsent data. */ void sock_rwstat(const Socket *s, uint32 *readkb_ret, uint32 *writekb_ret) { if (!s) { log("sockets: sock_rwstat() with NULL socket!"); errno = EINVAL; return; } if (readkb_ret) *readkb_ret = s->total_read; if (writekb_ret) *writekb_ret = s->total_written; } /*************************************************************************/ /* Return the larger of (1) the ratio of the given socket's total buffer * size (read and write buffers combined) to NetBufferSize and (2) the * ratio of the amount of memory used by all sockets' buffers to * TotalNetBufferSize, as a percentage rounded up to the next integer. * Ratios (1) and (2) are zero if NetBufferSize or TotalNetBufferSize, * respectively, are not set. * * If any of `socksize_ret', `totalsize_ret', `ratio1_ret', and * `ratio2_ret' are non-NULL, they are set respectively to the given * socket's total buffer size, the amount of memory used by all sockets' * buffers, ratio (1) as a percentage, and ratio (2) as a percentage. * * If `s' is NULL, ratio (1) is set to zero, and *socksize_ret is not * modified. */ int sock_bufstat(const Socket *s, uint32 *socksize_ret, uint32 *totalsize_ret, int *ratio1_ret, int *ratio2_ret) { int ratio1 = 0, ratio2 = 0; if (NetBufferSize && s) { uint32 size = s->rbufsize + s->wbufsize; if (NetBufferSize <= 0x7FFFFFFF/100) ratio1 = (size*100 + NetBufferSize-1) / NetBufferSize; else ratio1 = (size + NetBufferSize/100-1) / (NetBufferSize/100); } if (TotalNetBufferSize) { if (NetBufferSize <= 0x7FFFFFFF/100) ratio2 = (total_bufsize*100 + TotalNetBufferSize-1) / TotalNetBufferSize; else ratio2 = (total_bufsize + TotalNetBufferSize/100-1) / (TotalNetBufferSize/100); } if (socksize_ret && s) *socksize_ret = s->rbufsize + s->wbufsize; if (totalsize_ret) *totalsize_ret = total_bufsize; if (ratio1_ret) *ratio1_ret = ratio1; if (ratio2_ret) *ratio2_ret = ratio2; if (ratio1 > ratio2) return ratio1; else return ratio2; } /*************************************************************************/ /*************************************************************************/ /* Check all sockets for activity, and call callbacks as necessary. * Returns after activity has been detected on at least one socket or * ReadTimeout milliseconds have elapsed. */ void check_sockets(void) { fd_set rfds, wfds; struct timeval tv; int i; int32 res; Socket *s, *s2; rfds = sock_fds; wfds = write_fds; tv.tv_sec = ReadTimeout/1000; tv.tv_usec = (ReadTimeout%1000) * 1000; enable_signals(); do { /* Note: On systems which don't return the time remaining in `tv', * this loop could continue forever if Services keeps getting * signals, but as that shouldn't ordinarily happen we don't worry * about it. */ res = select(max_fd, &rfds, &wfds, NULL, &tv); } while (res < 0 && errno == EINTR); disable_signals(); if (debug >= 3) log("debug: sockets: select returned %d", res); if (res <= 0) { if (res < 0) log_perror("sockets: select()"); return; } s = NULL; for (i = 0; i < max_fd; i++) { if (s) s->flags &= ~SF_CALLBACK; s = sockets[i]; if (s && s->fd != i) { log("sockets: BUG: sockets[%d]->fd = %d (should be equal)," " clearing socket from table", i, s->fd); sockets[i] = NULL; } if (s) s->flags |= SF_CALLBACK; if (FD_ISSET(i, &wfds)) { if (debug >= 3) log("debug: sockets: write ready on fd %d", i); if (!s) { log("sockets: BUG: got write-ready on fd %d but no socket" " for it!", i); continue; } else if (s->flags & SF_CONNECTING) { /* Connection established (or failed) */ int val; socklen_t vallen; vallen = sizeof(val); if (debug >= 2) log("debug: sockets: connect on fd %d returned", i); if (getsockopt(i, SOL_SOCKET, SO_ERROR, &val, &vallen) < 0) { log_perror("sockets: getsockopt(SO_ERROR) for connect (%d" " -> %s:%u)", i, inet_ntoa(s->remote.sin_addr), htons(s->remote.sin_port)); do_disconn(s, DISCONN_CONNFAIL); continue; } if (val != 0) { errno = val; log_perror("sockets: connect(%d -> %s:%u)", i, inet_ntoa(s->remote.sin_addr), htons(s->remote.sin_port)); do_disconn(s, DISCONN_CONNFAIL); continue; } else { if (s->cb_connect) s->cb_connect(s, 0); } if (s->fd >= 0) { /* the socket might have been closed */ s->flags &= ~SF_CONNECTING; s->flags |= SF_CONNECTED; FD_SET(i, &sock_fds); } FD_CLR(i, &write_fds); } else if (!(s->flags & SF_CONNECTED)) { log("sockets: BUG: got write-ready on fd %d but socket not" " connected!", i); } else { flush_write_buffer(s); } } /* set in write fds */ if (FD_ISSET(i, &rfds)) { if (debug >= 3) log("debug: sockets: read ready on fd %d", i); if (!s) { log("sockets: BUG: got data on fd %d but no socket for it!",i); FD_CLR(i, &sock_fds); continue; } if ((s->flags & SF_LISTENER) && s->cb_accept) { /* Connection arrived */ do_accept(s); continue; } else if (!(s->flags & SF_CONNECTED)) { log("sockets: BUG: got data on fd %d but not connected!", i); FD_CLR(i, &sock_fds); continue; } /* Normal read */ if (read_buffer_len(s) >= s->rbufsize-1) { /* Buffer is full, try to expand it */ int newsize = 0; if (s->rbufsize < NET_MIN_BUFSIZE) newsize = NET_MIN_BUFSIZE; else if (s->rbufsize+NET_MIN_BUFSIZE < NetBufferSize) newsize = s->rbufsize + NET_MIN_BUFSIZE; if (newsize > 0) resize_rbuf(s, newsize); } res = fill_read_buffer(s); if (res < 0) { /* Connection was closed (or some other error occurred) */ if (debug && res < 0) log_perror("debug: sockets: read(%d)", i); do_disconn(s, DISCONN_REMOTE); continue; } else if (res == 0) { log_perror("sockets: BUG: fill_read_buffer() returned 0!"); } else { uint32 left = read_buffer_len(s), newleft; if (left == 0) { log("sockets: BUG: 0 bytes avail after successful read!"); continue; } /* Call read callback(s) in a loop until no more data is * left or neither callback takes any data, or the socket is * disconnected */ do { newleft = left; if (s->cb_read) { s->cb_read(s, (void *)(long)newleft); if ((s->flags & SF_DISCONNECT) || s->fd < 0) break; newleft = read_buffer_len(s); } if (s->cb_readline) { char *newline; if (s->rend > s->rptr) { newline = memchr(s->rptr, '\n', newleft); } else { newline = memchr(s->rptr, '\n', s->rtop - s->rptr); if (!newline) newline = memchr(s->rbuf, '\n', s->rend - s->rbuf); } if (newline) { s->cb_readline(s, (void *)(long)newleft); if ((s->flags & SF_DISCONNECT) || s->fd < 0) break; newleft = read_buffer_len(s); } } } while (newleft != left && (left = newleft) != 0); reclaim_buffer_space_one(s); } } /* socket ready for reading */ if (s && (s->flags & SF_DELETEME)) sock_free(s); } /* for all sockets */ if (s) /* clear SF_CALLBACK from last socket */ s->flags &= ~SF_CALLBACK; /* Clear out any ready-to-be-deleted sockets */ /* FIXME: this should never happen but it seems to anyway */ LIST_FOREACH_SAFE (s, allsockets, s2) { if (s->flags & SF_DELETEME) sock_free(s); } } /* check_sockets() */ /*************************************************************************/ /*************************************************************************/ /* Initiate a connection to the given host and port. If an error occurs, * returns -1, else returns 0. The connection is not necessarily complete * even if 0 is returned, and may later fail; use the SCB_CONNECT and * SCB_DISCONNECT callbacks. If this function fails due to inability to * resolve a hostname, errno will be set to the negative of h_errno; pass * the negative of this value to hstrerror() to get an appropriate error * message. * * lhost/lport specify the local side of the connection. If they are not * given (lhost==NULL, lport==0), then they are left to be set by the OS. * * If either host or lhost is not a valid IP address and the gethostbyname() * function is available, this function may block while the hostname is * being resolved. * * This function may be called from a socket's disconnect callback to * establish a new connection using the same socket. It may not be called, * however, if the socket is being freed with sock_free(). */ int conn(Socket *s, const char *host, int port, const char *lhost, int lport) { #if HAVE_GETHOSTBYNAME struct hostent *hp; #endif uint8 *addr; struct sockaddr_in sa, lsa; int fd, i; if (!s || !host || port <= 0 || port > 65535) { if (port <= 0 || port > 65535) log("sockets: conn() with bad port number (%d)!", port); else log("sockets: conn() with NULL %s!", !s ? "socket" : "hostname"); errno = EINVAL; return -1; } if (s->flags & SF_DELETEME) { log("sockets: conn() called on a freeing socket (%p)", s); errno = EPERM; return -1; } memset(&lsa, 0, sizeof(lsa)); lsa.sin_family = AF_INET; if (lhost) { if ((addr = pack_ip(lhost)) != 0) memcpy((char *)&lsa.sin_addr, addr, 4); #if HAVE_GETHOSTBYNAME else if ((hp = gethostbyname(lhost)) != NULL) memcpy((char *)&lsa.sin_addr, hp->h_addr, hp->h_length); #endif else lhost = NULL; } if (lport) lsa.sin_port = htons(lport); memset(&sa, 0, sizeof(sa)); if ((addr = pack_ip(host)) != 0) { memcpy((char *)&sa.sin_addr, addr, 4); sa.sin_family = AF_INET; } #if HAVE_GETHOSTBYNAME else if ((hp = gethostbyname(host)) != NULL) { memcpy((char *)&sa.sin_addr, hp->h_addr, hp->h_length); sa.sin_family = hp->h_addrtype; } else { errno = -h_errno; return -1; } #else else { log("sockets: conn(): `%s' is not a valid IP address", host); errno = EINVAL; return -1; } #endif sa.sin_port = htons((uint16)port); if ((fd = socket(sa.sin_family, SOCK_STREAM, 0)) < 0) return -1; if (fcntl(fd, F_SETFL, O_NDELAY) < 0) { int errno_save = errno; close(fd); errno = errno_save; return -1; } if ((lhost || lport) && bind(fd,(struct sockaddr *)&lsa,sizeof(sa)) < 0) { int errno_save = errno; close(fd); errno = errno_save; return -1; } if ((i = connect(fd, (struct sockaddr *)&sa, sizeof(sa))) < 0 && errno != EINPROGRESS ) { int errno_save = errno; close(fd); errno = errno_save; return -1; } if (max_fd < fd+1) { int j; sockets = srealloc(sockets, (fd+1) * sizeof(*sockets)); for (j = max_fd; j < fd; j++) sockets[j] = NULL; max_fd = fd+1; } sockets[fd] = s; s->remote = sa; s->fd = fd; if (i == 0) { s->flags |= SF_CONNECTED; FD_SET(fd, &sock_fds); if (s->cb_connect) s->cb_connect(s, 0); } else { s->flags |= SF_CONNECTING; FD_SET(fd, &write_fds); } return 0; } /*************************************************************************/ /* Disconnect a socket. Returns 0 on success, -1 on error (s == NULL or * listener socket). Calling this routine on an already-disconnected * socket returns success without doing anything. Note that the socket may * not be disconnected immediately; callers who intend to reuse the socket * MUST wait until the disconnect callback is called before doing so. */ int disconn(Socket *s) { if (!s) { log("sockets: disconn() with NULL socket!"); errno = EINVAL; return -1; } return do_disconn(s, DISCONN_LOCAL); } /*************************************************************************/ /* Open a listener socket on the given host and port; returns 0 on success * (the socket is set up and listening), -1 on error. If `host' has * multiple addresses, only the first one is used; if `host' is NULL, all * addresses are bound to. As with conn(), a negative errno value * indicates a failure to resolve the hostname `host'. `backlog' is the * backlog limit for incoming connections, and is passed directly to the * listen() system call. * * Note that even after the socket is successfully created, it will not * accept any connections until/unless the SCB_ACCEPT callback is set. * * If host is not a valid IP address and the gethostbyname() function is * available, this function may block while the hostname is being resolved. */ int open_listener(Socket *s, const char *host, int port, int backlog) { #if HAVE_GETHOSTBYNAME struct hostent *hp; #endif uint8 *addr; struct sockaddr_in sa; int fd, i; if (!s || port <= 0 || port > 65535 || backlog < 1) { if (port <= 0 || port > 65535) log("sockets: open_listener() with bad port number (%d)!", port); else if (backlog < 1) log("sockets: open_listener() with bad backlog (%d)!", backlog); else log("sockets: open_listener() with NULL socket!"); errno = EINVAL; return -1; } memset(&sa, 0, sizeof(sa)); if (host) { if ((addr = pack_ip(host)) != 0) { memcpy((char *)&sa.sin_addr, addr, 4); sa.sin_family = AF_INET; } #if HAVE_GETHOSTBYNAME else if ((hp = gethostbyname(host)) != NULL) { memcpy((char *)&sa.sin_addr, hp->h_addr, hp->h_length); sa.sin_family = hp->h_addrtype; } else { errno = -h_errno; return -1; } #else else { log("sockets: open_listener(): `%s' is not a valid IP address", host); errno = EINVAL; return -1; } #endif } else { /* !host */ sa.sin_family = AF_INET; sa.sin_addr.s_addr = INADDR_ANY; } sa.sin_port = htons((uint16)port); if ((fd = socket(sa.sin_family, SOCK_STREAM, 0)) < 0) return -1; i = 1; if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &i, sizeof(i)) < 0) { log_perror("sockets: open_listener(): setsockopt(%d, SO_REUSEADDR," " 1) failed", fd); } if (fcntl(fd, F_SETFL, O_NDELAY) < 0) { int errno_save = errno; close(fd); errno = errno_save; return -1; } if (bind(fd, (struct sockaddr *)&sa, sizeof(sa)) < 0) { int errno_save = errno; close(fd); errno = errno_save; return -1; } if (listen(fd, backlog) < 0) { int errno_save = errno; close(fd); errno = errno_save; return -1; } /* Listener sockets don't need read/write buffers */ free(s->rbuf); free(s->wbuf); s->rbuf = s->rptr = s->rend = s->rtop = NULL; s->wbuf = s->wptr = s->wend = s->wtop = NULL; s->rbufsize = 0; s->wbufsize = 0; if (max_fd < fd+1) { sockets = srealloc(sockets, (fd+1) * sizeof(*sockets)); for (i = max_fd; i < fd; i++) sockets[i] = NULL; max_fd = fd+1; } sockets[fd] = s; FD_SET(fd, &sock_fds); s->fd = fd; s->flags |= SF_LISTENER; return 0; } /*************************************************************************/ /* Close a listener socket. */ int close_listener(Socket *s) { if (s == NULL || !(s->flags & SF_LISTENER)) { if (s) log("sockets: close_listener() with non-listener socket (%d)!", s->fd); else log("sockets: close_listener() with NULL socket!"); errno = EINVAL; return -1; } sock_closefd(s); s->flags &= ~SF_LISTENER; return 0; } /*************************************************************************/ /* Read raw data from a socket, like read(). Returns number of bytes read, * or -1 on error. Only reads from the buffer (does not attempt to fetch * more data from the connection). */ int32 sread(Socket *s, char *buf, int32 len) { int32 nread = 0; if (!s || !buf || len <= 0) { log("sockets: sread() with %s!", !s ? "NULL socket" : !buf ? "NULL buffer" : "len <= 0"); errno = EINVAL; return -1; } if (s->rend < s->rptr) { /* Buffer data wraps around */ if (s->rptr + len <= s->rtop) { /* Only need to read from end of buffer */ memcpy(buf, s->rptr, len); s->rptr += len; if (s->rptr >= s->rtop) s->rptr -= s->rbufsize; return len; } else { /* Need to read from both end and beginning */ nread = s->rtop - s->rptr; memcpy(buf, s->rptr, nread); s->rptr = s->rbuf; len -= nread; /* Continue below */ } } /* Read data from beginning of buffer */ if (s->rptr < s->rend) { if (len > s->rend - s->rptr) len = s->rend - s->rptr; memcpy(buf+nread, s->rptr, len); s->rptr += len; nread += len; } /* Return number of bytes read */ return nread; } /*************************************************************************/ /* Write raw data to a socket, like write(). Returns number of bytes * written, or -1 on error. */ int32 swrite(Socket *s, const char *buf, int32 len) { if (!s || !buf || len < 0) { log("sockets: swrite() with %s!", !s ? "NULL socket" : !buf ? "NULL buffer" : "len <= 0"); errno = EINVAL; return -1; } return buffered_write(s, buf, len); } /*************************************************************************/ /*************************************************************************/ /* Read a character from a socket, like fgetc(). Returns EOF if no data * is available in the socket buffer. Assumes the socket is valid. */ int sgetc(Socket *s) { int c; /* No paranoia check here, to save time */ if (s->rptr == s->rend) return EOF; c = *s->rptr++; if (s->rptr >= s->rtop) s->rptr -= s->rbufsize; return c; } /*************************************************************************/ /* Read a line from a socket, like fgets(). If not enough buffered data * is available to fill a complete line, or another error occurs, returns * NULL. */ char *sgets(char *buf, int32 len, Socket *s) { char *ptr = s->rptr, *eol; int32 to_top = s->rtop - ptr; /* used for efficiency */ if (!s || !buf || len <= 0) { log("sockets: sgets[2]() with %s!", !s ? "NULL socket" : !buf ? "NULL buffer" : "len <= 0"); return NULL; } /* Find end of line */ if (s->rend > s->rptr) { eol = memchr(s->rptr, '\n', s->rend - s->rptr); } else { eol = memchr(s->rptr, '\n', to_top); if (!eol) eol = memchr(s->rbuf, '\n', s->rend - s->rbuf); } if (!eol) return NULL; eol++; /* Point 1 byte after \n */ /* Set rptr now; old value is in ptr variable */ s->rptr = eol; if (s->rptr >= s->rtop) /* >rtop is impossible, but just in case */ s->rptr = s->rbuf; /* Note: The greatest possible value for eol is s->rend, so as long as * we ensure that rend doesn't wrap around and reach rptr (i.e. always * leave at least 1 byte in the buffer unused), we can never have * eol == ptr here. */ /* Trim eol to ptr) { if (eol-ptr >= len) eol = ptr + len-1; } else { if (to_top >= len-1) /* we don't mind eol == rtop */ eol = ptr + len-1; else if (to_top + (eol - s->rbuf) >= len) eol = s->rbuf + (len-1 - to_top); } /* Actually copy to buffer and return */ if (eol > ptr) { memcpy(buf, ptr, eol-ptr); buf[eol-ptr] = 0; } else { memcpy(buf, ptr, to_top); memcpy(buf+to_top, s->rbuf, eol - s->rbuf); buf[to_top + (eol - s->rbuf)] = 0; } return buf; } /*************************************************************************/ /* Reads a line of text from a socket, and strips newline and carriage * return characters from the end of the line. */ char *sgets2(char *buf, int32 len, Socket *s) { char *str = sgets(buf, len, s); if (!str) return str; str = buf + strlen(buf)-1; if (*str == '\n') *str-- = 0; if (*str == '\r') *str = 0; return buf; } /*************************************************************************/ /* Write a string to a socket, like fputs(). Returns the number of bytes * written. */ int sputs(char *str, Socket *s) { if (!str || !s) { log("sockets: sputs() with %s!", !s ? "NULL socket" : "NULL string"); } return buffered_write(s, str, strlen(str)); } /*************************************************************************/ /* Write to a socket a la [v]printf(). Returns the number of bytes written; * in no case will more than 65535 bytes be written (if the output would be * be longer than this, it will be truncated). */ int sockprintf(Socket *s, const char *fmt, ...) { va_list args; int ret; va_start(args, fmt); ret = vsockprintf(s, fmt, args); va_end(args); return ret; } int vsockprintf(Socket *s, const char *fmt, va_list args) { char buf[65536]; if (!s || !fmt) { log("sockets: [v]sockprintf() with %s!", !s ? "NULL socket" : "NULL format string"); errno = EINVAL; return -1; } return buffered_write(s, buf, vsnprintf(buf, sizeof(buf), fmt, args)); } /*************************************************************************/ /************************** Internal routines ****************************/ /*************************************************************************/ /* Accept a connection on the given socket. Called from check_sockets(). */ static void do_accept(Socket *s) { int i; struct sockaddr_in sin; socklen_t sin_len = sizeof(sin); int newfd; newfd = accept(s->fd, (struct sockaddr *)&sin, &sin_len); if (newfd < 0) { if (errno != ECONNRESET) log_perror("sockets: accept(%d)", s->fd); } else if (fcntl(newfd, F_SETFL, O_NDELAY) < 0) { log_perror("sockets: fcntl(NDELAY) on accept(%d)", s->fd); close(newfd); } else { Socket *news = sock_new(); if (!news) { log("sockets: accept(%d): Unable to create socket structure" " (out of buffer space?)", s->fd); } else { news->fd = newfd; news->flags |= SF_SELFCREATED | SF_CONNECTED; memcpy(&news->remote, &sin, sin_len); FD_SET(newfd, &sock_fds); for (i = newfd; i < max_fd; i++); if (max_fd < newfd+1) { sockets = srealloc(sockets, (newfd+1) * sizeof(*sockets)); for (i = max_fd; i < newfd; i++) sockets[i] = NULL; max_fd = newfd+1; } sockets[newfd] = news; s->cb_accept(s, news); } } } /*************************************************************************/ /* Fill up the read buffer of a socket with any data that may have arrived. * Returns the number of bytes read (nonzero), or -1 on error; errno is set * by read() calls but is otherwise preserved. */ static int fill_read_buffer(Socket *s) { int nread = 0; int errno_save = errno; if (s->fd < 0) { errno = EBADF; return -1; } while (read_buffer_len(s) < s->rbufsize-1) { int maxread, res; if (s->rend < s->rptr) /* wrapped around? */ maxread = (s->rptr-1) - s->rend; else if (s->rptr == s->rbuf) maxread = s->rtop - s->rend - 1; else maxread = s->rtop - s->rend; do { errno = 0; res = read(s->fd, s->rend, maxread); if (res <= 0 && errno == 0) errno = ECONNRESET; /* make a guess */ } while (res <= 0 && errno == EINTR); errno_save = errno; if (debug >= 3) log("debug: sockets: fill_read_buffer wanted %d, got %d", maxread, nread); if (res <= 0) { if (nread == 0) nread = -1; break; } nread += res; s->total_read += res / 1024; s->total_read_B += res % 1024; if (s->total_read_B >= 1024) { s->total_read += s->total_read_B / 1024; s->total_read_B %= 1024; } s->rend += res; if (s->rend == s->rtop) s->rend = s->rbuf; } if (nread == 0) { nread = -1; errno = ENOBUFS; } else { errno = errno_save; } return nread; } /*************************************************************************/ /* Try and write up to one chunk of data from the buffer to the socket. * Return how much was written. */ static int flush_write_buffer(Socket *s) { int maxwrite, nwritten; if (s->fd < 0) { errno = EBADF; return -1; } if (!sock_isconn(s)) /* not yet connected */ return 0; if (s->wend != s->wptr) { if (s->wptr > s->wend) /* wrapped around? */ maxwrite = s->wtop - s->wptr; else maxwrite = s->wend - s->wptr; nwritten = write(s->fd, s->wptr, maxwrite); if (debug >= 3) log("debug: sockets: flush_write_buffer wanted %d, got %d", maxwrite, nwritten); if (nwritten < 0 && errno != EAGAIN && errno != EINTR) { int errno_save = errno; if (errno != ECONNRESET && errno != EPIPE) log_perror("sockets: flush_write_buffer(%d)", s->fd); do_disconn(s, DISCONN_REMOTE); errno = errno_save; return -1; } if (nwritten > 0) { s->flags &= ~SF_WARNED; s->wptr += nwritten; if (s->wptr >= s->wtop) s->wptr = s->wbuf; s->total_written += nwritten / 1024; s->total_written_B += nwritten % 1024; if (s->total_written_B >= 1024) { s->total_written += s->total_written_B / 1024; s->total_written_B %= 1024; } return nwritten; } } if (s->wptr == s->wend) { FD_CLR(s->fd, &write_fds); if (s->flags & SF_DISCONNECT) { s->flags &= ~SF_DISCONNECT; do_disconn(s, DISCONN_LOCAL); } else { reclaim_buffer_space_one(s); } } return 0; } /*************************************************************************/ /* Resize a socket's read or write buffer. */ static void resize_rbuf(Socket *s, uint32 size) { if (size <= read_buffer_len(s)) { /* log("sockets: BUG: resize_rbuf(%d): size (%d) <= rlen (%d)", s->fd, size, read_buffer_len(s)); */ return; } resize_buf(&s->rbuf, &s->rptr, &s->rend, &s->rtop, size); s->rbufsize = size; } static void resize_wbuf(Socket *s, uint32 size) { if (size <= write_buffer_len(s)) { /* log("sockets: BUG: resize_wbuf(%d): size (%d) <= wlen (%d)", s->fd, size, write_buffer_len(s)); */ return; } resize_buf(&s->wbuf, &s->wptr, &s->wend, &s->wtop, size); s->wbufsize = size; } /* Routine that does the actual resizing. Assumes that newsize >= current * size. */ static void resize_buf(char **p_buf, char **p_ptr, char **p_end, char **p_top, uint32 newsize) { uint32 size = *p_top - *p_buf; char *newbuf; uint32 len = 0; if (newsize <= size) return; newbuf = smalloc(newsize); /* Copy old data to new buffer, if any */ if (*p_end < *p_ptr) { len = *p_top - *p_ptr; memcpy(newbuf, *p_ptr, len); *p_ptr = *p_buf; } if (*p_end > *p_ptr) { memcpy(newbuf+len, *p_ptr, *p_end - *p_ptr); len += *p_end - *p_ptr; } free(*p_buf); *p_buf = newbuf; *p_ptr = newbuf; *p_end = newbuf + len; *p_top = newbuf + newsize; } /*************************************************************************/ /* Try to reclaim unused buffer space. Return 1 if some buffer space was * freed, 0 if not. */ static int reclaim_buffer_space_one(Socket *s) { uint32 rlen = read_buffer_len(s), wlen = write_buffer_len(s); int retval = 0; if (s->rbufsize > NET_MIN_BUFSIZE && rlen < s->rbufsize - NET_MIN_BUFSIZE ) { if (rlen < NET_MIN_BUFSIZE) { rlen = NET_MIN_BUFSIZE; } else { /* Round up to the next multiple of NET_MIN_BUFSIZE, leaving * at least one byte available */ rlen += NET_MIN_BUFSIZE; rlen /= NET_MIN_BUFSIZE; rlen *= NET_MIN_BUFSIZE; } resize_rbuf(s, rlen); retval = 1; } if (s->wbufsize > NET_MIN_BUFSIZE && wlen < s->wbufsize - NET_MIN_BUFSIZE ) { if (wlen < NET_MIN_BUFSIZE) { wlen = NET_MIN_BUFSIZE; } else { wlen += NET_MIN_BUFSIZE; wlen /= NET_MIN_BUFSIZE; wlen *= NET_MIN_BUFSIZE; } resize_wbuf(s, wlen); retval = 1; } return retval; } static int reclaim_buffer_space(void) { Socket *s; int retval = 0; LIST_FOREACH (s, allsockets) { retval |= reclaim_buffer_space_one(s); } return retval; } /*************************************************************************/ /* Write data to a socket with buffering. */ static int buffered_write(Socket *s, const char *buf, int len) { int nwritten, left = len; int errno_save = errno; if (s->fd < 0) { errno = EBADF; return -1; } while (left > 0) { /* Fill up to the current buffer size. */ if (write_buffer_len(s) < s->wbufsize-1) { int maxwrite; /* If buffer is empty, reset pointers to beginning for efficiency*/ if (write_buffer_len(s) == 0) s->wptr = s->wend = s->wbuf; if (s->wptr == s->wbuf) { /* Buffer not wrapped */ maxwrite = s->wtop - s->wend - 1; } else { /* Buffer is wrapped. If this write would reach to or past * the end of the buffer, write it first and reset the end * pointer to the beginning of the buffer. */ if (s->wend+left >= s->wtop && s->wptr <= s->wend) { nwritten = s->wtop - s->wend; memcpy(s->wend, buf, nwritten); buf += nwritten; left -= nwritten; s->wend = s->wbuf; } /* Now we can copy a single chunk to wend. */ if (s->wptr > s->wend) maxwrite = s->wptr - s->wend - 1; else maxwrite = left; /* guaranteed to fit from above code */ } if (left > maxwrite) nwritten = maxwrite; else nwritten = left; if (nwritten) { memcpy(s->wend, buf, nwritten); buf += nwritten; left -= nwritten; s->wend += nwritten; } } /* Now write to the socket as much as we can. */ flush_write_buffer(s); errno_save = errno; if (write_buffer_len(s) >= s->wbufsize-1) { /* Write failed on full buffer; try to expand the buffer. */ int over = 0, over_total = 0; if (NetBufferSize && s->rbufsize + s->wbufsize >= NetBufferSize) over = 1; if (TotalNetBufferSize && total_bufsize >= TotalNetBufferSize) over_total = 1; if (over || over_total) { if (s->flags & SF_BLOCKING) { fd_set fds; FD_ZERO(&fds); FD_SET(s->fd, &fds); if (select(s->fd+1, NULL, &fds, NULL, NULL) < 0) { log("sockets: waiting on blocking socket %d: %s", s->fd, strerror(errno)); break; } continue; /* don't expand the buffer, since it's at max */ } else { if (!(s->flags & SF_WARNED)) { log("sockets: socket %d exceeded %s buffer size" " limit (%d)", s->fd, over ? "per-connection" : "total", over ? NetBufferSize : TotalNetBufferSize); s->flags |= SF_WARNED; } errno_save = EAGAIN; break; } } resize_wbuf(s, s->wbufsize + NET_MIN_BUFSIZE); } } /* while (left > 0) */ /* If the socket wasn't closed for an error and data is left in the * write buffer, tell check_sockets() to try and flush it */ if (s->fd >= 0 && s->wptr != s->wend) FD_SET(s->fd, &write_fds); errno = errno_save; return len - left; } /*************************************************************************/ /* Internal version of disconn(), used to pass a specific code to the * disconnect callback. If code == DISCONN_LOCAL, attempt to first write * out any data left in the write buffer, and delay disconnection if we * can't. */ static int do_disconn(Socket *s, void *code) { int errno_save = errno; /* for passing to the callback */ if (s == NULL || (s->flags & SF_LISTENER)) { if (s) log("sockets: do_disconn(%d) with listener socket (%d)!", (int)(long)code, s->fd); else log("sockets: do_disconn(%d) with NULL socket!", (int)(long)code); errno = EINVAL; return -1; } if (s->flags & SF_DISCONN_CB) return 0; if (!(s->flags & (SF_CONNECTING | SF_CONNECTED))) return 0; if (code == DISCONN_LOCAL && s->wptr != s->wend) { /* Write out any buffered data */ flush_write_buffer(s); if (s->wptr != s->wend) { /* Some data is still buffered; request disconnect after it * goes out */ s->flags |= SF_DISCONNECT; /* It's not technically disconnected yet, but it will (should) * succeed eventually */ return 0; } } shutdown(s->fd, 2); sock_closefd(s); s->flags |= SF_DISCONN_CB; if (s->cb_disconn) { errno = errno_save; s->cb_disconn(s, code); } if (s->fd >= 0) { /* The socket was reconnected */ s->flags &= ~SF_DISCONN_CB; return 0; } s->flags &= ~(SF_CONNECTING | SF_CONNECTED | SF_DISCONN_CB); if (s->flags & (SF_SELFCREATED | SF_DELETEME)) { if (s->flags & SF_CALLBACK) s->flags |= SF_DELETEME; else sock_free(s); } else { reclaim_buffer_space_one(s); } return 0; } /*************************************************************************/ /* Close a socket's file descriptor, and clear it from all associated * structures (s->fd, sockets[], sock_fds, write_fds). */ static void sock_closefd(Socket *s) { int i; close(s->fd); FD_CLR(s->fd, &sock_fds); FD_CLR(s->fd, &write_fds); sockets[s->fd] = NULL; s->fd = -1; i = max_fd; while (i > 0 && !sockets[i-1]) i--; if (i < max_fd) { sockets = srealloc(sockets, sizeof(*sockets) * i); max_fd = i; } } /*************************************************************************/