/* This is -*- C -*- */
/* vim: set sw=2: */
/* $Id: guppi-matrix.c,v 1.3 2002/01/14 05:01:23 trow Exp $ */

/*
 * guppi-matrix.c
 *
 * Copyright (C) 2001 The Free Software Foundation, Inc.
 *
 * Developed by Jon Trowbridge <trow@gnu.org>
 */

/*
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2 of the
 * License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA.
 */

#include <config.h>
#include "guppi-matrix.h"

#include <string.h>
#include <math.h>
#include "guppi-memory.h"


GuppiMatrix *
guppi_matrix_new (gint r, gint c)
{
  GuppiMatrix *m;
  g_return_val_if_fail (r > 0 && c > 0, NULL);
  m = guppi_new0 (GuppiMatrix, 1);
  m->r = r;
  m->c = c;
  m->data = guppi_new0 (double, r * c);
  m->epsilon = 1e-8; /* Default Epsilon */
  return m;
}

GuppiMatrix *
guppi_matrix_copy (GuppiMatrix *m)
{
  GuppiMatrix *copy;
  gint r, c;

  if (m == NULL)
    return NULL;

  r = guppi_matrix_rows (m);
  c = guppi_matrix_cols (m);

  copy = guppi_matrix_new (r, c);
  memcpy (copy->data, m->data, r * c * sizeof (double));

  return copy;
}

void
guppi_matrix_free (GuppiMatrix *m)
{
  if (m) {
    guppi_free (m->data);
    m->data = NULL;
    m->r = 0;
    m->c = 0;

    guppi_matrix_free (m->LU);
    guppi_free (m->perm);

    guppi_free (m);
  }
}

void
guppi_matrix_touch (GuppiMatrix *m)
{
  if (m) {
    guppi_matrix_free (m->LU);
    m->LU = NULL;

    guppi_free0 (m->perm);
  }
}

void
guppi_matrix_set_constant (GuppiMatrix *m, double c)
{
  gint N;
  double *p;
  g_return_if_fail (m != NULL);
  
  N = guppi_matrix_rows (m) * guppi_matrix_cols (m);
  p = guppi_matrix_ptr (m, 0, 0);
  while (N > 0) {
    *p = c;
    ++p;
    --N;
  }
}

GuppiVector *
guppi_matrix_get_row (GuppiMatrix *m, gint r)
{
  GuppiVector *v;
  gint i;

  g_return_val_if_fail (m != NULL, NULL);
  g_return_val_if_fail (0 <= r && r < guppi_matrix_rows (m), NULL);

  v = guppi_vector_new (guppi_matrix_cols (m));
  for (i = 0; i < guppi_matrix_cols (m); ++i)
    guppi_vector_entry (v, i) = guppi_matrix_entry (m, r, i);

  return v;
}

GuppiVector *
guppi_matrix_get_col (GuppiMatrix *m, gint c)
{
  GuppiVector *v;
  gint i;

  g_return_val_if_fail (m != NULL, NULL);
  g_return_val_if_fail (0 <= c && c < guppi_matrix_cols (m), NULL);

  v = guppi_vector_new (guppi_matrix_rows (m));
  for (i = 0; i < guppi_matrix_rows (m); ++i)
    guppi_vector_entry (v, i) = guppi_matrix_entry (m, i, c);

  return v;
}

void
guppi_matrix_normalize_row (GuppiMatrix *m, gint r)
{
  double sumsq = 0, norm;
  double *p;
  double *q;
  gint i;

  g_return_if_fail (m != NULL);

  p = q = guppi_matrix_ptr (m, r, 0);
  for (i = 0; i < guppi_matrix_cols (m); ++i) {
    sumsq += (*q) * (*q);
    q = guppi_matrix_ptr_col_incr (m, q);
  }

  norm = sqrt (sumsq);

  for (i = 0; i < guppi_matrix_cols (m); ++i) {
    *p /= norm;
    p = guppi_matrix_ptr_col_incr (m, p);
  }
}

gboolean
guppi_matrix_row_is_nonzero (GuppiMatrix *m, gint r)
{
  double *p;
  gint i;

  g_return_val_if_fail (m != NULL, FALSE);

  p = guppi_matrix_ptr (m, r, 0);
  for (i = 0; i < guppi_matrix_cols (m); ++i) {
    if (fabs (*p) > m->epsilon)
      return TRUE;
    p = guppi_matrix_ptr_col_incr (m, p);
  }
  return FALSE;
}

gboolean
guppi_matrix_column_is_nonzero (GuppiMatrix *m, gint c)
{
  double *p;
  gint i;

  g_return_val_if_fail (m != NULL, FALSE);
 
  p = guppi_matrix_ptr (m, 0, c);
  for (i = 0; i < guppi_matrix_rows (m); ++i) {
    if (fabs (*p) > m->epsilon)
      return TRUE;
    p = guppi_matrix_ptr_row_incr (m, p);
  }
  return FALSE;
}

double
guppi_matrix_row_dot (GuppiMatrix *m, gint r1, gint r2)
{
  double *p;
  double *q;
  double sum = 0;
  gint i;

  g_return_val_if_fail (m != NULL, 0);

  p = guppi_matrix_ptr (m, r1, 0);
  q = guppi_matrix_ptr (m, r2, 0);

  for (i = 0; i < guppi_matrix_cols (m); ++i) {
    sum += (*p) * (*q);
    p = guppi_matrix_ptr_col_incr (m, p);
    q = guppi_matrix_ptr_col_incr (m, q);
  }
  return sum;
}

void
guppi_matrix_subtract_scaled_row_from_row (GuppiMatrix *m, double scale,
					   gint r, gint r_sub_from)
{
  double *p = guppi_matrix_ptr (m, r_sub_from, 0);
  double *q = guppi_matrix_ptr (m, r, 0);
  gint i;
  if (fabs (scale) < m->epsilon)
    return;
  for (i = 0; i < guppi_matrix_cols (m); ++i) {
    *p -= scale * (*q);
    p = guppi_matrix_ptr_col_incr (m, p);
    q = guppi_matrix_ptr_col_incr (m, q);
  }
}

static void
guppi_matrix_LU_decompose (GuppiMatrix *orig)
{
  gint *perm;
  GuppiMatrix *m;
  gint N, pi, i, j, k, max_ind;
  double x, max;

  g_return_if_fail (orig != NULL);
  g_return_if_fail (guppi_matrix_is_square (orig));

  if (orig->LU || orig->perm) {
    g_assert (orig->LU && orig->perm);
    return;
  };

  m = orig->LU = guppi_matrix_copy (orig);

  N = guppi_matrix_rows (m);
  perm = orig->perm = guppi_new0 (gint, N);
  pi = 0;

  for (k = 0; k < N - 1; ++k) {

    /* Find largest element */
    max = fabs (guppi_matrix_entry (m, k, k));
    max_ind = k;
    for (i = k + 1; i < N; ++i) {
      x = fabs (guppi_matrix_entry (m, i, k));
      if (max < x) {
	max_ind = i;
	max = x;
      }
    }
    perm[pi] = max_ind;
    ++pi;

    /* Swap */
    for (i = k; i < N; ++i) {
      x = guppi_matrix_entry (m, k, i);
      guppi_matrix_entry (m, k, i) = guppi_matrix_entry (m, max_ind, i);
      guppi_matrix_entry (m, max_ind, i) = x;
    }

    if (fabs (x = guppi_matrix_entry (m, k, k)) > m->epsilon) {
      for (i = k + 1; i < N; ++i) {
	guppi_matrix_entry (m, i, k) /= x;
      }

      for (i = k + 1; i < N; ++i) {
	x = guppi_matrix_entry (m, i, k);
	for (j = k + 1; j < N; ++j) {
	  guppi_matrix_entry (m, i, j) -= x * guppi_matrix_entry (m, k, j);
	}
      }
    }
  }
}

GuppiVector *
guppi_matrix_solve (GuppiMatrix *m, GuppiVector *vec)
{
  return guppi_matrix_solve_with_fallback (m, vec, NULL, NULL);
}

GuppiVector *
guppi_matrix_solve_with_fallback (GuppiMatrix *m,
				  GuppiVector *vec,
				  gboolean (*fallback) (GuppiMatrix *,
							GuppiVector *,
							gint, gpointer),
				  gpointer user_data)
{
  gint i, j, k, N;
  double x, t;
  GuppiVector *soln;
  
  g_return_val_if_fail (m != NULL, NULL);
  g_return_val_if_fail (vec != NULL, NULL);
  g_return_val_if_fail (guppi_matrix_is_square (m), NULL);

  guppi_matrix_LU_decompose (m);
  g_assert (m->LU && m->perm);

  N = guppi_matrix_cols (m);
  soln = guppi_vector_copy (vec);

  /* Apply gauss transforms and permutations */
  for (k = 0; k < N - 1; ++k) {
    i = m->perm[k];
    x = guppi_vector_entry (soln, k);
    guppi_vector_entry (soln, k) = guppi_vector_entry (soln, i);
    guppi_vector_entry (soln, i) = x;

    for (i = k + 1; i < N; ++i)
      guppi_vector_entry (soln, i) -= guppi_vector_entry(soln, k) * guppi_matrix_entry (m->LU, i, k);
  }

  /* Do back-substitution */
  t = guppi_matrix_entry (m->LU, N-1, N-1);
  if (fabs (t) > m->LU->epsilon)
    guppi_vector_entry (soln, N - 1) /= t;
  else if (!(fallback && fallback (m->LU, vec, N-1, user_data))) {
    goto fail;
  }

  for (i = N - 2; i >= 0; --i) {
    x = guppi_vector_entry (soln, i);
    for (j = i + 1; j < N; ++j)
      x -= guppi_matrix_entry (m->LU, i, j) * guppi_vector_entry (soln, j);
    t = guppi_matrix_entry (m->LU, i, i);
    if (fabs (t) > m->LU->epsilon)
      guppi_vector_entry (soln, i) = x / t;
    else if (!(fallback && fallback (m->LU, vec, i, user_data))) {
      goto fail;
    }
  }

  return soln;

 fail:
  guppi_vector_free (soln);
  return NULL;  
}

GuppiMatrix *
guppi_matrix_invert (GuppiMatrix *m)
{
  GuppiMatrix *mm;
  gint i, c, N;

  g_return_val_if_fail (m != NULL, NULL);
  g_return_val_if_fail (guppi_matrix_is_square (m), NULL);

  N = guppi_matrix_cols (m);
  mm = guppi_matrix_new (N, N);

  for (c = 0; c < N; ++c) {
    GuppiVector *v, *w;

    v = guppi_vector_new_basis (N, c);
    w = guppi_matrix_solve (m, v);
    guppi_vector_free (v);
    
    if (w == NULL)
      goto fail;

    for (i = 0; i < N; ++i)
      guppi_matrix_entry (mm, i, c) = guppi_vector_entry (w, i);

    guppi_vector_free (w);
  }

  return mm;

 fail:
  guppi_matrix_free (mm);
  return NULL;
}

GuppiVector *
guppi_matrix_apply (GuppiMatrix *m, GuppiVector *v)
{
  GuppiVector *vv;
  gint r, c, i, j;

  g_return_val_if_fail (m != NULL, NULL);
  g_return_val_if_fail (v != NULL, NULL);
  g_return_val_if_fail (guppi_matrix_cols (m) == guppi_vector_dim (v), NULL);

  r = guppi_matrix_rows (m);
  c = guppi_matrix_cols (m);
  vv = guppi_vector_new (r);

  for (i = 0; i < r; ++i) {
    double run = 0;
    double *ptr;

    ptr = guppi_matrix_ptr (m, i, 0);
    run = 0;

    for (j = 0; j < c; ++j) {
      run += *ptr * guppi_vector_entry (v, j);
      ptr = guppi_matrix_ptr_col_incr (m, ptr);
    }

    guppi_vector_entry (vv, i) = run;
  }

  return vv;
}

void
guppi_matrix_spew (GuppiMatrix *m)
{
  gint i, j;

  g_return_if_fail (m);

  for (i = 0; i < guppi_matrix_rows (m); ++i) {
    g_print ("| ");
    for (j = 0; j < guppi_matrix_cols (m); ++j) {
      g_print ("%8g ", guppi_matrix_entry (m, i, j));
    }
    g_print ("|\n");
  }
}




syntax highlighted by Code2HTML, v. 0.9.1