"""$URL: svn+ssh://svn/repos/trunk/grouch/lib/schema_util.py $
$Id: schema_util.py 24750 2004-07-21 15:26:51Z dbinger $

Search a set of directories for Python modules and classes, and parse
the class docstrings to generate an object schema.  Lots of
project-specific information (eg. directories to search, classes to
exclude, ...) can be supplied via a project description file.
"""

import sys, os, string, re
import getopt
import types
import traceback
from fnmatch import fnmatch
from cPickle import dump

from grouch.schema import ObjectSchema, ClassDefinition, InvalidAlias
from grouch.script_util import announce, error, warn


# -- Mid-level workers -------------------------------------------------
# (called by generate_schema())

def find_modules (dirs, base_dir=None, prefix=None, exclude=None):
    """find_modules(dirs : [string],
                    base_dir : string = None,
                    prefix : string = None,
                    exclude : [string] = None)
       -> [(modname : string, filename : string)]

    Searches a list of directories for Python module files (*.py).
    Returns a list of (modname, filename) tuples.  If 'prefix' is
    supplied, it is prepended (with a dot interpolated) to each module
    name.  If 'exclude' is supplied, it must be a list of
    fully-qualified module names; if a module is found in 'exclude'
    (after adding prefix), it will not be included in the returned list
    of modules.

    Example: dirs == ["foo"] and directory "foo" contains "bar.py" and
    "baz.py".  If no prefix or base_dir is supplied, returns
      [("foo.bar", "foo/bar.py"),
       ("foo.baz", "foo/baz.py")]

    If base_dir == "d" and prefix == "p", returns
      [("p.foo.bar", "d/foo/bar.py"),
       ("p.foo.baz", "d/foo/baz.py")]
    """
    modules = []                        # list of (modname, filename) tuples
    for dir in dirs:
        components = os.path.normpath(dir).split(os.sep)
        if os.pardir in components:     # disallow "../foo" after normpath()
            raise ValueError, \
                  ("invalid directory '%s': cannot contain '%s'" %
                   (dir, os.pardir))
        if components == [os.curdir]: # eg. dir == "." or ""
            components = []
        if prefix:
            components.insert(0, prefix)

        if base_dir:
            real_dir = os.path.normpath(os.path.join(base_dir, dir))
        else:
            real_dir = os.path.normpath(dir)

        for basename in os.listdir(real_dir):
            # Skip files that aren't Python modules
            if not fnmatch(basename, "*.py") or basename == "__init__.py":
                continue

            real_filename = os.path.join(real_dir, basename)
            bare_modname = os.path.splitext(basename)[0]
            modname = ".".join(components + [bare_modname])

            if modname not in exclude:
                modules.append((modname, real_filename))

    return modules


class ClassInfo:
    """
    Encapsulates all the information we need about a class in order to
    add it to an object schema.  find_classes() returns a list of
    ClassInfo instances to find_all_classes(), which incorporates that
    in a dictionary mapping module names to lists of ClassInfo
    instances.

    Instance attributes:
      bare_name : string
        the bare class name as seen in the "class" statement
      full_name :
        the fully-qualified class name (ie. including module name)
      base_classes : [string]
        the list of base class names from the class statement
      docstring : string
        the class docstring as a single (probably multi-line) string
    """

    def __init__ (self, bare_name, full_name, base_classes, docstring):
        self.bare_name = bare_name
        self.full_name = full_name
        self.base_classes = base_classes
        self.docstring = docstring

    def __str__ (self):
        return self.full_name

    def __repr__ (self):
        return "<%s at %08x: %s>" % (self.__class__.__name__, id(self), self)

    def expand_base_classes (self, schema):
        """
        Ensure that every base class listed in self.base_classes is
        either a class or an alias for a class in 'schema'.  Any base
        class names that are aliases are expanded to the full class
        name.  Raises ValueError if any base class names are bad.
        """
        for i in range(len(self.base_classes)):
            base_name = self.base_classes[i]
            if not schema.get_class_definition(base_name):
                alias = schema.get_alias(base_name)
                if not alias:
                    raise ValueError(
                        "%s: invalid base class %r (no such class or alias)"
                        % (self.full_name, base_name))
                elif not alias.is_plain_instance_type():
                    raise ValueError(
                        "%s: invalid base class %r (alias to non-class)"
                        % (self.full_name, base_name))
                else:
                    self.base_classes[i] = alias.klass_name


def find_all_classes (modules, schema, exclude_classes=None):
    """find_all_classes(modules : [(modname : string, filename : string)],
                        schema : ObjectSchema,
                        exclude_classes : [string] = None)
       -> { modname:string : [ClassInfo] }

    Find all classes in a list of modules.  Returns a dictionary mapping
    module name to list of ClassInfo objects.  For each class found,
    adds two things to schema: an empty ClassDefinition, and an alias
    mapping the bare class name to its full name (eg. for class
    foo.bar.FooBar, the alias maps "FooBar" to "foo.bar.FooBar").
    """

    # Map module name to list of class objects
    module_classes = {}
    num_classes = 0

    announce("looking for classes...\n")
    for (modname, filename) in modules:
        announce("module %s:\n" % modname, threshold=2)
        klasses = find_classes(filename, modname, exclude_classes)
        num_classes += len(klasses)

        ok_klasses = []
        for klass in klasses:
            bare_name = klass.bare_name
            full_name = klass.full_name

            # create new alias mapping bare_name to full_name and add
            # it to the schema?
            add_alias = 1

            alias = schema.get_alias(bare_name)
            if isinstance(alias, InvalidAlias):
                # We've already had one name collision on 'bare_name' and
                # invalidated that alias, so don't try to create it again!
                warn("can't alias %r to %r: name collision"
                     % (bare_name, full_name))
                add_alias = 0

            elif alias:
                # The alias is actually for this class, so it's not a
                # problem after all.
                if (alias.is_instance_type() and
                    alias.get_class_name() == full_name):
                    add_alias = 0

                # Uh-oh, alias already exists and it maps 'bare_name'
                # to something else.  Ambiguity is bad, so invalidate
                # this alias.
                else:
                    warn("invalidating ambiguous alias %r"
                         % bare_name)
                    schema.invalidate_alias(bare_name)
                    add_alias = 0

            # If we need to, add a new class definition to the schema.
            klass_def = schema.get_class_definition(full_name)
            if klass_def is None:
                klass_def = ClassDefinition(full_name, schema)
                schema.add_class(klass_def)

            # The above code decided there's no problem aliasing
            # 'bare_name' to 'full_name', so do it now.
            if add_alias:
                schema.add_alias(bare_name, full_name)

        module_classes[modname] = klasses
        for klass in klasses:
            announce("  %s\n" % klass, threshold=2)

    # This has to be done after scanning all modules, because it depends
    # on us having seen all classes in the application.
    for klasses in module_classes.values():
        for klass in klasses:
            try:
                klass.expand_base_classes(schema)
            except ValueError, err:
                # Should we exclude the class from the schema entirely
                # if this happens?  Right now we just refrain from giving
                # it any base classes if one of its base classes has
                # a problem.
                error(str(err))
            else:
                klass_def = schema.get_class_definition(klass.full_name)
                klass_def.set_bases(klass.base_classes)

    announce("found %d classes\n" % num_classes)
    return module_classes


def get_names (nodes):
    from compiler import ast
    names = []
    for node in nodes:
        if isinstance(node, ast.Name):
            names.append(node.name)
        elif isinstance(node, ast.Getattr):
            # Dotted name: "foo.bar.baz" in source becomes
            # Getattr(Getattr(Name(foo),bar),baz) in the AST.  Unwind
            # this stack to reconstruct the original string.
            cur_node = node
            stack = []
            while isinstance(cur_node, ast.Getattr):
                stack.append(cur_node.attrname)
                cur_node = cur_node.expr
            assert isinstance(cur_node, ast.Name), \
                   "expected Name at bottom of Getattr stack"
            stack.append(cur_node.name)
            stack.reverse()
            name = string.join(stack, ".")
            names.append(name)

    return names


def find_classes (filename, modname, exclude=None):
    """find_classes(filename : string,
                    modname : string,
                    exclude : [string] = None)
       -> [ClassInfo]

    Parses the Python source file 'filename' looking for classes.
    Assumes this source file contains the module 'modname'.  Returns a
    list of ClassInfo instances, one for each class in the file.

    Eg. if parsing a file foo.py which (according to 'modname') contains
    a module 'foo', this code:
      class Foo (bar.Bar):
        '''
        foo
        bar
        '''
    results in a ClassInfo instance like:
      bare_name = "Foo"
      full_name = "foo.Foo"
      base_classes = ["bar.Bar"]
      docstring = "\n    foo\n    bar\n    "

    No attempt is made to determine what "bar.Bar" really refers to --
    that would require interpreting the module, and in that case we
    might as well just import the damn thing.  Note that the
    ClassInfo.expand_base_classes() method, called by find_all_classes()
    just before returning, attempts to expand all base class names to
    their true names.
    """
    from compiler import parseFile, ast
    from parser import ParserError

    klasses = []

    try:
        module_ast = parseFile(filename)
    except ParserError:
        error("%s: unable to parse module (try importing it for more details)"
              % filename)

    for node in module_ast.node.nodes:
        if isinstance(node, ast.Class):
            if modname:
                fullname = modname + "." + node.name
            else:
                fullname = node.name

            if fullname in exclude:
                continue

            base_classes = get_names(node.bases)
            klasses.append(ClassInfo(node.name, fullname,
                                     base_classes, node.doc))

    return klasses


def parse_class_docstrings (modules, module_classes, schema):
    """parse_class_docstrings(modules : [(string, string)],
                              module_classes : {string : [ClassInfo]},
                              schema : ObjectSchema)
    """
    announce("parsing class docstrings...\n")
    for (modname, _) in modules:
        klasses = module_classes.get(modname, [])
        announce("  module: %s (%d classes)\n" % (modname, len(klasses)),
                 threshold=2)
        for klass in klasses:
            try:
                errors = parse_docstring(klass, schema)
            except ValueError, exc:
                error(str(exc))
                announce("   failed to parse %s docstring\n" % klass, threshold=2)
            else:
                for e in errors:
                    warn(e)
                announce("   parsed %s docstring\n" % klass, threshold=2)


# -- Parsing code ------------------------------------------------------
# (parse_class_docstrings() calls parse_docstring(), which uses
#  everything else in this section)

leading_ws_re = re.compile('^(\s*)')

def get_indent_level (s):
    """Return the number of spaces that 's' starts with."""
    m = leading_ws_re.match(s)
    return len(m.group(1))


def clean_docstring (doc, klass_name):
    lines = (doc or '').split("\n")
    assert lines, "string.split() returned empty list"
    if lines[0] and lines[0][0] == ' ':              # first line indented
        leading_indent = get_indent_level(lines[0])
        start_line = 0
    elif len(lines) > 1:                # look for first "real" line
        i = 1
        while i < len(lines) and not lines[i]: # skip over blank lines
            i += 1
        assert i < len(lines), "arg, screwy docstring"
        leading_indent = get_indent_level(lines[i])
        start_line = 1
    else:                               # single-line docstring
        leading_indent = 0
        start_line = 0

    for i in range(start_line, len(lines)):
        if not lines[i]:                # skip blanks
            continue
        if len(lines[i]) < leading_indent:
            raise ValueError, \
                  ("class %s: inconsistent indent in docstring "
                   "(line %d too short)" % (klass_name, i))
        if string.lstrip(lines[i][:leading_indent]) != "":
            raise ValueError, \
                  ("class %s: inconsistent indent in docstring "
                   "(line %d dedented relative to line %d)" %
                   (klass_name, i, start_line))

        lines[i] = re.sub(r'#.*', '', lines[i])
        lines[i] = string.rstrip(lines[i][leading_indent:])

    return lines

def find_attrs (lines, klass_name):
    i = 0
    while i < len(lines):
        line = lines[i]
        if line.startswith("Instance attributes:"):
            if line.endswith("none"):
                return None
            return i+1
        i += 1
    else:
        raise ValueError, \
              ("class %s: no \"Instance attributes:\" line in docstring" %
               klass_name)


_name_pat = r'[a-zA-Z_][a-zA-Z0-9_]*'
_dotted_name_pat = r'%s(?:\.%s)*' % (_name_pat, _name_pat)
_attr_line_re = re.compile(r'\s*(%s)\s*:\s*(.*)' % _name_pat)
_element_name_re = re.compile(r'(%s):(%s)' % (_name_pat, _dotted_name_pat))

def massage_typespec (typespec):
    # Now we have to massage the typespec so it can be parsed as
    # a ValueType.  Luckily, the docstring type specification
    # language is pretty similar to the ValueType type
    # specification language; the differences are:
    #   - container elements can have names as well as types,
    #     eg.  { key:keytype : value:valuetype }
    #     or   (val1:type1, val2:type2)
    #   - plain container types can include the name of the
    #     container, eg.
    #       dictionary {string : int}
    #   - typespecs can be trailed by default value description, eg.
    #       foo : int = 37
    #       bar : [string] = ["hello"]

    # Note that for named container elements, whitespace
    # matters!  {key:keytype : value:valuetype} is *not* the
    # same as {key : keytype : value : valuetype} -- the latter
    # is illegal.

    # Deal with the second exception first, since it's easiest.
    words = string.split(typespec, None, 1)
    if len(words) > 1:
        remainder = words[1]
        if words[0] in ('list', 'tuple', 'dictionary'):
            typespec = remainder
    else:
        remainder = None

    # Strip anything that looks like a default value, ie. " = ..."
    typespec = re.sub(r'\s*=.*', '', typespec)

    # Deal with the named-elements thing.  This is a kludge, but it
    # should work given the tight syntax constraints on named elements.
    if typespec[0] in "[({" or (remainder and remainder[0] in "[({"):
        typespec = _element_name_re.sub(r'\2', typespec)

    return typespec


def parse_docstring(klass, schema):
    """(klass : ClassInfo, schema : ObjectSchema) -> errors : [string]

    Parse the class docstring in 'klass' and use the docstring to update
    the class definition already in 'schema'.  Return a list of error
    messages resulting from parsing the docstring which should be presented
    to the user.  Raises ValueError if the docstring is a lost cause,
    ie. missing or completely unparseable.
    """

    klass_name = klass.full_name        # for error messages

    klass_def = schema.get_class_definition(klass_name)
    if klass_def.attrs or klass_def.all_attrs:
        raise ValueError, "class %s: already seen" % klass_name

    doc = klass.docstring
    if not doc:
        raise ValueError, "class %s: no docstring" % klass_name

    errors, attrs = parse_docstring_helper(schema, doc, klass_name)

    for name, type in attrs:
        if type.is_any_type():
            type.set_allow_any_instance(0)
        klass_def.add_attribute(name, type)

    # XXX ClassDefinition should expose len() of its attrs list
    if len(klass_def.attrs) == 0:
        errors.append("class %s: no attributes successfully parsed" %
                      klass_def.name)

    return errors


def parse_docstring_helper(schema, doc, klass_name):
    """
    This is separated so that it can be used by outside code.
    """

    errors = []
    attrs = []

    lines = clean_docstring(doc, klass_name)
    # Now 'lines' is a list of consistently-indented lines.  Find the
    # one that says "Instance attributes:", and start parsing attribute
    # type specifications from there.  Stop when we get back to
    # zero-indent level.
    start_line = find_attrs(lines, klass_name)
    if start_line is None:              # class is declared to have no
        return errors, attrs            # instance attributes

    i = start_line
    while i < len(lines):
        line = lines[i]
        if not line:                    # skip blanks
            i += 1
            continue
        indent = get_indent_level(line)
        if indent == 0:                 # out of the attribute list
            break

        m = _attr_line_re.match(line)
        if m:
            (name, typespec) = m.group(1,2)
            typespec = massage_typespec(typespec)
            try:
                attrs.append((name, schema.parse_type(typespec)))
            except ValueError, exc:
                errors.append("class %s, attribute %s: %s" %
                              (klass_name, name, exc))
        else:
            errors.append("class %s: couldn't parse line %d of docstring: %s" %
                          (klass_name, i, `line`))

        # Read lines until back at the indent level of this line
        # (ie. skip the indented attribute description).
        i += 1
        while i < len(lines):
            if get_indent_level(lines[i]) <= indent:
                break
            i += 1
    return errors, attrs


# -- High-level workers ------------------------------------------------
# (called from main())

def generate_schema (project, base_dir):
    schema = ObjectSchema()

    # Preparatory work -- stuff that needs to go in the schema, but
    # can't be discovered by searching for and parsing *.py files.
    for atomic_type in project.atomic_types:
        if type(atomic_type) is types.TupleType and len(atomic_type) == 2:
            schema.add_atomic_type(*atomic_type)
        else:
            schema.add_atomic_type(atomic_type)

    for name in project.forward_classes:
        cdef = ClassDefinition(name, schema)
        schema.add_class(cdef)

    for (name, value) in project.type_aliases:
        schema.add_alias(name, value)

    # The meat of the schema: class definitions for classes in whatever
    # *.py files we can find in project.dirs.
    if project.dirs:
        announce("searching for modules...")
        modules = find_modules(project.dirs,
                               base_dir=base_dir,
                               prefix=project.prefix,
                               exclude=project.exclude_modules)
        announce("found %d modules\n" % len(modules))
    else:
        modules = []

    project.add_extra_modules(modules, base_dir)

    module_classes = find_all_classes(modules, schema, project.exclude_classes)
    parse_class_docstrings(modules, module_classes, schema)

    if project.post_parse_hook:
        project.post_parse_hook(schema)

    # Finish up all class definitions -- ie. look at the inheritance tree
    # and gather up the list of all attributes that should be in instances
    # of a class, including those inherited from superclasses.
    for klass_name in schema.get_class_names():
        klass_def = schema.get_class_definition(klass_name)
        klass_def.finish_definition()

    return schema


def write_schema (schema, text_filename, pickle_filename):

    if text_filename:
        announce("writing object schema to %s..." % text_filename)
        schema_file = open(text_filename, "w")
        schema.write_aliases(schema_file)
        schema_file.write("\n\n")

        for klass_name in schema.get_class_names():
            klass_def = schema.get_class_definition(klass_name)
            klass_def.write(schema_file)
            schema_file.write("\n")
        schema_file.close()
        announce("\n")

    if pickle_filename:
        announce("pickling object schema to %s..." % pickle_filename)
        schema_file = open(pickle_filename, "w")
        dump(schema, schema_file, 1)
        schema_file.close()
        announce("\n")


# -- Main program ------------------------------------------------------

class ProjectDescription:
    """
    Instance attributes:
      atomic_types : [(any, string) | any]
      forward_classes : [string]
      type_aliases : [(alias:string, alias_expansion:string)]
      prefix : string
      dirs : [string]
      extra_modules : [string | (string, string)]
      exclude_modules : [string]
      exclude_classes : [string]

      post_parse_hook : function
    """

    def __init__ (self):
        self.atomic_types = []
        self.forward_classes = []
        self.type_aliases = []
        self.prefix = None
        self.dirs = []
        self.extra_modules = []
        self.exclude_modules = []
        self.exclude_classes = []

        self.post_parse_hook = None

    def read (self, filename):
        data = {}
        if not os.path.isfile(filename):
            raise ValueError, "no such file: %s" % filename
        try:
            execfile(filename, data)
        except:
            (t, v, tb) = sys.exc_info()
            error("error in %s:" % filename)
            exc = "".join(traceback.format_exception_only(t, v))
            sys.stderr.write(exc)
            sys.exit(1)

        for (name, value) in data.items():
            if hasattr(self, name):
                setattr(self, name, value)

        self.check_types(filename)

    def check_types (self, filename):
        from grouch.context import TypecheckContext

        schema = ObjectSchema()
        schema.add_atomic_type(lambda: None) # add 'function' type
        cdef = ClassDefinition(self.__class__.__name__, schema)
        schema.add_class(cdef)
        klass = ClassInfo("ProjectDescription",
                          "ProjectDescription",
                          [], self.__class__.__doc__)
        errors = parse_docstring(klass, schema)
        assert not errors, "errors in my own docstring!"
        cdef.finish_definition()
        context = TypecheckContext(report_errors=0)
        schema.check_value(self, context)
        if context.num_errors() > 0:
            error("type errors in %s:" % filename)
            context.write_errors(sys.stderr)
            sys.exit(1)

    def add_extra_modules (self, modules, base_dir):
        """add_extra_modules(modules : [(string, string)])

        Adds the modules listed in self.extra_modules to modules.  A bit
        tricky because extra_modules might not include filenames, in
        which case we have to go out and find the files.
        """

        # extra_modules is a list of any of the following
        #   string
        #     fully-qualified module name -- we'll have to search
        #     sys.path to find the file
        #   (modname : string, filename : string)
        #     module name with filename; we have to make sure the
        #     file exists, and try it relative to base_dir if not

        # yow! this is hairy and undertested!

        for module in self.extra_modules:
            if type(module) is types.StringType:
                modname = module
                comps = module.split(".")
                tail = os.path.join(*comps) + ".py"
                for dir in sys.path:
                    filename = os.path.join(dir, tail)
                    if os.path.exists(filename):
                        break
                else:
                    raise ValueError(
                        "no such module %s: %s not found in sys.path"
                        % (module, tail))

            elif type(module) is types.TupleType and len(module) == 2:
                (modname, filename) = module
                if os.path.exists(filename):
                    pass                # found it, good
                elif base_dir:          # we have a second chance
                    filename2 = os.path.join(base_dir, filename)
                    if os.path.exists(filename2):
                        filename = filename2
                    else:
                        raise ValueError(
                            "no such module %s: neither %s nor %s exist"
                            % (modname, filename, filename2))
                else:
                    raise ValueError(
                        "no such module %s: %s does not exist"
                        % (modname, filename))
            else:
                raise TypeError("bad extra_modules")

            assert os.path.exists(filename)
            modules.append((modname, filename))

        # for module

    # add_extra_modules ()


def main ():
    global VERBOSITY                    # global because announce(

    prog = os.path.basename(sys.argv[0])
    args = sys.argv[1:]

    usage = """\
usage: %s [options]
options:
  -v                        verbose: run noisily (repeat for more noise)
  -q                        run quietly
  -p FILE, --project=FILE   read project description info from FILE
  -o FILE, --output=FILE    output file: write pickled schema to FILE
                            [default: schema.pkl]
  -t FILE, --text=FILE      text output: write human-readable schema to FILE
                            [default: none]
  -d DIR, --base-dir=DIR    interpret DIRS in project description file
                            relative to DIR
""" % prog

    try:
        (opts, args) = getopt.getopt(args, "p:o:t:d:vq",
                                     ["project=",
                                      "output=",
                                      "text=",
                                      "base-dir=",])
    except getopt.error, msg:
        sys.exit(usage + str(msg))

    project = ProjectDescription()
    VERBOSITY = 1
    text_filename = None
    pickle_filename = "schema.pkl"
    base_dir = None
    prefix = None
    for (opt, val) in opts:
        if opt in ("-p", "project"):
            project.read(val)
        elif opt in ("-o", "--output"):
            pickle_filename = val
        elif opt in ("-t", "text"):
            text_filename = val
        elif opt in ("-d", "--base-dir"):
            base_dir = val
        elif opt == "-v":
            VERBOSITY += 1
        elif opt == "-q":
            VERBOSITY = 0

    if len(args) != 0:
        raise SystemExit, usage + "error: too many arguments"

    # Find modules in project.dirs, classes in those modules, and parse
    # the docstring of each class.  The result is an ObjectSchema, a
    # collection of atomic types, type aliases, and class definitions.
    schema = generate_schema(project, base_dir)

    # Write the schema out to either or both of a text and pickle file.
    write_schema(schema, text_filename, pickle_filename)

# main ()


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        sys.exit("interrupted")