import Numeric
from MA import *
def test1(*s):
    x=Numeric.array([1.,1.,1.,-2., pi/2.0, 4., 5., -10., 10., 1., 2., 3.])
    y=Numeric.array([5.,0.,3., 2., -1., -4., 0., -10., 10., 1., 0., 3.])
    a10 = 10.
    m1 = [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
    m2 = [0, 0, 1, 0, 0, 1, 1, 0, 0, 0 ,0, 1]
    xm = array(x, mask=m1)
    ym = array(y, mask=m2)
    z = Numeric.array([-.5, 0., .5, .8])
    zm = array(z, mask=[0,1,0,0])
    xf = Numeric.where(m1, 1.e+20, x)
    xm.set_fill_value(1.e+20)
    for item in [x, y, xm, ym, xf]:
        item.shape = s

    assert isMA(x) == 0
    assert isMA(xm) == 1
    assert allclose(xm, xf)
    assert allclose(filled(xm, 1.e20), xf)
    assert allclose(x, xm)
    assert allclose(-x, -xm)
    assert allclose(x + y, xm + ym)
    assert allclose(x - y, xm - ym)
    assert allclose(x * y, xm * ym)
    assert allclose(x / y, xm / ym)
    assert allclose(a10 + y, a10 + ym)
    assert allclose(a10 - y, a10 - ym)
    assert allclose(a10 * y, a10 * ym)
    assert allclose(a10 / y, a10 / ym)
    assert allclose(x + a10, xm + a10)
    assert allclose(x - a10, xm - a10)
    assert allclose(x * a10, xm * a10)
    assert allclose(x / a10, xm / a10)
    assert allclose(x**2, xm**2)
    assert allclose(abs(x)**2.5, abs(xm) **2.5)
    assert allclose(x**y, xm**ym)
    assert allclose(Numeric.add(x,y), add(xm, ym))
    assert allclose(Numeric.subtract(x,y), subtract(xm, ym))
    assert allclose(Numeric.multiply(x,y), multiply(xm, ym))
    assert allclose(Numeric.divide(x,y), divide(xm, ym))
    assert allclose(Numeric.cos(x), cos(xm))
    assert allclose(Numeric.cosh(x), cosh(xm))
    assert allclose(Numeric.sin(x), sin(xm))
    assert allclose(Numeric.sinh(x), sinh(xm))
    assert allclose(Numeric.tan(x), tan(xm))
    assert allclose(Numeric.tanh(x), tanh(xm))
    assert allclose(Numeric.sqrt(abs(x)), sqrt(xm))
    assert allclose(Numeric.log(abs(x)), log(xm))
    assert allclose(Numeric.log10(abs(x)), log10(xm))
    assert allclose(Numeric.exp(x), exp(xm))
    assert allclose(Numeric.arcsin(z), arcsin(zm))
    assert allclose(Numeric.arccos(z), arccos(zm))
    assert allclose(Numeric.arctan(z), arctan(zm))
    assert allclose(Numeric.arctan2(x, y), arctan2(xm, ym))
    assert allclose(Numeric.absolute(x), absolute(xm))
    assert allclose(Numeric.equal(x,y), equal(xm, ym))
    assert allclose(Numeric.not_equal(x,y), not_equal(xm, ym))
    assert allclose(Numeric.less(x,y), less(xm, ym))
    assert allclose(Numeric.greater(x,y), greater(xm, ym))
    assert allclose(Numeric.less_equal(x,y), less_equal(xm, ym))
    assert allclose(Numeric.greater_equal(x,y), greater_equal(xm, ym))
    assert allclose(Numeric.conjugate(x), conjugate(xm))
    assert allclose(Numeric.concatenate((x,y)), concatenate((xm,ym)))
    assert allclose(Numeric.concatenate((x,y)), concatenate((x,y)))
    assert allclose(Numeric.concatenate((x,y)), concatenate((xm,y)))
    assert allclose(Numeric.concatenate((x,y,x)), concatenate((x,ym,x)))
    ott = array([0.,1.,2.,3.], mask=[1,0,0,0])
    assert allclose(2.0, average(ott))
    assert allclose(2.0, average(ott, weights=[1., 1., 2., 1.]))
    xr = Numeric.ravel(x) #max doesn't work if shaped
    xmr = ravel(xm)
    assert allclose(max(xr), maximum(xmr)) #true because of careful selection of data
    assert allclose(min(xr), minimum(xmr)) #true because of careful selection of data
    assert allclose(Numeric.add.reduce(x), add.reduce(x))
    assert allclose(Numeric.sum(x), sum(x))
    assert allclose(Numeric.sum(filled(xm,0)), sum(xm))
    assert allclose(Numeric.sum(x,0), sum(x,0))
    assert allclose(Numeric.product(x), product(x))
    assert allclose(Numeric.product(x,0), product(x,0))
    assert allclose(Numeric.product(filled(xm,1)), product(xm))
    if len(s) > 1:
        assert allclose(Numeric.concatenate((x,y),1), concatenate((xm,ym),1))
        assert allclose(Numeric.add.reduce(x,1), add.reduce(x,1))
        assert allclose(Numeric.sum(x,1), sum(x,1))
        assert allclose(Numeric.product(x,1), product(x,1))


def test2():
    x1 = Numeric.array([1,2,4,3])
    x2 = array(x1, mask = [1,0,0,0])
    x3 = array(x1, mask = [0,1,0,1])
    x4 = array(x1)
# test conversion to strings
    junk, garbage = str(x2), repr(x2)
    assert allclose(Numeric.sort(x1),sort(x2, fill_value=0))
# tests of indexing
    assert type(x2[1]) is type(x1[1])
    assert x1[1] == x2[1]
    assert x2[0] is masked
    assert allclose(x1[2],x2[2])
    assert allclose(x1[2:5],x2[2:5])
    assert allclose(x1[:],x2[:])
    assert allclose(x1[1:], x3[1:])
    x1[2]=9
    x2[2]=9
    assert allclose(x1,x2)
    x1[1:3] = 99
    x2[1:3] = 99
    assert allclose(x1,x2)
    x2[1] = masked
    assert allclose(x1,x2)
    x2[1:3]=masked
    assert allclose(x1,x2)
    x2[:] = x1
    x2[1] = masked
    assert allequal(getmask(x2),array([0,1,0,0]))
    x3[:] = masked_array([1,2,3,4],[0,1,1,0])
    assert allequal(getmask(x3), array([0,1,1,0]))
    x4[:] = masked_array([1,2,3,4],[0,1,1,0])
    assert allequal(getmask(x4), array([0,1,1,0]))
    assert allequal(x4, array([1,2,3,4]))
    x1 = Numeric.arange(5)*1.0
    x2 = masked_values(x1, 3.0)
    assert allclose(x1,x2)
    assert allequal(array([0,0,0,1,0]), x2.mask())
    assert allclose(3.0, x2.fill_value())

def test3():
    "Tests of some subtle points of copying and sizing."
    n = [0,0,1,0,0]
    m = make_mask(n)
    m2 = make_mask(m)
    assert m is m2
    m3 = make_mask(m, copy=1)
    assert m is not m3
    
    x1 = Numeric.arange(5)
    y1 = array(x1, mask=m)
    assert y1.raw_data() is not x1
    assert allequal(x1,y1.raw_data())
    assert y1.mask() is m
    
    y2 = array(x1, mask=m, copy=0)
    assert y2.raw_data() is x1
    assert y2.mask() is m
    assert y2[2] is masked
    y2[2]=9
    assert y2[2] is not masked
    assert y2.mask() is not m
    assert allequal(y2.mask(), 0)

    y3 = array(x1*1.0, mask=m, savespace=1).astype(Float32)
    assert filled(y3).typecode() == Float32
    assert y3.spacesaver()

def test4():
    "test of put"
    x=arange(5)
    put (x, [1,4],[10,40])
    assert allclose(x, [0,10,2,3,40])

    x=arange(5) * 1.0
    put (x, [1,4], [10.,40.])
    assert allclose(x, [0,10,2,3,40])

    x = arange(5).astype(Float32)
    put (x, [1,4],[10.,40.])
    assert allclose(x, [0,10,2,3,40])

    x=arange(6)*1.0
    x.shape=(2,3)
    put(x, [1,4],[10,40])
    assert allclose(x, [[0,10,2],[3,40,5]])

    d = arange(5)
    n = [0,0,0,1,1]
    m = make_mask(n)
    x = array(d, mask = m)
    assert x[3] is masked
    assert x[4] is masked
    put(x, [1,4], [10,40])
    assert x.mask() is not m
    assert x[3] is masked
    assert x[4] is not masked
    assert allclose(x, [0,10,2,-1,40])

    x = array(d, mask = m) 
    assert x.mask() is m
    x.put([-1,100,200])
    assert allclose(x, [-1,100,200,0,0])
    assert x[3] is masked
    assert x[4] is masked
    assert x.mask() is m

    x = array(d, mask = m) 
    x.putmask([30,40])
    assert allclose(x, [0,1,2,30,40])
    assert x.mask() is None

    x = array(d, mask = m) 
    y = x.compressed()
    z = array(x, mask = m)
    z.put(y)
    assert allclose (x, z)

def test5():
    "Test of other odd features"
    x = arange(20); x.shape=(4,5)
    x.flat[5] = 12
    assert x[1,0] == 12
    z = x + 10j * x
    assert allclose(z.real, x)
    assert allclose(z.imaginary, 10*x)
    assert allclose((z*conjugate(z)).real, 101*x*x)
    z.imaginary[...] = 0.0

    x = arange(10)
    x[3] = masked
    assert str(x[3]) == str(masked)
    z = where(greater_equal(x,8), x, masked)
    assert z.typecode() == x.typecode()
    assert z[3] is masked
    assert z[8] is masked
    assert z[9] is masked
    assert allclose(x,z)
    z = masked_where(greater_equal(x,8), x)
    assert z.typecode() == x.typecode()
    assert z[3] is masked
    assert z[8] is masked
    assert z[9] is masked
    assert allclose(x,z)
    x = array([1.,2.,3.,4.,5.])
    c = array([1,1,1,0,0])
    x[2] = masked
    z = where(c, x, -x)
    assert allclose(z, [1.,2.,0., -4., -5])
    c[0] = masked
    z = where(c, x, -x)
    assert allclose(z, [1.,2.,0., -4., -5])
    assert z[0] is masked
    assert z[1] is not masked
    assert z[2] is masked
    assert allclose(masked_where(greater(x, 2), x), masked_greater(x,2))
    assert allclose(masked_where(greater_equal(x, 2), x), masked_greater_equal(x,2))
    assert allclose(masked_where(less(x, 2), x), masked_less(x,2))
    assert allclose(masked_where(less_equal(x, 2), x), masked_less_equal(x,2))
    assert allclose(masked_where(not_equal(x, 2), x), masked_not_equal(x,2))
    assert allclose(masked_where(equal(x, 2), x), masked_equal(x,2))
    assert allclose(masked_where([1,1,0,0,0], [1,2,3,4,5]), [99,99,3,4,5])

def test6():
    "Test of minumum, maximum."    
    assert allclose(minimum([1,2,3],[4,0,9]), [1,0,3])
    assert allclose(maximum([1,2,3],[4,0,9]), [4,2,9])
    x = arange(5)
    y = arange(5) - 2
    x[3] = masked
    y[0] = masked
    assert(minimum(x,y), where(less(x,y), x, y))
    assert(maximum(x,y), where(greater(x,y), x, y))
    assert(minimum(x) == 0)
    assert(maximum(x) == 4)

def test7():
    "Test of take, transpose, inner, outer products"
    x = arange(24)
    y = Numeric.arange(24)
    x[5:6] = masked
    x.shape=(2,3,4)
    y.shape=(2,3,4)
    assert(allclose(Numeric.transpose(y,(2,0,1)), transpose(x,(2,0,1))))
    assert(allclose(Numeric.take(y, (2,0,1), 1), take(x, (2,0,1), 1)))
    assert(allclose(Numeric.innerproduct(filled(x),filled(y)),
                            innerproduct(x, y)))
    assert(allclose(Numeric.outerproduct(filled(x),filled(y)),
                            outerproduct(x, y)))
   
def test8():
    print "Timing tests ...       Numeric .... MA maskless ... MA mask"
    print "Timing test n=1000:   ", test8a(1000), test8b(1000), test8c(1000)
    print "Timing test n=10000:  ", test8a(10000), test8b(10000), test8c(10000)
    print "Timing test n=100000: ", test8a(100000), test8b(100000), test8c(100000)

def test8a(n):
    import time
    from Numeric import *
    x=arange(n) + 1.0
    tn0 = time.time()
    z = (sqrt(x) * (x + x + 2.*x * x) - x ) / (x + 1.)
    return time.time() - tn0

def test8b(n):
    import time
    from MA import *
    x=arange(n) + 1.0
    tn0 = time.time()
    z = (sqrt(x) * (x + x + 2.*x * x) - x ) / (x + 1.)
    return time.time() - tn0

def test8c(n):
    import time
    from MA import *
    x=arange(n) + 1.0
    x[0] = masked
    tn0 = time.time()
    z = (sqrt(x) * (x + x + 2.*x * x) - x ) / (x + 1.)
    return time.time() - tn0

print "test 1"                   
test1(12)
test1(4,3)
test1(2,3,2)
print "test 2"                   
test2()
print "test 3"                   
test3()
print "test 4"
test4()
print "test 5"
test5()
print "test 6"
test6()
print "test 7"
test7()
print "test 8"
test8()
import gc
gc.set_debug(gc.DEBUG_UNCOLLECTABLE)
assert gc.collect() == 0
print "Test of MA completed o.k."


syntax highlighted by Code2HTML, v. 0.9.1