#include <string.h>
#include "mex.h"
#include <stdio.h>
#include <stdlib.h>
#include "epsilon.h"
#include "setup.h"
#include "cuda_generator.h"
#include "globals.h"

/* Input: 
 *  - problem
 *  - data
 *  - structure
 *     - market
 *     - q
 *     - arg
 *     - gamma
 *  - mySeed
 */

static int initialized = 0;
static BLPproblem *problemMex;
static BLPdata *dataMex;

void cleanup(void) {
  int i;
  int data_size = dataMex->index[problemMex->date*problemMex->market];

  initialized=0;

  mxFree(dataMex->index);

  for(i=0;i<data_size;++i) {
    mxFree(dataMex->data[i]);
  }
  mxFree(dataMex->data);

  for(i=0;i<problemMex->instruments;++i) {
    mxFree(dataMex->instruments[i]);
  }
  mxFree(dataMex->instruments);
  if(problemMex->ar) {
    mxFree(lag);
  }

  for(i=0;i<problemMex->date*problemMex->market*problemMex->N;++i) {
    mxFree(dataMex->demographics[i]);
  }
  mxFree(dataMex->demographics);

  mxFree(dataMex);
  mxFree(problemMex);
}

void mexFunction(int nlhs, mxArray *plhs[ ],int nrhs, const mxArray *prhs[ ])  {
  int i,j,n,d,stations,m,r1,o;
  char ar;
  double *foc_out;
  double q[100];
  double *moment, *indexInit, *demographicsInit, *lagInit, *instrumentsInit;
  double *dataInit;
  double *mySeedInit = mxGetPr(prhs[3]);
  double rho,aggrQ,aggrQ_old;
  int market, date_market, data_size, lagSize;
  double *ms,*sumNu;

  reInit = (int) mxGetPr(mxGetField(prhs[0], 0, "reInit"))[0];
  if((reInit) && (initialized)) {
    cleanup();
  }

  if((!initialized) || (reInit)) {
    printf("Initing\n");

    initialized=1;
    problemMex = (BLPproblem *) mxMalloc(sizeof(BLPproblem));
    mexMakeMemoryPersistent(problemMex);
    mexAtExit(cleanup);


    /* Initialize problem structure */

    problemMex->market=(int) mxGetPr(mxGetField(prhs[0], 0, "market"))[0];
    problemMex->date=(int) mxGetPr(mxGetField(prhs[0], 0, "date"))[0];
    problemMex->product=(int) mxGetPr(mxGetField(prhs[0], 0, "product"))[0];
    problemMex->demoCharacteristics=(int) mxGetPr(mxGetField(prhs[0], 0, "demoCharacteristics"))[0];
    problemMex->demoGroups=(int) mxGetPr(mxGetField(prhs[0], 0, "demoGroups"))[0];
    problemMex->parameters=(int) mxGetPr(mxGetField(prhs[0], 0, "parameters"))[0];
    problemMex->P=(int) mxGetPr(mxGetField(prhs[0], 0, "P"))[0];
    problemMex->Pn=(int) mxGetPr(mxGetField(prhs[0], 0, "Pn"))[0];
    problemMex->N=(int) mxGetPr(mxGetField(prhs[0], 0, "N"))[0];
    problemMex->allParams=problemMex->parameters+problemMex->product*(problemMex->demoCharacteristics+1);
    problemMex->doInstruments=(int) mxGetPr(mxGetField(prhs[0], 0, "doInstruments"))[0];
    problemMex->instruments=(int) mxGetPr(mxGetField(prhs[0], 0, "instruments"))[0];
    
    problemMex->ar=(int) mxGetPr(mxGetField(prhs[0], 0, "ar"))[0];

    problemMex->bs=(int) mxGetPr(mxGetField(prhs[0], 0, "bs"))[0];

    dataMex = (BLPdata *) mxMalloc(sizeof(BLPdata));
    mexMakeMemoryPersistent(dataMex);

    date_market=problemMex->market*problemMex->date;

    problemMex->dataLineSize=(int) mxGetPr(mxGetField(prhs[0], 0, "dataLineSize"))[0];;
    problemMex->groupsLineSize=(int) mxGetPr(mxGetField(prhs[0], 0, "groupsLineSize"))[0];;
    problemMex->demoLineSize=(int) mxGetPr(mxGetField(prhs[0], 0, "demoLineSize"))[0];;

    /* Transfer data into the device */
    indexInit = mxGetPr(mxGetField(prhs[1], 0, "index"));
    dataMex->index = (int *) mxMalloc((date_market+1)*sizeof(int));
    mexMakeMemoryPersistent(dataMex->index);
    for(i=0;i<=date_market;++i) {
      dataMex->index[i] = (int) indexInit[i];
    }

    data_size = indexInit[date_market];

    dataInit = mxGetPr(mxGetField(prhs[1], 0, "data"));
    dataMex->data = (F_TYPE **) mxMalloc(data_size*sizeof(F_TYPE *));
    mexMakeMemoryPersistent(dataMex->data);
    m=0;
    for(i=0;i<data_size;++i) {
      dataMex->data[i] = (F_TYPE *) mxMalloc(problemMex->dataLineSize*sizeof(F_TYPE));
      mexMakeMemoryPersistent(dataMex->data[i]);
      for(j=0;j<problemMex->dataLineSize;++j) {
        dataMex->data[i][j]=dataInit[m];
        m++;
      }
    }

    demographicsInit = mxGetPr(mxGetField(prhs[1], 0, "demographics"));
    dataMex->demographics = (F_TYPE **) mxMalloc(date_market*problemMex->N*sizeof(F_TYPE *));
    mexMakeMemoryPersistent(dataMex->demographics);
    m=0;
    for(i=0;i<date_market*problemMex->N;++i) {
      dataMex->demographics[i] = (F_TYPE *) mxMalloc(problemMex->demoLineSize*sizeof(F_TYPE));
      mexMakeMemoryPersistent(dataMex->demographics[i]);
      for(j=0;j<problemMex->demoLineSize;++j) {
        dataMex->demographics[i][j]=demographicsInit[m];
        m++;
      }
    }

    lagInit = mxGetPr(mxGetField(prhs[1], 0, "lag"));
    lagSize = mxGetN(mxGetField(prhs[1], 0, "lag"));
    lag = (int *) mxMalloc(data_size*sizeof(int *));
    mexMakeMemoryPersistent(lag);
    for(i=0,j=0;i<lagSize;++i,j+=2) {
       lag[(int) lagInit[j]]=lagInit[j+1];
    }

    instrumentsInit = mxGetPr(mxGetField(prhs[1], 0, "instruments"));
    dataMex->instruments = (F_TYPE **) mxMalloc(problemMex->instruments*sizeof(F_TYPE *));
    mexMakeMemoryPersistent(dataMex->instruments);
    m=0;
    for(i=0;i<problemMex->instruments;++i) {
      dataMex->instruments[i] = (F_TYPE *) mxMalloc(data_size*sizeof(F_TYPE));
      mexMakeMemoryPersistent(dataMex->instruments[i]);
      for(j=0;j<data_size;++j) {
        dataMex->instruments[i][j]=instrumentsInit[m];
        m++;
      }
    }
  } else {
    date_market=problemMex->market*problemMex->date;
    data_size = dataMex->index[date_market];
  }

  if(reInit) {
    plhs[0]=mxCreateDoubleMatrix(1, 1, mxREAL);
    foc_out = mxGetPr(plhs[0]);
    foc_out[0]=0;
    return;
  }

  advDem = (advDemStructure *) mxMalloc(sizeof(advDemStructure));
  advDem->gamma[1]=mxGetPr(mxGetField(prhs[2], 0, "gamma"))[0];

  advDem->q = mxGetPr(mxGetField(prhs[2], 0, "q"));
  advDem->arg = mxGetPr(mxGetField(prhs[2], 0, "arg"));

  ar = mxGetPr(mxGetField(prhs[2], 0, "ar"))[0];
  if(ar) {
    rho = mxGetPr(mxGetField(prhs[2], 0, "rho"))[0];
  }

  market = mxGetPr(mxGetField(prhs[2], 0, "market"))[0];

  problem = problemMex;
  data = dataMex;

  /* Create the seed */
  seed = (seedStruct *) malloc(1*sizeof(seedStruct));

  mySeed = (long *) mxMalloc(4*sizeof(long));
  mySeed[0] = (long) mySeedInit[0];
  mySeed[1] = (long) mySeedInit[1];
  mySeed[2] = (long) mySeedInit[2];
  mySeed[3] = (long) mySeedInit[3];

  createGenerators(1,mySeed,seed);

  /* Count stations */
  stations = 0;
  advDem->idx=market+problem->market;
  for(d=1;d<problem->date;d++) {
    stations+=data->index[advDem->idx+1]-data->index[advDem->idx];
    advDem->idx+=problem->market;
  }

  /* Inititalize the output matrix */
  plhs[0] = mxCreateDoubleMatrix(stations,2,mxREAL);
  moment = mxGetPr(plhs[0]);

  plhs[1] = mxCreateDoubleMatrix(problem->date,1,mxREAL);
  sumNu = mxGetPr(plhs[1]);

  foc_out = (double *) mxMalloc(data_size*sizeof(double));
  ms = (double *) mxMalloc(data_size*sizeof(double));

  advDem->P = mxGetPr(mxGetField(prhs[2], 0, "P"));

  advDem->idx=market;
  for(d=0;d<problem->date;d++) {
    /*printf("Doing d=%d\n",d);*/
    /* count owners */
    o=data->data[data->index[advDem->idx]][4];
    advDem->number_of_owners=1;
    for(j=data->index[advDem->idx];j<data->index[advDem->idx+1];++j) {
      if(data->data[j][4]!=o) {
        advDem->number_of_owners++;
	o=data->data[j][4];
      }
    }

    /* create structure vector */
    advDem->structure = malloc((advDem->number_of_owners+1)*sizeof(int));
    advDem->structure[0]=0;
    o=data->data[data->index[advDem->idx]][4];
    r1=0;
    for(j=data->index[advDem->idx],n=0;j<data->index[advDem->idx+1];++j,++n) {
      if(data->data[j][4]!=o) {
        r1++;
        advDem->structure[r1]=n;
        o=data->data[j][4];
      }
    }

    advDem->structure[advDem->number_of_owners]=data->index[advDem->idx+1]-data->index[advDem->idx];

    for(j=data->index[advDem->idx],n=0;j<data->index[advDem->idx+1];++j,++n) {
      q[n]=advDem->q[j];
    }

    epsilon(q, foc_out, ms);

    sumNu[d]=0;
    for(j=data->index[advDem->idx],n=0;j<data->index[advDem->idx+1];++j,++n) {
      if(ms[n]>0) {
        sumNu[d]+=foc_out[j]/ms[n];
      }
    }
    sumNu[d]/=n;

    free(advDem->structure);

    advDem->idx+=problem->market;
  }

  /* collect */
  m=0;
  advDem->idx=market+problem->market;
  for(d=1;d<problem->date;d++) {
    for(j=data->index[advDem->idx];j<data->index[advDem->idx+1];++j) {
      if((ar==0) || (lag[j]==-1)) {
        moment[m]=foc_out[j];
        moment[m+stations]=moment[m]*data->instruments[3][j];
      } else {
        moment[m]=foc_out[j]-rho*foc_out[lag[j]];
        moment[m+stations]=moment[m]*data->instruments[3][j];
      }
      m++;
    }
    advDem->idx+=problem->market;
  }
}
