#define _GNU_SOURCE

#include "findmtu.h"

/**
 * Send a probe packet
 * Input: socket fd, data size
 * Output: bytes written
 */
int send_packet(int fd, int mtu) {
    char sndbuf[DEFAULT_MTU];

    memset(sndbuf, '\0', sizeof(sndbuf));
    return send(fd, sndbuf, mtu - 48, 0);
}

/**
 * Open sending socket to given host
 *
 */
int opensocket(char *host) {
    int fd;
    struct sockaddr_in6 sin6;
    struct hostent *he;

    fd = socket(AF_INET6, SOCK_DGRAM, 0);
    if (fd < 0) {
        perror("socket");
        exit(1);
    }

    sin6.sin6_family = AF_INET6;
    he = gethostbyname2(host, AF_INET6);
    if(!he) {
        fprintf(stderr, "Unknown host %s\n", host);
        exit(1);
    }
    
    sin6.sin6_port = htons(0x8000 | getpid());
    memcpy(&sin6.sin6_addr, he->h_addr, 16);

    if(connect(fd, (struct sockaddr*)&sin6, sizeof(sin6)) < 0) {
        perror("connect");
        exit(1);
    }

    return fd;
}

void printerrormessage(char *addr, u_int8_t type, u_int8_t code) {
    char unreach_codes[][16] = { "noroute", "admin", "notneighbor", "addr", "noport", "unknown" };

    printf("%s ", addr);
    switch(type) {
        case ICMP6_DST_UNREACH:
            if(code >= sizeof(unreach_codes) / sizeof(unreach_codes[0]))
                code = sizeof(unreach_codes) / sizeof(unreach_codes[0]) - 1;
            printf("unreach-%s", unreach_codes[code]);
            break;
        case ICMP6_TIME_EXCEEDED:
            switch(code) {
                case ICMP6_TIME_EXCEED_TRANSIT:
                    printf("hoplimit");
                    break;
                case ICMP6_TIME_EXCEED_REASSEMBLY:
                    printf("reasm");
                    break;
                default:
                    printf("exceeded-unknown");
                    break;
             }
            break;
        case ICMP6_PARAM_PROB:
            printf("paramprob");
            break;
    }
}

int main(int argc, char **argv) {
    struct mtureply mtuvalues[32];
    int current, currentmtu = DEFAULT_MTU;
    int sendfd, recvfd, i;
    struct mtureply reply;
    struct icmpv6responsefilter *filter;

    if(argc < 2) {
        fprintf(stderr, "Usage: %s <host>\n", argv[0]);
        exit(1);
    }

    sendfd = opensocket(argv[1]);
    recvfd = recv_init(sendfd);
    filter = getfilter(sendfd);

    /* Try to solicit a Packet Too Big ICMPv6 error three times.
     * If this fails, suppose the current MTU is the real MTU
     * and give up (could be a dead host or a black hole).
     */
    current = 0;
    for(i = 3; i > 0; i--) {
        send_packet(sendfd, currentmtu);
        if(wait_for_reply(recvfd, filter))
            reply = recvmtu(recvfd);
        else
            continue;

        if(reply.mtu > 0) {          /* Packet too big: lower estimate and start again */
            currentmtu = reply.mtu;
            i = 3;
            if(reply.addr[0]) {
                mtuvalues[current++] = reply;
            }
        } else if(reply.mtu == -1) { /* Destination reached or other error. Stop here. */
                mtuvalues[current++] = reply;
            break;
        }
    }

    printf("%d ", currentmtu);

    /* Did we only get an error? */
    if( (current == 0) && (mtuvalues[0].ee_type != 0) && (mtuvalues[0].ee_code != 0) ) {
        printerrormessage(mtuvalues[0].addr, mtuvalues[0].ee_type, mtuvalues[0].ee_code);
    }

    if(current > 0) {
        printf("(");
        for (i = 0; i < current; i++) {
            if(*mtuvalues[i].addr != '\0') {
                if(i != 0) printf(", ");
                if( (mtuvalues[i].ee_type) == 0 && (mtuvalues[i].ee_code == 0) ) {
                    if(mtuvalues[i].mtu == -1) {
                        printf("%s reached", mtuvalues[i].addr);
		    } else {
                    	printf("%s %d", mtuvalues[i].addr, mtuvalues[i].mtu);
                    }
                } else {
                    printerrormessage(mtuvalues[i].addr, mtuvalues[i].ee_type,
                                      mtuvalues[i].ee_code);
                }
            }
        }
        printf(")");
    }
    printf("\n");
    return 0;
}


syntax highlighted by Code2HTML, v. 0.9.1