/* Web Polygraph       http://www.web-polygraph.org/
 * (C) 2003-2006 The Measurement Factory
 * Licensed under the Apache License, Version 2.0 */

#include "xstd/xstd.h"

#if OPENSSL_ENABLED
#include <openssl/err.h>
#include <openssl/rand.h>
#endif

#include "xstd/Assert.h"
#include "xstd/String.h"
#include "xstd/Ssl.h"

int Ssl::TheLevel = 0;

// XXX: we should set the [SSL] error when SSL library is not found
// XXX: we should set the global error to an SSL error when an SSL call fails

/* SslCtx class */

SslCtx::SslCtx(SslProtocol protocol, const String &cipher): theCtx(0) {
#if OPENSSL_ENABLED
	SSL_METHOD *method;
	switch(protocol) {
		case SSLv2:
			method = ::SSLv2_method();
			break;
		case TLSv1:
			method = ::TLSv1_method();
			break;
		case SSLv3:
			method = ::SSLv3_method();
			break;
		case SSLv23:
			method = ::SSLv23_method();
			break;
		default:
			method = ::SSLv23_method();
			Should(false);
	}
	theCtx = ::SSL_CTX_new(method);
	Must(::SSL_CTX_set_cipher_list(theCtx, cipher.cstr()));
#endif
}

SslCtx::SslCtx(const SslCtx &anSslCtx) {
	Assert(false);
}

SslCtx::~SslCtx() {
#if OPENSSL_ENABLED
	if (theCtx)
		::SSL_CTX_free(theCtx);
#endif
}

SslCtx &SslCtx::operator =(const SslCtx &anSslCtx) {
	Assert(false);
	return *this;
}

bool SslCtx::useCertificateChainFile(const String &fname) {
#if OPENSSL_ENABLED
	return ::SSL_CTX_use_certificate_chain_file(theCtx, fname.cstr()) > 0;
#endif
	return false;
}

bool SslCtx::usePrivateKeyFile(const String &fname) {
#if OPENSSL_ENABLED
	return ::SSL_CTX_use_PrivateKey_file(theCtx, fname.cstr(), SSL_FILETYPE_PEM) > 0;
#endif
	return false;
}

bool SslCtx::checkPrivateKey() {
#if OPENSSL_ENABLED
	return ::SSL_CTX_check_private_key(theCtx) > 0;
#endif
	return false;
}

void SslCtx::setDefaultPasswdCb(pem_password_cb *cb) {
#if OPENSSL_ENABLED
	::SSL_CTX_set_default_passwd_cb(theCtx, cb);
#endif
}

bool SslCtx::loadVerifyLocations(const String &fname, const String &dirName) {
#if OPENSSL_ENABLED
	const char *dname = dirName.len() ? dirName.cstr() : 0;
	return ::SSL_CTX_load_verify_locations(theCtx, fname.cstr(), dname) > 0;
#endif
	Assert(sizeof(fname) && sizeof(dirName));
	return false;
}

long SslCtx::sessionCacheMode(long mode) {
#if OPENSSL_ENABLED
	return SSL_CTX_set_session_cache_mode(theCtx, mode);
#endif
	Assert(sizeof(mode));
	return 0;
}

long SslCtx::sessionCacheSize(long count) {
#if OPENSSL_ENABLED
	return SSL_CTX_sess_set_cache_size(theCtx, count);
#endif
	Assert(sizeof(count));
	return 0;
}

bool SslCtx::sessionId(const String &id) {
#if OPENSSL_ENABLED
	return SSL_CTX_set_session_id_context(theCtx, (unsigned char*)id.data(), id.len()) > 0;
#endif
	Assert(sizeof(id));
	return false;
}

Ssl *SslCtx::makeConnection() const {
	return new Ssl(theCtx);
}

void SslCtx::setVerify(int mode) const {
#if OPENSSL_ENABLED
	return ::SSL_CTX_set_verify(theCtx, mode, 0);
#endif
}


/* Ssl Class */

int Ssl::Level() {
	return TheLevel;
}

Ssl::Ssl(const SSL_CTX *ctx): theConn(0) {
#if OPENSSL_ENABLED
	// cast to non-const context because OpenSSL does not use "const"
	theConn = ::SSL_new((SSL_CTX*)ctx);
	if (Should(theConn))
		TheLevel++;
#else
	Assert(sizeof(ctx));
#endif
}

Ssl::Ssl(const Ssl &anSsl) {
	Assert(false);
}

Ssl::~Ssl() {
#if OPENSSL_ENABLED
	if (Should(theConn)) {
		::SSL_free(theConn);
		TheLevel--;
	}
#endif
}

Ssl &Ssl::operator =(const Ssl &anSsl) {
	Assert(false);
	return *this;
}

bool Ssl::shutdown(int &res) {
#if OPENSSL_ENABLED
	res = ::SSL_shutdown(theConn);
	return res > 0;
#endif
	res = -1;
	return false;
}

bool Ssl::setFd(int fd) {
#if OPENSSL_ENABLED
	return ::SSL_set_fd(theConn, fd) > 0;
#endif
	return false && !sizeof(fd);
}

void Ssl::playRole(int role) {
	switch (role) {
		case rlClient:
			playClientRole();
			break;
		case rlServer:
			playServerRole();
			break;
		default:
			Should(false);
	}
}

void Ssl::playClientRole() {
#if OPENSSL_ENABLED
	::SSL_set_connect_state(theConn);
#endif
}

void Ssl::playServerRole() {
#if OPENSSL_ENABLED
	::SSL_set_accept_state(theConn);
#endif
}

bool Ssl::enablePartialWrite() {
#if OPENSSL_ENABLED
#	ifdef SSL_MODE_ENABLE_PARTIAL_WRITE
		return addMode(SSL_MODE_ENABLE_PARTIAL_WRITE);
#	else
		return false; // unsupported mode
#	endif
#endif
	return false;
}

bool Ssl::enableAutoRetry() {
#if OPENSSL_ENABLED
#	ifdef SSL_MODE_AUTO_RETRY
		return addMode(SSL_MODE_AUTO_RETRY);
#	else
		return false; // unsupported mode
#	endif
#endif
	return false;
}

bool Ssl::acceptMovingWriteBuffer() {
#if OPENSSL_ENABLED
#	ifdef SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER
		return addMode(SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
#	else
		return false; // unsupported mode
#	endif
#endif
	return false;
}

bool Ssl::resumeSession(SslSession *session) {
#if OPENSSL_ENABLED
	return ::SSL_set_session(theConn, session->raw()) > 0;
#endif
	return false;
}

bool Ssl::connect(int &res) {
#if OPENSSL_ENABLED
	res = ::SSL_connect(theConn);
	return res > 0;
#endif
	res = -1;
	return false;
}

bool Ssl::accept(int &res) {
#if OPENSSL_ENABLED
	res = ::SSL_accept(theConn);
	return res > 0;
#endif
	res = -1;
	return false;
}

Size Ssl::read(char *buf, Size sz) {
#if OPENSSL_ENABLED
	return ::SSL_read(theConn, buf, sz);
#endif
	Assert(sizeof(buf) && sizeof(sz));
	return -1;
}

Size Ssl::write(const char *buf, Size sz) {
#if OPENSSL_ENABLED
	return ::SSL_write(theConn, buf, sz);
#endif
	Assert(sizeof(buf) && sizeof(sz));
	return -1;
}

bool Ssl::addMode(long modeBit) {
#if OPENSSL_ENABLED
	const long newMode = SSL_set_mode(theConn, modeBit);
	return (newMode & modeBit) != 0;
#endif
	Assert(sizeof(modeBit));
	return false;
}

bool Ssl::dataPending() const {
#if OPENSSL_ENABLED
	return ::SSL_pending((SSL*)theConn);
#endif
	return false;
}

bool Ssl::reusedSession() const {
#if OPENSSL_ENABLED
	return ::SSL_session_reused((SSL*)theConn) > 0;
#endif
	return false;
}

SslSession *Ssl::refCountedSession() const {
#if OPENSSL_ENABLED
	SSL_SESSION *session = SSL_get1_session((SSL*)theConn);
	return session ? new SslSession(session) : 0;
#endif
	return 0;
}

const char *Ssl::getCipher() {
#if OPENSSL_ENABLED
	return ::SSL_get_cipher(theConn);
#else
	return "You don't have libssl";
#endif
}

int Ssl::getError(int e) {
#if OPENSSL_ENABLED
	return ::SSL_get_error(theConn, e);
#endif
	Assert(sizeof(e));
	return -1;
}

const char *Ssl::getErrorString(int e) {
#if OPENSSL_ENABLED
	switch(::SSL_get_error(theConn, e)) {
		case SSL_ERROR_NONE:
			return "SSL_ERROR_NONE";
		case SSL_ERROR_ZERO_RETURN:
			return "SSL_ERROR_ZERO_RETURN";
		case SSL_ERROR_WANT_READ:
			return "SSL_ERROR_WANT_READ";
		case SSL_ERROR_WANT_WRITE:
			return "SSL_ERROR_WANT_WRITE";
		case SSL_ERROR_WANT_X509_LOOKUP:
			return "SSL_ERROR_WANT_X509_LOOKUP";
		case SSL_ERROR_SYSCALL:
			return "SSL_ERROR_SYSCALL";
		case SSL_ERROR_SSL:
			return "SSL_ERROR_SSL";
		default:
			return "UNKNOWN";
	}
#endif
	return "NO_LIBSSL";
}


/* SslSession */

SslSession::SslSession(SSL_SESSION *aSession): theSession(aSession) {
}

SslSession::~SslSession() {
#if OPENSSL_ENABLED
	if (theSession)
		SSL_SESSION_free(theSession);
#endif
}

SSL_SESSION *SslSession::raw() {
	return theSession;
}



/* SslMisc class : MISC functions, unlreated to SSL_CTX or SSL */

const char *SslMisc::ErrErrorString(unsigned long e) {
#if OPENSSL_ENABLED
	return ::ERR_error_string(e, NULL);
#else
	return "You do not have libssl";
#endif
}

unsigned long SslMisc::ErrGetErrorLine(const char **fname, int *line) {
#if OPENSSL_ENABLED
	return ::ERR_get_error_line(fname, line);
#else
	return 0;
#endif
}

void SslMisc::LibraryInit() {
#if OPENSSL_ENABLED
	::SSL_load_error_strings(); // optional, uses extra memory, but faster?
	(void)::SSL_library_init();
#endif
}

void SslMisc::SeedRng(double seed) {
#if OPENSSL_ENABLED
#	if HAVE_RAND_STATUS
		if (RAND_status() > 0)
			return; // enough random data (e.g., library used /dev/urandom)
#	endif
	RAND_seed(&seed, sizeof(seed));
#	if HAVE_RAND_STATUS
		Must(RAND_status());
#	endif
#endif
}


syntax highlighted by Code2HTML, v. 0.9.1