/***************************************************************************
 *   Copyright (C) 2006 by Jeziorski, Weintraub, Benkard and Van Roy       *
 *   przemekj@stanford.edu                                                 *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/
#ifndef TRIDIAGMATRIX_H
#define TRIDIAGMATRIX_H

#include<iostream>
//#include<vectorMTL.h>
//#include<matrixMTL.h>


/**
	@author Jeziorski, Weintraub, Benkard and Van Roy <przemekj@stanford.edu>
*/

using namespace std;

namespace MTL {

template <class T> class VectorMTL;
template <class T> class MatrixMTL; // Forward refence for MatrixMTL class. It is needed because of cross referencing.
template <class T> class TriDiagMatrixMTL;

template <typename T> inline const TriDiagMatrixMTL<T> operator+(const TriDiagMatrixMTL<T>&, const TriDiagMatrixMTL<T>&);
template <typename T> inline const TriDiagMatrixMTL<T> operator+(const TriDiagMatrixMTL<T>&, const T);
template <typename T> inline const TriDiagMatrixMTL<T> operator+(const T, const TriDiagMatrixMTL<T>&);
template <typename T> inline const TriDiagMatrixMTL<T> operator-(const TriDiagMatrixMTL<T>&, const TriDiagMatrixMTL<T>&);
template <typename T> inline const TriDiagMatrixMTL<T> operator-(const TriDiagMatrixMTL<T>&, const T);
template <typename T> inline const TriDiagMatrixMTL<T> operator-(const T, const TriDiagMatrixMTL<T>&);
template <typename T> inline const TriDiagMatrixMTL<T> operator*(const TriDiagMatrixMTL<T>&, const TriDiagMatrixMTL<T>&);
template <typename T> inline const TriDiagMatrixMTL<T> operator*(const TriDiagMatrixMTL<T>&, const T);
template <typename T> inline const TriDiagMatrixMTL<T> operator*(const T, const TriDiagMatrixMTL<T>&);
template <typename T> inline const TriDiagMatrixMTL<T> operator/(const TriDiagMatrixMTL<T>&, const TriDiagMatrixMTL<T>&);
template <typename T> inline const TriDiagMatrixMTL<T> operator/(const TriDiagMatrixMTL<T>&, const T);
template <typename T> inline const TriDiagMatrixMTL<T> operator/(const T, const TriDiagMatrixMTL<T>&);

template <typename T> std::ostream& operator<<(ostream &output, const TriDiagMatrixMTL<T> &p);

template <typename T> T diffNorm(const TriDiagMatrixMTL<T> &input1, const TriDiagMatrixMTL<T> &input2);
template <typename T> T norm(const TriDiagMatrixMTL<T> &input);


template <class T>
class TriDiagMatrixMTL {
  friend class MatrixMTL<T>;
  friend class VectorMTL<T>;
  friend std::ostream& operator<< <T>(std::ostream &output, const TriDiagMatrixMTL<T> &p);

  friend const TriDiagMatrixMTL<T> operator+<T>(const TriDiagMatrixMTL<T>&, const TriDiagMatrixMTL<T>&);
  friend const TriDiagMatrixMTL<T> operator+<T>(const TriDiagMatrixMTL<T>&, const T);
  friend const TriDiagMatrixMTL<T> operator+<T>(const T, const TriDiagMatrixMTL<T>&);
  friend const TriDiagMatrixMTL<T> operator-<T>(const TriDiagMatrixMTL<T>&, const TriDiagMatrixMTL<T>&);
  friend const TriDiagMatrixMTL<T> operator-<T>(const TriDiagMatrixMTL<T>&, const T);
  friend const TriDiagMatrixMTL<T> operator-<T>(const T, const TriDiagMatrixMTL<T>&);
  friend const TriDiagMatrixMTL<T> operator*<T>(const TriDiagMatrixMTL<T>&, const TriDiagMatrixMTL<T>&);
  friend const TriDiagMatrixMTL<T> operator*<T>(const TriDiagMatrixMTL<T>&, const T);
  friend const TriDiagMatrixMTL<T> operator*<T>(const T, const TriDiagMatrixMTL<T>&);
  friend const TriDiagMatrixMTL<T> operator/<T>(const TriDiagMatrixMTL<T>&, const TriDiagMatrixMTL<T>&);
  friend const TriDiagMatrixMTL<T> operator/<T>(const TriDiagMatrixMTL<T>&, const T);
  friend const TriDiagMatrixMTL<T> operator/<T>(const T, const TriDiagMatrixMTL<T>&);

  friend T norm <T>(const TriDiagMatrixMTL<T> &);
  friend T diffNorm <T>(const TriDiagMatrixMTL<T> &input1, const TriDiagMatrixMTL<T> &input2);
public:
    int size;
    T *upper,*diag,*lower;

    MatrixMTL<T> dot(MatrixMTL<T> &input);
    MatrixMTL<T> dot(TriDiagMatrixMTL<T> &input);
    int inverse(MatrixMTL<T> &output);

    inline TriDiagMatrixMTL() {
      diag = NULL;
      upper = NULL;
      lower = NULL;
    }

    void init(int sizeInit) {
      size = sizeInit;
      diag = new T[size];
      upper = new T[size-1];
      lower = new T[size-1];
    }

    inline TriDiagMatrixMTL(int sizeInit) {
      size = sizeInit;
      diag = new T[size];
      upper = new T[size-1];
      lower = new T[size-1];
    }

    inline TriDiagMatrixMTL(const TriDiagMatrixMTL<T> &copy) {
      size = copy.size;
      diag = new T[size];
      upper = new T[size-1];
      lower = new T[size-1];
      for(int i=0;i<size-1;i++) {
        diag[i]=copy.diag[i];
        upper[i]=copy.upper[i];
        lower[i]=copy.lower[i];
      }
      diag[size-1]=copy.diag[size-1];
     }

    inline TriDiagMatrixMTL<T> &operator +=(const TriDiagMatrixMTL<T> &input) {
      for (int i=0; i<size-1; i++) {
        diag[i]+=input.diag[i];
        upper[i]+=input.upper[i];
        lower[i]+=input.lower[i];
      }
      diag[size-1]+=input.diag[size-1];

      return *this;
    }

    inline TriDiagMatrixMTL<T> &operator +=(const T input) {
      for (int i=0; i<size-1; i++) {
        diag[i]+=input;
        upper[i]+=input;
        lower[i]+=input;
      }
      diag[size-1]+=input;

      return *this;
    }

    inline TriDiagMatrixMTL<T> &operator *=(const TriDiagMatrixMTL<T> &input) {
      for (int i=0; i<size-1; i++) {
        diag[i]*=input.diag[i];
        upper[i]*=input.upper[i];
        lower[i]*=input.lower[i];
      }

      diag[size-1]*=input.diag[size-1];

      return *this;
    }

    inline TriDiagMatrixMTL<T> &operator *=(const T input) {
      for (int i=0; i<size-1; i++) {
        diag[i]*=input;
        upper[i]*=input;
        lower[i]*=input;
      }

      diag[size-1]*=input;

      return *this;
    }

    inline TriDiagMatrixMTL<T> &operator -=(const TriDiagMatrixMTL<T> &input) {
      for (int i=0; i<size-1; i++) {
        diag[i]-=input.diag[i];
        upper[i]-=input.upper[i];
        lower[i]-=input.lower[i];
      }

      diag[size-1]-=input.diag[size-1];

      return *this;
    }

    inline TriDiagMatrixMTL<T> &operator -=(const T input) {
      for (int i=0; i<size-1; i++) {
        diag[i]-=input;
        upper[i]-=input;
        lower[i]-=input;
      }

      diag[size-1]-=input;

      return *this;
    }


    inline TriDiagMatrixMTL<T> &operator /=(const TriDiagMatrixMTL<T> &input) {
      for (int i=0; i<size-1; i++) {
        diag[i]/=input.diag[i];
        upper[i]/=input.upper[i];
        lower[i]/=input.lower[i];
      }

      diag[size-1]/=input.diag[size-1];

      return *this;
    }

    inline TriDiagMatrixMTL<T> &operator /=(const T input) {
      for (int i=0; i<size-1; i++) {
        diag[i]/=input;
        upper[i]/=input;
        lower[i]/=input;
      }

      diag[size-1]/=input;

      return *this;
    }


    inline T *operator[](const int input) {
      switch(input) {
        case(-1) : return upper;
        case(0) : return diag;
        case(1) : return lower;
        default : { 
          cout << "Index must be -1,0 or 1" << endl;
          exit(0);
        }
      }
    }

  inline TriDiagMatrixMTL<T> &operator=(const T input) {
    for (int i=0; i<size-1; i++) {
      diag[i]=input;
      lower[i]=input;
      upper[i]=input;
    }
    diag[size-1]=input;

    return *this;
  }

  void addDiag(const T element) {
    for(int i=0;i<size;i++) {
      diag[i]+=element;
    }
  }

  inline TriDiagMatrixMTL<T> &operator=(const TriDiagMatrixMTL<T> &input) {
    for (int i=0; i<size-1; i++) {
      diag[i]=input.diag[i];
      lower[i]=input.lower[i];
      upper[i]=input.upper[i];
    }
    diag[size-1]=input.diag[size-1];

    return *this;
  }


    inline int getSize() {
      return size;
    }

    inline void transpose() {
      T *temp = lower;
      lower = upper;
      upper = temp;
    }
  /** 
	Compute stationary distribution
   */

   int statDistr(VectorMTL<T> &expfreq);


  /**
	Norm of the difference of two tridiagonal matrices
	\param input1
        \param input2
   */
  T diffNorm(MatrixMTL<T> &input1, MatrixMTL<T> &input2) {
    T norm = fabs(input1.diag[0]-input2.diag[0]);
    T temp;

    for(int i=0;i<size-1;i++) {
      if((temp=fabs(input1.upper[i]-input2.upper[i]))>norm) norm = temp;
      if((temp=fabs(input1.diag[i+1]-input2.diag[i+1]))>norm) norm = temp;
      if((temp=fabs(input1.lower[i]-input2.lower[i]))>norm) norm = temp;
    }

    return norm;
  }


    const VectorMTL<T> dot(const VectorMTL<T> &input) {
      VectorMTL<T> output(size);

      output[0]=diag[0]*input.data[0]+upper[0]*input.data[1];

      for(int i=1; i<size-1; i++) {
        output[i]=lower[i-1]*input.data[i-1]+diag[i]*input.data[i]+upper[i]*input.data[i+1];
      }

      output[size-1]=lower[size-2]*input.data[size-2]+diag[size-1]*input.data[size-1];
      return output;
    }

    const VectorMTL<T> primedot(const VectorMTL<T> &input) {
      VectorMTL<T> output(size);

      output[0]=diag[0]*input.data[0]+lower[0]*input.data[1];

      for(int i=1; i<size-1; i++) {
        output[i]=upper[i-1]*input.data[i-1]+diag[i]*input.data[i]+lower[i]*input.data[i+1];
      }

      output[size-1]=upper[size-2]*input.data[size-2]+diag[size-1]*input.data[size-1];
      return output;
    }


  void diagonal(const T input) {
    for(int i=0;i<size-1;i++) {
      diag[i]=input;
      lower=0;
      upper=0;
    }
    diag[size-1]=input;
  }

  void eye() {
    for(int i=0;i<size-1;i++) {
      diag[i]=1;
      lower=0;
      upper=0;
    }
    diag[size-1]=1;
  }

  VectorMTL<T> getRow(const int row) {
    VectorMTL<T> output(size);

    /* first row */
    if(row==0) {
      output.data[0]=diag[0];
      output.data[1]=upper[0];
      for (int i=2;i<size;i++) {
        output.data[i]=0;
      }
    } else if(row==size-1) { //Last row
      for (int i=0;i<size-2;i++) {
        output.data[i]=0;
      }
      output.data[size-2]=lower[size-2];
      output.data[size-1]=diag[size-1];
    } else { 
      for(int i=0; i<row-1; i++) {
        output.data[i]=0;
      }

      output.data[row-1] = lower[row-1];
      output.data[row] = diag[row];
      output.data[row+1] = upper[row];

      for(int i=row+2; i<size; i++) {
        output.data[i]=0;
      }
    }

    return output;
  }


  ~TriDiagMatrixMTL() {
       delete[] diag;
       delete[] upper;
       delete[] lower;
  }

};

template <class T> 
int TriDiagMatrixMTL<T>::inverse(MatrixMTL<T> &output) {
      T bet;
      VectorMTL<T> gam(size);

      if(diag[0]==0) {
        cout << "Tridiagonal matrix is singular" << endl;
        return 0;
      }

      output.data[0]=1/(bet=diag[0]);
      for(int i=1; i<size;i++) {
        output.data[i]=0;
      }

      for(int j=1;j<size; j++) {
        gam[j]=upper[j-1]/bet;
        bet=diag[j]-lower[j-1]*gam[j];
        if (bet==0) {
          cout << "Error on diagonal in tridiag matrix inversion" <<endl;
          return 0;
        }
        for(int i=0;i<size; i++) { 
          if (i==j)
            output[j][i]=(1-lower[j-1]*output[j-1][i])/bet;
          else 
            output[j][i]=-lower[j-1]*output[j-1][i]/bet;
        }
      }

      for(int j=size-2;j>=0;j--) {
        for(int i=size-1;i>=0;i--) {
          output[j][i] -= gam[j+1]*output[j+1][i];
        }
      }
  return 1;
}

template <class T> 
MatrixMTL<T> TriDiagMatrixMTL<T>::dot(MatrixMTL<T> &input) {
      int length=input.getLength();
      MatrixMTL<T> output(size,length);

      for(int j=0; j<length; j++) {
        output[0][j]=diag[0]*input[0][j]+upper[0]*input[1][j];

        for(int i=1; i<size-1; i++) {
          output[i][j]=lower[i-1]*input[i-1][j]+diag[i]*input[i][j]+upper[i]*input[i+1][j];
        }

        output[size-1][j]=lower[size-2]*input[size-2][j]+diag[size-1]*input[size-1][j];
      }
      return output;
}

// Quite ineffective but only used in bound computations maybe I will fix it later
template <class T> 
MatrixMTL<T> TriDiagMatrixMTL<T>::dot(TriDiagMatrixMTL<T> &input) {
      int size=input.getSize();
      MatrixMTL<T> temp(size,size);

      temp = 0;
      for(int i=0;i<size-1;i++) {
        temp[i][i]=diag[i];
        temp[i][i+i]=upper[i];
        temp[i+1][i]=lower[i];
      }
      temp[size-1][size-1]=diag[size-1];

      return this->dot(temp);
}


template <typename T> 
inline const TriDiagMatrixMTL<T> operator *(const TriDiagMatrixMTL<T> &self, const TriDiagMatrixMTL<T> &input) {
    return TriDiagMatrixMTL<T>(self)*=input;
}	

template <typename T> 
inline const TriDiagMatrixMTL<T> operator *(const TriDiagMatrixMTL<T> &self, const T input) {
    return MatrixMTL<T>(self)*=input;
}

template <typename T> 
inline const TriDiagMatrixMTL<T> operator *(const T input, const TriDiagMatrixMTL<T> &self) {
    return TriDiagMatrixMTL<T>(self)*=input;
}	

template <typename T> 
inline const TriDiagMatrixMTL<T> operator +(const TriDiagMatrixMTL<T> &self, const TriDiagMatrixMTL<T> &input) {
    return TriDiagMatrixMTL<T>(self)+=input;
}	

template <typename T> 
inline const TriDiagMatrixMTL<T> operator +(const T input, const TriDiagMatrixMTL<T> &self) {
    return TriDiagMatrixMTL<T>(self)+=input;
}	

template <typename T> 
inline const TriDiagMatrixMTL<T> operator +(const TriDiagMatrixMTL<T> &self, const T input) {
    return TriDiagMatrixMTL<T>(self)+=input;
}

template <typename T> 
inline const TriDiagMatrixMTL<T> operator /(const TriDiagMatrixMTL<T> &self, const TriDiagMatrixMTL<T> &input) {
    return TriDiagMatrixMTL<T>(self)/=input;
}	

template <typename T> 
inline const TriDiagMatrixMTL<T> operator /(const TriDiagMatrixMTL<T> &self, const T input) {
    return TriDiagMatrixMTL<T>(self)/=input;
}

template <typename T> 
inline const TriDiagMatrixMTL<T> operator /(const T lhs, const TriDiagMatrixMTL<T> &input) {
    TriDiagMatrixMTL<T> output(input.size);

    for (int i=0; i<input.size-1; i++) {
      output.diag[i]-=lhs/input.diag[i];
      output.lower[i]-=lhs/input.lower[i];
      output.upper[i]-=lhs/input.upper[i];
    }

    output.diag[input.size-1]-=lhs/input.diag[input.size-1];

    return output;
}	

template <typename T> 
inline const TriDiagMatrixMTL<T> operator -(const TriDiagMatrixMTL<T> &self, const TriDiagMatrixMTL<T> &input) {
    return TriDiagMatrixMTL<T>(self)-=input;
}	

template <typename T> 
inline const TriDiagMatrixMTL<T> operator -(const TriDiagMatrixMTL<T> &self, const T input) {
    return TriDiagMatrixMTL<T>(self)-=input;
}

template <typename T> 
inline const TriDiagMatrixMTL<T> operator -(const T input, const TriDiagMatrixMTL<T> &self) {
    return -TriDiagMatrixMTL<T>(self)+=input;
}	

template <typename T> std::ostream& operator<<(ostream &output, const TriDiagMatrixMTL<T> &p) {
  /* first row */

  output << p.diag[0] << "\t\t" << p.upper[0];
  for (int i=0;i<p.size-2;i++) {
    output << "\t\t0";
  }

  output << endl; 


  for (int j=1; j<p.size-1; j++) {
    for(int i=0; i<j-1; i++) {
      output << "0\t\t";
    }

    output << p.lower[j-1] << "\t\t" << p.diag[j] << "\t\t" << p.upper[j];

    for(int i=0; i<p.size-2-j; i++) {
      output << "\t\t0";
    }

    output << endl;
  }

  for (int i=0;i<p.size-2;i++) {
    output << "0\t\t";
  }

  output << p.lower[p.size-2] << "\t\t" << p.diag[p.size-1] << endl;

  return output;
}

template <typename T> 
T norm(const TriDiagMatrixMTL<T> &input) {
    T Norm=0;
    for(int i=0;i<input.size-1; i++) {
      if(input.lower[i]>Norm) Norm=input.lower[i];
      if(input.diag[i]>Norm) Norm=input.diag[i];
      if(input.upper[i]>Norm) Norm=input.upper[i];
    }

    if(input.diag[input.size-1]>Norm) Norm=input.diag[input.size-1];

    return Norm;
}

template <typename T> 
T diffNorm(const TriDiagMatrixMTL<T> &input1, const TriDiagMatrixMTL<T> &input2) {
  T norm = fabs(input1.diag[0]-input2.diag[0]);
  T temp;

  for(int i=0;i<input1.size-1;i++) {
    if((temp=fabs(input1.upper[i]-input2.upper[i]))>norm) norm = temp;
    if((temp=fabs(input1.diag[i+1]-input2.diag[i+1]))>norm) norm = temp;
    if((temp=fabs(input1.lower[i]-input2.lower[i]))>norm) norm = temp;
  }

  return norm;
}

template <typename T> 
int TriDiagMatrixMTL<T>::statDistr(VectorMTL<T> &expfreq) {
    TriDiagMatrixMTL<T> temp(size);

    for(int i=0;i<size-1;i++) {
      temp.upper[i]=lower[i];
      temp.diag[i]=diag[i]-1;
      temp.lower[i]=upper[i];
    }

    // Put something at the to make it invertable
    temp.diag[size-1]=lower[size-2]+1;

    MatrixMTL<T> inv(size,size);
    if(!(temp.inverse(inv))) {
      cout << "Cannot compute invariant distribution\n";
      exit(0);
    }

    VectorMTL<T> temp2(size);
    temp2=0;
    temp2[size-1]=1;

    expfreq = inv.dot(temp2);

    double sum=0;  // normalize
    for(int i=0;i<size;i++) {
      sum+=expfreq[i];
    }
    expfreq/=sum;

    return 1;
}
}
#endif
