/* Web Polygraph       http://www.web-polygraph.org/
 * (C) 2003-2006 The Measurement Factory
 * Licensed under the Apache License, Version 2.0 */

#include "base/polygraph.h"

#include <stdlib.h>
#include "xstd/h/iomanip.h"

#include "xstd/Ssl.h"
#include "xstd/StrIdentifier.h"
#include "xstd/rndDistrs.h"
#include "base/RndPermut.h"
#include "pgl/SslWrapSym.h"
#include "runtime/LogComment.h"
#include "runtime/SslWrap.h"

static String tmpServerReqPem = "/tmp/serverreq.pem";
static String tmpServerKeyPem = "/tmp/serverkey.pem";
static String tmpServerCertPem = "/tmp/servercert.pem";
static String tmpServerChainPem = "/tmp/serverchain.pem";
static String tmpCASerialFile = "/tmp/cert.srl";

static bool SslWrap_RunCommand(ostringstream &os, const String &descr);


SslWrap::SslWrap(): theProtocolSel(0), theRsaKeySizeSel(0), theCipherSel(0),
	theResumpProb(-1), theSessionCacheSize(-1) {
}

void SslWrap::configure(const SslWrapSym &cfg) {
	theRootCertificate = cfg.rootCertificate();
	configureProtocols(cfg);
	configureRsaKeySizes(cfg);
	configureCiphers(cfg);

	cfg.sessionResumpt(theResumpProb);
	cfg.sessionCacheSize(theSessionCacheSize);

	if (theSessionCacheSize == 0) {
		Comment << cfg.loc() << "fyi: session cache size of zero means " <<
			"no cache, and not unlimited-size cache as in OpenSSL" << endc;
	}

	if (theResumpProb <= 0 && theSessionCacheSize > 0) {
		Comment << cfg.loc() << "warning: positive session cache size " <<
			"ignored since session resumption is disabled" << endc;
	} else
	if (theResumpProb > 0 && theSessionCacheSize == 0) {
		Comment << cfg.loc() << "warning: positive session resumption " <<
			"probability is ignored since session cache size is zero" << endc;
	}

	if (theResumpProb <= 0)
		theSessionCacheSize = 0;
	else
	if (theSessionCacheSize == 0)
		theResumpProb = 0;
}

void SslWrap::configureProtocols(const SslWrapSym &cfg) {
	static StrIdentifier sidf;
	if (!sidf.count()) {
		sidf.add("SSLv2", SslCtx::SSLv2);
		sidf.add("SSLv3", SslCtx::SSLv3);
		sidf.add("TLSv1", SslCtx::TLSv1);
		sidf.add("any", SslCtx::SSLv23); // all of the above
		sidf.optimize();
	}

	theProtocolSel = cfg.protocols(sidf);
	if (!theProtocolSel)
		theProtocolSel = new ConstDistr(new RndGen, SslCtx::SSLv23);
	theProtocolSel->rndGen(GlbRndGen("ssl_protocols"));
}

void SslWrap::configureRsaKeySizes(const SslWrapSym &cfg) {
	if (cfg.rsaKeySizes(theRsaKeySizes, theRsaKeySizeSel))
		theRsaKeySizeSel->rndGen(GlbRndGen("rsa_key_sizes"));
}

void SslWrap::configureCiphers(const SslWrapSym &cfg) {
	if (cfg.ciphers(theCiphers, theCipherSel))
		theCipherSel->rndGen(GlbRndGen("ssl_ciphers"));
}

Size SslWrap::selectRsaKeySize() const {
	if (!theRsaKeySizeSel)
		return Size::Bit(1024);
	const int idx = (int)theRsaKeySizeSel->trial();
	return theRsaKeySizes[idx];
}

String SslWrap::selectCipher() const {
	if (!theCipherSel)
		return String("ALL");
	const int idx = (int)theCipherSel->trial();
	return *theCiphers[idx];
}

int SslWrap::sessionCacheSize() const {
	return theSessionCacheSize;
}

double SslWrap::resumpProb() const {
	return theResumpProb;
}

SslCtx *SslWrap::makeClientCtx(const NetAddr &addr) const {
	SslCtx *ctx = makeCtx(addr);

	if (!theRootCertificate) {
		Comment << "no root certificate, setting SSL_VERIFY_NONE"
			<< endc;
		ctx->setVerify(SSL_VERIFY_NONE);
		return ctx;
	}

	// robots need the CA cert to verify server's key
	if (ctx->loadVerifyLocations(theRootCertificate, String())) {
		// XXX: add this: ctx->setVerify(SSL_VERIFY_PEER);
		return ctx;
	}

	Comment << "loadVerifyLocations() failed to load root certificate"
		<< endc;

	ReportErrors();
	exit(2);
	return ctx;
}

#ifdef UNUSED_CODE
static
int SslWrap_passwdCb(char *buf, int size, int, void *) {
	strncpy(buf, "password", size);
	return strlen(buf);
}
#endif

SslCtx *SslWrap::makeServerCtx(const NetAddr &addr) const {
	SslCtx *ctx = makeCtx(addr);

	// Always set SSL_VERIFY_PEER on the server.  The handshake
	// fails only if the client provides an invalid certificate
	ctx->setVerify(SSL_VERIFY_PEER);

	// ctx->setDefaultPasswdCb(&SslWrap_passwdCb); // not needed due to -nodes

	// servers need a private key and the root CA cert
	if (configureSrvPrivateKey(ctx) && configureSrvCert(ctx))
		return ctx;

	ReportErrors();
	exit(2);
	return ctx;
}

SslCtx *SslWrap::makeCtx(const NetAddr &) const {
	static bool libInited = false;
	if (!libInited) {
		SslMisc::LibraryInit();
		SslMisc::SeedRng(LclPermut(rndSslSeed));
		Comment(5) << "fyi: SSL library initialized and seeded" << endc;
		libInited = true;
	}

	const SslCtx::SslProtocol protocol =
		(SslCtx::SslProtocol)theProtocolSel->trial();
	const String cipher = selectCipher();
	if (cipher.cmp("ALL"))
		Comment << "SSL context using cipher " << cipher << endc;
	return new SslCtx(protocol, cipher);
}

bool SslWrap::configureSrvCert(SslCtx *ctx) const {
	// to make a server certificate, we need the CA public key
	// and the CA certificate

	// this command assumes passphrase-less root/CA key
	ostringstream cmd1;
	cmd1 << "openssl x509"
		<< " -req"
		<< " -in " << tmpServerReqPem
		<< " -sha1"
		<< " -extensions usr_cert";

	if (theRootCertificate) {
		cmd1 << " -CA " << theRootCertificate
			<< " -CAkey " << theRootCertificate;
	} else {
		cmd1 << " -signkey " << tmpServerKeyPem;
	}

	//
	// Use -CAserial option because diskless drones probably won't
	// be able to write the serial file in their current directory.
	//
	cmd1 << " -CAcreateserial"
		<< " -CAserial " << tmpCASerialFile
		<< " -out " << tmpServerCertPem
		<< ends;

	if (!SslWrap_RunCommand(cmd1, "x509 key generation"))
		return false;

	// To create a certificate chain, we must concatenate certificates
	ostringstream cmd2;
	cmd2	<< "cat " << tmpServerCertPem << ' ' << tmpServerKeyPem;
	if (!theRootCertificate)
		cmd2 << ' ' << theRootCertificate;
	cmd2 << " > " << tmpServerChainPem << ends;
	if (!SslWrap_RunCommand(cmd2, "certificate chain creation"))
		return false;

	if (!ctx->useCertificateChainFile(tmpServerChainPem)) {
		ReportErrors();
		return false;
	}

	return true;
}

bool SslWrap::configureSrvPrivateKey(SslCtx *ctx) const {
	ostringstream cmd;
	const Size keylen = selectRsaKeySize();
	cmd << "openssl req"
		<< " -newkey rsa:" << 8*keylen.byte()
		<< " -sha1"
		<< " -nodes"
		<< " -config myssl.conf" // XXX: hardcoded openssl.cnf
		<< " -keyout " << tmpServerKeyPem
		<< " -out " << tmpServerReqPem
		<< ends;
	if (!SslWrap_RunCommand(cmd, "server private key generation"))
		return false;
	
	if (!ctx->usePrivateKeyFile(tmpServerKeyPem)) {
		Comment << "error: failed to use private key from " << 
			tmpServerKeyPem << endc;
		ReportErrors();
		return false;
	}

	return true;
}

#ifdef UNUSED_CODE
String SslWrap::needParam(const SslWrapSym &sym, String value, const char *pname) const {
	if (value.len() <= 0) {
		cerr << sym.loc() << "error: an SslWrap configuration " <<
			" is missing valid " << pname << endc;
		exit(2);
	}
	return value;
}
#endif

static
bool SslWrap_RunCommand(ostringstream &os, const String &descr) {
	const char *cmd = os.str().c_str();
	Comment << "executing: " << cmd << endc;
	const bool res = ::system(cmd) == 0;
	if (!res)
		Comment << "error: " << descr << " command failed" << endc;
	streamFreeze(os, false);
	return res;
}

void SslWrap::ReportErrors() {
	const char *fname;
	int line;
	ostream &os = Comment << "SSL error stack:" << endl;
	while (const unsigned long e = SslMisc::ErrGetErrorLine(&fname, &line)) {
		os << "\t" << fname << ":" << line << ": " <<
			SslMisc::ErrErrorString(e) << endl;
	}
	os << endc;
}


syntax highlighted by Code2HTML, v. 0.9.1