#include <string.h>
#include "mex.h"
#include <stdio.h>
#include <stdlib.h>

#include "cuda_generator.h"
#include "profitGlobals.h"

static int initialized = 0;
static profitStructure *profitDataMex;

void cleanup(void) {
  int i;
  int data_size = profitDataMex->stations;

  initialized=0;

  for(i=0;i<data_size;++i) {
    mxFree(profitDataMex->data[i]);
  }
  mxFree(profitDataMex->data);

/*  for(i=0;i<profitDataMex->N;++i) {
    mxFree(profitDataMex->demographics[i]);
  }
  mxFree(profitDataMex->demographics);
*/
  mxFree(profitDataMex->structure);

  mxFree(profitDataMex);
}

void mexFunction(int nlhs, mxArray *plhs[ ],int nrhs, const mxArray *prhs[ ])  {
  int i,j,n,d,stations,m,r1,o;
  char ar;
  double *foc_out, *output, *dgamma1, *dgamma2;
  double *indexInit, *demographicsInit, *lagInit, *instrumentsInit;
  double *dataInit;
  double *mySeedInit = mxGetPr(prhs[1]);
  double rho;
  double *omega;
  double aggrQ;
  int market, date, date_market, data_size, lagSize;

  reInit = (int) mxGetPr(mxGetField(prhs[0], 0, "reInit"))[0];
  if((reInit) && (initialized)) {
    cleanup();
  }

  if((!initialized) || (reInit)) {
/*    printf("Initing\n");*/

    initialized=1;
    profitDataMex = (profitStructure *) mxMalloc(sizeof(profitStructure));
    mexMakeMemoryPersistent(profitDataMex);
    mexAtExit(cleanup);

    profitDataMex->stations=(int) mxGetPr(mxGetField(prhs[0], 0, "stations"))[0];
    profitDataMex->demoCharacteristics=(int) mxGetPr(mxGetField(prhs[0], 0, "demoCharacteristics"))[0];
    profitDataMex->parameters=(int) mxGetPr(mxGetField(prhs[0], 0, "parameters"))[0];
    profitDataMex->N=(int) mxGetPr(mxGetField(prhs[0], 0, "N"))[0];
    profitDataMex->random_effects=(int) mxGetPr(mxGetField(prhs[0], 0, "random_effects"))[0];
    profitDataMex->product=(int) mxGetPr(mxGetField(prhs[0], 0, "product"))[0];
    profitDataMex->power=profitDataMex->parameters+profitDataMex->product*(profitDataMex->demoCharacteristics)+profitDataMex->random_effects;
    int power = (int) mxGetPr(mxGetField(prhs[0], 0, "extra_parameters"))[0];
    profitDataMex->year_dummies=profitDataMex->parameters+profitDataMex->product*(profitDataMex->demoCharacteristics)+profitDataMex->random_effects+
      power;
    profitDataMex->dummy=mxGetPr(mxGetField(prhs[0], 0, "dummy"))[0];
    profitDataMex->allParams=profitDataMex->parameters+profitDataMex->product*profitDataMex->demoCharacteristics+profitDataMex->random_effects+power+profitDataMex->year_dummies;
    profitDataMex->dataLineSize=(int) mxGetPr(mxGetField(prhs[0], 0, "dataLineSize"))[0];;
    profitDataMex->demoLineSize=(int) mxGetPr(mxGetField(prhs[0], 0, "demoLineSize"))[0];;

    stations = profitDataMex->stations;

    dataInit = mxGetPr(mxGetField(prhs[0], 0, "data"));
    profitDataMex->data = (F_TYPE **) mxMalloc(stations*sizeof(F_TYPE *));
    mexMakeMemoryPersistent(profitDataMex->data);
    m=0;
    for(i=0;i<stations;++i) {
      profitDataMex->data[i] = (F_TYPE *) mxMalloc(profitDataMex->dataLineSize*sizeof(F_TYPE));
      mexMakeMemoryPersistent(profitDataMex->data[i]);
      for(j=0;j<profitDataMex->dataLineSize;++j) {
        profitDataMex->data[i][j]=dataInit[m];
        m++;
      }
    }

/*    demographicsInit = mxGetPr(mxGetField(prhs[0], 0, "demographics"));
    profitDataMex->demographics = (F_TYPE **) mxMalloc(profitDataMex->N*sizeof(F_TYPE *));
    mexMakeMemoryPersistent(profitDataMex->demographics);
    m=0;
    for(i=0;i<profitDataMex->N;++i) {
      profitDataMex->demographics[i] = (F_TYPE *) mxMalloc(profitDataMex->demoLineSize*sizeof(F_TYPE));
      mexMakeMemoryPersistent(profitDataMex->demographics[i]);
      for(j=0;j<profitDataMex->demoLineSize;++j) {
        profitDataMex->demographics[i][j]=demographicsInit[m];
        m++;
      }
    }*/

    profitDataMex->gamma[0]=mxGetPr(mxGetField(prhs[0], 0, "gamma"))[0];
    profitDataMex->gamma[1]=mxGetPr(mxGetField(prhs[0], 0, "gamma"))[1];
    
    /* count owners */
    o=profitDataMex->data[0][2];
    profitDataMex->number_of_owners=1;
    for(j=0;j<stations;++j) {
      if(profitDataMex->data[j][2]!=o) {
        profitDataMex->number_of_owners++;
        o=profitDataMex->data[j][2];
      }
    }

    /* create structure vector */
    profitDataMex->structure = (int *) mxMalloc((profitDataMex->number_of_owners+1)*sizeof(int));
    mexMakeMemoryPersistent(profitDataMex->structure);
    profitDataMex->structure[0]=0;
    o=profitDataMex->data[0][2];
    r1=0;
    for(j=0;j<stations;++j) {
      if(profitDataMex->data[j][2]!=o) {
        r1++;
        profitDataMex->structure[r1]=j;
        o=profitDataMex->data[j][2];
      }
    }

    profitDataMex->structure[profitDataMex->number_of_owners]=stations;
  }

  if(reInit) {
    plhs[0]=mxCreateDoubleMatrix(1, 1, mxREAL);
    foc_out = mxGetPr(plhs[0]);
    foc_out[0]=0;
    return;
  }
  
  profitDataMex->q = mxGetPr(mxGetField(prhs[0], 0, "q"));
  profitDataMex->arg = mxGetPr(mxGetField(prhs[0], 0, "arg"));
  profitDataMex->demographics = mxGetPr(mxGetField(prhs[0], 0, "demographics"));

  omega = mxGetPr(mxGetField(prhs[0], 0, "omega"));

  profitDataMex->Gamma=(double **) mxMalloc(profitDataMex->stations*sizeof(double *));
  for(i=0;i<profitDataMex->stations;i++) { 
    profitDataMex->Gamma[i]=(double *) mxMalloc(profitDataMex->stations*sizeof(double));
    for(j=0;j<profitDataMex->stations;j++) {
      profitDataMex->Gamma[i][j]=-profitDataMex->gamma[1]*
        omega[(int) (profitDataMex->data[i][1]+profitDataMex->data[j][1]*8)];
    }
  }

  profitData = profitDataMex;

  seed = (seedStruct *) mxMalloc(1*sizeof(seedStruct));

  mySeed = (long *) mxMalloc(4*sizeof(long));
  mySeed[0] = (long) mySeedInit[0];
  mySeed[1] = (long) mySeedInit[1];
  mySeed[2] = (long) mySeedInit[2];
  mySeed[3] = (long) mySeedInit[3];

  createGenerators(1,mySeed,seed);

  plhs[0] = mxCreateDoubleMatrix(profitData->stations,1,mxREAL);
  foc_out = mxGetPr(plhs[0]);

  plhs[1] = mxCreateDoubleMatrix(profitData->stations,1,mxREAL);
  output = mxGetPr(plhs[1]);

  plhs[2] = mxCreateDoubleMatrix(profitData->stations,1,mxREAL);
  dgamma1 = mxGetPr(plhs[2]);

  plhs[3] = mxCreateDoubleMatrix(profitData->stations,1,mxREAL);
  dgamma2 = mxGetPr(plhs[3]);

  epsilonDer(foc_out, output, dgamma1, dgamma2);
}
