/*
* R : A Computer Language for Statistical Data Analysis
* Copyright (C) 1995, 1996 Robert Gentleman and Ross Ihaka
* Copyright (C) 1998-2003 The R Development Core Team.
* Copyright (C) 2004-5 The R Foundation
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* A copy of the GNU General Public License is available via WWW at
* http://www.gnu.org/copyleft/gpl.html. You can also obtain it by
* writing to the Free Software Foundation, Inc., 51 Franklin Street
* Fifth Floor, Boston, MA 02110-1301 USA.
*
*
* Symbolic Differentiation
*/
#ifdef HAVE_CONFIG_H
#include <config.h>
#endif
#include "Defn.h"
static SEXP ParenSymbol;
static SEXP PlusSymbol;
static SEXP MinusSymbol;
static SEXP TimesSymbol;
static SEXP DivideSymbol;
static SEXP PowerSymbol;
static SEXP ExpSymbol;
static SEXP LogSymbol;
static SEXP SinSymbol;
static SEXP CosSymbol;
static SEXP TanSymbol;
static SEXP SinhSymbol;
static SEXP CoshSymbol;
static SEXP TanhSymbol;
static SEXP SqrtSymbol;
static SEXP PnormSymbol;
static SEXP DnormSymbol;
static SEXP AsinSymbol;
static SEXP AcosSymbol;
static SEXP AtanSymbol;
static SEXP GammaSymbol;
static SEXP LGammaSymbol;
static SEXP PsiSymbol;
static Rboolean Initialized = FALSE;
static void InitDerivSymbols()
{
/* Called from do_D() and do_deriv() */
if(Initialized) return;
ParenSymbol = install("(");
PlusSymbol = install("+");
MinusSymbol = install("-");
TimesSymbol = install("*");
DivideSymbol = install("/");
PowerSymbol = install("^");
ExpSymbol = install("exp");
LogSymbol = install("log");
SinSymbol = install("sin");
CosSymbol = install("cos");
TanSymbol = install("tan");
SinhSymbol = install("sinh");
CoshSymbol = install("cosh");
TanhSymbol = install("tanh");
SqrtSymbol = install("sqrt");
PnormSymbol = install("pnorm");
DnormSymbol = install("dnorm");
AsinSymbol = install("asin");
AcosSymbol = install("acos");
AtanSymbol = install("atan");
GammaSymbol = install("gamma");
LGammaSymbol = install("lgamma");
PsiSymbol = install("psigamma");
Initialized = TRUE;
}
static SEXP Constant(double x)
{
SEXP s = allocVector(REALSXP, 1);
REAL(s)[0] = x;
return s;
}
static int isZero(SEXP s)
{
return asReal(s) == 0.0;
}
static int isOne(SEXP s)
{
return asReal(s) == 1.0;
}
static int isUminus(SEXP s)
{
if (TYPEOF(s) == LANGSXP && CAR(s) == MinusSymbol) {
switch(length(s)) {
case 2:
return 1;
case 3:
if (CADDR(s) == R_MissingArg)
return 1;
else return 0;
default:
error(_("invalid form in unary minus check"));
return -1;/* for -Wall */
}
}
else return 0;
}
/* Pointer protect and return the argument */
static SEXP PP(SEXP s)
{
PROTECT(s);
return s;
}
static SEXP simplify(SEXP fun, SEXP arg1, SEXP arg2)
{
SEXP ans;
if (fun == PlusSymbol) {
if (isZero(arg1))
ans = arg2;
else if (isZero(arg2))
ans = arg1;
else if (isUminus(arg1))
ans = simplify(MinusSymbol, arg2, CADR(arg1));
else if (isUminus(arg2))
ans = simplify(MinusSymbol, arg1, CADR(arg2));
else
ans = lang3(PlusSymbol, arg1, arg2);
}
else if (fun == MinusSymbol) {
if (arg2 == R_MissingArg) {
if (isZero(arg1))
ans = Constant(0.);
else if (isUminus(arg1))
ans = CADR(arg1);
else
ans = lang2(MinusSymbol, arg1);
}
else {
if (isZero(arg2))
ans = arg1;
else if (isZero(arg1))
ans = simplify(MinusSymbol, arg2, R_MissingArg);
else if (isUminus(arg1)) {
ans = simplify(MinusSymbol,
PP(simplify(PlusSymbol, CADR(arg1), arg2)),
R_MissingArg);
UNPROTECT(1);
}
else if (isUminus(arg2))
ans = simplify(PlusSymbol, arg1, CADR(arg2));
else
ans = lang3(MinusSymbol, arg1, arg2);
}
}
else if (fun == TimesSymbol) {
if (isZero(arg1) || isZero(arg2))
ans = Constant(0.);
else if (isOne(arg1))
ans = arg2;
else if (isOne(arg2))
ans = arg1;
else if (isUminus(arg1)) {
ans = simplify(MinusSymbol,
PP(simplify(TimesSymbol, CADR(arg1), arg2)),
R_MissingArg);
UNPROTECT(1);
}
else if (isUminus(arg2)) {
ans = simplify(MinusSymbol,
PP(simplify(TimesSymbol, arg1, CADR(arg2))),
R_MissingArg);
UNPROTECT(1);
}
else
ans = lang3(TimesSymbol, arg1, arg2);
}
else if (fun == DivideSymbol) {
if (isZero(arg1))
ans = Constant(0.);
else if (isZero(arg2))
ans = Constant(NA_REAL);
else if (isOne(arg2))
ans = arg1;
else if (isUminus(arg1)) {
ans = simplify(MinusSymbol,
PP(simplify(DivideSymbol, CADR(arg1), arg2)),
R_MissingArg);
UNPROTECT(1);
}
else if (isUminus(arg2)) {
ans = simplify(MinusSymbol,
PP(simplify(DivideSymbol, arg1, CADR(arg2))),
R_MissingArg);
UNPROTECT(1);
}
else ans = lang3(DivideSymbol, arg1, arg2);
}
else if (fun == PowerSymbol) {
if (isZero(arg2))
ans = Constant(1.);
else if (isZero(arg1))
ans = Constant(0.);
else if (isOne(arg1))
ans = Constant(1.);
else if (isOne(arg2))
ans = arg1;
else
ans = lang3(PowerSymbol, arg1, arg2);
}
else if (fun == ExpSymbol) {
/* FIXME: simplify exp(lgamma( E )) = gamma( E ) */
ans = lang2(ExpSymbol, arg1);
}
else if (fun == LogSymbol) {
/* FIXME: simplify log(gamma( E )) = lgamma( E ) */
ans = lang2(LogSymbol, arg1);
}
else if (fun == CosSymbol) ans = lang2(CosSymbol, arg1);
else if (fun == SinSymbol) ans = lang2(SinSymbol, arg1);
else if (fun == TanSymbol) ans = lang2(TanSymbol, arg1);
else if (fun == CoshSymbol) ans = lang2(CoshSymbol, arg1);
else if (fun == SinhSymbol) ans = lang2(SinhSymbol, arg1);
else if (fun == TanhSymbol) ans = lang2(TanhSymbol, arg1);
else if (fun == SqrtSymbol) ans = lang2(SqrtSymbol, arg1);
else if (fun == PnormSymbol)ans = lang2(PnormSymbol, arg1);
else if (fun == DnormSymbol)ans = lang2(DnormSymbol, arg1);
else if (fun == AsinSymbol) ans = lang2(AsinSymbol, arg1);
else if (fun == AcosSymbol) ans = lang2(AcosSymbol, arg1);
else if (fun == AtanSymbol) ans = lang2(AtanSymbol, arg1);
else if (fun == GammaSymbol)ans = lang2(GammaSymbol, arg1);
else if (fun == LGammaSymbol)ans = lang2(LGammaSymbol, arg1);
else if (fun == PsiSymbol) ans = lang2(PsiSymbol, arg1);
else ans = Constant(NA_REAL);
/* FIXME */
#ifdef NOTYET
if (length(ans) == 2 && isAtomic(CADR(ans)) && CAR(ans) != MinusSymbol)
c = eval(c, rho);
if (length(c) == 3 && isAtomic(CADR(ans)) && isAtomic(CADDR(ans)))
c = eval(c, rho);
#endif
return ans;
}/* simplify() */
/* D() implements the "derivative table" : */
static SEXP D(SEXP expr, SEXP var)
{
#define PP_S(F,a1,a2) PP(simplify(F,a1,a2))
#define PP_S2(F,a1) PP(simplify(F,a1, R_MissingArg))
SEXP ans=R_NilValue, expr1, expr2;
switch(TYPEOF(expr)) {
case LGLSXP:
case INTSXP:
case REALSXP:
case CPLXSXP:
ans = Constant(0);
break;
case SYMSXP:
if (expr == var) ans = Constant(1.);
else ans = Constant(0.);
break;
case LISTSXP:
if (inherits(expr, "expression")) ans = D(CAR(expr), var);
else ans = Constant(NA_REAL);
break;
case LANGSXP:
if (CAR(expr) == ParenSymbol) {
ans = D(CADR(expr), var);
}
else if (CAR(expr) == PlusSymbol) {
if (length(expr) == 2)
ans = D(CADR(expr), var);
else {
ans = simplify(PlusSymbol,
PP(D(CADR(expr), var)),
PP(D(CADDR(expr), var)));
UNPROTECT(2);
}
}
else if (CAR(expr) == MinusSymbol) {
if (length(expr) == 2) {
ans = simplify(MinusSymbol,
PP(D(CADR(expr), var)),
R_MissingArg);
UNPROTECT(1);
}
else {
ans = simplify(MinusSymbol,
PP(D(CADR(expr), var)),
PP(D(CADDR(expr), var)));
UNPROTECT(2);
}
}
else if (CAR(expr) == TimesSymbol) {
ans = simplify(PlusSymbol,
PP_S(TimesSymbol,PP(D(CADR(expr),var)), CADDR(expr)),
PP_S(TimesSymbol,CADR(expr), PP(D(CADDR(expr),var))));
UNPROTECT(4);
}
else if (CAR(expr) == DivideSymbol) {
PROTECT(expr1 = D(CADR(expr), var));
PROTECT(expr2 = D(CADDR(expr), var));
ans = simplify(MinusSymbol,
PP_S(DivideSymbol, expr1, CADDR(expr)),
PP_S(DivideSymbol,
PP_S(TimesSymbol, CADR(expr), expr2),
PP_S(PowerSymbol,CADDR(expr),PP(Constant(2.)))));
UNPROTECT(7);
}
else if (CAR(expr) == PowerSymbol) {
if (isLogical(CADDR(expr)) || isNumeric(CADDR(expr))) {
ans = simplify(TimesSymbol,
CADDR(expr),
PP_S(TimesSymbol,
PP(D(CADR(expr), var)),
PP_S(PowerSymbol,
CADR(expr),
PP(Constant(asReal(CADDR(expr))-1.)))));
UNPROTECT(4);
}
else {
expr1 = simplify(TimesSymbol,
PP_S(PowerSymbol,
CADR(expr),
PP_S(MinusSymbol,
CADDR(expr),
PP(Constant(1.0)))),
PP_S(TimesSymbol,
CADDR(expr),
PP(D(CADR(expr), var))));
UNPROTECT(5);
PROTECT(expr1);
expr2 = simplify(TimesSymbol,
PP_S(PowerSymbol, CADR(expr), CADDR(expr)),
PP_S(TimesSymbol,
PP_S2(LogSymbol, CADR(expr)),
PP(D(CADDR(expr), var))));
UNPROTECT(4);
PROTECT(expr2);
ans = simplify(PlusSymbol, expr1, expr2);
UNPROTECT(2);
}
}
else if (CAR(expr) == ExpSymbol) {
ans = simplify(TimesSymbol,
expr,
PP(D(CADR(expr), var)));
UNPROTECT(1);
}
else if (CAR(expr) == LogSymbol) {
ans = simplify(DivideSymbol,
PP(D(CADR(expr), var)),
CADR(expr));
UNPROTECT(1);
}
else if (CAR(expr) == CosSymbol) {
ans = simplify(TimesSymbol,
PP_S2(SinSymbol, CADR(expr)),
PP_S2(MinusSymbol, PP(D(CADR(expr), var))));
UNPROTECT(3);
}
else if (CAR(expr) == SinSymbol) {
ans = simplify(TimesSymbol,
PP_S2(CosSymbol, CADR(expr)),
PP(D(CADR(expr), var)));
UNPROTECT(2);
}
else if (CAR(expr) == TanSymbol) {
ans = simplify(DivideSymbol,
PP(D(CADR(expr), var)),
PP_S(PowerSymbol,
PP_S2(CosSymbol, CADR(expr)),
PP(Constant(2.0))));
UNPROTECT(4);
}
else if (CAR(expr) == CoshSymbol) {
ans = simplify(TimesSymbol,
PP_S2(SinhSymbol, CADR(expr)),
PP(D(CADR(expr), var)));
UNPROTECT(2);
}
else if (CAR(expr) == SinhSymbol) {
ans = simplify(TimesSymbol,
PP_S2(CoshSymbol, CADR(expr)),
PP(D(CADR(expr), var))),
UNPROTECT(2);
}
else if (CAR(expr) == TanhSymbol) {
ans = simplify(DivideSymbol,
PP(D(CADR(expr), var)),
PP_S(PowerSymbol,
PP_S2(CoshSymbol, CADR(expr)),
PP(Constant(2.0))));
UNPROTECT(4);
}
else if (CAR(expr) == SqrtSymbol) {
PROTECT(expr1 = allocList(3));
SET_TYPEOF(expr1, LANGSXP);
SETCAR(expr1, PowerSymbol);
SETCADR(expr1, CADR(expr));
SETCADDR(expr1, Constant(0.5));
ans = D(expr1, var);
UNPROTECT(1);
}
else if (CAR(expr) == PnormSymbol) {
ans = simplify(TimesSymbol,
PP_S2(DnormSymbol, CADR(expr)),
PP(D(CADR(expr), var)));
UNPROTECT(2);
}
else if (CAR(expr) == DnormSymbol) {
ans = simplify(TimesSymbol,
PP_S2(MinusSymbol, CADR(expr)),
PP_S(TimesSymbol,
PP_S2(DnormSymbol, CADR(expr)),
PP(D(CADR(expr), var))));
UNPROTECT(4);
}
else if (CAR(expr) == AsinSymbol) {
ans = simplify(DivideSymbol,
PP(D(CADR(expr), var)),
PP_S(SqrtSymbol,
PP_S(MinusSymbol, Constant(1.),
PP_S(PowerSymbol,CADR(expr),Constant(2.))),
R_MissingArg));
UNPROTECT(4);
}
else if (CAR(expr) == AcosSymbol) {
ans = simplify(MinusSymbol,
PP_S(DivideSymbol,
PP(D(CADR(expr), var)),
PP_S(SqrtSymbol,
PP_S(MinusSymbol,Constant(1.),
PP_S(PowerSymbol,
CADR(expr),Constant(2.))),
R_MissingArg)), R_MissingArg);
UNPROTECT(5);
}
else if (CAR(expr) == AtanSymbol) {
ans = simplify(DivideSymbol,
PP(D(CADR(expr), var)),
PP_S(PlusSymbol,Constant(1.),
PP_S(PowerSymbol, CADR(expr),Constant(2.))));
UNPROTECT(3);
}
else if (CAR(expr) == LGammaSymbol) {
ans = simplify(TimesSymbol,
PP(D(CADR(expr), var)),
PP_S2(PsiSymbol, CADR(expr)));
UNPROTECT(2);
}
else if (CAR(expr) == GammaSymbol) {
ans = simplify(TimesSymbol,
PP(D(CADR(expr), var)),
PP_S(TimesSymbol,
expr,
PP_S2(PsiSymbol, CADR(expr))));
UNPROTECT(3);
}
else {
SEXP u = deparse1(CAR(expr), 0, SIMPLEDEPARSE);
error(_("Function '%s' is not in the derivatives table"),
CHAR(STRING_ELT(u, 0)));
}
break;
default:
ans = Constant(NA_REAL);
}
return ans;
#undef PP_S
} /* D() */
static int isPlusForm(SEXP expr)
{
return TYPEOF(expr) == LANGSXP
&& length(expr) == 3
&& CAR(expr) == PlusSymbol;
}
static int isMinusForm(SEXP expr)
{
return TYPEOF(expr) == LANGSXP
&& length(expr) == 3
&& CAR(expr) == MinusSymbol;
}
static int isTimesForm(SEXP expr)
{
return TYPEOF(expr) == LANGSXP
&& length(expr) == 3
&& CAR(expr) == TimesSymbol;
}
static int isDivideForm(SEXP expr)
{
return TYPEOF(expr) == LANGSXP
&& length(expr) == 3
&& CAR(expr) == DivideSymbol;
}
static int isPowerForm(SEXP expr)
{
return (TYPEOF(expr) == LANGSXP
&& length(expr) == 3
&& CAR(expr) == PowerSymbol);
}
static SEXP AddParens(SEXP expr)
{
SEXP e;
if (TYPEOF(expr) == LANGSXP) {
e = CDR(expr);
while(e != R_NilValue) {
SETCAR(e, AddParens(CAR(e)));
e = CDR(e);
}
}
if (isPlusForm(expr)) {
if (isPlusForm(CADDR(expr))) {
SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
}
}
else if (isMinusForm(expr)) {
if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr))) {
SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
}
}
else if (isTimesForm(expr)) {
if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr))
|| isTimesForm(CADDR(expr)) || isDivideForm(CADDR(expr))) {
SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
}
if (isPlusForm(CADR(expr)) || isMinusForm(CADR(expr))) {
SETCADR(expr, lang2(ParenSymbol, CADR(expr)));
}
}
else if (isDivideForm(expr)) {
if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr))
|| isTimesForm(CADDR(expr)) || isDivideForm(CADDR(expr))) {
SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
}
if (isPlusForm(CADR(expr)) || isMinusForm(CADR(expr))) {
SETCADR(expr, lang2(ParenSymbol, CADR(expr)));
}
}
else if (isPowerForm(expr)) {
if (isPowerForm(CADR(expr))) {
SETCADR(expr, lang2(ParenSymbol, CADR(expr)));
}
if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr))
|| isTimesForm(CADDR(expr)) || isDivideForm(CADDR(expr))) {
SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
}
}
return expr;
}
SEXP attribute_hidden do_D(SEXP call, SEXP op, SEXP args, SEXP env)
{
SEXP expr, var;
checkArity(op, args);
if (isExpression(CAR(args))) expr = VECTOR_ELT(CAR(args), 0);
else expr = CAR(args);
var = CADR(args);
if (!isString(var) || length(var) < 1)
errorcall(call, _("variable must be a character string"));
if (length(var) > 1)
warningcall(call,
_("only the first element is used as variable name"));
var = install(CHAR(STRING_ELT(var, 0)));
InitDerivSymbols();
PROTECT(expr = D(expr, var));
expr = AddParens(expr);
UNPROTECT(1);
return expr;
}
/* ------ FindSubexprs ------ and ------ Accumulate ------ */
static void InvalidExpression(char *where)
{
error(_("invalid expression in \"%s\""), where);
}
static int equal(SEXP expr1, SEXP expr2)
{
if (TYPEOF(expr1) == TYPEOF(expr2)) {
switch(TYPEOF(expr1)) {
case NILSXP:
return 1;
case SYMSXP:
return expr1 == expr2;
case LGLSXP:
case INTSXP:
return INTEGER(expr1)[0] == INTEGER(expr2)[0];
case REALSXP:
return REAL(expr1)[0] == REAL(expr2)[0];
case CPLXSXP:
return COMPLEX(expr1)[0].r == COMPLEX(expr2)[0].r
&& COMPLEX(expr1)[0].i == COMPLEX(expr2)[0].i;
case LANGSXP:
case LISTSXP:
return equal(CAR(expr1), CAR(expr2))
&& equal(CDR(expr1), CDR(expr2));
default:
InvalidExpression("equal");
}
}
return 0;
}
static int Accumulate(SEXP expr, SEXP exprlist)
{
SEXP e;
int k;
e = exprlist;
k = 0;
while(CDR(e) != R_NilValue) {
e = CDR(e);
k = k + 1;
if (equal(expr, CAR(e)))
return k;
}
SETCDR(e, CONS(expr, R_NilValue));
return k + 1;
}
static int Accumulate2(SEXP expr, SEXP exprlist)
{
SEXP e;
int k;
e = exprlist;
k = 0;
while(CDR(e) != R_NilValue) {
e = CDR(e);
k = k + 1;
}
SETCDR(e, CONS(expr, R_NilValue));
return k + 1;
}
static SEXP MakeVariable(int k, SEXP tag)
{
char buf[64];
snprintf(buf, 64, "%s%d", CHAR(STRING_ELT(tag, 0)), k);
return install(buf);
}
static int FindSubexprs(SEXP expr, SEXP exprlist, SEXP tag)
{
SEXP e;
int k;
switch(TYPEOF(expr)) {
case SYMSXP:
case LGLSXP:
case INTSXP:
case REALSXP:
case CPLXSXP:
return 0;
break;
case LISTSXP:
if (inherits(expr, "expression"))
return FindSubexprs(CAR(expr), exprlist, tag);
else { InvalidExpression("FindSubexprs"); return -1/*-Wall*/; }
break;
case LANGSXP:
if (CAR(expr) == install("(")) {
return FindSubexprs(CADR(expr), exprlist, tag);
}
else {
e = CDR(expr);
while(e != R_NilValue) {
if ((k = FindSubexprs(CAR(e), exprlist, tag)) != 0)
SETCAR(e, MakeVariable(k, tag));
e = CDR(e);
}
return Accumulate(expr, exprlist);
}
break;
default:
InvalidExpression("FindSubexprs");
return -1/*-Wall*/;
}
}
static int CountOccurrences(SEXP sym, SEXP lst)
{
switch(TYPEOF(lst)) {
case SYMSXP:
return lst == sym;
case LISTSXP:
case LANGSXP:
return CountOccurrences(sym, CAR(lst))
+ CountOccurrences(sym, CDR(lst));
default:
return 0;
}
}
static SEXP Replace(SEXP sym, SEXP expr, SEXP lst)
{
switch(TYPEOF(lst)) {
case SYMSXP:
if (lst == sym) return expr;
else return lst;
case LISTSXP:
case LANGSXP:
SETCAR(lst, Replace(sym, expr, CAR(lst)));
SETCDR(lst, Replace(sym, expr, CDR(lst)));
return lst;
default:
return lst;
}
}
static SEXP CreateGrad(SEXP names)
{
SEXP p, q, data, dim, dimnames;
int i, n;
n = length(names);
PROTECT(dimnames = lang3(R_NilValue, R_NilValue, R_NilValue));
SETCAR(dimnames, install("list"));
p = install("c");
PROTECT(q = allocList(n));
SETCADDR(dimnames, LCONS(p, q));
UNPROTECT(1);
for(i=0 ; i<n ; i++) {
SETCAR(q, allocVector(STRSXP, 1));
SET_STRING_ELT(CAR(q), 0, STRING_ELT(names, i));
q = CDR(q);
}
PROTECT(dim = lang3(R_NilValue, R_NilValue, R_NilValue));
SETCAR(dim, install("c"));
SETCADR(dim, lang2(install("length"), install(".value")));
SETCADDR(dim, allocVector(REALSXP, 1));
REAL(CADDR(dim))[0] = length(names);
PROTECT(data = allocVector(REALSXP, 1));
REAL(data)[0] = 0;
PROTECT(p = lang4(install("array"), data, dim, dimnames));
p = lang3(install("<-"), install(".grad"), p);
UNPROTECT(4);
return p;
}
static SEXP CreateHess(SEXP names)
{
SEXP p, q, data, dim, dimnames;
int i, n;
n = length(names);
PROTECT(dimnames = lang4(R_NilValue, R_NilValue, R_NilValue, R_NilValue));
SETCAR(dimnames, install("list"));
p = install("c");
PROTECT(q = allocList(n));
SETCADDR(dimnames, LCONS(p, q));
UNPROTECT(1);
for(i=0 ; i<n ; i++) {
SETCAR(q, allocVector(STRSXP, 1));
SET_STRING_ELT(CAR(q), 0, STRING_ELT(names, i));
q = CDR(q);
}
SETCADDDR(dimnames, duplicate(CADDR(dimnames)));
PROTECT(dim = lang4(R_NilValue, R_NilValue, R_NilValue,R_NilValue));
SETCAR(dim, install("c"));
SETCADR(dim, lang2(install("length"), install(".value")));
SETCADDR(dim, allocVector(REALSXP, 1));
REAL(CADDR(dim))[0] = length(names);
SETCADDDR(dim, allocVector(REALSXP, 1));
REAL(CADDDR(dim))[0] = length(names);
PROTECT(data = allocVector(REALSXP, 1));
REAL(data)[0] = 0;
PROTECT(p = lang4(install("array"), data, dim, dimnames));
p = lang3(install("<-"), install(".hessian"), p);
UNPROTECT(4);
return p;
}
static SEXP DerivAssign(SEXP name, SEXP expr)
{
SEXP ans, newname;
PROTECT(ans = lang3(install("<-"), R_NilValue, expr));
PROTECT(newname = allocVector(STRSXP, 1));
SET_STRING_ELT(newname, 0, name);
SETCADR(ans, lang4(install("["), install(".grad"), R_MissingArg, newname));
UNPROTECT(2);
return ans;
}
static SEXP lang5(SEXP s, SEXP t, SEXP u, SEXP v, SEXP w)
{
PROTECT(s);
s = LCONS(s, list4(t, u, v, w));
UNPROTECT(1);
return s;
}
static SEXP HessAssign1(SEXP name, SEXP expr)
{
SEXP ans, newname;
PROTECT(ans = lang3(install("<-"), R_NilValue, expr));
PROTECT(newname = allocVector(STRSXP, 1));
SET_STRING_ELT(newname, 0, name);
SETCADR(ans, lang5(install("["), install(".hessian"), R_MissingArg,
newname, newname));
UNPROTECT(2);
return ans;
}
static SEXP HessAssign2(SEXP name1, SEXP name2, SEXP expr)
{
SEXP ans, newname1, newname2;
PROTECT(newname1 = allocVector(STRSXP, 1));
PROTECT(newname2 = allocVector(STRSXP, 1));
SET_STRING_ELT(newname1, 0, name1);
SET_STRING_ELT(newname2, 0, name2);
ans = lang3(install("<-"),
lang5(install("["), install(".hessian"), R_MissingArg,
newname1, newname2),
lang3(install("<-"),
lang5(install("["), install(".hessian"), R_MissingArg,
newname2, newname1),
expr));
UNPROTECT(2);
return ans;
}
/* attr(.value, "gradient") <- .grad */
static SEXP AddGrad()
{
SEXP ans;
PROTECT(ans = mkString("gradient"));
PROTECT(ans = lang3(install("attr"), install(".value"), ans));
ans = lang3(install("<-"), ans, install(".grad"));
UNPROTECT(2);
return ans;
}
static SEXP AddHess()
{
SEXP ans;
PROTECT(ans = mkString("hessian"));
PROTECT(ans = lang3(install("attr"), install(".value"), ans));
ans = lang3(install("<-"), ans, install(".hessian"));
UNPROTECT(2);
return ans;
}
static SEXP Prune(SEXP lst)
{
if (lst == R_NilValue)
return lst;
SETCDR(lst, Prune(CDR(lst)));
if (CAR(lst) == R_MissingArg)
return CDR(lst);
else return lst ;
}
SEXP attribute_hidden do_deriv(SEXP call, SEXP op, SEXP args, SEXP env)
{
/* deriv.default(expr, namevec, function.arg, tag, hessian) */
SEXP ans, ans2, expr, funarg, names, s;
int f_index, *d_index, *d2_index;
int i, j, k, nexpr, nderiv=0, hessian;
char *vmax;
SEXP exprlist, tag;
checkArity(op, args);
vmax = vmaxget();
InitDerivSymbols();
PROTECT(exprlist = LCONS(install("{"), R_NilValue));
/* expr: */
if (isExpression(CAR(args)))
PROTECT(expr = VECTOR_ELT(CAR(args), 0));
else PROTECT(expr = CAR(args));
args = CDR(args);
/* namevec: */
names = CAR(args);
if (!isString(names) || (nderiv = length(names)) < 1)
errorcall(call, _("invalid variable names"));
args = CDR(args);
/* function.arg: */
funarg = CAR(args);
args = CDR(args);
/* tag: */
tag = CAR(args);
if (!isString(tag) || length(tag) < 1
|| length(STRING_ELT(tag, 0)) < 1 || length(STRING_ELT(tag, 0)) > 60)
errorcall(call, _("invalid tag"));
args = CDR(args);
/* hessian: */
hessian = asLogical(CAR(args));
/* NOTE: FindSubexprs is destructive, hence the duplication */
PROTECT(ans = duplicate(expr));
f_index = FindSubexprs(ans, exprlist, tag);
d_index = (int*)R_alloc(nderiv, sizeof(int));
if (hessian)
d2_index = (int*)R_alloc((nderiv * (1 + nderiv))/2, sizeof(int));
else d2_index = d_index;/*-Wall*/
UNPROTECT(1);
for(i=0, k=0; i<nderiv ; i++) {
PROTECT(ans = duplicate(expr));
PROTECT(ans = D(ans, install(CHAR(STRING_ELT(names, i)))));
ans2 = duplicate(ans); /* keep a temporary copy */
d_index[i] = FindSubexprs(ans, exprlist, tag); /* examine the derivative first */
ans = duplicate(ans2); /* restore the copy */
if (hessian) {
for(j = i; j < nderiv; j++) {
PROTECT(ans2 = duplicate(ans));
PROTECT(ans2 = D(ans2, install(CHAR(STRING_ELT(names, j)))));
d2_index[k] = FindSubexprs(ans2, exprlist, tag);
k++;
UNPROTECT(2);
}
}
UNPROTECT(2);
}
nexpr = length(exprlist) - 1;
if (f_index) {
Accumulate2(MakeVariable(f_index, tag), exprlist);
}
else {
PROTECT(ans = duplicate(expr));
Accumulate2(expr, exprlist);
UNPROTECT(1);
}
Accumulate2(R_NilValue, exprlist);
if (hessian) { Accumulate2(R_NilValue, exprlist); }
for (i = 0, k = 0; i < nderiv ; i++) {
if (d_index[i]) {
Accumulate2(MakeVariable(d_index[i], tag), exprlist);
if (hessian) {
PROTECT(ans = duplicate(expr));
PROTECT(ans = D(ans, install(CHAR(STRING_ELT(names, i)))));
for (j = i; j < nderiv; j++) {
if (d2_index[k]) {
Accumulate2(MakeVariable(d2_index[k], tag), exprlist);
} else {
PROTECT(ans2 = duplicate(ans));
PROTECT(ans2 = D(ans2,
install(CHAR(STRING_ELT(names, j)))));
Accumulate2(ans2, exprlist);
UNPROTECT(2);
}
k++;
}
UNPROTECT(2);
}
} else { /* the first derivative is constant or simple variable */
PROTECT(ans = duplicate(expr));
PROTECT(ans = D(ans, install(CHAR(STRING_ELT(names, i)))));
Accumulate2(ans, exprlist);
UNPROTECT(2);
if (hessian) {
for (j = i; j < nderiv; j++) {
if (d2_index[k]) {
Accumulate2(MakeVariable(d2_index[k], tag), exprlist);
} else {
PROTECT(ans2 = duplicate(ans));
PROTECT(ans2 = D(ans2,
install(CHAR(STRING_ELT(names, j)))));
if(isZero(ans2)) Accumulate2(R_MissingArg, exprlist);
else Accumulate2(ans2, exprlist);
UNPROTECT(2);
}
k++;
}
}
}
}
Accumulate2(R_NilValue, exprlist);
Accumulate2(R_NilValue, exprlist);
if (hessian) { Accumulate2(R_NilValue, exprlist); }
i = 0;
ans = CDR(exprlist);
while (i < nexpr) {
if (CountOccurrences(MakeVariable(i+1, tag), CDR(ans)) < 2) {
SETCDR(ans, Replace(MakeVariable(i+1, tag), CAR(ans), CDR(ans)));
SETCAR(ans, R_MissingArg);
}
else SETCAR(ans, lang3(install("<-"), MakeVariable(i+1, tag), AddParens(CAR(ans))));
i = i + 1;
ans = CDR(ans);
}
/* .value <- ... */
SETCAR(ans, lang3(install("<-"), install(".value"), AddParens(CAR(ans))));
ans = CDR(ans);
/* .grad <- ... */
SETCAR(ans, CreateGrad(names));
ans = CDR(ans);
/* .hessian <- ... */
if (hessian) { SETCAR(ans, CreateHess(names)); ans = CDR(ans); }
/* .grad[, "..."] <- ... */
for (i = 0; i < nderiv ; i++) {
SETCAR(ans, DerivAssign(STRING_ELT(names, i), AddParens(CAR(ans))));
ans = CDR(ans);
if (hessian) {
for (j = i; j < nderiv; j++) {
if (CAR(ans) != R_MissingArg) {
if (i == j) {
SETCAR(ans, HessAssign1(STRING_ELT(names, i),
AddParens(CAR(ans))));
} else {
SETCAR(ans, HessAssign2(STRING_ELT(names, i),
STRING_ELT(names, j),
AddParens(CAR(ans))));
}
}
ans = CDR(ans);
}
}
}
/* attr(.value, "gradient") <- .grad */
SETCAR(ans, AddGrad());
ans = CDR(ans);
if (hessian) { SETCAR(ans, AddHess()); ans = CDR(ans); }
/* .value */
SETCAR(ans, install(".value"));
/* Prune the expression list removing eliminated sub-expressions */
SETCDR(exprlist, Prune(CDR(exprlist)));
if (TYPEOF(funarg) == LGLSXP && LOGICAL(funarg)[0]) { /* fun = TRUE */
funarg = names;
}
if (TYPEOF(funarg) == CLOSXP)
{
s = allocSExp(CLOSXP);
SET_FORMALS(s, FORMALS(funarg));
SET_CLOENV(s, CLOENV(funarg));
funarg = s;
SET_BODY(funarg, exprlist);
}
else if (isString(funarg)) {
PROTECT(names = duplicate(funarg));
PROTECT(funarg = allocSExp(CLOSXP));
PROTECT(ans = allocList(length(names)));
SET_FORMALS(funarg, ans);
for(i = 0; i < length(names); i++) {
SET_TAG(ans, install(CHAR(STRING_ELT(names, i))));
SETCAR(ans, R_MissingArg);
ans = CDR(ans);
}
UNPROTECT(3);
SET_BODY(funarg, exprlist);
SET_CLOENV(funarg, R_GlobalEnv);
}
else {
funarg = allocVector(EXPRSXP, 1);
SET_VECTOR_ELT(funarg, 0, exprlist);
/* funarg = lang2(install("expression"), exprlist); */
}
UNPROTECT(2);
vmaxset(vmax);
return funarg;
}
syntax highlighted by Code2HTML, v. 0.9.1