/*
**	This file is part of XDowl
**	Copyright (c) 1994 Jamie Mazer
**	California Institute of Technology
**	<mazer@asterix.cns.caltech.edu>
*/

/******************************************************************
**  RCSID: $Id: pm_bam.c,v 1.17 2002/07/15 04:30:15 cmalek Exp $
** Program: xdphys
**  Module: pm_bam.c
**  Author: bjarthur
** Descrip: binaurally out-of-phase amplitude modulation resulting
**  in continuously varying IID
**
** Revision History (most recent last)
**
** 97.4 bjarthur  CREATION DATE
**  adapted from pm_beats.c
**
*******************************************************************/

#include "xdphyslib.h"
#include "xdphys.h"
#include "plotter.h"

static void bam_perhistFN(FDO *, FILEDATA *, int, FILE *, int);
static void bam_rasterFN(FDO *, FILEDATA *, int, FILE *);
static int bam_plotter(FDO *, FILEDATA *, FDObj_ViewType *, int, FILE *);
static void bam_summarize(FILEDATA *, int);
static void bam_freeSumData(FILEDATA *);
static int bam_free(FILEDATA *);
static int bam_valid_view(FDO *, int);

typedef struct {
	int unit;		/* unit id or "mask" */

	int l_ndata;		/* number of summary data points */
	float *l_depvals;	/* depvar values (int) for matching means */
	float *l_means;		/* mean responses (nrasters/reps) */
	float *l_stderrs;	/* std error of responses (nrasters/reps) */
	float *l_stddevs;	/* std dev of responses (nrasters/reps) */
	int *l_n;		/* # of data points means, etc are based on */

	int r_ndata;		/* number of summary data points */
	float *r_depvals;	/* depvar values (int) for matching means */
	float *r_means;		/* mean responses (nrasters/reps) */
	float *r_stderrs;	/* std error of responses (nrasters/reps) */
	float *r_stddevs;	/* std dev of responses (nrasters/reps) */
	int *r_n;		/* # of data points means, etc are based on */

	int b_ndata;		/* number of summary data points */
	float *b_depvals;	/* depvar values (int) for matching means */
	float *b_means;		/* mean responses (nrasters/reps) */
	float *b_stderrs;	/* std error of responses (nrasters/reps) */
	float *b_stddevs;	/* std dev of responses (nrasters/reps) */
	int *b_n;		/* # of data points means, etc are based on */

	float s_mean;		/* mean responses (nrasters/reps) */
	float s_stderr;		/* std error of responses (nrasters/reps) */
	float s_stddev;		/* std dev of responses (nrasters/reps) */
	int s_n;		/* # data that means etc are based on */
} SumData;

#define SUMDATA(data) ((SumData*)(data->sumdata))

static int view_order[][2] = {
	{PM_PERHIST},
	{PM_PERHIST_LEFT},
	{PM_PERHIST_RIGHT},
	{PM_RASTER},
	{PM_PSTH},
	{PM_ISIH}
};
static int nsubviews[] = { 1, 1, 1, 1, 1, 1 };
static int nviews = 6;

int bam_do_plot(FDO * fdo, FILEDATA * fd, int view, int l, FILE * fptr)
{
	if (!bam_valid_view(fdo, view)) {
		pm_type_error("bam", view);
		view = PM_DEFAULT;
		(void) bam_valid_view(fdo, view);
	}

	switch (view) {
	case PM_DEFAULT:
	case PM_PERHIST:
		bam_perhistFN(fdo, fd, l, fptr, 0);
		break;
	case PM_PERHIST_LEFT:
		bam_perhistFN(fdo, fd, l, fptr, 1);
		break;
	case PM_PERHIST_RIGHT:
		bam_perhistFN(fdo, fd, l, fptr, 2);
		break;
	case PM_RASTER:
		bam_rasterFN(fdo, fd, l, fptr);
		break;
	case PM_PSTH:
		FD1_psthFN(fdo, fd, l, fptr);
		break;
	case PM_ISIH:
		FD1_isihFN(fdo, fd, l, fptr);
		break;
	default:
		pm_type_error("bam", view);
		break;
	}

	return (1);
}

static int bam_valid_view(FDO * fdo, int view)
{
	int retval = 0;
	int i, j;

	if (view == PM_DEFAULT) {
		retval = 1;
	} else {
		for (i = 0; i < nviews; i++) {
			for (j = 0; j < nsubviews[i]; j++) {
				if (view_order[i][j] == view) {
					retval = 1;
					fdo->view.lr = i;
					fdo->view.ud = j;
					break;
				}
			}
		}
	}

	return (retval);
}

static int bam_plotter(FDO * fdo, FILEDATA * fd, FDObj_ViewType * view,
		       int l, FILE * fptr)
{
	adjust_index(view->lr, view->ud);

	bam_do_plot(fdo, fd, view_order[view->lr][view->ud], l, fptr);

	return (1);
}

static void bam_perhistFN(FDO * fdo, FILEDATA * fd, int l, FILE * fptr,
			  int opt)
{
/* fdo->no_errs=1 here means it's a true histogram, ie spikes are simply summed
   and no division by # trials or # cycles occurrs.  fdo->no_errs=0 here means
   that a division by # cycles occurrs.  note that this is different from
   the old perhist, which divides by # trials */

	int i;
	char *ti;
	char *xlabel, *ylabel;
	float *re_depvals, *re_means, *re_stderrs;
	int *re_n, re_ndata;
	float *px, *py, *pz, *sx, *sy, *sz;
	int pn, sn;

	bam_summarize(fd, fdo->sum_code[l]);

	switch (opt) {
	case (0):		/* both */
		re_ndata = SUMDATA(fd)->b_ndata;
		re_depvals = SUMDATA(fd)->b_depvals;
		re_means = SUMDATA(fd)->b_means;
		re_stderrs = SUMDATA(fd)->b_stderrs;
		re_n = SUMDATA(fd)->b_n;
		ti = "combined per_hist";
		break;
	case (1):		/* decreasing IIDs */
		re_ndata = SUMDATA(fd)->l_ndata;
		re_depvals = SUMDATA(fd)->l_depvals;
		re_means = SUMDATA(fd)->l_means;
		re_stderrs = SUMDATA(fd)->l_stderrs;
		re_n = SUMDATA(fd)->l_n;
		ti = "decreasing per_hist";
		break;
	case (2):		/* increasing IIDs */
		re_ndata = SUMDATA(fd)->r_ndata;
		re_depvals = SUMDATA(fd)->r_depvals;
		re_means = SUMDATA(fd)->r_means;
		re_stderrs = SUMDATA(fd)->r_stderrs;
		re_n = SUMDATA(fd)->r_n;
		ti = "increasing per_hist";
		break;
	}

	if (fdo->no_errs) {
		SUMDATA(fd)->s_mean *= re_n[0];
		for (i = 0; i < re_ndata; i++) {
			re_means[i] *= re_n[i];
		}
		fd->trashed = 1;
	}

	FD_plotter_copy(re_depvals, re_means,
			!fdo->no_errs ? re_stderrs : NULL, re_ndata, 1,
			&px, &py, &pz, &pn, &sx, &sy, &sz, &sn);

	if (fdo->no_X)
		return;

	FDObj_Add_Data_All(fdo, px, py, pn);
	FDObj_AddLine(fdo, l, px, py, pz, pn, AtFloat,
		      atQuadLinePlotWidgetClass, AtTypeLINEPOINTS,
		      AtMarkCIRCLE, ConvertColor(fdo->graph,
						 FDObj_Colors[l %
							      FDObj_numColors]),
		      FDObj_Legend(fdo, l));
	FDObj_Add_Data_All(fdo, sx, sy, sn);
	FDObj_AddLine(fdo, l, sx, sy, sz, sn, AtFloat,
		      atQuadLinePlotWidgetClass, AtTypeLINEPOINTS,
		      AtMarkCIRCLE, ConvertColor(fdo->graph,
						 FDObj_Colors[l %
							      FDObj_numColors]),
		      FDObj_Legend(fdo, l));

	XtVaSetValues(fdo->graph, XtNtitle, FDObj_Title(fdo, ti, l),
		      XtNshowTitle, True, NULL);
	ylabel = "nspikes";
	xlabel = "iid (dB)";
	XtVaSetValues(fdo->xaxis, XtNlabel, xlabel, NULL);
	XtVaSetValues(fdo->yaxis, XtNlabel, ylabel, NULL);
}

static void bam_rasterFN(FDO * fdo, FILEDATA * fd, int l, FILE * fptr)
{
	int i, j, k;
	char *ti;
	SA_SpikeRasterList *srl = NULL;
	char *xlabel, *ylabel;
	float *xdata, *ydata;
	int ndata;
	double xmin, xmax;

	srl = FD_to_SpikeRasterList(FD1_RAWDATA(fdo->fds[l])->nrasters,
				    FD1_RAWDATA(fdo->fds[l])->rastlist,
				    FD_GI(fdo->fds[l], "epoch"),
				    fdo->sum_code[l]);
	ndata = SA_TotalSpikesInSRL(srl);
	assert((xdata = (float *) malloc(sizeof(float) * ndata)) != NULL);
	assert((ydata = (float *) malloc(sizeof(float) * ndata)) != NULL);

	for (k = i = 0; i < srl->nrasters; i++) {
		for (j = 0; j < srl->sr[i]->nspikes; j++) {
			xdata[k] = srl->sr[i]->tm[j];
			ydata[k++] = FD1_RAWDATA(fd)->pres_order[i];
		}
	}

	SA_FreeSpikeRasterList(srl);

	if (fdo->no_X)
		return;

	FDObj_Add_Data_All(fdo, xdata, ydata, ndata);
	FDObj_AddLine(fdo, l, xdata, ydata, NULL, ndata, AtFloat,
		      atQuadLinePlotWidgetClass, AtTypePOINTS, AtMarkVBAR,
		      ConvertColor(fdo->graph,
				   FDObj_Colors[l % FDObj_numColors]),
		      FDObj_Legend(fdo, l));

	xmin = 0.0;
	xmax = (double) FD_GI(fd, "epoch");
	XtVaSetValues(fdo->xaxis, XtNmin, &xmin, XtNmax, &xmax,
		      XtNautoScale, False, NULL);
	xmin = 0.0;
	xmax = (double) srl->nrasters;
	XtVaSetValues(fdo->yaxis, XtNmin, &xmin, XtNmax, &xmax,
		      XtNautoScale, False, NULL);

	ylabel = "raster # (pres order)";
	ti = "rawraster";
	XtVaSetValues(fdo->graph, XtNtitle, FDObj_Title(fdo, ti, l),
		      XtNshowTitle, True, NULL);
	xlabel = "time (ms)";
	XtVaSetValues(fdo->xaxis, XtNlabel, xlabel, NULL);
	XtVaSetValues(fdo->yaxis, XtNlabel, ylabel, NULL);
}

static void bam_summarize(FILEDATA * fd, int unitid)
{
	int i, j;
	int spont_stims;
	float *spont_resp;
	float *phase;
	float num_per, depth;
	int dur, delay, rise, fall;
	float *delay_vec;
	int from, to;
	int nbins = 100;
	float *t_depvals, *t_means, *t_stderrs, *t_stddevs;
	int *t_n;
	int use_delay;

	bam_freeSumData(fd);
	assert((SUMDATA(fd) =
		(SumData *) malloc(sizeof(SumData))) != NULL);

	/* determine params */

	spont_stims = FD_GI(fd, "Spont_Stims");
	dur = FD_GI(fd, "Dur");
	delay = FD_GI(fd, "Delay");
	rise = FD_GI(fd, "Rise");
	fall = FD_GI(fd, "Fall");
	num_per = (float) FD_GI(fd, "bam.Num_Periods");
	depth = (float) FD_GI(fd, "bam.Depth");
	use_delay = lookupParms_int("parms.use_delay");

	/* summarize the spont */

	if ((!spont_stims) || use_delay) {
		spont_resp =
		    (float *) malloc((FD1_RAWDATA(fd)->nrasters) *
				     sizeof(float));
		for (i = 0; i < FD1_RAWDATA(fd)->nrasters; i++) {
			spont_resp[i] =
			    countRaster(0.0, (float) delay,
					FD1_RAWDATA(fd)->rastlist[i],
					unitid);
			spont_resp[i] /= ((float) delay * nbins * 2 *
					  num_per / (float) (dur - rise -
							     fall));
		}
		SUMDATA(fd)->s_n = i;
		SUMDATA(fd)->s_mean = mean(spont_resp, i);
		SUMDATA(fd)->s_stderr = stderror(spont_resp, i);
		SUMDATA(fd)->s_stddev = stddev(spont_resp, i);
		free(spont_resp);
	} else {
		assert(FD1_RAWDATA(fd)->depints[0] == SPONT);
		SUMDATA(fd)->s_n = 1;
		SUMDATA(fd)->s_mean =
		    countRaster((float) delay, (float) (delay + dur),
				FD1_RAWDATA(fd)->rastlist[0], unitid);
		SUMDATA(fd)->s_mean /= ((float) dur * nbins * 2 * num_per /
					(float) (dur - rise - fall));
		SUMDATA(fd)->s_stderr = 0;
		SUMDATA(fd)->s_stddev = 0;
	}

	/* summarize the increasing IID portions */

	assert((phase =
		(float *) malloc(FD1_RAWDATA(fd)->nrasters *
				 sizeof(float))) != NULL);
	assert((delay_vec =
		(float *) malloc(FD1_RAWDATA(fd)->nrasters *
				 sizeof(float))) != NULL);

	from = 0;
	while ((from < FD1_RAWDATA(fd)->nrasters)
	       && (FD1_RAWDATA(fd)->depints[from] != 0))
		from++;
	to = from;
	while ((to < FD1_RAWDATA(fd)->nrasters)
	       && (FD1_RAWDATA(fd)->depints[to] != 1)) {
		delay_vec[to] =
		    delay + rise - (float) (dur - rise -
					    fall) / (float) num_per;
		phase[to] = 0.0;
		to++;
	}
	while (to < FD1_RAWDATA(fd)->nrasters) {
		delay_vec[to] =
		    delay + rise + (float) (dur - rise -
					    fall) / (float) (2 * num_per) -
		    (float) (dur - rise - fall) / (float) num_per;
		phase[to] = 0.0;
		to++;
	}

	for (i = 0; i < num_per; i++) {
		for (j = from; j < to; j++)
			delay_vec[j] +=
			    (float) (dur - rise - fall) / (float) num_per;
		FD_perhist((to - from), FD1_RAWDATA(fd)->rastlist + from,
			   delay_vec + from,
			   (float) (dur - rise - fall) / (2 * num_per),
			   (float) (dur - rise - fall) / (2 * num_per),
			   phase + from, nbins, &t_depvals, &t_means,
			   &t_stderrs, &t_stddevs, &t_n, NULL, NULL, NULL);
		if (i == 0) {
			SUMDATA(fd)->r_depvals = t_depvals;
			SUMDATA(fd)->r_means = t_means;
			SUMDATA(fd)->r_stderrs = t_stderrs;
			SUMDATA(fd)->r_stddevs = t_stddevs;
			SUMDATA(fd)->r_n = t_n;
		} else {
			for (j = 0; j < nbins; j++) {
				combine_mean(t_means[j], t_n[j],
					     SUMDATA(fd)->r_means[j],
					     SUMDATA(fd)->r_n[j],
					     &(SUMDATA(fd)->r_means[j]));
				combine_stddev(t_means[j], t_stddevs[j],
					       t_n[j],
					       SUMDATA(fd)->r_means[j],
					       SUMDATA(fd)->r_stddevs[j],
					       SUMDATA(fd)->r_n[j],
					       &(SUMDATA(fd)->
						 r_stddevs[j]));
				combine_stderr(t_means[j], t_stderrs[j],
					       t_n[j],
					       SUMDATA(fd)->r_means[j],
					       SUMDATA(fd)->r_stderrs[j],
					       SUMDATA(fd)->r_n[j],
					       &(SUMDATA(fd)->
						 r_stderrs[j]));
				SUMDATA(fd)->r_n[j] += t_n[j];
			}
		}
	}

	/* summarize the decreasing IID portions */

	from = 0;
	while ((from < FD1_RAWDATA(fd)->nrasters)
	       && (FD1_RAWDATA(fd)->depints[from] != 0))
		from++;
	to = from;
	while ((to < FD1_RAWDATA(fd)->nrasters)
	       && (FD1_RAWDATA(fd)->depints[to] != 1)) {
		delay_vec[to] =
		    delay + rise + (float) (dur - rise -
					    fall) / (float) (2 * num_per) -
		    (float) (dur - rise - fall) / (float) num_per;
		phase[to] = 0.0;
		to++;
	}
	while (to < FD1_RAWDATA(fd)->nrasters) {
		delay_vec[to] =
		    delay + rise - (float) (dur - rise -
					    fall) / (float) num_per;
		phase[to] = 0.0;
		to++;
	}

	for (i = 0; i < num_per; i++) {
		for (j = from; j < to; j++)
			delay_vec[j] +=
			    (float) (dur - rise - fall) / (float) num_per;
		FD_perhist((to - from), FD1_RAWDATA(fd)->rastlist + from,
			   delay_vec + from,
			   (float) (dur - rise - fall) / (2 * num_per),
			   (float) (dur - rise - fall) / (2 * num_per),
			   phase + from, nbins, &t_depvals, &t_means,
			   &t_stderrs, &t_stddevs, &t_n, NULL, NULL, NULL);
		if (i == 0) {
			SUMDATA(fd)->l_depvals = t_depvals;
			SUMDATA(fd)->l_means = t_means;
			SUMDATA(fd)->l_stderrs = t_stderrs;
			SUMDATA(fd)->l_stddevs = t_stddevs;
			SUMDATA(fd)->l_n = t_n;
		} else {
			for (j = 0; j < nbins; j++) {
				combine_mean(t_means[j], t_n[j],
					     SUMDATA(fd)->l_means[j],
					     SUMDATA(fd)->l_n[j],
					     &(SUMDATA(fd)->l_means[j]));
				combine_stddev(t_means[j], t_stddevs[j],
					       t_n[j],
					       SUMDATA(fd)->l_means[j],
					       SUMDATA(fd)->l_stddevs[j],
					       SUMDATA(fd)->l_n[j],
					       &(SUMDATA(fd)->
						 l_stddevs[j]));
				combine_stderr(t_means[j], t_stderrs[j],
					       t_n[j],
					       SUMDATA(fd)->l_means[j],
					       SUMDATA(fd)->l_stderrs[j],
					       SUMDATA(fd)->l_n[j],
					       &(SUMDATA(fd)->
						 l_stderrs[j]));
				SUMDATA(fd)->l_n[j] += t_n[j];
			}
		}
	}

	/* swap the decreasing IIDs around so they increase from 0 to 2*M_PI */

	assert((nbins % 2) == 0);

	for (i = 0; i < nbins / 2; i++) {
		SWAP(SUMDATA(fd)->l_means[i],
		     SUMDATA(fd)->l_means[nbins - i - 1], float);
		SWAP(SUMDATA(fd)->l_stderrs[i],
		     SUMDATA(fd)->l_stderrs[nbins - i - 1], float);
		SWAP(SUMDATA(fd)->l_stddevs[i],
		     SUMDATA(fd)->l_stddevs[nbins - i - 1], float);
		SWAP(SUMDATA(fd)->l_n[i], SUMDATA(fd)->l_n[nbins - i - 1],
		     int);
	}

	/*  sum over increasing and decreasing to get combined */

	assert((SUMDATA(fd)->b_depvals =
		(float *) malloc(nbins * sizeof(float))) != NULL);
	assert((SUMDATA(fd)->b_means =
		(float *) malloc(nbins * sizeof(float))) != NULL);
	assert((SUMDATA(fd)->b_stddevs =
		(float *) malloc(nbins * sizeof(float))) != NULL);
	assert((SUMDATA(fd)->b_stderrs =
		(float *) malloc(nbins * sizeof(float))) != NULL);
	assert((SUMDATA(fd)->b_n =
		(int *) malloc(nbins * sizeof(int))) != NULL);

	for (j = 0; j < nbins; j++) {
		SUMDATA(fd)->b_depvals[j] = SUMDATA(fd)->l_depvals[j];
		combine_mean(SUMDATA(fd)->l_means[j], SUMDATA(fd)->l_n[j],
			     SUMDATA(fd)->r_means[j], SUMDATA(fd)->r_n[j],
			     &(SUMDATA(fd)->b_means[j]));
		combine_stddev(SUMDATA(fd)->l_means[j],
			       SUMDATA(fd)->l_stddevs[j],
			       SUMDATA(fd)->l_n[j],
			       SUMDATA(fd)->r_means[j],
			       SUMDATA(fd)->r_stddevs[j],
			       SUMDATA(fd)->r_n[j],
			       &(SUMDATA(fd)->b_stddevs[j]));
		combine_stderr(SUMDATA(fd)->l_means[j],
			       SUMDATA(fd)->l_stderrs[j],
			       SUMDATA(fd)->l_n[j],
			       SUMDATA(fd)->r_means[j],
			       SUMDATA(fd)->r_stderrs[j],
			       SUMDATA(fd)->r_n[j],
			       &(SUMDATA(fd)->b_stderrs[j]));
		SUMDATA(fd)->b_n[j] =
		    SUMDATA(fd)->l_n[j] + SUMDATA(fd)->r_n[j];
	}

	SUMDATA(fd)->l_ndata = nbins;
	SUMDATA(fd)->r_ndata = nbins;
	SUMDATA(fd)->b_ndata = nbins;

	/* convert to dB */

	for (i = 0; i < nbins; i++) {
		SUMDATA(fd)->r_depvals[i] -= M_PI;
		SUMDATA(fd)->r_depvals[i] *= (depth / M_PI);
		SUMDATA(fd)->l_depvals[i] -= M_PI;
		SUMDATA(fd)->l_depvals[i] *= (depth / M_PI);
		SUMDATA(fd)->b_depvals[i] -= M_PI;
		SUMDATA(fd)->b_depvals[i] *= (depth / M_PI);
	}

	free(phase);
	free(t_depvals);
	free(t_means);
	free(t_stddevs);
	free(t_stderrs);
	free(t_n);
}

static void bam_freeSumData(FILEDATA * data)
{
	if (SUMDATA(data) != NULL) {
		FREE(SUMDATA(data)->l_depvals);
		FREE(SUMDATA(data)->l_means);
		FREE(SUMDATA(data)->l_stderrs);
		FREE(SUMDATA(data)->l_stddevs);
		FREE(SUMDATA(data)->l_n);

		FREE(SUMDATA(data)->r_depvals);
		FREE(SUMDATA(data)->r_means);
		FREE(SUMDATA(data)->r_stderrs);
		FREE(SUMDATA(data)->r_stddevs);
		FREE(SUMDATA(data)->r_n);

		FREE(SUMDATA(data));
	}
}

static int bam_free(FILEDATA * data)
{
	FD1_freeRawData(data);
	bam_freeSumData(data);

	return (1);
}

int pm_bam_init(void)
{
	setFDOvalidviewMethod("bam", bam_valid_view);
	setFDOdoplotMethod("bam", bam_do_plot);
	setFDOplotMethod("bam", bam_plotter);
	setFDreadMethod("bam", FD1_reader);
	setFDfreeMethod("bam", bam_free);
	return (1);
}
