// ///////////////////////////////////////////////////////////////////////////
// Copyright (C) 2002 Ultr@VNC Team Members. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
// USA.
//
// Program is based on the
// http://www.imasy.or.jp/~gotoh/ssh/connect.c
// Written By Shun-ichi GOTO <gotoh@taiyo.co.jp>
//
// If the source code for the program is not available from the place
// from
// which you received this file, check
// http://ultravnc.sourceforge.net/
//
// Linux port (C) 2005 Jari Korhonen, jarit1.korhonen@dnainternet.net
// ///////////////////////////////////////////////////////////////////////////
#if defined(__FreeBSD__)
#if __FreeBSD__ < 5
#include <machine/limits.h>
#else
#include <sys/limits.h>
#endif
#endif /* __FreeBSD__ */
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#include <memory.h>
#include <errno.h>
#include <assert.h>
#include <stdarg.h>
#include <fcntl.h>
#include <signal.h>
#include <time.h>
#include <netdb.h>
#include <unistd.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include "repeaterproc.h"
#define REPEATER_VERSION "0.08"
#define RFB_PROTOCOL_VERSION_FORMAT "RFB %03d.%03d\n"
#define RFB_PROTOCOL_MAJOR_VERSION 0
#define RFB_PROTOCOL_MINOR_VERSION 0
#define SIZE_RFBPROTOCOLVERSIONMSG 12
#define RFB_PORT_OFFSET 5900 //servers 1st display is in this port number
#define MAX_IDLE_CONNECTION_TIME 600 //Seconds
#define MAX_HOST_NAME_LEN 250
#define MAX_IP_LEN 50
#define MAX_SESSIONS 100 //Maximum active repeater sessions
#define UNKNOWN_REPINFO_IND 999 //Notice: This should always be bigger than MAX_SESSIONS
//connectionFrom defines for acceptConnection()
#define CONNECTION_FROM_SERVER 0
#define CONNECTION_FROM_VIEWER 1
//connMode defines for acceptConnection()
//Linux repeater version 0.01-0.07 only supported Mode 2.
//Starting from version 0.08, Mode 1 is also supported
//Both types of connections can be running on same repeater at the same time
#define CONN_MODE1 1
#define CONN_MODE2 2
//Use safer openbsd stringfuncs: strlcpy, strlcat
#include "openbsd_stringfuncs.h"
typedef char rfbProtocolVersionMsg[SIZE_RFBPROTOCOLVERSIONMSG+1]; /* allow extra byte for null */
typedef struct _repeaterInfo {
int socket;
//Code is used for cross-connection between servers and viewers
//In Mode 2, Server/Viewer sends IdCode string "ID:xxxxx", where xxxxx is some positive (1 or bigger) long integer number
//In Mode 1, Repeater "invents" a non-used code (negative number) and assigns that to both Server/Viewer
//code == 0 means that entry in servers[] / viewers[] table is free
long code;
unsigned long timeStamp;
char peerIp[MAX_IP_LEN]; //Ip address of the other end
//There are 3 connection levels (using variables "code" and "active"):
//A. code==0,active==false: fully idle, no connection attempt detected
//B. code==non-zero,active==false: server/viewer has connected, waiting for other end to connect
//C. code==non-zero,active=true: doRepeater() running on viewer/server connection, fully active
//-after viewer/server disconnects or some error in doRepeater, returns both to level A
//(and closes respective sockets)
//This logic means, that when one end disconnects, BOTH ends need to reconnect. This is not a bug, it is a feature ;-)
bool active;
} repeaterInfo;
static repeaterInfo servers[MAX_SESSIONS];
static repeaterInfo viewers[MAX_SESSIONS];
//mode1ConnCode is used in Mode1 to "invent" code field in repeaterInfo, when new Mode1 connection from
//viewer is accepted. This is just decremented for each new Mode 1 connection to ensure unique number
//for each Mode 1 session
//Values for this are: 0=program has just started, -1....MIN_INVENTED_CONN_CODE: Codes for each session
#define MIN_INVENTED_CONN_CODE -1000000
static long mode1ConnCode;
//These tables are used in function updateServerViewerInfo to compare against current servers/viewers
//and update differences
static repeaterInfo oldServers[MAX_SESSIONS];
static repeaterInfo oldViewers[MAX_SESSIONS];
//This structure (and repeaterProcs[] table) is used for keeping track of child processes running doRepeater
//and cleaning up after they exit
typedef struct _repeaterProcInfo
{
long code;
pid_t pid;
} repeaterProcInfo;
static repeaterProcInfo repeaterProcs[MAX_SESSIONS];
//This structure keeps information of ports/socket used when
//routeConnections() listens for new incoming connections
typedef struct _listenPortInfo {
int socket;
int port;
} listenPortInfo;
//stopped==true means that user wants program to stop (has pressed ctrl+c)
//From version 0.08 onwards, function fatal() also sets stopped == TRUE to achieve clean shutdown
static bool stopped;
static int readExact(int sock, char *buf, int len);
static int findViewerList(long code);
static void cleanUpAfterRepeaterProcs(void);
//Global functions
//Global functions
//Global functions
void debug(const char *fmt, ...)
{
time_t errTime;
va_list args;
va_start(args, fmt);
//JK: Added timestamp to log messages
errTime = time(NULL);
fprintf(stderr, "UltraVnc %s> ", (errTime != -1) ? ctime(&errTime) : "");
vfprintf(stderr, fmt, args);
va_end(args);
}
void error(const char *fmt, ...)
{
time_t errTime;
va_list args;
va_start(args, fmt);
//JK: Added timestamp to log messages
errTime = time(NULL);
fprintf(stderr, "UltraVnc ERROR %s> ", (errTime != -1) ? ctime(&errTime) : "");
vfprintf(stderr, fmt, args);
va_end(args);
}
//Local functions
//Local functions
//Local functions
static void fatal(const char *fmt, ...)
{
time_t errTime;
va_list args;
va_start(args, fmt);
//JK: Added timestamp to log messages
errTime = time(NULL);
fprintf(stderr, "UltraVnc FATAL %s> ", (errTime != -1) ? ctime(&errTime) : "");
vfprintf(stderr, fmt, args);
va_end(args);
//Close program down cleanly (as if user just pressed ctrl+c, of course
//log file will show FATAL message in case program shuts down)
stopped = true;
}
static void cleanRepeaterProcList(void)
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
repeaterProcs[i].code = 0;
repeaterProcs[i].pid = 0;
}
}
static void addRepeaterProcList(long code, pid_t pid)
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
if (repeaterProcs[i].code == 0) {
debug("addRepeaterProcList(): Added proc to index %d, pid=%d, code=%ld\n", i, pid, code);
repeaterProcs[i].code = code;
repeaterProcs[i].pid = pid;
return;
}
}
debug("addRepeaterProcList(): Warning, no free process slots found\n");
}
static void removeRepeaterProcList(pid_t pid)
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
if (repeaterProcs[i].pid == pid) {
debug("removeRepeaterProcList(): Removing proc from index %d, pid=%d\n", i, pid);
repeaterProcs[i].code = 0;
repeaterProcs[i].pid = 0;
return;
}
}
debug("removeRepeaterProcList(): Warning, did not find any process to remove\n");
}
static int findRepeaterProcList(pid_t pid)
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
if (repeaterProcs[i].pid == pid) {
debug("findRepeaterProcList(): proc found at %d, pid=%d, code = %ld\n", i, pid, repeaterProcs[i].code);
return i;
}
}
return UNKNOWN_REPINFO_IND;
}
static void cleanServerList()
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
servers[i].code = 0;
servers[i].active = false;
oldServers[i].code = 0;
oldServers[i].active = false;
}
}
static void addServerList(int socket, long code, char *peerIp)
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
if (servers[i].code == code) {
debug("addServerList(): similar server already there (reconnect ?)\n");
return;
}
}
for (i = 0; i < MAX_SESSIONS; i++) {
if (servers[i].code == 0) {
debug("addServerList(): Server added to list %ld\n", code);
servers[i].code = code;
servers[i].socket = socket;
strlcpy(servers[i].peerIp, peerIp, MAX_IP_LEN);
servers[i].timeStamp = time(NULL); /* 1 second accuracy is enough ? */
servers[i].active = false;
return;
}
}
}
static void removeServerList(long code)
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
if (servers[i].code == code) {
debug("removeServerList(): Server Removed from list %ld\n", code);
servers[i].code = 0;
servers[i].active = false;
return;
}
}
}
static void setServerActive(long code)
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
if (servers[i].code == code) {
servers[i].active = true;
debug("setServerActive(): activated server at %d, code = %ld\n", i, servers[i].code);
return;
}
}
}
static int findServerList(long code)
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
if (servers[i].code == code) {
debug("findServerList(): server found at %d, code = %ld\n", i, servers[i].code);
return i;
}
}
return UNKNOWN_REPINFO_IND;
}
static void cleanViewerList()
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
viewers[i].code = 0;
servers[i].active = false;
oldViewers[i].code = 0;
oldServers[i].active = false;
}
}
static void addViewerList(int socket, long code, char *peerIp)
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
if (viewers[i].code == code) {
debug("addViewerList(): Similar viewer already there (reconnect ?)\n");
return;
}
}
for (i = 0; i < MAX_SESSIONS; i++) {
if (viewers[i].code == 0) {
debug("addViewerList(): Viewer added to list %d\n", code);
viewers[i].code = code;
viewers[i].socket = socket;
strlcpy(viewers[i].peerIp, peerIp, MAX_IP_LEN);
viewers[i].timeStamp = time(NULL);
viewers[i].active = false;
return;
}
}
}
static void removeViewerList(long code)
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
if (viewers[i].code == code) {
debug("removeViewerList(): Viewer removed from list %d\n", code);
viewers[i].code = 0;
viewers[i].active = false;
return;
}
}
}
static void setViewerActive(long code)
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
if (viewers[i].code == code) {
viewers[i].active = true;
debug("setViewerActive(): activated viewer at %d, code = %ld\n", i, viewers[i].code);
return;
}
}
}
static int findViewerList(long code)
{
int i;
for (i = 0; i < MAX_SESSIONS; i++) {
if (viewers[i].code == code) {
debug("findViewerList(): viewer found at %d, code = %ld\n", i, viewers[i].code);
return i;
}
}
return UNKNOWN_REPINFO_IND;
}
//Check IdCode string, require that 1st 3 characters of IdCode are 'I','D',':'
static bool checkIdCode(char *IdCode)
{
if ((IdCode[0] != 'I') || (IdCode[1] != 'D') || (IdCode[2] != ':')) {
debug("checkIdCode(): %s is not IdCode string\n", IdCode);
return false;
}
return true;
}
//Parse IdCode string of format "ID:xxxxx", where xxxxx is some positive (non-zero) long integer number
//Return -1 on error, xxxxx on success
static long parseId(char *IdCode)
{
unsigned int ii;
int retVal;
debug("parseId(): IdCode = %s\n", IdCode);
//Require that 1st 3 characters of IdCode are 'I','D',':'
if (false == checkIdCode(IdCode)) {
debug("parseId(): IdCode format error, does not start ""ID:"" \n");
return -1;
}
else {
//Require that all other characters of IdCode are digits
for (ii = 3; ii < strlen(IdCode); ii++) {
if (!isdigit(IdCode[ii])) {
debug("parseId(): IdCode format error, code should consist of decimal digits\n");
return -1;
}
}
retVal = strtol(&(IdCode[3]), NULL, 10);
if (retVal <= 0) {
debug("parseId(): IdCode format error, code should be positive long integer number\n");
return -1;
}
else if (retVal == LONG_MAX) {
debug("parseId(): IdCode format error, code is too big\n");
return -1;
}
return retVal;
}
}
static int writeExact(int sock, char *buf, int len)
{
int n;
while (len > 0) {
n = send(sock, buf, len, 0);
if (n > 0) {
buf += n;
len -= n;
}
else if (n == 0) {
fprintf(stderr, "writeExact: write returned 0\n");
exit(1);
}
else {
return n;
}
}
return 1;
}
static int readExact(int sock, char *buf, int len)
{
int n;
while (len > 0) {
n = recv(sock, buf, len, 0);
if (n > 0) {
buf += n;
len -= n;
}
else {
return n;
}
}
return 1;
}
//This function is periodically called from routeConnections() to remove
//servers / viewers that did not receive any matching other end connection
static void removeOldInactiveConnections(void)
{
int i;
unsigned long tick = time(NULL);
for (i = 0; i < MAX_SESSIONS; i++) {
//Remove old inactive viewers
if ((tick - viewers[i].timeStamp) > MAX_IDLE_CONNECTION_TIME) {
if ((viewers[i].active == false) && (viewers[i].code != 0)) {
close(viewers[i].socket);
debug("removeOldInactiveConnections(): Removing viewer %ld at index %d \n", viewers[i].code, i);
removeViewerList(viewers[i].code);
}
}
//Remove old inactive servers
if ((tick - servers[i].timeStamp) > MAX_IDLE_CONNECTION_TIME) {
if ((servers[i].active == false) && (servers[i].code != 0)) {
close(servers[i].socket);
debug("removeOldInactiveConnections(): Removing server %ld at index %d\n", servers[i].code, i);
removeServerList(servers[i].code);
}
}
}
}
//Parse [hostname / ip address] / [port number / display number] combination
//Return true if success, false if error
static bool parseHostAndPort(char *id, char *host, int hostLen, int *port)
{
int tmpPort;
char *colonPos;
debug("parseHostAndPort() start: id = %s\n", id);
colonPos = strchr(id, ':');
if (hostLen < (int) strlen(id)) {
debug("parseHostAndPort(): Id string too long\n");
return false;
}
if (colonPos == NULL) {
// No colon -- use default port number
tmpPort = RFB_PORT_OFFSET;
strlcpy(host, id, hostLen);
}
else {
strlcpy(host, id, (colonPos-id)+1);
if (colonPos[1] == ':') {
// Two colons -- interpret as a port number
if (sscanf(colonPos + 2, "%d", &tmpPort) != 1) {
debug("parseHostAndPort(): sscanf error 1\n");
return false;
}
}
else {
// One colon -- interpret as a display number or port
// number
if (sscanf(colonPos + 1, "%d", &tmpPort) != 1) {
return false;
}
// RealVNC method - If port < 100 interpret as display
// number else as Port number
if (tmpPort < 100)
tmpPort += RFB_PORT_OFFSET;
}
}
*port = tmpPort;
debug("parseHostAndPort() end: host = %s, port = %d\n", host, tmpPort);
return true;
}
//Try to connect to vnc server, return connected socket if success, -1 if error
//parameter serverIp holds server ip address on return (or "" in case of error)
static int openConnectionToVncServer(const char *host, unsigned short port, char *serverIp)
{
int s;
struct sockaddr_in saddr;
struct hostent *h;
h = gethostbyname(host);
if (NULL == h) {
error("open_connection(): can't resolve hostname: %s\n", host);
return -1;
}
saddr.sin_family = AF_INET;
saddr.sin_port = htons(port);
//Interesting ;-) typecast copied from "Beej's Guide to network programming".
//See http://beej.us/guide/bgnet/ for more info
saddr.sin_addr = *((struct in_addr *)h->h_addr);
memset(&(saddr.sin_zero), '\0', 8); // zero the rest of the struct
strlcpy(serverIp, inet_ntoa(saddr.sin_addr), MAX_IP_LEN);
debug("open_connection(): connecting to %s:%u\n", serverIp, port);
s = socket(AF_INET, SOCK_STREAM, 0);
if (connect(s, (struct sockaddr *) &saddr, sizeof(saddr)) == -1) {
debug("open_connection(): connect() failed.\n");
close(s);
strlcpy(serverIp, "", MAX_IP_LEN);
return -1;
}
else
return s;
}
//Accept connections from both servers and viewers
//(connectionFrom == CONNECTIONFROMSERVER means server is connecting,
//connectionFrom==CONNECTIONFROMVIEWER means viewer is connecting)
//Mode 2 connections can come from both sides,
//Mode 1 connections only from viewers
static void acceptConnection(int socket, int connectionFrom)
{
rfbProtocolVersionMsg pv;
int connection;
char id[MAX_HOST_NAME_LEN + 1];
long code;
struct sockaddr_in client;
socklen_t sockLen;
char peerIp[MAX_IP_LEN];
int connMode; //Connection mode: CONN_MODE1 or CONN_MODE2
//These variables are used in Mode 1
char host[MAX_HOST_NAME_LEN+1];
char connMode1ServerIp[MAX_IP_LEN];
int port;
sockLen = sizeof(struct sockaddr_in);
connection = accept(socket, (struct sockaddr *) &client, &sockLen);
if (connection < 0)
debug("acceptConnection(): accept() failed, errno=%d (%s)\n", errno, strerror(errno));
else {
strlcpy(peerIp, inet_ntoa(client.sin_addr), MAX_IP_LEN);
debug("acceptConnection(): connection accepted ok from ip: %s\n", peerIp);
if (connectionFrom == CONNECTION_FROM_VIEWER) {
//We handshake viewers by transmitting rfbProtocolVersion first
snprintf(pv, SIZE_RFBPROTOCOLVERSIONMSG+1, RFB_PROTOCOL_VERSION_FORMAT,
RFB_PROTOCOL_MAJOR_VERSION, RFB_PROTOCOL_MINOR_VERSION);
debug("acceptConnection(): pv = %s", pv);
if (writeExact(connection, pv, SIZE_RFBPROTOCOLVERSIONMSG) < 0) {
debug("acceptConnection(): Writing protocol version error\n");
close(connection);
return;
}
}
//Make sure that id is null-terminated
id[MAX_HOST_NAME_LEN] = '\0';
if (readExact(connection, id, MAX_HOST_NAME_LEN) < 0) {
debug("acceptConnection(): Reading id error\n");
close(connection);
return;
}
//id can be of format:
//Normally in Mode 2:
//"ID:xxxxx", where xxxxx is some positive (non-zero) long integer number.
//
//Normally in Mode 1:
//"xx.yy.zz.nn::pppp" (Ip address, 2 colons, port number)
//"xx.yy.zz.nn:pppp" (Ip address, 1 colons, some number): This is a problematic case.
//It is interpreted in the following way (copied directly from original repeater):
//If pppp is < 100, it is a display number. If >= 100, it is a port number
//"xx.yy.zz.nn" (Only Ip Address: Default port number RFB_PORT_OFFSET is used)
//In mode 1, instead of ip address, also DNS hostname can be used in any combination with
//port / display number
if (checkIdCode(id)) {
//id is an IdCode string, parse it
code = parseId(id);
if (-1 == code) {
debug("acceptConnection(): parseId returned error\n");
close(connection);
return;
}
debug("acceptConnection(): %s sent code %ld \n",
(connectionFrom == CONNECTION_FROM_VIEWER) ? "Viewer" : "Server", code);
connMode = CONN_MODE2;
}
else {
//id is an [hostname / ip address] / [port number / display number] combination of some sort, parse it
if (false == parseHostAndPort(id, host, MAX_HOST_NAME_LEN + 1, &port)) {
debug("acceptConnection(): parseHostAndPort returned error\n");
close(connection);
return;
}
connMode = CONN_MODE1;
}
if (connMode == CONN_MODE1) {
//Only viewers can try Mode 1 connection
if (connectionFrom == CONNECTION_FROM_SERVER) {
debug("acceptConnection(): Mode 1 connection only allowed from viewer\n");
close(connection);
return;
}
else {
//Connection from viewer: ok, continue
int server;
pid_t pid;
server = openConnectionToVncServer(host, (unsigned short) port, connMode1ServerIp);
if (server == -1) {
debug("acceptConnection(): openConnectionToVncServer() failed\n");
close(connection);
return;
}
else {
//Invent new unique connection code
//Minus-side numbers are used for Mode1 sessions
mode1ConnCode--;
if (mode1ConnCode < MIN_INVENTED_CONN_CODE)
mode1ConnCode = -1;
//Add new viewer
addViewerList(connection, mode1ConnCode, peerIp);
setViewerActive(mode1ConnCode);
//Add new server
addServerList(server, mode1ConnCode, connMode1ServerIp);
setServerActive(mode1ConnCode);
//fork repeater
pid = fork();
if (0 == pid) {
//child code
debug("acceptConnection(): Mode 1: forking doRepeater(%d, %d)\n", server, connection);
exit(doRepeater(server, connection));
}
else {
//parent code
//Add necessary information of child to repeaterProcs list so we can
//properly clean up after child has exited
addRepeaterProcList(mode1ConnCode, pid);
}
}
}
}
else if (connMode == CONN_MODE2) {
if (connectionFrom == CONNECTION_FROM_VIEWER) {
int serverInd;
//New viewer, find respective server
addViewerList(connection, code, peerIp);
serverInd = findServerList(code);
if (serverInd != UNKNOWN_REPINFO_IND) {
int server;
pid_t pid;
//found respective server, activate viewer and server
setViewerActive(code);
setServerActive(code);
server = servers[serverInd].socket;
//fork repeater
pid = fork();
if (0 == pid) {
//child code
debug("acceptConnection(): Mode 2: new viewer connected, forking doRepeater(%d, %d)\n",
server, connection);
exit(doRepeater(server, connection));
}
else {
//parent code
//Add necessary information of child to repeaterProcs list so we
//can properly clean up after child has exited
addRepeaterProcList(code, pid);
}
}
else {
debug("acceptConnection(): respective server has not connected yet\n");
}
}
else {
int viewerInd;
//New server, find respective viewer
addServerList(connection, code, peerIp);
viewerInd = findViewerList(code);
if (viewerInd != UNKNOWN_REPINFO_IND) {
int viewer;
pid_t pid;
//found respective viewer, activate server and viewer
setServerActive(code);
setViewerActive(code);
viewer = viewers[viewerInd].socket;
//fork repeater
pid = fork();
if (0 == pid) {
//child code
debug("acceptConnection(): Mode 2: new server connected, forking doRepeater(%d, %d)\n",
connection, viewer);
exit(doRepeater(connection, viewer));
}
else {
//parent code
//Add necessary information of child to repeaterProcs list so we
//can properly clean up after child has exited
addRepeaterProcList(code, pid);
}
}
else {
debug("acceptConnection(): respective viewer has not connected yet\n");
}
}
}
}
}
//Initialize listening on port.
//Listening itself happens on function routeConnections
static void startListeningOnPort(listenPortInfo * pInfo)
{
int yes = 1;
struct sockaddr_in name;
pInfo->socket = socket(PF_INET, SOCK_STREAM, 0);
if (pInfo->socket < 0)
fatal("startListeningOnPort(): socket() failed, errno=%d (%s)\n", errno, strerror(errno));
else
debug("startListeningOnPort(): socket() initialized\n");
if (setsockopt(pInfo->socket, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(int)) == -1)
fatal("startListeningOnPort(): setsockopt() failed, errno=%d (%s)\n", errno, strerror(errno));
else
debug("startListeningOnPort(): setsockopt() success\n");
name.sin_family = AF_INET;
name.sin_port = htons(pInfo->port);
name.sin_addr.s_addr = htonl(INADDR_ANY);
if (bind(pInfo->socket, (struct sockaddr *) &name, sizeof(name)) < 0)
fatal("startListeningOnPort(): bind() failed, errno=%d (%s)\n", errno, strerror(errno));
else
debug("startListeningOnPort(): bind() succeeded to port %d\n", pInfo->port);
if (listen(pInfo->socket, 1) < 0)
fatal("startListeningOnPort(): listen() failed, errno=%d (%s)\n", errno, strerror(errno));
else
debug("startListeningOnPort(): listen() succeeded\n");
}
//This function (called from routeConnections) periodically checks changes in
//servers[] / viewers[] tables.
//This could be used to update changes in web page, database etc.
//Current version only outputs a debug() line
static void updateServerViewerInfo(void)
{
int i;
//Check changes in servers
for (i = 0; i < MAX_SESSIONS; i++) {
if (memcmp( &(servers[i]), &(oldServers[i]), sizeof(repeaterInfo))) {
//Something has changed in index i, update that in database
debug("updateServerViewerInfo(): servers[] has changed at index %d\n", i);
//New server connection ?
if ((servers[i].code != 0) && (oldServers[i].code == 0))
debug("updateServerViewerInfo(): New server connection from : %s\n", servers[i].peerIp);
//Update to current situation
oldServers[i] = servers[i];
}
}
//Check changes in viewers
for (i = 0; i < MAX_SESSIONS; i++) {
if (memcmp(&(viewers[i]), &(oldViewers[i]), sizeof(repeaterInfo))) {
//Something has changed in index i, update that in database
debug("updateServerViewerInfo(): viewers[] has changed at index %d\n", i);
//New viewer connection ?
if ((viewers[i].code != 0) && (oldViewers[i].code == 0))
debug("updateServerViewerInfo(): New viewer connection from : %s\n", viewers[i].peerIp);
//Update to current situation
oldViewers[i] = viewers[i];
}
}
}
//Listen for new connections on both server and viewer ports,
//call acceptConnection() to accept them.
//Periodically also remove old inactive connections by calling removeOldInactiveConnections()
//Periodically call updateServerViewerInfo() to check changes in servers[]/viewers[] tables
static void routeConnections(int viewerSocket, int serverSocket)
{
int seconds;
fd_set readfds;
int numfds;
bool select_ok;
struct timeval tv;
const int SELECT_WAIT_SECONDS=1;
seconds = 0;
numfds = ((viewerSocket > serverSocket) ? viewerSocket : serverSocket) + 1;
debug("routeConnections(): starting select() loop, terminate with ctrl+c\n");
while (stopped == false) {
FD_ZERO(&readfds);
FD_SET(viewerSocket, &readfds);
FD_SET(serverSocket, &readfds);
tv.tv_sec = SELECT_WAIT_SECONDS;
tv.tv_usec = 0;
select_ok = true;
if (-1 == select(numfds, &readfds, NULL, NULL, &tv)) {
select_ok = false;
if (stopped == false) {
debug("routeConnections(): select() failed, errno=%d (%s)\n", errno, strerror(errno));
}
}
if ((select_ok == true) && (stopped == false)) {
//New viewer trying to connect ?
if (FD_ISSET(viewerSocket, &readfds)) {
debug("routeConnections(): new viewer connecting, accepting...\n");
acceptConnection(viewerSocket, CONNECTION_FROM_VIEWER);
}
//New server trying to connect ?
if (FD_ISSET(serverSocket, &readfds)) {
debug("routeConnections(): new server connecting, accepting...\n");
acceptConnection(serverSocket, CONNECTION_FROM_SERVER);
}
//Remove old inactive connections
seconds += SELECT_WAIT_SECONDS;
if (seconds >= 60) {
seconds = 0;
removeOldInactiveConnections();
}
//Clean up after children (Repeaterprocs that have exited)
cleanUpAfterRepeaterProcs();
//Update external info about servers/viewers
updateServerViewerInfo();
}
}
}
//After doRepeater process has exited, this function reads exit code/pid and clears
//servers[], viewers[] and repeaterProcs[] tables accordingly
static void cleanUpAfterRepeaterProcExit(int exitCode, pid_t pid) {
long code;
int index;
int serverInd;
int viewerInd;
debug("cleanUpAfterRepeaterProcExit(): exitCode=%d, pid=%d\n", exitCode, pid);
index = findRepeaterProcList(pid);
if (index != UNKNOWN_REPINFO_IND) {
code = repeaterProcs[index].code;
serverInd = findServerList(code);
viewerInd = findViewerList(code);
if ((serverInd != UNKNOWN_REPINFO_IND) && (viewerInd != UNKNOWN_REPINFO_IND)) {
//Remove repeaterproc from list
removeRepeaterProcList(pid);
debug("cleanUpAfterRepeaterProcExit(): code=%ld, serverInd=%d, viewerInd=%d\n",
code, serverInd, viewerInd);
switch(exitCode) {
case 1:
//Error in select(), fall through
case 2:
//Server has disconnected, fall through
case 3:
//Viewer has disconnected, fall through
case 4:
//Error when reading from viewer, fall through
case 5:
//Error when reading from server
close(viewers[viewerInd].socket);
close(servers[serverInd].socket);
removeServerList(code);
removeViewerList(code);
break;
default:
break;
}
}
else {
debug("cleanUpAfterRepeaterProcExit(): illegal viewerInd = %d or serverInd =%d\n", viewerInd, serverInd);
}
}
else {
debug("cleanUpAfterRepeaterProcExit(): proc not found\n");
}
}
//Check each possible children and clean up after they have exited
static void cleanUpAfterRepeaterProcs(void)
{
int status;
pid_t pid;
int i;
for(i = 0; i < MAX_SESSIONS; i++) {
if (repeaterProcs[i].code != 0) {
pid = waitpid(repeaterProcs[i].pid, &status, WNOHANG);
if (pid > 0) {
cleanUpAfterRepeaterProcExit(WEXITSTATUS(status), pid);
}
}
}
}
//Terminate program with ctrl+c cleanly
static void handleSigInt(int s)
{
stopped = true;
}
int main(int argc, char **argv)
{
//Ports where we listen
unsigned short viewerPort;
unsigned short serverPort;
//Viewer port listener variable
listenPortInfo viewerListener;
//Server port listener variable
listenPortInfo serverListener;
//ctrl+c signal handler
struct sigaction saInt;
stopped = false;
mode1ConnCode = 0;
fprintf(stderr, "UltraVnc Linux Repeater version %s\n", REPEATER_VERSION);
cleanServerList();
cleanViewerList();
cleanRepeaterProcList();
//Initialize port variables according to command-line parameters
viewerPort = 5900;
serverPort = 5500;
if (argc >= 2)
viewerPort = atoi(argv[1]);
if (argc >= 3)
serverPort = atoi(argv[2]);
//Initialize ctrl+c signal handler
memset(&saInt, 0, sizeof(saInt));
//Restart interrupted system calls after handler returns
saInt.sa_flags = SA_RESTART;
saInt.sa_handler = &handleSigInt;
sigaction(SIGINT, &saInt, NULL);
//Initialize and start listening on viewer port
viewerListener.port = viewerPort;
startListeningOnPort(&viewerListener);
//Initialize and start listening on server port
serverListener.port = serverPort;
startListeningOnPort(&serverListener);
//Accept & Route new connections
routeConnections(viewerListener.socket, serverListener.socket);
debug("main(): relaying done.\n");
close(viewerListener.socket);
close(serverListener.socket);
return 0;
}
syntax highlighted by Code2HTML, v. 0.9.1