#include "findmtu.h"
#include <net/route.h>

/**
 * Set up ICMPv6 error reception
 * Input: sending socket fd
 */
int recv_init(int sendfd) {
    int recvfd;	
    struct icmp6_filter filter;

    recvfd = socket(PF_INET6, SOCK_RAW, IPPROTO_ICMPV6);
    if(recvfd < 0) {
        perror("socket");
        exit(1);
    }

    ICMP6_FILTER_SETBLOCKALL(&filter);
    ICMP6_FILTER_SETPASS(ICMP6_DST_UNREACH, &filter);
    ICMP6_FILTER_SETPASS(ICMP6_PACKET_TOO_BIG, &filter);
    ICMP6_FILTER_SETPASS(ICMP6_TIME_EXCEEDED, &filter);
    ICMP6_FILTER_SETPASS(ICMP6_PARAM_PROB, &filter);

    if (setsockopt(recvfd, IPPROTO_ICMPV6, ICMP6_FILTER, &filter, sizeof(filter)) < 0) {
        perror("setsockopt ICMP6_FILTER");
        exit(-1);
    }

    return recvfd;    
}

int wait_for_reply(int fd, struct icmpv6responsefilter *filter) {
    char packet[DEFAULT_MTU];
    int ret;
    struct icmp6_hdr *icmp6;
    fd_set fds;
    struct timeval time1, time2, tv = {2,0};

    FD_ZERO(&fds);
    FD_SET(fd, &fds);

    gettimeofday(&time1, NULL);
    while(select(fd+1, &fds, NULL, NULL, &tv) != 0) {
        /* Update timeout */
        gettimeofday(&time2, NULL);
        /* Check for carry, or we'll be sleeping forever... */
        if(time2.tv_usec < time1.tv_usec) {
            time2.tv_sec = time2.tv_sec - 1;
            time2.tv_usec = time2.tv_usec + 1000000L;
        }
        tv.tv_sec = time2.tv_sec - time1.tv_sec;
        tv.tv_usec = time2.tv_usec - time1.tv_usec;
        time1 = time2;

        /* Paranoia helps sometimes. If 0s timeout, select will wait forever. */
        if(tv.tv_sec == 0 && tv.tv_usec == 0) {
            tv.tv_sec = 0;
            tv.tv_usec = 100000L;
        }

        /* Now that's settled, get waiting packet */
        ret = recvfrom(fd, packet, sizeof(packet), MSG_PEEK, NULL, NULL);

        /* Could it be ours? */
        if (ret < sizeof(struct icmp6_hdr) + sizeof(struct ip6_hdr)) {
            /* If not, remove it from the queue */
            recvfrom(fd, packet, sizeof(packet), 0, NULL, NULL);
            continue;
        }

        /* Is it ours? */
        icmp6 = (struct icmp6_hdr *) packet;
        if(match_icmpv6_response((struct ip6_hdr *) ((char *)icmp6 + 8), filter)) {
            return 1;
        } else {
            /* If not, remove it from the queue */
            recvfrom(fd, packet, sizeof(packet), 0, NULL, NULL);
        }
    }
    return 0;
}

/**
 * Process ICMPv6 messages from a socket
 * Input: socket fd
 * Output: struct mtureply with address of reporting host, MTU, and ICMP type/code
 */

struct mtureply recvmtu(int fd) {
    struct mtureply reply;
    char packet[DEFAULT_MTU];
    struct sockaddr_in6 addr;
    int len, ret;
    struct icmp6_hdr *icmp6;

    /* Get waiting packet */
    len = sizeof(addr);
    ret = recvfrom(fd, packet, sizeof(packet), 0, (struct sockaddr *) &addr, &len);
    inet_ntop(AF_INET6, &addr.sin6_addr, reply.addr, sizeof(reply.addr));
    icmp6 = (struct icmp6_hdr *) packet;

    /* What type of reply did we get? */
    if(icmp6->icmp6_type == ICMP6_PACKET_TOO_BIG) {
        /* Path MTU response */
        reply.mtu = ntohl(icmp6->icmp6_mtu);
    } else if(icmp6->icmp6_type == ICMP6_DST_UNREACH &&
              icmp6->icmp6_code == ICMP6_DST_UNREACH_NOPORT) {
        /* Host reached */
        reply.mtu = -1;
    } else {
        /* Other error */
        reply.mtu = -1;
	reply.ee_type = icmp6->icmp6_type;
        reply.ee_code = icmp6->icmp6_code;
    }

    return reply;
}

struct icmpv6responsefilter *getfilter(int fd) {
    struct icmpv6responsefilter *filter;
    struct sockaddr_in6 src, dst;
    int len;

    if(!(filter = malloc(sizeof(filter)))) {
        perror("allocating filter");
        exit(-1);
    }

    len = sizeof(src);
    getsockname(fd, (struct sockaddr *) &src, &len);
    len = sizeof(dst);
    getpeername(fd, (struct sockaddr *) &dst, &len);

    filter->src = src.sin6_addr;
    filter->dst = dst.sin6_addr;
    filter->srcport = src.sin6_port;
    filter->dstport = dst.sin6_port;

    return filter;
}

int match_icmpv6_response(struct ip6_hdr *ip6h, struct icmpv6responsefilter *filter) {
    struct udphdr *udph;

    if(ip6h->ip6_nxt != IPPROTO_UDP)
        return 0;

    udph = (struct udphdr *) (ip6h + 1);

    if(memcmp(&filter->dst, &ip6h->ip6_dst, sizeof(ip6h->ip6_src)))
        return 0;
    if(memcmp(&filter->src, &ip6h->ip6_src, sizeof(ip6h->ip6_dst)))
        return 0;

    return (filter->srcport == udph->uh_sport && filter->dstport == udph->uh_dport);
}


syntax highlighted by Code2HTML, v. 0.9.1