/* 
 * 
 */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <ctype.h>
#include <errno.h>
#include <pthread.h>
#include <unistd.h>
#include "VMS_Implementations/Vthread_impl/VPThread.h"
#include "C_Libraries/Queue_impl/PrivateQueue.h"

#include <linux/perf_event.h>
#include <linux/prctl.h>
#include <sys/syscall.h>

#undef DEBUG
//#define DEBUG

#if !defined(unix) && !defined(__unix__)
#ifdef __MACH__
#define unix		1
#define __unix__	1
#endif	/* __MACH__ */
#endif	/* unix */

/* find the appropriate way to define explicitly sized types */
/* for C99 or GNU libc (also mach's libc) we can use stdint.h */
#if (__STDC_VERSION__ >= 199900) || defined(__GLIBC__) || defined(__MACH__)
#include <stdint.h>
#elif defined(unix) || defined(__unix__)	/* some UNIX systems have them in sys/types.h */
#include <sys/types.h>
#elif defined(__WIN32__) || defined(WIN32)	/* the nameless one */
typedef unsigned __int8 uint8_t;
typedef unsigned __int32 uint32_t;
#endif	/* sized type detection */

/* provide a millisecond-resolution timer for each system */
#if defined(unix) || defined(__unix__)
#include <time.h>
#include <sys/time.h>
unsigned long get_msec(void) {
	static struct timeval timeval, first_timeval;

	gettimeofday(&timeval, 0);
	if(first_timeval.tv_sec == 0) {
		first_timeval = timeval;
		return 0;
	}
	return (timeval.tv_sec - first_timeval.tv_sec) * 1000 + (timeval.tv_usec - first_timeval.tv_usec) / 1000;
}
#elif defined(__WIN32__) || defined(WIN32)
#include <windows.h>
unsigned long get_msec(void) {
	return GetTickCount();
}
#else
//#error "I don't know how to measure time on your platform"
#endif

//======================== Globals =========================
char __ProgrammName[] = "overhead_test";
char __DataSet[255];

int outer_iters, inner_iters, num_threads;
size_t chunk_size = 0;

int cycles_counter_main_fd;
int misses_counter_fd;

uint64_t cache_misses;

int cycles_counter_fd[NUM_CORES];
struct perf_event_attr* hw_event;

//======================== Defines =========================
typedef struct perfData measurement_t;
struct perfData{
    uint64 cycles;
} __align_to_cacheline__;

const char *usage = {
	"Usage: malloc_test [options]\n"
	"  Spwans a number of threads and allocates memory.\n\n"
	"Options:\n"
	"  -t <num>   how many threads to use (default: 1). This is internaly multiplied by the number of cores.\n"
	"  -o <num>   repeat workload and sync operation <m> times\n"
        "  -i <num>   size of workload, repeat <n> times\n"     
	"  -h         this help screen\n\n"
};

struct barrier_t
{
    int counter;
    int nthreads;
    int32 mutex;
    int32 cond;
    measurement_t endBarrierCycles;

} __align_to_cacheline__;
typedef struct barrier_t barrier;

void inline barrier_init(barrier *barr, int nthreads, VirtProcr *animatingPr)
 {
   barr->counter = 0;
   barr->nthreads = nthreads;
   barr->mutex   = VPThread__make_mutex(animatingPr);
   barr->cond    = VPThread__make_cond(barr->mutex, animatingPr);
 }

void inline barrier_wait(barrier *barr, VirtProcr *animatingPr)
 { int i;

   VPThread__mutex_lock(barr->mutex, animatingPr);
   barr->counter++;
   if(barr->counter == barr->nthreads)
    { 
        read(cycles_counter_main_fd, &(barr->endBarrierCycles.cycles), \
                sizeof(barr->endBarrierCycles.cycles));
       
      barr->counter = 0;
      for(i=0; i < barr->nthreads; i++)
         VPThread__cond_signal(barr->cond, animatingPr);
    }
   else
    { VPThread__cond_wait(barr->cond, animatingPr);
    }
   VPThread__mutex_unlock(barr->mutex, animatingPr);
 }



struct WorkerParams_t
 { struct barrier_t* barrier;
   uint64_t  totalWorkCycles;
   uint64_t  totalBadCycles;
   uint64_t  totalSyncCycles;
   uint64_t  totalBadSyncCycles;
   uint64     numGoodSyncs;
   uint64     numGoodTasks;
 };
 
 typedef union
 {
     struct WorkerParams_t data;
     char padding[CACHELINE_SIZE];
 } WorkerParams __align_to_cacheline__;
 
WorkerParams *workerParamsArray;

typedef struct
 { measurement_t *startExeCycles;
   measurement_t *endExeCycles;
 } BenchParams __align_to_cacheline__;

//======================== App Code =========================
/*
 p* Workload
 */

#define saveCyclesAndInstrs(core,cycles) do{     \
   int cycles_fd = cycles_counter_fd[core];             \
   int nread;                                           \
                                                        \
   nread = read(cycles_fd,&(cycles),sizeof(cycles));    \
   if(nread<0){                                         \
       perror("Error reading cycles counter");          \
       cycles = 0;                                      \
   }                                                    \
} while (0) //macro magic for scoping
 
#define saveMisses(misses) do{     \
   int nread;                                           \
                                                        \
   nread = read(misses_counter_fd,&(misses),sizeof(misses));    \
   if(nread<0){                                         \
       perror("Error reading misses counter");          \
       misses = 0;                                      \
   }                                                    \
} while (0) //macro magic for scoping


double
worker_TLF(void* _params, VirtProcr* animatingPr)
 {
   int i,o;
   WorkerParams* params = (WorkerParams*)_params;
   unsigned int totalWorkCycles = 0, totalBadCycles = 0;
   unsigned int totalSyncCycles = 0, totalBadSyncCycles = 0;
   unsigned int workspace1=0, numGoodSyncs = 0, numGoodTasks = 0;
   double workspace2=0.0;
   int32 privateMutex = VPThread__make_mutex(animatingPr);
   
   int cpuid = sched_getcpu();
   
   measurement_t startWorkload, endWorkload;
   uint64 numCycles;
   for(o=0; o < outer_iters; o++)
    {
       
          saveCyclesAndInstrs(cpuid,startWorkload.cycles);
       
      //task
      for(i=0; i < inner_iters; i++)
       {
         workspace1 += (workspace1 + 32)/2;
         workspace2 += (workspace2 + 23.2)/1.4;
       }
      
          saveCyclesAndInstrs(cpuid,endWorkload.cycles);
          numCycles = endWorkload.cycles - startWorkload.cycles;
          //sanity check (400K is about 20K iters)
          if( numCycles < 400000 ) {totalWorkCycles += numCycles; numGoodTasks++;}
          else                     {totalBadCycles  += numCycles; }

      //mutex access often causes switch to different Slave VP
      VPThread__mutex_lock(privateMutex, animatingPr);
      
/*
          saveCyclesAndInstrs(cpuid,startWorkload2.cycles);
      //Task
      for(i=0; i < inner_iters; i++)
       {
         workspace1 += (workspace1 + 32)/2;
         workspace2 += (workspace2 + 23.2)/1.4;
       }
      
          saveCyclesAndInstrs(cpuid,endWorkload2.cycles);
          numCycles = endWorkload2.cycles - startWorkload2.cycles;
          //sanity check (400K is about 20K iters)
          if( numCycles < 400000 ) {totalWorkCycles += numCycles; numGoodTasks++;}
          else                     {totalBadCycles  += numCycles; }
      
*/
      VPThread__mutex_unlock(privateMutex, animatingPr);
    }

   params->data.totalWorkCycles = totalWorkCycles;
   params->data.totalBadCycles = totalBadCycles;
   params->data.numGoodTasks   = numGoodTasks;
   params->data.totalSyncCycles = totalSyncCycles;
   params->data.totalBadSyncCycles = totalBadSyncCycles;
   params->data.numGoodSyncs = numGoodSyncs;
/*
   params->totalSyncCycles = VMS__give_num_plugin_cycles();
   params->totalBadSyncCycles = 0;
   params->numGoodSyncs = VMS__give_num_plugin_animations();
*/
   
   
   //Wait for all threads to end
   barrier_wait(params->data.barrier, animatingPr);
   
   //Shutdown worker
   VPThread__dissipate_thread(animatingPr);
   
     //below return never reached --> there for gcc
   return (workspace1 + workspace2);  //to prevent gcc from optimizing work out
 }

//local variables of benchmark, made global for alignment
struct barrier_t  barr __align_to_cacheline__;
BenchParams      *params __align_to_cacheline__;

/* this is run after the VMS is set up*/
void benchmark(void *_params, VirtProcr *animatingPr)
 {
   int i;

   params = (BenchParams *)_params;

   barrier_init(&barr, num_threads+1, animatingPr);
      
   //prepare input
   for(i=0; i<num_threads; i++)
    { 
       workerParamsArray[i].data.barrier = &barr;
    }
    
   uint64_t cache_misses_at_start, cache_misses_at_end;
   saveMisses(cache_misses_at_start);
   //save cycles before execution of threads, to get total exe cycles
   int nread = read(cycles_counter_main_fd, &(params->startExeCycles->cycles),
                sizeof(params->startExeCycles->cycles));
   if(nread<0) perror("Error reading cycles counter");
   
   //create (which starts running) all threads
   for(i=0; i<num_threads; i++)
    { VPThread__create_thread((VirtProcrFnPtr)worker_TLF, &(workerParamsArray[i]), animatingPr);
    }
   //wait for all threads to finish
   barrier_wait(&barr, animatingPr);
  
   //endBarrierCycles read in barrier_wait()!  Merten, email me if want to chg
   params->endExeCycles->cycles = barr.endBarrierCycles.cycles;
   saveMisses(cache_misses_at_end);
   cache_misses = cache_misses_at_end-cache_misses_at_start;
/*
   uint64_t overallWorkCycles = 0;
   for(i=0; i<num_threads; i++){ 
       printf("WorkCycles: %lu\n",input[i].totalWorkCycles);
       overallWorkCycles += input[i].totalWorkCycles;
    }
   
   printf("Sum across threads of work cycles: %lu\n", overallWorkCycles);
   printf("Total Execution: %lu\n", endBenchTime.cycles-startBenchTime.cycles);
   printf("Runtime/Workcycle Ratio %lu\n", 
   ((endBenchTime.cycles-startBenchTime.cycles)*100)/overallWorkCycles);
*/

   //======================================================

   VPThread__dissipate_thread(animatingPr);
 }

int main(int argc, char **argv)
 {
   int i;

   //set global static variables, based on cmd-line args
   for(i=1; i<argc; i++)
    {
      if(argv[i][0] == '-' && argv[i][2] == 0)
       {
         switch(argv[i][1])
          {
            case 't':
               if(!isdigit(argv[++i][0]))
                {
                  fprintf(stderr, "-t must be followed by the number of worker threads to spawn\n");
                  return EXIT_FAILURE;
                }
               num_threads = atoi(argv[i]);
               if(!num_threads)
                {
                  fprintf(stderr, "invalid number of threads specified: %d\n", num_threads);
                  return EXIT_FAILURE;
                }
            break;
            case 'o':
               if(!isdigit(argv[++i][0]))
                {
                  fputs("-i must be followed by a number\n", stderr);
                  return EXIT_FAILURE;
                }
               outer_iters = atoi(argv[i]);
				break;
            case 'i':
               if(!isdigit(argv[++i][0]))
                {
                  fputs("-o must be followed by a number (workload size)\n", stderr);
                  return EXIT_FAILURE;
                }
               inner_iters = atoi(argv[i]);
				break;
            case 'h':
               fputs(usage, stdout);
               return 0;
				
            default:
               fprintf(stderr, "unrecognized argument: %s\n", argv[i]);
               fputs(usage, stderr);
               return EXIT_FAILURE;
          }//switch
       }//if arg
      else
       {
			fprintf(stderr, "unrecognized argument: %s\n", argv[i]);
			fputs(usage, stderr);
			return EXIT_FAILURE;
       }
    }//for
   
   
   //setup performance counters
    hw_event = malloc(sizeof(struct perf_event_attr));
    memset(hw_event,0,sizeof(struct perf_event_attr));
    
    hw_event->type = PERF_TYPE_HARDWARE;
    hw_event->size = sizeof(hw_event);
    hw_event->disabled = 0;
    hw_event->freq = 0;
    hw_event->inherit = 1; /* children inherit it   */
    hw_event->pinned = 1; /* says this virt counter must always be on HW */
    hw_event->exclusive = 0; /* only group on PMU     */
    hw_event->exclude_user = 0; /* don't count user      */
    hw_event->exclude_kernel = 1; /* don't count kernel  */
    hw_event->exclude_hv = 1; /* ditto hypervisor      */
    hw_event->exclude_idle = 1; /* don't count when idle */
    hw_event->mmap = 0; /* include mmap data     */
    hw_event->comm = 0; /* include comm data     */

    hw_event->config = PERF_COUNT_HW_CPU_CYCLES; //cycles
    
    int cpuID, retries;

   for( cpuID = 0; cpuID < NUM_CORES; cpuID++ )
    { retries = 0;
      do
       { retries += 1;
         cycles_counter_fd[cpuID] = 
          syscall(__NR_perf_event_open, hw_event,
                  0,//pid_t: 0 is "pid of calling process" 
                  cpuID,//int: cpu, the value returned by "CPUID" instr(?)
                  -1,//int: group_fd, -1 is "leader" or independent
                  0//unsigned long: flags
                 );
       }
      while(cycles_counter_fd[cpuID]<0 && retries < 100);
      if(retries >= 100)
       {
         fprintf(stderr,"On core %d: ",cpuID);
         perror("Failed to open cycles counter");
       }
    }

   //Set up counter to accumulate total cycles to process, across all CPUs

   retries = 0;
   do
    { retries += 1;
      cycles_counter_main_fd = 
       syscall(__NR_perf_event_open, hw_event,
               0,//pid_t: 0 is "pid of calling process" 
               -1,//int: cpu, -1 means accumulate from all cores
               -1,//int: group_fd, -1 is "leader" == independent
               0//unsigned long: flags
              );
    }
   while(cycles_counter_main_fd<0 && retries < 100);
   if(retries >= 100)
    {
      fprintf(stderr,"in main ");
      perror("Failed to open cycles counter");
    }
   
   //Set up counters to count cache misses
    hw_event->type = PERF_TYPE_HARDWARE;
    hw_event->config = PERF_COUNT_HW_CACHE_MISSES; //misses
    
   retries = 0;
   do
    { retries += 1;
      misses_counter_fd = 
       syscall(__NR_perf_event_open, hw_event,
               0,//pid_t: 0 is "pid of calling process" 
               -1,//int: cpu, -1 means accumulate from all cores
               -1,//int: group_fd, -1 is "leader" == independent
               0//unsigned long: flags
              );
    }
   while(misses_counter_fd<0 && retries < 100);
   if(retries >= 100)
    {
      fprintf(stderr,"in main ");
      perror("Failed to misses counter");
    }
   
   measurement_t startExeCycles, endExeCycles;
   BenchParams *benchParams;
   
   benchParams = malloc(sizeof(BenchParams)); 
   
   benchParams->startExeCycles = &startExeCycles;
   benchParams->endExeCycles   = &endExeCycles;
   
   workerParamsArray =  (WorkerParams *)malloc( (num_threads + 1) * sizeof(WorkerParams) );
   if(workerParamsArray == NULL ) printf("error mallocing worker params array\n");
   
 
   //This is the transition to the VMS runtime
   VPThread__create_seed_procr_and_do_work( &benchmark, benchParams );
   
   uint64_t totalWorkCyclesAcrossCores = 0, totalBadCyclesAcrossCores = 0;
   uint64_t totalSyncCyclesAcrossCores = 0, totalBadSyncCyclesAcrossCores = 0;
   for(i=0; i<num_threads; i++){ 
       printf("WorkCycles: %lu\n",workerParamsArray[i].data.totalWorkCycles);
//       printf("Num Good Tasks: %lu\n",workerParamsArray[i].numGoodTasks);
//       printf("SyncCycles: %lu\n",workerParamsArray[i].totalSyncCycles);
//       printf("Num Good Syncs: %lu\n",workerParamsArray[i].numGoodSyncs);
       totalWorkCyclesAcrossCores += workerParamsArray[i].data.totalWorkCycles;
       totalBadCyclesAcrossCores  += workerParamsArray[i].data.totalBadCycles;
       totalSyncCyclesAcrossCores += workerParamsArray[i].data.totalSyncCycles;
       totalBadSyncCyclesAcrossCores  += workerParamsArray[i].data.totalBadSyncCycles;
    }

   uint64_t totalExeCycles = endExeCycles.cycles - startExeCycles.cycles;
   totalExeCycles -= totalBadCyclesAcrossCores;
   uint64 totalOverhead = totalExeCycles - totalWorkCyclesAcrossCores;
   int32  numSyncs = outer_iters * num_threads * 2;
   printf("Total Execution Cycles: %lu\n", totalExeCycles);
   printf("Total number of cache misses: %lu\n", cache_misses);
   printf("Sum across threads of work cycles: %lu\n", totalWorkCyclesAcrossCores);
   printf("Sum across threads of bad work cycles: %lu\n", totalBadCyclesAcrossCores);
//   printf("Sum across threads of Bad Sync cycles: %lu\n", totalBadSyncCyclesAcrossCores);
   printf("Overhead per sync: %f\n", (double)totalOverhead / (double)numSyncs );
   printf("ExeCycles/WorkCycles Ratio %f\n", 
          (double)totalExeCycles / (double)totalWorkCyclesAcrossCores);
   return 0;
 }
