/* a throttling transparent proxy. */
#include <string.h>
#include <stdlib.h>
#include "../gsk.h"
#include "../gsklistmacros.h"
#include "../http/gskhttpcontent.h"
typedef struct _GskThrottleProxyConnection GskThrottleProxyConnection;
typedef struct _Side Side;
/* configuration */
guint upload_per_second_base = 10*1024;
guint download_per_second_base = 100*1024;
guint upload_per_second_noise = 1*1024;
guint download_per_second_noise = 10*1024;
/* if TRUE, shut-down the read and write ends of the connection
independently. if FALSE, either propagation a read or a write
shutdown into both. */
gboolean half_shutdowns = TRUE;
GskSocketAddress *bind_addr = NULL;
GskSocketAddress *server_addr = NULL;
GskSocketAddress *bind_status_addr = NULL;
static guint n_connections_accepted = 0;
static guint64 n_bytes_read_total = 0;
static guint64 n_bytes_written_total = 0;
struct _Side
{
GskThrottleProxyConnection *connection;
GskStream *read_side; /* client for upload, server for download */
GskStream *write_side; /* client for upload, server for download */
gboolean read_side_blocked;
gboolean write_side_blocked;
/* sides are in this list if their xferred_in_last_second==max
but buffer.size < max_buffer */
Side *next_throttled, *prev_throttled;
gboolean throttled;
guint max_xfer_per_second;
gulong last_xfer_second;
guint xferred_in_last_second;
GskBuffer buffer;
guint max_buffer;/* should be set to max_xfer_per_second or a bit more */
guint total_read, total_written;
};
struct _GskThrottleProxyConnection
{
Side upload;
Side download;
guint ref_count;
GskThrottleProxyConnection *prev, *next;
};
static GskThrottleProxyConnection *first_conn, *last_conn;
#define GET_CONNECTION_LIST() \
GskThrottleProxyConnection *, first_conn, last_conn, prev, next
static Side *first_throttled, *last_throttled;
#define GET_THROTTLED_LIST() \
Side *, first_throttled, last_throttled, prev_throttled, next_throttled
#define CURRENT_SECOND() (gsk_main_loop_default ()->current_time.tv_sec)
/* must be called whenever side->buffer changes "emptiness" */
static inline void
update_write_block (Side *side)
{
gboolean old_val = side->write_side_blocked;
gboolean val = (side->read_side != NULL && side->buffer.size == 0);
side->write_side_blocked = val;
if (old_val && !val)
gsk_io_unblock_write (side->write_side);
else if (!old_val && val)
gsk_io_block_write (side->write_side);
}
/* must be called whenever side->buffer changes "emptiness" */
static inline void
update_read_block (Side *side)
{
gboolean was_throttled = side->throttled;
gboolean old_val = side->read_side_blocked;
gboolean xfer_blocked = side->xferred_in_last_second >= side->max_xfer_per_second;
gboolean buf_blocked = side->buffer.size >= side->max_buffer;
gboolean val = xfer_blocked || buf_blocked;
side->throttled = xfer_blocked && !buf_blocked;
side->read_side_blocked = val;
if (side->throttled && !was_throttled)
{
/* put in throttled list */
GSK_LIST_APPEND (GET_THROTTLED_LIST (), side);
}
else if (!side->throttled && was_throttled)
{
/* remove from throttled list */
GSK_LIST_REMOVE (GET_THROTTLED_LIST (), side);
}
if (old_val && !val)
gsk_io_unblock_read (side->read_side);
else if (!old_val && val)
gsk_io_block_read (side->read_side);
}
static void
connection_unref (GskThrottleProxyConnection *conn)
{
if (--(conn->ref_count) == 0)
{
GSK_LIST_REMOVE (GET_CONNECTION_LIST (), conn);
gsk_buffer_destruct (&conn->upload.buffer);
gsk_buffer_destruct (&conn->download.buffer);
g_free (conn);
}
}
static gboolean
handle_side_writable (GskStream *stream,
gpointer data)
{
Side *side = data;
GError *error = NULL;
guint written = gsk_stream_write_buffer (stream, &side->buffer, &error);
if (error)
{
g_warning ("error writing to stream %p: %s",
stream, error->message);
g_error_free (error);
}
n_bytes_written_total += written;
side->total_written += written;
update_write_block (side);
update_read_block (side);
if (written == 0 && side->read_side == NULL && side->buffer.size == 0)
{
update_write_block (side);
if (half_shutdowns)
gsk_io_write_shutdown (side->write_side, NULL);
else
gsk_io_shutdown (GSK_IO (side->write_side), NULL);
}
return TRUE;
}
static gboolean
handle_side_write_shutdown (GskStream *stream,
gpointer data)
{
Side *side = data;
if (side->buffer.size > 0)
g_warning ("write-side shut down while data still pending");
if (side->read_side)
{
if (half_shutdowns)
gsk_io_read_shutdown (side->read_side, NULL);
else
gsk_io_shutdown (GSK_IO (side->read_side), NULL);
}
return FALSE;
}
static void
handle_side_write_destroy (gpointer data)
{
Side *side = data;
g_object_unref (side->write_side);
side->write_side = NULL;
connection_unref (side->connection);
}
static gboolean
handle_side_readable (GskStream *stream,
gpointer data)
{
Side *side = data;
gulong cur_sec = CURRENT_SECOND ();
GError *error = NULL;
guint max_read;
guint nread;
char *tmp;
if (cur_sec == side->last_xfer_second)
{
max_read = side->max_xfer_per_second - side->xferred_in_last_second;
}
else
{
side->xferred_in_last_second = 0;
side->last_xfer_second = cur_sec;
max_read = side->max_xfer_per_second;
}
if (max_read + side->buffer.size > side->max_buffer)
{
if (side->buffer.size > side->max_buffer)
max_read = 0;
else
max_read = side->max_buffer - side->buffer.size;
}
tmp = g_malloc (max_read);
nread = gsk_stream_read (stream, tmp, max_read, &error);
if (error != NULL)
{
g_warning ("error reading from stream %p: %s",
stream, error->message);
g_error_free (error);
}
/* TODO: use append_foreign if nread is big */
gsk_buffer_append (&side->buffer, tmp, nread);
g_free (tmp);
n_bytes_read_total += nread;
side->total_read += nread;
side->xferred_in_last_second += nread;
g_assert (side->xferred_in_last_second <= side->max_xfer_per_second);
update_write_block (side);
update_read_block (side);
return TRUE;
}
static gboolean
handle_side_read_shutdown (GskStream *stream,
gpointer data)
{
return FALSE;
}
static void
handle_side_read_destroy (gpointer data)
{
Side *side = data;
g_object_unref (side->read_side);
side->read_side = NULL;
if (side->buffer.size == 0 && side->write_side != NULL)
{
update_write_block (side);
if (half_shutdowns)
gsk_io_write_shutdown (side->write_side, NULL);
else
gsk_io_shutdown (GSK_IO (side->write_side), NULL);
}
connection_unref (side->connection);
}
static void
side_init (Side *side,
GskThrottleProxyConnection *conn,
GskStream *read_side,
GskStream *write_side,
guint max_xfer_per_second)
{
side->connection = conn;
side->read_side = read_side;
side->write_side = write_side;
side->read_side_blocked = FALSE;
side->write_side_blocked = FALSE;
side->throttled = FALSE;
side->next_throttled = side->prev_throttled = NULL;
side->max_xfer_per_second = max_xfer_per_second;
side->last_xfer_second = gsk_main_loop_default ()->current_time.tv_sec;
side->xferred_in_last_second = 0;
gsk_buffer_construct (&side->buffer);
side->max_buffer = max_xfer_per_second;
side->total_read = 0;
side->total_written = 0;
conn->ref_count += 2;
g_object_ref (read_side);
gsk_io_trap_readable (read_side,
handle_side_readable,
handle_side_read_shutdown,
side,
handle_side_read_destroy);
g_object_ref (write_side);
gsk_io_trap_writable (write_side,
handle_side_writable,
handle_side_write_shutdown,
side,
handle_side_write_destroy);
}
/* --- handle a new stream --- */
static guint
pick_rand (guint base, guint noise)
{
return base + (noise ? g_random_int_range (0, noise) : 0);
}
static gboolean
handle_accept (GskStream *stream,
gpointer data,
GError **error)
{
GskThrottleProxyConnection *conn = g_new (GskThrottleProxyConnection, 1);
GError *e = NULL;
GskStream *server = gsk_stream_new_connecting (server_addr, &e);
if (e)
g_error ("gsk_stream_new_connecting failed: %s", e->message);
n_connections_accepted++;
conn->ref_count = 1;
GSK_LIST_APPEND (GET_CONNECTION_LIST (), conn);
side_init (&conn->upload, conn, stream, server,
pick_rand (upload_per_second_base, upload_per_second_noise));
side_init (&conn->download, conn, server, stream,
pick_rand (download_per_second_base, download_per_second_noise));
connection_unref (conn);
g_object_unref (stream);
g_object_unref (server);
return TRUE;
}
static void
handle_listener_error (GError *error,
gpointer data)
{
g_error ("handle_listener_error: %s", error->message);
}
/* --- unblock throttled streams every second --- */
static gboolean
unblock_timer_func (gpointer data)
{
Side *at = first_throttled;
gulong sec = CURRENT_SECOND ();
while (at)
{
Side *next = at->next_throttled;
g_assert (at->throttled);
g_assert (at->read_side_blocked);
if (sec > at->last_xfer_second)
{
at->last_xfer_second = sec;
at->xferred_in_last_second = 0;
update_read_block (at);
}
at = next;
}
/* schedule next timeout */
gsk_main_loop_add_timer_absolute (gsk_main_loop_default (),
unblock_timer_func, NULL, NULL,
sec + 1, 0);
return FALSE;
}
static void
usage (void)
{
g_printerr ("usage: %s --bind=LISTEN_ADDR --server=CONNECT_ADDR OPTIONS\n\n",
g_get_prgname ());
g_printerr ("Bind to LISTEN_ADDR; whenever we receive a connection,\n"
"proxy to CONNECT_ADDR, obeying thottling constraints.\n"
"\n"
"Options:\n"
" --bind-status=STATUS_ADDR Report status on this addr.\n"
" --upload-rate=BPS ...\n"
" --download-rate=BPS ...\n"
" --upload-rate-noise=BPS ...\n"
" --download-rate-noise=BPS ...\n"
" --full-shutdowns\n"
" --half-shutdowns\n"
);
exit (1);
}
static void
dump_side_to_buffer (Side *side, GskBuffer *out)
{
gsk_buffer_printf (out, "<td>%sreadable%s, %swritable%s, %u buffered [total read/written=%u/%u]</td>\n",
side->read_side ? "" : "NOT ",
side->throttled ? " [throttled]" :
side->read_side_blocked ? " [blocked]" : "",
side->write_side ? "" : "NOT ",
side->write_side_blocked ? " [blocked]" : "",
side->buffer.size,
side->total_read, side->total_written);
}
static GskHttpContentResult
create_status_page (GskHttpContent *content,
GskHttpContentHandler *handler,
GskHttpServer *server,
GskHttpRequest *request,
GskStream *post_data,
gpointer data)
{
GskThrottleProxyConnection *conn;
GskBuffer buffer = GSK_BUFFER_STATIC_INIT;
GskHttpResponse *response;
GskStream *stream;
gsk_buffer_printf (&buffer, "<html><head>\n");
gsk_buffer_printf (&buffer, "<title>GskThrottleProxy Status Page</title>\n");
gsk_buffer_printf (&buffer, "</head>\n");
gsk_buffer_printf (&buffer, "<body>\n");
gsk_buffer_printf (&buffer, "<h1>Statistics</h1>\n");
gsk_buffer_printf (&buffer, "<br>%u connections accepted.\n",
n_connections_accepted);
gsk_buffer_printf (&buffer, "<br>%"G_GUINT64_FORMAT" bytes read.\n",
n_bytes_read_total);
gsk_buffer_printf (&buffer, "<br>%"G_GUINT64_FORMAT" bytes written.\n",
n_bytes_written_total);
gsk_buffer_printf (&buffer, "<h1>Connections</h1>\n");
gsk_buffer_printf (&buffer, "<table>\n"
" <tr><th>Connection Pointer</th>"
"<th>RefCount</th>"
"<th>Upload</th>"
"<th>Download</th>"
"</tr>\n");
for (conn = first_conn; conn; conn = conn->next)
{
gsk_buffer_printf (&buffer,
" <tr><td>%p</td><td>%u</td>", conn, conn->ref_count);
dump_side_to_buffer (&conn->upload, &buffer);
dump_side_to_buffer (&conn->download, &buffer);
gsk_buffer_printf (&buffer, "</tr>\n");
}
gsk_buffer_printf (&buffer, "</table>\n</body>\n</html>\n");
response = gsk_http_response_from_request (request, 200, buffer.size);
gsk_http_header_set_content_type (response, "text");
gsk_http_header_set_content_subtype (response, "html");
stream = gsk_memory_buffer_source_new (&buffer);
gsk_http_server_respond (server, request, response, stream);
g_object_unref (response);
g_object_unref (stream);
return GSK_HTTP_CONTENT_OK;
}
/* --- main --- */
int main(int argc, char **argv)
{
guint i;
GskStreamListener *listener;
GError *error = NULL;
gsk_init_without_threads (&argc, &argv);
for (i = 1; i < (guint) argc; i++)
{
if (g_str_has_prefix (argv[i], "--bind="))
{
const char *bind_str = strchr (argv[i], '=') + 1;
if (bind_addr != NULL)
g_error ("--bind may only be given once");
if (g_ascii_isdigit (bind_str[0]))
{
bind_addr = gsk_socket_address_ipv4_new (gsk_ipv4_ip_address_any,
atoi (bind_str));
}
else
{
bind_addr = gsk_socket_address_local_new (bind_str);
}
}
else if (g_str_has_prefix (argv[i], "--bind-status="))
{
const char *bind_str = strchr (argv[i], '=') + 1;
if (bind_status_addr != NULL)
g_error ("--bind-status may only be given once");
if (g_ascii_isdigit (bind_str[0]))
{
bind_status_addr = gsk_socket_address_ipv4_new (gsk_ipv4_ip_address_any,
atoi (bind_str));
}
else
{
bind_status_addr = gsk_socket_address_local_new (bind_str);
}
}
else if (g_str_has_prefix (argv[i], "--server="))
{
const char *server_str = strchr (argv[i], '=') + 1;
const char *colon = strchr (server_str, ':');
if (server_addr != NULL)
g_error ("--server may only be given once");
if (colon != NULL && strchr (server_str, '/') == NULL)
{
/* host:port */
char *host = g_strndup (server_str, colon - server_str);
guint port = atoi (colon + 1);
server_addr = gsk_socket_address_symbolic_ipv4_new (host, port);
g_free (host);
}
else
{
/* unix */
server_addr = gsk_socket_address_local_new (server_str);
}
}
else if (g_str_has_prefix (argv[i], "--upload-rate="))
upload_per_second_base = atoi (strchr (argv[i], '=') + 1);
else if (g_str_has_prefix (argv[i], "--download-rate="))
download_per_second_base = atoi (strchr (argv[i], '=') + 1);
else if (g_str_has_prefix (argv[i], "--upload-rate-noise="))
upload_per_second_noise = atoi (strchr (argv[i], '=') + 1);
else if (g_str_has_prefix (argv[i], "--download-rate-noise="))
download_per_second_noise = atoi (strchr (argv[i], '=') + 1);
else if (strcmp (argv[i], "--half-shutdowns") == 0)
half_shutdowns = TRUE;
else if (strcmp (argv[i], "--full-shutdowns") == 0)
half_shutdowns = FALSE;
else
usage ();
}
if (server_addr == NULL)
g_error ("missing --server=ADDRESS: try --help");
if (bind_addr == NULL)
g_error ("missing --bind=ADDRESS: try --help");
listener = gsk_stream_listener_socket_new_bind (bind_addr, &error);
if (listener == NULL)
g_error ("bind failed: %s", error->message);
gsk_stream_listener_handle_accept (listener,
handle_accept,
handle_listener_error,
NULL, NULL);
if (bind_status_addr != NULL)
{
GskHttpContentHandler *handler;
GskHttpContent *content = gsk_http_content_new ();
GskHttpContentId id = GSK_HTTP_CONTENT_ID_INIT;
handler = gsk_http_content_handler_new (create_status_page, NULL, NULL);
id.path = "/";
gsk_http_content_add_handler (content, &id, handler, GSK_HTTP_CONTENT_REPLACE);
gsk_http_content_handler_unref (handler);
if (!gsk_http_content_listen (content, bind_status_addr, &error))
g_error ("error listening: %s", error->message);
}
gsk_main_loop_add_timer_absolute (gsk_main_loop_default (),
unblock_timer_func, NULL, NULL,
gsk_main_loop_default ()->current_time.tv_sec + 1, 0);
return gsk_main_run ();
}
syntax highlighted by Code2HTML, v. 0.9.1