/***************************************************************************
 *   Copyright (C) 2006 by Weintraub, Benkard, Van Roy and Jeziorski       *
 *   przemekj@stanford.edu                                                 *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

#define USE_MUMPS
#include <iostream>
#include <fstream>
#include <cstdlib>
#include <math.h>
#include <compoe/compoe.h>
#include <compoe/compoe_aggr.h>
#include "matrix.h"
#include <string.h>
#include "mat.h"
#include "setup.h"

#define NAME_LENGTH 30

mxArray *read(MATFile *mfp, const char *var_name);

ofstream myfile;

void transposeCopy(double *in, double *out, int m, int n);

int main(int argc, char *argv[]) {
  if(argc<3) {
    cout << "The executable needs two arguments: output and input .mat file names" << endl;
    exit(1);
  }
  char filename[NAME_LENGTH];
  MATFile *mfp;

  strncpy(filename, argv[1], NAME_LENGTH);

  if ((mfp = matOpen(filename, "r")) == NULL) {
    #ifdef COUT_DEBUG
      cout << "Cannot open MAT-file " << filename << endl;
    #endif
    exit(1);
  } else {
    #ifdef COUT_DEBUG
      cout << "Opening MAT-file " << filename << endl;
    #endif
  }

  mxArray *var = read(mfp, "filename");

  char *file=new char[mxGetN(var)+1];
  mxGetString(var, file, mxGetN(var)+1);

  myfile.open(file);

  mxArray *controlmx = read(mfp, "control");
  mxArray *prof_funmx = read(mfp, "prof_fun");
  mxArray *prof_fun_ipoptmx = read(mfp, "prof_fun_ipopt");
  mxArray *transitionmx = read(mfp, "transition");
  mxArray *constantsmx = read(mfp, "constants");
  mxArray *initmx = read(mfp, "init");
  mxArray *convtolmx = read(mfp, "convtol");
  mxArray *boundsmx = read(mfp, "bounds");
  mxArray *aggrparamsmx = read(mfp, "aggr");

  double *control = mxGetPr(controlmx);
  double *prof_fun = mxGetPr(prof_funmx);
  double *prof_fun_ipopt = mxGetPr(prof_fun_ipoptmx);
  double *transition = mxGetPr(transitionmx);
  double *constants = mxGetPr(constantsmx);
  double *init = mxGetPr(initmx);
  double *convtol = mxGetPr(convtolmx);
  double *bounds = mxGetPr(boundsmx);
  double *aggrparams = mxGetPr(aggrparamsmx);

  int xmax = int(aggrparams[9]);

  if(control[2]>0) {
    transition[0]=1e-16;
  }
 
  double lambda_oe = mxGetPr(read(mfp, "lambda"))[0];

  VectorMTL<double> iota_oe(xmax);
  var = read(mfp, "iota");
  int size = mxGetN(var);
  if(size!=xmax) {
    myfile << "Size of iota is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    cout << "Size of iota is " << size << ". Should be " << xmax << ". Exiting..." << endl; 
    exit(1);
  }
  for(int i=0;i<xmax;i++) {
    iota_oe[i]=(mxGetPr(var))[i];
  }

  VectorMTL<double> rho_oe(xmax);
  var = read(mfp, "rho");
  size = mxGetN(var);
  if(size!=xmax) {
    myfile << "Size of rho is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    cout << "Size of rho is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    exit(1);
  }
  for(int i=0;i<xmax;i++) {
    rho_oe[i]=(mxGetPr(var))[i];
  }

  VectorMTL<double> V_oe(xmax);
  var = read(mfp, "V");
  size = mxGetN(var);
  if(size!=xmax) {
    myfile << "Size of V is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    cout << "Size of V is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    exit(1);
  }
  for(int i=0;i<xmax;i++) {
    V_oe[i]=(mxGetPr(var))[i];
  }

  VectorMTL<double> prices_oe(xmax);
  var = read(mfp, "prices");
  size = mxGetN(var);
  if(size!=xmax) {
    myfile << "Size of prices is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    cout << "Size of prices is " << size << ". Should be " << xmax << ". Exiting..." << endl;
    exit(1);
  }
  for(int i=0;i<xmax;i++) {
    prices_oe[i]=(mxGetPr(var))[i];
  }

  myfile << "Aggregate shocks OE Solver" << endl;
  myfile << "----------------------------" << endl << endl;

  myfile << "\nProfit Function Parameters:" << endl;
  for(int ii=0;ii<mxGetM(prof_funmx);ii++) myfile << "  " << ii << ": " << prof_fun[ii] << endl;

  myfile << "\nProfit Function Parameters for Ipopt:" << endl;
  for(int ii=0;ii<mxGetM(prof_fun_ipoptmx);ii++) myfile << "  " << ii << ": " << prof_fun_ipopt[ii] << endl;
 
  myfile << "\n\nDynamic Parameters:\n\n";
  for(int ii=0;ii<mxGetM(constantsmx);ii++) myfile << "  " << ii << ": " << constants[ii] << endl;

  myfile << "\n\nTransition Parameters:\n\n";
  for(int ii=0;ii<mxGetM(transitionmx);ii++) myfile << "  " << ii << ": " << transition[ii] << endl;

  myfile << "\n\nAggregate Shocks Parameters:\n\n";
  for(int ii=0;ii<mxGetM(aggrparamsmx);ii++) myfile << "  " << ii << ": " << aggrparams[ii] << endl;

  if(control[2]>0) {
    myfile << "  Fixed Number of Firms (No entry/No exit): " << control[2] << endl;
  } else {
    if(control[1]==1) {
        myfile << "  Poisson Entry Process" << endl;
    } else {
        myfile << "  Constant Entry Process" << endl;
    }
  }

  myfile << "\nParameters:\n\n";
  myfile << "  xmax: " << xmax << endl;
  myfile << "  Entry Rate: " << lambda_oe << endl;

  CompOE_aggr *oe_aggr = new CompOE_aggr;
  oe_aggr->setOutput(&myfile);

  oe_aggr->initialize(control, prof_fun, prof_fun_ipopt, transition, constants, init, 
      convtol, aggrparams);

  var = read(mfp, "A");
  if(!mxIsSparse(var)) {
      cout << "Transition matrix must be sparse." << endl;
      exit(1);
  }

  /*#if ( __WORDSIZE == 32 )
    int *ir = mxGetIr(prhs[14]);
    int *jc = mxGetJc(prhs[14]);
  #elif ( __WORDSIZE == 64 )
    mwIndex *ir = mxGetIr(prhs[14]);
    mwIndex *jc = mxGetJc(prhs[14]);
  #else
    #error Compatible with gcc on 32/64 bit systems only
  #endif*/
  mwIndex *ir = mxGetIr(var);
  mwIndex *jc = mxGetJc(var);

  double *c =mxGetPr(var);
  int dim = mxGetN(var);
  int nz = jc[dim];

  SparseMatrixMTL<double> shockTran(dim,dim,nz);

  int row_i = 0;
  for(int i=0;i<nz;i++) {
    shockTran.columns[i]=ir[i];
    if(jc[row_i]==i) row_i++;
    shockTran.rows[i]=row_i-1;
    shockTran[i]=c[i];
  }

  if (matClose(mfp) != 0) {
    printf("Error closing file %s\n",filename);
    return(EXIT_FAILURE);
  }

  oe_aggr->compute_aggr(xmax,iota_oe,rho_oe,lambda_oe,shockTran);
  int shocks_states = oe_aggr->getShocksSize();

  MatrixMTL<double> *pricesStart;
  MatrixMTL<double> *tildesMatrix;

  pricesStart = oe_aggr->getPricesInitAggr();
  tildesMatrix = oe_aggr->getTildesAggr();

  strncpy(filename, argv[2], NAME_LENGTH);
  if ((mfp = matOpen(filename, "w")) == NULL) {
    #ifdef COUT_DEBUG
      cout << "Cannot open MAT-file " << filename << endl;
    #endif
    #ifdef DEBUG
      myfile << "Cannot open MAT-file " << filename << endl;
    #endif
    exit(1);
  } else {
    #ifdef COUT_DEBUG
      cout << "Opening MAT-file " << filename << endl;
    #endif
    #ifdef DEBUG
      myfile << "Opening MAT-file " << filename << endl;
    #endif
  }
  mxArray *output;
  double *outputpr;

  // s Matrix //
  output = mxCreateDoubleMatrix(xmax,shocks_states,mxREAL);
  outputpr = mxGetPr(output);
  memcpy(outputpr,tildesMatrix->getData(),shocks_states*xmax*sizeof(double));
  matPutVariable(mfp, "sMatrix", output);

  // V Matrix
  output = mxCreateDoubleMatrix(xmax,shocks_states,mxREAL);
  outputpr = mxGetPr(output);
  memcpy(outputpr,oe_aggr->getVAggr()->getData(),shocks_states*xmax*sizeof(double));
  matPutVariable(mfp, "VMatrix", output);
  
  //
  output = mxCreateDoubleMatrix(1,shocks_states,mxREAL);
  outputpr = mxGetPr(output);
  memcpy(outputpr,oe_aggr->getInvariant()->getData(),shocks_states*sizeof(double));
  matPutVariable(mfp, "Q", output);

  output = mxCreateDoubleMatrix(1,shocks_states,mxREAL);
  outputpr = mxGetPr(output);
  memcpy(outputpr,oe_aggr->getLambdaAggr()->getData(),shocks_states*sizeof(double));
  matPutVariable(mfp, "lambdaVector", output);
  
  #ifdef ACW
  mxArray *cell = mxCreateCellMatrix(shocks_states,1);
  for(int w=0;w<shocks_states;++w) {
    output = mxCreateDoubleMatrix(xmax,(int) constants[6],mxREAL);
    outputpr = mxGetPr(output);
    mxSetCell(cell,w,output);
    memcpy(outputpr,oe_aggr->getActionProfit(w)->getData(),xmax*constants[6]*sizeof(double));
  }
  matPutVariable(mfp, "action_profit", cell);

  MatrixMTL3D<double> *tranFringe = oe_aggr->getTranFringe();
  mwSize dims[1];
  dims[0] = tranFringe->getDeep();
  mxArray *output2;
  output2 = mxCreateCellArray(1, dims);
  for(int i=0;i<dims[0];++i) {
    output = mxCreateDoubleMatrix(xmax,xmax,mxREAL);
    outputpr = mxGetPr(output);
    transposeCopy(outputpr,(*tranFringe)[i].getData(),xmax,xmax);
    mxSetCell(output2, i, output);
  }
  matPutVariable(mfp, "tranFringe", output2);
  #endif

  if (matClose(mfp) != 0) {
    printf("Error closing file %s\n",filename);
    return(EXIT_FAILURE);
  }


  // Check if we need statistics
/*
  switch(nlhs) {
    case(9): {
      plhs[8] = mxCreateDoubleMatrix(xmax,shocks_states,mxREAL);
      output = mxGetPr(plhs[8]);
      MatrixMTL<double> *temp = oe_aggr->getPricesInitAggr();
      memcpy(output,temp->data,shocks_states*xmax*sizeof(double));
    }

    case(8): {
      // Create the array
      const char *field_names[] = {"profit", "prices", "ms", "prodsur", "conssur", "totalsur", "c"};
      mwSize dims[1] = {1};

      plhs[7] = mxCreateStructArray(1, dims, 7, field_names);
      // Create the matrices and scalars
      mxArray *field_value;
      field_value = mxCreateDoubleMatrix(shocks_states,xmax,mxREAL); // profit
      MatrixMTL<double> profit(xmax,shocks_states,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[7],0,0,field_value);

      field_value = mxCreateDoubleMatrix(shocks_states,xmax,mxREAL); // prices
      MatrixMTL<double> prices(xmax,shocks_states,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[7],0,1,field_value);

      field_value = mxCreateDoubleMatrix(shocks_states,xmax,mxREAL); // ms
      MatrixMTL<double> ms(xmax,shocks_states,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[7],0,2,field_value);

      field_value = mxCreateDoubleMatrix(shocks_states,1,mxREAL); // prodsur
      VectorMTL<double> prodsur(shocks_states,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[7],0,3,field_value);

      field_value = mxCreateDoubleMatrix(shocks_states,1,mxREAL); // cosssur
      VectorMTL<double> conssur(shocks_states,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[7],0,4,field_value);

      field_value = mxCreateDoubleMatrix(shocks_states,1,mxREAL); // totalsur
      VectorMTL<double> totalsur(shocks_states,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[7],0,5,field_value);

      field_value = mxCreateDoubleMatrix(shocks_states,4,mxREAL); // c
      MatrixMTL<double> c(4,shocks_states,mxGetPr(field_value));
      mxSetFieldByNumber(plhs[7],0,6,field_value);

      VectorMTL<double> state_w(xmax);
      VectorMTL<double> profit_w(xmax);
      VectorMTL<double> prices_w(xmax);
      VectorMTL<double> pricesStart_w(xmax);
      VectorMTL<double> ms_w(xmax);
      double prodsur_w;
      double conssur_w;
      double totalsur_w;
      VectorMTL<double> c_w(4);

      VectorMTL<double> q(shocks_states);
      oe_aggr->getInvariant(q);
      // Compute stats
      for(int w=0;w<shocks_states;w++) {
        if(q[w]>0) {
          state_w = tildesMatrix->getRow(w);
          pricesStart_w = pricesStart->getRow(w);

          oe_aggr->compProfAndStats_aggr(state_w,pricesStart_w,prices_w,profit_w,ms_w,prodsur_w,conssur_w,w);
        } else {
          state_w = -1;
          prices_w = -1;
          ms_w = -1;
          prodsur_w = -1;
          conssur_w = -1;
        }

        profit.setColumn(w,profit_w);
        prices.setColumn(w,prices_w);
        ms.setColumn(w,ms_w);
        prodsur[w]=prodsur_w;
        conssur[w]=conssur_w;
        totalsur[w]=prodsur_w+conssur_w;        

        oe_aggr->concentrationRatios(ms_w, state_w, c_w[0], c_w[1], c_w[2], c_w[3]);
        c.setColumn(w,c_w);
      }

      // Reset the data pointers so that the will be no double free()
      profit.setData(NULL);
      prices.setData(NULL);
      ms.setData(NULL);
      prodsur.setData(NULL);
      conssur.setData(NULL);
      totalsur.setData(NULL);
      c.setData(NULL);
    }
    // Expected value function //
    case(7): {
      plhs[6] = mxCreateDoubleMatrix(xmax,shocks_states,mxREAL);
      output = mxGetPr(plhs[6]);
      MatrixMTL<double> *temp = oe_aggr->getVExpectedAggr();
      memcpy(output,temp->data,shocks_states*xmax*sizeof(double));
    }
    // Profit Matrix //
    case(6): {
      plhs[5] = mxCreateDoubleMatrix(xmax,shocks_states,mxREAL);
      output = mxGetPr(plhs[5]);
      MatrixMTL<double> *temp = oe_aggr->getProfitAggr();
      memcpy(output,temp->data,shocks_states*xmax*sizeof(double));
    }
    // s Matrix //
    case(5): {
      plhs[4] = mxCreateDoubleMatrix(xmax,shocks_states,mxREAL);
      output = mxGetPr(plhs[4]);
      MatrixMTL<double> *temp = tildesMatrix;
      memcpy(output,temp->data,shocks_states*xmax*sizeof(double));
    }
    // Lambda Matrix //
    case(4): {
      plhs[3] = mxCreateDoubleMatrix(shocks_states,1,mxREAL);
      output = mxGetPr(plhs[3]);
      VectorMTL<double> *temp = oe_aggr->getLambdaAggr();
      memcpy(output,temp->data,shocks_states*sizeof(double));
    }
    // Rho Matrix //
    case(3): {
      plhs[2] = mxCreateDoubleMatrix(xmax,shocks_states,mxREAL);
      output = mxGetPr(plhs[2]);
      MatrixMTL<double> *temp = oe_aggr->getRhoAggr(); 
      memcpy(output,temp->getData(),shocks_states*xmax*sizeof(double));
   }
    // Iota Matrix //
    case(2): {
      plhs[1] = mxCreateDoubleMatrix(xmax,shocks_states,mxREAL);
      output = mxGetPr(plhs[1]);
      MatrixMTL<double> *temp = oe_aggr->getIotaAggr();
      memcpy(output,temp->getData(),shocks_states*xmax*sizeof(double));
    }
    // V Matrix //
    case(1): {
      plhs[0] = mxCreateDoubleMatrix(xmax,shocks_states,mxREAL);
      output = mxGetPr(plhs[0]);
      MatrixMTL<double> *temp = oe_aggr->getVAggr();
      memcpy(output,temp->getData(),shocks_states*xmax*sizeof(double));
    }

    default: {}
  }
  */
  delete oe_aggr;

  myfile.close();
  delete[] file;
}

mxArray *read(MATFile *mfp, const char *var_name) {
    mxArray *array_ptr;
    int *from, *to;

    if ((array_ptr = matGetVariable(mfp, var_name)) == NULL) {
        cout << "Cannot read matlab variable " << var_name << endl;
        myfile << "Cannot read matlab variable " << var_name << endl;
        exit(1);
    }
    if ((mxGetNumberOfElements(array_ptr) == 1) && mxIsNumeric(array_ptr)) {
      #ifdef COUT_DEBUG
        cout << "Reading matlab variable " << var_name << "=" << mxGetScalar(array_ptr) << endl;
      #endif
      #ifdef DEBUG
        myfile << "Reading matlab variable " << var_name << "=" << mxGetScalar(array_ptr) << endl;
      #endif
    } else {
        #ifdef COUT_DEBUG
          cout << "Reading matlab variable " << var_name << " (";
        #endif
        #ifdef DEBUG
          myfile << "Reading matlab variable " << var_name << " (";
        #endif
        from = (int *) mxGetDimensions(array_ptr);
        to = from+mxGetNumberOfDimensions(array_ptr)-1;
        for (; from<to; from++) {
        #ifdef COUT_DEBUG
          cout << *from << "x";
        #endif
        #ifdef DEBUG
          myfile << *from << "x";
        #endif
        }
        #ifdef COUT_DEBUG
          cout << *from << ")"<< endl;
        #endif
        #ifdef DEBUG
          myfile << *from << ")"<< endl;
        #endif
    }

    return array_ptr;
}


void transposeCopy(double *out, double *in, int n, int m) {
  int k=0,l;
  for(int i=0;i<m;i++) {
    l=i;
    for(int j=0;j<n;j++) {
      out[k]=in[l];
      l+=m;
      k++;
    }
  }
}
