/* Web Polygraph http://www.web-polygraph.org/
* (C) 2003-2006 The Measurement Factory
* Licensed under the Apache License, Version 2.0 */
#include "xstd/xstd.h"
#if OPENSSL_ENABLED
#include <openssl/err.h>
#include <openssl/rand.h>
#endif
#include "xstd/Assert.h"
#include "xstd/String.h"
#include "xstd/Ssl.h"
int Ssl::TheLevel = 0;
// XXX: we should set the [SSL] error when SSL library is not found
// XXX: we should set the global error to an SSL error when an SSL call fails
/* SslCtx class */
SslCtx::SslCtx(SslProtocol protocol, const String &cipher): theCtx(0) {
#if OPENSSL_ENABLED
SSL_METHOD *method;
switch(protocol) {
case SSLv2:
method = ::SSLv2_method();
break;
case TLSv1:
method = ::TLSv1_method();
break;
case SSLv3:
method = ::SSLv3_method();
break;
case SSLv23:
method = ::SSLv23_method();
break;
default:
method = ::SSLv23_method();
Should(false);
}
theCtx = ::SSL_CTX_new(method);
Must(::SSL_CTX_set_cipher_list(theCtx, cipher.cstr()));
#endif
}
SslCtx::SslCtx(const SslCtx &anSslCtx) {
Assert(false);
}
SslCtx::~SslCtx() {
#if OPENSSL_ENABLED
if (theCtx)
::SSL_CTX_free(theCtx);
#endif
}
SslCtx &SslCtx::operator =(const SslCtx &anSslCtx) {
Assert(false);
return *this;
}
bool SslCtx::useCertificateChainFile(const String &fname) {
#if OPENSSL_ENABLED
return ::SSL_CTX_use_certificate_chain_file(theCtx, fname.cstr()) > 0;
#endif
return false;
}
bool SslCtx::usePrivateKeyFile(const String &fname) {
#if OPENSSL_ENABLED
return ::SSL_CTX_use_PrivateKey_file(theCtx, fname.cstr(), SSL_FILETYPE_PEM) > 0;
#endif
return false;
}
bool SslCtx::checkPrivateKey() {
#if OPENSSL_ENABLED
return ::SSL_CTX_check_private_key(theCtx) > 0;
#endif
return false;
}
void SslCtx::setDefaultPasswdCb(pem_password_cb *cb) {
#if OPENSSL_ENABLED
::SSL_CTX_set_default_passwd_cb(theCtx, cb);
#endif
}
bool SslCtx::loadVerifyLocations(const String &fname, const String &dirName) {
#if OPENSSL_ENABLED
const char *dname = dirName.len() ? dirName.cstr() : 0;
return ::SSL_CTX_load_verify_locations(theCtx, fname.cstr(), dname) > 0;
#endif
Assert(sizeof(fname) && sizeof(dirName));
return false;
}
long SslCtx::sessionCacheMode(long mode) {
#if OPENSSL_ENABLED
return SSL_CTX_set_session_cache_mode(theCtx, mode);
#endif
Assert(sizeof(mode));
return 0;
}
long SslCtx::sessionCacheSize(long count) {
#if OPENSSL_ENABLED
return SSL_CTX_sess_set_cache_size(theCtx, count);
#endif
Assert(sizeof(count));
return 0;
}
bool SslCtx::sessionId(const String &id) {
#if OPENSSL_ENABLED
return SSL_CTX_set_session_id_context(theCtx, (unsigned char*)id.data(), id.len()) > 0;
#endif
Assert(sizeof(id));
return false;
}
Ssl *SslCtx::makeConnection() const {
return new Ssl(theCtx);
}
void SslCtx::setVerify(int mode) const {
#if OPENSSL_ENABLED
return ::SSL_CTX_set_verify(theCtx, mode, 0);
#endif
}
/* Ssl Class */
int Ssl::Level() {
return TheLevel;
}
Ssl::Ssl(const SSL_CTX *ctx): theConn(0) {
#if OPENSSL_ENABLED
// cast to non-const context because OpenSSL does not use "const"
theConn = ::SSL_new((SSL_CTX*)ctx);
if (Should(theConn))
TheLevel++;
#else
Assert(sizeof(ctx));
#endif
}
Ssl::Ssl(const Ssl &anSsl) {
Assert(false);
}
Ssl::~Ssl() {
#if OPENSSL_ENABLED
if (Should(theConn)) {
::SSL_free(theConn);
TheLevel--;
}
#endif
}
Ssl &Ssl::operator =(const Ssl &anSsl) {
Assert(false);
return *this;
}
bool Ssl::shutdown(int &res) {
#if OPENSSL_ENABLED
res = ::SSL_shutdown(theConn);
return res > 0;
#endif
res = -1;
return false;
}
bool Ssl::setFd(int fd) {
#if OPENSSL_ENABLED
return ::SSL_set_fd(theConn, fd) > 0;
#endif
return false && !sizeof(fd);
}
void Ssl::playRole(int role) {
switch (role) {
case rlClient:
playClientRole();
break;
case rlServer:
playServerRole();
break;
default:
Should(false);
}
}
void Ssl::playClientRole() {
#if OPENSSL_ENABLED
::SSL_set_connect_state(theConn);
#endif
}
void Ssl::playServerRole() {
#if OPENSSL_ENABLED
::SSL_set_accept_state(theConn);
#endif
}
bool Ssl::enablePartialWrite() {
#if OPENSSL_ENABLED
# ifdef SSL_MODE_ENABLE_PARTIAL_WRITE
return addMode(SSL_MODE_ENABLE_PARTIAL_WRITE);
# else
return false; // unsupported mode
# endif
#endif
return false;
}
bool Ssl::enableAutoRetry() {
#if OPENSSL_ENABLED
# ifdef SSL_MODE_AUTO_RETRY
return addMode(SSL_MODE_AUTO_RETRY);
# else
return false; // unsupported mode
# endif
#endif
return false;
}
bool Ssl::acceptMovingWriteBuffer() {
#if OPENSSL_ENABLED
# ifdef SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER
return addMode(SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
# else
return false; // unsupported mode
# endif
#endif
return false;
}
bool Ssl::resumeSession(SslSession *session) {
#if OPENSSL_ENABLED
return ::SSL_set_session(theConn, session->raw()) > 0;
#endif
return false;
}
bool Ssl::connect(int &res) {
#if OPENSSL_ENABLED
res = ::SSL_connect(theConn);
return res > 0;
#endif
res = -1;
return false;
}
bool Ssl::accept(int &res) {
#if OPENSSL_ENABLED
res = ::SSL_accept(theConn);
return res > 0;
#endif
res = -1;
return false;
}
Size Ssl::read(char *buf, Size sz) {
#if OPENSSL_ENABLED
return ::SSL_read(theConn, buf, sz);
#endif
Assert(sizeof(buf) && sizeof(sz));
return -1;
}
Size Ssl::write(const char *buf, Size sz) {
#if OPENSSL_ENABLED
return ::SSL_write(theConn, buf, sz);
#endif
Assert(sizeof(buf) && sizeof(sz));
return -1;
}
bool Ssl::addMode(long modeBit) {
#if OPENSSL_ENABLED
const long newMode = SSL_set_mode(theConn, modeBit);
return (newMode & modeBit) != 0;
#endif
Assert(sizeof(modeBit));
return false;
}
bool Ssl::dataPending() const {
#if OPENSSL_ENABLED
return ::SSL_pending((SSL*)theConn);
#endif
return false;
}
bool Ssl::reusedSession() const {
#if OPENSSL_ENABLED
return ::SSL_session_reused((SSL*)theConn) > 0;
#endif
return false;
}
SslSession *Ssl::refCountedSession() const {
#if OPENSSL_ENABLED
SSL_SESSION *session = SSL_get1_session((SSL*)theConn);
return session ? new SslSession(session) : 0;
#endif
return 0;
}
const char *Ssl::getCipher() {
#if OPENSSL_ENABLED
return ::SSL_get_cipher(theConn);
#else
return "You don't have libssl";
#endif
}
int Ssl::getError(int e) {
#if OPENSSL_ENABLED
return ::SSL_get_error(theConn, e);
#endif
Assert(sizeof(e));
return -1;
}
const char *Ssl::getErrorString(int e) {
#if OPENSSL_ENABLED
switch(::SSL_get_error(theConn, e)) {
case SSL_ERROR_NONE:
return "SSL_ERROR_NONE";
case SSL_ERROR_ZERO_RETURN:
return "SSL_ERROR_ZERO_RETURN";
case SSL_ERROR_WANT_READ:
return "SSL_ERROR_WANT_READ";
case SSL_ERROR_WANT_WRITE:
return "SSL_ERROR_WANT_WRITE";
case SSL_ERROR_WANT_X509_LOOKUP:
return "SSL_ERROR_WANT_X509_LOOKUP";
case SSL_ERROR_SYSCALL:
return "SSL_ERROR_SYSCALL";
case SSL_ERROR_SSL:
return "SSL_ERROR_SSL";
default:
return "UNKNOWN";
}
#endif
return "NO_LIBSSL";
}
/* SslSession */
SslSession::SslSession(SSL_SESSION *aSession): theSession(aSession) {
}
SslSession::~SslSession() {
#if OPENSSL_ENABLED
if (theSession)
SSL_SESSION_free(theSession);
#endif
}
SSL_SESSION *SslSession::raw() {
return theSession;
}
/* SslMisc class : MISC functions, unlreated to SSL_CTX or SSL */
const char *SslMisc::ErrErrorString(unsigned long e) {
#if OPENSSL_ENABLED
return ::ERR_error_string(e, NULL);
#else
return "You do not have libssl";
#endif
}
unsigned long SslMisc::ErrGetErrorLine(const char **fname, int *line) {
#if OPENSSL_ENABLED
return ::ERR_get_error_line(fname, line);
#else
return 0;
#endif
}
void SslMisc::LibraryInit() {
#if OPENSSL_ENABLED
::SSL_load_error_strings(); // optional, uses extra memory, but faster?
(void)::SSL_library_init();
#endif
}
void SslMisc::SeedRng(double seed) {
#if OPENSSL_ENABLED
# if HAVE_RAND_STATUS
if (RAND_status() > 0)
return; // enough random data (e.g., library used /dev/urandom)
# endif
RAND_seed(&seed, sizeof(seed));
# if HAVE_RAND_STATUS
Must(RAND_status());
# endif
#endif
}
syntax highlighted by Code2HTML, v. 0.9.1