/* net6 - Library providing IPv4/IPv6 network access
 * Copyright (C) 2005, 2006 Armin Burgmeier / 0x539 dev group
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free
 * Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

#include <fcntl.h>

#include "config.hpp"
#include "encrypt.hpp"

namespace
{
	const unsigned int DH_BITS = 1024;

	net6::gnutls_session_t create_session(net6::gnutls_connection_end_t end)
	{
		net6::gnutls_session_t session;
		gnutls_init(&session, end);
		return session;
	}

#ifdef WIN32
	// Required to turn WSA error codes into errno.
	ssize_t net6_win32_send_func(net6::gnutls_transport_ptr_t ptr,
	                             const void* data,
	                             size_t size)
	{
		ssize_t ret = ::send(
			reinterpret_cast<SOCKET>(ptr),
			static_cast<const char*>(data),
			size,
			0
		);

		int error = WSAGetLastError();
		if(error == WSAEWOULDBLOCK) errno = EAGAIN;
		if(error == WSAEINTR) errno = EINTR;

		// Ensures that a second call to WSAGetLastError
		// has the same result
		WSASetLastError(error);
		return ret;
	}

	ssize_t net6_win32_recv_func(net6::gnutls_transport_ptr_t ptr,
	                             void* data,
	                             size_t size)
	{
		ssize_t ret = ::recv(
			reinterpret_cast<SOCKET>(ptr),
			static_cast<char*>(data),
			size,
			0
		);

		int error = WSAGetLastError();
		if(error == WSAEWOULDBLOCK) errno = EAGAIN;
		if(error == WSAEINTR) errno = EINTR;

		// Ensures that a second call to WSAGetLastError
		// has the same result
		WSASetLastError(error);
		return ret;
	}
#else
	// Send data with MSG_NOSIGNAL
	ssize_t net6_unix_send_func(net6::gnutls_transport_ptr_t ptr,
	                            const void* data,
	                            size_t size)
	{
		// TODO: How to properly get the fd from ptr?
		return send(
			static_cast<int>(reinterpret_cast<intptr_t>(ptr)),
			data,
			size,
#ifdef HAVE_MSG_NOSIGNAL
			// Linux
			MSG_NOSIGNAL
#else
			// Plain BSD
			0
#endif
		);
	}
#endif

	typedef net6::tcp_encrypted_socket_base::size_type
		io_size_type;
	typedef net6::tcp_encrypted_socket_base::handshake_state
		io_handshake_state;

	template<
		typename buffer_type,
		ssize_t(*func)(net6::gnutls_session_t, buffer_type, size_t)
	> io_size_type io_impl(const net6::gnutls_session_t session,
	                       buffer_type buf,
	                       io_size_type len,
	                       io_handshake_state state)
	{
		ssize_t ret;
		switch(state)
		{
		case net6::tcp_encrypted_socket_base::DEFAULT:
			throw std::logic_error(
				"net6::encrypt.cpp:io_impl:\n"
				"Handshake not yet performed"
			);

			break;
		case net6::tcp_encrypted_socket_base::HANDSHAKING:
			throw std::logic_error(
				"net6::encrypt.cpp:io_impl:\n"
				"IO tried while handshaking"
			);

			break;
		case net6::tcp_encrypted_socket_base::HANDSHAKED:
			ret = func(session, buf, len);
			if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
				func(session, NULL, 0);

			if(ret < 0)
				throw net6::error(net6::error::GNUTLS, ret);

			break;
		}

		return ret;
	}
}

net6::dh_params::dh_params():
	params(NULL)
{
	gnutls_dh_params_init(&params);
	gnutls_dh_params_generate2(params, DH_BITS);
}

net6::dh_params::dh_params(gnutls_dh_params_t initial_params):
	params(initial_params)
{
}

net6::dh_params::~dh_params()
{
	gnutls_dh_params_deinit(params);
}

net6::gnutls_dh_params_t net6::dh_params::cobj()
{
	return params;
}

const net6::gnutls_dh_params_t net6::dh_params::cobj() const
{
	return params;
}


net6::tcp_encrypted_socket_base::
	tcp_encrypted_socket_base(socket_type cobj,
                                  gnutls_session_t sess):
	tcp_client_socket(cobj), session(sess), state(DEFAULT)
{
	const int kx_prio[] = { GNUTLS_KX_ANON_DH, 0 };

	gnutls_set_default_priority(session);
	gnutls_kx_set_priority(session, kx_prio);

	gnutls_transport_set_ptr(
		session,
		reinterpret_cast<gnutls_transport_ptr_t>(cobj)
	);

#ifdef WIN32
	gnutls_transport_set_pull_function(
		session,
		net6_win32_recv_func
	);

	gnutls_transport_set_push_function(
		session,
		net6_win32_send_func
	);
#else
	gnutls_transport_set_push_function(
		session,
		net6_unix_send_func
	);
#endif

	gnutls_transport_set_lowat(session, 0);
}

net6::tcp_encrypted_socket_base::~tcp_encrypted_socket_base()
{
	gnutls_bye(session, GNUTLS_SHUT_WR);
	gnutls_deinit(session);
}

bool net6::tcp_encrypted_socket_base::handshake()
{
	if(state == HANDSHAKED)
	{
		throw std::logic_error(
			"net6::tcp_encrypted_socket_base::handshake:\n"
			"Handshake has already been performed"
		);
	}

	if(state == DEFAULT)
	{
#ifdef WIN32
		u_long iMode = 1;
		if(ioctlsocket(cobj(), FIONBIO, &iMode) == SOCKET_ERROR)
			throw net6::error(net6::error::SYSTEM);

		// TODO: How to find out whether the socket is in blocking
		// mode?
		was_blocking = false;
#else
		// Make socket nonblocking to allow to call handshake
		// multiple times
		int flags = fcntl(cobj(), F_GETFL);
		if(fcntl(cobj(), F_SETFL, flags | O_NONBLOCK) == -1)
			throw net6::error(net6::error::SYSTEM);

		was_blocking = ((flags & O_NONBLOCK) == 0);
#endif

		state = HANDSHAKING;
	}

	int ret = gnutls_handshake(session);

	if(ret == 0)
	{
		if(was_blocking)
		{
			// Remove nonblocking state for further handling,
			// so the socket behaves like a nonencrypted tcp
			// client socket.
#ifdef WIN32
		u_long iMode = 0;
		if(ioctlsocket(cobj(), FIONBIO, &iMode) == SOCKET_ERROR)
			throw net6::error(net6::error::SYSTEM);
#else
			int flags = fcntl(cobj(), F_GETFL);
			if(fcntl(cobj(), F_SETFL, flags & ~O_NONBLOCK) == -1)
				throw net6::error(net6::error::SYSTEM);
#endif
		}

		state = HANDSHAKED;
		return true;
	}

	if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
		return false;

	throw net6::error(net6::error::GNUTLS, ret);
}

bool net6::tcp_encrypted_socket_base::get_dir() const
{
	return gnutls_record_get_direction(session) == 1;
}

net6::tcp_encrypted_socket_base::size_type
net6::tcp_encrypted_socket_base::get_pending() const
{
	return gnutls_record_check_pending(session);
}

net6::tcp_encrypted_socket_base::size_type
net6::tcp_encrypted_socket_base::send(const void* buf, size_type len) const
{
	return ::io_impl<const void*, gnutls_record_send>(
		session, buf, len, state
	);
}

net6::tcp_encrypted_socket_base::size_type
net6::tcp_encrypted_socket_base::recv(void* buf, size_type len) const
{
	return ::io_impl<void*, gnutls_record_recv>(
		session, buf, len, state
	);
}

net6::tcp_encrypted_socket_client::
	tcp_encrypted_socket_client(tcp_client_socket& sock):
	tcp_encrypted_socket_base(sock.cobj(), create_session(GNUTLS_CLIENT) )
{
	sock.invalidate();

	gnutls_anon_allocate_client_credentials(&anoncred);
	gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);

	gnutls_dh_set_prime_bits(session, DH_BITS);
}

net6::tcp_encrypted_socket_client::~tcp_encrypted_socket_client()
{
	gnutls_anon_free_client_credentials(anoncred);
}

net6::tcp_encrypted_socket_server::
	tcp_encrypted_socket_server(tcp_client_socket& sock):
	tcp_encrypted_socket_base(sock.cobj(), create_session(GNUTLS_SERVER) ),
	own_params(new dh_params)
{
	sock.invalidate();

	gnutls_anon_allocate_server_credentials(&anoncred);
	gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);

	gnutls_anon_set_server_dh_params(anoncred, own_params->cobj());
}

net6::tcp_encrypted_socket_server::
	tcp_encrypted_socket_server(tcp_client_socket& sock,
	                            dh_params& params):
	tcp_encrypted_socket_base(sock.cobj(), create_session(GNUTLS_SERVER) )
{
	sock.invalidate();

	gnutls_anon_allocate_server_credentials(&anoncred);
	gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);

	gnutls_anon_set_server_dh_params(anoncred, params.cobj() );
}

net6::tcp_encrypted_socket_server::~tcp_encrypted_socket_server()
{
	gnutls_anon_free_server_credentials(anoncred);
}


syntax highlighted by Code2HTML, v. 0.9.1