/***************************************************************************
 *   Copyright (C) 2006 by Weintraub, Benkard, Van Roy and Jeziorski       *
 *   przemekj@stanford.edu                                                 *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

#include <iostream>
#include <fstream>
#include <cstdlib>
#include <math.h>
#include <compoe/compoe.h>
#include <compoe/compoe_nonstat.h>
#include <mex.h>
#include <string.h>

/***********************************************************
 * Computes non-stationary equilibrium for given oe        *
 *                                                         *
 * INPUT:                                                  *
 *    1) Output filename - string                          *
 *    2) Control parameters - vector                       *
 *    3) Profit function parameters - vector               *
 *    4) Profit function parameters for Ipopt - vector     *
 *    5) State transition parameters - vector              *
 *    6) Model Constants - vector                          *
 *    7) Model Initial values - vector                     *
 *    8) Convergence tolerance values - vector             *
 *    9) Bounds parameters - vector                        *
 *  Oblivious equlibrium (vectors must be of size xmax):   *
 *   10) xmax - int                                        *
 *   11) Lambda - int                                      *
 *   12) Iota - vector                                     *
 *   13) Rho - vector                                      *
 *   14) Value function - vector                           *
 *   15) Initial Prices - vector                           *
 *   16) Tildes - vector                                   *
 *                                                         *
 * OUTPUT:                                                 *
 *    1) Bound 1                                           *
 *    2) Bound 2                                           *
 *    3) Bound 3                                           *
 *    4) Bound 4                                           *
 *    5) Bound 5                                           *
 ***********************************************************/

void mexFunction (int nlhs,
  mxArray *plhs[],
  int nrhs,
  const mxArray *prhs[]) {

  if(nrhs!=16) {
    mexErrMsgTxt("You must provide 16 parameters");
  }
  char *file=new char[mxGetN(prhs[0])+1];
  mxGetString(prhs[0], file, mxGetN(prhs[0])+1);

  ofstream myfile;
  myfile.open(file);


  double *control = mxGetPr(prhs[1]);
  double *prof_fun = mxGetPr(prhs[2]);
  double *prof_fun_ipopt = mxGetPr(prhs[3]);
  double *transition = mxGetPr(prhs[4]);
  double *constants = mxGetPr(prhs[5]);
  double *init = mxGetPr(prhs[6]);
  double *convtol = mxGetPr(prhs[7]);
  double *bounds = mxGetPr(prhs[8]);

  int xmax = int(mxGetPr(prhs[9])[0]);

  double lambda_oe = mxGetPr(prhs[10])[0];

  int size = mxGetN(prhs[11]);
  if(size!=xmax) {
    myfile << "Size of iota is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of iota");
  }
  size = mxGetN(prhs[12]);
  if(size!=xmax) {
    myfile << "Size of rho is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of rho");
  }
  size = mxGetN(prhs[13]);
  if(size!=xmax) {
    myfile << "Size of V is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of V");
  }
  size = mxGetN(prhs[14]);
  if(size!=xmax) {
    myfile << "Size of prices is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of prices");
  }

  size = mxGetN(prhs[15]);
  if(size!=xmax) {
    myfile << "Size of tildes is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    mexErrMsgTxt("Wrong size of tildes");
  }

  VectorMTL<double> iota_oe(xmax);
  VectorMTL<double> rho_oe(xmax);
  VectorMTL<double> V_oe(xmax);
  VectorMTL<double> tildes(xmax);
  VectorMTL<double> prices_end(xmax);

  for(int i=0;i<xmax;i++) {
    iota_oe[i]=(mxGetPr(prhs[11]))[i];
    rho_oe[i]=(mxGetPr(prhs[12]))[i];
    V_oe[i]=(mxGetPr(prhs[13]))[i];
    prices_end[i]=(mxGetPr(prhs[14]))[i];
    tildes[i]=(mxGetPr(prhs[15]))[i];
  }

  CompOE *oe = new CompOE();
  oe->setOutput(&myfile);

  oe->initialize(control, prof_fun, prof_fun_ipopt, transition, constants, init, convtol);

  oe->computeBounds(bounds, iota_oe, rho_oe, V_oe, tildes, lambda_oe, prices_end);

  switch(nlhs) {
    // Industry stats //
    case(6): {
      plhs[5] = mxCreateDoubleMatrix(8,3,mxREAL);
      double *output = mxGetPr(plhs[5]);

      MatrixMTL<double> output_init(8,3);

      oe->getStats(output_init);

      int k=0;
      for(int i=0;i<3;i++) {
        for(int j=0;j<8;j++) {
          output[k]=output_init[j][i];
          k++;
        }
      }
    }

    // Bound 5 //
    case(5): {
      const char *field_names[] = {"average", "precision", "variance", "relative"};

/*      #if ( __WORDSIZE == 32 )
        int dims[1] = {1};
      #elif ( __WORDSIZE == 64 )
        mwSize dims[1] = {1};
      #else
        #error Compatible with gcc on 32/64 bit systems only
      #endif*/
      mwSize dims[1] = {1};

      plhs[4] = mxCreateStructArray(1, dims, 4, field_names);
      mxArray *field_value;

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> average(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[4],0,0,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> precision(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[4],0,1,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> variance(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[4],0,2,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> relative(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[4],0,3,field_value);

      oe->getBound5(average, precision, variance, relative);

      relative.setData(NULL);
      average.setData(NULL);
      precision.setData(NULL);
      variance.setData(NULL);
    }
    case(4): {
      const char *field_names[] = {"average", "precision", "variance", "relative"};

/*      #if ( __WORDSIZE == 32 )
        int dims[1] = {1};
      #elif ( __WORDSIZE == 64 )
        mwSize dims[1] = {1};
      #else
        #error Compatible with gcc on 32/64 bit systems only
      #endif*/
      mwSize dims[1] = {1};

      plhs[3] = mxCreateStructArray(1, dims, 4, field_names);
      mxArray *field_value;

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> average(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[3],0,0,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> precision(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[3],0,1,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> variance(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[3],0,2,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> relative(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[3],0,3,field_value);

      oe->getBound4(average, precision, variance, relative);

      relative.setData(NULL);
      average.setData(NULL);
      precision.setData(NULL);
      variance.setData(NULL);
    }

    // Bound 3 //
    case(3): {
      const char *field_names[] = {"average", "precision", "variance", "relative"};
      
/*      #if ( __WORDSIZE == 32 )
        int dims[1] = {1};
      #elif ( __WORDSIZE == 64 )
        mwSize dims[1] = {1};
      #else
        #error Compatible with gcc on 32/64 bit systems only
      #endif*/
      mwSize dims[1] = {1};

      plhs[2] = mxCreateStructArray(1, dims, 4, field_names);
      mxArray *field_value;

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> average(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[2],0,0,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> precision(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[2],0,1,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> variance(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[2],0,2,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> relative(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[2],0,3,field_value);

      oe->getBound3(average, precision, variance, relative);

      relative.setData(NULL);
      average.setData(NULL);
      precision.setData(NULL);
      variance.setData(NULL);
    }

    // Bound 2 //
    case(2): {
      const char *field_names[] = {"average", "precision", "variance", "relative"};

/*      #if ( __WORDSIZE == 32 )
        int dims[1] = {1};
      #elif ( __WORDSIZE == 64 )
        mwSize dims[1] = {1};
      #else
        #error Compatible with gcc on 32/64 bit systems only
      #endif*/
      mwSize dims[1] = {1};

      plhs[1] = mxCreateStructArray(1, dims, 4, field_names);
      mxArray *field_value;

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> average(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[1],0,0,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> precision(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[1],0,1,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> variance(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[1],0,2,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> relative(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[1],0,3,field_value);

      oe->getBound2(average, precision, variance, relative);

      relative.setData(NULL);
      average.setData(NULL);
      precision.setData(NULL);
      variance.setData(NULL);
    }

    // Bound 1 //
    case(1): {
      const char *field_names[] = {"average", "precision", "variance", "relative"};

/*      #if ( __WORDSIZE == 32 )
        int dims[1] = {1};
      #elif ( __WORDSIZE == 64 )
        mwSize dims[1] = {1};
      #else
        #error Compatible with gcc on 32/64 bit systems only
      #endif*/
      mwSize dims[1] = {1};

      plhs[0] = mxCreateStructArray(1, dims, 4, field_names);
      mxArray *field_value;

      field_value = mxCreateDoubleMatrix(1,1,mxREAL); 
      double *average = mxGetPr(field_value);
      mxSetFieldByNumber(plhs[0],0,0,field_value);

      field_value = mxCreateDoubleMatrix(1,1,mxREAL); 
      double *precision = mxGetPr(field_value);
      mxSetFieldByNumber(plhs[0],0,1,field_value);

      field_value = mxCreateDoubleMatrix(1,1,mxREAL); 
      double *variance = mxGetPr(field_value);
      mxSetFieldByNumber(plhs[0],0,2,field_value);

      field_value = mxCreateDoubleMatrix(1,xmax,mxREAL); 
      VectorMTL<double> relative(xmax,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[0],0,3,field_value);

      oe->getBound1(*average, *precision, *variance, relative);

      relative.setData(NULL);
    }
  }

  delete oe;

  myfile.close();
  delete[] file;
}
