/***************************************************************************
 *   Copyright (C) 2006 by Weintraub, Benkard, Van Roy and Jeziorski       *
 *   przemekj@stanford.edu                                                 *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

#include <iostream>
#include <fstream>
#include <cstdlib>
#include <math.h>
#include <compoe/compoe.h>
#include <compoe/compoe_nonstat.h>
#include <mat.h>
#include <string.h>

/***********************************************************
 * Computes non-stationary equilibrium for given oe        *
 *                                                         *
 * INPUT:                                                  *
 *    1) Output filename - string                          *
 *    2) Control parameters - vector                       *
 *    3) Profit function parameters - vector               *
 *    4) Profit function parameters for ipopt - vector     *
 *    5) State transition parameters - vector              *
 *    6) Model Constants - vector                          *
 *    7) Model Initial values - vector                     *
 *    8) Convergence tolerance values - vector             *
 *    9) Bounds parameters - vector                        *
 *   10) Non-stationary eq. specific parameters - vector   *
 *  Oblivious equlibrium                                   *
 *   11) Lambda - vector (Tbar)                            *
 *   12) Iota - Matrix (Tbar,xmax)                         *
 *   13) Rho - Matrix (Tbar,xmax)                          *
 *   14) Initial Prices - Matrix (Tbar,xmax)               *
 *   15) Tildes - Matrix (Tbar+1,xmax)                     *
 *                                                         *
 * OUTPUT:                                                 *
 *    1) Bound 1 - Average - vector                        *
 *    2) Bound 1 - Precision - vector                      *
 *    3) Bound 1 - Variance - vector                       *
 ***********************************************************/

mxArray *readvar(MATFile *mfp, const char *var_name);

ofstream myfile;

void mexFunction (int nlhs,
  mxArray *plhs[],
  int nrhs,
  const mxArray *prhs[]) {

  if(nrhs!=15) {
    mexErrMsgTxt("You must provide 16 parameters");
  }
  char *file=new char[mxGetN(prhs[0])+1];
  mxGetString(prhs[0], file, mxGetN(prhs[0])+1);

  myfile.open(file);


  double *control = mxGetPr(prhs[1]);
  double *prof_fun = mxGetPr(prhs[2]);
  double *prof_fun_ipopt = mxGetPr(prhs[3]);
  double *transition = mxGetPr(prhs[4]);
  double *constants = mxGetPr(prhs[5]);
  double *init = mxGetPr(prhs[6]);
  double *convtol = mxGetPr(prhs[7]);
  double *bounds = mxGetPr(prhs[8]);
  double *nonstatparams = mxGetPr(prhs[9]);
  int xmax = int(nonstatparams[12]);
  int Tbar = int(nonstatparams[10]);

  MatrixMTL<double> *A;
  if (matGetVariable(mfp, "A") == NULL) {
    A = NULL;
  } else {
    mxArray *Amx = readvar(mfp, "A");
    double *Apr = mxGetPr(Amx);
    int shockn = mxGetN(Amx);

    A = new MatrixMTL<double>(shockn,shockn);
    int z=0;
    for(int i=0;i<shockn;++i) {
      for(int j=0;j<shockn;++j) {
        (*A)[j][i]=Apr[z];
        z++;
      }
    }
  }

/*  int size = mxGetN(prhs[10]);
  if(size!=xmax) {
    myfile << "Size of iota is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of iota");
  }
  size = mxGetN(prhs[11]);
  if(size!=xmax) {
    myfile << "Size of rho is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of rho");
  }
  size = mxGetN(prhs[12]);
  if(size!=xmax) {
    myfile << "Size of V is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of V");
  }
  size = mxGetN(prhs[13]);
  if(size!=xmax) {
    myfile << "Size of prices is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of prices");
  }

  size = mxGetN(prhs[14]);
  if(size!=xmax) {
    myfile << "Size of tildes is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of tildes");
  }*/

  VectorMTL<double> lambdaVector(Tbar+1);
  MatrixMTL<double> iotaMatrix(Tbar+1,xmax);
  MatrixMTL<double> rhoMatrix(Tbar+1,xmax);
  MatrixMTL<double> sMatrix(Tbar+2,xmax);
  MatrixMTL<double> pMatrix(Tbar+1,xmax);

  for(int i=0;i<Tbar+1;i++) {
    lambdaVector[i]=(mxGetPr(prhs[10]))[i];
  }
  int offset=0;
  for(int j=0;j<xmax;j++) {
    for(int i=0;i<Tbar+1;i++) {
      iotaMatrix[i][j]=(mxGetPr(prhs[11]))[offset+i];
      rhoMatrix[i][j]=(mxGetPr(prhs[12]))[offset+i];
      pMatrix[i][j]=(mxGetPr(prhs[13]))[offset+i];
    }
    offset+=Tbar+1;
  }
  offset=0;
  for(int j=0;j<xmax;j++) {
    for(int i=0;i<Tbar+2;i++) {
      sMatrix[i][j]=(mxGetPr(prhs[14]))[offset+i];
    }
    offset+=(Tbar+2);
  }
  CompOE_nonstat *oe = new CompOE_nonstat();
  oe->setOutput(&myfile);

  oe->initialize(control, prof_fun, prof_fun_ipopt, transition, constants, init, 
      convtol, nonstatparams, A);

  oe->computeBoundsNonstat(bounds, iotaMatrix, rhoMatrix, sMatrix, lambdaVector, pMatrix);

  double *output;

  switch(nlhs) {
    // Industry statistics //
    case(4): {
      const char *field_names[] = { "cs", "ps", "ts", "inv" };

/*      #if ( __WORDSIZE == 32 )
        int dims[1] = {1};
      #elif ( __WORDSIZE == 64 )
        mwSize dims[1] = {1};
      #else
        #error Compatible with gcc on 32/64 bit systems only
      #endif
*/
      mwSize dims[1] = {1};

      plhs[3] = mxCreateStructArray(1, dims, 4, field_names);
      mxArray *field_value;

      field_value = mxCreateDoubleMatrix(1,3,mxREAL);
      VectorMTL<double> cs(3,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[3],0,0,field_value);

      field_value = mxCreateDoubleMatrix(1,3,mxREAL);
      VectorMTL<double> ps(3,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[3],0,1,field_value);

      field_value = mxCreateDoubleMatrix(1,3,mxREAL);
      VectorMTL<double> ts(3,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[3],0,2,field_value);

      field_value = mxCreateDoubleMatrix(1,3,mxREAL);
      VectorMTL<double> inv(3,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[3],0,3,field_value);

      oe->getStats_nonstat(cs, ps, ts, inv);

      cs.setData(NULL);
      ps.setData(NULL);
      ts.setData(NULL);
      inv.setData(NULL);
    }
    // Bound 3 //
    case(3): {
      const char *field_names[] = {"average", "precision", "variance"};

/*      #if ( __WORDSIZE == 32 )
        int dims[1] = {1};
      #elif ( __WORDSIZE == 64 )
        mwSize dims[1] = {1};
      #else
        #error Compatible with gcc on 32/64 bit systems only
      #endif
*/
      mwSize dims[1] = {1};

      plhs[2] = mxCreateStructArray(1, dims, 3, field_names);
      mxArray *field_value;

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> average(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[2],0,0,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> precision(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[2],0,1,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> variance(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[2],0,2,field_value);

      oe->getBound3_nonstat(average, precision, variance);

      average.setData(NULL);
      precision.setData(NULL);
      variance.setData(NULL);
    }
    // Bound 2 //
    case(2): {
      const char *field_names[] = {"average", "precision", "variance"};
/*
      #if ( __WORDSIZE == 32 )
        int dims[1] = {1};
      #elif ( __WORDSIZE == 64 )
        mwSize dims[1] = {1};
      #else
        #error Compatible with gcc on 32/64 bit systems only
      #endif
*/
      mwSize dims[1] = {1};

      plhs[1] = mxCreateStructArray(1, dims, 3, field_names);
      mxArray *field_value;

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> average(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[1],0,0,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> precision(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[1],0,1,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> variance(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[1],0,2,field_value);

      oe->getBound2_nonstat(average, precision, variance);

      average.setData(NULL);
      precision.setData(NULL);
      variance.setData(NULL);
    }

    case(1): {
      const char *field_names[] = {"average", "precision", "variance"};
/*
      #if ( __WORDSIZE == 32 )
        int dims[1] = {1};
      #elif ( __WORDSIZE == 64 )
        mwSize dims[1] = {1};
      #else
        #error Compatible with gcc on 32/64 bit systems only
      #endif
*/

      mwSize dims[1] = {1};

      plhs[0] = mxCreateStructArray(1, dims, 3, field_names);
      mxArray *field_value;

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> average(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[0],0,0,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> precision(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[0],0,1,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> variance(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[0],0,2,field_value);

      oe->getBound1_nonstat(average, precision, variance);

      average.setData(NULL);
      precision.setData(NULL);
      variance.setData(NULL);
    }
  }

  if(A!=NULL) delete A;

  delete oe;

  myfile.close();
  delete[] file;
}

mxArray *readvar(MATFile *mfp, const char *var_name) {
    mxArray *array_ptr;
    int *from, *to;

    if ((array_ptr = matGetVariable(mfp, var_name)) == NULL) {
        cout << "Cannot read matlab variable " << var_name << endl;
        myfile << "Cannot read matlab variable " << var_name << endl;
        exit(1);
    }
    if ((mxGetNumberOfElements(array_ptr) == 1) && mxIsNumeric(array_ptr)) {
      #ifdef COUT_DEBUG
        cout << "Reading matlab variable " << var_name << "=" << mxGetScalar(array_ptr) << endl;
      #endif
      #ifdef DEBUG
        myfile << "Reading matlab variable " << var_name << "=" << mxGetScalar(array_ptr) << endl;
      #endif
    } else {
        #ifdef COUT_DEBUG
          cout << "Reading matlab variable " << var_name << " (";
        #endif
        #ifdef DEBUG
          myfile << "Reading matlab variable " << var_name << " (";
        #endif
        from = (int *) mxGetDimensions(array_ptr);
        to = from+mxGetNumberOfDimensions(array_ptr)-1;
        for (; from<to; from++) {
        #ifdef COUT_DEBUG
          cout << *from << "x";
        #endif
        #ifdef DEBUG
          myfile << *from << "x";
        #endif
        }
        #ifdef COUT_DEBUG
          cout << *from << ")"<< endl;
        #endif
        #ifdef DEBUG
          myfile << *from << ")"<< endl;
        #endif
    }

    return array_ptr;
}
