/*****************************************************************************\
* Copyright (c) 2004 Pelle Johansson.                                         *
* All rights reserved.                                                        *
*                                                                             *
* This file is part of the moftpd package. Use and distribution of            *
* this software is governed by the terms in the file LICENCE, which           *
* should have come with this package.                                         *
\*****************************************************************************/

/* $moftpd: tls_openssl.c 1223 2004-10-28 16:14:40Z morth $ */

#include "system.h"

#include "tls.h"

#include "utf8fs/memory.h"

static SSL_CTX *sslCtx;
extern char *sslCertsPath;

const char *tls_get_cert_dir (void)
{
  return X509_get_default_cert_dir ();
}

tls_t tls_open (int fd, int options, tlscert_t cert, tlskey_t key)
{
  tls_t res;
  int sslOpts;
  
  if (fd < 0 || !cert || !key)
    return NULL;
  if (sslCtx == (SSL_CTX *)-1)
    return NULL;
  
  if (!sslCtx)
  {
    SSL_load_error_strings ();
    SSL_library_init ();
    
    sslCtx = SSL_CTX_new (SSLv23_server_method ());
    if (!sslCtx)
    {
      // We only try once.
      sslCtx = (SSL_CTX*)-1;
      return NULL;
    }
    SSL_CTX_load_verify_locations (sslCtx, NULL, sslCertsPath);
  }
  
  res = palloc (sizeof (tls_t), NULL, NULL);
  if (!res)
    return NULL;
  res->ssl = SSL_new (sslCtx);
  if (!res->ssl)
  {
    pfree (res, NULL);
    return NULL;
  }
  
  sslOpts = 0;
  SSL_set_options (res->ssl, sslOpts);
  if (!SSL_use_certificate (res->ssl, cert))
  {
    syslog (LOG_ERR, "Failed to load certificate file: %s",
	  ERR_reason_error_string (ERR_get_error ()));
    SSL_free (res->ssl);
    pfree (res, NULL);
    return NULL;
  }
  if (!SSL_use_PrivateKey (res->ssl, key))
  {
    syslog (LOG_ERR, "Failed to load private key: %s",
	  ERR_reason_error_string (ERR_get_error ()));
    SSL_free (res->ssl);
    pfree (res, NULL);
    return NULL;
  }
  if (options & tlsVerifyClient)
  {
    STACK_OF(X509_NAME) *caStack;
    
    SSL_set_verify (res->ssl, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
	  SSL_VERIFY_CLIENT_ONCE, NULL);
    
    caStack = sk_X509_new_null ();
    SSL_add_dir_cert_subjects_to_stack (caStack, sslCertsPath);
    SSL_set_client_CA_list (res->ssl, caStack);
  }
  res->bio = BIO_new_socket (fd, BIO_NOCLOSE);
  if (!res->bio)
  {
    SSL_free (res->ssl);
    pfree (res, NULL);
    return NULL;
  }
  BIO_set_nbio (res->bio, 1);
  return res;
}

void tls_start (tls_t tls)
{
  SSL_set_bio (tls->ssl, tls->bio, tls->bio);
}

int tls_stop (tls_t tls)
{
  int l = SSL_shutdown (tls->ssl);
  
  if (l == 1)
    return 1;
  if (l)
  {
    if (SSL_get_error (tls->ssl, l) == SSL_ERROR_WANT_READ)
      return 0;
    return l;
  }
  return -1;
}

void tls_free (tls_t tls)
{
  SSL_free (tls->ssl);
  pfree (tls, NULL);
}

int tls_accept (tls_t tls)
{
  int l = SSL_accept (tls->ssl);
  
  if (l == 1)
    return 1;
  if (l)
  {
    if (SSL_get_error (tls->ssl, l) == SSL_ERROR_WANT_READ)
      return 0;
    errno = 0;
    return l;
  }
  return -1;
}

ssize_t tls_read (tls_t tls, void *buf, size_t maxlen)
{
  return SSL_read (tls->ssl, buf, maxlen);
}

ssize_t tls_write (tls_t tls, const void *buf, size_t len)
{
  return SSL_write (tls->ssl, buf, len);
}

ssize_t tls_write_vecs (tls_t tls, struct iovec *vecs, int num)
{
  int i, l = 0;
  int res = 0;
  
  for (i = 0; i < num; i++)
  {
    l = tls_write (tls, vecs[i].iov_base, vecs[i].iov_len);
    if (l < 0)
      break;
    res += l;
  }
  if (res)
    return res;
  return l;
}

tlscert_t tls_read_cert (const char *file)
{
  FILE *fp = fopen (file, "rb");
  tlscert_t res;
  
  if (!fp)
    return NULL;
  
  res = PEM_read_X509 (fp, NULL, NULL, NULL);
  fclose (fp);
  return res;
}

tlscert_t tls_get_peer_cert (const tls_t tls)
{
  return SSL_get_peer_certificate (tls->ssl);
}

void tls_free_cert (tlscert_t cert)
{
  X509_free (cert);
}

const char *tls_get_cn (tlscert_t cert)
{
  X509_NAME *subject = X509_get_subject_name (cert);
  static char buf[100];
  
  if (X509_NAME_get_text_by_NID (subject, NID_commonName, buf, sizeof (buf)) > 0)
    return buf;
  return NULL;
}

int tls_compare_certs (const tlscert_t c1, const tlscert_t c2)
{
  return X509_cmp (c1, c2);
}

tlskey_t tls_read_key (const char *file)
{
  FILE *fp = fopen (file, "rb");
  tlskey_t res;
  
  if (!fp)
    return NULL;
  
  res = PEM_read_PrivateKey (fp, NULL, NULL, NULL);
  fclose (fp);
  return res;
}

void tls_free_key (tlskey_t key)
{
  EVP_PKEY_free (key);
}

const char *tls_error (const tls_t tls, int res)
{
  return ERR_reason_error_string (ERR_get_error ());
}


syntax highlighted by Code2HTML, v. 0.9.1