/* FILE interval */
     /* Count spikes in any ONE or TWO intervals; if 2, also the diff	*/
#include "deffs.h"
#include "config.h"
#include <math.h>

     /* CRITICAL POINT:
      *	There are 3 conditions, depending on command line arguments.
      *	Get_Interval_Type() identifies these conditions:
      *
      *	   1:  -i specified
      *	   		compute only one interval.
      *		*** Interval 1 returns interval specified by -i ***
      *		    Interval 2 undefined. 
      *
      *	   2:  -i & -B specified
      *	   		compute DIFFERENCE between 2 intervals.
      *		*** Interval 1 returns difference between the 2 intervals ***
      *		    Interval 2 returns background interval (I think)
      *
      *	   3:  -i & -I specified
      *	   		compute two intervals
      *		*** Interval 1 returns interval specified by -i ***
      *		*** Interval 2 returns interval specified by -I ***
      *		
      */

#define DEBUG		0	/* Can be 0, 1 or 2			*/

/* ********************************************************************	*/
/* Init_Intervals		Initialize structures
 * Spikes_In_Interval		Return spikes in interval (current trial)
 * Cumulate_Intervals		Add in 1 trial's data to interval stats
 * Calc_Spikes_Per_Interval	Do post-calculations on spike,interval data
  
 * Show_Interval		Plot interval bounds on current figure
 * Set_Interval_By_Bins
 * Set_Interval_By_Peak_Variance	IN HISTO.C, NOT HERE!

 * Set_Interval_From_CmdLine	Interval 1 or 2 (call with -i or -I)
 * Set_bg_Interval_From_CmdLine	Int2 duration == Int1 (call with -B endtime)
  
 * Print_Histo_Interval, Print_Histo_Stats	Print values (for macros)

 * Get_Interval_Begin		In ms
 * Get_Interval_End
 * Get_Interval_Mean		Mean firing rate within interval for 
 * Get_Interval_SEM
 * Get_Interval_SD
 * Get_Interval_Count		How many individual values
 * Get_Interval_Data		The individual values
 * Get_Interval_TrialNumber	Trial number assoc'd with each individual value
 * Get_Interval_Type		1:standard  2:minus background  3:two ints

 * Ttest_On_Interval		Do a ttest comparing two trial types
 * Ttest_On_Two_Intervals	Do a ttest comparing one trial type, int 1 vs 2
 * Ttest_Vs_Zero_On_Interval	Do a ttest on one trial type

 * Anova_On_Interval		Do non-parametric anova across types

 * Bad_Interval			Used when -w, -a and -i/I/B flags are given
 
 *   PRIVATE:
 * Check_Interval		Boring: use previously set times
 */

/* ********************************************************************	*/
/* ********************************************************************	*/

static int Trials[MAX_TYPES+1];	/* 'n' for each particular type		*/
	/* Histogram[MAX_TYPE] may be mean; [MAX+1] may be variance	*/
#define MAX_TRIALS 16000  /* inc from 2000 to 16000 on 1/4/11 by jk */

/* FIRST INTERVAL */
static int SpikesInInterval1[MAX_TYPES][2]; /* (sum,sum2) for each fig	*/
static int SpikesInEachInterval1[MAX_TYPES][MAX_TRIALS+1];
static float Mean1[MAX_TYPES], SEM1[MAX_TYPES], SD1[MAX_TYPES];
static int BeginMs1 = -300;		/* Int'vl bound in ms, re: align*/
static int EndMs1 = -1;			/* Int'vl bound in ms, re: align*/

/* SECOND INTERVAL */
static int SpikesInInterval2[MAX_TYPES][2]; /* (sum,sum2) for each fig	*/
static int SpikesInEachInterval2[MAX_TYPES][MAX_TRIALS+1];
static float Mean2[MAX_TYPES], SEM2[MAX_TYPES], SD2[MAX_TYPES];
static int BeginMs2 = 0;		/* Int'vl bound in ms, re: align*/
static int EndMs2 = 0;			/* Int'vl bound in ms, re: align*/

static int TrialNumberForEachSpikeCount[MAX_TYPES][MAX_TRIALS+1];
	/* Store (for both intervals); else have to search data file	*/

static int  bg_EndMs;			/* Set by -B option's value	*/
static int  SubtractBackground = NO;	/* Boolean: set by -B option	*/
static int  KeepSecondInterval = NO;	/* Boolean: set by -I option	*/

static void Check_Interval(int b1, int e1, int b2, int e2, int bg);
/* ********************************************************************	*/

/* FUNCTION Init_Intervals */
	 /* Called once per file by main, after options are read	*/
void Init_Intervals() {
	int f;

	for (f=0; f<MAX_TYPES; f++) {			/* Init bins	*/
	     Trials[f] =
	     SpikesInInterval1[f][0] =
	     SpikesInInterval1[f][1] =
	     SpikesInInterval2[f][0] =
	     SpikesInInterval2[f][1] = 0;
	     }
	Check_Interval(BeginMs1, EndMs1, BeginMs2, EndMs2, bg_EndMs);
	}
/* ********************************************************************	*/

/* FUNCTION Reset_Intervals */
	 /* Call once per file by main, BEFORE options are read!	*/
void Reset_Intervals() {
	SubtractBackground = KeepSecondInterval = NO;
	}
/* ********************************************************************	*/

/* FUNCTION Spikes_In_Interval */
	 /* Return spike count in a single interval for current trial	*/
int Spikes_In_Interval(int interval) {
	int SpikeCount = SpikeCount_Header(-1);
	int spike_count = 0;
	int i = 0;
	extern short Spikes[];
	short *spikes = Spikes;
	int begin = Get_Interval_Begin(interval) + Get_ZeroTime();
	int end   = Get_Interval_End(interval)   + Get_ZeroTime();

	if (interval != 1 && interval != 2)
	   Exit("No interval specified", "Interval_Spike_Count");
	if (interval == 2 && (Get_Interval_Begin(2)== FAIL))
	   Exit("No interval 2 given", "Interval_Spike_Count");

	while (i++ < SpikeCount) {
	   if (*spikes >= begin && *spikes<end)
	      spike_count++;
	   spikes++;				/* Get next spike	*/
	   }
	return(spike_count);
    }
/* ********************************************************************	*/

/* FUNCTION Cumulate_Interval */
	 /* Called by Cumulate_Histogram(): Tally one trial's spikes	*/
	 /* If -B flag, save difference only; if -I, save both counts	*/
/* ********************************************************************	*/

/* FUNCTION Cumulate_Interval */
	 /* Called by Cumulate_Histogram(): Tally one trial's spikes	*/
	 /* If -B flag, save difference only; if -I, save both counts	*/
void Cumulate_Interval(int spikes1, int spikes2, int type) {
	if (SubtractBackground)		/* If -B flag set, subtract!	*/
	   spikes1 -= spikes2;

	SpikesInInterval1[type][0] += spikes1;		 /* Tally sum	*/
	SpikesInInterval1[type][1] += spikes1 * spikes1; /* Tally sum2	*/
	SpikesInEachInterval1[type][Trials[type]] = spikes1;
	TrialNumberForEachSpikeCount[type][Trials[type]] =
		TrialNumber_Header();	/* Keep assoc'd trial#, too	*/

	if (SubtractBackground || KeepSecondInterval) {
	   SpikesInInterval2[type][0] += spikes2;	 /* Tally sum	*/
	   SpikesInInterval2[type][1] += spikes2 * spikes2; /* Tally sum2*/
	   SpikesInEachInterval2[type][Trials[type]] = spikes2;
	   }

	(Trials[type])++;				/* Tally trial	*/
	if (DEBUG==2)
	   fprintf(stderr, "Cum_Int: %d,%d spikes in %d-th trial (type %d)\n",
					spikes1,spikes2,Trials[type], type);
	}
/* ********************************************************************	*/

/* FUNCTION Calc_Spikes_Per_Interval */
	 /* Post-process: Calc mean & sem for each class	*/
void Calc_Spikes_Per_Interval() {
   float IntervalTime1 = (EndMs1-BeginMs1) / 1000.0 ;
   float IntervalTime2 = (EndMs2-BeginMs2) / 1000.0 ;
   int type;

   for (type=0; type<MAX_TYPES; type++) {	/* Go thru all types	*/

      if (Trials[type] == 0)			/* End of used		*/
         break;

      if (Trials[type] >= MAX_TRIALS)
         Exit("Increase MAX_TRIALS", "interval.c:Calc_Spikes_Per_Interval()");

	
      Mean1[type] = (SpikesInInterval1[type][0] / IntervalTime1) / Trials[type];
      SEM1[type]  = (SE(SpikesInInterval1[type][0],
   		        SpikesInInterval1[type][1],
   		        Trials[type]))
   	              / IntervalTime1;
      SD1[type]  = (SD(SpikesInInterval1[type][0],
   		        SpikesInInterval1[type][1],
   		        Trials[type]))
   	              / IntervalTime1;
      if (SubtractBackground || KeepSecondInterval) {
        Mean2[type]=(SpikesInInterval2[type][0] / IntervalTime2) / Trials[type];
        SEM2[type] =(SE(SpikesInInterval2[type][0],
   		        SpikesInInterval2[type][1],
   		        Trials[type]))
   	              / IntervalTime2;
        SD2[type] =(SD(SpikesInInterval2[type][0],
   		        SpikesInInterval2[type][1],
   		        Trials[type]))
   	              / IntervalTime2;
        }

      if (DEBUG)
         fprintf(stderr,
"type %2d: %.4f,%.4f = %2d,%2d /%.4f,%.4f /%2d  SD=%.2f,%.2f  SEM=%.2f,%.2f\n",
	    type,Mean1[type],Mean2[type],
		SpikesInInterval1[type][0], SpikesInInterval2[type][0],
		IntervalTime1, IntervalTime2,
		Trials[type],SD1[type], SD2[type], SEM1[type], SEM2[type]);
      }
   }
/* ********************************************************************	*/

/* FUNCTION Get_Interval_Begin */
/* FUNCTION Get_Interval_End */
/* FUNCTION Get_Interval_Mean */
/* FUNCTION Get_Interval_SEM */
/* FUNCTION Get_Interval_Count */
/* FUNCTION Get_Interval_Data */		/* Tell *all* values	*/
/* FUNCTION Get_Interval_TrialNumber */		/* Tell *all* values	*/
	 /* For macros and/or histogram */

int Get_Interval_Begin(int which_interval) {
	return((which_interval==1) ? BeginMs1 : BeginMs2);
	}
int Get_Interval_End(int which_interval) {
	return((which_interval==1) ? EndMs1 : EndMs2);
	}
float Get_Interval_Mean(int type, int which_interval) {
	return((which_interval==1) ? Mean1[type]:Mean2[type]);
	}
float Get_Interval_SEM (int type, int which_interval) {
	return((which_interval==1) ? SEM1[type]:SEM2[type]);
	}
float Get_Interval_SD (int type, int which_interval) {
	return((which_interval==1) ? SD1[type]:SD2[type]);
	}

int Get_Interval_Count(int type, int which_interval) {
	if (which_interval == 200)	/* else warns that it's unused	*/
	    which_interval--;		/* (these are dummy statements)	*/
	if (type == FAIL)		/* which_interval irrelevant!	*/
	   return(0);			/* FAIL --> no data		*/
	return(Trials[type]);
	}

int *Get_Interval_Data(int type, int which_interval) {
	if (type == FAIL)
	   return(SpikesInEachInterval1[0]);	/* Use a real pointer	*/
	return((which_interval==1) ?
		SpikesInEachInterval1[type]:SpikesInEachInterval2[type]);
	}

int *Get_Interval_TrialNumber(int type) {
	if (type == FAIL)				/* Nonsense, but*/
	   return(TrialNumberForEachSpikeCount[0]);	/* Use real ptr	*/
	return(TrialNumberForEachSpikeCount[type]);
	}
/* ********************************************************************	*/

/* FUNCTION Get_Interval_Type */
	 /* Return 1:(one interval), 2:(interval minus bg) or 3:(2 intervals)*/
int Get_Interval_Type() {
        return((SubtractBackground==YES)? 2 : ((KeepSecondInterval)? 3 : 1));
	}
/* ********************************************************************	*/
/* ********************************************************************	*/

float Ttest_On_Interval(int interval, int type1, int type2, int sides) {
	return(Ttest(
		Get_Interval_Data(type1, interval),
		Get_Interval_Count(type1, interval),
		Get_Interval_Data(type2, interval),
		Get_Interval_Count(type2, interval),
		sides));
	}

float Ttest_On_Two_Intervals(int type, int sides) {
	/* Compile-time errs if put these two in-line */
	float a = Get_Interval_Mean(type, 2);
	float b = Get_Interval_SEM(type, 1);

	return(Ttest_short(
		Get_Interval_Mean(type, 1),
		Get_Interval_Count(type, 1),
		b,
		a,
		Get_Interval_Count(type, 2),
		Get_Interval_SEM(type, 2),
		sides));
	}

float Ttest_Vs_Zero_On_Interval(int interval, int type, int sides) {
	return(Ttest_vs_zero(
		Get_Interval_Data(type, interval),
		Get_Interval_Count(type, interval),
		sides));
	}
/* ********************************************************************	*/
/* ********************************************************************	*/

/* FUNCTION Anova_On_Interval */
	 /* Always uses interval 1	*/
float Anova_On_Interval() {
	double **data;
	int Types = Count_TrialTypes();
	int max_trials = 0;
	int i,j;
	double *results;


	for (i=0; i<Types; i++)
	   if (Trials[i] > max_trials)
	      max_trials = Trials[i];

	data = dmatrix(Types, max_trials);

	for (i=0; i<Types; i++) {
	   for (j=0; j<max_trials; j++)
	     data[i][j] = SpikesInEachInterval1[i][j];
	   if (j < max_trials)
	    for ( ; j<max_trials; j++)
	     data[i][j] = FAIL;			/* Pad with empties	*/
	   }

	results = anova(data, Types, max_trials);
	return((float)(*results));
	}

/* ********************************************************************	*/
/* ********************************************************************	*/

/* FUNCTION Set_Interval_From_CmdLine */
	 /* FROM COMMAND LINE ONLY! interval times in ms re: alignment	*/
void Set_Interval_From_CmdLine(int b, int e, int which_interval) {
	if (which_interval==1) {
	   BeginMs1 = b;
	   EndMs1 = e;
	 } else {
	   BeginMs2 = b;
	   EndMs2 = e;
	   KeepSecondInterval = YES;
	   }
	}
/* ********************************************************************	*/

/* FUNCTION Set_bg_Interval_From_CmdLine */
	 /* Specify only ENDING time (re: align): will set int1==int2	*/
void Set_bg_Interval_From_CmdLine(int e) {
	bg_EndMs = e;	/* Check_Interval() sets BeginMs2 and EndMs2	*/
	SubtractBackground = YES;
	}
/* ********************************************************************	*/

/* FUNCTION Check_Interval */
	 /* Bounds checking; see also _macro/bounds.c			*/
	 /* Note: -B interval set to same length as fg interval		*/
static void Check_Interval(int b1, int e1, int b2, int e2, int bg) {
	/* INTERVAL 1	*/
	if (Get_Macro_Number() == 150)
	   return;

	if (b1 < -Get_ZeroTime()) {
	   BeginMs1 = -Get_ZeroTime();
	   if (Is_Batch())
	      Warning("");
	   fprintf(stderr, "BeginMs1 --> %5d\n", BeginMs1);
	 } else
	   BeginMs1 = b1;

	if (e1 > DurationTime() - Get_ZeroTime()) {
	   EndMs1 = DurationTime() - Get_ZeroTime();
	   if (Is_Batch())
	      Warning("");
	   fprintf(stderr, "EndMs1   --> %5d\n", EndMs1);
	 } else
	   EndMs1 = e1;

	 if (e1 < b1) {
	   int temp = e1;
	   if (Is_Batch())
	      Warning("");
	   fprintf(stderr, "end comes before beginning - swapping!\n");
	   e1 = b1;
	   b1 = temp;
	   }
	

	/* BACKGROUND INTERVAL */
	if (!(SubtractBackground || KeepSecondInterval))
	   return;

	if (SubtractBackground) {
	   if (KeepSecondInterval)
	      Exit("Cannot use both -B and -I", "Check_Interval");
	   e2 = bg;
	   b2 = e2 - (EndMs1-BeginMs1);			/* Equal lengths*/
	   }

	/* INTERVAL 2  (either -I or -B) */
	if (b2 < -Get_ZeroTime()) {
	   BeginMs2 = -Get_ZeroTime();
	   if (Is_Batch())
	      Warning("");
	   fprintf(stderr, "BeginMs2 --> %5d\n", BeginMs2);
	   if (SubtractBackground) {		/* -B: move end forward	*/
	      EndMs2 = BeginMs2 + (EndMs1-BeginMs1);	/* Keep = sizes	*/
	      fprintf(stderr, "EndMs2   --> %5d\n", EndMs2);
	      }
	 } else
	   BeginMs2 = b2;
	if (e2 < b2) {
	   int temp = e2;
	   if (Is_Batch())
	      Warning("");
	   fprintf(stderr, "end comes before beginning - swapping!\n");
	   e2 = b2;
	   b2 = temp;
	   }
	if (e2 > DurationTime() - Get_ZeroTime()) {
	   if (SubtractBackground)		/* -B: move end forward	*/
	      Exit("End of bg interval too far: reset!", "Check_Interval");
	   EndMs2 = DurationTime() - Get_ZeroTime();
	   if (Is_Batch())
	      Warning("");
	   fprintf(stderr, "EndMs2   --> %5d\n", EndMs2);
	 } else
	   EndMs2 = e2;
	}
/* ********************************************************************	*/

/* FUNCTION Bad_Interval */
	 /* For grab -w -a -i/I/B	*/
	 /* Also -xi500 (or -xi) : skip if bad interval (or if end+500 is bad)*/
int Bad_Interval(int align, int extra) {
	int start = EventExtractTapeOn() - align;
	int end   = EventExtractTapeOff() - align + extra;

	if (start == FAIL || end == FAIL)
	   Exit("Bad timing", "Bad_Interval");

	if (BeginMs1 < start)
	   return(0);
	if (EndMs1 < end)
	   return(0);

	if (!(SubtractBackground || KeepSecondInterval))
	   return(1);

	if (BeginMs2 < start)
	   return(0);
	if (EndMs2 < end)
	   return(0);

	return(1);
	}
/* ********************************************************************	*/

/* FUNCTION Show_Interval */
	 /* Plot interval boundaries using current coordinates	*/
void Show_Interval() {
	int x0 = BeginMs1 + Get_ZeroTime();
	int x1 = EndMs1 + Get_ZeroTime();
	int y0 = HISTO_OFFSET-10;
	int y1 = HISTO_OFFSET + (HEIGHT-HISTO_OFFSET)/3;

	if (x0 == x1 &&				/* Skip 0 ms intervals	*/
	       SubtractBackground==NO && KeepSecondInterval==NO)
	   return;

	Color(8,8,8);				/* Solid white line	*/
	move(x0, y0);
	cont(x0, y1);
	move(x1, y0);
	cont(x1, y1);
	fill(0);

	Color(0,0,0);
	linemod("longdashed");			/* Dashed black line	*/
	move(x0, y0);
	cont(x0, y1);
	move(x1, y0);
	cont(x1, y1);
	fill(0);
	linemod("solid");			/* Restore solid	*/

	if (SubtractBackground==NO && KeepSecondInterval==NO)
	   return;

	x0 = BeginMs2 + Get_ZeroTime();
	x1 = EndMs2 + Get_ZeroTime();

	Color(6,6,6);				/* Solid grey line	*/
	move(x0, y0);
	cont(x0, y1);
	move(x1, y0);
	cont(x1, y1);
	fill(0);

	Color(2,2,2);
	linemod("longdashed");			/* Dashed grey line	*/
	move(x0, y0);
	cont(x0, y1);
	move(x1, y0);
	cont(x1, y1);
	fill(0);
	linemod("solid");			/* Restore solid	*/
	Color(0,0,0);
	}
/* ********************************************************************	*/

/* FUNCTION Set_Interval_By_Bin */
	 /* Used ONLY when setting interval by 'most variance' method	*/
void Set_Interval_By_Bin(int b, int e) {
	Check_Interval(			/* Using 'Check()' to set!	*/
        	b * Get_Histo_BinSize() - Get_ZeroTime(),
		e * Get_Histo_BinSize() - Get_ZeroTime(),
		BeginMs2, EndMs2,
		bg_EndMs);
	}
/* ********************************************************************	*/
