//Copyright (C) 2004 Dominic Letourneau (dominic.letourneau@usherbrooke.ca) #ifndef _MUL_OPERATORS_CC_ #define _MUL_OPERATORS_CC_ #include "operators.h" #include "net_types.h" #include "Vector.h" #include "Matrix.h" #include "Complex.h" //@implements core using namespace std; namespace FD { template ObjectRef mulCTypeFunction(ObjectRef op1, ObjectRef op2) { RCPtr op1Value = op1; RCPtr op2Value = op2; RCPtr resultValue(Z::alloc(static_cast (static_cast(op1Value->val()) * static_cast(op2Value->val())))); return resultValue; } REGISTER_ALL_SCALAR_VTABLE(mulVtable, mulCTypeFunction); template ObjectRef mulVectorFunction(ObjectRef op1, ObjectRef op2) { RCPtr op1Value = op1; RCPtr op2Value = op2; if (op1Value->size() != op2Value->size()) { throw new GeneralException("MulVectorFunction : Vector size mismatch ",__FILE__,__LINE__); } //creating new vector RCPtr resultValue(Z::alloc(op1Value->size())); for (int i = 0; i < resultValue->size(); i++) { (*resultValue)[i] = static_cast ((*op1Value)[i]) * static_cast ((*op2Value)[i]); } return resultValue; } REGISTER_ALL_VECTOR_VTABLE(mulVtable, mulVectorFunction); template ObjectRef mulMatrixFunction(ObjectRef op1, ObjectRef op2) { RCPtr op1Value = op1; RCPtr op2Value = op2; if (op1Value->nrows() != op2Value->nrows() || op1Value->ncols() != op2Value->ncols()) { throw new GeneralException("MulMatrixFunction : Matrix size mismatch ",__FILE__,__LINE__); } //creating new Matrix //TODO use Matrix pool? RCPtr resultValue(new Z(op1Value->nrows(), op1Value->ncols())); for (int i = 0; i < resultValue->nrows(); i++) { for (int j = 0; j < resultValue->ncols(); j++) { (*resultValue)(i,j) = static_cast ((*op1Value)(i,j)) * static_cast ((*op2Value)(i,j)); } } return resultValue; } REGISTER_ALL_MATRIX_VTABLE(mulVtable, mulMatrixFunction); ObjectRef mulVectorObjectRef(ObjectRef op1, ObjectRef op2) { RCPtr > op1Value = op1; RCPtr > op2Value = op2; if (op1Value->size() != op2Value->size()) { throw new GeneralException("MulVectorFunction : Vector size mismatch ",__FILE__,__LINE__); } //creating new vector RCPtr > resultValue(new Vector(op1Value->size())); for (int i = 0; i < resultValue->size(); i++) { (*resultValue)[i] = (*op1Value)[i] * (*op2Value)[i]; } return resultValue; } REGISTER_DOUBLE_VTABLE(mulVtable,mulVectorObjectRef,Vector,Vector); ObjectRef mulMatrixObjectRef(ObjectRef op1, ObjectRef op2) { RCPtr > op1Value = op1; RCPtr > op2Value = op2; if (op1Value->nrows() != op2Value->nrows() || op1Value->ncols() != op2Value->ncols()) { throw new GeneralException("MulMatrixFunction : Matrix size mismatch ",__FILE__,__LINE__); } //creating new Matrix RCPtr > resultValue(new Matrix(op1Value->nrows(), op1Value->ncols())); for (int i = 0; i < resultValue->nrows(); i++) { for (int j = 0; j < resultValue->ncols(); j++) { (*resultValue)(i,j) = (*op1Value)(i,j) * (*op2Value)(i,j); } } return resultValue; } REGISTER_DOUBLE_VTABLE(mulVtable,mulMatrixObjectRef,Matrix,Matrix); template ObjectRef mulVectorScalarFunction(ObjectRef op1, ObjectRef op2) { RCPtr op1Value = op1; RCPtr op2Value = op2; //creating new vector RCPtr resultValue(Z::alloc(op1Value->size())); for (int i = 0; i < resultValue->size(); i++) { (*resultValue)[i] = static_cast ((*op1Value)[i]) * static_cast (op2Value->val()); } return resultValue; } REGISTER_ALL_VECTOR_SCALAR_VTABLE(mulVtable, mulVectorScalarFunction); template ObjectRef mulMatrixScalarFunction(ObjectRef op1, ObjectRef op2) { RCPtr op1Value = op1; RCPtr op2Value = op2; //creating new Matrix //TODO use Matrix pool? RCPtr resultValue(new Z(op1Value->nrows(), op1Value->ncols())); for (int i = 0; i < resultValue->nrows(); i++) { for (int j = 0; j < resultValue->ncols(); j++) { (*resultValue)(i,j) = static_cast ((*op1Value)(i,j)) * static_cast (op2Value->val()); } } return resultValue; } REGISTER_ALL_MATRIX_SCALAR_VTABLE(mulVtable, mulMatrixScalarFunction); }//namespace FD #endif