/*-
 * Copyright (c) 2004 Free (Olivier Beyssac)
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 */
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <signal.h>
#include <string.h>
#include <errno.h>
#include <time.h>
#include <grp.h>
#include <pwd.h>
#include <sys/time.h>
#include <sys/select.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/wait.h>
#include <sys/resource.h>
#include <unistd.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <syslog.h>
#include "options.h"
#include "parse_args.h"
#include "parse_config.h"
#include "net.h"
#include "cmd.h"
#include "iptree.h"
#include "daemon.h"
#include "netlist.h"
#include "utils.h"


#define BANNER "220 "PROGNAME" "PROGRAM_VERSION"\r\n"

#ifdef FD_SETSIZE
#define MAX_CLIENTS (FD_SETSIZE >= 160 ? 150 : FD_SETSIZE - 10)
#else
#define MAX_CLIENTS (150)
#endif /* FD_SETSIZE */

struct options opt; /* global runtime options/parameters */
iptree ipt;         /* global iptree */
int ncli = 0;       /* clients connected */
int listen_sd;      /* listening socket */


/* Information about clients */
struct {
  int sd;
  struct sockaddr_in addr;
  int reply;
  int mode;
  time_t last;
} client[MAX_CLIENTS];


#define close_client(i)  \
  do {                   \
    close(client[i].sd); \
    client[i].sd = -1;   \
    client[i].last = 0;  \
    ncli--;              \
  } while(0)


static void dump_data(void)
{
  iptree_dump2files(ipt);

  opt.last_dump = time(NULL);
}


/* Exit cleanly */
static void exit_signal_handler(const int signo)
{
  int i;
  
  syslog(LOG_INFO, "terminated by signal %d, closing all connections", signo);
  close(listen_sd);
  for (i = 0; i < MAX_CLIENTS; i++)
    if (client[i].sd != -1) {
      if (opt.log_level >= 3)
	syslog(LOG_INFO, "closing slot %d", i);
      close(client[i].sd);
    }

  /* Dump data to disk before exiting */
  dump_data();

  if (opt.daemon && unlink(opt.pid_filename) < 0)
    syslog(LOG_ERR, "unlink: %s", strerror(errno));
  
  exit(EXIT_SUCCESS);
}


/* Show some stats when we catch USR1 */
static void usr1_signal_handler(const int signo)
{
  struct rusage u;
  time_t t = time(NULL);
  
  syslog(LOG_INFO, "caught signal %d", signo);
  syslog(LOG_INFO, "start time: %lu (age: %lu secs)",
	 (unsigned long)opt.start_time, (unsigned long)t - opt.start_time);
  syslog(LOG_INFO, "submissions: %lu", opt.submissions);
  syslog(LOG_INFO, "insertions in BL: %lu", opt.insertqueries);
  syslog(LOG_INFO, "notifies: %lu", opt.notifies);
  syslog(LOG_INFO, "decrements: %lu", opt.decrqueries);
  syslog(LOG_INFO, "positive BL queries: %lu", opt.positive_blqueries);
  syslog(LOG_INFO, "total BL queries: %lu", opt.blqueries);
  syslog(LOG_INFO, "bad requests: %lu", opt.bad_requests);
  syslog(LOG_INFO, "denied requests: %lu", opt.denied_requests);
  syslog(LOG_INFO, "clients: %d/%d", ncli, MAX_CLIENTS);
  iptree_show_stats(ipt, t);
  if (getrusage(RUSAGE_SELF, &u) == -1) {
    syslog(LOG_ERR, "getrusage(): %s", strerror(errno));
    return;
  }
  syslog(LOG_INFO, "utime=%li.%lis - stime=%li.%lis",
	 u.ru_utime.tv_sec, u.ru_utime.tv_usec,
	 u.ru_stime.tv_sec, u.ru_stime.tv_usec);
  syslog(LOG_INFO, "maxrss=%li", u.ru_maxrss);
  syslog(LOG_INFO, "ixrss=%li - idrss=%li", u.ru_ixrss, u.ru_idrss);
  syslog(LOG_INFO, "minflt=%li - majflt=%li", u.ru_minflt, u.ru_majflt);
  syslog(LOG_INFO, "nswap=%li", u.ru_nswap);
  syslog(LOG_INFO, "inblock=%li - oublock=%li", u.ru_inblock, u.ru_oublock);
  syslog(LOG_INFO, "msgsnd=%li - msgrcv=%li", u.ru_msgsnd, u.ru_msgrcv);
  syslog(LOG_INFO, "nsignals=%li", u.ru_nsignals);
  syslog(LOG_INFO, "nvcsw=%lu - nivcsw=%lu", u.ru_nvcsw, u.ru_nivcsw);

#ifdef SHOW_PROC_STATUS
  sprintf(filename, "/proc/%d/status", getpid());
  if ((f = fopen(filename, "r")) == NULL) {
    syslog(LOG_ERR, "fopen: %s", strerror(errno));
    return;
  }
  while (fgets(buf, sizeof(buf), f))
    syslog(LOG_INFO, buf);
  fclose(f);
#endif /* PROC_STATUS_EXISTS */
}


/* Dump data to disk when catching USR2 */
static void usr2_signal_handler(const int signo)
{
  syslog(LOG_INFO, "caught signal %d", signo);
  dump_data();
}


/* Catch terminated childs */
static void chld_signal_handler(const int signo)
{
  pid_t pid;
  int status;

  status = signo; /* don't like gcc warnings :) */

  while ((pid = waitpid(-1, &status, WNOHANG)) > 0)
    if (opt.log_level >= 3)
      syslog(LOG_INFO, "child %d terminated", pid);
}


static void usage(void)
{
  printf("Usage: %s [-v] [-p port] [-n] [-l loglevel]\n", opt.progname);
  printf("\nOptions:\n");
  printf("  -h         this help\n");
  printf("  -v         output version information and exit\n");
  printf("  -f file    use a specific configuration file\n");
  printf("             (default: %s)\n", opt.config_file);
  printf("  -n         do not fork as a daemon\n");
  printf("  -a ip      address to bind to (default: %s)\n",
	 opt.listening_ip ? opt.listening_ip : "any");
  printf("  -p port    port to listen to (default: %s)\n", opt.port);
  printf("  -t number  min. time interval before blacklisting (default: %lu secs)\n", (unsigned long)opt.interval);
  printf("  -m number  maximum submissions in time interval (default %d)\n",
	 opt.max_req);
  printf("  -i number  IP list size (default: %d)\n", opt.list_size);
  printf("  -b number  blacklist size (default: %d)\n", opt.blacklist_size);
  printf("  -e number  blacklist expiration (default: %lu secs)\n",
	 (unsigned long)opt.blacklist_expiration);
  printf("  -l number  log level (0-3, default: %d)\n", opt.log_level);
  printf("  -u user    user to run as\n");
  printf("  -g group   group to run as\n");
  printf("  -T number  client timeout (default %lu secs)\n",
	 (unsigned long)opt.client_timeout);
  printf("  -I file    which file to dump IP list do\n");
  printf("             (default: `%s')\n", opt.wl_filename);
  printf("  -B file    which file to dump blacklist to\n");
  printf("             (default: `%s')\n", opt.bl_filename);
  printf("  -P file    PID filename (default: `%s')\n", opt.pid_filename);
  printf("  -A file    ACL filename (default: `%s')\n", opt.acl_filename);
  printf("  -W file    whitelist filename (default: `%s')\n",
	 opt.whitelist_filename);
}


/* Return a free slot for a client.
   FIXME: should be linear and keep maxsd more optimal for select() below */
static int get_client_slot(const int max)
{
  int i, j;

  if (max == MAX_CLIENTS)
    return -1;
  
  for (i = 0, j = MAX_CLIENTS - 1;
       i < MAX_CLIENTS && j != i && client[i].sd != -1 && client[j].sd != -1;
       i++, j--);

  if (i < MAX_CLIENTS && client[i].sd == -1)
    return i;
  if (j > 0 && client[j].sd == -1)
    return j;

  return -1;
}


/* Reload ACL and whitelist files if they changed and if they contain
   valid information */
static void reload_files(void)
{
  struct stat sb;
  netlist nl;
#if 0
  struct options sv_opt;
  
  /* Untested code. */
  /* Config */
  if (stat(opt.config_file, &sb) == 0) {
    if (S_ISREG(sb.st_mode) && sb.st_mtime > opt.config_mtime) {
      /* File modified */
      opt.config_mtime = sb.st_mtime;

      if (opt.log_level >= 2)
	syslog(LOG_INFO, "%s changed on disk, reloading", opt.config_file);
      memcpy(&sv_opt, &opt, sizeof(struct options));
      if (!parse_config()) {
	syslog(LOG_ERR, "error while parsing %s, restoring old configuration",
	       opt.config_file);
	memcpy(&opt, &sv_opt, sizeof(struct options));
      }
    }
  } else if (errno == ENOENT && opt.config_mtime) {
    /* File removed */
    if (opt.log_level >= 2)
      syslog(LOG_INFO, "%s disappeared", opt.config_file);
    opt.config_mtime = 0;
  }
#endif
  
  /* ACL */
  if (stat(opt.acl_filename, &sb) == 0) {
    if (S_ISREG(sb.st_mode) && sb.st_mtime > opt.acl_mtime) {
      /* File modified */
      opt.acl_mtime = sb.st_mtime;

      if (opt.log_level >= 2)
	syslog(LOG_INFO, "%s changed on disk, reloading", opt.acl_filename);
      if ((nl = netlist_init()) == NULL) {
	syslog(LOG_ERR, "netlist_init() failed for ACL, keeping old configuration");
      }
      else if (netlist_acl_getfromfile(nl, opt.acl_filename) == -1) {
	syslog(LOG_ERR, "error while parsing %s, keeping old configuration",
	       opt.acl_filename);
	free(nl);
      } else {
	netlist_free(opt.acl);
	opt.acl = nl;
      }
    }
  } else if (errno == ENOENT && opt.acl_mtime) {
    /* File removed */
    if (opt.log_level >= 2)
      syslog(LOG_INFO, "%s disappeared, cleaning list", opt.acl_filename);
    netlist_free(opt.acl);
    opt.acl = netlist_init();
    opt.acl_mtime = 0;
  }
  
  /* Whitelist */
  if (stat(opt.whitelist_filename, &sb) == 0) {
    if (S_ISREG(sb.st_mode) && sb.st_mtime > opt.whitelist_mtime) {
      /* File modified */
      opt.whitelist_mtime = sb.st_mtime;
      
      if (opt.log_level >= 2)
	syslog(LOG_INFO, "%s changed on disk, reloading",
	       opt.whitelist_filename);
      if ((nl = netlist_init()) == NULL) {
	syslog(LOG_ERR, "netlist_init() failed for whitelist, keeping old configuration");
      }
      else if (netlist_whitelist_getfromfile(nl, opt.whitelist_filename) == -1) {
	syslog(LOG_ERR, "error while parsing %s, keeping old configuration",
	       opt.whitelist_filename);
	free(nl);
      } else {
	netlist_free(opt.whitelist);
	opt.whitelist = nl;
      }
    }
  } else if (errno == ENOENT && opt.whitelist_mtime) {
    /* File removed */
    if (opt.log_level >= 2)
      syslog(LOG_INFO, "%s disappeared, cleaning list",
	     opt.whitelist_filename);
    netlist_free(opt.whitelist);
    opt.whitelist = netlist_init();
    opt.whitelist_mtime = 0;
  }
}


/* Set effective user and group.  Exit with 1 on error */
static void set_eugid(void)
{
  int id;
  struct passwd *pw;
  struct group *gr;

  if (opt.gid != NULL) {
    if ((gr = getgrnam(opt.gid)) == NULL
	&& ((id = xstrtol(opt.gid)) < 0 || (gr = getgrgid(id)) == NULL)) {
      fprintf(stderr, "%s: Invalid group: %s\n", opt.progname, opt.gid);
      exit(EXIT_FAILURE);
    }
    if (setegid(gr->gr_gid) == -1) {
      perror("setegid");
      exit(EXIT_FAILURE);
    }
    if (setgroups(1, &gr->gr_gid) == -1) {
      perror("setgroups");
      exit(EXIT_FAILURE);
    }
  }
  
  if (opt.uid != NULL) {
    if ((pw = getpwnam(opt.uid)) == NULL
	&& ((id = xstrtol(opt.uid)) < 0 || (pw = getpwuid(id)) == NULL)) {
      fprintf(stderr, "%s: Invalid user: %s\n", opt.progname, opt.uid);
      exit(EXIT_FAILURE);
    }
    if (seteuid(pw->pw_uid) == -1) {
      perror("seteuid");
      exit(EXIT_FAILURE);
    }
  }
}


/* Fill the buffer with the reply code and human understandable string */
static void buildreply(const int code, char *buffer, const size_t maxsize)
{
  size_t len;
  
  switch (code) {
  case 200:
    snprintf(buffer, maxsize, "%d OK", code);
    break;
  case 421:
    snprintf(buffer, maxsize, "%d Blacklisted", code);
    break;
  case 500:
    snprintf(buffer, maxsize, "%d Syntax error", code);
    break;
  case 600:
    snprintf(buffer, maxsize, "%d Authorization failed", code);
    break;
  default:
    snprintf(buffer, maxsize, "%d", code);
  }

  len = strlen(buffer);
  
  if (len >= maxsize - 2)
    sprintf(buffer + len - 2, "\r\n");
  else
      sprintf(buffer + len, "\r\n");
  
  return;
}


extern int main(int argc, char **argv)
{
  int i, client_sd, maxsd, ret;
  struct sockaddr_in cliaddr;
  size_t csize = sizeof(cliaddr);
  ssize_t rd, wd;
  char buf[MAX_CMD_LEN+1];
  fd_set rset, svrset, wset, svwset;
  int slot, errors;
  struct sigaction sa;
  struct stat sb;
  struct timeval tv;
  
  openlog(PROGNAME, LOG_PID, LOG_DAEMON);

  xsigaction(sa, SIGINT, exit_signal_handler);
  xsigaction(sa, SIGTERM, exit_signal_handler);
  xsigaction(sa, SIGUSR1, usr1_signal_handler);
  xsigaction(sa, SIGUSR2, usr2_signal_handler);
  xsigaction(sa, SIGCHLD, chld_signal_handler);

  options_init(argv[0]);
  parse_args_for_config_file(argc, argv);
  if (!parse_config())
    exit(EXIT_FAILURE);

  if ((ret = parse_args(argc, argv)) == -1) {
    usage();
    exit(EXIT_FAILURE);
  } else if (ret == 0) {
    usage();
    exit(EXIT_SUCCESS);
  }

  if (netlist_acl_getfromfile(opt.acl, opt.acl_filename) == -1) {
    fprintf(stderr, "%s: Error while parsing %s\n",
            opt.progname, opt.acl_filename);
    exit(EXIT_FAILURE);
  }

  if (netlist_whitelist_getfromfile(opt.whitelist,
				    opt.whitelist_filename) == -1) {
    fprintf(stderr, "%s: Error while parsing %s\n",
            opt.progname, opt.whitelist_filename);
    exit(EXIT_FAILURE);
  }

  if (stat(opt.config_file, &sb) == 0 && S_ISREG(sb.st_mode))
    opt.config_mtime = sb.st_mtime;
  if (stat(opt.acl_filename, &sb) == 0 && S_ISREG(sb.st_mode))
    opt.acl_mtime = sb.st_mtime;
  if (stat(opt.whitelist_filename, &sb) == 0 && S_ISREG(sb.st_mode))
    opt.whitelist_mtime = sb.st_mtime;

  set_eugid();

  if (opt.log_level >= 3)
    options_log();

  opt.syslog = 1;
  if (opt.daemon)
    daemon_init();
  
  umask(S_IRGRP | S_IWGRP | S_IROTH | S_IWOTH);

  if ((listen_sd = net_server_sock_create(opt.listening_ip, opt.port)) == -1)
    exit(EXIT_FAILURE);

  net_socket_set_nb(listen_sd);
 
  ipt = iptree_init();

  /* Initialize all client slots */
  for (i = 0; i < MAX_CLIENTS; i++) {
    client[i].sd = -1;
    client[i].reply = -1;
    client[i].last = 0;
  }
  
  FD_ZERO(&svrset);
  FD_SET(listen_sd, &svrset);
  FD_ZERO(&svwset);
  
  /* FIXME: should optimize maxsd count for select() */
  maxsd = MAX_CLIENTS;

  syslog(LOG_INFO, "listening on port %s", opt.port);
  
  errors = 0;
  
  for (;;) {
    int nsel;
    time_t t = time(NULL);

    if (t - opt.last_dump > opt.dump_delay)
      dump_data();

    /* Close idle connections */
    if (ncli)
      for (i = 0; i < MAX_CLIENTS; i++)
	if (client[i].sd != -1 && client[i].last
	    && t - client[i].last >= opt.client_timeout) {
	  if (opt.log_level >= 2)
	    syslog(LOG_INFO, "connection with %s closed (timeout)",
		   inet_ntoa(client[i].addr.sin_addr));
	  if (FD_ISSET(client[i].sd, &svrset))
	    FD_CLR(client[i].sd, &svrset);
	  if (FD_ISSET(client[i].sd, &svwset))
	    FD_CLR(client[i].sd, &svwset);
	  close_client(i);
	}
    
    /* Reload ACL and whitelist files if needed */
    reload_files();
    
    rset = svrset;
    wset = svwset;
    
    /* Wake up at least every 60 seconds to perform some cleanup */
    tv.tv_sec = 60;
    tv.tv_usec = 0;
    if ((nsel = select(maxsd + 1, &rset, &wset, NULL, &tv)) == -1
	&& errno != EINTR) {
      syslog(LOG_ERR, "error in select(): %s", strerror(errno));
      if (++errors >= 3) {
	syslog(LOG_ERR, "too many errors, aborting");
	close(listen_sd);
	exit(EXIT_FAILURE);
      }
      continue;
    }

    if (FD_ISSET(listen_sd, &rset)) {
      /* First look for incoming connections */
      nsel--;
      client_sd = accept(listen_sd, (struct sockaddr *)&cliaddr, &csize);
      if (client_sd < 0) {
	if (errno != EINTR && errno != EWOULDBLOCK)
	  syslog(LOG_ERR, "accept: %s (errno=%d)", strerror(errno), errno);
      } else {
	slot = get_client_slot(ncli);
	if (slot == -1)
	  close(client_sd);
	else {
	  int mode = netlist_getmode(opt.acl, ntohl(cliaddr.sin_addr.s_addr));

	  if (!mode) {
	    close(client_sd);
	    syslog(LOG_INFO, "closed unauthorized connection from %s",
		   inet_ntoa(cliaddr.sin_addr));
	  } else {
	    ncli++;
	    net_socket_set_nb(client_sd);
	    client[slot].sd = client_sd;
	    client[slot].addr.sin_addr = cliaddr.sin_addr;
	    client[slot].mode = mode;
	    client[slot].last = time(NULL);
	    if (opt.log_level == 2)
	      syslog(LOG_INFO, "connect from %s (%d/%d) [%d]",
		     inet_ntoa(cliaddr.sin_addr), ncli, MAX_CLIENTS, mode);

	    write(client_sd, BANNER, strlen(BANNER));
	    FD_SET(client_sd, &svrset);
	  } /* if (!mode) [...] else */
	} /* if (slot == -1) [...] else */
      } /* if (client_sd < 0) [...] else */
    } /* if (FD_ISSET(listen_sd, &rset) */
    
    /* Then have a look at open connections */
    for (i = 0; nsel > 0 && i < maxsd; i++) {
      if (client[i].sd != -1) {
	if (FD_ISSET(client[i].sd, &rset)) {
	  /* Something to read */
	  nsel--;
	  if ((rd = read(client[i].sd, &buf, MAX_CMD_LEN)) <= 0) {
	    if (rd < 0 && (errno == EINTR || errno == EWOULDBLOCK))
	      continue;
	    
	    if (rd < 0)
	      syslog(LOG_ERR, "read: %s", strerror(errno));
	    
	    FD_CLR(client[i].sd, &svrset);
	    close_client(i);
	  } else {
	    client[i].reply = read_cmd(buf, rd, inet_ntoa(cliaddr.sin_addr),
				       ipt, client[i].mode);
	    client[i].last = time(NULL);
	    FD_CLR(client[i].sd, &svrset);
	    FD_SET(client[i].sd, &svwset);
	  }
	} else if (FD_ISSET(client[i].sd, &wset)) {
	  /* Something to write */
	  nsel--;
	  buildreply(client[i].reply, buf, MAX_CMD_LEN + 1);
	  wd = write(client[i].sd, &buf, strlen(buf));
	  if (wd < 0 && (errno == EINTR || errno == EWOULDBLOCK))
	    continue;

	  if (wd < 0)
	    syslog(LOG_ERR, "write: %s", strerror(errno));
	  
	  FD_CLR(client[i].sd, &svwset);
	  close_client(i);

	  if (opt.log_level == 2)
	    syslog(LOG_INFO, "connection with %s closed (%d/%d)",
		   inet_ntoa(client[i].addr.sin_addr), ncli, MAX_CLIENTS);
	  
	} /* if (FD_ISSET(client[i].sd, &rset)) [...] else [...] */
      } /* if (client[i].sd != -1) */
    } /* for (i = 0; i < MAX_CLIENTS; i++) */
  } /* for (;;) */

  /* Can never be called but exit with a polite code */
  exit(EXIT_SUCCESS);
}


syntax highlighted by Code2HTML, v. 0.9.1