/* net6 - Library providing IPv4/IPv6 network access
* Copyright (C) 2005, 2006 Armin Burgmeier / 0x539 dev group
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
#include <fcntl.h>
#include "config.hpp"
#include "encrypt.hpp"
namespace
{
const unsigned int DH_BITS = 1024;
net6::gnutls_session_t create_session(net6::gnutls_connection_end_t end)
{
net6::gnutls_session_t session;
gnutls_init(&session, end);
return session;
}
#ifdef WIN32
// Required to turn WSA error codes into errno.
ssize_t net6_win32_send_func(net6::gnutls_transport_ptr_t ptr,
const void* data,
size_t size)
{
ssize_t ret = ::send(
reinterpret_cast<SOCKET>(ptr),
static_cast<const char*>(data),
size,
0
);
int error = WSAGetLastError();
if(error == WSAEWOULDBLOCK) errno = EAGAIN;
if(error == WSAEINTR) errno = EINTR;
// Ensures that a second call to WSAGetLastError
// has the same result
WSASetLastError(error);
return ret;
}
ssize_t net6_win32_recv_func(net6::gnutls_transport_ptr_t ptr,
void* data,
size_t size)
{
ssize_t ret = ::recv(
reinterpret_cast<SOCKET>(ptr),
static_cast<char*>(data),
size,
0
);
int error = WSAGetLastError();
if(error == WSAEWOULDBLOCK) errno = EAGAIN;
if(error == WSAEINTR) errno = EINTR;
// Ensures that a second call to WSAGetLastError
// has the same result
WSASetLastError(error);
return ret;
}
#else
// Send data with MSG_NOSIGNAL
ssize_t net6_unix_send_func(net6::gnutls_transport_ptr_t ptr,
const void* data,
size_t size)
{
// TODO: How to properly get the fd from ptr?
return send(
static_cast<int>(reinterpret_cast<intptr_t>(ptr)),
data,
size,
#ifdef HAVE_MSG_NOSIGNAL
// Linux
MSG_NOSIGNAL
#else
// Plain BSD
0
#endif
);
}
#endif
typedef net6::tcp_encrypted_socket_base::size_type
io_size_type;
typedef net6::tcp_encrypted_socket_base::handshake_state
io_handshake_state;
template<
typename buffer_type,
ssize_t(*func)(net6::gnutls_session_t, buffer_type, size_t)
> io_size_type io_impl(const net6::gnutls_session_t session,
buffer_type buf,
io_size_type len,
io_handshake_state state)
{
ssize_t ret;
switch(state)
{
case net6::tcp_encrypted_socket_base::DEFAULT:
throw std::logic_error(
"net6::encrypt.cpp:io_impl:\n"
"Handshake not yet performed"
);
break;
case net6::tcp_encrypted_socket_base::HANDSHAKING:
throw std::logic_error(
"net6::encrypt.cpp:io_impl:\n"
"IO tried while handshaking"
);
break;
case net6::tcp_encrypted_socket_base::HANDSHAKED:
ret = func(session, buf, len);
if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
func(session, NULL, 0);
if(ret < 0)
throw net6::error(net6::error::GNUTLS, ret);
break;
}
return ret;
}
}
net6::dh_params::dh_params():
params(NULL)
{
gnutls_dh_params_init(¶ms);
gnutls_dh_params_generate2(params, DH_BITS);
}
net6::dh_params::dh_params(gnutls_dh_params_t initial_params):
params(initial_params)
{
}
net6::dh_params::~dh_params()
{
gnutls_dh_params_deinit(params);
}
net6::gnutls_dh_params_t net6::dh_params::cobj()
{
return params;
}
const net6::gnutls_dh_params_t net6::dh_params::cobj() const
{
return params;
}
net6::tcp_encrypted_socket_base::
tcp_encrypted_socket_base(socket_type cobj,
gnutls_session_t sess):
tcp_client_socket(cobj), session(sess), state(DEFAULT)
{
const int kx_prio[] = { GNUTLS_KX_ANON_DH, 0 };
gnutls_set_default_priority(session);
gnutls_kx_set_priority(session, kx_prio);
gnutls_transport_set_ptr(
session,
reinterpret_cast<gnutls_transport_ptr_t>(cobj)
);
#ifdef WIN32
gnutls_transport_set_pull_function(
session,
net6_win32_recv_func
);
gnutls_transport_set_push_function(
session,
net6_win32_send_func
);
#else
gnutls_transport_set_push_function(
session,
net6_unix_send_func
);
#endif
gnutls_transport_set_lowat(session, 0);
}
net6::tcp_encrypted_socket_base::~tcp_encrypted_socket_base()
{
gnutls_bye(session, GNUTLS_SHUT_WR);
gnutls_deinit(session);
}
bool net6::tcp_encrypted_socket_base::handshake()
{
if(state == HANDSHAKED)
{
throw std::logic_error(
"net6::tcp_encrypted_socket_base::handshake:\n"
"Handshake has already been performed"
);
}
if(state == DEFAULT)
{
#ifdef WIN32
u_long iMode = 1;
if(ioctlsocket(cobj(), FIONBIO, &iMode) == SOCKET_ERROR)
throw net6::error(net6::error::SYSTEM);
// TODO: How to find out whether the socket is in blocking
// mode?
was_blocking = false;
#else
// Make socket nonblocking to allow to call handshake
// multiple times
int flags = fcntl(cobj(), F_GETFL);
if(fcntl(cobj(), F_SETFL, flags | O_NONBLOCK) == -1)
throw net6::error(net6::error::SYSTEM);
was_blocking = ((flags & O_NONBLOCK) == 0);
#endif
state = HANDSHAKING;
}
int ret = gnutls_handshake(session);
if(ret == 0)
{
if(was_blocking)
{
// Remove nonblocking state for further handling,
// so the socket behaves like a nonencrypted tcp
// client socket.
#ifdef WIN32
u_long iMode = 0;
if(ioctlsocket(cobj(), FIONBIO, &iMode) == SOCKET_ERROR)
throw net6::error(net6::error::SYSTEM);
#else
int flags = fcntl(cobj(), F_GETFL);
if(fcntl(cobj(), F_SETFL, flags & ~O_NONBLOCK) == -1)
throw net6::error(net6::error::SYSTEM);
#endif
}
state = HANDSHAKED;
return true;
}
if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
return false;
throw net6::error(net6::error::GNUTLS, ret);
}
bool net6::tcp_encrypted_socket_base::get_dir() const
{
return gnutls_record_get_direction(session) == 1;
}
net6::tcp_encrypted_socket_base::size_type
net6::tcp_encrypted_socket_base::get_pending() const
{
return gnutls_record_check_pending(session);
}
net6::tcp_encrypted_socket_base::size_type
net6::tcp_encrypted_socket_base::send(const void* buf, size_type len) const
{
return ::io_impl<const void*, gnutls_record_send>(
session, buf, len, state
);
}
net6::tcp_encrypted_socket_base::size_type
net6::tcp_encrypted_socket_base::recv(void* buf, size_type len) const
{
return ::io_impl<void*, gnutls_record_recv>(
session, buf, len, state
);
}
net6::tcp_encrypted_socket_client::
tcp_encrypted_socket_client(tcp_client_socket& sock):
tcp_encrypted_socket_base(sock.cobj(), create_session(GNUTLS_CLIENT) )
{
sock.invalidate();
gnutls_anon_allocate_client_credentials(&anoncred);
gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
gnutls_dh_set_prime_bits(session, DH_BITS);
}
net6::tcp_encrypted_socket_client::~tcp_encrypted_socket_client()
{
gnutls_anon_free_client_credentials(anoncred);
}
net6::tcp_encrypted_socket_server::
tcp_encrypted_socket_server(tcp_client_socket& sock):
tcp_encrypted_socket_base(sock.cobj(), create_session(GNUTLS_SERVER) ),
own_params(new dh_params)
{
sock.invalidate();
gnutls_anon_allocate_server_credentials(&anoncred);
gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
gnutls_anon_set_server_dh_params(anoncred, own_params->cobj());
}
net6::tcp_encrypted_socket_server::
tcp_encrypted_socket_server(tcp_client_socket& sock,
dh_params& params):
tcp_encrypted_socket_base(sock.cobj(), create_session(GNUTLS_SERVER) )
{
sock.invalidate();
gnutls_anon_allocate_server_credentials(&anoncred);
gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
gnutls_anon_set_server_dh_params(anoncred, params.cobj() );
}
net6::tcp_encrypted_socket_server::~tcp_encrypted_socket_server()
{
gnutls_anon_free_server_credentials(anoncred);
}
syntax highlighted by Code2HTML, v. 0.9.1