""" This module manages different implementations of matrix multiplication
related routines.
"""
import generic as _gen
import numarrayall as _na
try:
"""
Provides a BLAS-optimized (i.e. much faster) replacement `dot`
function for numarray arrays.
BLAS routines exist only for 32 & 64 bit float and complex types;
if BLAS routines cannot be used the dot defined here defers to the
standard `numarray.dot`.
This package is integrated with numarray so its improved functions
are used whenever they're available; they're not always available
because not everyone needs or wants to install a BLAS.
"""
__author__ = "Richard Everson (R.M.Everson@exeter.ac.uk)"
__revision__ = "$Revision: 1.2 $"
__version__ = "1.0"
import numarray._dotblas as _dotblas
import numarray._numarray as _numarray
USING_BLAS = 1
except ImportError:
USING_BLAS = 0
if not USING_BLAS:
from numarray._numarray import dot, innerproduct
else:
def dot(a, b):
"""returns matrix-multiplication between a and b.
The product-sum is over the last dimension of a and the
second-to-last dimension of b.
NB: No conjugation of complex arguments is performed.
This version uses the BLAS optimized routines where possible.
"""
try:
return _dotblas.dot(a, b)
except TypeError:
try:
return _numarray.dot(a, b)
except TypeError,detail:
if _na.shape(a) == () or _na.shape(b) == ():
return a*b
else:
raise TypeError, detail or "invalid types for dot"
def innerproduct(a, b):
"""returns inner product between a and b.
The product-sum is over the last dimension of a and b.
NB: No conjugation of complex arguments is performed.
This version uses the BLAS optimized routines where possible.
"""
try:
return _dotblas.innerproduct(a, b)
except TypeError:
try:
return _numarray.innerproduct(a, b)
except TypeError,detail:
if (_na.shape(a) == () or _na.shape(b) == ()):
return a*b
else:
raise TypeError, detail or "invalid types for innerproduct"
def vdot(a, b):
"""Returns the dot product of 2 vectors (or anything that can be made into
a vector). NB: this is not the same as `dot`, as it takes the conjugate
of its first argument if complex and always returns a scalar."""
a, b = _na.ravel(a), _na.ravel(b)
try:
return _dotblas.vdot(a, b)
# in case we get an integer Value
except TypeError:
return _numarray.dot(a, b)
matrixmultiply = dot
def outerproduct(array1, array2):
"""outerproduct(array1, array2) computes the NxM outerproduct of N vector
'array1' and M vector 'array2', where result[i,j] = array1[i]*array2[j].
"""
array1=_gen.reshape(
_na.asarray(array1), (-1,1)) # ravel array1 into an Nx1
array2=_gen.reshape(
_na.asarray(array2), (1,-1)) # ravel array2 into a 1xM
return matrixmultiply(array1,array2) # return NxM result
def tensormultiply(array1, array2):
"""tensormultiply returns the product for any rank >=1 arrays, defined as:
r_{xxx, yyy} = \sum_k array1_{xxx, k} array2_{k, yyyy}
where xxx, yyy denote the rest of the a and b dimensions.
"""
array1, array2 = _na.asarray(array1), _na.asarray(array2)
if array1.shape[-1] != array2.shape[0]:
raise ValueError, "Unmatched dimensions"
shape = array1.shape[:-1] + array2.shape[1:]
return _gen.reshape(dot(_gen.reshape(array1, (-1, array1.shape[-1])),
_gen.reshape(array2, (array2.shape[0], -1))), shape)
def kroneckerproduct(a,b):
'''Computes a otimes b where otimes is the Kronecker product operator.
Note: the Kronecker product is also known as the matrix direct product
or tensor product. It is defined as follows for 2D arrays a and b
where shape(a)=(m,n) and shape(b)=(p,q):
c = a otimes b => cij = a[i,j]*b where cij is the ij-th submatrix of c.
So shape(c)=(m*p,n*q).
>>> print kroneckerproduct([[1,2]],[[3],[4]])
[[3 6]
[4 8]]
>>> print kroneckerproduct([[1,2]],[[3,4]])
[ [3 4 6 8]]
>>> print kroneckerproduct([[1],[2]],[[3],[4]])
[[3]
[4]
[6]
[8]]
'''
a, b = _na.asarray(a), _na.asarray(b)
if not (len(a.shape)==2 and len(b.shape)==2):
raise ValueError, 'Input must be 2D arrays.'
if not a.iscontiguous():
a = _gen.reshape(a, a.shape)
if not b.iscontiguous():
b = _gen.reshape(b, b.shape)
o = outerproduct(a,b)
o.shape = a.shape + b.shape
return _gen.concatenate(_gen.concatenate(o, axis=1), axis=1)
syntax highlighted by Code2HTML, v. 0.9.1