/** \file TcpSocket.cpp ** \date 2004-02-13 ** \author grymse@alhem.net **/ /* Copyright (C) 2004-2007 Anders Hedstrom This library is made available under the terms of the GNU GPL. If you would like to use this library in a closed-source application, a separate license agreement is available. For information about the closed-source license agreement for the C++ sockets library, please visit http://www.alhem.net/Sockets/license.html and/or email license@alhem.net. This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. */ #ifdef _WIN32 #pragma warning(disable:4786) #include #else #include #endif #include "ISocketHandler.h" #include #include #include #ifdef HAVE_OPENSSL #include #include #endif #include "TcpSocket.h" #include "Utility.h" #include "Ipv4Address.h" #include "Ipv6Address.h" #include "Mutex.h" #include "Uid.h" #include "IFile.h" #ifdef SOCKETS_NAMESPACE namespace SOCKETS_NAMESPACE { #endif #ifdef _DEBUG #define DEB(x) x #else #define DEB(x) #endif // statics #ifdef HAVE_OPENSSL SSLInitializer TcpSocket::m_ssl_init; #endif // thanks, q #ifdef _WIN32 #pragma warning(disable:4355) #endif TcpSocket::TcpSocket(ISocketHandler& h) : Socket(h) ,ibuf(*this, TCP_BUFSIZE_READ) ,obuf(*this, 32768) ,m_b_input_buffer_disabled(false) ,m_bytes_sent(0) ,m_bytes_received(0) ,m_skip_c(false) #ifdef SOCKETS_DYNAMIC_TEMP ,m_buf(new char[TCP_BUFSIZE_READ + 1]) #endif #ifdef HAVE_OPENSSL ,m_ssl_ctx(NULL) ,m_ssl(NULL) ,m_sbio(NULL) #endif #ifdef ENABLE_SOCKS4 ,m_socks4_state(0) #endif #ifdef ENABLE_RESOLVER ,m_resolver_id(0) #endif #ifdef ENABLE_RECONNECT ,m_b_reconnect(false) ,m_b_is_reconnect(false) #endif { } #ifdef _WIN32 #pragma warning(default:4355) #endif #ifdef _WIN32 #pragma warning(disable:4355) #endif TcpSocket::TcpSocket(ISocketHandler& h,size_t isize,size_t osize) : Socket(h) ,ibuf(*this, isize) ,obuf(*this, osize) ,m_b_input_buffer_disabled(false) ,m_bytes_sent(0) ,m_bytes_received(0) ,m_skip_c(false) #ifdef SOCKETS_DYNAMIC_TEMP ,m_buf(new char[TCP_BUFSIZE_READ + 1]) #endif #ifdef HAVE_OPENSSL ,m_ssl_ctx(NULL) ,m_ssl(NULL) ,m_sbio(NULL) #endif #ifdef ENABLE_SOCKS4 ,m_socks4_state(0) #endif #ifdef ENABLE_RESOLVER ,m_resolver_id(0) #endif #ifdef ENABLE_RECONNECT ,m_b_reconnect(false) ,m_b_is_reconnect(false) #endif { } #ifdef _WIN32 #pragma warning(default:4355) #endif TcpSocket::~TcpSocket() { if (!m_mes.empty()) { Handler().LogError(this, "TcpSocket destructor", 0, "Output buffer not empty", LOG_LEVEL_WARNING); } while (!m_mes.empty()) { MES *p = m_mes.front(); m_mes.pop_front(); delete p; } #ifdef SOCKETS_DYNAMIC_TEMP delete[] m_buf; #endif #ifdef HAVE_OPENSSL if (m_ssl) { SSL_free(m_ssl); } #endif } bool TcpSocket::Open(ipaddr_t ip,port_t port,bool skip_socks) { Ipv4Address ad(ip, port); Ipv4Address local; return Open(ad, local, skip_socks); } #ifdef ENABLE_IPV6 #ifdef IPPROTO_IPV6 bool TcpSocket::Open(in6_addr ip,port_t port,bool skip_socks) { Ipv6Address ad(ip, port); return Open(ad, skip_socks); } #endif #endif bool TcpSocket::Open(SocketAddress& ad,bool skip_socks) { Ipv4Address bind_ad("0.0.0.0", 0); return Open(ad, bind_ad, skip_socks); } bool TcpSocket::Open(SocketAddress& ad,SocketAddress& bind_ad,bool skip_socks) { if (!ad.IsValid()) { Handler().LogError(this, "Open", 0, "Invalid SocketAddress", LOG_LEVEL_FATAL); SetCloseAndDelete(); return false; } if (Handler().GetCount() >= FD_SETSIZE) { Handler().LogError(this, "Open", 0, "no space left in fd_set", LOG_LEVEL_FATAL); SetCloseAndDelete(); return false; } SetConnecting(false); #ifdef ENABLE_SOCKS4 SetSocks4(false); #endif // check for pooling #ifdef ENABLE_POOL if (Handler().PoolEnabled()) { ISocketHandler::PoolSocket *pools = Handler().FindConnection(SOCK_STREAM, "tcp", ad); if (pools) { CopyConnection( pools ); delete pools; SetIsClient(); SetCallOnConnect(); // ISocketHandler must call OnConnect Handler().LogError(this, "SetCallOnConnect", 0, "Found pooled connection", LOG_LEVEL_INFO); return true; } } #endif // if not, create new connection SOCKET s = CreateSocket(ad.GetFamily(), SOCK_STREAM, "tcp"); if (s == INVALID_SOCKET) { return false; } // socket must be nonblocking for async connect if (!SetNonblocking(true, s)) { SetCloseAndDelete(); closesocket(s); return false; } #ifdef ENABLE_POOL SetIsClient(); // client because we connect #endif SetClientRemoteAddress(ad); int n = 0; if (bind_ad.GetPort() != 0) { bind(s, bind_ad, bind_ad); } #ifdef ENABLE_SOCKS4 if (!skip_socks && GetSocks4Host() && GetSocks4Port()) { Ipv4Address sa(GetSocks4Host(), GetSocks4Port()); { std::string sockshost; Utility::l2ip(GetSocks4Host(), sockshost); Handler().LogError(this, "Open", 0, "Connecting to socks4 server @ " + sockshost + ":" + Utility::l2string(GetSocks4Port()), LOG_LEVEL_INFO); } SetSocks4(); n = connect(s, sa, sa); SetRemoteAddress(sa); } else #endif { n = connect(s, ad, ad); SetRemoteAddress(ad); } if (n == -1) { // check error code that means a connect is in progress #ifdef _WIN32 if (Errno == WSAEWOULDBLOCK) #else if (Errno == EINPROGRESS) #endif { Attach(s); SetConnecting( true ); // this flag will control fd_set's } else #ifdef ENABLE_SOCKS4 if (Socks4() && Handler().Socks4TryDirect() ) // retry { closesocket(s); return Open(ad, true); } else #endif #ifdef ENABLE_RECONNECT if (Reconnect()) { Handler().LogError(this, "connect: failed, reconnect pending", Errno, StrError(Errno), LOG_LEVEL_INFO); Attach(s); SetConnecting( true ); // this flag will control fd_set's } else #endif { Handler().LogError(this, "connect: failed", Errno, StrError(Errno), LOG_LEVEL_FATAL); SetCloseAndDelete(); closesocket(s); return false; } } else { Attach(s); SetCallOnConnect(); // ISocketHandler must call OnConnect } // 'true' means connected or connecting(not yet connected) // 'false' means something failed return true; //!Connecting(); } bool TcpSocket::Open(const std::string &host,port_t port) { #ifdef ENABLE_IPV6 #ifdef IPPROTO_IPV6 if (IsIpv6()) { #ifdef ENABLE_RESOLVER if (!Handler().ResolvedEnabled() || Utility::isipv4(host) ) { #endif in6_addr a; if (!Utility::u2ip(host, a)) { SetCloseAndDelete(); return false; } Ipv6Address ad(a, port); Ipv6Address local; return Open(ad, local); #ifdef ENABLE_RESOLVER } m_resolver_id = Resolve6(host, port); return true; #endif } #endif #endif #ifdef ENABLE_RESOLVER if (!Handler().ResolverEnabled() || Utility::isipv4(host) ) { #endif ipaddr_t l; if (!Utility::u2ip(host,l)) { SetCloseAndDelete(); return false; } Ipv4Address ad(l, port); Ipv4Address local; return Open(ad, local); #ifdef ENABLE_RESOLVER } // resolve using async resolver thread m_resolver_id = Resolve(host, port); return true; #endif } #ifdef ENABLE_RESOLVER void TcpSocket::OnResolved(int id,ipaddr_t a,port_t port) { if (id == m_resolver_id) { if (a && port) { Ipv4Address ad(a, port); Ipv4Address local; if (Open(ad, local)) { if (!Handler().Valid(this)) { Handler().Add(this); } } } else { Handler().LogError(this, "OnResolved", 0, "Resolver failed", LOG_LEVEL_FATAL); SetCloseAndDelete(); } } else { Handler().LogError(this, "OnResolved", id, "Resolver returned wrong job id", LOG_LEVEL_FATAL); SetCloseAndDelete(); } } #ifdef ENABLE_IPV6 void TcpSocket::OnResolved(int id,in6_addr& a,port_t port) { if (id == m_resolver_id) { Ipv6Address ad(a, port); if (ad.IsValid()) { Ipv6Address local; if (Open(ad, local)) { if (!Handler().Valid(this)) { Handler().Add(this); } } } } else { Handler().LogError(this, "OnResolved", id, "Resolver returned wrong job id", LOG_LEVEL_FATAL); SetCloseAndDelete(); } } #endif #endif void TcpSocket::OnRead() { int n = 0; #ifdef SOCKETS_DYNAMIC_TEMP char *buf = m_buf; #else char buf[TCP_BUFSIZE_READ]; #endif #ifdef HAVE_OPENSSL if (IsSSL()) { if (!Ready()) return; n = SSL_read(m_ssl, buf, TCP_BUFSIZE_READ); if (n == -1) { n = SSL_get_error(m_ssl, n); switch (n) { case SSL_ERROR_NONE: case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: break; case SSL_ERROR_ZERO_RETURN: DEB( fprintf(stderr, "SSL_read() returns zero - closing socket\n");) SetCloseAndDelete(true); SetFlushBeforeClose(false); #ifdef ENABLE_POOL SetLost(); #endif break; default: DEB( fprintf(stderr, "SSL read problem, errcode = %d\n",n);) SetCloseAndDelete(true); SetFlushBeforeClose(false); #ifdef ENABLE_POOL SetLost(); #endif } return; } else if (!n) { Handler().LogError(this, "SSL_read", 0, "read returns 0", LOG_LEVEL_FATAL); SetCloseAndDelete(true); SetFlushBeforeClose(false); #ifdef ENABLE_POOL SetLost(); #endif SetShutdown(SHUT_WR); return; } else if (n > 0 && n <= TCP_BUFSIZE_READ) { m_bytes_received += n; if (GetTrafficMonitor()) { GetTrafficMonitor() -> fwrite(buf, 1, n); } if (!m_b_input_buffer_disabled && !ibuf.Write(buf,n)) { Handler().LogError(this, "OnRead(ssl)", 0, "ibuf overflow", LOG_LEVEL_WARNING); } } else { Handler().LogError(this, "OnRead(ssl)", n, "abnormal value from SSL_read", LOG_LEVEL_ERROR); } } else #endif // HAVE_OPENSSL { n = recv(GetSocket(), buf, TCP_BUFSIZE_READ, MSG_NOSIGNAL); if (n == -1) { Handler().LogError(this, "read", Errno, StrError(Errno), LOG_LEVEL_FATAL); SetCloseAndDelete(true); SetFlushBeforeClose(false); #ifdef ENABLE_POOL SetLost(); #endif return; } else if (!n) { SetCloseAndDelete(true); SetFlushBeforeClose(false); #ifdef ENABLE_POOL SetLost(); #endif SetShutdown(SHUT_WR); return; } else if (n > 0 && n <= TCP_BUFSIZE_READ) { m_bytes_received += n; if (GetTrafficMonitor()) { GetTrafficMonitor() -> fwrite(buf, 1, n); } if (!m_b_input_buffer_disabled && !ibuf.Write(buf,n)) { Handler().LogError(this, "OnRead", 0, "ibuf overflow", LOG_LEVEL_WARNING); } } else { Handler().LogError(this, "OnRead", n, "abnormal value from recv", LOG_LEVEL_ERROR); } } // unbuffered if (n > 0 && n <= TCP_BUFSIZE_READ) { if (LineProtocol()) { size_t x = 0; buf[n] = 0; int i = 0; if (m_skip_c && (buf[i] == 13 || buf[i] == 10) && buf[i] != m_c) { m_skip_c = false; i++; } for (; i < n && LineProtocol(); i++) { while ((buf[i] == 13 || buf[i] == 10) && LineProtocol()) { char c = buf[i]; buf[i] = 0; if (buf[x]) { m_line += (buf + x); } OnLine( m_line ); i++; m_skip_c = true; m_c = c; if (i < n && (buf[i] == 13 || buf[i] == 10) && buf[i] != c) { m_skip_c = false; i++; } x = i; m_line = ""; } if (!LineProtocol()) { break; } } if (!LineProtocol()) { if (i < n) { OnRawData(buf + i, n - i); } } else if (buf[x]) { m_line += (buf + x); } } else { OnRawData(buf, n); } } if (m_b_input_buffer_disabled) { return; } // further processing: socks4 #ifdef ENABLE_SOCKS4 if (Socks4()) { bool need_more = false; while (GetInputLength() && !need_more && !CloseAndDelete()) { need_more = OnSocks4Read(); } } #endif } void TcpSocket::OnWrite() { if (Connecting()) { if (CheckConnect()) { SetCallOnConnect(); return; } // failed #ifdef ENABLE_SOCKS4 if (Socks4()) { OnSocks4ConnectFailed(); return; } #endif if (GetConnectionRetry() == -1 || (GetConnectionRetry() && GetConnectionRetries() < GetConnectionRetry()) ) { // even though the connection failed at once, only retry after // the connection timeout. // should we even try to connect again, when CheckConnect returns // false it's because of a connection error - not a timeout... return; } SetConnecting(false); SetCloseAndDelete( true ); /// \todo state reason why connect failed OnConnectFailed(); return; } #ifdef HAVE_OPENSSL if (IsSSL()) { int n = SSL_write(m_ssl,obuf.GetStart(),(int)obuf.GetLength()); int errnr = n < 1 ? SSL_get_error(m_ssl, n) : 0; if (n == -1) { if ( errnr != SSL_ERROR_WANT_READ && errnr != SSL_ERROR_WANT_WRITE ) { SetCloseAndDelete(true); SetFlushBeforeClose(false); #ifdef ENABLE_POOL SetLost(); #endif const char *errbuf = ERR_error_string(errnr, NULL); Handler().LogError(this, "OnWrite/SSL_write", errnr, errbuf, LOG_LEVEL_FATAL); } } else if (!n) { SetCloseAndDelete(true); SetFlushBeforeClose(false); #ifdef ENABLE_POOL SetLost(); #endif DEB( const char *errbuf = ERR_error_string(errnr, NULL); fprintf(stderr, "SSL_write() returns 0: %d : %s\n",errnr, errbuf);) } else { m_bytes_sent += n; if (GetTrafficMonitor()) { GetTrafficMonitor() -> fwrite(obuf.GetStart(), 1, n); } obuf.Remove(n); // move data from m_mes to immediate output buffer while (obuf.Space() && !m_mes.empty()) { ucharp_v::iterator it = m_mes.begin(); MES *p = *it; if (obuf.Space() > p -> left()) { obuf.Write(p -> curbuf(),p -> left()); delete p; m_mes.erase(it); } else { size_t sz = obuf.Space(); obuf.Write(p -> curbuf(),sz); p -> ptr += sz; } } } // { bool br; bool bw; bool bx; Handler().Get(GetSocket(), br, bw, bx); if (obuf.GetLength() || !m_mes.empty()) Set(br, true); else Set(br, false); } return; } #endif // HAVE_OPENSSL int n = send(GetSocket(),obuf.GetStart(),(int)obuf.GetLength(),MSG_NOSIGNAL); /* When writing onto a connection-oriented socket that has been shut down (by the local or the remote end) SIGPIPE is sent to the writing process and EPIPE is returned. The signal is not sent when the write call specified the MSG_NOSIGNAL flag. */ if (n == -1) { // normal error codes: // WSAEWOULDBLOCK // EAGAIN or EWOULDBLOCK #ifdef _WIN32 if (Errno != WSAEWOULDBLOCK) #else if (Errno != EWOULDBLOCK) #endif { Handler().LogError(this, "send", Errno, StrError(Errno), LOG_LEVEL_FATAL); SetCloseAndDelete(true); SetFlushBeforeClose(false); #ifdef ENABLE_POOL SetLost(); #endif } } else if (!n) { // SetCloseAndDelete(true); } else { m_bytes_sent += n; if (GetTrafficMonitor()) { GetTrafficMonitor() -> fwrite(obuf.GetStart(), 1, n); } obuf.Remove(n); // move data from m_mes to immediate output buffer while (obuf.Space() && !m_mes.empty()) { ucharp_v::iterator it = m_mes.begin(); MES *p = *it; if (obuf.Space() > p -> left()) { obuf.Write(p -> curbuf(),p -> left()); delete p; m_mes.erase(it); } else { size_t sz = obuf.Space(); obuf.Write(p -> curbuf(),sz); p -> ptr += sz; } } } // set write flag if data still left to write { bool br; bool bw; bool bx; Handler().Get(GetSocket(), br, bw, bx); if (obuf.GetLength() || !m_mes.empty()) Set(br, true); else Set(br, false); } } void TcpSocket::Send(const std::string &str,int i) { SendBuf(str.c_str(),str.size(),i); } void TcpSocket::SendBuf(const char *buf,size_t len,int) { size_t n = obuf.GetLength(); if (!Ready() && !Connecting()) { Handler().LogError(this, "SendBuf", -1, "Attempt to write to a non-ready socket" ); // warning if (GetSocket() == INVALID_SOCKET) Handler().LogError(this, "SendBuf", 0, " * GetSocket() == INVALID_SOCKET", LOG_LEVEL_INFO); if (Connecting()) Handler().LogError(this, "SendBuf", 0, " * Connecting()", LOG_LEVEL_INFO); if (CloseAndDelete()) Handler().LogError(this, "SendBuf", 0, " * CloseAndDelete()", LOG_LEVEL_INFO); return; } if (!IsConnected()) { Handler().LogError(this, "SendBuf", -1, "Attempt to write to a non-connected socket, will be sent on connect" ); // warning } // size_t ptr = 0; if (m_mes.empty() && obuf.Space()) { if (len <= obuf.Space()) // entire block of data fits { ptr = len; obuf.Write(buf, len); } else // a part of the block fits { ptr = obuf.Space(); obuf.Write(buf, ptr); } } /// \todo check good value for max blocksize #define MAX_BLOCKSIZE 1024000 while (ptr < len) // data still left to buffer { size_t sz = len - ptr; // size of data block that has not been buffered if (sz > MAX_BLOCKSIZE) sz = MAX_BLOCKSIZE; m_mes.push_back(new MES(buf + ptr, sz)); ptr += sz; } // kick a real write if (!n && IsConnected()) { OnWrite(); } } void TcpSocket::OnLine(const std::string& ) { } #ifdef _WIN32 #pragma warning(disable:4355) #endif TcpSocket::TcpSocket(const TcpSocket& s) :Socket(s) ,ibuf(*this,0) ,obuf(*this,0) { } #ifdef _WIN32 #pragma warning(default:4355) #endif #ifdef ENABLE_SOCKS4 void TcpSocket::OnSocks4Connect() { char request[1000]; memset(request, 0, sizeof(request)); request[0] = 4; // socks v4 request[1] = 1; // command code: CONNECT { std::auto_ptr ad = GetClientRemoteAddress(); if (ad.get()) { struct sockaddr *p0 = (struct sockaddr *)*ad; struct sockaddr_in *p = (struct sockaddr_in *)p0; if (p -> sin_family == AF_INET) { memcpy(request + 2, &p -> sin_port, 2); // nwbo is ok here memcpy(request + 4, &p -> sin_addr, sizeof(struct in_addr)); } else { /// \todo warn } } else { /// \todo warn } } strcpy(request + 8, GetSocks4Userid().c_str()); size_t length = GetSocks4Userid().size() + 8 + 1; SendBuf(request, length); m_socks4_state = 0; } void TcpSocket::OnSocks4ConnectFailed() { Handler().LogError(this,"OnSocks4ConnectFailed",0,"connection to socks4 server failed, trying direct connection",LOG_LEVEL_WARNING); if (!Handler().Socks4TryDirect()) { SetConnecting(false); SetCloseAndDelete(); OnConnectFailed(); // just in case } else { SetRetryClientConnect(); } } bool TcpSocket::OnSocks4Read() { switch (m_socks4_state) { case 0: ibuf.Read(&m_socks4_vn, 1); m_socks4_state = 1; break; case 1: ibuf.Read(&m_socks4_cd, 1); m_socks4_state = 2; break; case 2: if (GetInputLength() > 1) { ibuf.Read( (char *)&m_socks4_dstport, 2); m_socks4_state = 3; } else { return true; } break; case 3: if (GetInputLength() > 3) { ibuf.Read( (char *)&m_socks4_dstip, 4); SetSocks4(false); switch (m_socks4_cd) { case 90: OnConnect(); Handler().LogError(this, "OnSocks4Read", 0, "Connection established", LOG_LEVEL_INFO); break; case 91: case 92: case 93: Handler().LogError(this,"OnSocks4Read",m_socks4_cd,"socks4 server reports connect failed",LOG_LEVEL_FATAL); SetConnecting(false); SetCloseAndDelete(); OnConnectFailed(); break; default: Handler().LogError(this,"OnSocks4Read",m_socks4_cd,"socks4 server unrecognized response",LOG_LEVEL_FATAL); SetCloseAndDelete(); break; } } else { return true; } break; } return false; } #endif void TcpSocket::Sendf(const char *format, ...) { va_list ap; va_start(ap, format); char slask[5000]; // vsprintf / vsnprintf temporary #ifdef _WIN32 vsprintf(slask, format, ap); #else vsnprintf(slask, 5000, format, ap); #endif va_end(ap); Send( slask ); } #ifdef HAVE_OPENSSL void TcpSocket::OnSSLConnect() { SetNonblocking(true); { if (m_ssl_ctx) { DEB( fprintf(stderr, "SSL Context already initialized - closing socket\n");) SetCloseAndDelete(true); return; } InitSSLClient(); } if (m_ssl_ctx) { /* Connect the SSL socket */ m_ssl = SSL_new(m_ssl_ctx); if (!m_ssl) { DEB( fprintf(stderr, " m_ssl is NULL\n");) SetCloseAndDelete(true); return; } SSL_set_mode(m_ssl, SSL_MODE_AUTO_RETRY); m_sbio = BIO_new_socket((int)GetSocket(), BIO_NOCLOSE); if (!m_sbio) { DEB( fprintf(stderr, " m_sbio is NULL\n");) SetCloseAndDelete(true); return; } SSL_set_bio(m_ssl, m_sbio, m_sbio); // if (!SSLNegotiate()) { SetSSLNegotiate(); } } else { SetCloseAndDelete(); } } void TcpSocket::OnSSLAccept() { SetNonblocking(true); { if (m_ssl_ctx) { DEB( fprintf(stderr, "SSL Context already initialized - closing socket\n");) SetCloseAndDelete(true); return; } InitSSLServer(); SetSSLServer(); } if (m_ssl_ctx) { m_ssl = SSL_new(m_ssl_ctx); if (!m_ssl) { DEB( fprintf(stderr, " m_ssl is NULL\n");) SetCloseAndDelete(true); return; } SSL_set_mode(m_ssl, SSL_MODE_AUTO_RETRY); m_sbio = BIO_new_socket((int)GetSocket(), BIO_NOCLOSE); if (!m_sbio) { DEB( fprintf(stderr, " m_sbio is NULL\n");) SetCloseAndDelete(true); return; } SSL_set_bio(m_ssl, m_sbio, m_sbio); // if (!SSLNegotiate()) { SetSSLNegotiate(); } } } bool TcpSocket::SSLNegotiate() { if (!IsSSLServer()) // client { int r = SSL_connect(m_ssl); if (r > 0) { SetSSLNegotiate(false); /// \todo: resurrect certificate check... client // CheckCertificateChain( "");//ServerHOST); SetNonblocking(false); // { SetConnected(); if (GetOutputLength()) { OnWrite(); } } #ifdef ENABLE_RECONNECT if (IsReconnect()) OnReconnect(); else #endif { OnConnect(); } Handler().LogError(this, "SSLNegotiate/SSL_connect", 0, "Connection established", LOG_LEVEL_INFO); return true; } else if (!r) { Handler().LogError(this, "SSLNegotiate/SSL_connect", 0, "Connection failed", LOG_LEVEL_INFO); SetSSLNegotiate(false); SetCloseAndDelete(); OnSSLConnectFailed(); } else { r = SSL_get_error(m_ssl, r); if (r != SSL_ERROR_WANT_READ && r != SSL_ERROR_WANT_WRITE) { Handler().LogError(this, "SSLNegotiate/SSL_connect", -1, "Connection failed", LOG_LEVEL_INFO); DEB( fprintf(stderr, "SSL_connect() failed - closing socket, return code: %d\n",r);) SetSSLNegotiate(false); SetCloseAndDelete(true); OnSSLConnectFailed(); } } } else // server { int r = SSL_accept(m_ssl); if (r > 0) { SetSSLNegotiate(false); /// \todo: resurrect certificate check... server // CheckCertificateChain( "");//ClientHOST); SetNonblocking(false); // { SetConnected(); if (GetOutputLength()) { OnWrite(); } } OnAccept(); Handler().LogError(this, "SSLNegotiate/SSL_accept", 0, "Connection established", LOG_LEVEL_INFO); return true; } else if (!r) { Handler().LogError(this, "SSLNegotiate/SSL_accept", 0, "Connection failed", LOG_LEVEL_INFO); SetSSLNegotiate(false); SetCloseAndDelete(); OnSSLAcceptFailed(); } else { r = SSL_get_error(m_ssl, r); if (r != SSL_ERROR_WANT_READ && r != SSL_ERROR_WANT_WRITE) { Handler().LogError(this, "SSLNegotiate/SSL_accept", -1, "Connection failed", LOG_LEVEL_INFO); DEB( fprintf(stderr, "SSL_accept() failed - closing socket, return code: %d\n",r);) SetSSLNegotiate(false); SetCloseAndDelete(true); OnSSLAcceptFailed(); } } } return false; } void TcpSocket::InitSSLClient() { InitializeContext("", SSLv23_method()); } void TcpSocket::InitSSLServer() { Handler().LogError(this, "InitSSLServer", 0, "You MUST implement your own InitSSLServer method", LOG_LEVEL_FATAL); SetCloseAndDelete(); } void TcpSocket::InitializeContext(const std::string& context, SSL_METHOD *meth_in) { /* Create our context*/ static std::map client_contexts; if (client_contexts.find(context) == client_contexts.end()) { SSL_METHOD *meth = meth_in ? meth_in : SSLv3_method(); m_ssl_ctx = client_contexts[context] = SSL_CTX_new(meth); SSL_CTX_set_mode(m_ssl_ctx, SSL_MODE_AUTO_RETRY); } else { m_ssl_ctx = client_contexts[context]; } } void TcpSocket::InitializeContext(const std::string& context,const std::string& keyfile,const std::string& password,SSL_METHOD *meth_in) { /* Create our context*/ static std::map server_contexts; if (server_contexts.find(context) == server_contexts.end()) { SSL_METHOD *meth = meth_in ? meth_in : SSLv3_method(); m_ssl_ctx = server_contexts[context] = SSL_CTX_new(meth); SSL_CTX_set_mode(m_ssl_ctx, SSL_MODE_AUTO_RETRY); // session id if (!context.empty()) SSL_CTX_set_session_id_context(m_ssl_ctx, (const unsigned char *)context.c_str(), (unsigned int)context.size()); else SSL_CTX_set_session_id_context(m_ssl_ctx, (const unsigned char *)"--empty--", 9); } else { m_ssl_ctx = server_contexts[context]; } /* Load our keys and certificates*/ if (!(SSL_CTX_use_certificate_file(m_ssl_ctx, keyfile.c_str(), SSL_FILETYPE_PEM))) { Handler().LogError(this, "TcpSocket InitializeContext", 0, "Couldn't read certificate file " + keyfile, LOG_LEVEL_FATAL); } m_password = password; SSL_CTX_set_default_passwd_cb(m_ssl_ctx, SSL_password_cb); SSL_CTX_set_default_passwd_cb_userdata(m_ssl_ctx, this); if (!(SSL_CTX_use_PrivateKey_file(m_ssl_ctx, keyfile.c_str(), SSL_FILETYPE_PEM))) { Handler().LogError(this, "TcpSocket InitializeContext", 0, "Couldn't read private key file " + keyfile, LOG_LEVEL_FATAL); } } int TcpSocket::SSL_password_cb(char *buf,int num,int rwflag,void *userdata) { Socket *p0 = static_cast(userdata); TcpSocket *p = dynamic_cast(p0); std::string pw = p ? p -> GetPassword() : ""; if ( (size_t)num < pw.size() + 1) { return 0; } strcpy(buf,pw.c_str()); return (int)pw.size(); } #endif // HAVE_OPENSSL int TcpSocket::Close() { #ifdef HAVE_OPENSSL if (IsSSL() && m_ssl) SSL_shutdown(m_ssl); if (m_ssl) { SSL_free(m_ssl); m_ssl = NULL; } #endif return Socket::Close(); } #ifdef HAVE_OPENSSL SSL_CTX *TcpSocket::GetSslContext() { if (!m_ssl_ctx) Handler().LogError(this, "GetSslContext", 0, "SSL Context is NULL; check InitSSLServer/InitSSLClient", LOG_LEVEL_WARNING); return m_ssl_ctx; } SSL *TcpSocket::GetSsl() { if (!m_ssl) Handler().LogError(this, "GetSsl", 0, "SSL is NULL; check InitSSLServer/InitSSLClient", LOG_LEVEL_WARNING); return m_ssl; } #endif #ifdef ENABLE_RECONNECT void TcpSocket::SetReconnect(bool x) { m_b_reconnect = x; } #endif void TcpSocket::OnRawData(const char *buf_in,size_t len) { } size_t TcpSocket::GetInputLength() { return ibuf.GetLength(); } size_t TcpSocket::GetOutputLength() { return obuf.GetLength(); } uint64_t TcpSocket::GetBytesReceived(bool clear) { uint64_t z = m_bytes_received; if (clear) m_bytes_received = 0; return z; } uint64_t TcpSocket::GetBytesSent(bool clear) { uint64_t z = m_bytes_sent; if (clear) m_bytes_sent = 0; return z; } #ifdef ENABLE_RECONNECT bool TcpSocket::Reconnect() { return m_b_reconnect; } void TcpSocket::SetIsReconnect(bool x) { m_b_is_reconnect = x; } bool TcpSocket::IsReconnect() { return m_b_is_reconnect; } #endif #ifdef HAVE_OPENSSL const std::string& TcpSocket::GetPassword() { return m_password; } #endif void TcpSocket::DisableInputBuffer(bool x) { m_b_input_buffer_disabled = x; } void TcpSocket::OnOptions(int family,int type,int protocol,SOCKET s) { DEB( fprintf(stderr, "Socket::OnOptions()\n");) /* Handler().LogError(this, "OnOptions", family, "Address Family", LOG_LEVEL_INFO); Handler().LogError(this, "OnOptions", type, "Type", LOG_LEVEL_INFO); Handler().LogError(this, "OnOptions", protocol, "Protocol", LOG_LEVEL_INFO); */ SetReuse(true); SetKeepalive(true); } void TcpSocket::SetLineProtocol(bool x) { Socket::SetLineProtocol(x); DisableInputBuffer(x); } #ifdef SOCKETS_NAMESPACE } #endif