"""Numeric module defining a multi-dimensional array and useful procedures for
   Numerical computation.

   Functions
   =========
   array                       NumPy Array construction
   zeros                       Return an array of all zeros
   fromstring                  Construct array from (byte) string
   take                        Select sub-arrays using sequence of indices
   put                         Set sub-arrays using sequence of 1-D indices
   putmask                     Set portion of arrays using a mask 
   reshape                     Return array with new shape
   repeat                      Repeat elements of array
   choose                      Construct new array from indexed array tuple
   cross_correlate             Correlate two 1-d arrays
   searchsorted                Search for element in 1-d array
   sum                         Total sum over a specified dimension
   cumsum                      Cumulative sum over a specified dimension
   product                     Total product over a specified dimension
   cumproduct                  Cumulative product over a specified dimension
   alltrue                     Logical and over an entire axis
   sometrue                    Logical or over an entire axis
   allclose		       Tests if sequences are essentially equal
   
   arrayrange (arange)         Return regularly spaced array
   asarray                     Guarantee NumPy array
   sarray                      Guarantee a NumPy array that keeps precision 
   convolve                    Convolve two 1-d arrays
   swapaxes                    Exchange axes
   concatenate                 Join arrays together
   transpose                   Permute axes
   sort                        Sort elements of array
   argsort                     Indices of sorted array
   argmax                      Index of largest value                      
   argmin                      Index of smallest value
   innerproduct                Innerproduct of two arrays
   dot                         Dot product (matrix multiplication)
   outerproduct                Outerproduct of two arrays
   resize                      Return array with arbitrary new shape
   indices                     Tuple of indices
   fromfunction                Construct array from universal function
   diagonal                    Return diagonal array
   trace                       Trace of array
   dump                        Dump array to file object (pickle)
   dumps                       Return pickled string representing data
   load                        Return array stored in file object
   loads                       Return array from pickled string
   ravel                       Return array as 1-D 
   nonzero                     Indices of nonzero elements for 1-D array
   shape                       Shape of array
   where                       Construct array from binary result
   compress                    Elements of array where condition is true
   clip                        Clip array between two values
   zeros                       Array of all zeros
   ones                        Array of all ones
   identity                    2-D identity array (matrix)

   (Universal) Math Functions
   ====================
   add                    logical_or             exp        
   subtract               logical_xor            log        
   multiply               logical_not            log10      
   divide                 maximum                sin        
   divide_safe            minimum                sinh       
   conjugate              bitwise_and            sqrt       
   power                  bitwise_or             tan        
   absolute               bitwise_xor            tanh       
   negative               invert                 ceil       
   greater                left_shift             fabs       
   greater_equal          right_shift            floor      
   less                   arccos                 arctan2    
   less_equal             arcsin                 fmod       
   equal                  arctan                 hypot      
   not_equal              cos                    around     
   logical_and            cosh                   sign
"""   
import numeric_version
__version__ = numeric_version.version
del numeric_version

import multiarray
from umath import * 
from Precision import *

import _numpy # for freeze dependency resolution (at least on Mac)

import string, types, math

#Use this to add a new axis to an array
NewAxis = None

#The following functions are considered builtin, they all might be
#in C some day

def arrayrange(start, stop=None, step=1, typecode=None):
    """Just like range() except it returns a array whose type can be specfied
    by the keyword argument typecode.
    """
    
    if (stop == None):
        stop = start
        start = 0
    n = int(math.ceil(float(stop-start)/step))
    if n <= 0:
        m = zeros( (0,) )+(step+start+stop)
    else:
        m = (add.accumulate(ones((n,), Int))-1)*step +(start+(stop-stop))
        # the last bit is to deal with e.g. Longs -- 3L-3L==0L
    if typecode != None and m.typecode() != typecode:
        return m.astype(typecode)
    else:
        return m

array = multiarray.array
zeros = multiarray.zeros

def asarray(a, typecode=None, savespace=0):
    """asarray(a,typecode=None) returns a as a NumPy array.  Unlike array(),
    no copy is performed if a is already an array.
    """
    return multiarray.array(a, typecode, copy=0, savespace=savespace)

def sarray(a, typecode=None, copy=0):
    """sarray(a, typecode=None, copy=0) calls array with savespace=1."""
    return multiarray.array(a, typecode, copy, savespace=1)

fromstring = multiarray.fromstring
take = multiarray.take
reshape = multiarray.reshape
repeat = multiarray.repeat
choose = multiarray.choose
cross_correlate = multiarray.cross_correlate

def put (a, ind, v):
    """put(a, ind, v) results in a[n] = v[n] for all n in ind
       If v is shorter than mask it will be repeated as necessary.
       In particular v can be a scalar or length 1 array.
    """
    multiarray.put (a, ind, array(v, copy=0).astype(a.typecode()))

def putmask (a, mask, v):
    """putmask(a, mask, v) results in a = v for all places mask is true.
       If v is shorter than mask it will be repeated as necessary.
       In particular v can be a scalar or length 1 array.
    """
    multiarray.putmask (a, mask, array(v, copy=0).astype(a.typecode()))

def convolve(a,v,mode=0):
    """convolve(a,v,mode=0) returns the discrete, linear convolution of 1-D
    sequences a and v; mode can be 0 (full), 1 (same), or 2 (valid)
    to specify size of resulting sequence.
    """
    if (len(v) > len(a)):
        temp = a
        a = v
        v = temp
        del temp
    return cross_correlate(a,asarray(v)[::-1],mode)

ArrayType = multiarray.arraytype

def swapaxes(a, axis1, axis2):
    """swapaxes(a, axis1, axis2) returns array a with axis1 and axis2
    interchanged.
    """
    n = len(shape(a))
    if n <= 1: return a
    new_axes = arange(n)
    new_axes[axis1] = axis2
    new_axes[axis2] = axis1
    return multiarray.transpose(a, new_axes)

arraytype = multiarray.arraytype
#add extra intelligence to the basic C functions
def concatenate(a, axis=0):
    """concatenate(a, axis=0) joins the tuple of sequences in a into a single
    NumPy array.
    """
    if axis == 0:
        return multiarray.concatenate(a)
    else:
        new_list = []
        for m in a:
            new_list.append(swapaxes(m, axis, 0))
    return swapaxes(multiarray.concatenate(new_list), axis, 0)

def transpose(a, axes=None):
    """transpose(a, axes=None) returns array with dimensions permuted
    according to axes.  If axes is None (default) returns array with
    dimensions reversed.
    """
    if axes == None:
        axes = arange(len(array(a).shape))[::-1]
    return multiarray.transpose(a, axes)

def sort(a, axis=-1):
    """sort(a,axis=-1) returns array with elements sorted along given axis.
    """
    if axis != -1: a = swapaxes(a, axis, -1)
    s = multiarray.sort(a)
    if axis != -1: s = swapaxes(s, axis, -1)
    return s

def argsort(a, axis=-1):
    """argsort(a,axis=-1) return the indices into a of the sorted array
    along the given axis, so that take(a,result,axis) is the sorted array.
    """
    if axis != -1: a = swapaxes(a, axis, -1)
    s = multiarray.argsort(a)
    if axis != -1: s = swapaxes(s, axis, -1)
    return s

def argmax(a, axis=-1):
    """argmax(a,axis=-1) returns the indices to the maximum value of the
    1-D arrays along the given axis.    
    """
    if axis != -1: a = swapaxes(a, axis, -1)
    s = multiarray.argmax(a)
    #probably need a swap here if > 2d
    if axis != -1: s = swapaxes(s, axis, -1)
    return s

def argmin(x, axis=-1):
    """argmin(a,axis=-1) returns the indices to the minimum value of the
    1-D arrays along the given axis.    
    """
    return argmax(negative(x), axis)


searchsorted = multiarray.binarysearch

def innerproduct(a,b):
    """innerproduct(a,b) returns the dot product of two arrays, which has
    shape a.shape[:-1] + b.shape[:-1] with elements computed by summing the
    product of the elements from the last dimensions of a and b.
    """
    try:
        return multiarray.innerproduct(a,b)
    except TypeError,detail:
        if array(a).shape == () or array(b).shape == ():
            return a*b
        else:
            raise TypeError, detail or "invalid types for dot"

def outerproduct(a,b):
   """outerproduct(a,b) returns the outer product of two vectors.
      result(i,j) = a(i)*b(j) when a and b are vectors
      Will accept any arguments that can be made into vectors.
   """
   return asarray(a).flat[:,NewAxis]*asarray(b).flat[NewAxis,:]

def dot(a, b):
    """dot(a,b) returns matrix-multiplication between a and b.  The product-sum
    is over the last dimension of a and the second-to-last dimension of b.
    """
    return innerproduct(a, swapaxes(b, -1, -2))

#This is obsolete, don't use in new code
matrixmultiply = dot

#Use Konrad's printing function (modified for both str and repr now)
from ArrayPrinter import array2string
def array_repr(a, max_line_width = None, precision = None, suppress_small = None):
    return array2string(a, max_line_width, precision, suppress_small, ', ', 1)

def array_str(a, max_line_width = None, precision = None, suppress_small = None):
    return array2string(a, max_line_width, precision, suppress_small, ' ', 0)
    
multiarray.set_string_function(array_str, 0)
multiarray.set_string_function(array_repr, 1)

#This is a nice value to have around
#Maybe in sys some day
LittleEndian = fromstring("\001"+"\000"*7, 'i')[0] == 1

def resize(a, new_shape):
    """resize(a,new_shape) returns a new array with the specified shape.
    The original array's total size can be any size.
    """

    a = ravel(a)
    if not len(a): return zeros(new_shape, a.typecode())
    total_size = multiply.reduce(new_shape)
    n_copies = total_size / len(a)
    extra = total_size % len(a)

    if extra != 0: 
        n_copies = n_copies+1
        extra = len(a)-extra

    a = concatenate( (a,)*n_copies)
    if extra > 0:
        a = a[:-extra]

    return reshape(a, new_shape)

def indices(dimensions, typecode=None):
    """indices(dimensions,typecode=None) returns an array representing a grid
    of indices with row-only, and column-only variation.
    """
    tmp = ones(dimensions, typecode)
    lst = []
    for i in range(len(dimensions)):
        lst.append( add.accumulate(tmp, i, )-1 )
    return array(lst)

def fromfunction(function, dimensions):
    """fromfunction(function, dimensions) returns an array constructed by
    calling function on a tuple of number grids.  The function should
    accept as many arguments as there are dimensions which is a list of
    numbers indicating the length of the desired output for each axis.
    """
    return apply(function, tuple(indices(dimensions)))
    

def diagonal(a, offset= 0, axis1=0, axis2=1):
    """diagonal(a, offset=0, axis1=0, axis2=1) returns the given diagonals
    defined by the last two dimensions of the array.
    """
    a = array (a)
    if axis2 < axis1: axis1, axis2 = axis2, axis1
    if axis2 > 1:
        new_axes = range (len (a.shape))
        del new_axes [axis2]; del new_axes [axis1]
        new_axes [0:0] = [axis1, axis2]
        a = transpose (a, new_axes)
    s = a.shape
    if len (s) == 2:
        n1 = s [0]
        n2 = s [1]
        n = n1 * n2
        s = (n,)
        a = reshape (a, s)
        if offset < 0:
            return take (a, range ( - n2 * offset, min(n2, n1+offset) * (n2+1) - n2 * offset, n2+1), 0)
        else:
            return take (a, range (offset,         min(n1, n2-offset) * (n2+1) + offset,      n2+1), 0)
    else :
        my_diagonal = []
        for i in range (s [0]) :
            my_diagonal.append (diagonal (a [i], offset))
        return array (my_diagonal)

def trace(a, offset=0, axis1=0, axis2=1):
    """trace(a,offset=0, axis1=0, axis2=1) returns the sum along diagonals
    (defined by the last two dimenions) of the array.
    """
    return add.reduce(diagonal(a, offset, axis1, axis2))


# These two functions are used in my modified pickle.py so that
# matrices can be pickled.  Notice that matrices are written in 
# binary format for efficiency, but that they pay attention to
# byte-order issues for  portability.

def DumpArray(m, fp):    
    if m.typecode() == 'O': 
        raise TypeError, "Numeric Pickler can't pickle arrays of Objects"
    s = m.shape
    if LittleEndian: endian = "L"
    else: endian = "B"
    fp.write("A%s%s%d " % (m.typecode(), endian, m.itemsize()))
    for d in s:
        fp.write("%d "% d)
    fp.write('\n')
    fp.write(m.tostring())

def LoadArray(fp):
    ln = string.split(fp.readline())
    if ln[0][0] == 'A': ln[0] = ln[0][1:] # Nasty hack showing my ignorance of pickle
    typecode = ln[0][0]
    endian = ln[0][1]
    
    shape = map(lambda x: string.atoi(x), ln[1:])
    itemsize = string.atoi(ln[0][2:])

    sz = reduce(multiply, shape)*itemsize
    data = fp.read(sz)
        
    m = fromstring(data, typecode)
    m = reshape(m, shape)

    if (LittleEndian and endian == 'B') or (not LittleEndian and endian == 'L'):
        return m.byteswapped()
    else:
        return m

import pickle, copy
class Unpickler(pickle.Unpickler):
    def load_array(self):
        self.stack.append(LoadArray(self))
    
    dispatch = copy.copy(pickle.Unpickler.dispatch)    
    dispatch['A'] = load_array

class Pickler(pickle.Pickler):
    def save_array(self, object):
        DumpArray(object, self)

    dispatch = copy.copy(pickle.Pickler.dispatch)        
    dispatch[ArrayType] = save_array

#Convenience functions
from StringIO import StringIO

def dump(object, file):
    """dump(object, file) pickles (binary-writes) the object to an open file.
    """
    Pickler(file).dump(object)

def dumps(object):
    """dumps(object) pickles (binary-writes) the object and returns the byte
    stream.
    """
    file = StringIO()
    Pickler(file).dump(object)
    return file.getvalue()

def load(file):
    """load(file) returns an array from the open file pointing to pickled data. 
    """
    return Unpickler(file).load()

def loads(str):
    """loads(str) returns an array from a byte stream containing its pickled
    representation.
    """
    file = StringIO(str)
    return Unpickler(file).load()

# slightly different format uses the copy_reg mechanism
import copy_reg

def array_constructor(shape, typecode, thestr, Endian=LittleEndian):
    x = fromstring(thestr, typecode)
    x.shape = shape
    if LittleEndian != Endian:
        return x.byteswapped()
    else:
        return x

def pickle_array(a):
    return (array_constructor, 
            (a.shape, a.typecode(), a.tostring(), LittleEndian))

copy_reg.pickle(ArrayType, pickle_array, array_constructor)


# These are all essentially abbreviations
# These might wind up in a special abbreviations module

def ravel(m):
    """ravel(m) returns a 1d array corresponding to all the elements of it's
    argument.
    """
    return reshape(m, (-1,))

def nonzero(a):
    """nonzero(a) returns the indices of the elements of a which are not zero,
    a must be 1d
    """
    return repeat(arange(len(a)), not_equal(a, 0))

#Move this into C to do it right!
def shape(a):
    """shape(a) returns the shape of a in functional form.
    """
    return asarray(a).shape

def where(condition, x, y):
    """where(condition,x,y) is shaped like condition and has elements of x and
    y where condition is respectively true or false.
    """
    return choose(not_equal(condition, 0), (y, x))

def compress(condition, m, dimension=-1):
    """compress(condition, x, dimension=-1) = those elements of x corresponding 
    to those elements of condition that are "true".  condition must be the
    same size as the given dimension of x."""
    return take(m, nonzero(condition), dimension)

def clip(m, m_min, m_max):
    """clip(m, m_min, m_max) = every entry in m that is less than m_min is
    replaced by m_min, and every entry greater than m_max is replaced by
    m_max.
    """
    selector = less(m, m_min)+2*greater(m, m_max)
    return choose(selector, (m, m_min, m_max))

def ones(shape, typecode='l', savespace=0):
    """ones(shape, typecode=Int, savespace=0) returns an array of the given
    dimensions which is initialized to all ones. 
    """
    a=zeros(shape, typecode, savespace)
    a[...]=1
    return a

def identity(n):
    """identity(n) returns the identity matrix of shape n x n.
    """
    return resize([1]+n*[0], (n,n))

sum = add.reduce
cumsum = add.accumulate
product = multiply.reduce
cumproduct = multiply.accumulate
alltrue = logical_and.reduce
sometrue = logical_or.reduce

arange = arrayrange

def around(m, decimals=0):
    """around(m, decimals=0) \
    Round in the same way as standard python performs rounding. Returns 
    always a float.
    """
    m = asarray(m)
    s = sign(m)
    if decimals:
        m = absolute(m*10.**decimals)
    else:
        m = absolute(m)
    rem = m-asarray(m).astype(Int)
    m = where(less(rem,0.5), floor(m), ceil(m))
    # convert back
    if decimals:
        m = m*s/(10.**decimals)
    else:
        m = m*s
    return m
    
def sign(m):
    """sign(m) gives an array with shape of m with elements defined by sign
    function:  where m is less than 0 return -1, where m greater than 0, a=1,
    elsewhere a=0.
    """
    m = asarray(m)
    return zeros(shape(m))-less(m,0)+greater(m,0)

def allclose (a, b, rtol=1.e-5, atol=1.e-8):
    """ allclose(a,b,rtol=1.e-5,atol=1.e-8)
        Returns true if all components of a and b are equal
        subject to given tolerances.
        The relative error rtol must be positive and << 1.0
        The absolute error atol comes into play for those elements
        of y that are very small or zero; it says how small x must be also.
    """
    x = array(a, copy=0)
    y = array(b, copy=0)
    d = less(absolute(x-y), atol + rtol * absolute(y))
    return alltrue(ravel(d))
    


syntax highlighted by Code2HTML, v. 0.9.1