// This file is part of fityk program. Copyright (C) 2005 Marcin Wojdyr
// Licence: GNU General Public License version 2
// $Id: ast.cpp 294 2007-05-16 03:18:25Z wojdyr $
// Based on ast_calc example from Boost::Spirit by Daniel Nuffer
//TODO:
// CSE in tree (or in VM code?)
// new op: SQR? DUP? STORE/WRITE?
// AST -> VM code
// output VM code for tests
// constant-merge (merge identical constants)
#include <boost/spirit/core.hpp>
#include <boost/spirit/tree/ast.hpp>
#include <sstream>
#include <string>
#include <vector>
#include <cassert>
#include <cstdlib>
#include <cmath>
#include "common.h"
#include "ui.h"
#include "ast.h"
#include "var.h"
#include "datatrans.h"
#include "logic.h"
#include "numfuncs.h"
#include "voigt.h"
////////////////////////////////////////////////////////////////////////////
using namespace std;
using namespace boost::spirit;
//typedef char const* iterator_t;
//typedef tree_match<iterator_t> parse_tree_match_t;
//typedef parse_tree_match_t::tree_iterator iter_t;
//typedef parse_tree_match_t::const_tree_iterator const_iter_t;
//typedef tree_match<char const*>::const_tree_iterator const_iter_t;
//typedef tree_match<char const*>::const_tree_iterator const_tm_iter_t;
////////////////////////////////////////////////////////////////////////////
OpTree::OpTree(int n, OpTree *arg1) : op(n), c1(arg1), c2(0), val(0.)
{ assert(n >= OP_ONE_ARG && n < OP_TWO_ARG); }
OpTree::OpTree(int n, OpTree *arg1, OpTree *arg2)
: op(n), c1(arg1), c2(arg2), val(0.) { assert(n >= OP_TWO_ARG); }
string OpTree::str(const vector<string> *vars)
{
if (op < 0) {
int v_nr = -op-1;
return vars->empty() ? "var"+S(v_nr) : (*vars)[v_nr];
}
switch (op) {
case 0: return S(val);
case OP_NEG: return "-" + c1->str_b(c1->op >= OP_POW, vars);
case OP_EXP: return "exp(" + c1->str(vars) + ")";
case OP_ERFC: return "erfc(" + c1->str(vars) + ")";
case OP_ERF: return "erf(" + c1->str(vars) + ")";
case OP_SIN: return "sin(" + c1->str(vars) + ")";
case OP_COS: return "cos(" + c1->str(vars) + ")";
case OP_ATAN: return "atan("+ c1->str(vars) + ")";
case OP_TAN: return "tan(" + c1->str(vars) + ")";
case OP_ASIN: return "asin("+ c1->str(vars) + ")";
case OP_ACOS: return "acos("+ c1->str(vars) + ")";
case OP_LGAMMA: return "lgamma("+ c1->str(vars) + ")";
case OP_DIGAMMA: return "digamma("+ c1->str(vars) + ")";
case OP_LOG10:return "log10("+c1->str(vars) + ")";
case OP_LN: return "ln(" + c1->str(vars) + ")";
case OP_SQRT: return "sqrt("+ c1->str(vars) + ")";
case OP_VOIGT: return "voigt("+ c1->str(vars) +","+ c2->str(vars) +")";
case OP_DVOIGT_DX: return "dvoigt_dx("+ c1->str(vars)
+ "," + c2->str(vars) + ")";
case OP_DVOIGT_DY: return "dvoigt_dy("+ c1->str(vars)
+ "," + c2->str(vars) + ")";
case OP_POW: return c1->str_b(c1->op >= OP_POW, vars)
+ "^" + c2->str_b(c2->op >= OP_POW, vars);
case OP_ADD: return c1->str(vars) + "+" + c2->str(vars);
case OP_SUB: return c1->str(vars) + "-"
+ c2->str_b(c2->op >= OP_ADD, vars);
case OP_MUL: return c1->str_b(c1->op >= OP_ADD, vars)
+ "*" + c2->str_b(c2->op >= OP_ADD, vars);
case OP_DIV: return c1->str_b(c1->op >= OP_ADD, vars)
+ "/" + c2->str_b(c2->op >= OP_MUL, vars);
default: assert(0); return "";
}
}
string OpTree::ascii_tree(int width, int start, const vector<string> *vars)
{
string node = "???";
if (op < 0) {
int v_nr = -op-1;
node = vars->empty() ? "var"+S(v_nr) : (*vars)[v_nr];
}
else
switch (op) {
case 0: node = S(val); break;
case OP_NEG: node = "NEG"; break;
case OP_EXP: node = "EXP"; break;
case OP_ERFC: node = "ERFC"; break;
case OP_ERF: node = "ERF"; break;
case OP_SIN: node = "SIN"; break;
case OP_COS: node = "COS"; break;
case OP_ATAN: node = "ATAN"; break;
case OP_TAN: node = "TAN"; break;
case OP_ASIN: node = "ASIN"; break;
case OP_ACOS: node = "ACOS"; break;
case OP_LOG10:node = "LOG"; break;
case OP_LN: node = "LN"; break;
case OP_SQRT: node = "SQRT"; break;
case OP_POW: node = "POW"; break;
case OP_ADD: node = "ADD"; break;
case OP_SUB: node = "SUB"; break;
case OP_MUL: node = "MUL"; break;
case OP_DIV: node = "DIV"; break;
}
int n = (int(node.size()) < width ? start + (width-node.size())/2
: start);
node = string(n, ' ') + node + "\n";
if (c1)
node += c1->ascii_tree(width/2, start, vars);
if (c2)
node += c2->ascii_tree(width/2, start+width/2, vars);
return node;
}
OpTree* OpTree::copy() const
{
OpTree *t = new OpTree(*this);
if (c1) t->c1 = c1->copy();
if (c2) t->c2 = c2->copy();
return t;
}
////////////////////////////////////////////////////////////////////////////
namespace {
void do_find_tokens(int tokenID, const_tm_iter_t const &i, vector<string> &vars)
{
for (const_tm_iter_t j = i->children.begin(); j != i->children.end(); ++j) {
if (j->value.id() == tokenID) {
string v(j->value.begin(), j->value.end());
if (find(vars.begin(), vars.end(), v) == vars.end())
vars.push_back(v);
}
else
do_find_tokens(tokenID, j, vars);
}
}
} // anonymous namespace
vector<string> find_tokens_in_ptree(int tokenID, const tree_parse_info<> &info)
{
vector<string> vars;
const_tm_iter_t const &root = info.trees.begin();
if (root->value.id() == tokenID)
vars.push_back(string(root->value.begin(), root->value.end()));
else
do_find_tokens(tokenID, root, vars);
return vars;
}
////////////////////////////////////////////////////////////////////////////
OpTree* simplify_terms(OpTree *a);
OpTree* do_multiply(OpTree *a, OpTree *b);
OpTree* do_neg(OpTree *a)
{
if (a->op == 0) {
double val = - a->val;
delete a;
return new OpTree(val);
}
else if (a->op == OP_NEG) {
OpTree *t = a->c1->copy();
delete a;
return t;
}
else
return new OpTree(OP_NEG, a);
}
OpTree* do_add(int op, OpTree *a, OpTree *b)
{
if (a->op == 0 && b->op == 0) { // p + q
double val = (op == OP_ADD ? a->val + b->val : a->val - b->val);
delete a;
delete b;
return new OpTree(val);
}
else if (a->op == 0 && is_eq(a->val, 0.)) { // 0 + t
delete a;
if (op == OP_ADD)
return b;
else
return do_neg(b);
}
else if (b->op == 0 && is_eq(b->val, 0.)) { // t + 0
delete b;
return a;
}
else if (b->op == OP_NEG) { // t + -u
OpTree *t = b->remove_c1();
delete b;
return do_add(op == OP_ADD ? OP_SUB : OP_ADD, a, t);
}
else if ((b->op == OP_MUL || b->op == OP_DIV)
&& b->c1->op == 0 && b->c1->val < 0) { // t + -p*v
b->c1->val = - b->c1->val;
return do_add(op == OP_ADD ? OP_SUB : OP_ADD, a, b);
}
else if (*a == *b) {
delete b;
if (op == OP_ADD) // t + t
return do_multiply(new OpTree(2.), a);
else { // t - t
delete a;
return new OpTree(0.);
}
}
else
return new OpTree(op, a, b);
}
OpTree* do_add(OpTree *a, OpTree *b)
{
return do_add(OP_ADD, a, b);
}
OpTree* do_sub(OpTree *a, OpTree *b)
{
return do_add(OP_SUB, a, b);
}
OpTree* do_multiply(OpTree *a, OpTree *b)
{
if (a->op == 0 && b->op == 0) // const * const
{
double val = a->val * b->val;
delete a;
delete b;
return new OpTree(val);
}
else if ((a->op == 0 && is_eq(a->val, 0.)) // 0 * ...
|| (b->op == 0 && is_eq(b->val, 0.))) // ... * 0
{
delete a;
delete b;
return new OpTree(0.);
}
else if (a->op == 0 && is_eq(a->val, 1.)) { // 1 * ...
delete a;
return b;
}
else if (b->op == 0 && is_eq(b->val, 1.)) { // ... * 1
delete b;
return a;
}
else if (a->op == 0 && is_eq(a->val, -1.)) { // -1 * ...
delete a;
return do_neg(b);
}
else if (b->op == 0 && is_eq(b->val, -1.)) { // ... * -1
delete b;
return do_neg(a);
}
// const1 * (const2 / ...)
else if (a->op == 0 && b->op == OP_DIV && b->c1->op == 0) {
b->c1->val *= a->val;
delete a;
return b;
}
else {
return new OpTree(OP_MUL, a, b);
}
}
OpTree* do_divide(OpTree *a, OpTree *b)
{
//no check for division by zero
if (a->op == 0 && b->op == 0)
{
double val = a->val / b->val;
delete a;
delete b;
return new OpTree(val);
}
else if (a->op == 0 && is_eq(a->val, 0.)) {
delete a;
delete b;
return new OpTree(0.);
}
else if (b->op == 0 && is_eq(b->val, 1.)) {
delete b;
return a;
}
else {
return new OpTree(OP_DIV, a, b);
}
}
OpTree *do_sqr(OpTree *a)
{
return do_multiply(a, a->copy());
//return new OpTree(OP_MUL, a, a->copy());
}
OpTree *do_oneover(OpTree *a)
{
return do_divide(new OpTree(1.), a);
}
OpTree* do_exp(OpTree *a)
{
if (a->op == 0) {
double val = exp(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_EXP, simplify_terms(a));
}
OpTree* do_erf(OpTree *a)
{
if (a->op == 0) {
double val = erf(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_ERF, simplify_terms(a));
}
OpTree* do_erfc(OpTree *a)
{
if (a->op == 0) {
double val = erfc(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_ERFC, simplify_terms(a));
}
OpTree* do_sqrt(OpTree *a)
{
if (a->op == 0) {
double val = sqrt(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_SQRT, a);
}
OpTree* do_log10(OpTree *a)
{
if (a->op == 0) {
double val = log10(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_LOG10, simplify_terms(a));
}
OpTree* do_ln(OpTree *a)
{
if (a->op == 0) {
double val = log(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_LN, simplify_terms(a));
}
OpTree* do_sin(OpTree *a)
{
if (a->op == 0) {
double val = sin(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_SIN, simplify_terms(a));
}
OpTree* do_cos(OpTree *a)
{
if (a->op == 0) {
double val = cos(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_COS, simplify_terms(a));
}
OpTree* do_tan(OpTree *a)
{
if (a->op == 0) {
double val = tan(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_TAN, simplify_terms(a));
}
OpTree* do_atan(OpTree *a)
{
if (a->op == 0) {
double val = atan(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_ATAN, simplify_terms(a));
}
OpTree* do_asin(OpTree *a)
{
if (a->op == 0) {
double val = asin(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_ASIN, simplify_terms(a));
}
OpTree* do_acos(OpTree *a)
{
if (a->op == 0) {
double val = acos(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_ACOS, simplify_terms(a));
}
OpTree* do_lgamma(OpTree *a)
{
if (a->op == 0) {
double val = lgammafn(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_LGAMMA, simplify_terms(a));
}
OpTree* do_digamma(OpTree *a)
{
if (a->op == 0) {
double val = lgammafn(a->val);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_DIGAMMA, simplify_terms(a));
}
OpTree* do_pow(OpTree *a, OpTree *b)
{
if (a->op == 0 && b->op == 0) {
double val = pow(a->val, b->val);
delete a;
delete b;
return new OpTree(val);
}
else if (a->op == 0 && is_eq(a->val, 0.)) {
delete a;
delete b;
return new OpTree(0.);
}
else if ((b->op == 0 && is_eq(b->val, 0.))
|| (a->op == 0 && is_eq(a->val, 1.))) {
delete a;
delete b;
return new OpTree(1.);
}
else if (b->op == 0 && is_eq(b->val, 1.)) {
delete b;
return a;
}
else if (b->op == 0 && is_eq(b->val, -1.)) {
delete b;
return do_oneover(a);
}
else {
return new OpTree(OP_POW, a, simplify_terms(b));
}
}
OpTree* do_voigt(OpTree *a, OpTree *b)
{
if (a->op == 0 && b->op == 0) {
double val = humlik(a->val, b->val) / sqrt(M_PI);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_VOIGT, simplify_terms(a), simplify_terms(b));
}
OpTree* do_dvoigt_dx(OpTree *a, OpTree *b)
{
if (a->op == 0 && b->op == 0) {
double val = humdev_dkdx(a->val, b->val) / sqrt(M_PI);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_DVOIGT_DX, simplify_terms(a), simplify_terms(b));
}
OpTree* do_dvoigt_dy(OpTree *a, OpTree *b)
{
if (a->op == 0 && b->op == 0) {
double val = humdev_dkdy(a->val, b->val) / sqrt(M_PI);
delete a;
return new OpTree(val);
}
else
return new OpTree(OP_DVOIGT_DY, simplify_terms(a), simplify_terms(b));
}
////////////////////////////////////////////////////////////////////////////
struct MultFactor
{
// factor (*t)^(*e)
OpTree *t, *e;
MultFactor(OpTree *t_, OpTree *e_) : t(t_), e(e_) {}
void clear() { delete t; delete e; t=e=0; }
};
/// recursively walk though OP_MUL, OP_DIV, OP_NEG, OP_SQRT, OP_POW
/// and builds list of nodes with factors, such that tree a is equal to
/// (v[0]->t)^(v[0]->e) * (v[1]->t)^(v[1]->e) * ...
void get_factors(OpTree *a, OpTree *expo,
double &constant, vector<MultFactor>& v)
{
if (a->op == OP_ADD || a->op == OP_SUB)
a = simplify_terms(a);
if (a->op == 0 && expo->op == 0)
constant *= pow(a->val, expo->val);
else if (a->op == OP_MUL) {
get_factors(a->c1, expo, constant, v);
get_factors(a->c2, expo, constant, v);
}
else if (a->op == OP_DIV) {
get_factors(a->c1, expo, constant, v);
OpTree *expo2 = do_neg(expo->copy());
get_factors(a->c2, expo2, constant, v);
delete expo2;
}
else if (a->op == OP_NEG) {
get_factors(a->c1, expo, constant, v);
get_factors(new OpTree(-1.), expo, constant, v);
}
else if (a->op == OP_SQRT) {
OpTree *expo2 = do_multiply(new OpTree(0.5), expo->copy());
get_factors(a->c1, expo2, constant, v);
delete expo2;
}
else if (a->op == OP_POW) {
OpTree *expo2 = do_multiply(a->remove_c2(), expo->copy());
get_factors(a->c1, expo2, constant, v);
delete expo2;
}
else {
bool found = false;
for (vector<MultFactor>::iterator i = v.begin(); i != v.end(); ++i)
if (*i->t == *a) {
i->e = do_add(i->e, expo->copy());
found = true;
break;
}
if (!found) {
v.push_back(MultFactor(a, expo->copy()));
return; //don't delete a
}
}
//we are here -- MultFactor(a,...) not created
a->c1 = a->c2 = 0;
delete a;
}
OpTree* simplify_factors(OpTree *a)
{
#ifdef DEBUG_SIMPLIFY
cout << "simplify_factors() [<] " << a->str() << endl;
#endif
vector<MultFactor> v;
OpTree expo(1.);
double constant = 1;
get_factors(a, &expo, constant, v); //deletes a
#ifdef DEBUG_SIMPLIFY
cout << "simplify_factors(): [.] {" << constant << "} ";
for (vector<MultFactor>::iterator i = v.begin(); i != v.end(); ++i)
cout << "{" << i->t->str() << "|" << i->e->str() << "} ";
cout << endl;
#endif
// tan*cos -> sin; tan/sin -> cos
for (vector<MultFactor>::iterator i = v.begin(); i != v.end(); ++i)
if (i->t && i->t->op == OP_TAN) {
for (vector<MultFactor>::iterator j = v.begin(); j != v.end(); ++j){
if (j->t && j->t->op == OP_COS && *j->e == *i->e) {
i->t->change_op(OP_SIN);
j->clear();
}
if (j->t && j->t->op == OP_SIN
&& ((j->e->op==0 && i->e->op==0 && j->e->val==-i->e->val)
|| (j->e->op==OP_NEG && *j->e->c1 == *i->e)
|| (i->e->op==OP_NEG && *i->e->c1 == *j->e))) {
i->t->change_op(OP_COS);
j->clear();
}
}
}
// sin/cos -> tan
for (vector<MultFactor>::iterator i = v.begin(); i != v.end(); ++i)
if (i->t && i->t->op == OP_SIN) {
for (vector<MultFactor>::iterator j = v.begin(); j != v.end(); ++j){
if (j->t && j->t->op == OP_COS
&& ((j->e->op==0 && i->e->op==0 && j->e->val==-i->e->val)
|| (j->e->op==OP_NEG && *j->e->c1 == *i->e)
|| (i->e->op==OP_NEG && *i->e->c1 == *j->e))) {
i->t->change_op(OP_TAN);
j->clear();
}
}
}
// -> tree
// TODO x^z * y^z -> (x*y)^z (if z != -1,0,1)
OpTree *tu = 0, *tb = 0; // preparing expression as (tu / tb)
for (vector<MultFactor>::iterator i = v.begin(); i != v.end(); ++i)
if (i->t) {
if ((i->e->op == 0 && i->e->val < 0) || i->e->op == OP_NEG) {
OpTree *p = do_pow(i->t, do_neg(i->e));
tb = (tb == 0 ? p : do_multiply(tb, p));
}
else {
OpTree *p = do_pow(i->t, i->e);
tu = (tu == 0 ? p : do_multiply(tu, p));
}
}
OpTree *constant_t = new OpTree(constant);
OpTree *ret = 0;
if (tu) {
if (tb)
ret = do_multiply(constant_t, do_divide(tu, tb));
else //tu && !tb
ret = do_multiply(constant_t, tu);
}
else {
if (tb) //!tu && tb
ret = do_divide(constant_t, tb);
else //!tu && !tb
ret = constant_t;
}
#ifdef DEBUG_SIMPLIFY
cout << "simplify_factors() [>] " << ret->str() << endl;
#endif
return ret;
}
////////////////////////////////////////////////////////////////////////////
struct MultTerm
{
OpTree *t;
double k;
MultTerm(OpTree *t_, double k_) : t(t_), k(k_) {}
void clear() { delete t; t=0; }
};
void get_terms(OpTree *a, double multiplier, vector<MultTerm> &v)
{
if (a->op == OP_MUL || a->op == OP_DIV || a->op == OP_SQRT
|| a->op == OP_POW)
a = simplify_factors(a);
if (a->op == OP_ADD) { // p + q
get_terms(a->c1, multiplier, v);
get_terms(a->c2, multiplier, v);
a->c1 = a->c2 = 0;
delete a;
}
else if (a->op == OP_SUB) { // p - q
get_terms(a->c1, multiplier, v);
get_terms(a->c2, -multiplier, v);
a->c1 = a->c2 = 0;
delete a;
}
else if (a->op == OP_NEG) { // - p
get_terms(a->c1, -multiplier, v);
a->c1 = a->c2 = 0;
delete a;
}
else if (a->op == OP_MUL && a->c1->op == 0) { // const * p
get_terms(a->c2, multiplier*(a->c1->val), v);
a->c2 = 0;
delete a;
}
// const / p for const != 1 (to avoid loop)
else if (a->op == OP_DIV && a->c1->op == 0 && a->c1->val != 1.) {
get_terms(do_oneover(a->c2), multiplier*(a->c1->val), v);
a->c2 = 0;
delete a;
}
else { // a can't be splitted
for (vector<MultTerm>::iterator i = v.begin(); i != v.end(); ++i) {
if (a->op == 0 && i->t && i->t->op == 0) {// number (not the first)
i->k += multiplier * a->val;
delete a;
return;
}
if (i->t && *i->t == *a) { //token already in v
i->k += multiplier;
delete a;
return;
}
}
// we are here -- no first token of its kind
if (a->op == 0) { //add number
v.push_back(MultTerm(new OpTree(1.), multiplier * a->val));
delete a;
}
else { // add token
v.push_back(MultTerm(a, multiplier));
}
}
}
OpTree* simplify_terms(OpTree *a)
{
// not handled:
// (x+y) * (x-y) == x^2 - y^2
// (x+/-y)^2 == x^2 +/- 2xy + y^2
if (a->op == OP_MUL || a->op == OP_DIV || a->op == OP_SQRT
|| a->op == OP_POW)
return simplify_factors(a);
else if (!(a->op == OP_ADD || a->op == OP_SUB || a->op == OP_NEG))
return a;
#ifdef DEBUG_SIMPLIFY
cout << "simplify_terms() [<] " << a->str() << endl;
#endif
vector<MultTerm> v;
get_terms(a, 1., v); //deletes a
#ifdef DEBUG_SIMPLIFY
cout << "simplify_terms() [.] ";
for (vector<MultTerm>::iterator i = v.begin(); i != v.end(); ++i)
cout << "{" << i->t->str() << "|" << i->k << "} ";
cout << endl;
#endif
// sin^2(x) + cos^2(x) = 1
double to_add = 0.;
for (vector<MultTerm>::iterator i = v.begin(); i != v.end(); ++i)
if (i->t && i->t->op == OP_POW && i->t->c1->op == OP_SIN
&& i->t->c2->op == 0 && is_eq(i->t->c2->val, 2.))
for (vector<MultTerm>::iterator j = v.begin(); j != v.end(); ++j)
if (j->t && j->t->op == OP_POW && j->t->c1->op == OP_COS
&& j->t->c2->op == 0 && is_eq(j->t->c2->val, 2.)) {
double k = j->k;
i->k -= k;
j->clear();
to_add += k;
}
if (to_add)
get_terms(new OpTree(1.), to_add, v);
// -> tree
OpTree *t = 0;
for (vector<MultTerm>::iterator i = v.begin(); i != v.end(); ++i)
if (i->t && !is_eq(i->k, 0)) {
if (!t)
t = do_multiply(new OpTree(i->k), i->t);
else if (i->k > 0)
t = do_add(t, do_multiply(new OpTree(i->k), i->t));
else //i->k < 0
t = do_sub(t, do_multiply(new OpTree(-i->k), i->t));
}
if (!t)
t = new OpTree(0.);
#ifdef DEBUG_SIMPLIFY
cout << "simplify_terms() [>] " << t->str() << endl;
#endif
return t;
}
////////////////////////////////////////////////////////////////////////////
fp get_constant_value(string const &s)
{
if (s == "pi")
return M_PI;
else if (s[0] == '{') {
assert(*(s.end()-1) == '}');
string expr(s.begin()+1, s.end()-1);
Data const* data = 0;
string::size_type in_pos = expr.find(" in ");
if (in_pos != string::npos && in_pos+4 < expr.size()) {
string in_expr(expr, in_pos+4);
int n;
if (parse(in_expr.c_str(),
*ch_p(' ') >> '@' >> uint_p[assign_a(n)] >> *ch_p(' ')
).full) {
data = AL->get_data(n);
expr.resize(in_pos);
}
else
throw ExecuteError("Syntax error near: `" + in_expr + "'");
}
else if (AL->get_ds_count() == 1)
data = AL->get_data(0);
return get_transform_expression_value(expr, data);
}
else {
fp val = strtod(s.c_str(), 0);
if (val != 0. && fabs(val) < epsilon)
AL->warn("Warning: Numeric literal 0 < |" + s + "| < epsilon="
+ S(epsilon) + ".");
return val;
}
}
/// returns array of trees,
/// first n=vars.size() derivatives and the last tree for value
vector<OpTree*> calculate_deriv(const_tm_iter_t const &i,
vector<string> const &vars)
{
int len = vars.size();
vector<OpTree*> results(len + 1);
string s(i->value.begin(), i->value.end());
if (i->value.id() == FuncGrammar::real_constID)
{
assert(s.size() > 0);
for (int k = 0; k < len; ++k)
results[k] = new OpTree(0.);
double v = get_constant_value(s);
results[len] = new OpTree(v);
}
else if (i->value.id() == FuncGrammar::variableID)
{
for (int k = 0; k < len; ++k)
if (s == vars[k]) {
results[k] = new OpTree(1.);
results[len] = new OpTree(k, s);
}
else
results[k] = new OpTree(0.);
}
else if (i->value.id() == FuncGrammar::exptokenID)
{
if (i->children.size() == 1) {
vector<OpTree*> arg = calculate_deriv(i->children.begin(), vars);
OpTree* (* do_op)(OpTree *) = 0;
OpTree* der = 0;
OpTree* larg = arg.back()->copy();
if (s == "sqrt") {
der = do_divide(new OpTree(0.5), do_sqrt(larg));
do_op = do_sqrt;
}
else if (s == "exp") {
der = do_exp(larg);
do_op = do_exp;
}
else if (s == "erfc") {
// d/dz erfc(z) = -2/sqrt(pi) * exp(-z^2)
der = do_multiply(do_exp(do_neg(do_sqr(larg))),
new OpTree(-2/sqrt(M_PI)));
do_op = do_erfc;
}
else if (s == "erf") {
// d/dz erf(z) = 2/sqrt(pi) * exp(-z^2)
der = do_multiply(do_exp(do_neg(do_sqr(larg))),
new OpTree(2/sqrt(M_PI)));
do_op = do_erf;
}
else if (s == "log10") {
OpTree *ln_10 = do_ln(new OpTree(10.));
der = do_oneover(do_multiply(larg, ln_10));
do_op = do_log10;
}
else if (s == "ln") {
der = do_oneover(larg);
do_op = do_ln;
}
else if (s == "sin") {
der = do_cos(larg);
do_op = do_sin;
}
else if (s == "cos") {
der = do_neg(do_sin(larg));
do_op = do_cos;
}
else if (s == "tan") {
der = do_oneover(do_sqr(do_cos(larg)));
do_op = do_tan;
}
else if (s == "atan") {
der = do_oneover(do_add(new OpTree(1.), do_sqr(larg)));
do_op = do_atan;
}
else if (s == "asin") {
OpTree *root_arg = do_sub(new OpTree(1.), do_sqr(larg));
der = do_oneover(do_sqrt(root_arg));
do_op = do_asin;
}
else if (s == "acos") {
OpTree *root_arg = do_sub(new OpTree(1.), do_sqr(larg));
der = do_divide(new OpTree(-1.), do_sqrt(root_arg));
do_op = do_acos;
}
else if (s == "lgamma") {
der = do_digamma(larg);
do_op = do_lgamma;
}
else
assert(0);
for (int k = 0; k < len; ++k)
results[k] = do_multiply(der->copy(), arg[k]);
delete der;
results[len] = (*do_op)(arg[len]);
}
else if (i->children.size() == 2) {
vector<OpTree*>
left = calculate_deriv(i->children.begin(), vars),
right = calculate_deriv(i->children.begin() + 1, vars);
OpTree *d1=0, *d2=0;
if (s == "voigt") {
d1 = do_dvoigt_dx(left[len]->copy(), right[len]->copy());
d2 = do_dvoigt_dy(left[len]->copy(), right[len]->copy());
for (int k = 0; k < len; ++k) {
results[k] = do_add(do_multiply(d1->copy(), left[k]),
do_multiply(d2->copy(), right[k]));
}
results[len] = do_voigt(left[len], right[len]);
}
else
assert(0);
delete d1;
delete d2;
}
else
assert(0);
}
else if (i->value.id() == FuncGrammar::signargID)
{
assert(s == "^");
assert(i->children.size() == 2);
vector<OpTree*> left = calculate_deriv(i->children.begin(), vars);
vector<OpTree*> right = calculate_deriv(i->children.begin() + 1, vars);
// this special case is needed, because in cases like -2^4
// there was a problem with logarithm in formula below
if (left[len]->op == 0 && right[len]->op == 0) {
for (int k = 0; k < len; ++k) {
assert(left[k]->op == 0 && right[k]->op == 0);
delete left[k];
delete right[k];
results[k] = new OpTree(0.);
}
results[len] = do_pow(left[len], right[len]);
}
//another special cases like a(x)^n are not handeled separately.
//Should they?
else {
for (int k = 0; k < len; ++k) {
OpTree *a = left[len],
*b = right[len],
*ap = left[k],
*bp = right[k];
// b(x) b(x) b(x) a'(x)
// (a(x) )' = a(x) (---------- + ln(a(x)) b'(x))
// a(x)
OpTree *pow_a_b = do_pow(a->copy(), b->copy());
OpTree *term1 = do_divide(do_multiply(b->copy(),ap), a->copy());
OpTree *term2 = do_multiply(do_ln(a->copy()), bp);
results[k] = do_multiply(pow_a_b, do_add(term1, term2));
}
results[len] = do_pow(left[len], right[len]);
}
}
else if (i->value.id() == FuncGrammar::factorID)
{
assert (s == "-");
vector<OpTree*> arg = calculate_deriv(i->children.begin(), vars);
for (int k = 0; k < len+1; ++k)
results[k] = do_neg(arg[k]);
}
else if (i->value.id() == FuncGrammar::termID)
{
assert(s == "*" || s == "/");
assert(i->children.size() == 2);
int op = (s == "*" ? OP_MUL : OP_DIV);
vector<OpTree*> left = calculate_deriv(i->children.begin(), vars);
vector<OpTree*> right = calculate_deriv(i->children.begin() + 1, vars);
for (int k = 0; k < len; ++k) {
OpTree *a = left[len],
*b = right[len],
*ap = left[k],
*bp = right[k];
if (op == OP_MUL) { // a*b' + a'*b
results[k] = do_add(do_multiply(a->copy(), bp),
do_multiply(ap, b->copy()));
}
else { //OP_DIV (a'*b - b'*a) / (b*b)
OpTree *upper = do_sub(do_multiply(ap, b->copy()),
do_multiply(bp, a->copy()));
results[k] = do_divide(upper, do_sqr(b->copy()));
}
}
results[len] = (op == OP_MUL ? do_multiply(left[len], right[len])
: do_divide(left[len], right[len]));
}
else if (i->value.id() == FuncGrammar::expressionID)
{
assert(s == "+" || s == "-");
assert(i->children.size() == 2);
vector<OpTree*> left = calculate_deriv(i->children.begin(), vars);
vector<OpTree*> right = calculate_deriv(i->children.begin() + 1, vars);
for (int k = 0; k < len+1; ++k)
results[k] = (s == "+" ? do_add(left[k], right[k])
: do_sub(left[k], right[k]));
}
else
assert(0); // error
for (int k = 0; k < len+1; ++k)
results[k] = simplify_terms(results[k]);
return results;
}
/// debug utility, shows symbolic derivatives of given formula
string get_derivatives_str(string const &formula)
{
string s;
tree_parse_info<> info = ast_parse(formula.c_str(), FuncG, space_p);
if (!info.full)
throw ExecuteError("Can't parse formula: " + formula);
const_tm_iter_t const &root = info.trees.begin();
vector<string> vars = find_tokens_in_ptree(FuncGrammar::variableID, info);
vector<OpTree*> results = calculate_deriv(root, vars);
s = "f(" + join_vector(vars, ", ") + ") = " + results.back()->str(&vars);
for (size_t i = 0; i != vars.size(); ++i)
s += "\ndf / d " + vars[i] + " = " + results[i]->str(&vars);
purge_all_elements(results);
return s;
}
syntax highlighted by Code2HTML, v. 0.9.1