/* FILE _macros/reach/roc.c */
     /* Produce an 'ROC' plot for reach data */
     /* Uses -i arguments (or -i default) to set span to calculate over	*/
     /* One form (88) does calc everywhere, gives time plot	*/
     /* Other form (89) does calc at just one location, prints it */
     /* Third form (87) does cumulative ROC across all files given to it*/

#define WIDTH	250		/* Half of base of triangular kernel(8)	*/
				/* If zero, no smoothing		*/
				/* Have used 25 and 50 and 80		*/



	/* Height of triangle is irrelevant; just a scale factor	*/

#define ADD_TTEST    1		/* Macro 89 only: t-test value, too?	*/

#include "../Deffs.h"
/* ********************************************************************	*/

static void Init_ROC();
static void Cumulate_Spikes(int type);
static void Scale_ROC();
static void Setup_Plot();
static void Plot_ROC();
static void Cumulate_ROC();
static void Plot_Cumulative_ROC();
static float Return_OneValue_of_ROC(int ms);
/* ********************************************************************	*/
/* FUNCTION ROC87 */
	 /* Cumulative ROC curve (from many cells)			*/
	    /* Use -i to specify the span to average over		*/
	    /* Will generate n ps.files, with 1,2,3,...,n cells in each	*/
void ROC87() {
	Rewind_InputFile();

	if (Count_TrialTypes() != 2) {
	   int i;
	   fprintf(stderr, "%d.%d has %d trial types (skipping):  ",
	   	UnitNumber_Header(), RunNumber_Header(), Count_TrialTypes());
	   for (i=0; i<Count_TrialTypes(); i++)
	       fprintf(stderr,"  %d: %d.%d",
	       	i, Get_TrialType_Info(i, STACK), Get_TrialType_Info(i, CLASS));
	   fprintf(stderr, "\n");  
	   return;
	   }

	Init_ROC();
	while (Read_Next_Trial(WITH_DATA))
	   Cumulate_Spikes(0);

	Cumulate_ROC();		/* Add another cell to cumulative sum	*/
 	Setup_Plot();
	Plot_Cumulative_ROC();
	}
/* ********************************************************************	*/

/* FUNCTION ROC88 */
	 /* Single cell ROC curve */
	    /* Use -i to specify the span to average over		*/
void ROC88() {
	Rewind_InputFile();

	if (Count_TrialTypes() != 2) {
	   fprintf(stderr, "%d.%d has %d trial types (skipping)\n",
	   	UnitNumber_Header(), RunNumber_Header(), Count_TrialTypes());
	   return;
	   }

	Init_ROC();
	while (Read_Next_Trial(WITH_DATA))
	   Cumulate_Spikes(0);
	Scale_ROC();
 	Setup_Plot();
	Plot_ROC();
	}
/* ********************************************************************	*/

/* FUNCTION ROC89 */
	 /* ROC value from single point (actually, a single _span_: -i)	*/
void ROC89() {
	char Name[80];
	FILE *file;
	int best = Get_dataValue_From_CmdLine();
	int null = Get_DataValue_From_CmdLine();
	int BestDirectionFromFile = 0;
	int class1, class2;
	int stack1, stack2;
	float ROC;


	Rewind_InputFile();
	if (best==-1) {			/* Special flag (arm-eye stuff)*/
	   if (List_Length(STACK) != 2) {
	     fprintf(stderr, "%d.%d has %d trial types (skipping)\n",
	   	UnitNumber_Header(), RunNumber_Header(), Count_TrialTypes());
	     return;
	     }
	   sprintf(Name, "best_%d.%d", UnitNumber_Header(), RunNumber_Header());
	   file = fopen(Name, "r");

	   if (file == NULL) {
	      int i;			/* Try other run numbers	*/
	      for (i=RunNumber_Header()-1; i>0; i--) {
	         sprintf(Name, "best_%d.%d", UnitNumber_Header(), i);
	         file = fopen(Name, "r");
		 if (file != NULL)
		    break;
		 }
	      if (file == NULL)
	       for (i=RunNumber_Header()+1; i<20; i++) {
	         sprintf(Name, "best_%d.%d", UnitNumber_Header(), i);
	         file = fopen(Name, "r");
		 if (file != NULL)
		    break;
		 }
	      if (file == NULL) {
	         fprintf(stderr, "No 'best' file (unit %d)\n",
		 				UnitNumber_Header());
		 fprintf(file, "%d.%d versus %d.%d\n%d:%d ms from %s\n%.3f\n",
		    0, 0, 0, 0, 
		    Get_Interval_Begin(1), Get_Interval_End(1),
		    AlignTime_String(), -9999.);
		 return;
		 }
	      }
	   fscanf(file, "%d", &best);	/* GET BEST CLASS FROM FILE! */
	   BestDirectionFromFile = 1;
	 } else 
	if (best==0 && Count_TrialTypes() != 2) {
	   fprintf(stderr, "%d.%d has %d trial types (skipping)\n",
	   	UnitNumber_Header(), RunNumber_Header(), Count_TrialTypes());
	   return;
	 } else
        if (best!=0 && List_Length(STACK) != 1) {
	   fprintf(stderr,
	         "%d.%d specifies the class but has > 1 stack (skipping)",
				UnitNumber_Header(), RunNumber_Header());
	   return;
	   }


	file = Open_Macro_Output("roc", -1);
	if (file == NULL) {
	   fprintf(stderr, "Err opening 'roc'");
	   Exit("", "ROC89()");
	   }


	if (List_Length(STACK) > 1) {
	   stack1 = List_Element(0,STACK);
	   stack2 = List_Element(1,STACK);
	 } else
	   stack1 = stack2 = List_Element(0,STACK);

	if (BestDirectionFromFile) {
	    class1 = class2 = best;
	    goto HAVE_CLASSES;
	    }
	if (List_Length(CLASS) > 1) {
	    class1 = List_Element(0,CLASS);
	    class2 = List_Element(1,CLASS);
	  } else
	    class1 = class2 = List_Element(0,CLASS);


	if (best) {			/* Overwrite previous choice	*/
	   class1 = best;
	   if (null) {
	      class2 = null;
	      goto HAVE_CLASSES;
	      }

	   switch (best%8) {			/* Type 1-8?	*/
	      case 1: class2 = 5;   break;
	      case 2: class2 = 4;   break;
	      case 4: class2 = 2;   break;
	      case 5: class2 = 1;   break;
	      case 6: class2 = 8;   break;
	      case 0: class2 = 6;   break;
	      default:
	              fprintf(stderr,
		          "%d.%d No entry for best class %d (skipping)\n",
	   	          UnitNumber_Header(), RunNumber_Header(), best);
	      }
	   class2 += 8*((best-1)/8);		/* Type 10's or 20's row?*/
	   if (Get_TrialType_Info(
	   		Get_StackClass_TrialType(stack1,class2), NUMBER) < 3)
	      class2 = -class1;			/* All but this class	*/
	   }
	      
	HAVE_CLASSES:
	fprintf(file, "%d.%d versus %d.%d\n%d:%d ms from %s\n",
	        stack1, class1,  stack2, class2,
		Get_Interval_Begin(1), Get_Interval_End(1),
		AlignTime_String());

	Init_ROC();
	while (Read_Next_Trial(WITH_DATA)) {
	   if (best==0)
	      /* All rightward RF's, & 1st stack of pair is always to right,
	       * and class 3/7 is always target to right, so if stack is 1st,
	       * use one, else use the other				*/
	      Cumulate_Spikes(0);
	    else if (StackNumber_Header()==stack1 &&
	            TableNumber_Header()==class1)
	      Cumulate_Spikes(1);
	    else if ((StackNumber_Header()==stack2 &&
	             TableNumber_Header()==class2)
	          || (class2 < 0 && (TableNumber_Header()%8) != ((-class2)%8)))
	       Cumulate_Spikes(2);
	   }
	Scale_ROC();

	ROC = Return_OneValue_of_ROC(Get_Interval_Begin(1));

	fprintf(file, "%.3f\n", ROC);
	if (ADD_TTEST) {
	   if (Get_TrialType_Info(
		Get_StackClass_TrialType(stack1,class1), NUMBER) > 2 &&
	       Get_TrialType_Info(
		Get_StackClass_TrialType(stack2,class2), NUMBER) > 2)
	      fprintf(file, "%.3f\n", Ttest_On_Interval(1,
	   		Get_StackClass_TrialType(stack1,class1),
	   		Get_StackClass_TrialType(stack2,class2),
			2));
	    else fprintf(file, "NA\n");
	    }
	fclose(file);
	}
/* ********************************************************************	*/
/* ********************************************************************	*/

/* WAS: 30 600 40 5 */

#define MAX_TRIALS	800
#define MAX_BINS	1000		/* Hold entire ROC histogram	*/

static int ROC_SPAN;			/* Calc over this span (ms)	*/
#define ROC_BIN		25		/* [25] Calc/plot pts this far apart*/
#define BINS		(ROC_SPAN/ROC_BIN)

static float ConditionA[MAX_TRIALS][MAX_BINS];
static float ConditionB[MAX_TRIALS][MAX_BINS];
static int trialsA, trialsB;		/* Number of trials (pref'd)	*/
static int begin;			/* Time to start  ROC analysis	*/
static int end;				/* Time to finish ROC analysis	*/
static int KernelArea;			/* Area under the kernel	*/

static float Get_RightA(int criterion, int bin);
static float Get_WrongB(int criterion, int bin);
static float Calc_ROC(int bin);
/* ********************************************************************	*/

/* FUNCTION Init_ROC */
	 /* Initialize ROC variables */
static void Init_ROC() {
	int i,j;

	ROC_SPAN = Get_Interval_End(1) - Get_Interval_Begin(1);	

	for (i=0; i<MAX_TRIALS; i++)
	     for (j=0; j<MAX_BINS; j++)
		 ConditionA[i][j] = ConditionB[i][j] = 0.;

	trialsA = trialsB = 0;
	begin = 0;
	end   = DurationTime()-1;
					/* Use Get_Times_From_CmdLine()? */

	if (((end-begin)/ROC_BIN) >= MAX_BINS)
	   Exit("Increase MAX_BINS", "Init_ROC");

	if (BINS * ROC_BIN != ROC_SPAN) {
	   fprintf(stderr, "Bins is %d, ROC_BIN is %d, ROC_SPAN is %d\n", BINS, ROC_BIN, ROC_SPAN);
	   Exit("Pick divisible ROC_BIN", "Init_ROC");
	   }

 	KernelArea = WIDTH;
	for (i=1; i<WIDTH; i++)
	    KernelArea += 2*i;
	}
/* ********************************************************************	*/

/* FUNCTION Cumulate_Spikes */
	 /* Tally spikes from a trial for ROC */
	 /* BUT use smoothed data: triangular kernel	*/
static void Cumulate_Spikes(int type) {
	int SpikeCount = SpikeCount_Header(-1);	/* Spikes in this trial	*/
	extern short Spikes[];		/* Spike times, ms from TAPE_ON	*/
	short *spikes = Spikes;		/* Ptr to raw spike times	*/
	int i = 0;			/* Index raw spike times	*/
	float *histo;			/* Ptr to appropriate cumulator	*/
#	define SMOOTH_SIZE	20000
# 	define MARGIN		(WIDTH+2)	/* Prevent over-runs	*/
	int Smoothed[MARGIN+SMOOTH_SIZE+MARGIN];   /* Long buffer	*/
	int j;					/* Index smoothed spikes*/

	if (trialsA>=MAX_TRIALS || trialsB>=MAX_TRIALS) {
	   static int informed = 0;
	   if (informed == 0)
	      fprintf(stderr, "Increase 'MAX_TRIALS' in roc.c\n");
	   informed = 1;
	   return;
	   }

	if (type == 1)
	   histo = ConditionA[trialsA++];
	else if (type == 2)
	   histo = ConditionB[trialsB++];
	else {
	   int SORT_ON = (List_Length(STACK) > 1) ? STACK: CLASS;
	   int fig = Get_TrialType();	/* Fig type: indexes info	*/
	   if (Get_TrialType_Info(fig, SORT_ON) == List_Element(0, SORT_ON))
	      histo = ConditionA[trialsA++];
	    else
	      histo = ConditionB[trialsB++];
	   }

	for (j=0; j<SMOOTH_SIZE+2*MARGIN; j++)			/* Init	*/
	   Smoothed[j] = 0;

			/* Transfer triangle into buffer for each spike	*/
	i = 0;
	while (i++ < SpikeCount) {
	   int *center = Smoothed + MARGIN + *(spikes++);
	   if (WIDTH) {
	      *(center) += WIDTH;			/* Center bin	*/
	      for (j=1; j<WIDTH; j++) {			/* Side bins	*/
	          *(center+j) += WIDTH-j;
	          *(center-j) += WIDTH-j;
	          }
	    } else
	      *(center) += 1;
	   }
	if (*(spikes-1) >= SMOOTH_SIZE)
	   Exit("Make SMOOTH_SIZE larger", "Cumulate_Spikes()");

	for (j=begin; j<end; j++)		/* For ROC interval	*/
	    *(histo + (j-begin)/ROC_BIN) += Smoothed[MARGIN+j];
	}
/* ********************************************************************	*/

/* FUNCTION Scale_ROC */
	 /* Correct for non-unity kernel area */
	 /* Was non-functional until 1-7-2010! (bug!) */
static void Scale_ROC() {
	int i,j;

	if (WIDTH)
	 for (i=0; i<MAX_TRIALS; i++)
	   for (j=0; j<MAX_BINS; j++) {
	       ConditionA[i][j] /= KernelArea;
	       ConditionB[i][j] /= KernelArea;
	   }
	}
/* ********************************************************************	*/
static float Cumulative_ROC[MAX_BINS];		/* Values as f() of time*/
static int n = 0;				/* Number of cells	*/

/* Cumulate_ROC */
static void Cumulate_ROC() {
	static int StoreSpan = 0;		/* Remember (so can ck)	*/
	int i;

	if (StoreSpan==0)			/* First time through:	*/
	    StoreSpan = ROC_SPAN;		/* Store this value	*/
	else if (StoreSpan != ROC_SPAN)
	    Exit("Can't change span in mid-cumulation", "ROC87");

	if (++n == 1)					/* First cell?	*/
	   for (i=0; i<MAX_BINS; i++)
	     Cumulative_ROC[i] = 0.;

	for (i=0; i<(end-begin)/ROC_BIN; i++)
	     Cumulative_ROC[i] += Calc_ROC(i);
	}
/* ********************************************************************	*/
/* ********************************************************************	*/
#define SCALE	(.6 * HEIGHT)

/* FUNCTION Setup_Plot */
static void Setup_Plot() {
	char tag[80];
	int i = 0;
	extern void Graph_TimeLine();
	extern void Do_Vertical_Lines_On_One_Graph();
	extern void Graph_Analog_Data();
	extern void Picture();
	extern int  function_HISTO_OFFSET();
	extern int  function_HEIGHT();

	Setup_Coords_For_TrialType(Count_TrialTypes());
	Graph_TimeLine();
	Do_Vertical_Lines_On_One_Graph();	/* Don't change coords	*/

	Rewind_InputFile();
	while (Read_Next_Trial(WITH_DATA) && (++i<20)) {
	   SetColorByTrialType(Get_TrialType());
	   Graph_Analog_Data();
	   SetColor(0,0,0);				/* Black 	*/
	   if (i<3)					/* Just a few	*/
	      Picture();
	   }

	SetLineWidth(5);
	move(DurationTime()+100, (int)(HISTO_OFFSET));
	cont(DurationTime(),     (int)(HISTO_OFFSET));
	cont(DurationTime(),     (int)(HISTO_OFFSET + SCALE));
	cont(DurationTime()+50,  (int)(HISTO_OFFSET + SCALE));
	move(DurationTime(),     (int)(HISTO_OFFSET + .5 * SCALE));
	cont(DurationTime()+100, (int)(HISTO_OFFSET + .5 * SCALE));

	fontsize(16);
	move(DurationTime()+150, (int)(HISTO_OFFSET));
	alabel('l', 'c', "0");
	move(DurationTime()+150, (int)(HISTO_OFFSET + SCALE));
	alabel('l', 'c', "1.0");
	move(DurationTime()+150, (int)(HISTO_OFFSET + SCALE/2));
	alabel('l', 'c', "0.5");
	move(0, -HISTO_OFFSET/2);
	sprintf(tag, "ROC: %d ms bins (%d), kernel base is %d",
		ROC_SPAN, ROC_BIN, WIDTH * 2);
	if (n > 1)
	   sprintf(tag, "%s, %d cells", tag, n);
	alabel('l', 'c', tag);

	fontsize(12);
	move(DurationTime()+150, (int)(HISTO_OFFSET + .25 * SCALE));
	alabel('l', 'c', "0.25");
	move(DurationTime()+150, (int)(HISTO_OFFSET + .75 * SCALE));
	alabel('l', 'c', "0.75");		/* CRITERION VALUE	*/

	SetLineWidth(1);
	move(DurationTime(),	 (int)(HISTO_OFFSET + SCALE));
	cont(0,	 		 (int)(HISTO_OFFSET + SCALE));
	move(DurationTime(),	 (int)(HISTO_OFFSET));
	cont(0,	 		 (int)(HISTO_OFFSET));
	linemod("longdashed");
	move(DurationTime(),	 (int)(HISTO_OFFSET + .75 * SCALE));
	cont(0,	 		 (int)(HISTO_OFFSET + .75 * SCALE));
	move(DurationTime(),	 (int)(HISTO_OFFSET + .25 * SCALE));
	cont(0,	 		 (int)(HISTO_OFFSET + .25 * SCALE));
	move(DurationTime(),	 (int)(HISTO_OFFSET + .50 * SCALE));
	cont(0,	 		 (int)(HISTO_OFFSET + .50 * SCALE));
	linemod("solid");
	}
/* ********************************************************************	*/

/* FUNCTION Return_OneValue_of_ROC */
	 /* Return one value */
static float Return_OneValue_of_ROC(int ms) {
	if (trialsA < 3 || trialsB < 3) {
	   fprintf(stderr,
	     "Warning: %s (%s) %d.%d compares %d vs %d trials (best=%d)\n",
	   	MonkName_Header(), Date_From_Header(),
		UnitNumber_Header(), RunNumber_Header(),
		trialsA, trialsB,
	        Get_dataValue_From_CmdLine());
	   if (trialsA < 2 || trialsB < 2)	/* 1 or 0	*/
	       return(FAIL);		/* Illegal value	*/
	   }

	return(Calc_ROC((ms+Get_ZeroTime())/ROC_BIN));
	}
/* ********************************************************************	*/
/* FUNCTION Plot_ROC */
	 /* Plot ROC values as function of bin (or time); scale to fit	*/
static void Plot_ROC() {
        int end_bin = (end-begin-ROC_SPAN)/ROC_BIN;
	int bin;
	FILE *file;
	extern int  function_HISTO_OFFSET();
	extern int  function_HEIGHT();

	file = Open_Macro_Output("rocs", -1);

	SetLineWidth(3);
	bin = 0;
	move(begin + ROC_SPAN/2 + bin*ROC_BIN,
	     (int)(.5 + HISTO_OFFSET + SCALE*Calc_ROC(bin)));
					/* y value: Round,offset,scale	*/
	fprintf(file, "%d %.3f\n",
		begin + ROC_SPAN/2 + bin*ROC_BIN - Get_ZeroTime(),
	     	Calc_ROC(bin));

	while (++bin <= end_bin) {
	   cont(begin + ROC_SPAN/2 + bin*ROC_BIN,
	        (int)(.5 + HISTO_OFFSET + SCALE*Calc_ROC(bin)));
	   fprintf(file, "%d %.3f\n",
		begin + ROC_SPAN/2 + bin*ROC_BIN - Get_ZeroTime(),
	     	Calc_ROC(bin));
	   }

	fclose(file);
	SetLineWidth(1);
	}	
/* ********************************************************************	*/

/* FUNCTION Plot_Cumulative_ROC */
static void Plot_Cumulative_ROC() {
        int end_bin = (end-begin-ROC_SPAN)/ROC_BIN;
	int bin;
	extern int  function_HISTO_OFFSET();
	extern int  function_HEIGHT();

	SetLineWidth(3);
	bin = 0;
	move(begin + ROC_SPAN/2 + bin*ROC_BIN, /* Round,offset,scale*/
	     (int)(.5 + HISTO_OFFSET + (SCALE * Cumulative_ROC[bin])/n));

	while (++bin <= end_bin)
	   cont(begin + ROC_SPAN/2 + bin*ROC_BIN,
	        (int)(.5 + HISTO_OFFSET + (SCALE * Cumulative_ROC[bin])/n));

	SetLineWidth(1);
   }
/* ********************************************************************	*/
/* ********************************************************************	*/

/* FUNCTION Get_WrongB */
	 /* Return # of condition B incorrectly id'd using 'criterion'	*/
static float Get_WrongB(int criterion, int bin) {
	float count = 0;
	int i;

	for (i=0;i<trialsB;i++) {
	   float sum = 0.;
	   int b;
	   for (b=0; b<BINS; b++)		/* Avg across bins	*/
	      sum += ConditionB[i][bin+b];
	   if (sum > (float)criterion)
	      count += 1;
	   else if (sum == (float)criterion)
	      count += .5;
	   }
	return(count);
	}
/* ********************************************************************	*/

/* FUNCTION Get_RightA */
	 /* Return # of condition A correctly id'd using 'criterion'	*/
static float Get_RightA(int criterion, int bin) {
	float count = 0;
	int i;

	for (i=0;i<trialsA;i++) {
	   float sum = 0;
	   int b;
	   for (b=0; b<BINS; b++)		/* Avg across bins	*/
	      sum += ConditionA[i][bin+b];
	   if (sum > (float)criterion)
	      count += 1;
	   else if (sum == (float)criterion)
	      count += .5;
	   }
	return(count);
	}
/* ********************************************************************	*/

/* FUNCTION Calc_ROC */
	 /* Return 1 ROC value (for a given time slice) */
	 /* Uses current value of 'bin' to pick correct time slice */
static float Calc_ROC(int bin) {
	int criterion = 0;
	float RightA = trialsA;		/* Get_RightA(0)		*/
	float WrongB = trialsB;		/* Get_WrongB(0)		*/
	float Area = 0;


#ifdef  WANT_EQUAL_NUMBERS	/* Not really an issue: works regardless*/
	if (trialsA != trialsB) {
	   fprintf(stderr, "No ROC plot: unequal numbers of trials\n");
	   fprintf(stderr, "   Using %d %d\n", trialsA,trialsB);
	   Exit("Unequal trials", "roc.c:Calc_ROC()");
	   }
#endif

/* In the plot of 'y axis = true hits for condition A' vs 'x axis = false
 * alarms for condition B' (where a value greater than criterion is a hit
 * [either true or false]), we start at the upper right corner and march
 * to the left and down.  Code marked "/ *** /" handles the case where the
 * line being traced drops straight down, that is, the x axis value does
 * not change - must recalculate the corresponding 'y' value in order to
 * get the area.
 * */

	do {
	   while (Get_WrongB(criterion, bin) == WrongB) {
	      criterion++;
	      RightA = Get_RightA(criterion-1, bin);		/***/
	      }
	   Area +=  (WrongB - Get_WrongB(criterion, bin)) *	/* Width */
		   ((RightA + Get_RightA(criterion, bin)) / 2);	/* Avg height */
	   RightA = Get_RightA(criterion, bin);
	   WrongB = Get_WrongB(criterion, bin);
	 } while (WrongB > 0.);
	return(Area/(trialsA*trialsB));
	}
/* ********************************************************************	*/
