#ifndef vm_m3math_h_INCLUDED
#define vm_m3math_h_INCLUDED

// File:   vm_m3math.h
// Author: Terry Gaetz

/* --8<--8<--8<--8<--
 *
 * Copyright (C) 2006 Smithsonian Astrophysical Observatory
 *
 * This file is part of vm_math.cvs
 *
 * vm_math.cvs is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * vm_math.cvs is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the 
 *       Free Software Foundation, Inc. 
 *       51 Franklin Street, Fifth Floor
 *       Boston, MA  02110-1301, USA
 *
 * -->8-->8-->8-->8-- */

/****************************************************************************
 * Description: collection of methods for vm_M3Math
 *
 * History
 *--------
 * 0.0.0 1998-Jan-30  tjg  original version
 */

#include <vm_math/vm_vmath.h>  // vm_VMath<T,N>
#include <cstring>             // memcpy strlen
#include <cmath>               // sqrt fabs
#include <iostream>            // ostream, <<
#include <cstdio>              // FILE*

//########################################################################
// vm_M3Math<T_fp>
//########################################################################

/** 
 * \class vm_M3Math vm_m3math.h <vm_math/vm_m3math.h>
 *
 * A template class providing common numerical operations 
 * on 3x3-matrices of T_fp's (floating point type).  The matrix 
 * is assumed to be stored as a * contiguous one-dimensional 
 * array of 9 T_fp's, properly aligned for type T_fp.
 *
 * The class is a simple class to handle common numerical
 * operations on 3x3-matrices <br>of floating point T_fps.
 * 
 * Unless otherwise noted, the 
 * operations are component by component, e.g., 
   \verbatim
       mul_eq(m1,m2)
   \endverbatim
 * results in
   \verbatim
       m1[i][j] += m2[i][j], where i = 0,1,2 and j = 0,1,2.
   \endverbatim
 * 
 * vm_M3Math has only static member functions; there are no data
 * members.
 * 
 * Where possible, the static member functions are inlined.
 */

template <class T_fp>
class vm_M3Math
  : public vm_VMath<T_fp,9>
{
private:

  enum ENelts_     { ENelts     = 9 };
  enum EColStride_ { EColStride = 3 };
  enum ERows_      { ERow0 = 0, ERow1 = 3, ERow2 = 6 };

public:

  // typedef typename T_fp value_type;
  typedef T_fp value_type;

  /*!
   * \defgroup mat_index_calculations Index calculations
   */
  /*@{*/

  /**
   * Index calculation.
   *
     \verbatim
     Array value m[i][j] is *(m + i*EColStride + j)
     \endverbatim
   *
   * @return offset of index set i, j (used for flat array storage)
   *
   * @param i      index
   * @param j      index
   */
  static int at( int i, int j );
  /*@}*/

  /*!
   * \defgroup mat_init Initialize a matrix.
   */
  /*@{*/

  /**
   * Initialize matrix by row; set rows of m to row0, row1, row2.
   *
   * @param m      matrix (as stored flat 1D array)
   * @param row0   1st row vector
   * @param row1   2nd row vector
   * @param row2   3rd row vector
   */
  inline static void init_by_row( T_fp       m[],
                                  T_fp const row0[], 
                                  T_fp const row1[], 
                                  T_fp const row2[] );

  /**
   * Initialize matrix by column; set columns of m to col0, col1, col2.
   *
   * @param m      matrix (as stored flat 1D array)
   * @param col0   1st column vector
   * @param col1   2nd column vector
   * @param col2   3rd column vector
   */
  inline static void init_by_col( T_fp       m[],
                                  T_fp const col0[], 
                                  T_fp const col1[], 
                                  T_fp const col2[] );

  /**
   * Initialize a matrix to a dyadic (outer) product of vectors
   *
   * For each i, j:  m[at(i,j)] = v1[i] * v2[j]
   *
   * @param m        matrix (as stored flat 1D array) 
   * @param v1       1st vector
   * @param v2       2nd vector
   */
  inline static void dyad_product( T_fp m[], T_fp const v1[], T_fp const v2[] );
  /*@}*/

  /*!
   * \defgroup mat_insert_extract For a matrix, insert or extract a vector.
   */
  /*@{*/

  /**
   * Copy a vector to a given row of m.
   *
   * @param m        matrix (as stored flat 1D array) 
   * @param row      vector to be stored
   * @param whichrow row index
   */
  inline static void inject_row( T_fp       m[],
                                 T_fp const row[], int whichrow );
    //
    //: copy the supplied vector into row whichrow of m.

  /**
   * Copy a vector to a given column of m.
   *
   * @param m        matrix (as stored flat 1D array) 
   * @param col      vector to be stored
   * @param whichcol column index
   */
  inline static void inject_col( T_fp       m[],
                                 T_fp const col[], int whichcol );

  /**
   * Copy a given row of m to a vector.
   *
   * @param m        matrix (as stored flat 1D array) 
   * @param row      vector to receive the copy
   * @param whichrow row index
   */
  inline static void extract_row( T_fp const m[],
                                  T_fp       row[], int whichrow );

  /**
   * Copy a given column of m to a vector.
   *
   * @param m        matrix (as stored flat 1D array) 
   * @param col      vector to receive the copy
   * @param whichcol column index
   */
  inline static void extract_col( T_fp const m[],
                                  T_fp       col[], int whichcol );
  /*@}*/

  // ---------------------------------------------------------------------
  /*!
   * \defgroup mat_vec_ops Matrix Vector operations.
   */
  /*@{*/

  /**
   * Matrix multiplication of vector v by matrix m.
   *
   * result = m _matrix_multiply_ v.
   *
   * @param res      resulting matrix (as stored flat 1D array) 
   * @param m        matrix to be multiplied
   * @param v        vector to be multiplied
   */
  inline static void mvmult( T_fp res[], T_fp const m[], T_fp const v[] );

  /**
   * Matrix multiplication of vector v by transpose of matrix m.
   *
   * result = m_transpose _matrix_multiply_ v.
   *
   * @param res      resulting matrix (as stored flat 1D array) 
   * @param m        matrix to be transposed and multiplied
   * @param v        vector to be multiplied
   */
  inline static void mtvmult( T_fp res[], T_fp const m[], T_fp const v[] );
  /*@}*/

  // ---------------------------------------------------------------------
  /*!
   * \defgroup mat_mat_ops Matrix Matrix operations.
   */
  /*@{*/

  inline static void mmult( T_fp mres[], T_fp const m1[], T_fp const m2[] );
  /*@}*/

  // ---------------------------------------------------------------------
  /*!
   * \defgroup mat_io IO operations.
   */
  /*@{*/
  /**
   * Print a matrix to an ostream.
   *
   * @param os       the ostream
   * @param m        matrix to be printed
   * @param prefix   optional prefix string
   * @param postfix  optional postfix string
   */
  inline static std::ostream&
  print_on( std::ostream& os, T_fp const m[], 
            char const prefix[] = "", char const postfix[] = "" );

  /**
   * Print a matrix to a FILE* stream.
   *
   * @param of       the FILE*
   * @param m        matrix to be printed
   * @param prefix   optional prefix string
   * @param postfix  optional postfix string
   */
  inline static void
  cprint_on( FILE* of, T_fp const m[], 
             char const prefix[] = "", char const postfix[] = "" );
  /*@}*/

};

//########################################################################
//########################################################################
//
//    #    #    #  #          #    #    #  ######   ####
//    #    ##   #  #          #    ##   #  #       #
//    #    # #  #  #          #    # #  #  #####    ####
//    #    #  # #  #          #    #  # #  #            #
//    #    #   ##  #          #    #   ##  #       #    #
//    #    #    #  ######     #    #    #  ######   ####
//
//########################################################################
//########################################################################

//-------------------------------------------------------------------------
// index calculation

template <class T_fp>
inline int vm_M3Math<T_fp>::
at( int i, int j )
{ return i * EColStride + j; }

//-------------------------------------------------------------------------
// initialize matrix by row or by column

template <class T_fp>
inline void vm_M3Math<T_fp>::
init_by_row( T_fp       m[],
             T_fp const row0[],
             T_fp const row1[],
             T_fp const row2[] )
{ memcpy( &m[ERow0], row0, 3 * sizeof(T_fp) );
  memcpy( &m[ERow1], row1, 3 * sizeof(T_fp) );
  memcpy( &m[ERow2], row2, 3 * sizeof(T_fp) ); }

template <class T_fp>
inline void vm_M3Math<T_fp>::
init_by_col( T_fp       m[],
             T_fp const col0[],
             T_fp const col1[],
             T_fp const col2[] )
{ m[0] = col0[0];   m[3] = col0[1];   m[6] = col0[2];
  m[1] = col1[0];   m[4] = col1[1];   m[7] = col1[2];
  m[2] = col2[0];   m[5] = col2[1];   m[8] = col2[2]; }

//-------------------------------------------------------------------------
// copy vector to given row or column of m

template <class T_fp>
inline void vm_M3Math<T_fp>::
inject_row( T_fp       m[],
            T_fp const row[], int whichrow )
{ memcpy( &m[whichrow*EColStride], row, 3 * sizeof(T_fp) ); }

template <class T_fp>
inline void vm_M3Math<T_fp>::
inject_col( T_fp       m[],
            T_fp const col[], int whichcol )
{ 
  double* pm = &m[whichcol];
  *pm = *col++;   pm += EColStride;
  *pm = *col++;   pm += EColStride;
  *pm = *col;
}

//-------------------------------------------------------------------------
// copy a given row or column of m to a vector

template <class T_fp>
inline void vm_M3Math<T_fp>::
extract_row( T_fp const m[],
             T_fp       row[], int whichrow )
{ memcpy( row, &m[whichrow*EColStride], 3 * sizeof(T_fp) ); }

template <class T_fp>
inline void vm_M3Math<T_fp>::
extract_col( T_fp const m[],
             T_fp       col[], int whichcol )
{ 
  double const* pm = &m[whichcol];
  *col = *pm;  pm += EColStride;  ++col;
  *col = *pm;  pm += EColStride;  ++col;
  *col = *pm;
}

//-------------------------------------------------------------------------
// initialize a matrix to a dyadic product of vectors

template <class T_fp>
inline void vm_M3Math<T_fp>::
dyad_product( T_fp m[], T_fp const v1[], T_fp const v2[] )
{
  m[0] = v1[0] * v2[0];   m[1] = v1[0] * v2[1];   m[2] = v1[0] * v2[2];
  m[3] = v1[1] * v2[0];   m[4] = v1[1] * v2[1];   m[5] = v1[1] * v2[2];
  m[6] = v1[2] * v2[0];   m[7] = v1[2] * v2[1];   m[8] = v1[2] * v2[2];
}

//-------------------------------------------------------------------------
// matrix-vector operations.

template <class T_fp>
inline void vm_M3Math<T_fp>::
mvmult( T_fp res[], T_fp const m[], T_fp const v[] )
{
  res[0] = m[0] * v[0] + m[1] * v[1] + m[2] * v[2];
  res[1] = m[3] * v[0] + m[4] * v[1] + m[5] * v[2];
  res[2] = m[6] * v[0] + m[7] * v[1] + m[8] * v[2];
}

template <class T_fp>
inline void vm_M3Math<T_fp>::
mtvmult( T_fp res[], T_fp const m[], T_fp const v[] )
{
  res[0] = m[0] * v[0] + m[3] * v[1] + m[6] * v[2];
  res[1] = m[1] * v[0] + m[4] * v[1] + m[7] * v[2];
  res[2] = m[2] * v[0] + m[5] * v[1] + m[8] * v[2];
}

template <class T_fp>
inline void vm_M3Math<T_fp>::
mmult( T_fp mres[], T_fp const m1[], T_fp const m2[] )
{
  mres[0] = m1[0] * m2[0]  +  m1[1] * m2[3]  +  m1[2] * m2[6];
  mres[1] = m1[0] * m2[1]  +  m1[1] * m2[4]  +  m1[2] * m2[7];
  mres[2] = m1[0] * m2[2]  +  m1[1] * m2[5]  +  m1[2] * m2[8];

  mres[3] = m1[3] * m2[0]  +  m1[4] * m2[3]  +  m1[5] * m2[6];
  mres[4] = m1[3] * m2[1]  +  m1[4] * m2[4]  +  m1[5] * m2[7];
  mres[5] = m1[3] * m2[2]  +  m1[4] * m2[5]  +  m1[5] * m2[8];

  mres[6] = m1[6] * m2[0]  +  m1[7] * m2[3]  +  m1[8] * m2[6];
  mres[7] = m1[6] * m2[1]  +  m1[7] * m2[4]  +  m1[8] * m2[7];
  mres[8] = m1[6] * m2[2]  +  m1[7] * m2[5]  +  m1[8] * m2[8];
}

template <class T_fp>
inline std::ostream& vm_M3Math<T_fp>::
print_on( std::ostream& os, T_fp const m[], 
          char const prefix[], char const postfix[] )
{
  if ( std::strlen(prefix) ) { os << prefix; }
  os << m[0] << " " << m[1] << " " << m[2] << "\n"
     << m[3] << " " << m[4] << " " << m[5] << "\n"
     << m[6] << " " << m[7] << " " << m[8] << "\n";
  if ( std::strlen(postfix) ) { os << postfix; }
  return os;
}

template <class T_fp>
inline void vm_M3Math<T_fp>::
cprint_on( FILE* of, T_fp const m[], 
           char const prefix[], char const postfix[] )
{
  if ( std::strlen(prefix) ) { std::fprintf(of, "%s", prefix); }
  std::fprintf(of, "%.18e %.18e %.18e\n", m[0], m[1], m[2]);
  std::fprintf(of, "%.18e %.18e %.18e\n", m[3], m[4], m[5]);
  std::fprintf(of, "%.18e %.18e %.18e\n", m[6], m[7], m[8]);
  if ( std::strlen(postfix) ) { std::fprintf(of, "%s", postfix); }
}

#endif  /* vm_m3math.h */
