/***************************************************************************
 *   Copyright (C) 2006 by Weintraub, Benkard, Van Roy and Jeziorski       *
 *   przemekj@stanford.edu                                                 *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

#include <iostream>
#include <fstream>
#include <cstdlib>
#include <math.h>
#include <compoe/compoe.h>
#include <compoe/compoe_dom.h>
#include "matrix.h"
#include <string.h>
#include "mat.h"
#include<matrixMTL.h>

/*************************************************************
 * Computes aggregate shocks equilibrium bounds              *
 *                                                           *
 * INPUT:                                                    *
 *    1) Output filename - string                            *
 *    2) Control parameters - vector                         *
 *    3) Profit function parameters - vector                 *
 *    4) Profit function parameters for ipopt - vector       *
 *    5) State transition parameters - vector                *
 *    6) Model Constants - vector                            *
 *    7) Model Initial values - vector                       *
 *    8) Convergence tolerance values - vector               *
 *    9) Bounds parameters - vector                          *
 *   10) Dominant firms specific parameters - vector         *
 *  Aggreate shocks equlibrium                               *
 *   11) Lambda - vector (xmax_dom)                          *
 *   12) Iota - Matrix (xmax_dom,xmax)                       *
 *   13) Iota Dom - Vector (xmax_dom)                        *
 *   14) Rho - Matrix (xmax_dom,xmax)                        *
 *   15) Initial Prices - Matrix (xmax_dom,xmax)             *
 *   16) Tildes - Matrix (xmax_dom,xmax)                     *
 *   17) Stationary distr. of dom firms - Vector (xmax_dom)  *
 *                                                           *
 * OUTPUT:                                                   *
 *    1) Bound 1                                             *
 *    2) Bound 2                                             *
 *    3) Bound 3                                             *
 *************************************************************/

#define NAME_LENGTH 50

mxArray *readvar(MATFile *mfp, const char *var_name);

using namespace std;
using namespace MTL;


ofstream myfile;

void copyMatrix(MatrixMTL<double> &out, double *in) {
  int y = out.getHeight();
  int x = out.getLength();
  
  int offset=0;
  for(int j=0;j<x;j++) {
    for(int i=0;i<y;i++) {
      out[i][j]=in[offset+i];
    }
    offset+=y;
  }
}

int main(int argc, char *argv[]) {
  if(argc<3) {
    cout << "The executable needs two arguments: output and input .mat file names" << endl;
    exit(1);
  }
  char filename[NAME_LENGTH];
  MATFile *mfp; 
  
  strncpy(filename, argv[1], NAME_LENGTH);
  
  if ((mfp = matOpen(filename, "r")) == NULL) {
    cout << "Cannot open MAT-file " << filename << endl;
    exit(1);
  } else {
    cout << "Opening MAT-file " << filename << endl;
  }

  mxArray *var = readvar(mfp, "filename");

  char *file=new char[mxGetN(var)+1];
  mxGetString(var, file, mxGetN(var)+1);

  myfile.open(file);

  double *control = mxGetPr(readvar(mfp, "control"));
  double *prof_fun = mxGetPr(readvar(mfp, "prof_fun"));
  double *prof_fun_ipopt = mxGetPr(readvar(mfp, "prof_fun_ipopt"));
  double *transition = mxGetPr(readvar(mfp, "transition"));
  double *constants = mxGetPr(readvar(mfp, "constants"));
  double *init = mxGetPr(readvar(mfp, "init"));
  double *convtol = mxGetPr(readvar(mfp, "convtol"));
  double *bounds = mxGetPr(readvar(mfp, "bounds"));
  double *dom = mxGetPr(readvar(mfp, "dom"));
  double *shocklevels = mxGetPr(readvar(mfp, "shocklevels"));
  int dom_firm_number = int(dom[10]);

  /* readvar the xmax and number of states */
  if(dom_firm_number>0) {
    var = readvar(mfp, "iotaMatrix");
  } else {
    var = readvar(mfp, "iota");
  }
  int xmax = int(mxGetN(var));
  int shocks = int(mxGetM(var));

  VectorMTL<double> s0_init(xmax);
  int w0_init;
  if (matGetVariable(mfp, "s0") == NULL) {
    s0_init[0]=-1;
  } else {
    double *s0_initpr = mxGetPr(readvar(mfp, "s0"));
    for(int i=0;i<xmax;i++) { 
      s0_init[i]=s0_initpr[i];
    }
    w0_init = mxGetPr(readvar(mfp, "w0"))[0];
  }

  
  /* copy */
  MatrixMTL<double> iotaMatrix(shocks,xmax);
  copyMatrix(iotaMatrix, mxGetPr(var));

  MatrixMTL<double> rhoMatrix(shocks,xmax);
  VectorMTL<double> Q(shocks);
  MatrixMTL<double> sMatrix(shocks,xmax);
  MatrixMTL<double> pricesInit(shocks,xmax);
  copyMatrix(pricesInit, mxGetPr(readvar(mfp, "pricesInit")));
  VectorMTL<double> lambdaVector(shocks);
  int others;
  if(dom_firm_number>0) {  
    /* copy the rest of matrices */
    copyMatrix(rhoMatrix, mxGetPr(readvar(mfp, "rhoMatrix")));
    copyMatrix(sMatrix, mxGetPr(readvar(mfp, "sMatrix")));
    copyMatrix(pricesInit, mxGetPr(readvar(mfp, "pricesInit")));
  
    /* Dominant firm strategy matrix, just size, copy later */
    var = readvar(mfp, "domIotaMatrix");
    others = int(mxGetM(var));
  
    /* Some vectors */
    double *temp = mxGetPr(readvar(mfp, "lambdaVector"));
    for(int i=0;i<shocks;i++) {
      lambdaVector[i]=temp[i];
    }
    temp = mxGetPr(readvar(mfp, "Q"));
    for(int i=0;i<shocks;i++) {
      Q[i]=temp[i];
    }
  } else {
    copyMatrix(rhoMatrix, mxGetPr(readvar(mfp, "rho")));
    copyMatrix(sMatrix, mxGetPr(readvar(mfp, "tildes")));
    lambdaVector[0]=mxGetPr(readvar(mfp, "lambda"))[0];
    others=1;
  }
  MatrixMTL<double> iotaMatrixDom(others,xmax);
  if(dom_firm_number>0) {
    var = readvar(mfp, "domIotaMatrix");
    copyMatrix(iotaMatrixDom, mxGetPr(var));
  }

  var = readvar(mfp, "A");
  if(!mxIsSparse(var)) {
      cout << "Transition matrix must be sparse." << endl;
      exit(1);
  }
  
  mwIndex *ir = mxGetIr(var);
  mwIndex *jc = mxGetJc(var);
  
  double *c =mxGetPr(var);
  int dim = mxGetN(var);
  int nz = jc[dim];
    
  SparseMatrixMTL<double> shockTran(dim,dim,nz);
  
  int row_i = 0;
  for(int i=0;i<nz;i++) {
    shockTran.columns[i]=ir[i];
    if(jc[row_i]==i) row_i++;
    shockTran.rows[i]=row_i-1;
    shockTran[i]=c[i];
  }

  CompOE_dom *oe = new CompOE_dom;
  oe->setOutput(&myfile);
    
  oe->initialize(control, prof_fun, prof_fun_ipopt, transition, constants, init, 
      convtol, dom, shocklevels);

  oe->computeStats_dom(bounds, iotaMatrix, iotaMatrixDom, rhoMatrix, sMatrix, lambdaVector, pricesInit, Q, shockTran, s0_init, w0_init);
/*
  double *output;

  switch(nlhs) {
    case(1): {
      plhs[0] = mxCreateDoubleMatrix(8,3,mxREAL);
      double *output = mxGetPr(plhs[0]);

      MatrixMTL<double> output_init(8,3);

      oe->getStats(output_init);

      int k=0;
      for(int i=0;i<3;i++) {
        for(int j=0;j<8;j++) {
          output[k]=output_init[j][i];
          k++;
        }
      }
    }
  }

  delete oe;*/
  
  strncpy(filename, argv[2], NAME_LENGTH);
  if ((mfp = matOpen(filename, "w")) == NULL) {
    cout << "Cannot open MAT-file " << filename << endl;
    exit(1);
  } else {
    cout << "Opening MAT-file " << filename << endl;
  }
  mxArray *output;
  double *outputpr;
  
  const char *field_names[] = {"cs", "ps", "ts", "c1", "c2", "c4", "c10", "c20", "hhi", "entrants", "investment", "firms_number"};
  mwSize dims[1] = {1};
  mxArray *statsOE = mxCreateStructArray(1, dims, 12, field_names);

  const char *subfield_names[] = {"avr", "pre", "var"};
  mxArray *statsOEsub;
  mxArray *field_value;
  double *field_value_pr;

  MatrixMTL<double> output_init(11,3);
  oe->getStats(output_init);

  /* Fill in the one-dimentional fields */
  for(int i=0;i<11;i++) {
    /* Create a subarray */
    statsOEsub = mxCreateStructArray(1, dims, 3, subfield_names);
    /* Plant a subarray */
    mxSetFieldByNumber(statsOE,0,i,statsOEsub);
    /* Fill in subarray */
    for(int j=0;j<3;j++) {
       /* Create value */
       field_value=mxCreateDoubleMatrix(1,1,mxREAL);
       field_value_pr=mxGetPr(field_value);
       /* Fill in value */
       field_value_pr[0]=output_init[i][j];
       /* Plant value */
       mxSetFieldByNumber(statsOEsub,0,j,field_value);
    }
  }
  cout << "done\n";
  
  /* Fill in the vector */

  /* Expected number of firms */
  VectorMTL<double> firmsnumberavg(xmax);
  VectorMTL<double> firmsnumberpre(xmax);
  VectorMTL<double> firmsnumbervar(xmax);
  
  oe->getFirmNumber(firmsnumberavg,firmsnumberpre,firmsnumbervar);

  cout << firmsnumberavg;
  statsOEsub = mxCreateStructArray(1, dims, 3, subfield_names);
  mxSetFieldByNumber(statsOE,0,11,statsOEsub);

  field_value = mxCreateDoubleMatrix(xmax,1,mxREAL);
  field_value_pr=mxGetPr(field_value);
  for(int i=0;i<xmax;i++) {
    field_value_pr[i]=firmsnumberavg[i];
  }
  mxSetFieldByNumber(statsOEsub,0,0,field_value);

  field_value = mxCreateDoubleMatrix(xmax,1,mxREAL);
  field_value_pr=mxGetPr(field_value);
  for(int i=0;i<xmax;i++) {
    field_value_pr[i]=firmsnumberpre[i];
  }
  mxSetFieldByNumber(statsOEsub,0,1,field_value);

  field_value = mxCreateDoubleMatrix(xmax,1,mxREAL);
  field_value_pr=mxGetPr(field_value);
  for(int i=0;i<xmax;i++) {
    field_value_pr[i]=firmsnumbervar[i];
  }
  mxSetFieldByNumber(statsOEsub,0,2,field_value);

  matPutVariable(mfp, "statsPOE", statsOE);

  int discounted=0;
  if(s0_init[0]>-1) {
    if(dom[21]>0) {
      discounted=1;
    }
  }


  if(discounted==1) {
    int M = int(dom[12]);

    statsOE = mxCreateStructArray(1, dims, 12, field_names);

    MatrixMTL<double> output_init_path(11*M,3);

    oe->getStats_path(output_init_path, M);

    int z=0;
    for(int i=0;i<11;i++) {
      statsOEsub = mxCreateStructArray(1, dims, 3, subfield_names);
      mxSetFieldByNumber(statsOE,0,i,statsOEsub);

      /* avg */
      mxArray *field_value1=mxCreateDoubleMatrix(1,M,mxREAL);
      double *field_value_pr1=mxGetPr(field_value1);

      /* var */
      mxArray *field_value2=mxCreateDoubleMatrix(1,M,mxREAL);
      double *field_value_pr2=mxGetPr(field_value2);

      /* pre */
      mxArray *field_value3=mxCreateDoubleMatrix(1,M,mxREAL);
      double *field_value_pr3=mxGetPr(field_value3);

      for(int k=0;k<M;++k,++z) {
        field_value_pr1[k]=output_init_path[z][0];
        field_value_pr2[k]=output_init_path[z][1];
        field_value_pr3[k]=output_init_path[z][2];
      }

      mxSetFieldByNumber(statsOEsub,0,0,field_value1);
      mxSetFieldByNumber(statsOEsub,0,1,field_value2);
      mxSetFieldByNumber(statsOEsub,0,2,field_value3);
    }


    matPutVariable(mfp, "statsPOE_path", statsOE);
  }  
  cout << "done2\n";

  /* Fringe state matrix */
  MatrixMTL<double> *fringe = oe->getStatsFringe();

  output = mxCreateDoubleMatrix(shocks,xmax,mxREAL);
  outputpr = mxGetPr(output);
  int k=0;
  for(int i=0;i<xmax;i++) {
    for(int j=0;j<shocks;j++) {
      outputpr[k]=(*fringe)[j][i];
      k++;
    }
  }
  matPutVariable(mfp, "fringeState", output);
  
  if (matClose(mfp) != 0) {
    printf("Error closing file %s\n",file);
    return(EXIT_FAILURE);
  }
  
  myfile.close();
  delete[] file;

  printf("Done\n");
  return(EXIT_SUCCESS);
}

mxArray *readvar(MATFile *mfp, const char *var_name) {
    mxArray *array_ptr;
    int *from, *to;

    if ((array_ptr = matGetVariable(mfp, var_name)) == NULL) {
        cout << "Cannot read matlab variable " << var_name << endl;
        myfile << "Cannot read matlab variable " << var_name << endl;
        exit(1);
    }
    if ((mxGetNumberOfElements(array_ptr) == 1) && mxIsNumeric(array_ptr)) {
      #ifdef COUT_DEBUG
        cout << "Reading matlab variable " << var_name << "=" << mxGetScalar(array_ptr) << endl;
      #endif
      #ifdef DEBUG
        myfile << "Reading matlab variable " << var_name << "=" << mxGetScalar(array_ptr) << endl;
      #endif
    } else {
        #ifdef COUT_DEBUG
          cout << "Reading matlab variable " << var_name << " (";
        #endif
        #ifdef DEBUG
          myfile << "Reading matlab variable " << var_name << " (";
        #endif
        from = (int *) mxGetDimensions(array_ptr);
        to = from+mxGetNumberOfDimensions(array_ptr)-1;
        for (; from<to; from++) {
        #ifdef COUT_DEBUG
          cout << *from << "x";
        #endif
        #ifdef DEBUG
          myfile << *from << "x";
        #endif
        }
        #ifdef COUT_DEBUG
          cout << *from << ")"<< endl;
        #endif
        #ifdef DEBUG
          myfile << *from << ")"<< endl;
        #endif
    }

    return array_ptr;
}
