/*
 * sockprot.cpp - protocol-independant socket addresses handling
 * $Id: sockprot.cpp,v 1.4 2004/07/23 14:48:39 rdenisc Exp $
 */

/***********************************************************************
 *  Copyright (C) 2002-2004 Remi Denis-Courmont.                       *
 *  This program is free software; you can redistribute and/or modify  *
 *  it under the terms of the GNU General Public License as published  *
 *  by the Free Software Foundation; version 2 of the license.         *
 *                                                                     *
 *  This program is distributed in the hope that it will be useful,    *
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of     *
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.               *
 *  See the GNU General Public License for more details.               *
 *                                                                     *
 *  You should have received a copy of the GNU General Public License  *
 *  along with this program; if not, you can get it from:              *
 *  http://www.gnu.org/copyleft/gpl.html                               *
 ***********************************************************************/

#if HAVE_CONFIG_H
# include <config.h>
#endif

#include "gettext.h"
#include "secstr.h" // secure_strncpy()

#include <sys/types.h>
#if HAVE_SYS_SOCKET_H
# include <sys/socket.h>
#endif
#if HAVE_SYS_UN_H
# include <sys/un.h> // struct sockaddr_un
#endif

#include <errno.h> // errno
#include <unistd.h> // close(), unlink()
#include <fcntl.h> // fcntl()
#if HAVE_SYS_SELECT_H
# include <sys/select.h> // select()
#endif

#include "solve.h"
#include "sockprot.h"

/*** SocketAddress class implementation ***/
int SocketAddress::SetFromSocket (int fd, int flags, int side)
{
	if (res != NULL)
		freeai (res);
	ClearError ();

	struct sockaddr_storage addr;
	socklen_t len = sizeof (struct sockaddr_storage);

	if ((side)	? getpeername (fd, (struct sockaddr *)&addr, &len)
			: getsockname (fd, (struct sockaddr *)&addr, &len))
	{
		res = NULL;
		return SetError ();
	}

	res = makeai ((struct sockaddr *)&addr, len);
	if (res == NULL)
		return SetError ();

	if (SetError (getnamebyaddr ((struct sockaddr *)res->ai_addr,
					res->ai_addrlen,
					myhost, sizeof (myhost),
					myserv, sizeof (myserv), flags)))
	{
		secure_strncpy (myhost, _("unknown_node"), sizeof (myhost));
		secure_strncpy (myserv, _("unknown_service"), sizeof (myserv));
		return GetError ();
	}

	return 0;
}


int SocketAddress::SetByName (const char *host, const char *service,
				int flags, int af, int type, int proto)
{
	if (res != NULL)
		freeai (res);
	ClearError ();

	if (host != NULL)
		strncpy (myhost, host, sizeof (myhost));
	else
		*myhost = 0;
	secure_strncpy (myserv, (service != NULL) ? service : "0",
			sizeof (myserv));

	struct addrinfo info;
	memset (&info, 0, sizeof (info));
	if (flags)
		info.ai_flags = flags;
	if (af)
		info.ai_family = af;
	if (type)
		info.ai_socktype = type;
	if (proto)
		info.ai_protocol = proto;

	if (SetError (getaddrbyname (host, service, &info, &res)))
	{
		res = NULL;
		return GetError ();
	}

	//if ((flags & AI_CANONNAME) && (res != NULL))
	// -- broken if multiple results
	//	secure_strncpy (myhost, sizeof (myhost), res->ai_canonname);

	return 0;
}


SocketAddress::~SocketAddress (void)
{
	if (res != NULL)
		freeai (res);
}


SocketAddress::SocketAddress (const SocketAddress& src)
	: err_ai (src.err_ai), err_sys (src.err_sys)
{
	secure_strncpy (myhost, src.myhost, sizeof (myhost));
	secure_strncpy (myserv, src.myserv, sizeof (myserv));

	if (src.res != NULL)
	{
		res = copyai (src.res);
		if (res == NULL)
			SetError ();
	}
	else
		res = NULL;
}


int SocketAddress::Bind (void)
{
	for (struct addrinfo *ai = res; ai != NULL; ai = ai->ai_next)
	{
		int fd = socket (ai->ai_family, ai->ai_socktype,
					ai->ai_protocol);

		if (fd != -1)
		{
			int t = 1;

			setsockopt (fd, SOL_SOCKET, SO_REUSEADDR, &t,
					sizeof (t));
			if (bind (fd, ai->ai_addr, ai->ai_addrlen))
			{
				SetError ();
				close (fd);
			}
			else
			{
				ClearError ();
				return fd; // success!
			}
		}
		else
			SetError ();
	}

	return -1;
}


const char *
SocketAddress::StrError (void) const
{
	return (err_ai == EAI_SYSTEM)
		? strerror (err_sys)
		: gai_strerror (err_ai);
}



int
SocketAddress::SetError (int ai)
{
	if (ai == EAI_SYSTEM)
		err_sys = errno;
	return err_ai = ai;
}


int SocketAddress::Connect (int nonblock)
{
	for (struct addrinfo *ai = res; ai != NULL; ai = ai->ai_next)
	{
		int fd = socket (ai->ai_family, ai->ai_socktype,
				ai->ai_protocol), t = 1;

		setsockopt (fd, SOL_SOCKET, SO_REUSEADDR, &t, sizeof (t));

		if (fd != -1)
		{
#ifdef O_NONBLOCK
			int flags = fcntl (fd, F_GETFL);
			fcntl (fd, F_SETFL, O_NONBLOCK);
#endif

			if (connect (fd, ai->ai_addr, ai->ai_addrlen) == 0)
			{
				ClearError ();
#ifdef O_NONBLOCK
				if (flags != -1)
					fcntl (fd, F_SETFL, flags);
#endif
				return fd;
			}

			if (errno != EINPROGRESS)
			{
				SetError ();
				close (fd);
				continue;
			}

			if (nonblock)
				return fd;

			/* Waits until connection is established */
			fd_set s;
			FD_ZERO (&s);
			FD_SET (fd, &s);

			int err = 0;
			socklen_t len = sizeof (err);

			if (select (fd + 1, NULL, &s, NULL, NULL) != 1)
			{
				SetError ();
				close (fd);

				if (err_sys == EINTR)
					// aborts if interrupted
					return -1;
				continue;
			}

			if (getsockopt (fd, SOL_SOCKET, SO_ERROR, &err, &len))
			{
				SetError ();
				close (fd);
				continue;
			}

			if (len != sizeof (err))
				continue; // impossible error

			if (err)
			{
				errno = err;
				SetError ();
				close (fd);
				continue;
			}

			ClearError ();
#ifdef O_NONBLOCK
			if (flags != -1)
				fcntl (fd, F_SETFL, flags);
#endif
			return fd;
		}
	}

	return -1;
}


void
SocketAddress::CleanUp (void) const
{
#if HAVE_SYS_UN_H
	for (struct addrinfo *ai = res; ai != NULL; ai = ai->ai_next)
		if (ai->ai_family == PF_LOCAL)
			unlink (((const struct sockaddr_un *)ai->ai_addr)->sun_path);
#endif
}


syntax highlighted by Code2HTML, v. 0.9.1