/*
* Copyright (c) 1997-1999, 2003 Massachusetts Institute of Technology
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*
*/
#include <stdio.h>
#include <math.h>
#include "fftw_mpi.h"
#include "fftw-int.h"
/************************** Twiddle Factors *****************************/
/* To conserve space, we share twiddle factor arrays between forward and
backward plans and plans of the same size (just as in the uniprocessor
transforms). */
static fftw_mpi_twiddle *fftw_mpi_twiddles = NULL;
static fftw_mpi_twiddle *fftw_mpi_create_twiddle(int rows, int rowstart,
int cols, int n)
{
fftw_mpi_twiddle *tw = fftw_mpi_twiddles;
while (tw && (tw->rows != rows || tw->rowstart != rowstart ||
tw->cols != cols || tw->n != n))
tw = tw->next;
if (tw) {
tw->refcount++;
return tw;
}
tw = (fftw_mpi_twiddle *) fftw_malloc(sizeof(fftw_mpi_twiddle));
tw->rows = rows;
tw->rowstart = rowstart;
tw->cols = cols;
tw->n = n;
tw->refcount = 1;
tw->next = fftw_mpi_twiddles;
{
fftw_complex *W = (fftw_complex *) fftw_malloc(sizeof(fftw_complex) *
rows * (cols - 1));
int j, i;
FFTW_TRIG_REAL twoPiOverN = FFTW_K2PI / (FFTW_TRIG_REAL) n;
for (j = 0; j < rows; ++j)
for (i = 1; i < cols; ++i) {
int k = (j * (cols - 1) - 1) + i;
FFTW_TRIG_REAL
ij = (FFTW_TRIG_REAL) (i * (j + rowstart));
c_re(W[k]) = FFTW_TRIG_COS(twoPiOverN * ij);
c_im(W[k]) = FFTW_FORWARD * FFTW_TRIG_SIN(twoPiOverN * ij);
}
tw->W = W;
}
fftw_mpi_twiddles = tw;
return tw;
}
static void fftw_mpi_destroy_twiddle(fftw_mpi_twiddle *tw)
{
if (tw) {
tw->refcount--;
if (tw->refcount == 0) {
/* delete tw from fftw_mpi_twiddles list: */
if (fftw_mpi_twiddles == tw)
fftw_mpi_twiddles = tw->next;
else {
fftw_mpi_twiddle *prev = fftw_mpi_twiddles;
if (!prev)
fftw_mpi_die("unexpected empty MPI twiddle list");
while (prev->next && prev->next != tw)
prev = prev->next;
if (prev->next != tw)
fftw_mpi_die("tried to destroy unknown MPI twiddle");
prev->next = tw->next;
}
fftw_free(tw->W);
fftw_free(tw);
}
}
}
/* multiply the array in d (of size tw->cols * n_fields) by the row cur_row
of the twiddle factors pointed to by tw, given the transform direction. */
static void fftw_mpi_mult_twiddles(fftw_complex *d, int n_fields,
int cur_row,
fftw_mpi_twiddle *tw,
fftw_direction dir)
{
int cols = tw->cols;
fftw_complex *W = tw->W + cur_row * (cols - 1);
int j;
if (dir == FFTW_FORWARD) {
if (n_fields > 1)
for (j = 1; j < cols; ++j) {
fftw_real
w_re = c_re(W[j-1]),
w_im = c_im(W[j-1]);
int f;
for (f = 0; f < n_fields; ++f) {
fftw_real
d_re = c_re(d[j*n_fields + f]),
d_im = c_im(d[j*n_fields + f]);
c_re(d[j*n_fields + f]) = w_re * d_re - w_im * d_im;
c_im(d[j*n_fields + f]) = w_re * d_im + w_im * d_re;
}
}
else
for (j = 1; j < cols; ++j) {
fftw_real w_re = c_re(W[j-1]),
w_im = c_im(W[j-1]),
d_re = c_re(d[j]),
d_im = c_im(d[j]);
c_re(d[j]) = w_re * d_re - w_im * d_im;
c_im(d[j]) = w_re * d_im + w_im * d_re;
}
}
else { /* FFTW_BACKWARDS */
/* same as above, except that W is complex-conjugated: */
if (n_fields > 1)
for (j = 1; j < cols; ++j) {
fftw_real
w_re = c_re(W[j-1]),
w_im = c_im(W[j-1]);
int f;
for (f = 0; f < n_fields; ++f) {
fftw_real
d_re = c_re(d[j*n_fields + f]),
d_im = c_im(d[j*n_fields + f]);
c_re(d[j*n_fields + f]) = w_re * d_re + w_im * d_im;
c_im(d[j*n_fields + f]) = w_re * d_im - w_im * d_re;
}
}
else
for (j = 1; j < cols; ++j) {
fftw_real w_re = c_re(W[j-1]),
w_im = c_im(W[j-1]),
d_re = c_re(d[j]),
d_im = c_im(d[j]);
c_re(d[j]) = w_re * d_re + w_im * d_im;
c_im(d[j]) = w_re * d_im - w_im * d_re;
}
}
}
/***************************** Plan Creation ****************************/
/* return the factor of n closest to sqrt(n): */
static int find_sqrt_factor(int n)
{
int i = sqrt(n) + 0.5;
int i2 = i - 1;
while (i2 > 0) {
if (n % i2 == 0)
return i2;
if (n % i == 0)
return i;
++i; --i2;
}
return 1; /* n <= 1 */
}
/* find the "best" r to divide n by for the FFT decomposition. Ideally,
we would like both r and n/r to be divisible by the number of
processes (for optimum load-balancing). Also, pick r to be close
to sqrt(n) if possible. */
static int find_best_r(int n, MPI_Comm comm)
{
int n_pes;
MPI_Comm_size(comm, &n_pes);
if (n % n_pes == 0) {
n /= n_pes;
if (n % n_pes == 0)
return (n_pes * find_sqrt_factor(n / n_pes));
else
return (n_pes * find_sqrt_factor(n));
}
else
return find_sqrt_factor(n);
}
#define MAX2(a,b) ((a) > (b) ? (a) : (b))
fftw_mpi_plan fftw_mpi_create_plan(MPI_Comm comm,
int n, fftw_direction dir, int flags)
{
fftw_mpi_plan p;
int i, r, m;
p = (fftw_mpi_plan) fftw_malloc(sizeof(struct fftw_mpi_plan_struct));
i = find_best_r(n, comm);
if (dir == FFTW_FORWARD)
m = n / (r = i);
else
r = n / (m = i);
p->n = n;
p->r = r;
p->m = m;
flags |= FFTW_IN_PLACE;
p->flags = flags;
p->dir = dir;
p->pr = fftw_create_plan(r, dir, flags);
p->pm = fftw_create_plan(m, dir, flags);
p->p_transpose = transpose_mpi_create_plan(m, r, comm);
p->p_transpose_inv = transpose_mpi_create_plan(r, m, comm);
transpose_mpi_get_local_size(r,
p->p_transpose_inv->my_pe,
p->p_transpose_inv->n_pes,
&p->local_r,
&p->local_r_start);
transpose_mpi_get_local_size(m,
p->p_transpose->my_pe,
p->p_transpose->n_pes,
&p->local_m,
&p->local_m_start);
if (dir == FFTW_FORWARD)
p->tw = fftw_mpi_create_twiddle(p->local_r, p->local_r_start, m, n);
else
p->tw = fftw_mpi_create_twiddle(p->local_m, p->local_m_start, r, n);
p->fft_work = (fftw_complex *) fftw_malloc(sizeof(fftw_complex) *
MAX2(m, r));
return p;
}
/********************* Getting Local Size ***********************/
void fftw_mpi_local_sizes(fftw_mpi_plan p,
int *local_n,
int *local_start,
int *local_n_after_transform,
int *local_start_after_transform,
int *total_local_size)
{
if (p) {
if (p->flags & FFTW_SCRAMBLED_INPUT) {
*local_n = p->local_r * p->m;
*local_start = p->local_r_start * p->m;
}
else {
*local_n = p->local_m * p->r;
*local_start = p->local_m_start * p->r;
}
if (p->flags & FFTW_SCRAMBLED_OUTPUT) {
*local_n_after_transform = p->local_m * p->r;
*local_start_after_transform = p->local_m_start * p->r;
}
else {
*local_n_after_transform = p->local_r * p->m;
*local_start_after_transform = p->local_r_start * p->m;
}
*total_local_size =
transpose_mpi_get_local_storage_size(p->p_transpose->nx,
p->p_transpose->ny,
p->p_transpose->my_pe,
p->p_transpose->n_pes);
}
}
static void fftw_mpi_fprint_plan(FILE *f, fftw_mpi_plan p)
{
fprintf(f, "mpi plan:\n");
fprintf(f, "m = %d plan:\n", p->m);
fftw_fprint_plan(f, p->pm);
fprintf(f, "r = %d plan:\n", p->r);
fftw_fprint_plan(f, p->pr);
}
void fftw_mpi_print_plan(fftw_mpi_plan p)
{
fftw_mpi_fprint_plan(stdout, p);
}
/********************** Plan Destruction ************************/
void fftw_mpi_destroy_plan(fftw_mpi_plan p)
{
if (p) {
fftw_destroy_plan(p->pr);
fftw_destroy_plan(p->pm);
transpose_mpi_destroy_plan(p->p_transpose);
transpose_mpi_destroy_plan(p->p_transpose_inv);
fftw_mpi_destroy_twiddle(p->tw);
fftw_free(p->fft_work);
fftw_free(p);
}
}
/******************** Computing the Transform *******************/
void fftw_mpi(fftw_mpi_plan p, int n_fields,
fftw_complex *local_data, fftw_complex *work)
{
int i;
int el_size = (sizeof(fftw_complex) / sizeof(TRANSPOSE_EL_TYPE))
* n_fields;
fftw_complex *fft_work;
fftw_direction dir;
fftw_mpi_twiddle *tw;
if (n_fields < 1)
return;
if (!(p->flags & FFTW_SCRAMBLED_INPUT))
transpose_mpi(p->p_transpose, el_size,
(TRANSPOSE_EL_TYPE *) local_data,
(TRANSPOSE_EL_TYPE *) work);
tw = p->tw;
dir = p->dir;
fft_work = work ? work : p->fft_work;
/* For forward plans, we multiply by the twiddle factors here,
before the second transpose. For backward plans, we multiply
by the twiddle factors after the second transpose. We do
this so that forward and backward transforms can share the
same twiddle factor array (noting that m and r are swapped
for the two directions so that the local sizes will be compatible). */
{
int rows = p->local_r, cols = p->m;
fftw_plan p_fft = p->pm;
if (dir == FFTW_FORWARD) {
for (i = 0; i < rows; ++i) {
fftw_complex *d = local_data + i * (cols * n_fields);
fftw(p_fft, n_fields, d, n_fields, 1, fft_work, 1, 0);
fftw_mpi_mult_twiddles(d, n_fields, i, tw, FFTW_FORWARD);
}
}
else {
if (n_fields > 1)
for (i = 0; i < rows; ++i)
fftw(p_fft, n_fields, local_data + i*(cols*n_fields),
n_fields, 1, fft_work, 1, 0);
else
fftw(p_fft, rows, local_data, 1, cols, fft_work, 1, 0);
}
}
transpose_mpi(p->p_transpose_inv, el_size,
(TRANSPOSE_EL_TYPE *) local_data,
(TRANSPOSE_EL_TYPE *) work);
{
int rows = p->local_m, cols = p->r;
fftw_plan p_fft = p->pr;
if (dir == FFTW_BACKWARD) {
for (i = 0; i < rows; ++i) {
fftw_complex *d = local_data + i * (cols * n_fields);
fftw_mpi_mult_twiddles(d, n_fields, i, tw, FFTW_BACKWARD);
fftw(p_fft, n_fields, d, n_fields, 1, fft_work, 1, 0);
}
}
else {
if (n_fields > 1)
for (i = 0; i < rows; ++i)
fftw(p_fft, n_fields, local_data + i*(cols*n_fields),
n_fields, 1, fft_work, 1, 0);
else
fftw(p_fft, rows, local_data, 1, cols, fft_work, 1, 0);
}
}
if (!(p->flags & FFTW_SCRAMBLED_OUTPUT))
transpose_mpi(p->p_transpose, el_size,
(TRANSPOSE_EL_TYPE *) local_data,
(TRANSPOSE_EL_TYPE *) work);
/* Yes, we really had to do three transposes...sigh. */
}
syntax highlighted by Code2HTML, v. 0.9.1