/*****************************************************************************\
* Copyright (c) 2002 Pelle Johansson.                                         *
* All rights reserved.                                                        *
*                                                                             *
* This file is part of the moftpd package. Use and distribution of            *
* this software is governed by the terms in the file LICENCE, which           *
* should have come with this package.                                         *
\*****************************************************************************/

/* $moftpd: main.c 1264 2005-04-06 13:32:27Z morth $ */

#include "system.h"

#include "main.h"

#include "connection.h"
#include "server.h"
#include "utf8fs/file.h"
#include "confparse.h"
#include "utf8fs/memory.h"
#include "accounter.h"
#include "events.h"

static int doFork = 1;

int debug = 0, accClientSock = -1;
int running = 1, runningAsRoot;
const char *configFile = ETCDIR "/moftpd.conf";
int forkConnections = 1, skipFd0 = 0, closeServers = 0, selfFailedFork = 1;
size_t maxMmapSize = 10 * 1024 * 1024;
int urgData = 0, reloadConfig = 0;
uid_t unprivUid;
gid_t unprivGid;
const char *localeDir = SHAREDIR "/" PACKAGE_NAME "/locale/";
const char *localeSuffix = ".loc";
const char *pidFile = VARDIR "/run/" PACKAGE_NAME ".pid";
const char *templateDir = SHAREDIR "/" PACKAGE_NAME "/";
char *sslCertsPath;
int forker, doFakeChroot;

extern char *optarg;
extern int optind;
extern int fakeChroot, accServer;

void sigkill(int unused)
{
  running = 0;
}

void sigquit(int unused)
{
  running = 0;
}

void sigurg(int unused)
{
  urgData = 1;
}

void sigchld(int unused)
{
  wait(NULL);
}

void sigint(int unused)
{
  close_accounter (0);
  quit_all_connections("421 Server shutting down.");
  closeServers = 1;
}

void sighup(int unused)
{
  sigint(unused);
  reloadConfig = 1;
}

int main(int argc, char *argv[])
{
  connection_t *conn;
  char startCwd[MAXPATHLEN];
  int fd, wrotePidFile = 0;
#if defined (RLIMIT_CORE) || defined (RLIMIT_FSIZE)
  struct rlimit rl;
#endif
  
  signal(SIGKILL, sigkill);
  signal(SIGQUIT, sigquit);
  signal(SIGURG, sigurg);
  signal(SIGCHLD, sigchld);
  signal(SIGINT, sigint);
  signal(SIGHUP, sighup);
  signal (SIGPIPE, SIG_IGN);
#ifdef SIGXFSZ
  signal (SIGXFSZ, SIG_IGN);
#endif
  
  setprogname (argv[0]);
  unprivUid = geteuid();
  unprivGid = getegid();
  
  runningAsRoot = !unprivUid;
  
  mem_init ();
  events_init ();
  
#ifdef USE_TLS
  sslCertsPath = pstring (tls_get_cert_dir (), NULL);
#else
  sslCertsPath = pstring ("/etc/certs", NULL);
#endif
  
  argc = parse_options(argc, &argv);
  if(argc < 0)
    return 1;
  if(argc && !getenv("MOFTP_CHILD"))
    usage();
  
  /*
   * Use NDELAY since we might chroot.
   */
  openlog (getprogname (), LOG_PID | LOG_NDELAY | (debug? LOG_PERROR : 0),
	LOG_FTP);
  if (!debug)
  {
#ifdef RLIMIT_CORE
    if (getrlimit (RLIMIT_CORE, &rl))
      rl.rlim_max = 0;
    rl.rlim_cur = 0;
    setrlimit (RLIMIT_CORE, &rl);
#endif
    setlogmask (LOG_UPTO (LOG_INFO));
  }
  
#ifdef RLIMIT_FSIZE
  if (getrlimit (RLIMIT_FSIZE, &rl))
    rl.rlim_cur = RLIM_INFINITY;
  rl.rlim_max = RLIM_INFINITY;
  setrlimit (RLIMIT_FSIZE, &rl);
#endif
  
  getcwd (startCwd, sizeof (startCwd));
  
 loadConfig:
  if (reloadConfig)
    chdir (startCwd);
  
  set_access (NULL, 1, -1, -1); // Stop symlink resolving
  if(read_config(configFile))
  {
    if (!reloadConfig)
    {
      if (errno == ENOENT)
      {
	if (strncmp (configFile, templateDir, strlen (templateDir)))
	  fprintf (stderr, "To run a default server, use -C <template>. See moftpd(8).\n");
	else
	  fprintf (stderr, "No such template. See moftpd(8).\n");
      }
      else if (!debug)
	fprintf (stderr, "Config file error: %s. See syslog or start with -D for more information.\n",
	      strerror (errno));
    }
    return 1;
  }
  
  if (!doFakeChroot && (!runningAsRoot || !forkConnections))
  {
    syslog (LOG_WARNING, "Warning: FakeChroot turned on because %s.",
	  runningAsRoot? "ForkConnections is false" : "not running as root");
    doFakeChroot = 1;
  }
  
  if(skipFd0)
  {
    conn = NULL;
    errno = 0;
  }
  else
  {
    /*
     * Check if we have a connection on stdin.
     * Not sure this is foolproof with regards to pipes, but I don't see why
     * you'd start moftpd with a pipe on stdin. find_server() only works on
     * PF_INET and PF_INET6, so you'll get no service available on pipe
     * sockets.
     */
    conn = new_connection(0, NULL, -1);
    skipFd0 = 1;
  }
  if(conn)
  {
    /*
     * Seems we do. Means we've been charged with handling this connection
     * only. new_connection() should've added it to readSet.
     */
    forkConnections = 1; // This is in effect a forked connection.
    quit_all_servers ();
    forker = -1;
    fakeChroot = doFakeChroot;
  }
  else if(errno && errno != ENOTSOCK)
  {
    /* Error! */
    return 1;
  }
  else if (!reloadConfig)
  {
    /* We're in daemon mode. Setup as appropriate. */
    
    if(doFork)
    {
      switch(fork())
      {
      case 0:
	/* Daemon. Keep running. */
	break;
      case -1:
	/* Error */
	syslog(LOG_ERR, "fork: %m");
	return 1;
      default:
	/* Parent. We're done. */
	return 0;
      }
      close (0);
      close (1);
      events_init ();
    }
    
    create_server_sockets ();
    if (!event_channels ())
    {
      fprintf (stderr, "Failed to create any server sockets.\n");
      if (!debug)
	fprintf (stderr, "See syslog or start with -D for more information.\n");
      exit (1);
    }
    if (doFork && !debug)
    {
      close (2);
      setsid ();
    }
    
#ifdef HAVE_SETENV
    setenv ("MOFTP_CHILD", "1", 0); /* For children we launch. */
#else
    putenv (tstring ("MOFTP_CHILD=1"));
#endif
    
    fd = open (pidFile, O_WRONLY | O_CREAT | O_TRUNC, 0666);
    if (fd >= 0)
    {
      char pidBuf[10];
      
      sprintf (pidBuf, "%d\n", (int)getpid ());
      write (fd, pidBuf, strlen (pidBuf));
      close (fd);
      wrotePidFile = getpid ();
    }
    forker = getpid ();
    fakeChroot = 1;
  }
  else
  {
    reloadConfig = 0;
    create_server_sockets ();
  }
  
  /* Main loop. */
  while (running && event_channels ())
  {
    // If there's only the accServer listening, quit.
    if (accClientSock != -1)
    {
      if (event_channels () <= 1)
	break;
      if (event_channels () == 2 && accServer)
	break;
    }
    else
    {
      if (event_channels () == 1 && accServer)
	break;
      
      accClientSock = connect_accounter ();
      if (accClientSock >= 0)
      {
	add_read_fd (accClientSock, accounter_master_reply, NULL);
	accounter (accClientSock, "SET PID %d\n", (int)getpid ());
      }
    }
    
    while (waitpid (-1, NULL, WNOHANG) > 0)
      ;
    drop_privs();
    
    tfree_all();

#ifdef HAVE_SETPROCTITLE
    {
      const int numConns = count_connections ();
      
      if (numConns > 1)
	setproctitle ("%s%d connections", accServer? "accounter, " : "",
	      numConns);
      else if (numConns) // TODO: more info here?
	setproctitle ("%sconnection", accServer? "accounter, " : "");
      else if (accServer)
	setproctitle ("accounter");
      else
	setproctitle (NULL);
    }
#endif
    
    if (run_events ())
    {
      if (wrotePidFile == getpid ())
	unlink (pidFile);
      return 1;
    }
    
    check_idle();
    
    if (reloadConfig)
    {
      if (!super_privs (1))
	break;
      reloadConfig = 0;
    }
    
    if (closeServers)
      quit_all_servers ();
  }
  
  quit_all_servers ();
  
  if(running && reloadConfig && forker == getpid ())
    goto loadConfig;
  
  close_accounter (2);
  
  if (wrotePidFile == getpid ())
    unlink (pidFile);
  
  return 0;
}

int accounter_master_reply (int sock, void *user, int urgent)
{
  char buf[4097], *bp, *nbp;
  int l;
  
  l = read (sock, buf, sizeof (buf) - 1);
  if (l <= 0)
  {
    remove_read_fd (accClientSock);
    close (sock);
    accClientSock = -1;
    return 0;
  }
  buf[l] = 0;
  
  for (bp = buf; bp; bp = nbp)
  {
    nbp = strchr (bp, '\n');
    if (nbp)
      *nbp++ = 0;
    
    if (!bp[0])
      continue;
    
    if (!strcmp (bp, "RELOAD"))
    {
      if (!reloadConfig)
      {
	reloadConfig = 1;
	quit_all_servers ();
	quit_all_connections ("421 Server shutting down.");
	urgData = 1;
      }
    }
  }
  return 0;
}

int parse_options(int argc, char ***argv)
{
  int ch;
#ifdef HAVE_SETENV
  char dbuf[4];
#endif
  
  if(getenv("MOFTPD_CONFIG"))
  {
    configFile = pstring (getenv ("MOFTP_CONFIG"), NULL);
    if(!configFile)
    {
      perror ("pstring");
      return -1;
    }
  }
  if (getenv ("MOFTPD_DEBUG"))
    debug = atoi (getenv ("MOFTPD_DEBUG"));
  
  while((ch = getopt(argc, *argv, "c:C:dfD")) != -1)
  {
    switch(ch)
    {
    case 'C':
      configFile = palloc (strlen (templateDir) + strlen (optarg) + 6, NULL, NULL);
      if (!configFile)
      {
	perror ("palloc");
	return -1;
      }
      strcpy ((char*)configFile, templateDir);
      strcat ((char*)configFile, optarg);
      strcat ((char*)configFile, ".conf");
      break;
    case 'c':
      configFile = pstring(optarg, NULL);
      if(!configFile)
      {
	perror ("pstring");
	return -1;
      }
      break;
    case 'D':
      debug++;
      break;
    case 'd':
      doFork = 1;
      skipFd0 = 1;
      break;
    case 'f':
      if(!doFork)
	forkConnections = 0;
      else
	doFork = 0;
      break;
    default:
      usage();
    }
  }
  
  if(debug < 0 || debug > 999)
    debug = 999;
#ifdef HAVE_SETENV
  setenv("MOFTPD_CONFIG", configFile, 1);
  sprintf(dbuf, "%d", debug);
  setenv("MOFTPD_DEBUG", dbuf, 1);
#else
  {
    char *str = talloc (4097);
    
    if (str)
    {
      snprintf (str, 4096, "MOFTPD_CONFIG=%s", configFile);
      putenv (str);
      snprintf (str, 4096, "MOFTPD_DEBUG=%d", debug);
      putenv (str);
    }
  }
#endif
  
  *argv += optind;
  return argc - optind;
}

void usage(void)
{
  fprintf (stderr, "Usage: %s [-dDf] [-c <configfile>] [-C <template>]\n", getprogname ());
  exit(1);
}

int super_privs(int resetRoot)
{
  if(resetRoot && !set_root("/"))
    return -1;
  
  if(runningAsRoot)
  {
    seteuid(0);
    
    if(setuid(0))
      return -1;
    
    setgid(0);
  }
  
  return 0;
}

int drop_privs(void)
{
  if(runningAsRoot)
  {
    seteuid (0);
    setegid(unprivGid);
    seteuid(unprivUid);
  }
  set_access (NULL, 0, -1, -1);
  return 0;
}

int set_uid(uid_t uid, int permanent)
{
  if(runningAsRoot)
  {
    seteuid(0);
    if (permanent)
      return setuid (uid);
    return seteuid(uid);
  }
  return 0;
}

int set_gid(gid_t gid, int permanent)
{
  if(runningAsRoot)
  {
    int euid = geteuid(), res;
    
    seteuid(0);
    if (permanent)
      res = setgid (gid);
    else
      res = setegid (gid);
    seteuid(euid);
    return res;
  }
  return 0;
}

int set_groups(int ngroups, const gid_t *groups)
{
  if(runningAsRoot)
  {
    int euid = geteuid(), res;
    
    seteuid(0);
    res = setgroups(ngroups, groups);
    seteuid(euid);
    return res;
  }
  return 0;
}

void set_locale (const char *loc)
{
  char lbuf[3];
  
  if (loc)
  {
    strncpy (lbuf, loc, 2);
    lbuf[2] = 0;
    if (!setlocale (LC_ALL, lbuf))
      setlocale (LC_ALL, "en");
  }
  else
    setlocale (LC_ALL, "en");
}

int same_addr(const struct sockaddr *a1, const struct sockaddr *a2,
      int checkPort)
{
  if(a1->sa_family != a2->sa_family)
    return 0;
  
  switch(a1->sa_family)
  {
  case AF_INET:
    if(checkPort)
    {
      if(((struct sockaddr_in*)a1)->sin_port != ((struct sockaddr_in*)a2)->
	    sin_port)
	return 0;
    }
    return !memcmp(&((struct sockaddr_in*)a1)->sin_addr, &((struct
	  sockaddr_in*)a2)->sin_addr, 4);
  case AF_INET6:
    if(checkPort)
    {
      if(((struct sockaddr_in6*)a1)->sin6_port != ((struct sockaddr_in6*)a2)->
	    sin6_port)
	return 0;
    }
    return !memcmp(&((struct sockaddr_in6*)a1)->sin6_addr, &((struct
	  sockaddr_in6*)a2)->sin6_addr, 16);
  default:
    return 0;
  }
}

int check_range (const struct sockaddr *addr, const struct sockaddr *net,
      const struct sockaddr *mask)
{
  int32_t *a, *n, *m;
  int sz, i;
  
  if (addr->sa_family != net->sa_family || net->sa_family != mask->sa_family)
    return 0;
  
  switch (addr->sa_family)
  {
  case AF_INET:
    a = (int32_t*)&((struct sockaddr_in*)addr)->sin_addr;
    n = (int32_t*)&((struct sockaddr_in*)net)->sin_addr;
    m = (int32_t*)&((struct sockaddr_in*)mask)->sin_addr;
    sz = 1;
    break;
  case AF_INET6:
    a = (int32_t*)&((struct sockaddr_in6*)addr)->sin6_addr;
    n = (int32_t*)&((struct sockaddr_in6*)net)->sin6_addr;
    m = (int32_t*)&((struct sockaddr_in6*)mask)->sin6_addr;
    sz = 4;
    break;
  default:
    return 0;
  }
  for (i = 0; i < sz; i++, a++, n++, m++)
  {
    if ((*a & *m) != (*n & *m))
      return 0;
  }
  return 1;
}


syntax highlighted by Code2HTML, v. 0.9.1