/***************************************************************************
 *   Copyright (C) 2006 by Weintraub, Benkard, Van Roy and Jeziorski       *
 *   przemekj@stanford.edu                                                 *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

#define USE_MUMPS
#include <iostream>
#include <fstream>
#include <cstdlib>
#include <math.h>
#include <compoe/compoe.h>
#include <compoe/compoe_dom.h>
#include "matrix.h"
#include <string.h>
#include "mat.h"
#include "setup.h"
#include<matrixMTL.h>

/***********************************************************
 * Computes stationary partial OE                          *
 *                                                         *
 * INPUT:                                                  *
 *    1) Output filename - string                          *
 *    2) Control parameters - vector                       *
 *    3) Profit function parameters - vector               *
 *    4) Profit function parameters for ipopt - vector     *
 *    5) State transition parameters - vector              *
 *    6) Model Constants - vector                          *
 *    7) Model Initial values - vector                     *
 *    8) Convergence tolerance values - vector             *
 *    9) Partial OE parameters - vector                    *
 *   10) OE iota                                           *
 *   11) OE rho                                            *
 *                                                         *
 * OUTPUT:                                                 *
 *    1) Value function - Matrix                           *
 *    2) Iota - Matrix                                     *
 *    3) Rho - Matrix                                      *
 *    4) Lambda - vector (Tbar)                            *
 *    5) s - Matrix                                        *
 *    6) Profit matrix - Matrix                            *
 *    7) Expected value function - Matrix                  *
 *    8) Statistics structure                              *
 *    9) Last initial condition for profits                *
 ***********************************************************/
using namespace std;
using namespace MTL;

#define NAME_LENGTH 30

mxArray *readvar(MATFile *mfp, const char *var_name);

ofstream myfile;

void transposeCopy(double *in, double *out, int m, int n);

int main(int argc, char *argv[]) {
  if(argc<3) {
    cout << "The executable needs two arguments: output and input .mat file names" << endl;
    exit(1);
  }
  char filename[NAME_LENGTH];
  MATFile *mfp;

  strncpy(filename, argv[1], NAME_LENGTH);

  if ((mfp = matOpen(filename, "r")) == NULL) {
    #ifdef COUT_DEBUG
      cout << "Cannot open MAT-file " << filename << endl;
    #endif
    exit(1);
  } else {
    #ifdef COUT_DEBUG
      cout << "Opening MAT-file " << filename << endl;
    #endif
  }

  mxArray *var = readvar(mfp, "filename");

  char *file=new char[mxGetN(var)+1];
  mxGetString(var, file, mxGetN(var)+1);

  myfile.open(file);
  
  mxArray *controlmx = readvar(mfp, "control");
  mxArray *prof_funmx = readvar(mfp, "prof_fun");
  mxArray *prof_fun_ipoptmx = readvar(mfp, "prof_fun_ipopt");
  mxArray *transitionmx = readvar(mfp, "transition");
  mxArray *constantsmx = readvar(mfp, "constants");
  mxArray *initmx = readvar(mfp, "init");
  mxArray *convtolmx = readvar(mfp, "convtol");
  mxArray *boundsmx = readvar(mfp, "bounds");
  mxArray *dommx = readvar(mfp, "dom");
  mxArray *shocklevelsmx = readvar(mfp, "shocklevels");

  double *control = mxGetPr(controlmx);
  double *prof_fun = mxGetPr(prof_funmx);
  double *prof_fun_ipopt = mxGetPr(prof_fun_ipoptmx);
  double *transition = mxGetPr(transitionmx);
  double *constants = mxGetPr(constantsmx);
  double *init = mxGetPr(initmx);
  double *convtol = mxGetPr(convtolmx);
  double *bounds = mxGetPr(boundsmx);
  double *dom = mxGetPr(dommx);
  double *shocklevels = mxGetPr(shocklevelsmx);
  int dom_firm_number = int(dom[10]);
  int xmax = int(dom[9]);
  if(control[2]>0) {
    transition[0]=1e-16;
  }
  #ifdef DEBUG
  myfile << "Stationary partial OE Solver" << endl;
  myfile << "----------------------------" << endl << endl;

  myfile << "\nProfit Function Parameters:" << endl;
  for(int ii=0;ii<mxGetM(prof_funmx);ii++) myfile << "  " << ii << ": " << prof_fun[ii] << endl;

  myfile << "\nProfit Function Parameters for Ipopt:" << endl;
  for(int ii=0;ii<mxGetM(prof_fun_ipoptmx);ii++) myfile << "  " << ii << ": " << prof_fun_ipopt[ii] << endl;
 
  myfile << "\n\nDynamic Parameters:\n\n";
  for(int ii=0;ii<mxGetM(constantsmx);ii++) myfile << "  " << ii << ": " << constants[ii] << endl;

  myfile << "\n\nTransition Parameters:\n\n";
  for(int ii=0;ii<mxGetM(transitionmx);ii++) myfile << "  " << ii << ": " << transition[ii] << endl;

  myfile << "\n\nDominant firms Parameters:\n\n";
  for(int ii=0;ii<mxGetM(dommx);ii++) myfile << "  " << ii << ": " << dom[ii] << endl;

  myfile << "\n\nShock levels Parameters:\n\n";
  for(int ii=0;ii<mxGetM(shocklevelsmx);ii++) myfile << "  " << ii << ": " << shocklevels[ii] << endl;

  if(control[2]>0) {
    myfile << "  Fixed Number of Firms (No entry/No exit): " << control[2] << endl;
  } else {
    if(control[1]==1) {
        myfile << "  Poisson Entry Process" << endl;
    } else {
        myfile << "  Constant Entry Process" << endl;
    }
  }

  myfile << "\nParameters:\n\n";
  myfile << "  xmax: " << xmax << endl;
  #endif
  CompOE_dom *oe_dom = new CompOE_dom;
  oe_dom->setOutput(&myfile);

  VectorMTL<double> iotaInit(xmax);
  VectorMTL<double> rhoInit(xmax);
  VectorMTL<double> lambdaInit(xmax);
  

  double *temp;
  temp = mxGetPr(readvar(mfp, "iotaInit"));
  for(int i=0;i<xmax;i++) iotaInit[i]=temp[i];
  temp = mxGetPr(readvar(mfp, "rhoInit"));
  for(int i=0;i<xmax;i++) rhoInit[i]=temp[i];
  temp = mxGetPr(readvar(mfp, "lambdaInit"));
  for(int i=0;i<xmax;i++) lambdaInit[i]=temp[i];

  oe_dom->initialize(control, prof_fun, prof_fun_ipopt, transition, constants, init, convtol, dom, shocklevels);

  VectorMTL<double> mean_fringe_state(xmax);
  if(dom[5]) {
    temp = mxGetPr(readvar(mfp, "mean_fringe_state"));
    for(int i=0;i<xmax;i++) mean_fringe_state[i]=temp[i];
  }

  var = readvar(mfp, "A");
  if(!mxIsSparse(var)) {
      cout << "Transition matrix must be sparse." << endl;
      exit(1);
  }
  
  mwIndex *ir = mxGetIr(var);
  mwIndex *jc = mxGetJc(var);

  double *c =mxGetPr(var);
  int dim = mxGetN(var);
  int nz = jc[dim];

  SparseMatrixMTL<double> shockTran(dim,dim,nz);

  int row_i = 0;
  for(int i=0;i<nz;i++) {
    shockTran.columns[i]=ir[i];
    if(jc[row_i]==i) row_i++;
    shockTran.rows[i]=row_i-1;
    shockTran[i]=c[i];
  }

  if (matClose(mfp) != 0) {
    printf("Error closing file %s\n",filename);
    return(EXIT_FAILURE);
  }
  #ifdef COUT_DEBUG
    cout << "Init done...\n";
  #endif

  strncpy(filename, argv[2], NAME_LENGTH);
  if ((mfp = matOpen(filename, "w")) == NULL) {
    #ifdef COUT_DEBUG
      cout << "Cannot open MAT-file " << filename << endl;
    #endif
    #ifdef DEBUG
      myfile << "Cannot open MAT-file " << filename << endl;
    #endif
    exit(1);
  } else {
    #ifdef COUT_DEBUG
      cout << "Opening MAT-file " << filename << endl;
    #endif
    #ifdef DEBUG
      myfile << "Opening MAT-file " << filename << endl;
    #endif
  }
  mxArray *output;
  double *outputpr;

  int success = oe_dom->compute_dom(xmax,iotaInit,rhoInit,lambdaInit,mean_fringe_state,shockTran);
  output = mxCreateDoubleMatrix(1,1,mxREAL);
  outputpr = mxGetPr(output);
  outputpr[0] = success;
  matPutVariable(mfp, "success", output);

  if(success==0) {
    cout << "Computing of dominant OE failed" << endl;

    delete oe_dom;

    myfile.close();
    delete[] file;

    if (matClose(mfp) != 0) {
      printf("Error closing file %s\n",file);
      return(EXIT_FAILURE);
    }
    return 0;
  }


  int other_number = dom_firm_number-1;
  int xmax_dom = xmax;
  int other_dom_firms = (dom_firm_number == 1) ? 1 :
    oe_dom->newton(xmax_dom+other_number-1,other_number)/oe_dom->factorial(other_number); // Number of states of competitors of dom firm (symetric)
  int shocks_states = oe_dom->newton(xmax_dom+dom_firm_number-1,dom_firm_number)/oe_dom->factorial(dom_firm_number); // Number of states of dominant firms (symetric)


  MatrixMTL<double> *pricesStart;
  MatrixMTL<double> *tildesMatrix;

  pricesStart = oe_dom->getPricesInitDom();
  tildesMatrix = oe_dom->getTildesDom();

  MatrixMTL<double> *temp_m;

  output = mxCreateDoubleMatrix(shocks_states,dom_firm_number,mxREAL);
  outputpr = mxGetPr(output);
  int **encoding = oe_dom->getTotalEncodingDom();
  int k=0;
  for(int j=0;j<dom_firm_number;j++) {
    for(int i=0;i<shocks_states;i++) {
      outputpr[k]=encoding[i][j];
      k++;
    }
  }
  matPutVariable(mfp, "enc", output); 
 
  output = mxCreateDoubleMatrix(other_dom_firms,other_number,mxREAL);
  outputpr = mxGetPr(output);
  encoding = oe_dom->getOtherEncodingDom();
  k=0;
  for(int j=0;j<other_number;j++) {
    for(int i=0;i<other_dom_firms;i++) {
      outputpr[k]=encoding[i][j];
      k++;
    }
  }
  matPutVariable(mfp, "encOther", output);
 
  VectorMTL<double> *temp_v;
  // Stationary distribution of dominant firms
 
  output = mxCreateDoubleMatrix(shocks_states*dim,1,mxREAL);
  outputpr = mxGetPr(output);
  temp_v = oe_dom->getQ_domDom();
  memcpy(outputpr,temp_v->getData(),shocks_states*dim*sizeof(double));
  matPutVariable(mfp, "Q", output);

  // Lambda vector
  output = mxCreateDoubleMatrix(shocks_states*dim,1,mxREAL);
  outputpr = mxGetPr(output);
  temp_v = oe_dom->getLambdaDom();
  memcpy(outputpr,temp_v->getData(),shocks_states*dim*sizeof(double));
  matPutVariable(mfp, "lambdaVector", output);

  // Investment of dominant firms
  output = mxCreateDoubleMatrix(other_dom_firms*dim,xmax_dom,mxREAL);
  outputpr = mxGetPr(output);
  temp_m = oe_dom->getIota_domDom();
  transposeCopy(outputpr,temp_m->getData(),other_dom_firms*dim,xmax_dom);
  matPutVariable(mfp, "domIotaMatrix", output);

  // Value function of dominant firm
  output = mxCreateDoubleMatrix(other_dom_firms*dim,xmax_dom,mxREAL);
  outputpr = mxGetPr(output);
  temp_m = oe_dom->getV_domDom();
  transposeCopy(outputpr,temp_m->getData(),other_dom_firms*dim,xmax_dom);
  matPutVariable(mfp, "domVMatrix", output);

  // Prices
  output = mxCreateDoubleMatrix(shocks_states*dim,xmax,mxREAL);
  outputpr = mxGetPr(output);
  temp_m = oe_dom->getPricesInitDom();
  transposeCopy(outputpr,temp_m->getData(),shocks_states*dim,xmax);
  matPutVariable(mfp, "pricesInit", output);

  // Expected value function
  output = mxCreateDoubleMatrix(shocks_states*dim,xmax,mxREAL);
  outputpr = mxGetPr(output);
  temp_m = oe_dom->getVExpectedDom();
  transposeCopy(outputpr,temp_m->getData(),shocks_states*dim,xmax);
  matPutVariable(mfp, "EVMatrix", output);

  // Profit Matrix
  output = mxCreateDoubleMatrix(shocks_states*dim,xmax,mxREAL);
  outputpr = mxGetPr(output);
  temp_m = oe_dom->getProfitDom();
  transposeCopy(outputpr,temp_m->getData(),shocks_states*dim,xmax);
  matPutVariable(mfp, "piMatrix", output);

  // Prcont Matrix
  output = mxCreateDoubleMatrix(shocks_states*dim,xmax,mxREAL);
  outputpr = mxGetPr(output);
  temp_m = oe_dom->getPrcontMatrix();
  transposeCopy(outputpr,temp_m->getData(),shocks_states*dim,xmax);
  matPutVariable(mfp, "prcontMatrix", output);

  // s Matrix
  output = mxCreateDoubleMatrix(shocks_states*dim,xmax,mxREAL);
  outputpr = mxGetPr(output);
  temp_m = tildesMatrix;
  transposeCopy(outputpr,temp_m->getData(),shocks_states*dim,xmax);
  matPutVariable(mfp, "sMatrix", output);
  
  // Rho Matrix
  output = mxCreateDoubleMatrix(shocks_states*dim,xmax,mxREAL);
  outputpr = mxGetPr(output);
  temp_m = oe_dom->getRhoDom();
  transposeCopy(outputpr,temp_m->getData(),shocks_states*dim,xmax);
  matPutVariable(mfp, "rhoMatrix", output);
 
  // Iota Matrix
  output = mxCreateDoubleMatrix(shocks_states*dim,xmax,mxREAL);
  outputpr = mxGetPr(output);
  temp_m = oe_dom->getIotaDom();
  transposeCopy(outputpr,temp_m->getData(),shocks_states*dim,xmax);
  matPutVariable(mfp, "iotaMatrix", output);
      
  // VMatrix
  output = mxCreateDoubleMatrix(shocks_states*dim,xmax,mxREAL);
  outputpr = mxGetPr(output);
  temp_m = oe_dom->getVDom();
  transposeCopy(outputpr,temp_m->getData(),shocks_states*dim,xmax);
  matPutVariable(mfp, "VMatrix", output);

#ifdef ACW
  SparseMatrix3D<double> *tranDom = oe_dom->getTranDom();
  mwSize dims[1];
  dims[0] = tranDom->getDeep();
  mxArray *output2;
  output2 = mxCreateCellArray(1, dims);
  for(int i=0;i<dims[0];++i) {
    output = mxCreateDoubleMatrix(xmax,xmax,mxREAL);
    outputpr = mxGetPr(output);
    transposeCopy(outputpr,(*tranDom)[i].getData(),xmax,xmax);
    mxSetCell(output2, i, output);
  }
  matPutVariable(mfp, "tranDom", output2);

  MatrixMTL3D<double> *tranFringe = oe_dom->getTranFringe();
  dims[0] = tranFringe->getDeep();
  output2 = mxCreateCellArray(1, dims);
  for(int i=0;i<dims[0];++i) {
    output = mxCreateDoubleMatrix(xmax,xmax,mxREAL);
    outputpr = mxGetPr(output);
    transposeCopy(outputpr,(*tranFringe)[i].getData(),xmax,xmax);
    mxSetCell(output2, i, output);
  } 
  matPutVariable(mfp, "tranFringe", output2);
#endif

  delete oe_dom;

  myfile.close();
  delete[] file;

  if (matClose(mfp) != 0) {
    printf("Error closing file %s\n",filename);
    return(EXIT_FAILURE);
  }

  return 0;
}

mxArray *readvar(MATFile *mfp, const char *var_name) {
    mxArray *array_ptr;
    int *from, *to;

    if ((array_ptr = matGetVariable(mfp, var_name)) == NULL) {
        cout << "Cannot read matlab variable " << var_name << endl;
        myfile << "Cannot read matlab variable " << var_name << endl;
        exit(1);
    }
    if ((mxGetNumberOfElements(array_ptr) == 1) && mxIsNumeric(array_ptr)) {
      #ifdef COUT_DEBUG
        cout << "Reading matlab variable " << var_name << "=" << mxGetScalar(array_ptr) << endl;
      #endif
      #ifdef DEBUG
        myfile << "Reading matlab variable " << var_name << "=" << mxGetScalar(array_ptr) << endl;
      #endif
    } else {
	#ifdef COUT_DEBUG
          cout << "Reading matlab variable " << var_name << " (";
        #endif
	#ifdef DEBUG
	  myfile << "Reading matlab variable " << var_name << " (";
	#endif
        from = (int *) mxGetDimensions(array_ptr);
        to = from+mxGetNumberOfDimensions(array_ptr)-1;
        for (; from<to; from++) {
        #ifdef COUT_DEBUG
          cout << *from << "x";
	#endif
	#ifdef DEBUG
	  myfile << *from << "x";
	#endif
	}
        #ifdef COUT_DEBUG
          cout << *from << ")"<< endl;
	#endif
	#ifdef DEBUG
          myfile << *from << ")"<< endl;
	#endif
    }

    return array_ptr;
}

void transposeCopy(double *out, double *in, int n, int m) {
  int k=0,l;
  for(int i=0;i<m;i++) {
    l=i;
    for(int j=0;j<n;j++) {
      out[k]=in[l];
      l+=m;
      k++;
    }
  }
}
