#ifndef AST_HH
#define AST_HH

#include <iostream>
#include <string>
#include <list>

namespace ast {

    // FIXME: refactor this -- use classes instead of enums
    enum binop_t {
	BINOP_PLUS,
	BINOP_MINUS,
	BINOP_MULT,
	BINOP_DIV,
    };

    enum basic_type_t {
	TYPE_CHAR,
	TYPE_INT,
	TYPE_FLOAT,
    };

    // forward decl.
    class Fun;
    class Expr;
    class Visitor;

    class Ast {
	int _line_no;
	int _ast_id;
	static int current_id;

	Ast();
	Ast(Ast&);

    protected:
	Ast(int line_no) : _line_no(line_no), _ast_id(++current_id) { ; }
	
    public:
	int line_no() const { return _line_no; }
	int ast_id()  const { return _ast_id; }

	virtual void accept(Visitor &visitor) = 0;
	virtual void accept_top_down(Visitor &visitor) = 0;
    };

    class BinRel : public Ast {
    protected:
	BinRel(int line_no) : Ast(line_no) { ; }

    public:
	virtual bool is_less()    const { return false; }
	virtual bool is_greater() const { return false; }

	virtual bool include()    const { return false; }
    };

    class BinRelLT : public BinRel {
    public:
	BinRelLT(int line_no) : BinRel(line_no) { ; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);

	virtual bool is_less() const { return true; }
    };

    class BinRelLTE : public BinRel {
    public:
	BinRelLTE(int line_no) : BinRel(line_no) { ; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);

	virtual bool is_less() const { return true; }
	virtual bool include() const { return true; }
    };

    class BinRelGT : public BinRel {
    public:
	BinRelGT(int line_no) : BinRel(line_no) { ; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);

	virtual bool is_greater() const { return true; }
    };

    class BinRelGTE : public BinRel {
    public:
	BinRelGTE(int line_no) : BinRel(line_no) { ; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);

	virtual bool is_greater() const { return true; }
	virtual bool include() const { return true; }
    };

    class BinRelEQ : public BinRel {
    public:
	BinRelEQ(int line_no) : BinRel(line_no) { ; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class BinRelNEQ : public BinRel {
    public:
	BinRelNEQ(int line_no) : BinRel(line_no) { ; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };



    class RExpr : public Ast {
	Expr   *_begin;
	BinRel *_begin_br;
	char   *_index;
	BinRel *_end_br;
	Expr   *_end;
   
    public:
	RExpr(int line_no,
	      Expr *begin, BinRel *begin_br,
	      char *index, BinRel *end_br, Expr *end);

	Expr       *begin()    const { return _begin; }
	BinRel     *begin_br() const { return _begin_br; }
	const char *index()    const { return _index; }
	Expr       *end()      const { return _end; }
	BinRel     *end_br()   const { return _end_br; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class BExpr : public Ast {
    protected:
	BExpr(int line_no) : Ast(line_no) { ; }
    };

    class RelBExpr : public BExpr {
	Expr     *_left;
	BinRel   *_rel;
	Expr     *_right;

    public:
	RelBExpr(int line_no, Expr *left, BinRel *rel, Expr *right);

	Expr     *left()  const { return _left; }
	BinRel   *rel()   const { return _rel; }
	Expr     *right() const { return _right; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class ANDBExpr : public BExpr {
	BExpr *_left;
	BExpr *_right;

    public:
	ANDBExpr(int line_no, BExpr *left, BExpr *right);

	BExpr *left()  const { return _left; }
	BExpr *right() const { return _right; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class ORBExpr : public BExpr {
	BExpr *_left;
	BExpr *_right;

    public:
	ORBExpr(int line_no, BExpr *left, BExpr *right);

	BExpr *left()  const { return _left; }
	BExpr *right() const { return _right; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class NOTBExpr : public BExpr {
	BExpr *_bexpr;
    public:
	NOTBExpr(int line_no, BExpr *bexpr);

	BExpr *bexpr() const { return _bexpr; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class TrueExpr : public BExpr {
    public:
	TrueExpr(int line_no = 0) : BExpr(line_no) {};

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class WhenExpr : public Ast {
	Fun *_fun;
	BExpr *_bexpr;
    public:
	WhenExpr(int line_no, Fun *fun, BExpr *bexpr = new TrueExpr);

	Fun   *fun()   const { return _fun; }
	BExpr *bexpr() const { return _bexpr; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class WhereExpr : public Ast {
	Fun *_fun;
	std::list<RExpr*> *_rexpr_list;
    public:
	WhereExpr(int line_no, Fun *fun, std::list<RExpr*> *rexpr_list);

	Fun               *fun()        const { return _fun; }
	std::list<RExpr*> *rexpr_list() const { return _rexpr_list; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class Fun : public Ast {
	char *_id;

    protected: 
	Fun(int line_no, char *id)
	    : Ast(line_no), _id(id)
	{}

    public:
	const char *id()         const { return _id; }

    };

    class Range : public Fun {
	WhereExpr *_where_expr;

    public:
	Range(int line_no, char *id, WhereExpr *where_expr);

	WhereExpr  *where_expr() const { return _where_expr; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class Values : public Fun {
	std::list<WhenExpr*> *_when_expr_list;

    public:
	Values(int line_no, char *id, std::list<WhenExpr*> *when_expr_list);

	std::list<WhenExpr*> *when_expr_list() const
	{ return _when_expr_list; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class Expr;
    class SimpleFun : public Fun {
	Expr *_expr;
    public:
	SimpleFun(int line_no, Expr *expr);
	Expr *expr() const { return _expr; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };
	
    class Expr : public Ast {
    protected:
	Expr(int line_no);
    };

    class IDExpr : public Expr {
	char *_id;
    public:
	IDExpr(int line_no, char *id);
	const char *id() const { return _id; }
	
	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class IntegerExpr : public Expr {
	int _integer;
    public:
	IntegerExpr(int line_no, int integer);

	int integer() const { return _integer; }
	
	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class MatrixExpr : public Expr {
	char *_id;
	std::list<Expr*> *_expr_list;

    public:
	MatrixExpr(int line_no, char *id, std::list<Expr*> *expr_list);

	const char             *id()        const { return _id; }
	const std::list<Expr*> *expr_list() const { return _expr_list; }
	
	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class FunCallExpr : public Expr {
	char *_id;
	std::list<Expr*> *_expr_list;

    public:
	FunCallExpr(int line_no, char *id, std::list<Expr*> *expr_list);

	const char             *id()        const { return _id; }
	const std::list<Expr*> *expr_list() const { return _expr_list; }
	
	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class BinOpExpr : public Expr {
	Expr *_left;
	binop_t _binop;
	Expr *_right;

    public:
	BinOpExpr(int line_no, Expr *left, binop_t binop, Expr *right);

	Expr    *left()  const { return _left; }
	binop_t  binop() const { return _binop; }
	Expr    *right() const { return _right; }
	
	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class NEGExpr : public Expr {
	Expr *_expr;
    public:
	NEGExpr(int line_no, Expr *expr);

	Expr *expr() const { return _expr; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };


    class Update : public Ast {
	basic_type_t             _type;
	char                    *_name;
	std::list<const char *> *_indices;
	Fun                     *_fun;

    public:
	Update(int line_no,
	       basic_type_t type,
	       char *name,
	       std::list<const char *> *indices,
	       Fun *fun);

	basic_type_t                   type()    const { return _type; }
	const char                    *name()    const { return _name; }
	const std::list<const char *> *indices() const { return _indices; }
	Fun                           *fun()     const { return _fun; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class Declaration : public Ast {
	const char *_name;
    protected:
	Declaration(int line_no, const char *name);
    public:
	const char *name() const { return _name; }
    };

    class ValDeclaration : public Declaration {
	basic_type_t _type;
    public:
	ValDeclaration(int line_no, const char *name, basic_type_t type);

	basic_type_t type() const { return _type; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    // FIXME: handle more than arity!
    class FunDeclaration : public Declaration {
	std::list<basic_type_t> *_parameter_types;
	basic_type_t _return_type;
    public:
	FunDeclaration(int line_no, const char *name,
		       std::list<basic_type_t> *parameter_types,
		       basic_type_t return_type);

	int                      arity()            const
	{ return _parameter_types->size(); }
	std::list<basic_type_t> *parameter_types()  const
	{ return _parameter_types; }
	basic_type_t             return_type()      const
	{ return _return_type; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class MatrixDeclaration : public Declaration {
	std::list<Expr*> *_dim_list;
	basic_type_t _type;
    public:
	MatrixDeclaration(int line_no, const char *name,
			  std::list<Expr*> *dim_list,
			  basic_type_t type);

	int               arity()    const { return _dim_list->size(); }
	std::list<Expr*> *dim_list() const { return _dim_list; }
	basic_type_t      type()     const { return _type; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    class DProg : public Ast {
	std::list<Declaration*> *_globals;
	std::list<Declaration*> *_parameters;
	std::list<Declaration*> *_locals;

	std::list<Update*> *_update_list;
	std::list<RExpr*>  *_rexpr_list;

    public:
	DProg(int line_no,
	      std::list<Declaration*> *globals,
	      std::list<Declaration*> *parameters,
	      std::list<Declaration*> *locals,
	      std::list<Update*> *update_list,
	      std::list<RExpr*>  *rexpr_list);

	const std::list<Declaration*> *globals()    const { return _globals; }
	const std::list<Declaration*> *parameters() const { return _parameters; }
	const std::list<Declaration*> *locals()     const { return _locals; }

	const std::list<Update*> *update_list() const { return _update_list; }
	const std::list<RExpr*>  *rexpr_list()  const { return _rexpr_list; }

	virtual void accept(Visitor &visitor);
	virtual void accept_top_down(Visitor &visitor);
    };

    // generic exception refering to an ast node.
    class Exception {
	const Ast &_ast;

	Exception();
	Exception(Exception&);

    protected:
	Exception(const Ast &ast) : _ast(ast) {}
	virtual void print_error_msg(std::ostream &os) = 0;

    public:
	const Ast &ast() const { return _ast; }

	void print_error(std::ostream &os);
    };

};

std::ostream &operator<<(std::ostream &os, const ast::basic_type_t  bt);
std::ostream &operator<<(std::ostream &os, const ast::binop_t       bo);
std::ostream &operator<<(std::ostream &os, const ast::Ast          &ast);


#endif // AST_HH


syntax highlighted by Code2HTML, v. 0.9.1