/*-GNU-GPL-BEGIN-*
nepim - network pipemeter
Copyright (C) 2005 Everton da Silva Marques

nepim is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2, or (at your option)
any later version.

nepim is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with nepim; see the file COPYING.  If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
Boston, MA 02111-1307, USA.
*-GNU-GPL-END-*/


/* $Id: sock.c,v 1.25 2005/09/20 19:59:50 evertonm Exp $ */


#include <assert.h>
#include <netdb.h>
#include <unistd.h>
#include <fcntl.h>
#include <stdio.h>
#include <errno.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>

#include "sock.h"

#ifndef SOL_IP
#define SOL_IP 0
#endif
#ifndef SOL_IPV6
#define SOL_IPV6 41
#endif
#ifndef SOL_TCP
#define SOL_TCP 6
#endif
#ifndef IP_MTU
#define IP_MTU 14
#endif
#ifndef IP_ADD_MEMBERSHIP
#define IP_ADD_MEMBERSHIP 35
#endif
#ifndef IPV6_ADD_MEMBERSHIP
#define IPV6_ADD_MEMBERSHIP 20
#endif

#ifndef HAVE_IP_MREQN
struct ip_mreqn {
  struct in_addr imr_multiaddr;
  struct in_addr imr_address;
  int            imr_ifindex;
};
#endif

#ifndef HAVE_IPV6_MREQ
struct ipv6_mreq {
  struct in6_addr ipv6mr_multiaddr;
  int             ipv6mr_ifindex;
};
#endif

#define NEPIM_SOCK_ERR_NONE           (0)
#define NEPIM_SOCK_ERR_UNSPEC         (-1)
#define NEPIM_SOCK_ERR_SOCKET         (-2)
#define NEPIM_SOCK_ERR_BIND           (-3)
#define NEPIM_SOCK_ERR_LISTEN         (-4)
#define NEPIM_SOCK_ERR_CONNECT        (-5)
#define NEPIM_SOCK_ERR_BLOCK          (-6)
#define NEPIM_SOCK_ERR_UNBLOCK        (-7)
#define NEPIM_SOCK_ERR_UNLINGER       (-8)
#define NEPIM_SOCK_ERR_REUSE          (-9)
#define NEPIM_SOCK_ERR_NODELAY        (-10)
#define NEPIM_SOCK_ERR_PMTU           (-11)
#define NEPIM_SOCK_ERR_TTL            (-12)
#define NEPIM_SOCK_ERR_MCAST_TTL      (-13)
#define NEPIM_SOCK_ERR_MCAST_JOIN     (-14)
#define NEPIM_SOCK_ERR_WIN_RECV       (-15)
#define NEPIM_SOCK_ERR_WIN_SEND       (-16)

int nepim_sock_get_port(const struct sockaddr *addr)
{
  union {
    struct sockaddr_in inet;
    struct sockaddr_in6 inet6;
  } *sa = (void *) addr;

  assert(&(sa->inet.sin_port) == &(sa->inet6.sin6_port));
  assert(sa->inet.sin_port == sa->inet6.sin6_port);

  return ntohs(sa->inet.sin_port);
}

int nepim_sock_family(const struct sockaddr *addr)
{
  union {
    struct sockaddr_in inet;
    struct sockaddr_in6 inet6;
  } *sa = (void *) addr;

  assert(&(sa->inet.sin_family) == &(sa->inet6.sin6_family));
  assert(sa->inet.sin_family == sa->inet6.sin6_family);

  return sa->inet.sin_family;
}

void nepim_sock_dump_addr(char *buf, int buf_size, const struct sockaddr *addr)
{
  union {
    struct sockaddr_in inet;
    struct sockaddr_in6 inet6;
  } *sa = (void *) addr;
  const char *dst;
  int family;

  family = nepim_sock_family(addr);

  assert(PF_INET == AF_INET);
  assert(PF_INET6 == AF_INET6);

  switch (family) {
  case PF_INET:
    assert(buf_size >= INET_ADDRSTRLEN);
    dst = inet_ntop(family, &sa->inet.sin_addr, buf, buf_size);
    break;
  case PF_INET6:
    assert(buf_size >= INET6_ADDRSTRLEN);
    dst = inet_ntop(family, &sa->inet6.sin6_addr, buf, buf_size);
    break;
  default:
    assert(0);
  }

#ifndef NDEBUG
  {
    if (!dst) {
      fprintf(stderr,
	      "%s %s: inet_ntop() failure: errno=%d %s\n",
	      __FILE__, __PRETTY_FUNCTION__,
	      errno, strerror(errno));
    }
  }
#endif /* NDEBUG */

  assert(dst);
  assert(dst == buf);
}

int nepim_socket_block(int sd)
{
  long flags;

  flags = fcntl(sd, F_GETFL, 0);
  if (flags == -1)
    return NEPIM_SOCK_ERR_BLOCK;
  assert(flags >= 0);
  if (fcntl(sd, F_SETFL, flags & ~O_NONBLOCK))
    return NEPIM_SOCK_ERR_BLOCK;

  return NEPIM_SOCK_ERR_NONE;
}

int nepim_socket_nonblock(int sd)
{
  long flags;

  flags = fcntl(sd, F_GETFL, 0);
  if (flags == -1)
    return NEPIM_SOCK_ERR_UNBLOCK;
  assert(flags >= 0);
  if (fcntl(sd, F_SETFL, flags | O_NONBLOCK))
    return NEPIM_SOCK_ERR_UNBLOCK;

  return NEPIM_SOCK_ERR_NONE;
}

int nepim_socket_pmtu(int sd, int pmtu_mode)
{
  if (pmtu_mode < 0)
    return NEPIM_SOCK_ERR_NONE;

#ifdef IP_MTU_DISCOVER
  return setsockopt(sd, SOL_IP, IP_MTU_DISCOVER, &pmtu_mode, sizeof(pmtu_mode));
#else
  return NEPIM_SOCK_ERR_NONE;
#endif
}

int nepim_socket_ttl(int sd, int ttl)
{
  if (ttl < 0)
    return NEPIM_SOCK_ERR_NONE;

  return setsockopt(sd, SOL_IP, IP_TTL, &ttl, sizeof(ttl));
}

int nepim_socket_mcast_ttl(int sd, int mc_ttl)
{
#ifdef HAVE_UCHAR_MCAST_TTL
  unsigned char mcast_ttl;
#else
  int mcast_ttl;
#endif

  if (mc_ttl < 0)
    return NEPIM_SOCK_ERR_NONE;

  assert(mc_ttl >= 0);
  assert(mc_ttl < 256);

  mcast_ttl = mc_ttl;

  return setsockopt(sd, SOL_IP, IP_MULTICAST_TTL, &mcast_ttl, sizeof(mcast_ttl));
}

static int socket_set_win_recv(int sd, int win_recv)
{
  if (win_recv < 0)
    return NEPIM_SOCK_ERR_NONE;

  if (setsockopt(sd, SOL_SOCKET, SO_RCVBUF, &win_recv, sizeof(win_recv)))
    return NEPIM_SOCK_ERR_WIN_RECV;

  return NEPIM_SOCK_ERR_NONE;
}

static int socket_set_win_send(int sd, int win_send)
{
  if (win_send < 0)
    return NEPIM_SOCK_ERR_NONE;

  if (setsockopt(sd, SOL_SOCKET, SO_SNDBUF, &win_send, sizeof(win_send)))
    return NEPIM_SOCK_ERR_WIN_SEND;
  
  return NEPIM_SOCK_ERR_NONE;
}

static int create_socket(int domain, int type, int protocol, int pmtu_mode, int ttl,
			 int win_recv, int win_send)
{
  int sd;
  int result;

  sd = socket(domain, type, protocol);
  if (sd < 0)
    return NEPIM_SOCK_ERR_SOCKET;

  if (type == SOCK_STREAM) {
    result = nepim_socket_tcp_opt(sd);
    if (result) {
      close(sd);
      return result;
    }
  }

  result = nepim_socket_opt(sd, pmtu_mode, ttl);
  if (result) {
    close(sd);
    return result;
  }

  result = socket_set_win_recv(sd, win_recv);
  if (result) {
    close(sd);
    return result;
  }

  result = socket_set_win_send(sd, win_send);
  if (result) {
    close(sd);
    return result;
  }

  return sd;
}

static int socket_mcast_join(int sd, int family, struct sockaddr *addr, int addr_len)
{
  union {
    struct sockaddr_in inet;
    struct sockaddr_in6 inet6;
  } *sa = (void *) addr;

  switch (family) {
  case PF_INET:
    {
      struct ip_mreqn opt;

      opt.imr_multiaddr = sa->inet.sin_addr;
      opt.imr_address.s_addr = htons(INADDR_ANY);
      opt.imr_ifindex = 0;

      return setsockopt(sd, SOL_IP, IP_ADD_MEMBERSHIP, &opt, sizeof(opt));
    }
    break;
  case PF_INET6:
    {
      struct ipv6_mreq opt;

      assert(sizeof(opt.ipv6mr_multiaddr.s6_addr) == 
	     sizeof(sa->inet6.sin6_addr.s6_addr));
      memcpy(&opt.ipv6mr_multiaddr.s6_addr, 
	     &sa->inet6.sin6_addr.s6_addr, 
	     sizeof(opt.ipv6mr_multiaddr.s6_addr));
      opt.ipv6mr_interface = 0;

      return setsockopt(sd, SOL_IPV6, IPV6_ADD_MEMBERSHIP, &opt, sizeof(opt));
    }
    break;
  default:
    assert(0);
  }

  assert(0);

  return NEPIM_SOCK_ERR_MCAST_JOIN;
}

int nepim_create_socket(struct sockaddr *addr,
			int addr_len,
			int family,
			int type,
			int protocol,
			int pmtu_mode,
			int ttl,
			int mcast_join,
			int win_recv,
			int win_send)
{
  int sd;
  int result;

  sd = create_socket(family, type, protocol, pmtu_mode, ttl,
		     win_recv, win_send);
  if (sd < 0)
    return sd;

  if (mcast_join) {
    assert(type == SOCK_DGRAM);
    assert(protocol == IPPROTO_UDP);
    
    if (socket_mcast_join(sd, family, addr, addr_len)) {
      char buf[INET6_ADDRSTRLEN];
      int e = errno;

      nepim_sock_dump_addr(buf, sizeof(buf), addr);

      fprintf(stderr, "%s: %s: mcast_join(%d,%s,%d): errno=%d: %s\n",
	      __FILE__, __PRETTY_FUNCTION__,
	      sd, buf, nepim_sock_get_port(addr),
	      e, strerror(e));

      close(sd);
      return NEPIM_SOCK_ERR_MCAST_JOIN;
    }
  }

  result = nepim_socket_nonblock(sd);
  if (result) {
    close(sd);
    return result;
  }

  if (bind(sd, addr, addr_len)) {
    char buf[INET6_ADDRSTRLEN];
    int e = errno;

    nepim_sock_dump_addr(buf, sizeof(buf), addr);

    fprintf(stderr, "%s: %s: bind(%d,%s,%d): errno=%d: %s\n",
	    __FILE__, __PRETTY_FUNCTION__,
	    sd, buf, nepim_sock_get_port(addr),
	    e, strerror(e));

    close(sd);
    return NEPIM_SOCK_ERR_BIND;
  }

  return sd;
}

int nepim_create_listener_socket(struct sockaddr *addr,
				 int addr_len,
				 int family,
				 int type,
				 int protocol,
				 int backlog,
				 int pmtu_mode,
				 int ttl,
				 int win_recv,
				 int win_send)
{
  int sd;

  sd = nepim_create_socket(addr, addr_len, family,
			   type, protocol, pmtu_mode, 
			   ttl, 0, win_recv, win_send);
  if (sd < 0)
    return sd;

  if (listen(sd, backlog)) {
    close(sd);
    return NEPIM_SOCK_ERR_LISTEN;
  }

  return sd;
}

static int unlinger(int sd)
{
  struct linger opt;

  opt.l_onoff = 0;  /* active? */
  opt.l_linger = 0; /* seconds */

  return setsockopt(sd, SOL_SOCKET, SO_LINGER, &opt, sizeof(opt));
}

static int reuse(int sd)
{
  int opt = 1;

  return setsockopt(sd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
}

static int tcp_nodelay(int sd)
{
  int opt = 1;

  return setsockopt(sd, SOL_TCP, TCP_NODELAY, &opt, sizeof(opt));
}

int nepim_socket_opt(int sd, int pmtu_mode, int ttl)
{
  if (unlinger(sd))
    return NEPIM_SOCK_ERR_UNLINGER;

  if (reuse(sd))
    return NEPIM_SOCK_ERR_REUSE;

  if (nepim_socket_pmtu(sd, pmtu_mode))
    return NEPIM_SOCK_ERR_PMTU;

  if (nepim_socket_ttl(sd, ttl))
    return NEPIM_SOCK_ERR_TTL;

  return NEPIM_SOCK_ERR_NONE;
}

int nepim_socket_tcp_opt(int sd)
{
  if (tcp_nodelay(sd))
    return NEPIM_SOCK_ERR_NODELAY;

  return NEPIM_SOCK_ERR_NONE;
}

int nepim_connect_client_socket(struct sockaddr *addr,
				int addr_len,
				int family,
				int type,
				int protocol,
				int pmtu_mode,
				int ttl,
				int win_recv,
				int win_send)
{
  int sd;
  int result;

  sd = create_socket(family, type, protocol, pmtu_mode, ttl,
		     win_recv, win_send);
  if (sd < 0)
    return sd;

#ifdef SO_BSDCOMPAT
  /*
   * We don't want Linux ECONNREFUSED on UDP sockets
   */
  if (protocol == IPPROTO_UDP) {
    int one = 1;
    if (setsockopt(sd, SOL_SOCKET, SO_BSDCOMPAT, &one, sizeof(one)))
      return -1;
 }
#endif /* Linux SO_BSDCOMPAT */

  result = nepim_socket_block(sd);
  if (result) {
    close(sd);
    return result;
  }

  fprintf(stderr, 
	  "DEBUG FIXME %s %s slow synchronous connect(port=%d)\n",
	  __FILE__, __PRETTY_FUNCTION__, nepim_sock_get_port(addr));

  if (connect(sd, addr, addr_len)) {
    char buf[INET6_ADDRSTRLEN];
    int e = errno;

    nepim_sock_dump_addr(buf, sizeof(buf), addr);

    fprintf(stderr, "%s: %s: connect(%d,%s,%d): errno=%d: %s\n",
	    __FILE__, __PRETTY_FUNCTION__,
	    sd, buf, nepim_sock_get_port(addr),
	    e, strerror(e));

    close(sd);
    return NEPIM_SOCK_ERR_CONNECT;
  }

  result = nepim_socket_nonblock(sd);
  if (result) {
    close(sd);
    return result;
  }

  return sd;
}

int nepim_socket_pmtu_get_mode(int sd)
{
#ifdef IP_MTU_DISCOVER
  int mode;
  socklen_t optlen = sizeof(mode);

  if (getsockopt(sd, SOL_IP, IP_MTU_DISCOVER, &mode, &optlen))
    return NEPIM_SOCK_ERR_PMTU;
  
  assert(optlen == sizeof(mode));

  return mode;
#else
  return NEPIM_SOCK_ERR_PMTU;
#endif
}

int nepim_socket_pmtu_get_mtu(int sd)
{
  int mtu;
  socklen_t optlen = sizeof(mtu);

  if (getsockopt(sd, SOL_IP, IP_MTU, &mtu, &optlen))
    return NEPIM_SOCK_ERR_PMTU;

  assert(optlen == sizeof(mtu));

  return mtu;
}

int nepim_socket_get_ttl(int sd)
{
  int ttl;
  socklen_t optlen = sizeof(ttl);

  if (getsockopt(sd, SOL_IP, IP_TTL, &ttl, &optlen))
    return NEPIM_SOCK_ERR_TTL;

  assert(optlen == sizeof(ttl));

  return ttl;
}

static int socket_mcast_get_ttl(int sd)
{
#ifdef HAVE_UCHAR_MCAST_TTL
  unsigned char mcast_ttl;
#else
  int mcast_ttl;
#endif
  socklen_t optlen = sizeof(mcast_ttl);

  if (getsockopt(sd, SOL_IP, IP_MULTICAST_TTL, &mcast_ttl, &optlen))
    return NEPIM_SOCK_ERR_MCAST_TTL;

  assert(optlen == sizeof(mcast_ttl));

  return mcast_ttl;
}

static int socket_get_win_recv(int sd)
{
  int win_recv;
  socklen_t optlen = sizeof(win_recv);

  if (getsockopt(sd, SOL_SOCKET, SO_RCVBUF, &win_recv, &optlen))
    return NEPIM_SOCK_ERR_WIN_RECV;
  
  assert(optlen == sizeof(win_recv));

  return win_recv;
}

static int socket_get_win_send(int sd)
{
  int win_send;
  socklen_t optlen = sizeof(win_send);

  if (getsockopt(sd, SOL_SOCKET, SO_SNDBUF, &win_send, &optlen))
    return NEPIM_SOCK_ERR_WIN_SEND;
  
  assert(optlen == sizeof(win_send));

  return win_send;
}

void nepim_sock_show_opt(FILE *out, int sd)
{
  int pmtud_mode;
  int mtu;
  int ttl;
  int mcast_ttl;
  int win_recv;
  int win_send;

  pmtud_mode = nepim_socket_pmtu_get_mode(sd);
  mtu = nepim_socket_pmtu_get_mtu(sd);
  ttl = nepim_socket_get_ttl(sd);
  mcast_ttl = socket_mcast_get_ttl(sd);
  win_recv = socket_get_win_recv(sd);
  win_send = socket_get_win_send(sd);

  fprintf(out, 
	  "%d: pmtud_mode=%d path_mtu=%d ttl=%d mcast_ttl=%d win_recv=%d win_send=%d\n",
	  sd, pmtud_mode, mtu, ttl, mcast_ttl, win_recv, win_send);
}


syntax highlighted by Code2HTML, v. 0.9.1