#include <sysdeps.h>
#include <errno.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>
#include <unistd.h>
#include <fmt/number.h>
#include <iobuf/iobuf.h>
#include <msg/msg.h>
#include <str/str.h>
#include "ucspi-proxy.h"
const int msg_show_pid = 1;
static unsigned long bytes_client_in = 0;
static unsigned long bytes_client_out = 0;
static unsigned long bytes_server_in = 0;
static unsigned long bytes_server_out = 0;
int opt_verbose = 0;
static unsigned opt_timeout = 30;
pid_t pid;
int SERVER_FD = -1;
struct filter_node
{
int fd;
filter_fn filter;
eof_fn at_eof;
char* name;
struct filter_node* next;
};
struct filter_node* filters = 0;
static bool new_filter(int fd, filter_fn filter, eof_fn at_eof)
{
struct filter_node* newnode = malloc(sizeof *filters);
if(!newnode)
return false;
newnode->fd = fd;
newnode->filter = filter;
newnode->at_eof = at_eof;
newnode->next = 0;
if (fd == CLIENT_IN)
newnode->name = "client";
else if (fd == SERVER_FD)
newnode->name = "server";
else {
newnode->name = malloc(4 + fmt_udec(0, fd));
strcpy(newnode->name, "FD#");
newnode->name[fmt_udec(newnode->name+3, fd)+3] = 0;
}
if(!filters)
filters = newnode;
else {
struct filter_node* ptr = filters;
while(ptr->next)
ptr = ptr->next;
ptr->next = newnode;
}
return true;
}
bool set_filter(int fd, filter_fn filter, eof_fn at_eof)
{
struct filter_node* node;
for (node = filters; node != 0; node = node->next) {
if (node->fd == fd) {
node->filter = filter;
node->at_eof = at_eof;
return true;
}
}
return new_filter(fd, filter, at_eof);
}
bool del_filter(int fd)
{
struct filter_node* prev = 0;
struct filter_node* curr = filters;
while(curr) {
if(curr->fd == fd) {
if(prev)
prev->next = curr->next;
else
filters = curr->next;
free(curr->name);
free(curr);
return true;
}
}
return false;
}
static void handle_fd(struct filter_node* filter)
{
char buf[BUFSIZE+1];
ssize_t rd = read(filter->fd, buf, BUFSIZE);
if(rd == -1) {
if (errno == EAGAIN || errno == EINTR)
return;
die2sys(1, "Error reading from ", filter->name);
exit(1);
}
if(rd == 0) {
if (opt_verbose) msg2(filter->name, " hangup");
if(filter->at_eof)
filter->at_eof();
else
exit(0);
}
else {
buf[rd] = 0; /* Add an extra NUL for string searches in filter */
if (filter->fd == CLIENT_IN)
bytes_client_in += rd;
else if (filter->fd == SERVER_FD)
bytes_server_in += rd;
filter->filter(buf, rd);
}
}
static void retry_write(const char* data, ssize_t size,
int fd, const char* name, unsigned long* counter)
{
ssize_t wr;
iopoll_fd io;
io.fd = fd;
while(size > 0) {
io.events = IOPOLL_WRITE;
io.revents = 0;
switch (iopoll_restart(&io, 1, -1)) {
case -1:
die1sys(1, "Poll failed");
case 0:
die2(1, "Connection closed during write to ", name);
}
switch (wr = write(fd, data, size)) {
case 0:
die2(1, "Short write to ", name);
case -1:
die2sys(1, "Error writing to ", name);
default:
data += wr;
size -= wr;
*counter += wr;
}
}
}
void write_client(const char* data, ssize_t size)
{
retry_write(data, size, CLIENT_OUT, "client", &bytes_client_out);
}
void writes_client(const char* data)
{
write_client(data, strlen(data));
}
void write_server(const char* data, ssize_t size)
{
retry_write(data, size, SERVER_FD, "server", &bytes_server_out);
}
void writes_server(const char* data)
{
write_server(data, strlen(data));
}
static void exitfn(void)
{
char line[42+FMT_ULONG_LEN*4];
int i;
memcpy(line, "bytes: client->server ", 22); i = 22;
i += fmt_udec(line+i, bytes_client_in);
line[i++] = '-'; line[i++] = '>';
i += fmt_udec(line+i, bytes_server_out);
memcpy(line+i, " server->client ", 16); i += 16;
i += fmt_udec(line+i, bytes_server_in);
line[i++] = '-'; line[i++] = '>';
i += fmt_udec(line+i, bytes_client_out);
line[i] = 0;
msg1(line);
filter_deinit();
}
void usage(const char* message)
{
if(message)
msg1(message);
obuf_put4s(&errbuf, "usage: ", program,
" [-v] [-t timeout] host port ", filter_usage);
obuf_endl(&errbuf);
exit(1);
}
static void connfail(void)
{
str buf = {0,0,0};
str_copy4s(&buf,
filter_connfail_prefix,
"Connection to server failed: ",
strerror(errno),
filter_connfail_suffix);
write_client(buf.s, buf.len);
exit(0);
}
static void parse_args(int argc, char* argv[])
{
int opt;
unsigned tmp;
char* end;
while((opt = getopt(argc, argv, "vt:")) != EOF) {
switch(opt) {
case 'v':
opt_verbose++;
break;
case 't':
tmp = strtoul(optarg, &end, 10);
if (tmp == 0 || *end != 0)
usage("Invalid timeout");
opt_timeout = tmp;
break;
default:
usage("Unknown option.");
break;
}
}
if (argc - optind < 2)
usage("Missing host and port");
if ((SERVER_FD = tcp_connect(argv[optind], argv[optind+1],
opt_timeout)) == -1)
connfail();
optind += 2;
filter_init(argc-optind, argv+optind);
}
int main(int argc, char* argv[])
{
fd_set fds;
signal(SIGALRM, SIG_IGN);
signal(SIGHUP, SIG_IGN);
signal(SIGPIPE, SIG_IGN);
parse_args(argc, argv);
atexit(exitfn);
pid = getpid();
for(;;) {
struct filter_node* filter;
int maxfd = -1;
FD_ZERO(&fds);
for(filter = filters; filter; filter = filter->next) {
int fd = filter->fd;
FD_SET(fd, &fds);
if(fd > maxfd)
maxfd = fd;
}
while(select(maxfd+1, &fds, 0, 0, 0) == -1) {
if(errno != EINTR)
usage("select failed!");
}
for(filter = filters; filter; filter = filter->next)
if(FD_ISSET(filter->fd, &fds)) {
handle_fd(filter);
break;
}
}
}
syntax highlighted by Code2HTML, v. 0.9.1