#ifndef MATRIX_H
#define MATRIX_H

#define DEBUG
#include <assert.h>
#include <stdio.h>
#include <math.h>

/****************************************************************/
/** Template class for multi-dimensional matrices */

template<class T> class Matrix {
 protected:
  int order;
  int size;
  int *dimensions;
  T *a;

 public:
  Matrix(int norder, int *ndimensions, T defaultVal=0);
  inline Matrix(Matrix &na);

  Matrix() { order=0; size=0; }
  Matrix(int d1);
  Matrix(int d1, int d2);
  Matrix(int d1, int d2, int d3);
  Matrix(int d1, int d2, int d3, int d4);
  Matrix(int d1, int d2, int d3, int d4, int d5);
  void fill(T val);

  Matrix & operator = (const Matrix &na) {
    reinitMatrix(na.order, na.dimensions);
    copyMatrix(na);
    return *this;
  }

  inline int getSize() { return size; }
  inline int getOrder() { return order; }
  inline const int *getDimensions() { return dimensions; }

  inline T &operator[](int *d);
  inline T &operator[](char *d); /* index can be array of chars as well */
  inline T &operator()(int d1);
  inline T &operator()(int d1, int d2);
  inline T &operator()(int d1, int d2, int d3);
  inline T &operator()(int d1, int d2, int d3, int d4);
  inline T &operator()(int d1, int d2, int d3, int d4, int d5);

  inline T* vector(int d1, int d2); /* get 1d vector from 3D matrix */
  inline T* vector(int *d); /* get 1d vector from a matrix,
			     * d should have length order-1 */

  /* change dimensions of the matrix, all original data lost */
  void reinitMatrix(int norder, int *ndimensions, T defaultVal=0);

  void reinitMatrix() { deleteMatrix(); }
  void reinitMatrix(int d1);
  void reinitMatrix(int d1, int d2);
  void reinitMatrix(int d1, int d2, int d3);
  void reinitMatrix(int d1, int d2, int d3, int d4);
  void reinitMatrix(int d1, int d2, int d3, int d4, int d5);

  /* Prints matrix with named coordinates in each row.
   * Needs a function that prints one element of the matrix
   * as well as optional names of dimensions. */
  void printMatrixLong(FILE *f, char **names,
		       void(*print)(FILE *f, T item));

  ~Matrix();

 protected:
  /* allocate memory for the matrix, fill entries with defaultVal */
  void initMatrix(int norder, int *ndimensions, T defaultVal=0);
  /* deallocate memory, set order=0 */
  void deleteMatrix();
  void copyMatrix(const Matrix &na); /* assume the same dimensions */
};

/****************************************************************/
class DoubleMatrix:public Matrix<double> {

 public:
  DoubleMatrix(DoubleMatrix &na):Matrix<double>(na) {};
  DoubleMatrix():Matrix<double>() {};
  DoubleMatrix(int norder, int *ndimensions, double defaultVal=0)
    :Matrix<double>(norder,ndimensions,defaultVal) {}
  DoubleMatrix(int d1)
    :Matrix<double>(d1) {};
  DoubleMatrix(int d1, int d2)
    :Matrix<double>(d1,d2) {};
  DoubleMatrix(int d1, int d2, int d3)
    :Matrix<double>(d1,d2,d3) {};
  DoubleMatrix(int d1, int d2, int d3, int d4)
    :Matrix<double>(d1,d2,d3,d4) {};
  DoubleMatrix(int d1, int d2, int d3, int d4, int d5)
    :Matrix<double>(d1,d2,d3,d4,d5) {};

  void writeSimple(FILE *f);
  void readSimple(FILE *f);
  void readRecompute(FILE *f);
  void logAll();

  /* for fixed indexes in first order-1 dimensions the
   * sum over last dimension should be 1 */
  void normalize();
};


/****************************************************************/
class IntMatrix:public Matrix<int> {

 public:
  IntMatrix(IntMatrix &na):Matrix<int>(na) {};
  IntMatrix():Matrix<int>() {};
  IntMatrix(int norder, int *ndimensions, int defaultVal=0)
    :Matrix<int>(norder,ndimensions, defaultVal) {};
  IntMatrix(int d1)
    :Matrix<int>(d1) {};
  IntMatrix(int d1, int d2)
    :Matrix<int>(d1,d2) {};
  IntMatrix(int d1, int d2, int d3)
    :Matrix<int>(d1,d2,d3) {};
  IntMatrix(int d1, int d2, int d3, int d4)
    :Matrix<int>(d1,d2,d3,d4) {};
  IntMatrix(int d1, int d2, int d3, int d4, int d5)
    :Matrix<int>(d1,d2,d3,d4,d5) {};

  void writeSimple(FILE *f);
  void readSimple(FILE *f);
};

/****************************************************************/
template<class T> Matrix<T>::Matrix(int norder, int *ndimensions,
				    T defaultVal)
{
  initMatrix(norder, ndimensions, defaultVal);
}

template<class T> void Matrix<T>::initMatrix(int norder, int *ndimensions,
					      T defaultVal)
{
  order=norder;
  dimensions = new int[norder];

  size=1;
  for (int i=0; i<order; i++) {
    dimensions[i]=ndimensions[i];
    size*=dimensions[i];
  }
  assert(size>0);

  a=new T[size];
  fill(defaultVal);
}

template<class T> Matrix<T>::Matrix(Matrix &na)
{
  initMatrix(na.order, na.dimensions);
  copyMatrix(na);
}

template<class T> Matrix<T>::Matrix(int d1)
{
  int dim[1];
  dim[0]=d1;
  initMatrix(1,dim);
}

template<class T> Matrix<T>::Matrix(int d1, int d2)
{
  int dim[2];
  dim[0]=d1; dim[1]=d2;
  initMatrix(2,dim);
}

template<class T> Matrix<T>::Matrix(int d1, int d2, int d3) {
  int dim[3];
  dim[0]=d1; dim[1]=d2; dim[2]=d3;
  initMatrix(3,dim);
}

template<class T> Matrix<T>::Matrix(int d1, int d2, int d3, int d4)
{
  int dim[4];
  dim[0]=d1; dim[1]=d2; dim[2]=d3; dim[3]=d4;
  initMatrix(4,dim);
}

template<class T> Matrix<T>::Matrix(int d1, int d2, int d3, int d4, int d5)
{
  int dim[5];
  dim[0]=d1; dim[1]=d2; dim[2]=d3; dim[3]=d4; dim[4]=d5;
  initMatrix(5,dim);
}

template<class T> T &Matrix<T>::operator[](int *d)
{
  #ifdef DEBUG
  assert(order>0);
  #endif

  int index=d[0];
  for(int i=1;i<order;i++) {
    #ifdef DEBUG
    assert(0<=d[i] && d[i]<dimensions[i]);
    #endif
    index*=dimensions[i];
    index+=d[i];
  }
  return a[index];
}

template<class T> T &Matrix<T>::operator[](char *d)
{
  #ifdef DEBUG
  assert(order>0);
  #endif

  int index=d[0];
  for(int i=1;i<order;i++) {
    #ifdef DEBUG
    assert(0<=d[i] && d[i]<dimensions[i]);
    #endif
    index*=dimensions[i];
    index+=d[i];
  }
  return a[index];
}

template<class T> T &Matrix<T>::operator()(int d1)
{
  #ifdef DEBUG
  assert(order==1);
  assert(0<=d1 && d1<dimensions[0]);
  #endif
  return a[d1];
}

template<class T> T &Matrix<T>::operator()(int d1, int d2)
{
  #ifdef DEBUG
  assert(order==2);
  assert(0<=d1 && d1<dimensions[0]);
  assert(0<=d2 && d2<dimensions[1]);
  #endif
  return a[d1*dimensions[1]+d2];
}

template<class T> T &Matrix<T>::operator()(int d1, int d2, int d3)
{
  #ifdef DEBUG
  assert(order==3);
  assert(0<=d1 && d1<dimensions[0]);
  assert(0<=d2 && d2<dimensions[1]);
  assert(0<=d3 && d3<dimensions[2]);
  assert((d1*dimensions[1]+d2)*dimensions[2]+d3<size);
  #endif

  return a[(d1*dimensions[1]+d2)*dimensions[2]+d3];
}

template<class T> T &Matrix<T>::operator()(int d1, int d2, int d3, int d4)
{
  #ifdef DEBUG
  assert(order==4);
  assert(0<=d1 && d1<dimensions[0]);
  assert(0<=d2 && d2<dimensions[1]);
  assert(0<=d3 && d3<dimensions[2]);
  assert(0<=d4 && d4<dimensions[3]);
  #endif
  return a[((d1*dimensions[1]+d2)*dimensions[2]+d3)*dimensions[3]+d4];
}

template<class T> T &Matrix<T>::operator()(int d1, int d2, int d3,
					   int d4, int d5)
{
  #ifdef DEBUG
  assert(order==5);
  assert(0<=d1 && d1<dimensions[0]);
  assert(0<=d2 && d2<dimensions[1]);
  assert(0<=d3 && d3<dimensions[2]);
  assert(0<=d4 && d4<dimensions[3]);
  assert(0<=d5 && d5<dimensions[4]);
  assert((((d1*dimensions[1]+d2)*dimensions[2]
	   +d3)*dimensions[3]+d4)*dimensions[4]+d5<size);
  #endif

  return a[(((d1*dimensions[1]+d2)*dimensions[2]
	     +d3)*dimensions[3]+d4)*dimensions[4]+d5];
}


template<class T> T* Matrix<T>::vector(int *d)
{
  #ifdef DEBUG
  assert(order>0);
  #endif

  int index = 0;
  for(int i=0; i+1<order; i++) {
    #ifdef DEBUG
    assert(0<=d[i] && d[i]<dimensions[i]);
    #endif
    index*=dimensions[i];
    index+=d[i];
  }
  index*=dimensions[order-1];
  return a+index;
}


template<class T> T* Matrix<T>::vector(int d1, int d2)
{
  #ifdef DEBUG
  assert(order==3);
  assert(0<=d1 && d1<dimensions[0]);
  assert(0<=d2 && d2<dimensions[1]);
  #endif

  return a+((d1*dimensions[1]+d2)*dimensions[2]);
}

template<class T> void Matrix<T>::fill(T val)
{
  for (int i=0; i<size; i++) a[i]=val;
}

template<class T> Matrix<T>::~Matrix()
{
  deleteMatrix();
}

template<class T> void Matrix<T>::deleteMatrix()
{
  if (order) {
    delete[] a;
    delete[] dimensions;
    order=0;
  }
}

template<class T> void Matrix<T>::copyMatrix(const Matrix &na)
{
  //check that dimensions are the same
  assert(order==na.order && size==na.size);
  for(int i=0; i<order; i++) {
    assert(dimensions[i]==na.dimensions[i]);
  }

  for(int i=0; i<size; i++) {
    a[i] = na.a[i];
  }
}


template<class T> void Matrix<T>::reinitMatrix(int norder, int *ndimensions,
						T defaultVal)
{
  deleteMatrix();
  initMatrix(norder, ndimensions, defaultVal);
}

template<class T> void Matrix<T>::reinitMatrix(int d1)
{
  int dim[1];
  dim[0]=d1;
  reinitMatrix(1,dim);
}

template<class T> void Matrix<T>::reinitMatrix(int d1, int d2)
{
  int dim[2];
  dim[0]=d1; dim[1]=d2;
  reinitMatrix(2,dim);
}

template<class T> void Matrix<T>::reinitMatrix(int d1, int d2, int d3)
{
  int dim[3];
  dim[0]=d1; dim[1]=d2; dim[2]=d3;
  initMatrix(3,dim);
}

template<class T> void Matrix<T>::reinitMatrix(int d1, int d2,
						int d3, int d4)
{
  int dim[4];
  dim[0]=d1; dim[1]=d2; dim[2]=d3; dim[3]=d4;
  initMatrix(4,dim);
}

template<class T> void Matrix<T>:: reinitMatrix(int d1, int d2,
						 int d3, int d4, int d5)
{
  int dim[5];
  dim[0]=d1; dim[1]=d2; dim[2]=d3; dim[3]=d4; dim[4]=d5;
  initMatrix(5,dim);
}

template<class T> void Matrix<T> :: printMatrixLong
(FILE *f, char **names, void(*print)(FILE *f, T item))
{
  for (int i=0; i<size; i++) {
    int coords[order];
    int idx = i;
    for(int j=order-1; j>=0; j--) {
      coords[j] = idx % dimensions[j];
      idx /= dimensions[j];
    }
    assert(idx==0);
    const char *sep = "";
    for(int j=0; j<order; j++) {
      fprintf(f, sep);
      sep = " ";
      if(names) {
	fprintf(f, "%s=", names[j]);
      }
      fprintf(f,"%d", coords[j]);
    }
    fprintf(f, " ");
    print(f, a[i]);
    fprintf(f, "\n");
  }
}



#endif



