/*
 * dnsutl - utilities to make DNS easier to configure
 * Copyright (C) 1991-1993, 1995, 1996, 1999-2001, 2006, 2007 Peter Miller
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program. If not, see <http://www.gnu.org/licenses/>.
 */

#include <ac/ctype.h>
#include <ac/errno.h>
#include <ac/stdarg.h>
#include <ac/stdio.h>
#include <ac/stdlib.h>
#include <ac/string.h>
#include <ac/time.h>
#include <ac/unistd.h>

#include <srrf.h>
#include <srrf/private.h>
#include <srrf/origin.h>
#include <mem.h>
#include <error.h>

typedef enum lex_token_t lex_token_t;
enum lex_token_t
{
    lex_token_config,
    lex_token_eof,
    lex_token_eoln,
    lex_token_include,
    lex_token_line,
    lex_token_number,
    lex_token_origin,
    lex_token_string,
    lex_token_ttl
};

typedef struct table_t table_t;
struct table_t
{
    const char      *name;
    int             token;
    int             number;
};

static const table_t table[] =
{
    {"$config", lex_token_config, 0,},
    {"$include", lex_token_include, 0,},
    {"$line", lex_token_line, 0,},
    {"$origin", lex_token_origin, 0,},
    {"$ttl", lex_token_ttl, 0,},
};

typedef struct context_t context_t;
struct context_t
{
    string_ty       *fname;
    string_ty       *logical_file_name;
    FILE            *lex_fp;
    long            lino;
    context_t       *prev;
};

static context_t *context;
static int      lex_token;
static int      lex_nest;
static string_ty *lex_value;
static int      lex_number;
static int      lex_tail;
static srrf_class_ty *lex_cur_class;
static srrf_type_ty *lex_cur_type;
static string_ty *last_name;
static int      error_count;
static strlist_ty include;
static int      include_specified;
static int      maximum_name_length = 15;
static int      warning_name_length;
static int      allow_upper_case;
static string_ty *dot;
static int      default_time_to_live;
static int      record_line_number;


static string_ty *
path_join(string_ty *s1, string_ty *s2)
{
    char            *cp;

    cp = s2->str_text;
    while (*cp == '/')
        ++cp;
    if (str_equal(s1, dot))
        return str_from_c(cp);
    return str_format("%s/%s", s1->str_text, cp);
}


string_ty *
srrf_find(string_ty *s)
{
    size_t          j;

    if (s->str_text[0] == '/' && access(s->str_text, F_OK) >= 0)
        return str_copy(s);
    if (!dot)
        dot = str_from_c(".");
    if (!include.nstrings)
        strlist_append(&include, dot);
    for (j = 0; j < include.nstrings; ++j)
    {
        string_ty      *file_name;

        file_name = path_join(include.string[j], s);
        if (access(file_name->str_text, F_OK) >= 0)
            return file_name;
        str_free(file_name);
    }
    return str_copy(s);
}


static void
lex_include(string_ty *s)
{
    context_t      *cp;
    string_ty      *file_name;
    FILE           *fp = 0;

    /*
     * search along the include path for the file
     */
    file_name = srrf_find(s);
    fp = fopen(file_name->str_text, "r");
    if (!fp)
        nfatal("open \"%s\"", file_name->str_text);

    /*
     * allocate a context for this file
     */
    cp = mem_alloc(sizeof(context_t));
    cp->fname = file_name;
    cp->lex_fp = fp;
    cp->logical_file_name = str_copy(s);
    cp->lino = 1;

    cp->prev = context;
    context = cp;
}


void
srrf_close(void)
{
    context_t      *cp;

    assert(context);
    if (error_count)
    {
        cp = context;
        while (cp->prev)
            cp = cp->prev;
        fatal
        (
            "%s: found %d fatal error%s",
            cp->fname->str_text,
            error_count,
            (error_count == 1 ? "" : "s")
        );
    }
    while (context)
    {
        cp = context;
        context = cp->prev;

        assert(cp->lex_fp);
        if (cp->lex_fp != stdin)
            fclose(cp->lex_fp);
        cp->lex_fp = 0;
        str_free(cp->fname);
        cp->fname = 0;
        str_free(cp->logical_file_name);
        cp->logical_file_name = 0;
        cp->prev = 0;
        mem_free(cp);
    }
}


static int
is_a_number(const char *s)
{
    while (*s && strchr("0123456789", *s))
        ++s;
    return !*s;
}


static int
lex_getc(void)
{
    int             c;

    for (;;)
    {
        assert(context);
        c = getc(context->lex_fp);
        if (c == EOF)
        {
            if (ferror(context->lex_fp))
                nfatal("read \"%s\"", context->fname->str_text);
            if (context->prev)
            {
                context_t      *cp;

                cp = context;
                context = cp->prev;
                str_free(cp->fname);
                str_free(cp->logical_file_name);
                mem_free(cp);
                continue;
            }
            return EOF;
        }
        break;
    }
    if (c == '\n')
        context->lino++;
    return c;
}


static void
lex_getc_undo(int c)
{
    if (c == EOF)
        return;
    if (c == '\n')
        context->lino--;
    ungetc(c, context->lex_fp);
}


static int
reserved(const char *s)
{
    const table_t        *tp;

    if (lex_tail)
        return 0;
    for (tp = table; tp < ENDOF(table); ++tp)
    {
        if (!strcasecmp(s, tp->name))
        {
            lex_token = tp->token;
            lex_number = tp->number;
            return 1;
        }
    }
    return 0;
}


void
srrf_lex_error(const char *s, ...)
{
    va_list         ap;
    char            buffer[1 << 11];

    va_start(ap, s);
    vsprintf(buffer, s, ap);
    va_end(ap);
    assert(context);
    error("%s: %d: %s", context->logical_file_name->str_text,
        record_line_number, buffer);
    ++error_count;
    if (error_count >= 20)
    {
        context_t      *cp;

        cp = context;
        while (cp->prev)
            cp = cp->prev;
        fatal("%s: too many errors, aborting", cp->fname->str_text);
    }
}


static void
lex_warning(const char *s, ...)
{
    va_list         ap;
    char            buffer[1 << 11];

    va_start(ap, s);
    vsprintf(buffer, s, ap);
    va_end(ap);
    assert(context);
    error("%s: %d: warning: %s", context->logical_file_name->str_text,
        record_line_number, buffer);
}


static void
expected(const char *s)
{
    srrf_lex_error("expected %s", s);
}


static lex_token_t
lex(void)
{
    int             c;
    char            *cp;
    char            buffer[1 << 12];

    if (lex_value)
        str_free(lex_value);
    lex_value = 0;
    lex_number = 0;
    for (;;)
    {
        c = lex_getc();
        switch (c)
        {
        case EOF:
            lex_token = lex_token_eof;
            goto ret;

        default:
            cp = buffer;
            for (;;)
            {
                *cp++ = c;
                if (cp >= ENDOF(buffer))
                    srrf_lex_error("name too long");
                c = lex_getc();
                if (isspace(c) || c == ';' || c == ')')
                {
                    lex_getc_undo(c);
                    break;
                }
            }
            *cp = 0;
            if (reserved(buffer))
                goto ret;
            lex_value = str_from_c(buffer);
            if (is_a_number(buffer))
                lex_token = lex_token_number;
            else
                lex_token = lex_token_string;
            goto ret;

        case ';':
            for (;;)
            {
                c = lex_getc();
                if (c == '\n' || c == EOF)
                {
                    lex_getc_undo(c);
                    break;
                }
            }
            break;

        case ' ':
        case '\t':
            break;

        case '(':
            lex_nest++;
            break;

        case ')':
            if (lex_nest <= 0)
                srrf_lex_error("')' without '('");
            lex_nest--;
            break;

        case '\n':
            if (lex_nest > 0)
                break;
            lex_token = lex_token_eoln;
            lex_tail = 0;
            goto ret;

        case '"':
            cp = buffer;
            for (;;)
            {
                c = fgetc(context->lex_fp);
                if (c == '"')
                    break;
                if (c == '\n' || c == EOF)
                {
                  yuck:
                    lex_getc_undo(c);
                    srrf_lex_error("unterminated string");
                }
                if (c == '\\')
                {
                    int             n;
                    int             value;

                    c = fgetc(context->lex_fp);
                    switch (c)
                    {
                    case '\n':
                    case EOF:
                        goto yuck;

                    default:
                        srrf_lex_error("unknown '\\%c' escape sequence", c);
                        /* fall through... */

                    case '"':
                    case '\\':
                        *cp++ = c;
                        break;

                    case '0':
                    case '1':
                    case '2':
                    case '3':
                    case '4':
                    case '5':
                    case '6':
                    case '7':
                    case '8':
                    case '9':
                        /*
                         * RFC1035 says escapes
                         * are decimal
                         */
                        value = 0;
                        for (n = 0; n < 3; ++n)
                        {
                            value = value * 10 + c - '0';
                            c = fgetc(context->lex_fp);
                            switch (c)
                            {
                            case '0':
                            case '1':
                            case '2':
                            case '3':
                            case '4':
                            case '5':
                            case '6':
                            case '7':
                            case '8':
                            case '9':
                                continue;

                            default:
                                break;
                            }
                            break;
                        }
                        if (c != EOF)
                            ungetc(c, context->lex_fp);
                        if (value >= 256)
                            srrf_lex_error("escape '\\%d' out of range", value);
                        *cp++ = value;
                        break;
                    }
                }
                else
                    *cp++ = c;
                if (cp >= ENDOF(buffer))
                    srrf_lex_error("string too long");
            }
            *cp = 0;
            lex_value = str_from_c(buffer);
            lex_token = lex_token_string;
            goto ret;
        }
    }
  ret:
    assert(!lex_value || str_valid(lex_value));
#if 0
    error("lex -> token=%d value=\"%s\" number=%d", lex_token,
        (lex_value ? lex_value->str_text : ""), lex_number);
#endif
    return lex_token;
}


static char *
suffix_start(string_ty *s1, string_ty *s2)
{
    /* Be careful when s2 is root. */
    if (s2->str_text[0] == '.' && s2->str_text[1] == 0)
    {
        if (s1->str_text[s1->str_length - 1] == '.')
            return (s1->str_text + s1->str_length - 1);
        return 0;
    }

    if
    (
        s1->str_length > s2->str_length
    &&
        s1->str_text[s1->str_length - s2->str_length - 1] == '.'
    &&
        !memcmp
        (
            s1->str_text + s1->str_length - s2->str_length,
            s2->str_text,
            s2->str_length
        )
    )
        return (s1->str_text + s1->str_length - s2->str_length - 1);
    return 0;
}


static void
check_name(string_ty *orig, int length_flag)
{
    char            *s;
    char            *ep;
    int             ill;
    int             upper;
    int             c;
    char            *the_end;
    string_ty       *origin;

    ill = 0;
    upper = 0;
    s = orig->str_text;
    origin = srrf_origin_get();

    /*
     * If the name is not in our domain, ignore it.  If the name is
     * in our domain, only check the part before the domain name.
     */
    if (!origin)
        the_end = orig->str_text + orig->str_length;
    else
    {
        the_end = suffix_start(orig, origin);
        if (!the_end)
            return;
    }

    /*
     * This is a hack to allow MX records to work.
     */
    if (s[0] == '*' && s[1] == '.' && s[2])
        s += 2;
    if (s >= the_end)
        return;

    for (;;)
    {
        ep = s;
        while (ep < the_end && *ep != '.')
        {
            c = *ep;
            if (isupper(c) && !allow_upper_case)
                ++upper;
            if (!isalnum(c) && c != '-')
                ++ill;
            ++ep;
        }
        if (s == ep && *ep)
        {
            /* Watch out for root. */
            if (orig->str_text[0] != '.' || orig->str_text[1] != 0)
                srrf_lex_error("name \"%s\" has an empty component",
                    orig->str_text);
        }
        if (length_flag)
        {
            if (maximum_name_length > 0 && ep - s > maximum_name_length)
            {
                srrf_lex_error
                (
                    "name \"%s\", component \"%.*s\" is far too long (by %d)",
                    orig->str_text,
                    (int)(ep - s),
                    s,
                    (int)(ep - s - 8)
                );
            }
            else if (warning_name_length > 0 && ep - s > warning_name_length)
            {
                lex_warning
                (
                    "name \"%s\", component \"%.*s\" is too long (by %d)",
                    orig->str_text,
                    (int)(ep - s),
                    s,
                    (int)(ep - s - 8)
                );
            }
        }
        s = ep;
        if (s >= the_end)
            break;
        if (*s == '.')
            ++s;
    }
    if (upper)
    {
        lex_warning
        (
            "name \"%s\" contains %d upper case character%s; "
                "please use lower-case names",
            orig->str_text,
            upper,
            (upper == 1 ? "" : "s")
        );
    }
    if (ill)
    {
        srrf_lex_error
        (
            "name \"%s\" contains %d illegal character%s; "
                "only letters, digits, hyphen and dot are legal",
            orig->str_text,
            ill,
            (ill == 1 ? "" : "s")
        );
    }
}


typedef struct cfg_table_ty cfg_table_ty;
struct cfg_table_ty
{
    const char  *name;
    int         *value;
};

static cfg_table_ty cfg_table[] =
{
    {"maximum_name_length", &maximum_name_length,},
    {"warning_name_length", &warning_name_length,},
    {"allow_upper_case", &allow_upper_case,},
};


static void
config(string_ty *name, int value)
{
    cfg_table_ty    *tp;

    for (tp = cfg_table; tp < ENDOF(cfg_table); ++tp)
    {
        if (!strcasecmp(name->str_text, tp->name))
        {
            *tp->value = value;
            return;
        }
    }
    srrf_lex_error("config parameter \"%s\" unknown", name->str_text);
}


static void
parse_origin_line(void)
{
    if (lex_token != lex_token_string)
    {
        expected("domain");
chew_until_eoln:
        while (lex_token != lex_token_eoln && lex_token != lex_token_eoln)
            lex();
        return;
    }
    check_name(lex_value, 0);
    srrf_origin_set(lex_value);
    if (!last_name)
        last_name = str_copy(lex_value);
    lex();
    if (lex_token != lex_token_eoln)
    {
        expected("eoln");
        goto chew_until_eoln;
    }
    assert(str_valid(srrf_origin_get()));
}


static void
parse_include_line(void)
{
    string_ty       *new_file_name;

    if (lex_token != lex_token_string)
    {
        expected("file name");
        chew_until_eoln:
        while (lex_token != lex_token_eoln && lex_token != lex_token_eof)
            lex();
        return;
    }
    new_file_name = str_copy(lex_value);
    lex();
    if (lex_token != lex_token_eoln)
    {
        expected("eoln");
        goto chew_until_eoln;
    }
    lex_include(new_file_name);
    str_free(new_file_name);
}


static void
parse_line_line(void)
{
    long            new_line_number;
    string_ty       *new_file_name;

    if (lex_token != lex_token_number)
    {
        expected("number");
        chew_until_eoln:
        while (lex_token != lex_token_eof && lex_token != lex_token_eoln)
            lex();
        return;
    }
    new_line_number = atol(lex_value->str_text);
    lex();
    if (lex_token != lex_token_string)
    {
        expected("file name");
        goto chew_until_eoln;
    }
    new_file_name = str_copy(lex_value);
    lex();
    if (lex_token != lex_token_eoln)
    {
        expected("end of line");
        goto chew_until_eoln;
    }
    str_free(context->logical_file_name);
    context->logical_file_name = new_file_name;
    context->lino = new_line_number;
}


static void
parse_config_line(void)
{
    string_ty       *new_file_name;
    long            new_line_number;

    if (lex_token != lex_token_string)
    {
        expected("config name");
        chew_until_eoln:
        while (lex_token != lex_token_eof && lex_token != lex_token_eoln)
            lex();
        return;
    }
    new_file_name = str_copy(lex_value);
    lex();
    if (lex_token != lex_token_number)
    {
        expected("number");
        goto chew_until_eoln;
    }
    new_line_number = atol(lex_value->str_text);
    config(new_file_name, new_line_number);
    lex();
    if (lex_token != lex_token_eoln)
    {
        expected("end of line");
        goto chew_until_eoln;
    }
}


static void
parse_ttl_line(void)
{
    long            new_ttl;

    if (lex_token != lex_token_number)
    {
        expected("number");
        chew_until_eoln:
        while (lex_token != lex_token_eof && lex_token != lex_token_eoln)
            lex();
        return;
    }
    new_ttl = atol(lex_value->str_text);
    lex();
    if (lex_token != lex_token_eoln)
    {
        expected("end of line");
        goto chew_until_eoln;
    }
    default_time_to_live = new_ttl;
}


static srrf_t *
srrf_test_fitting(strlist_ty *line)
{
    srrf_t          *result;
    size_t          j;

    /*
     * Reject trivially invalid lines.
     */
    if (line->nstrings < 4)
    {
        return 0;
    }

    /*
     * The TTL must be a valid number.
     */
    if (!is_a_number(line->string[1]->str_text))
    {
        return 0;
    }

    /*
     * Make sure the class is valid.
     */
    srrf_class_ty *class_p =
        srrf_class_by_name(line->string[2]->str_text);
    if (!class_p)
    {
        return 0;
    }

    /*
     * Make sure the type is valid.
     */
    srrf_type_ty *type_p =
        srrf_type_by_name(class_p, line->string[3]->str_text);
    if (!type_p)
    {
        return 0;
    }

    /*
     * Looks like a valid line.
     */
    result = srrf_alloc();
    result->file_name = str_copy(context->logical_file_name);
    result->line_number = record_line_number;

    result->name = srrf_relative_to_absolute(line->string[0]);
    check_name(result->name, 1);
    if (last_name)
        str_free(last_name);
    last_name = str_copy(result->name);

    /*
     * get the time to live
     */
    result->ttl = atol(line->string[1]->str_text);

    /*
     * get the class
     */
    result->class = class_p;
    lex_cur_class = class_p;

    /*
     * get the type
     */
    result->type = type_p;
    lex_cur_type = type_p;

    /*
     * get all the other stuff
     */
    for (j = 4; j < line->nstrings; ++j)
    {
        strlist_append(&result->arg, line->string[j]);
    }

    if
    (
        result->type->number_of_arguments
    &&
        result->type->number_of_arguments != result->arg.nstrings
    )
    {
        string_ty      *s;

        srrf_lex_error
        (
            "'%s' requires %d arguments, but %d were given",
            result->type->name,
            result->type->number_of_arguments,
            result->arg.nstrings
        );
        s = str_from_c("?");
        while (result->arg.nstrings < result->type->number_of_arguments)
            strlist_append(&result->arg, s);
        str_free(s);
    }

    /*
     * check the arguments, even if we have added bogus args,
     * because the arg check can often have side effects
     */
    if (result->type->check_arguments)
        result->type->check_arguments(result);

    return result;
}


/*
 * The defaulting rules for the first four SRRF fields are arcane
 * and mysterious.  The ones implemented here may not be correct.
 * We read the whole line, and try 16 different combinations of
 * defaults, stopping at the first line which actually makes any
 * kind of sense.
 */
static srrf_t *
parse_srrf_line(void)
{
    srrf_t          *result;
    strlist_ty      line;
    int             j;

    /*
     * Read all of the command line arguments.
     */
    lex_tail = 1;
    strlist_zero(&line);
    for (;;)
    {
        strlist_append(&line, lex_value);
        lex();
        switch (lex_token)
        {
        case lex_token_eof:
        case lex_token_eoln:
            break;

        default:
            continue;
        }
        break;
    }

    /*
     * Now try a bunch of different defaulting alternatives,
     * stopping at the first which makes any sense.
     */
    for (j = 0; j < 16; ++j)
    {
        size_t          pos;
        strlist_ty      trial;

        pos = 0;
        strlist_zero(&trial);

        /*
         * Default the name.
         */
        if (j & 2)
        {
            if (last_name)
                strlist_append(&trial, last_name);
            else
                strlist_append(&trial, str_from_c("bogus."));
        }
        else
        {
            if (pos < line.nstrings)
                strlist_append(&trial, line.string[pos++]);
        }

        /*
         * Default the TTL.
         */
        if (j & 1)
        {
            string_ty       *s;

            if (default_time_to_live)
                s = str_format("%d", default_time_to_live);
            else
                s = str_from_c("0");
            strlist_append(&trial, s);
            str_free(s);
        }
        else
        {
            if (pos < line.nstrings)
                strlist_append(&trial, line.string[pos++]);
        }

        /*
         * Default the class.
         */
        if (j & 4)
        {
            string_ty       *s;

            if (lex_cur_class)
                s = str_from_c(lex_cur_class->name);
            else
                s = str_from_c("?");
            strlist_append(&trial, s);
            str_free(s);
        }
        else
        {
            if (pos < line.nstrings)
                strlist_append(&trial, line.string[pos++]);
        }

        /*
         * Default the type.
         */
        if (j & 8)
        {
            string_ty       *s;

            if (lex_cur_type)
                s = str_from_c(lex_cur_type->name);
            else
                s = str_from_c("?");
            strlist_append(&trial, s);
            str_free(s);
        }
        else
        {
            if (pos < line.nstrings)
                strlist_append(&trial, line.string[pos++]);
        }

        /*
         * Now append everything else.
         */
        while (pos < line.nstrings)
            strlist_append(&trial, line.string[pos++]);

        /*
         * See if the trial line we created is a valid SRRF record.
         */
        result = srrf_test_fitting(&trial);
        strlist_free(&trial);
        if (result)
        {
            strlist_free(&line);
            return result;
        }
    }
    srrf_lex_error("syntax error");
    return 0;
}


void
srrf_open(const char *s)
{
    context_t       *cp;

    cp = mem_alloc(sizeof(context_t));
    cp->prev = 0;
    if (s)
    {
        cp->fname = str_from_c(s);
        cp->lex_fp = fopen(s, "r");
        if (!cp->lex_fp)
            nfatal("open \"%s\"", s);
    }
    else
    {
        cp->fname = str_from_c("standard input");
        cp->lex_fp = stdin;
    }
    cp->logical_file_name = str_copy(cp->fname);
    cp->lino = 1;
    error_count = 0;
    default_time_to_live = 0;
    lex_cur_class = 0;
    lex_cur_type = 0;

    context = cp;

    /*
     * Read the first token.
     * The srrf_read function expects this.
     */
    lex();
}


srrf_t *
srrf_read(void)
{
    assert(context);
    assert(context->lex_fp);
    for (;;)
    {
        record_line_number = context->lino;
        switch (lex_token)
        {
        case lex_token_eof:
            return 0;

        case lex_token_eoln:
            lex();
            break;

#ifndef DEBUG
        default:
#endif
        case lex_token_number:
        case lex_token_string:
            return parse_srrf_line();

        case lex_token_origin:
            lex();
            parse_origin_line();
            break;

        case lex_token_include:
            lex();
            parse_include_line();
            break;

        case lex_token_line:
            lex();
            parse_line_line();
            break;

        case lex_token_config:
            lex();
            parse_config_line();
            break;

        case lex_token_ttl:
            lex();
            parse_ttl_line();
            break;
        }
    }
}


int
srrf_print_config(FILE *fp)
{
    int             result;
    cfg_table_ty    *tp;

    for (tp = cfg_table; tp < ENDOF(cfg_table); ++tp)
        fprintf(fp, "$config %s %d\n", tp->name, *tp->value);
    result = SIZEOF(cfg_table);
    return result;
}


void
srrf_include_path(const char *s)
{
    string_ty      *s2;

    s2 = str_from_c(s);
    strlist_append_unique(&include, s2);
    str_free(s2);
    include_specified = 1;
}


int
srrf_include_path_specified(void)
{
    return include_specified;
}


syntax highlighted by Code2HTML, v. 0.9.1