/* net6 - Library providing IPv4/IPv6 network access * Copyright (C) 2005, 2006 Armin Burgmeier / 0x539 dev group * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation; either * version 2 of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free * Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ #include #include "config.hpp" #include "encrypt.hpp" namespace { const unsigned int DH_BITS = 1024; net6::gnutls_session_t create_session(net6::gnutls_connection_end_t end) { net6::gnutls_session_t session; gnutls_init(&session, end); return session; } #ifdef WIN32 // Required to turn WSA error codes into errno. ssize_t net6_win32_send_func(net6::gnutls_transport_ptr_t ptr, const void* data, size_t size) { ssize_t ret = ::send( reinterpret_cast(ptr), static_cast(data), size, 0 ); int error = WSAGetLastError(); if(error == WSAEWOULDBLOCK) errno = EAGAIN; if(error == WSAEINTR) errno = EINTR; // Ensures that a second call to WSAGetLastError // has the same result WSASetLastError(error); return ret; } ssize_t net6_win32_recv_func(net6::gnutls_transport_ptr_t ptr, void* data, size_t size) { ssize_t ret = ::recv( reinterpret_cast(ptr), static_cast(data), size, 0 ); int error = WSAGetLastError(); if(error == WSAEWOULDBLOCK) errno = EAGAIN; if(error == WSAEINTR) errno = EINTR; // Ensures that a second call to WSAGetLastError // has the same result WSASetLastError(error); return ret; } #else // Send data with MSG_NOSIGNAL ssize_t net6_unix_send_func(net6::gnutls_transport_ptr_t ptr, const void* data, size_t size) { // TODO: How to properly get the fd from ptr? return send( static_cast(reinterpret_cast(ptr)), data, size, #ifdef HAVE_MSG_NOSIGNAL // Linux MSG_NOSIGNAL #else // Plain BSD 0 #endif ); } #endif typedef net6::tcp_encrypted_socket_base::size_type io_size_type; typedef net6::tcp_encrypted_socket_base::handshake_state io_handshake_state; template< typename buffer_type, ssize_t(*func)(net6::gnutls_session_t, buffer_type, size_t) > io_size_type io_impl(const net6::gnutls_session_t session, buffer_type buf, io_size_type len, io_handshake_state state) { ssize_t ret; switch(state) { case net6::tcp_encrypted_socket_base::DEFAULT: throw std::logic_error( "net6::encrypt.cpp:io_impl:\n" "Handshake not yet performed" ); break; case net6::tcp_encrypted_socket_base::HANDSHAKING: throw std::logic_error( "net6::encrypt.cpp:io_impl:\n" "IO tried while handshaking" ); break; case net6::tcp_encrypted_socket_base::HANDSHAKED: ret = func(session, buf, len); if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) func(session, NULL, 0); if(ret < 0) throw net6::error(net6::error::GNUTLS, ret); break; } return ret; } } net6::dh_params::dh_params(): params(NULL) { gnutls_dh_params_init(¶ms); gnutls_dh_params_generate2(params, DH_BITS); } net6::dh_params::dh_params(gnutls_dh_params_t initial_params): params(initial_params) { } net6::dh_params::~dh_params() { gnutls_dh_params_deinit(params); } net6::gnutls_dh_params_t net6::dh_params::cobj() { return params; } const net6::gnutls_dh_params_t net6::dh_params::cobj() const { return params; } net6::tcp_encrypted_socket_base:: tcp_encrypted_socket_base(socket_type cobj, gnutls_session_t sess): tcp_client_socket(cobj), session(sess), state(DEFAULT) { const int kx_prio[] = { GNUTLS_KX_ANON_DH, 0 }; gnutls_set_default_priority(session); gnutls_kx_set_priority(session, kx_prio); gnutls_transport_set_ptr( session, reinterpret_cast(cobj) ); #ifdef WIN32 gnutls_transport_set_pull_function( session, net6_win32_recv_func ); gnutls_transport_set_push_function( session, net6_win32_send_func ); #else gnutls_transport_set_push_function( session, net6_unix_send_func ); #endif gnutls_transport_set_lowat(session, 0); } net6::tcp_encrypted_socket_base::~tcp_encrypted_socket_base() { gnutls_bye(session, GNUTLS_SHUT_WR); gnutls_deinit(session); } bool net6::tcp_encrypted_socket_base::handshake() { if(state == HANDSHAKED) { throw std::logic_error( "net6::tcp_encrypted_socket_base::handshake:\n" "Handshake has already been performed" ); } if(state == DEFAULT) { #ifdef WIN32 u_long iMode = 1; if(ioctlsocket(cobj(), FIONBIO, &iMode) == SOCKET_ERROR) throw net6::error(net6::error::SYSTEM); // TODO: How to find out whether the socket is in blocking // mode? was_blocking = false; #else // Make socket nonblocking to allow to call handshake // multiple times int flags = fcntl(cobj(), F_GETFL); if(fcntl(cobj(), F_SETFL, flags | O_NONBLOCK) == -1) throw net6::error(net6::error::SYSTEM); was_blocking = ((flags & O_NONBLOCK) == 0); #endif state = HANDSHAKING; } int ret = gnutls_handshake(session); if(ret == 0) { if(was_blocking) { // Remove nonblocking state for further handling, // so the socket behaves like a nonencrypted tcp // client socket. #ifdef WIN32 u_long iMode = 0; if(ioctlsocket(cobj(), FIONBIO, &iMode) == SOCKET_ERROR) throw net6::error(net6::error::SYSTEM); #else int flags = fcntl(cobj(), F_GETFL); if(fcntl(cobj(), F_SETFL, flags & ~O_NONBLOCK) == -1) throw net6::error(net6::error::SYSTEM); #endif } state = HANDSHAKED; return true; } if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) return false; throw net6::error(net6::error::GNUTLS, ret); } bool net6::tcp_encrypted_socket_base::get_dir() const { return gnutls_record_get_direction(session) == 1; } net6::tcp_encrypted_socket_base::size_type net6::tcp_encrypted_socket_base::get_pending() const { return gnutls_record_check_pending(session); } net6::tcp_encrypted_socket_base::size_type net6::tcp_encrypted_socket_base::send(const void* buf, size_type len) const { return ::io_impl( session, buf, len, state ); } net6::tcp_encrypted_socket_base::size_type net6::tcp_encrypted_socket_base::recv(void* buf, size_type len) const { return ::io_impl( session, buf, len, state ); } net6::tcp_encrypted_socket_client:: tcp_encrypted_socket_client(tcp_client_socket& sock): tcp_encrypted_socket_base(sock.cobj(), create_session(GNUTLS_CLIENT) ) { sock.invalidate(); gnutls_anon_allocate_client_credentials(&anoncred); gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred); gnutls_dh_set_prime_bits(session, DH_BITS); } net6::tcp_encrypted_socket_client::~tcp_encrypted_socket_client() { gnutls_anon_free_client_credentials(anoncred); } net6::tcp_encrypted_socket_server:: tcp_encrypted_socket_server(tcp_client_socket& sock): tcp_encrypted_socket_base(sock.cobj(), create_session(GNUTLS_SERVER) ), own_params(new dh_params) { sock.invalidate(); gnutls_anon_allocate_server_credentials(&anoncred); gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred); gnutls_anon_set_server_dh_params(anoncred, own_params->cobj()); } net6::tcp_encrypted_socket_server:: tcp_encrypted_socket_server(tcp_client_socket& sock, dh_params& params): tcp_encrypted_socket_base(sock.cobj(), create_session(GNUTLS_SERVER) ) { sock.invalidate(); gnutls_anon_allocate_server_credentials(&anoncred); gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred); gnutls_anon_set_server_dh_params(anoncred, params.cobj() ); } net6::tcp_encrypted_socket_server::~tcp_encrypted_socket_server() { gnutls_anon_free_server_credentials(anoncred); }