/***************************************************************************
 *   Copyright (C) 2006 by Weintraub, Benkard, Van Roy and Jeziorski       *
 *   przemekj@stanford.edu                                                 *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

#include <iostream>
#include <fstream>
#include <cstdlib>
#include <math.h>
#include <compoe/compoe.h>
#include <compoe/compoe_aggr.h>
#include <mex.h>
#include "matrix.h"
#include <string.h>

/***********************************************************
 * Computes aggregate shocks equilibrium bounds            *
 *                                                         *
 * INPUT:                                                  *
 *    1) Output filename - string                          *
 *    2) Control parameters - vector                       *
 *    3) Profit function parameters - vector               *
 *    4) Profit function parameters for ipopt - vector     *
 *    5) State transition parameters - vector              *
 *    6) Model Constants - vector                          *
 *    7) Model Initial values - vector                     *
 *    8) Convergence tolerance values - vector             *
 *    9) Bounds parameters - vector                        *
 *   10) Aggregate shocks specific parameters - vector     *
 *  Aggreate shocks equlibrium                             *
 *   11) Lambda - vector (shocks)                          *
 *   12) Iota - Matrix (shocks,xmax)                       *
 *   13) Rho - Matrix (shocks,xmax)                        *
 *   14) Initial Prices - Matrix (shocks,xmax)             *
 *   15) Tildes - Matrix (shocks,xmax)                     *
 *   16) Transition Matrix - Sparse matrix                 *
 *                                                         *
 * OUTPUT:                                                 *
 *    1) Bound 1                                           *
 *    2) Bound 2                                           *
 *    3) Bound 3
 ***********************************************************/

void mexFunction (int nlhs,
  mxArray *plhs[],
  int nrhs,
  const mxArray *prhs[]) {

  if(nrhs!=16) {
    mexErrMsgTxt("You must provide 15 parameters");
  }
  char *file=new char[mxGetN(prhs[0])+1];
  mxGetString(prhs[0], file, mxGetN(prhs[0])+1);

  ofstream myfile;
  myfile.open(file);


  double *control = mxGetPr(prhs[1]);
  double *prof_fun = mxGetPr(prhs[2]);
  double *prof_fun_ipopt = mxGetPr(prhs[3]);
  double *transition = mxGetPr(prhs[4]);
  double *constants = mxGetPr(prhs[5]);
  double *init = mxGetPr(prhs[6]);
  double *convtol = mxGetPr(prhs[7]);
  double *bounds = mxGetPr(prhs[8]);
  double *aggrparams = mxGetPr(prhs[9]);
  int xmax = int(aggrparams[9]);
  int shocks = int(mxGetN(prhs[15]));

/*  int size = mxGetN(prhs[10]);
  if(size!=xmax) {
    myfile << "Size of iota is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of iota");
  }
  size = mxGetN(prhs[11]);
  if(size!=xmax) {
    myfile << "Size of rho is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of rho");
  }
  size = mxGetN(prhs[12]);
  if(size!=xmax) {
    myfile << "Size of V is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of V");
  }
  size = mxGetN(prhs[13]);
  if(size!=xmax) {
    myfile << "Size of prices is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of prices");
  }

  size = mxGetN(prhs[14]);
  if(size!=xmax) {
    myfile << "Size of tildes is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of tildes");
  }*/

  VectorMTL<double> lambdaVector(shocks);
  MatrixMTL<double> iotaMatrix(shocks,xmax);
  MatrixMTL<double> rhoMatrix(shocks,xmax);
  MatrixMTL<double> sMatrix(shocks,xmax);
  MatrixMTL<double> pMatrix(shocks,xmax);

  for(int i=0;i<shocks;i++) {
    lambdaVector[i]=(mxGetPr(prhs[10]))[i];
  }
  int offset=0;
  for(int j=0;j<xmax;j++) {
    for(int i=0;i<shocks;i++) {
      iotaMatrix[i][j]=(mxGetPr(prhs[11]))[offset+i];
      rhoMatrix[i][j]=(mxGetPr(prhs[12]))[offset+i];
      pMatrix[i][j]=(mxGetPr(prhs[13]))[offset+i];
      sMatrix[i][j]=(mxGetPr(prhs[14]))[offset+i];
    }
    offset+=shocks;
  }
/*  #if ( __WORDSIZE == 32 )
    int *ir = mxGetIr(prhs[15]);
    int *jc = mxGetJc(prhs[15]);
  #elif ( __WORDSIZE == 64 )
    mwIndex *ir = mxGetIr(prhs[15]);
    mwIndex *jc = mxGetJc(prhs[15]);
  #else
    #error Compatible with gcc on 32/64 bit systems only
  #endif*/
  
  mwIndex *ir = mxGetIr(prhs[15]);
  mwIndex *jc = mxGetJc(prhs[15]);

  double *c =mxGetPr(prhs[15]);
  int dim = mxGetN(prhs[15]);
  int nz = jc[dim];

  SparseMatrixMTL<double> shockTran(dim,dim,nz);

  int row_i = 0;
  for(int i=0;i<nz;i++) {
    shockTran.columns[i]=ir[i];
    if(jc[row_i]==i) row_i++;
    shockTran.rows[i]=row_i-1;
    shockTran[i]=c[i];
  }


  CompOE_aggr *oe = new CompOE_aggr();
  oe->setOutput(&myfile);

  oe->initialize(control, prof_fun, prof_fun_ipopt, transition, constants, init, 
      convtol, aggrparams);


  oe->computeBounds_aggr(bounds, iotaMatrix, rhoMatrix, sMatrix, lambdaVector, pMatrix, shockTran);

  double *output;

  switch(nlhs) {
    case(3): {
      const char *field_names[] = {"average", "precision", "variance"};

/*      #if ( __WORDSIZE == 32 )
        int dims[1] = {1};
      #elif ( __WORDSIZE == 64 )
        mwSize dims[1] = {1};
      #else
        #error Compatible with gcc on 32/64 bit systems only
      #endif*/
     
      mwSize dims[1] = {1};
 
      plhs[2] = mxCreateStructArray(1, dims, 3, field_names);
      mxArray *field_value;

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> average(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[2],0,0,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> precision(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[2],0,1,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> variance(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[2],0,2,field_value);

      oe->getBound3_aggr(average, precision, variance);

      average.setData(NULL);
      precision.setData(NULL);
      variance.setData(NULL);
    }

    case(2): {
      const char *field_names[] = {"average", "precision", "variance"};

/*      #if ( __WORDSIZE == 32 )
        int dims[1] = {1};
      #elif ( __WORDSIZE == 64 )
        mwSize dims[1] = {1};
      #else
        #error Compatible with gcc on 32/64 bit systems only
      #endif*/

      mwSize dims[1] = {1};

      plhs[1] = mxCreateStructArray(1, dims, 3, field_names);
      mxArray *field_value;

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> average(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[1],0,0,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> precision(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[1],0,1,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> variance(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[1],0,2,field_value);

      oe->getBound2_aggr(average, precision, variance);

      average.setData(NULL);
      precision.setData(NULL);
      variance.setData(NULL);
    }

    case(1): {
      const char *field_names[] = {"average", "precision", "variance"};

/*      #if ( __WORDSIZE == 32 )
        int dims[1] = {1};
      #elif ( __WORDSIZE == 64 )
        mwSize dims[1] = {1};
      #else
        #error Compatible with gcc on 32/64 bit systems only
      #endif
*/
      mwSize dims[1] = {1};

      plhs[0] = mxCreateStructArray(1, dims, 3, field_names);
      mxArray *field_value;

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> average(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[0],0,0,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> precision(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[0],0,1,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> variance(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[0],0,2,field_value);

      oe->getBound1_aggr(average, precision, variance);

      average.setData(NULL);
      precision.setData(NULL);
      variance.setData(NULL);
    }
  }

  delete oe;

  myfile.close();
  delete[] file;
}
