/***************************************************************************
 *   Copyright (C) 2006 by Weintraub, Benkard, Van Roy and Jeziorski       *
 *   przemekj@stanford.edu                                                 *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

#include <iostream>
#include <fstream>
#include <cstdlib>
#include <math.h>
#include <compoe/compoe.h>
#include <compoe/compoe_dom.h>
#include "matrix.h"
#include <string.h>
#include "mat.h"
#include<matrixMTL.h>

/*************************************************************
 * Computes aggregate shocks equilibrium bounds              *
 *                                                           *
 * INPUT:                                                    *
 *    1) Output filename - string                            *
 *    2) Control parameters - vector                         *
 *    3) Profit function parameters - vector                 *
 *    4) Profit function parameters for ipopt - vector       *
 *    5) State transition parameters - vector                *
 *    6) Model Constants - vector                            *
 *    7) Model Initial values - vector                       *
 *    8) Convergence tolerance values - vector               *
 *    9) Bounds parameters - vector                          *
 *   10) Dominant firms specific parameters - vector         *
 *  Aggreate shocks equlibrium                               *
 *   11) Lambda - vector (xmax_dom)                          *
 *   12) Iota - Matrix (xmax_dom,xmax)                       *
 *   13) Iota Dom - Vector (xmax_dom)                        *
 *   14) Rho - Matrix (xmax_dom,xmax)                        *
 *   15) Initial Prices - Matrix (xmax_dom,xmax)             *
 *   16) Tildes - Matrix (xmax_dom,xmax)                     *
 *   17) Stationary distr. of dom firms - Vector (xmax_dom)  *
 *                                                           *
 * OUTPUT:                                                   *
 *    1) Bound 1                                             *
 *    2) Bound 2                                             *
 *    3) Bound 3                                             *
 *************************************************************/

#define NAME_LENGTH 30

mxArray *readvar(MATFile *mfp, const char *var_name);

using namespace std;
using namespace MTL;


ofstream myfile;

void copyMatrix(MatrixMTL<double> &out, double *in) {
  int y = out.getHeight();
  int x = out.getLength();
  
  int offset=0;
  for(int j=0;j<x;j++) {
    for(int i=0;i<y;i++) {
      out[i][j]=in[offset+i];
    }
    offset+=y;
  }
}

int main(int argc, char *argv[]) {
  if(argc<3) {
    cout << "The executable needs two arguments: output and input .mat file names" << endl;
    exit(1);
  }
  char filename[NAME_LENGTH];
  MATFile *mfp; 
  
  strncpy(filename, argv[1], NAME_LENGTH);
  
  if ((mfp = matOpen(filename, "r")) == NULL) {
    cout << "Cannot open MAT-file " << filename << endl;
    exit(1);
  } else {
    cout << "Opening MAT-file " << filename << endl;
  }

  mxArray *var = readvar(mfp, "filename");

  char *file=new char[mxGetN(var)+1];
  mxGetString(var, file, mxGetN(var)+1);

  myfile.open(file);

  double *control = mxGetPr(readvar(mfp, "control"));
  double *prof_fun = mxGetPr(readvar(mfp, "prof_fun"));
  double *prof_fun_ipopt = mxGetPr(readvar(mfp, "prof_fun_ipopt"));
  double *transition = mxGetPr(readvar(mfp, "transition"));
  double *constants = mxGetPr(readvar(mfp, "constants"));
  double *init = mxGetPr(readvar(mfp, "init"));
  double *convtol = mxGetPr(readvar(mfp, "convtol"));
  double *bounds = mxGetPr(readvar(mfp, "bounds"));
  double *dom = mxGetPr(readvar(mfp, "dom"));
  double *shocklevels = mxGetPr(readvar(mfp, "shocklevels"));
  int dom_firm_number = int(dom[10]);

  /* readvar the xmax and number of states */
  if(dom_firm_number>0) {
    var = readvar(mfp, "iotaMatrix");
  } else {
    var = readvar(mfp, "iota");
  }
  int xmax = int(mxGetN(var));
  int shocks = int(mxGetM(var));

  
  /* copy */
  MatrixMTL<double> iotaMatrix(shocks,xmax);
  copyMatrix(iotaMatrix, mxGetPr(var));

  MatrixMTL<double> rhoMatrix(shocks,xmax);
  VectorMTL<double> Q(shocks);
  MatrixMTL<double> sMatrix(shocks,xmax);
  MatrixMTL<double> pricesInit(shocks,xmax);
  copyMatrix(pricesInit, mxGetPr(readvar(mfp, "pricesInit")));
  VectorMTL<double> lambdaVector(shocks);
  int others;
  if(dom_firm_number>0) {  
    /* copy the rest of matrices */
    copyMatrix(rhoMatrix, mxGetPr(readvar(mfp, "rhoMatrix")));
    copyMatrix(sMatrix, mxGetPr(readvar(mfp, "sMatrix")));
    copyMatrix(pricesInit, mxGetPr(readvar(mfp, "pricesInit")));
  
    /* Dominant firm strategy matrix, just size, copy later */
    var = readvar(mfp, "domIotaMatrix");
    others = int(mxGetM(var));
  
    /* Some vectors */
    double *temp = mxGetPr(readvar(mfp, "lambdaVector"));
    for(int i=0;i<shocks;i++) {
      lambdaVector[i]=temp[i];
    }
    temp = mxGetPr(readvar(mfp, "Q"));
    for(int i=0;i<shocks;i++) {
      Q[i]=temp[i];
    }
  } else {
    copyMatrix(rhoMatrix, mxGetPr(readvar(mfp, "rho")));
    copyMatrix(sMatrix, mxGetPr(readvar(mfp, "tildes")));
    lambdaVector[0]=mxGetPr(readvar(mfp, "lambda"))[0];
    others=1;
  }
  
  MatrixMTL<double> iotaMatrixDom(others,xmax);
  if(dom_firm_number>0) {
    var = readvar(mfp, "domIotaMatrix");
    copyMatrix(iotaMatrixDom, mxGetPr(var));
  }

  var = readvar(mfp, "A");
  if(!mxIsSparse(var)) {
      cout << "Transition matrix must be sparse." << endl;
      exit(1);
  }

  mwIndex *ir = mxGetIr(var);
  mwIndex *jc = mxGetJc(var);

  double *c =mxGetPr(var);
  int dim = mxGetN(var);
  int nz = jc[dim];

  SparseMatrixMTL<double> shockTran(dim,dim,nz);

  int row_i = 0;
  for(int i=0;i<nz;i++) {
    shockTran.columns[i]=ir[i];
    if(jc[row_i]==i) row_i++;
    shockTran.rows[i]=row_i-1;
    shockTran[i]=c[i];
  }

  CompOE_dom *oe = new CompOE_dom;
  oe->setOutput(&myfile);
    
  oe->initialize(control, prof_fun, prof_fun_ipopt, transition, constants, init, 
      convtol, dom, shocklevels);

  oe->computeBounds_dom(bounds, iotaMatrix, iotaMatrixDom, rhoMatrix, sMatrix, lambdaVector, pricesInit, Q, shockTran);
  
  strncpy(filename, argv[2], NAME_LENGTH);
  if ((mfp = matOpen(filename, "w")) == NULL) {
    cout << "Cannot open MAT-file " << filename << endl;
    exit(1);
  } else {
    cout << "Opening MAT-file " << filename << endl;
  }
  mxArray *output;
  double *outputpr;
  
  const char *field_names[] = {"bound2", "bound3"};
  mwSize dims[1] = {1};
  mxArray *statsOE = mxCreateStructArray(1, dims, 2, field_names);

  const char *subfield_names[] = {"avr", "pre", "var"};
  mxArray *statsOEsub;
  mxArray *field_value;
  double *field_value_pr;

  /* Fill in the vector */
  VectorMTL<double> average;
  VectorMTL<double> precision;
  VectorMTL<double> variance;
  
  /* Bound 2 */
  oe->getBound2_dom(average, precision, variance);
  
  statsOEsub = mxCreateStructArray(1, dims, 3, subfield_names);
  mxSetFieldByNumber(statsOE,0,0,statsOEsub);

  field_value = mxCreateDoubleMatrix(xmax,1,mxREAL);
  field_value_pr=mxGetPr(field_value);
  for(int i=0;i<xmax;i++) {
    field_value_pr[i]=average[i];
  }
  mxSetFieldByNumber(statsOEsub,0,0,field_value);

  field_value = mxCreateDoubleMatrix(xmax,1,mxREAL);
  field_value_pr=mxGetPr(field_value);
  for(int i=0;i<xmax;i++) {
    field_value_pr[i]=precision[i];
  }
  mxSetFieldByNumber(statsOEsub,0,1,field_value);

  field_value = mxCreateDoubleMatrix(xmax,1,mxREAL);
  field_value_pr=mxGetPr(field_value);
  for(int i=0;i<xmax;i++) {
    field_value_pr[i]=variance[i];
  }
  mxSetFieldByNumber(statsOEsub,0,2,field_value);

  /* Bound 3 */
  oe->getBound3_dom(average, precision, variance);

  statsOEsub = mxCreateStructArray(1, dims, 3, subfield_names);
  mxSetFieldByNumber(statsOE,0,1,statsOEsub);

  field_value = mxCreateDoubleMatrix(xmax,1,mxREAL);
  field_value_pr=mxGetPr(field_value);
  for(int i=0;i<xmax;i++) {
    field_value_pr[i]=average[i];
  }
  mxSetFieldByNumber(statsOEsub,0,0,field_value);

  field_value = mxCreateDoubleMatrix(xmax,1,mxREAL);
  field_value_pr=mxGetPr(field_value);
  for(int i=0;i<xmax;i++) {
    field_value_pr[i]=precision[i];
  }
  mxSetFieldByNumber(statsOEsub,0,1,field_value);

  field_value = mxCreateDoubleMatrix(xmax,1,mxREAL);
  field_value_pr=mxGetPr(field_value);
  for(int i=0;i<xmax;i++) {
    field_value_pr[i]=variance[i];
  }
  mxSetFieldByNumber(statsOEsub,0,2,field_value);

  matPutVariable(mfp, "boundsPOE", statsOE);
  
  if (matClose(mfp) != 0) {
    printf("Error closing file %s\n",file);
    return(EXIT_FAILURE);
  }
  
  myfile.close();
  delete[] file;

  printf("Done\n");
  return(EXIT_SUCCESS);
}

mxArray *readvar(MATFile *mfp, const char *var_name) {
    mxArray *array_ptr;
    int *from, *to;

    if ((array_ptr = matGetVariable(mfp, var_name)) == NULL) {
        cout << "Cannot read matlab variable " << var_name << endl;
        myfile << "Cannot read matlab variable " << var_name << endl;
        exit(1);
    }
    if ((mxGetNumberOfElements(array_ptr) == 1) && mxIsNumeric(array_ptr)) {
      #ifdef COUT_DEBUG
        cout << "Reading matlab variable " << var_name << "=" << mxGetScalar(array_ptr) << endl;
      #endif
      #ifdef DEBUG
        myfile << "Reading matlab variable " << var_name << "=" << mxGetScalar(array_ptr) << endl;
      #endif
    } else {
        #ifdef COUT_DEBUG
          cout << "Reading matlab variable " << var_name << " (";
        #endif
        #ifdef DEBUG
          myfile << "Reading matlab variable " << var_name << " (";
        #endif
        from = (int *) mxGetDimensions(array_ptr);
        to = from+mxGetNumberOfDimensions(array_ptr)-1;
        for (; from<to; from++) {
        #ifdef COUT_DEBUG
          cout << *from << "x";
        #endif
        #ifdef DEBUG
          myfile << *from << "x";
        #endif
        }
        #ifdef COUT_DEBUG
          cout << *from << ")"<< endl;
        #endif
        #ifdef DEBUG
          myfile << *from << ")"<< endl;
        #endif
    }

    return array_ptr;
}
