/* $Cambridge: hermes/src/prayer/accountd/ssl.c,v 1.1.1.1 2003/04/15 13:00:03 dpc22 Exp $ */
/************************************************
 *    Prayer - a Webmail Interface              *
 ************************************************/

/* Copyright (c) University of Cambridge 2000 - 2002 */
/* See the file NOTICE for conditions of use and distribution. */

#include "accountd.h"

/* Headers files for OpenSSL */

#include <openssl/lhash.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/rand.h>

/* ====================================================================== */

BOOL ssl_is_available()
{
    return (T);
}

/* ====================================================================== */

/* Assorted bits stolen straight from Stunnel (ssl.c) that we might need */

/* Global SSL context shared by server iostreams */
static SSL_CTX *server_ctx;

/* Global SSL context shared by client iostreams */
static SSL_CTX *client_ctx;

/* Identifier string used by both context */
static unsigned char *sid_ctx = (unsigned char *) "Prayer SID";

/* Enable full logging? */
static BOOL ssl_verbose_logging = NIL;

/* RSA key length */
#define SSL_RSA_KEYLENGTH   (1024)

/* Cipher list */
#define SSLCIPHERLIST "ALL:!LOW"

/* ====================================================================== */

/* PRNG stuff for SSL */

/* shortcut to determine if sufficient entropy for PRNG is present */
static int prng_seeded(int bytes)
{
    if (RAND_status()) {
        log_misc("RAND_status claims sufficient entropy for the PRNG\n");
        return (1);
    }
    return (0);                 /* assume we don't have enough */
}

static int add_rand_file(char *filename)
{
    int readbytes;
    struct stat sb;

    if (stat(filename, &sb) != 0) {
        return (0);
    }

    if ((readbytes = RAND_load_file(filename, 2048))) {
        log_misc("Snagged %lu random bytes from %s\n",
                 (unsigned long) readbytes, filename);
    } else {
        log_misc("Unable to retrieve any random data from %s\n", filename);
    }
    return (readbytes);
}

static void os_initialize_prng(struct config *config)
{
    int totbytes = 0;
    int bytes;

    if (config->egd_socket) {
        if ((bytes = RAND_egd(config->egd_socket)) == -1) {
            log_fatal("EGD Socket %s failed", config->egd_socket);
        } else {
            totbytes += bytes;
            log_debug("Snagged %d random bytes from EGD Socket %s",
                      bytes, config->egd_socket);
            goto SEEDED;        /* ditto */
        }
    }

    /* Try the good-old default /dev/urandom, if available  */
    totbytes += add_rand_file("/dev/urandom");
    if (prng_seeded(totbytes)) {
        goto SEEDED;
    }

    /* Random file specified during configure */

    log_fatal("PRNG seeded with %lu bytes total (insufficent)\n",
              (unsigned long) totbytes);
    exit(1);

  SEEDED:
    log_misc("PRNG seeded successfully\n");
    return;
}

/* ====================================================================== */

static struct config *rsa_config = NIL; /* Configuration        */
static RSA *rsa_tmp = NIL;      /* temporary RSA key    */
static time_t rsa_timeout = (time_t) 0; /* Timeout for this key */

/* ssl_make_rsakey() *****************************************************
 *
 * Set up RSAkey
 ************************************************************************/

static void ssl_make_rsakey(struct config *config)
{
    log_misc("Generating fresh RSA key");

    if (rsa_tmp)
        RSA_free(rsa_tmp);

    if (!
        (rsa_tmp =
         RSA_generate_key(SSL_RSA_KEYLENGTH, RSA_F4, NULL, NULL)))
        log_fatal("tmp_rsa_cb");

    log_misc("Generated fresh RSA key");

    if (config->ssl_rsakey_lifespan) {
        time_t now = time(NIL);
        rsa_timeout = now + config->ssl_rsakey_lifespan;
    } else
        rsa_timeout = 0;
}

/* ssl_init_rsakey() *****************************************************
 *
 * Initialise RSAkey stuff 
 ************************************************************************/

static void ssl_init_rsakey(struct config *config)
{
    ssl_make_rsakey(config);
    rsa_config = config;
}

/* ssl_freshen_rsakey() ***************************************************
 *
 * Extend life of RSA key (unless its already expired)
 *************************************************************************/

void ssl_freshen_rsakey(struct config *config)
{
    time_t now = time(NIL);

    if (rsa_tmp && (rsa_timeout != (time_t) 0L) && (rsa_timeout < now))
        rsa_timeout = now + config->ssl_rsakey_freshen;
}

/* ssl_check_rsakey() *****************************************************
 *
 * Generate fresh RSAkey if existing key has expired.
 *************************************************************************/

void ssl_check_rsakey(struct config *config)
{
    time_t now = time(NIL);

    if (rsa_tmp && (rsa_timeout != (time_t) 0L) && (rsa_timeout < now))
        ssl_make_rsakey(config);
}

/* ====================================================================== */

/* A pair of OpenSSL callbacks */

static RSA *rsa_callback(SSL * s, int export, int keylen)
{
    ssl_check_rsakey(rsa_config);

    log_misc("rsa_callback(): Requested %lu bit key", keylen);
    return rsa_tmp;
}

static void info_callback(SSL * s, int where, int ret)
{
}

/* ====================================================================== */

#ifdef SESSION_CACHE_ENABLE
/* SSL Session database, stolen from Cyrus */

#define DB (&mydb_db3_nosync)

static struct db *sessdb = NULL;
static int sess_dbopen = 0;

/*
 * The new_session_cb() is called, whenever a new session has been
 * negotiated and session caching is enabled.  We save the session in
 * a database so that we can share sessions between processes. 
 */
static int new_session_cb(SSL * ssl, SSL_SESSION * sess)
{
    int len;
    unsigned char *data = NULL, *asn;
    time_t expire;
    int ret = -1;

    if (!sess_dbopen)
        return 0;

    /* find the size of the ASN1 representation of the session */
    len = i2d_SSL_SESSION(sess, NULL);

    /*
     * create the data buffer.  the data is stored as:
     * <expire time><ASN1 data>
     */
    data = (unsigned char *)
        pool_alloc(NIL, sizeof(time_t) + len * sizeof(unsigned char));

    /* transform the session into its ASN1 representation */
    if (data) {
        asn = data + sizeof(time_t);
        len = i2d_SSL_SESSION(sess, &asn);
        if (!len)
            log_panic("i2d_SSL_SESSION failed");
    }

    /* set the expire time for the external cache, and prepend it to data */
    expire = SSL_SESSION_get_time(sess) + SSL_SESSION_get_timeout(sess);
    memcpy(data, &expire, sizeof(time_t));

    if (data && len) {
        /* store the session in our database */
        do {
            ret = DB->store(sessdb, (void *) sess->session_id,
                            sess->session_id_length,
                            (void *) data, len + sizeof(time_t), NULL);
        } while (ret == MYDB_AGAIN);
    }

    if (data)
        free(data);

    /* log this transaction */
    if (ssl_verbose_logging) {
        int i;
        char idstr[SSL_MAX_SSL_SESSION_ID_LENGTH * 2 + 1];
        for (i = 0; i < sess->session_id_length; i++)
            sprintf(idstr + i * 2, "%02X", sess->session_id[i]);

        log_debug("new SSL session: id=%s, expire=%s, status=%s",
                  idstr, ctime(&expire), ret ? "failed" : "ok");
    }
    return (ret == 0);
}

/*
 * Function for removing session from our database.
 */
static void remove_session(unsigned char *id, int idlen)
{
    int ret;

    if (!sess_dbopen)
        return;

    do {
        ret = DB->delete(sessdb, (void *) id, idlen, NULL, 1);
    } while (ret == MYDB_AGAIN);

    /* log this transaction */
    if (ssl_verbose_logging) {
        int i;
        char idstr[SSL_MAX_SSL_SESSION_ID_LENGTH * 2 + 1];
        for (i = 0; i < idlen; i++)
            sprintf(idstr + i * 2, "%02X", id[i]);

        log_debug("remove SSL session: id=%s", idstr);
    }
}

/*
 * The remove_session_cb() is called, whenever the SSL engine removes
 * a session from the internal cache. This happens if the session is
 * removed because it is expired or when a connection was not shutdown
 * cleanly.
 */
static void remove_session_cb(SSL_CTX * ctx, SSL_SESSION * sess)
{
    remove_session(sess->session_id, sess->session_id_length);
}

/*
 * The get_session_cb() is only called on SSL/TLS servers with the
 * session id proposed by the client. The get_session_cb() is always
 * called, also when session caching was disabled.  We lookup the
 * session in our database in case it was stored by another process.
 */
static SSL_SESSION *get_session_cb(SSL * ssl, unsigned char *id, int idlen,
                                   int *copy)
{
    int ret;
    const char *data = NULL;
    unsigned char *asn;
    int len = 0;
    time_t expire = 0, now = time(0);
    SSL_SESSION *sess = NULL;

    if (!sess_dbopen)
        return NULL;

    do {
        ret =
            DB->fetch(sessdb, (void *) id, idlen, (void *) &data, &len,
                      NULL);
    } while (ret == MYDB_AGAIN);

    if (data) {
        /* grab the expire time */
        memcpy(&expire, data, sizeof(time_t));

        /* check if the session has expired */
        if (expire < now) {
            remove_session(id, idlen);
        } else {
            /* transform the ASN1 representation of the session
               into an SSL_SESSION object */
            asn = (unsigned char *) data + sizeof(time_t);
            sess = d2i_SSL_SESSION(NULL, &asn, len - sizeof(time_t));
            if (!sess)
                log_panic("d2i_SSL_SESSION failed");
        }
    }

    /* log this transaction */
    if (ssl_verbose_logging) {
        int i;
        char idstr[SSL_MAX_SSL_SESSION_ID_LENGTH * 2 + 1];
        for (i = 0; i < idlen; i++)
            sprintf(idstr + i * 2, "%02X", id[i]);

        log_debug("get SSL session: id=%s, expire=%s, status=%s",
                  idstr, ctime(&expire),
                  !data ? "not found" : expire < now ? "expired" : "ok");
    }

    *copy = 0;
    return sess;
}
#endif

/* ====================================================================== */

/* ssl_context_init() ****************************************************
 *
 * Initialise SSL "context"es: one for server size activity and one for
 * client side activity.
 ************************************************************************/

void ssl_context_init(struct config *config)
{
    /* Set up random number generator */
    os_initialize_prng(config);

    /* Set up debug flag */
    ssl_verbose_logging = config->log_debug;

    SSLeay_add_ssl_algorithms();
    SSL_load_error_strings();

    if (1) {
        /* Set up client context: only used by accountd */
        client_ctx = SSL_CTX_new(SSLv3_client_method());
        SSL_CTX_set_session_cache_mode(client_ctx, SSL_SESS_CACHE_OFF);
        SSL_CTX_set_info_callback(client_ctx, info_callback);
        SSL_CTX_set_mode(client_ctx, SSL_MODE_AUTO_RETRY);

        if (SSL_CTX_need_tmp_RSA(client_ctx))
            SSL_CTX_set_tmp_rsa_callback(client_ctx, rsa_callback);

        /* Don't bother with session cache for client side: not enough
         * connections to worry about caching */
        SSL_CTX_set_timeout(client_ctx, 0);

        /* Set cipherlist */
        if (!SSL_CTX_set_cipher_list(client_ctx, SSLCIPHERLIST))
            log_fatal("SSL_CTX_set_cipher_list");
    }

    /* Set up server context */
    server_ctx = SSL_CTX_new(SSLv23_server_method());

    /* Enable all (sensible) bug fixes. A few others are not recommended */
    /* See ssl.h for details */
    SSL_CTX_set_options(server_ctx, SSL_OP_ALL);
#if 0
    /* Following appears to break Netscape? */
    SSL_CTX_set_options(server_ctx,
                        SSL_OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG);
#endif

    /* Start off with the session cache disabled */
    SSL_CTX_set_session_cache_mode(server_ctx, SSL_SESS_CACHE_OFF);
    SSL_CTX_sess_set_cache_size(server_ctx, 0);
    SSL_CTX_set_timeout(server_ctx, 0);

    SSL_CTX_set_info_callback(server_ctx, info_callback);
    SSL_CTX_set_mode(server_ctx, SSL_MODE_AUTO_RETRY);
    SSL_CTX_set_quiet_shutdown(server_ctx, 1);

#ifdef SESSION_CACHE_ENABLE
    if (config->ssl_session_timeout > 0) {
        int r;

        /* Set the callback functions for the external session cache */
        SSL_CTX_set_session_cache_mode(server_ctx, SSL_SESS_CACHE_BOTH);
        SSL_CTX_sess_set_cache_size(server_ctx, 128);
        SSL_CTX_set_timeout(server_ctx, config->ssl_session_timeout);

        /* Initialise the session cache */
        /* Initialize DB environment */
        r = DB->init(config->ssl_session_dir, 0);
        if (r != 0)
            log_fatal("DBERROR init: %s", mydb_strerror(r));

        /* create the name of the db file */
        r = DB->open("sessions.db", &sessdb);
        if (r != 0)
            log_fatal("DBERROR: opening %s: %s",
                      "sessions.db", mydb_strerror(r));
        sess_dbopen = 1;

        /* Set the callback functions for the external session cache */
        SSL_CTX_sess_set_new_cb(server_ctx, new_session_cb);
        SSL_CTX_sess_set_remove_cb(server_ctx, remove_session_cb);
        SSL_CTX_sess_set_get_cb(server_ctx, get_session_cb);
    }
#endif

    /* Set up DH file, if required */
    if (config->ssl_dh_file) {
        static DH *dh = NULL;
        BIO *bio = NULL;

        if (!(bio = BIO_new_file(config->ssl_dh_file, "r")))
            log_fatal("Error reading DH file: %s\n", config->ssl_dh_file);

        if (!(dh = PEM_read_bio_DHparams(bio, NULL, NULL, NULL)))
            log_fatal("Could not load DH parameters from: %s\n",
                      config->ssl_dh_file);

        SSL_CTX_set_tmp_dh(server_ctx, dh);
        if (bio)
            BIO_free(bio);
        if (dh)
            DH_free(dh);
    }

    /* Set up certificate file */
    if (!SSL_CTX_use_certificate_file(server_ctx, config->ssl_cert_file,
                                      SSL_FILETYPE_PEM))
        log_fatal("Error reading certificate file: %s\n",
                  config->ssl_cert_file);

#if 0
    /* Experiment: try setting up certificate chain here... */
    if (!SSL_CTX_use_certificate_chain_file
        (server_ctx, config->ssl_cert_file))
        log_fatal("Error reading certificate chains from file: %s\n",
                  config->ssl_cert_file);
#endif

    /* Set up PrivateKey file */
    if (!SSL_CTX_use_PrivateKey_file
        (server_ctx, config->ssl_privatekey_file, SSL_FILETYPE_PEM))
        log_fatal
            ("SSL_CTX_use_RSAPrivateKey_file: failed to use file %s\n",
             config->ssl_privatekey_file);

    /* Set cipherlist */
    if (!SSL_CTX_set_cipher_list(server_ctx, SSLCIPHERLIST))
        log_fatal("SSL_CTX_set_cipher_list() failed");

    /* Set up RSA temporary key callback routine */
    if (SSL_CTX_need_tmp_RSA(server_ctx))
        SSL_CTX_set_tmp_rsa_callback(server_ctx, rsa_callback);

    /* Initialise RSA temporary key (will take a couple of secs to complete) */
    ssl_init_rsakey(config);
}

void ssl_context_free()
{
    SSL_CTX_free(server_ctx);
}

void ssl_shutdown(void *ssl)
{
    SSL_shutdown((SSL *) ssl);
}

int ssl_get_error(void *ssl, int code)
{
    return (SSL_get_error((SSL *) ssl, code));
}

void ssl_free(void *ssl)
{
    SSL_free((SSL *) ssl);
    ERR_remove_state(0);
}

/* ====================================================================== */

/* ssl_server_server() ***************************************************
 *
 * Start server side SSL
 ************************************************************************/

void *ssl_start_server(int fd, unsigned long timeout)
{
    SSL *ssl;
    SSL_CIPHER *c;
    X509 *client_cert;
    char *ver;
    int bits;

    if (!(ssl = SSL_new(server_ctx)))
        return (NIL);

    SSL_set_session_id_context(ssl, sid_ctx, strlen((char *) sid_ctx));

    SSL_set_fd(ssl, fd);
    SSL_set_accept_state(ssl);

    if (timeout > 0) {
        fd_set readfds;
        struct timeval timeval;

        FD_ZERO(&readfds);
        FD_SET(fd, &readfds);

        /* Check for SSL negotiation */
        timeval.tv_sec = timeout;
        timeval.tv_usec = 0;

        while (select(fd + 1, &readfds, NIL, NIL, &timeval) < 0) {
            if (errno != EINTR) {
                SSL_shutdown(ssl);      /* Safe? */
                SSL_free(ssl);
                ERR_remove_state(0);
                return (NIL);
            }
        }

        if (!FD_ISSET(fd, &readfds)) {
            SSL_shutdown(ssl);  /* Safe? */
            SSL_free(ssl);
            ERR_remove_state(0);
            return (NIL);
        }
    }

    if (SSL_accept(ssl) <= 0) {
        SSL_shutdown(ssl);
        SSL_free(ssl);
        ERR_remove_state(0);
        return (NIL);
    }

    if ((client_cert = SSL_get_peer_certificate(ssl)))
        log_debug("SSL: Have client certificate");
    else
        log_debug("SSL: No client certificate");

    switch (ssl->session->ssl_version) {
    case SSL2_VERSION:
        ver = "SSLv2";
        break;
    case SSL3_VERSION:
        ver = "SSLv3";
        break;
    case TLS1_VERSION:
        ver = "TLSv1";
        break;
    default:
        ver = "UNKNOWN";
    }
    c = SSL_get_current_cipher(ssl);

    SSL_CIPHER_get_bits(c, &bits);
    log_debug("Opened with %s, cipher %s (%lu bits)\n",
              ver, SSL_CIPHER_get_name(c), (unsigned long) bits);

    return ((void *) ssl);
}

/* ssl_server_client() ***************************************************
 *
 * Start client side SSL
 ************************************************************************/

void *ssl_start_client(int fd, unsigned long timeout)
{
    SSL *ssl;
    SSL_CIPHER *c;
    char *ver;
    int bits;

    if (!(ssl = (void *) SSL_new(client_ctx)))
        return (NIL);

    SSL_set_session_id_context((SSL *) ssl, sid_ctx,
                               strlen((char *) sid_ctx));

    SSL_set_fd((SSL *) ssl, fd);
    SSL_set_connect_state((SSL *) ssl);

    if (SSL_connect((SSL *) ssl) <= 0)
        return (NIL);

    /* Verify certificate here? Need local context to play with? */

    switch (((SSL *) ssl)->session->ssl_version) {
    case SSL2_VERSION:
        ver = "SSLv2";
        break;
    case SSL3_VERSION:
        ver = "SSLv3";
        break;
    case TLS1_VERSION:
        ver = "TLSv1";
        break;
    default:
        ver = "UNKNOWN";
    }
    c = SSL_get_current_cipher((SSL *) ssl);
    SSL_CIPHER_get_bits(c, &bits);
    log_debug("Opened client connection with %s, cipher %s (%lu bits)\n",
              ver, SSL_CIPHER_get_name(c), (unsigned long) bits);

    return ((void *) ssl);
}

/* ====================================================================== */

/* ssl_read() ************************************************************
 *
 * read() from SSL pipe:
 *    ssl     - SSL abstraction
 *  buffer    - Buffer to read into
 *  blocksize - Size of buffer
 *
 * Returns: Numbers of bytes read. 0 => EOF, -1 => error
 ************************************************************************/

int ssl_read(void *ssl, unsigned char *buffer, unsigned long blocksize)
{
    return (SSL_read((SSL *) ssl, (char *) buffer, blocksize));
}

/* ssl_write() ***********************************************************
 *
 * write() to SSL pipe:
 *    ssl  - SSL abstraction
 *  buffer - Buffer to write from
 *  bytes  - Number of bytes to write
 *
 * Returns: Numbers of bytes written. -1 => error
 ************************************************************************/

int ssl_write(void *ssl, unsigned char *buffer, unsigned long bytes)
{
    return (SSL_write((SSL *) ssl, (char *) buffer, bytes));
}

/* ssl_read() ************************************************************
 *
 * Check for pending input on SSL pipe.
 ************************************************************************/

int ssl_pending(void *ssl)
{
    return (SSL_pending((SSL *) ssl));
}

/* ====================================================================== */

#ifdef SESSION_CACHE_ENABLE
/*
 * Delete expired sessions: again stolen from Cyrus.
 */
struct prunerock {
    struct db *db;
    int count;
    int deletions;
};

static int
prune_p(void *rock, const char *id, int idlen,
        const char *data, int datalen)
{
    struct prunerock *prock = (struct prunerock *) rock;
    time_t expire;

    prock->count++;

    /* grab the expire time */
    memcpy(&expire, data, sizeof(time_t));

    /* log this transaction */
    if (ssl_verbose_logging) {
        int i;
        char idstr[SSL_MAX_SSL_SESSION_ID_LENGTH * 2 + 1];
        for (i = 0; i < idlen; i++)
            sprintf(idstr + i * 2, "%02X", id[i]);

        log_debug("found SSL session: id=%s, expire=%s",
                  idstr, ctime(&expire));
    }

    /* check if the session has expired */
    return (expire < time(0));
}

static int
prune_cb(void *rock, const char *id, int idlen,
         const char *data, int datalen)
{
    struct prunerock *prock = (struct prunerock *) rock;
    int ret;

    prock->deletions++;

    do {
        ret = DB->delete(prock->db, id, idlen, NULL);
    } while (ret == MYDB_AGAIN);

    /* log this transaction */
    if (ssl_verbose_logging) {
        int i;
        char idstr[SSL_MAX_SSL_SESSION_ID_LENGTH * 2 + 1];
        for (i = 0; i < idlen; i++)
            sprintf(idstr + i * 2, "%02X", id[i]);

        log_debug("expiring SSL session: id=%s", idstr);
    }

    return 0;
}

int ssl_prune_sessions(struct config *config)
{
    int ret;
    struct prunerock prock;

    /* initialize DB environment */
    DB->init(config->ssl_session_dir, 0);

    /* create the name of the db file */

    ret = DB->open("sessions.db", &sessdb);
    if (ret != MYDB_OK)
        log_fatal("DBERROR: opening %s: %s",
                  "sessions.db", mydb_strerror(ret));

    /* check each session in our database */
    prock.db = sessdb;
    prock.count = prock.deletions = 0;
    DB->foreach(sessdb, "", 0, &prune_p, &prune_cb, &prock, NULL);
    DB->close(sessdb);
    sessdb = NULL;

    log_debug("tls_prune: purged %d out of %d entries",
              prock.deletions, prock.count);

    DB->done();

    return (0);
}
#else
int ssl_prune_sessions(struct config *config)
{
    return (0);
}
#endif


syntax highlighted by Code2HTML, v. 0.9.1