#include<cuda_generator.h>
#include<stdlib.h>
#include<stdio.h>
#include<csvread.h>
#include<setup.h>
#include<math.h>

#include <sys/times.h>
#include <sys/types.h>
#include <sys/time.h>

#include <blp.h>
#include <solveBLP.h>

#include <mex.h>

static int initialized = 0;
static BLPproblem *problemMex;
static BLPdata *dataMex;
static mwIndex *JcStatic;
static mwIndex *IrStatic;
static int *TransposeStatic;

void cleanup(void) {
  int i;
  int data_size = dataMex->index[problemMex->date*problemMex->market];

  mxFree(dataMex->index);

  mxFree(dataMex->jacobianIndex);

  for(i=0;i<data_size;++i) {
    mxFree(dataMex->data[i]);
  }
  mxFree(dataMex->data);

  if(problemMex->doInstruments==1) {
    for(i=0;i<problemMex->instruments;++i) {
      mxFree(dataMex->instruments[i]);
    }
    mxFree(dataMex->instruments);
    if(problem->ar) {
      mxFree(lag);
    }
  }

  for(i=0;i<problemMex->date*problemMex->market*problemMex->N;++i) {
    mxFree(data->demographics[i]);
  }
  mxFree(dataMex->demographics);

  for(i=0;i<problemMex->date*problemMex->demoGroups*problemMex->P;++i) {
    mxFree(dataMex->demoGroups[i]);
  }
  mxFree(dataMex->demoGroups);

  mxFree(dataMex->share);
  mxFree(dataMex->demoFromData);

  mxFree(dataMex);
  mxFree(problemMex);
}

void mexFunction(int nlhs, mxArray *plhs[],
        int nrhs, const mxArray *prhs[]) { 
  int i,j;
  int date_market, data_size;
  F_TYPE *dataInit;
  F_TYPE *demographicsInit;
  F_TYPE *demoGroupsInit;
  F_TYPE *instrumentsInit;
  F_TYPE *indexInit;
  F_TYPE *jacobianIndexInit;
  F_TYPE *mySeedInit;

  F_TYPE **LoadTemp;
  F_TYPE *arg;
  int argSize;
  F_TYPE *shareInit, *demoFromDataInit;
  F_TYPE *output, *Ird, *Jcd;
  mwIndex *Ir, *Jc;
  double *target, *T;
  int m,n;
  int lagSize;
  F_TYPE *lagInit;
  /* long r0,r1,r2;
   struct tms t,u,r; */

  if(!initialized) {
    initialized=1;
    problemMex = (BLPproblem *) mxMalloc(sizeof(BLPproblem));
    mexMakeMemoryPersistent(problemMex);
    mexAtExit(cleanup);
  

  /*problem = (BLPproblem *) mxMalloc(sizeof(BLPproblem));*/

  /* Initialize problem structure */

    problemMex->market=(int) mxGetPr(mxGetField(prhs[0], 0, "market"))[0];
    problemMex->date=(int) mxGetPr(mxGetField(prhs[0], 0, "date"))[0];
    problemMex->product=(int) mxGetPr(mxGetField(prhs[0], 0, "product"))[0];
    problemMex->demoCharacteristics=(int) mxGetPr(mxGetField(prhs[0], 0, "demoCharacteristics"))[0];
    problemMex->demoGroups=(int) mxGetPr(mxGetField(prhs[0], 0, "demoGroups"))[0];
    problemMex->parameters=(int) mxGetPr(mxGetField(prhs[0], 0, "parameters"))[0];
    problemMex->P=(int) mxGetPr(mxGetField(prhs[0], 0, "P"))[0];
    problemMex->Pn=(int) mxGetPr(mxGetField(prhs[0], 0, "Pn"))[0];
    problemMex->N=(int) mxGetPr(mxGetField(prhs[0], 0, "N"))[0];
    problemMex->allParams=problemMex->parameters+problemMex->product*(problemMex->demoCharacteristics+1);
    problemMex->doInstruments=(int) mxGetPr(mxGetField(prhs[0], 0, "doInstruments"))[0];
    if(problemMex->doInstruments==1) {
      problemMex->instruments=(int) mxGetPr(mxGetField(prhs[0], 0, "instruments"))[0];
      instrumentsSave = problemMex->instruments;
    } else {
      instrumentsSave = (int) mxGetPr(mxGetField(prhs[0], 0, "instruments"))[0];
      problemMex->instruments=0;
    }
    problemMex->ar=(int) mxGetPr(mxGetField(prhs[0], 0, "ar"))[0];

    problem=problemMex;

    dataMex = (BLPdata *) mxMalloc(sizeof(BLPdata));
    mexMakeMemoryPersistent(dataMex);
    data = dataMex;

    date_market=problem->market*problem->date;
    size = (int) mxGetPr(mxGetField(prhs[0], 0, "mysize"))[0];

    problem->dataLineSize=(int) mxGetPr(mxGetField(prhs[0], 0, "dataLineSize"))[0];;
    problem->groupsLineSize=(int) mxGetPr(mxGetField(prhs[0], 0, "groupsLineSize"))[0];;
    problem->demoLineSize=(int) mxGetPr(mxGetField(prhs[0], 0, "demoLineSize"))[0];;

    /* Transfer data into the device */
    indexInit = mxGetPr(mxGetField(prhs[1], 0, "index"));
    data->index = (int *) mxMalloc((date_market+1)*sizeof(int));
    mexMakeMemoryPersistent(data->index);
    for(i=0;i<=date_market;++i) {
      data->index[i] = (int) indexInit[i];
    }

    data_size = indexInit[date_market];
  
    jacobianIndexInit = mxGetPr(mxGetField(prhs[1], 0, "jacobianIndex"));
    data->jacobianIndex = (int *) mxMalloc((date_market+1)*sizeof(int));
    mexMakeMemoryPersistent(data->jacobianIndex);
    for(i=0;i<=date_market;++i) {
      data->jacobianIndex[i] = (int) jacobianIndexInit[i];
    }

    dataInit = mxGetPr(mxGetField(prhs[1], 0, "data"));
    data->data = (F_TYPE **) mxMalloc(data_size*sizeof(F_TYPE *));
    mexMakeMemoryPersistent(data->data);
    m=0;
    for(i=0;i<data_size;++i) {
      data->data[i] = (F_TYPE *) mxMalloc(problem->dataLineSize*sizeof(F_TYPE));
      mexMakeMemoryPersistent(data->data[i]);
      for(j=0;j<problem->dataLineSize;++j) {
        data->data[i][j]=dataInit[m];
        m++;
      }
    }
  
    demographicsInit = mxGetPr(mxGetField(prhs[1], 0, "demographics"));
    data->demographics = (F_TYPE **) mxMalloc(date_market*problem->N*sizeof(F_TYPE *));
    mexMakeMemoryPersistent(data->demographics);
    m=0;
    for(i=0;i<date_market*problem->N;++i) {
      data->demographics[i] = (F_TYPE *) mxMalloc(problem->demoLineSize*sizeof(F_TYPE));
      mexMakeMemoryPersistent(data->demographics[i]);
      for(j=0;j<problem->demoLineSize;++j) {
        data->demographics[i][j]=demographicsInit[m];
        m++;
      }
    }  

    demoGroupsInit = mxGetPr(mxGetField(prhs[1], 0, "demoGroups"));
    data->demoGroups = (F_TYPE **) mxMalloc(problem->date*problem->demoGroups*problem->P*sizeof(F_TYPE *));
    mexMakeMemoryPersistent(data->demoGroups);
    m=0;
    for(i=0;i<problem->date*problem->demoGroups*problem->P;++i) {
      data->demoGroups[i] = (F_TYPE *) mxMalloc(problem->groupsLineSize*sizeof(F_TYPE));
      mexMakeMemoryPersistent(data->demoGroups[i]);
      for(j=0;j<problem->groupsLineSize;++j) {
        data->demoGroups[i][j]=demoGroupsInit[m];
        m++;
      }
    }

    shareInit = mxGetPr(mxGetField(prhs[1], 0, "share"));
    data->share = (F_TYPE *) mxMalloc(data_size*sizeof(F_TYPE));
    mexMakeMemoryPersistent(data->share);
    for(i=0;i<data_size;++i) {
      data->share[i] = shareInit[i];
    } 

    demoFromDataInit = mxGetPr(mxGetField(prhs[1], 0, "demoFromData"));
    data->demoFromData = (F_TYPE *) mxMalloc(problem->product*problem->demoGroups*sizeof(F_TYPE));
    mexMakeMemoryPersistent(data->demoFromData);
    for(i=0;i<problem->product*problem->demoGroups;++i) {
      data->demoFromData[i] = demoFromDataInit[i];
    }
  
    if(problem->doInstruments==1) {
      instrumentsInit = mxGetPr(mxGetField(prhs[1], 0, "instruments"));
      data->instruments = (F_TYPE **) mxMalloc(problem->instruments*sizeof(F_TYPE *));
      mexMakeMemoryPersistent(data->instruments);
      m=0;
      for(i=0;i<problem->instruments;++i) {
        data->instruments[i] = (F_TYPE *) mxMalloc(data_size*sizeof(F_TYPE));
        mexMakeMemoryPersistent(data->instruments[i]);
        for(j=0;j<data_size;++j) {
          data->instruments[i][j]=instrumentsInit[m];
          m++;
        }
      }
      if(problem->ar) {
        lagInit = mxGetPr(mxGetField(prhs[1], 0, "lag"));
        lagSize = mxGetN(mxGetField(prhs[1], 0, "lag"));
        lag = (int *) mxMalloc(data_size*sizeof(int *));
        mexMakeMemoryPersistent(lag);
        for(i=0,j=0;i<lagSize;++i,j+=2) {
          lag[(int) lagInit[j]]=lagInit[j+1];
        }
      }
    }
  } else {
    problem=problemMex;
    data=dataMex;
  }
  /* Initialize random number generator */
  mySeedInit = mxGetPr(prhs[2]);
  mySeed = (long *) mxMalloc(4*sizeof(long));
  mySeed[0] = (long) mySeedInit[0];
  mySeed[1] = (long) mySeedInit[1];
  mySeed[2] = (long) mySeedInit[2];
  mySeed[3] = (long) mySeedInit[3];

  threadsN = problem->market*problem->date+problem->demoGroups*problem->date;
  seed = (seedStruct *) mxMalloc(threadsN*sizeof(seedStruct));
  /* Load the starting point */
  arg = mxGetPr(prhs[3]);

  /* Load indexes */
  Ird = mxGetPr(prhs[4]);
  Jcd = mxGetPr(prhs[5]);
  T = mxGetPr(prhs[6]);

  n = data->index[problem->market*problem->date]+problem->demoGroups*problem->product+problem->instruments;

  if(problem->doInstruments==1) {
    m = problem->allParams+n;
    if(problem->ar) {
      m++;
    }
  } else {
    m = problem->allParams+n+instrumentsSave;
  }

  plhs[0] = mxCreateSparse(n,m,size,mxREAL);

  Ir = mxGetIr(plhs[0]);
  Jc = mxGetJc(plhs[0]);

  for(i=0;i<=m;i++) {
    Jc[i]=Jcd[i];
  }
  for(i=0;i<size;i++) {
    Ir[i]=Ird[i];
  }


/*  Transpose = TransposeStatic;*/

  /* Load data to the device */
  
  /* Run optimization */
  target = mxGetPr(plhs[0]);
  output = (double *) mxMalloc(size*sizeof(double));
  augmentedJacobian(output, arg);
  for(i=0;i<size;i++) {
    target[(int) T[i]]=output[i];
  }
}
