/*
  
  $Id: ntlmauth.c,v 1.7 2004/03/10 13:35:21 thivillon Exp $

  © 2003 Alain Thivillon et Hervé Schauer Consultants 

*/

#include "config.h"

#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/rand.h>
#include <openssl/md4.h>
#include <errno.h>
#include <sys/types.h> 
#include <sys/socket.h>
#include <signal.h>
#include <sys/wait.h>
#include <sys/stat.h>
#include <sys/ioctl.h>
#include <sys/time.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <fcntl.h>
#include <string.h>
#include <termios.h>
#include <syslog.h>
#include <iconv.h>
#include <ctype.h>

#ifndef INADDR_NONE
#define       INADDR_NONE             0xffffffff
#endif

#include "ntlmauth.h"
#include "base64.h"

#define FROMLITTLE(p)  (((p)[0] & 0xff) + (((p)[1] & 0xff) << 8))
#define TOLITTLE(p, i) ((p)[0] = (i) & 0xff, (p)[1] = ((i) >> 8) & 0xff)

#define min(a,b) (((a) < (b)) ? (a) : (b))

extern int write_timeout(int s,  char *buffer, int size, int tmo);
extern int read_timeout(int s, char *buffer, int size, int tmo);
extern void do_log( int prio, char *format , ... );

static void setup_des_key(unsigned char key_56[], des_key_schedule *ks);
static void calc_resp(unsigned char *keys, unsigned char *plaintext, unsigned char *results);
static char *do_ntlm_auth_1(char *username, char *password);
static char *do_ntlm_auth_2(iconv_t myiconv, char *username, char *password, char *challenge);
   
static void dump_hex(unsigned char *string, int len);

static char *do_ntlm_auth_1(char *username, char *password) {

  char hostname[256];
  char domainname[256];
  mess_ntlm_1 mess_1;
  char *ptr,*ptr_dest;
  byte NTLM_string[] =  { 'N', 'T', 'L', 'M', 'S', 'S', 'P', '\0' };
  char *buffer_in=NULL, *buffer_out=NULL;
  unsigned short h_len;
  unsigned short d_len; 
  unsigned short client_flags;

  memset(&mess_1,0,sizeof(mess_ntlm_1));

  memcpy(&mess_1.protocol , &NTLM_string, 8);
  mess_1.type = NTLMSSP_NEGOCIATE;

  client_flags = NTLMSSP_ALWAYS_SIGN | NTLMSSP_WANT_NTLM | NTLMSSP_WANT_LM;

  client_flags |= (NTLMSSP_WANT_UNICODE |  NTLMSSP_WANT_OEM);
  

  if (gethostname ( hostname, 255) ) { 
    do_log(LOG_ERR, "gethostname : %s", strerror(errno));
    goto catch;
  }

  hostname[255] = 0;
  if ((ptr = strchr(hostname, '.')) != NULL) *ptr= '\0'; 

  strncpy(domainname, username, 255);

  /* If username is DOMAIN\user , get it and set flags */
  if ((ptr = strchr(domainname, '\\')) != NULL) { 
      *ptr = '\0';
      client_flags |= (NTLMSSP_HAS_DOMAIN | NTLMSSP_HAS_WORKGROUP);
  }
  else domainname[0] = '\0';

  TOLITTLE(mess_1.flags , client_flags);
  
  h_len =  strlen(hostname);
  d_len =  strlen(domainname);
  
  TOLITTLE(mess_1.dom_len, d_len );
  TOLITTLE(mess_1.dom_len_2, d_len );
  TOLITTLE(mess_1.host_len, h_len);
  TOLITTLE(mess_1.host_len_2 , h_len);

  TOLITTLE(mess_1.host_off , (unsigned short) 0x20);
  TOLITTLE(mess_1.dom_off , 0x20 + h_len);

  buffer_in = malloc(sizeof(mess_1) + h_len + d_len + 10);
  memcpy(buffer_in, &mess_1, sizeof(mess_1));
  ptr_dest = buffer_in + sizeof(mess_1); 
  for (ptr = hostname, ptr_dest = buffer_in + sizeof(mess_1); *ptr; ptr++, ptr_dest++)
    *ptr_dest = toupper(*ptr);
  for (ptr = domainname, ptr_dest = buffer_in + sizeof(mess_1) + h_len; *ptr; ptr++, ptr_dest++)
    *ptr_dest = toupper(*ptr);
  buffer_out = malloc ((sizeof(mess_1) + h_len + d_len) * 2);

  base64_encode (buffer_in, buffer_out , sizeof(mess_1) + h_len + d_len);
  free(buffer_in);
  return buffer_out;

catch:
  if (buffer_in != NULL) free(buffer_in);
  return NULL;

}

static char *do_ntlm_auth_2(iconv_t myiconv, char *username, char *password, char *challenge) {

  char *bufplain = NULL;
  mess_ntlm_2 *mess2; 
  mess_ntlm_3 mess3; 
  byte NTLM_string[] =  { 'N', 'T', 'L', 'M', 'S', 'S', 'P', '\0' };
  char *buffer_in = NULL, *buffer_out = NULL;
  unsigned short h_len;
  unsigned short d_len; 
  unsigned short u_len; 
  char *ptr,*inbuf,*outbuf,*nonce;
  char *user=username;
  size_t len_uni;
  size_t len_uni_after;
  char hostname[256];
  char domainname[256];
  char  lm_pw[14];
  int   len,idx;
  char  *nt_pw = NULL,*lm_resp,*nt_resp;
  const unsigned char magic[] = "KGS!@#$%";
  unsigned char lm_hpw[50];
  des_key_schedule ks1;
  des_key_schedule ks2;
  unsigned char nt_hpw[50];
  MD4_CTX context;
  unsigned short server_flags;
  unsigned short client_flags;
  unsigned short ntlm_offset;
  unsigned short total_len;

  bufplain = malloc (strlen(challenge));
  base64_decode(bufplain, challenge);
  mess2 = (mess_ntlm_2 *) bufplain;
  
  if (strcmp(mess2->protocol,"NTLMSSP")) {
    do_log(LOG_WARNING,"Bad NTLM challenge");
    goto catch;
  }

  server_flags = FROMLITTLE(mess2->flags);

  memset(&mess3, 0, sizeof(mess3));
  memcpy(&mess3.protocol , &NTLM_string, 8);

  /* Set client answer flags */
  mess3.type = NTLMSSP_AUTH;
  client_flags = NTLMSSP_ALWAYS_SIGN;
  if (server_flags & NTLMSSP_WANT_UNICODE) client_flags |= NTLMSSP_WANT_UNICODE;
  else client_flags |= NTLMSSP_WANT_OEM;
  if (server_flags & NTLMSSP_WANT_NTLM) client_flags |= NTLMSSP_WANT_NTLM;
  else if (server_flags & NTLMSSP_WANT_LM) client_flags |= NTLMSSP_WANT_LM;

  TOLITTLE(mess3.dom_off, 0x40);

  if (gethostname ( hostname, 255) ) goto catch;
  hostname[255] = 0;
  if ((ptr = strchr(hostname, '.')) != NULL) *ptr= '\0'; 

  strncpy(domainname, username, 255);
  if ((ptr = strchr(domainname, '\\')) != NULL) {
    *ptr = '\0';
    user = ptr+1;
    client_flags |= (NTLMSSP_HAS_DOMAIN | NTLMSSP_HAS_WORKGROUP);
  }
  else domainname[0] = '\0';

  TOLITTLE(mess3.flags , (unsigned short) client_flags);
  
  h_len = strlen(hostname);
  d_len = strlen(domainname);
  u_len = strlen(user);

  ptr = hostname;
  while (*ptr) { *ptr = toupper(*ptr); ptr ++; }
  ptr = domainname;
  while (*ptr) { *ptr = toupper(*ptr); ptr ++; }

  buffer_in = malloc(sizeof(mess3) + 2*h_len + 2*u_len + 2*d_len + 2*0x18);

  /* Start Iconv */
  iconv(myiconv, NULL, NULL, NULL, NULL);

  if (client_flags & NTLMSSP_WANT_UNICODE) {
    /* Conv Domain Name */
    len_uni = 2 * d_len;
    len_uni_after = len_uni;
    
    inbuf = domainname;
    outbuf = buffer_in + sizeof(mess3);
    iconv(myiconv, (char* *) &inbuf, (size_t * ) &len_uni, 
		   (char* *) &outbuf, (size_t *) &len_uni_after);
    if (len_uni_after != 0) {
      do_log(LOG_WARNING,"Bad Conversion of domainname");
      goto catch;
    }
  }
  else {
    len_uni = d_len;
    len_uni_after = len_uni;
    inbuf = domainname;
    outbuf = buffer_in + sizeof(mess3);
    memcpy(outbuf, inbuf, len_uni);
  }

  /* Conv User Name */
  iconv(myiconv, NULL, NULL, NULL, NULL);
  if (client_flags & NTLMSSP_WANT_UNICODE) {
    len_uni = 2 * u_len;
    len_uni_after = len_uni;
    inbuf = user;
    outbuf = buffer_in + sizeof(mess3) +  2 * d_len;
    iconv(myiconv, (char* *) &inbuf, (size_t * ) &len_uni, 
		   (char* *) &outbuf, (size_t *) &len_uni_after);
    if (len_uni_after != 0) {
      do_log(LOG_WARNING,"Bad Conversion of username");
      goto catch;
    }
  }
  else {
    len_uni =  u_len;
    len_uni_after = len_uni;
    inbuf = user;
    outbuf = buffer_in + sizeof(mess3) +  d_len;
    memcpy(outbuf, inbuf, len_uni);
  }

  /* Conv Hostname */
  iconv(myiconv, NULL, NULL, NULL, NULL);
  if (client_flags & NTLMSSP_WANT_UNICODE) {
    len_uni = 2 * h_len;
    len_uni_after = len_uni;
    inbuf = hostname;
    outbuf = buffer_in + sizeof(mess3) +  2 * d_len + 2 * u_len;
    iconv(myiconv, (char* *) &inbuf, (size_t * ) &len_uni, 
		   (char* *) &outbuf, (size_t *) &len_uni_after);
    if (len_uni_after != 0) {
      do_log(LOG_WARNING,"Bad Conversion of hostname");
      goto catch;
    }
  }
  else {
    len_uni = h_len;
    len_uni_after = len_uni;
    inbuf = hostname;
    outbuf = buffer_in + sizeof(mess3) +   d_len +  u_len;
    memcpy(outbuf, inbuf, len_uni);
  }

  if (client_flags & NTLMSSP_WANT_UNICODE) {
    TOLITTLE(mess3.dom_len_1, d_len * 2);
    TOLITTLE(mess3.dom_len_2, d_len * 2);
    TOLITTLE(mess3.host_len_1, h_len * 2);
    TOLITTLE(mess3.host_len_2 , h_len * 2);
    TOLITTLE(mess3.user_len_1, u_len * 2);
    TOLITTLE(mess3.user_len_2 , u_len * 2);
    TOLITTLE(mess3.user_off, sizeof(mess3) + d_len * 2);
    TOLITTLE(mess3.host_off, sizeof(mess3) + d_len * 2 + u_len *2);


    nonce = (char *) &(mess2->nonce);

    ntlm_offset = sizeof(mess3) + d_len * 2 + h_len *2 + u_len * 2 ;
    total_len = sizeof(mess3) + d_len * 2 + h_len *2 + u_len * 2;
    /* Not necessary everytime, but does not choke NT and 
     permits Ethereal dissector to work ... */
    TOLITTLE(mess3.lm_resp_off, 0x40 + d_len * 2 + h_len *2 + u_len * 2);
  }
  else {
    TOLITTLE(mess3.dom_len_1, d_len );
    TOLITTLE(mess3.dom_len_2, d_len );
    TOLITTLE(mess3.host_len_1, h_len );
    TOLITTLE(mess3.host_len_2 , h_len);
    TOLITTLE(mess3.user_len_1, u_len );
    TOLITTLE(mess3.user_len_2 , u_len);
    TOLITTLE(mess3.user_off, sizeof(mess3) + d_len );
    TOLITTLE(mess3.host_off, sizeof(mess3) + d_len + u_len );

    nonce = (char *) &(mess2->nonce);

    ntlm_offset = sizeof(mess3) + d_len + h_len + u_len ;
    total_len = sizeof(mess3) + d_len + h_len + u_len ;
    TOLITTLE(mess3.lm_resp_off, 0x40 + d_len + h_len + u_len );
  }

  len = strlen(password);

  if (1) {
    /* Squid NTLM Auth announce NTLM but does not use it !!!
       So we must stuck to send LM hash ! 
      //if (client_flags & NTLMSSP_WANT_LM) {
    */

    /* Lan Manager part */

    TOLITTLE(mess3.lm_resp_len_1, 0x18);
    TOLITTLE(mess3.lm_resp_len_2, 0x18);
    ntlm_offset += 0x18;
    total_len+= 0x18;

    /* setup LanManager password */
    if (len > 14)  len = 14;
    for (idx=0; idx<len; idx++)
	lm_pw[idx] = toupper(password[idx]);
    for (; idx<14; idx++)
	lm_pw[idx] = 0;

    /* create LanManager hashed password */

    setup_des_key(lm_pw, &ks1);
    DES_ecb_encrypt((const_DES_cblock *) magic, (const_DES_cblock *) lm_hpw, &ks1, DES_ENCRYPT);

    setup_des_key(lm_pw+7, &ks2);
    DES_ecb_encrypt((const_DES_cblock *) magic, (const_DES_cblock *) (char *) (lm_hpw+8), 
        &ks2, DES_ENCRYPT);

    memset(lm_hpw+16, 0, 5);

    /* create LanMan response */
    if (client_flags & NTLMSSP_WANT_UNICODE) 
      lm_resp = buffer_in + sizeof(mess3) +  2 * d_len + 2 * u_len + 2* h_len ;
    else 
      lm_resp = buffer_in + sizeof(mess3) +  d_len + u_len + h_len ;

    calc_resp(lm_hpw, nonce, lm_resp);

  }

  if (client_flags & NTLMSSP_WANT_NTLM) {

    /* NTLM Part */

    TOLITTLE(mess3.nt_resp_len_1, 0x18);
    TOLITTLE(mess3.nt_resp_len_2, 0x18);
    TOLITTLE(mess3.nt_resp_off, ntlm_offset);
    total_len+= 0x18;

    /* create NT hashed password */

    nt_pw = malloc(2*len);
    memset(nt_pw, 0, 2 * len);
    for (idx=0; idx<len; idx++)
    {
	nt_pw[2*idx]   = password[idx];
	nt_pw[2*idx+1] = 0;
    }

    MD4_Init(&context);
    MD4_Update(&context, nt_pw, 2*len);
    MD4_Final(nt_hpw, &context);

    memset(nt_hpw+16, 0, 5);

    nt_resp = buffer_in + ntlm_offset;
    calc_resp(nt_hpw, nonce, nt_resp);

  }

  TOLITTLE(mess3.msg_len, total_len);
  memcpy(buffer_in, &mess3, sizeof(mess3));

  buffer_out = malloc(2 * total_len);
  base64_encode (buffer_in, buffer_out , total_len);

  if (buffer_in != NULL) free(buffer_in);
  if (bufplain != NULL) free(bufplain);
  if (nt_pw != NULL) free(nt_pw);

  return buffer_out;

catch:
  
  /* Something is wrong */
  if (buffer_in != NULL) free(buffer_in);
  if (bufplain != NULL) free(bufplain);
  if (nt_pw != NULL) free(nt_pw);

  return NULL;

}

int do_ntlm_auth (int sock, char *username , char *password, char *uri, 
                  char *host, char *useragent, int mode) {

  char *first_packet = NULL ,*ptr = NULL, *ptr_data = NULL , *second_packet = NULL;
  char buffer[1024];
  char response[2048];
  int crlf;
  iconv_t myiconv;
  int lg;
  int doc_length;
  char *nt_challenge=NULL,*ptr_authenticate=NULL;
  char *auth_header = "Authorization:";
  char *ptr_content_length=NULL; 

  myiconv = iconv_open("utf-16le","latin1");
  if (myiconv == (iconv_t) -1) {
    do_log(LOG_ERR, "iconv_open failed : %s", strerror(errno)) ;
    goto catch;
  }

  first_packet = do_ntlm_auth_1(username, password);
  if (first_packet == NULL) {
    goto catch;
  }

  if (mode == PROXY_AUTH) auth_header="Proxy-Authorization:";
  snprintf(buffer, 1023, 
       "%s\r\nConnection: Keep-Alive\r\nHost: %s\r\n%s NTLM %s\r\nUser-Agent: %s\r\n\r\n",
           uri,
           host,
           auth_header,
           first_packet,
           useragent
  );

  do_log(LOG_INFO, "Sent to proxy : \n%s", buffer);
  if (write_timeout(sock, buffer, strlen(buffer), 10) <= 0) {
     goto catch;
  }

  ptr = response;
  crlf = 0;
  while (!crlf && (ptr < response +2048))  {
      lg = read_timeout(sock,ptr, 2048 - (ptr-response), 10);
      if (lg > 0) {
        *(ptr+lg) = 0;
        if ((ptr_data=strstr(response,"\r\n\r\n")) != NULL)  {
          crlf=1;
          *ptr_data = 0;
        }
        ptr += lg;
      }
      else {
        goto catch;
      }
  }
  doc_length = 0;

  /* Read rest of data */
  if (crlf) {
    do_log(LOG_INFO, "Proxy Answer:\n%s\n", response);
    if ((ptr_content_length = strcasestr(response,"Content-Length: ")) != NULL) {
       sscanf(ptr_content_length+16,"%d", &doc_length);
       doc_length -= (ptr - (ptr_data + 4));
    }
    if ( mode == WWW_AUTH && 
         (ptr_authenticate = strcasestr(response, "WWW-Authenticate: NTLM ")) != NULL) {
        if ((ptr = strstr(ptr_authenticate, "\r\n")) != NULL) {
          *ptr = '\0';
          nt_challenge = strdup(ptr_authenticate+23);
        }
    }
    if ( mode == PROXY_AUTH && 
        (ptr_authenticate = strcasestr(response, "Proxy-Authenticate: NTLM ")) != NULL) {
       if ((ptr = strstr(ptr_authenticate, "\r\n")) != NULL) {
          *ptr = '\0';
          nt_challenge = strdup(ptr_authenticate+25);
       }
    }
  }
  else {
     do_log(LOG_WARNING, "No Headers end\n%s", response);
     goto catch;
  }

  while (doc_length > 0) {
    lg = read_timeout(sock, response, min(2048, doc_length), 10);
    if (lg > 0) {
       *(ptr+lg) = 0;
       doc_length -= lg;
    }
    else {
        fprintf(stderr,"Short read\n");
        goto catch;
    }
  } 

  /* We have read all data */

  if (nt_challenge == NULL) {
    do_log(LOG_WARNING, "No valid NTLM challenge");
    goto catch;
  }

  second_packet = do_ntlm_auth_2(myiconv, username, password, nt_challenge);
  if (second_packet == NULL) {
    goto catch;
  }

  snprintf(buffer, 1023, 
       "%s\r\nConnection: Keep-Alive\r\nHost: %s\r\n%s NTLM %s\r\nUser-Agent: %s\r\n\r\n",
          uri,
          host,
          auth_header,
          second_packet,
          useragent
  );
  do_log(LOG_INFO, "Sent to proxy : \n%s", buffer);

  if (write_timeout(sock, buffer, strlen(buffer), 10) <= 0) {
     goto catch;
  }

  free(first_packet);
  free(second_packet); 
  free(nt_challenge);
  if (myiconv != (iconv_t) -1) iconv_close(myiconv);
  return 0;

catch:

  if (first_packet != NULL) free(first_packet);
  if (second_packet != NULL) free(second_packet); 
  if (nt_challenge != NULL) free(nt_challenge);
  if (myiconv != (iconv_t) -1) iconv_close(myiconv);
  return -1;

}

/*
 * turns a 56 bit key into the 64 bit, odd parity key and sets the key.
 * The key schedule ks is also set.
 */
static void setup_des_key(unsigned char key_56[], des_key_schedule *ks)
{
    des_cblock key;

    key[0] = key_56[0];
    key[1] = ((key_56[0] << 7) & 0xFF) | (key_56[1] >> 1);
    key[2] = ((key_56[1] << 6) & 0xFF) | (key_56[2] >> 2);
    key[3] = ((key_56[2] << 5) & 0xFF) | (key_56[3] >> 3);
    key[4] = ((key_56[3] << 4) & 0xFF) | (key_56[4] >> 4);
    key[5] = ((key_56[4] << 3) & 0xFF) | (key_56[5] >> 5);
    key[6] = ((key_56[5] << 2) & 0xFF) | (key_56[6] >> 6);
    key[7] =  (key_56[6] << 1) & 0xFF;

    DES_set_odd_parity(&key);
    DES_set_key_unchecked(&key, ks);
}

/*
 * takes a 21 byte array and treats it as 3 56-bit DES keys. The
 * 8 byte plaintext is encrypted with each key and the resulting 24
 * bytes are stored in the results array.
 */
static void calc_resp(unsigned char *keys, unsigned char *plaintext, unsigned char *results)
{
    des_key_schedule ks;

    setup_des_key(keys, &ks);
    DES_ecb_encrypt((const_DES_cblock *) plaintext, (const_DES_cblock *) results, &ks, DES_ENCRYPT);

    setup_des_key(keys+7, &ks);
    DES_ecb_encrypt((const_DES_cblock*) plaintext, (const_DES_cblock*) (results+8), &ks, DES_ENCRYPT);

    setup_des_key(keys+14, &ks);
    DES_ecb_encrypt((const_DES_cblock*) plaintext, (const_DES_cblock*) (results+16), &ks, DES_ENCRYPT);
}


#if 0

int main (int argc, char **argv) {

#define NTLM_USER "FENETRE2k\\at"
#define NTLM_PASS "zboing"
//#define URL "GET http://www.hsc.fr/ HTTP/1.1"
#define URL "CONNECT www.hsc.fr:443 HTTP/1.1"
#define SERVER_IP "192.70.106.143"
#define SERVER_HOST "www.hsc.fr"

  int sock;
  struct sockaddr_in sin;
  in_addr_t addr;
  char buffer[2048];
  int one = 1;
  int len;

   if ( ( sock = socket (AF_INET , SOCK_STREAM, 0 ) )  < 0 ) {
     fprintf(stderr,"socket : %s\n", strerror(errno)); 
     return (-2) ;
  }
  
  memset( &sin, sizeof(struct sockaddr_in), 0 );
  sin.sin_family = AF_INET;


  addr = inet_addr ( SERVER_IP );
  sin.sin_port = ntohs ( 8080 );
  sin.sin_addr.s_addr = addr;

  setsockopt ( sock, 6, TCP_NODELAY , &one, sizeof (int) );

  if ( connect ( sock ,
                 ( struct sockaddr * ) &sin ,
                 sizeof ( sin ) ) == -1 )
  {
     fprintf(stderr, "connect: %s", strerror(errno)) ;
     exit(0);
  }
  ioctl(sock, FIONBIO, &one);
  if (do_ntlm_auth (sock, NTLM_USER , NTLM_PASS , URL , SERVER_HOST, PROXY_AUTH) == 0) {
    while ((len = read_timeout(sock, buffer, 2048, 10 )) > 0) write(1, buffer, len);
  }
  close(sock);
  return 0; 
}

#endif

static void dump_hex(unsigned char *string, int len) {
  int i;
  
  for (i=0;i<len;i++) fprintf(stderr,"%02x ",(unsigned int) string[i]);
  fprintf(stderr,"\n");
}



syntax highlighted by Code2HTML, v. 0.9.1