/*
**	This file is part of XDowl
**	Copyright (c) 1994 Jamie Mazer
**	California Institute of Technology
**	<mazer@asterix.cns.caltech.edu>
*/

/******************************************************************
**  RCSID: $Id: pm_abi.c,v 1.16 2003/09/15 18:52:55 bja Exp $
** Program: dowl
**  Module: pm_abi.c
**  Author: bjarthur
** Descrip: xdowl plot method -- freq. tuning curves (1 and 2 tone)
**
** Revision History (most recent last)
**
** NOV96 bjarthur
**  copied from pm_freq.c
**
** 97.3 bjarthur
**   added FitDataToTanh courtesy YA
**
*******************************************************************/

#include "xdphyslib.h"
#include "xdphys.h"
#include "plotter.h"

static void fit_menu(FDO *);
static int abi_plotter(FDO *, FILEDATA *, FDObj_ViewType *, int, FILE *);
static int abi_valid_view(FDO *, int);
static float EstimateInflexion(float *, float *, int);
static float FitLogisticPoint(float, float *);
static void FitLogisticDerivatives(float, float[], float *, float[], int);
static int FitDataToLogistic(float *, float *, float *, int, float **, float *,
			 float *, float *, float *, float *, float *);

static int view_order[][3] = {
	{PM_CURVE},
	{PM_MEANPHASE, PM_VS, PM_RAYLEIGH},
	{PM_PERHIST},
	{PM_RASTER, PM_RASTER_RAW},
	{PM_PSTH},
	{PM_ISIH}
};
static int nsubviews[] = { 1, 3, 1, 2, 1, 1 };
static int nviews = 6;

int abi_do_plot(FDO * fdo, FILEDATA * fd, int view, int l, FILE * fptr)
{
	float period;
	int nbins;
	static int last_view = PM_NONE;

	if (!abi_valid_view(fdo, view)) {
		pm_type_error("abi", view);
		view = PM_DEFAULT;
		(void) abi_valid_view(fdo, view);
	}

	FD_perhist_compute_period_nbins(fd, &period, &nbins);

	switch (view) {
	case PM_DEFAULT:
	case PM_CURVE:
		fit_menu(fdo);
		FD1_curveFN(fdo, fd, l, fptr);
		break;
	case PM_MEANPHASE:
		FD1_perhist1FN(fdo, fd, l, fptr, period, nbins, 0);
		break;
	case PM_VS:
		FD1_perhist1FN(fdo, fd, l, fptr, period, nbins, 1);
		break;
	case PM_RAYLEIGH:
		FD1_perhist1FN(fdo, fd, l, fptr, period, nbins, 2);
		break;
	case PM_PERHIST:{
			if (last_view != PM_PERHIST)
				fdo->vsm_data = VSM_ALL_CODE;
			FD1_perhist_stats_menu(fdo);
			FD_vsm_menu(fdo, 1);
			FD1_perhist2FN(fdo, fd, l, fptr, period, nbins);
		}
		break;
	case PM_RASTER:
		FD1_rasterFN(fdo, fd, l, 0, fptr);
		break;
	case PM_RASTER_RAW:
		FD1_rasterFN(fdo, fd, l, 1, fptr);
		break;
	case PM_PSTH:
		FD1_psthFN(fdo, fd, l, fptr);
		break;
	case PM_ISIH:
		FD1_isih_stats_menu(fdo);
		FD1_isihFN(fdo, fd, l, fptr);
		break;
	default:
		pm_type_error("abi", view);
		break;
	}

	last_view = view;
	return (1);
}

static int abi_valid_view(FDO * fdo, int view)
{
	int retval = 0;
	int i, j;

	if (view == PM_DEFAULT) {
		retval = 1;
	} else {
		for (i = 0; i < nviews; i++) {
			for (j = 0; j < nsubviews[i]; j++) {
				if (view_order[i][j] == view) {
					retval = 1;
					fdo->view.lr = i;
					fdo->view.ud = j;
					break;
				}
			}
		}
	}

	return (retval);
}

static int abi_plotter(FDO * fdo, FILEDATA * fd, FDObj_ViewType * view,
		       int l, FILE * fptr)
{
	syn_spec ss;

	fd_syn_spec_parse(FD_GV(fdo->fds[l], "abi.Stim"), 0, &ss);

	adjust_index(view->lr, view->ud);

	if (l == 0) {
		if (!((ss.class == SC_TONE) ||
		      ((ss.class == SC_STACK)
		       && (ss.parms.stack.num_freqs == 1)))) {
			if (view->lr == 1) {
				view->lr = 3;
				view->ud = 0;
			} else if (view->lr == 2) {
				view->lr = 0;
				view->ud = 0;
			}
		}
	} else {
		if (!((ss.class == SC_TONE) ||
		      ((ss.class == SC_STACK)
		       && (ss.parms.stack.num_freqs == 1))))
			if ((view->lr == 1) || (view->lr == 2)) {
				notify
				    ("some curves missing due to different stim types");
				return (0);
			}
	}

	if (nsubviews[view->lr] > 1) {
		XtSetSensitive(fdo->up, True);
		XtSetSensitive(fdo->down, True);
	}

	abi_do_plot(fdo, fd, view_order[view->lr][view->ud], l, fptr);

	return (1);
}


void Logistic(Widget w, XtPointer fdoptr, XtPointer call_data)
{
	FDO *fdo = (FDO *) fdoptr;

	int i, j, count, size;
	char *buf, buf2[128], *tmp;
	float *yfit, A, B, C, D, chi;
	float big;
	float *x, *y, ndata;
	float *depvals;
	float *means, *stderrs;
	syn_spec ss;
	char *foo;
	float Min, Max, Slope, From, To;
	float *px, *py, *pz;
	int pn;

	assert((buf = (char *) malloc(1024 * sizeof(char))) != NULL);
	size = 1024;

	tmp = "equation:  A + B * logistic(C * x + D)\n";
	strcpy(buf, tmp);
	count = strlen(tmp);
	tmp = "units:  A,B (spikes);  C (1/dB);  D (none)\n";
	strcat(buf, tmp);
	count += strlen(tmp);
	tmp =
	    "units:  Min, Max (spikes);  Slope (spikes/dB);  Range, Low, High (dB)\n\n";
	strcat(buf, tmp);
	count += strlen(tmp);
	tmp =
	    "file                 A      B       C        D       chi    |  Min    Max    Slope   Range   From   To\n";
	strcat(buf, tmp);
	count += strlen(tmp);
	tmp =
	    "----                 -      -       -        -       ---    |  ---    ---    -----   -----   ----   --\n";
	strcat(buf, tmp);
	count += strlen(tmp);

	for (i = 0; i < fdo->nfds; i++) {
		ndata = FD1_get_ndata(fdo->fds[i]);
		depvals = FD1_get_depvals(fdo->fds[i]);
		means = FD1_get_means(fdo->fds[i]);
		stderrs = FD1_get_stderrs(fdo->fds[i]);
		foo = FD_GV(fdo->fds[i], "itd.Stim");
		fd_syn_spec_parse(foo, 0, &ss);
		FitDataToLogistic(depvals+1, means+1,
			      stderrs+1, ndata-1, &yfit,
			      &A, &B, &C, &D, &chi, &big);

		Min = A;
		Max = B;
		Slope = B * C;
		From = -(log(9.0)+D) / C;
		To = -(log(1.0/9.0)+D) / C;

#ifdef __linux__
		count +=
		    sprintf(buf2,
			    "%-13s   %5.1f  %5.1f     %5.3f    %5.1f   %5.1f%1s |  %5.1f  %5.1f  %5.1f   %5.1f   %5.1f  %5.1f\n",
			    strrchr(fdo->fds[i]->filename,
				    '/') ==
			    NULL ? fdo->fds[i]->filename : 1 +
			    (char *) strrchr(fdo->fds[i]->filename, '/'),
			    A, B, C, D, chi,
			    ((int) big) ? "*" : " ", Min, Max, Slope,
			    (To - From), From, To);
#else
		count +=
		    strlen(sprintf
			   (buf2,
			    "%-13s   %5.1f  %5.1f     %5.3f    %5.1f   %5.1f%1s |  %5.1f  %5.1f  %5.1f   %5.1f   %5.1f  %5.1f\n",
			    strrchr(fdo->fds[i]->filename,
				    '/') ==
			    NULL ? fdo->fds[i]->filename : 1 +
			    (char *) strrchr(fdo->fds[i]->filename, '/'),
			    A, B, C, D, chi,
			    ((int) big) ? "*" : " ", Min, Max, Slope,
			    (To - From), From, To));
#endif
		if (fdo->to_tty) {
			printf("%s", buf2);
			continue;
		}
		strcat(buf, buf2);
		if ((size - count) < 512) {
			assert((buf =
				(char *) realloc(buf,
						 (size +
						  1024) * sizeof(char))) !=
			       NULL);
			size += 1024;
		}

		FD_plotter_copy(depvals + 1, yfit + 1, NULL, ndata - 1,
					0, &px, &py, &pz, &pn, NULL, NULL,
					NULL, NULL);
		FD_plotter_copy(depvals + 1, yfit + 1, NULL, ndata - 1, 0,
		     &px, &py, &pz, &pn, NULL, NULL, NULL, NULL);

		if (!fdo->no_X) {
			FDObj_Add_Data_All(fdo, px, py, pn);
			FDObj_AddLine(fdo, i, px, py, pz, pn, AtFloat,
				      atQuadLinePlotWidgetClass,
				      AtTypeLINEPOINTS, AtMarkCIRCLE,
				      ConvertColor(fdo->graph,
						   FDObj_Colors[(i+1) % FDObj_numColors]),
				      (XtArgVal) "fit");
		}
	}

	tmp = "\n* = chi is big, fitted parameters are questionable";
	strcat(buf, tmp);
	count += strlen(tmp);

	if (!fdo->to_tty)
		pop_text("Logistic", buf, strlen(buf) + 1, False);
}

static void fit_menu(FDO * fdobj)
{
	if (fdobj->no_X)
		return;

	menubutton_clear(fdobj->fit, &(fdobj->fitpsh));
	XtSetSensitive(fdobj->fit, False);

	menubutton_add(fdobj->fitpsh, "Logistic", Logistic, fdobj);
	XtSetSensitive(fdobj->fit, True);
}

int pm_abi_init(void)
{
	setFDOvalidviewMethod("abi", abi_valid_view);
	setFDOdoplotMethod("abi", abi_do_plot);
	setFDOplotMethod("abi", abi_plotter);
	setFDreadMethod("abi", FD1_reader);
	setFDfreeMethod("abi", FD1_free);
	return (1);
}




/****************** from yuda **************************/

#define CHI2_STOP_STEP (0.1)

int ndBRangeMin = -70;
int ndBRangeMax = 70;

static float EstimateInflexion(float *x, float *y, int ndata)
{
	float miny, maxy, medy;
	float x1, x2, y1, y2;
	int i;
	miny = maxy = y[1];
	for (i = 1; i < ndata; i++) {
		if (y[i] > maxy)
			maxy = y[i];
		if (y[i] < miny)
			miny = y[i];
	}
	medy = (miny + maxy) / 2;

	for (i = 1; i < ndata; i++)
		if (y[i] > medy)
			break;
	x1 = x[i - 1];
	x2 = x[i];
	y1 = y[i - 1];
	y2 = y[i];
	if (y2 == y1)
		return (x2 + x1) / 2;
	return x1 + (medy - y1) * (x2 - x1) / (y2 - y1);
}

static float FitLogisticPoint(float x, float *a)
{
	return a[1] + a[2] / (1.0 + exp(-a[3] * x - a[4]));
}

void FitLogisticDerivatives(float x, float a[], float *y, float dyda[], int na)
{
  int i;

	float X = -a[3]*x - a[4];
	float E = exp(X);
	*y = a[1] + a[2] / (1.0 + E);

	dyda[1] = 1.0;
	dyda[2] = 1.0/(1.0+E);
	dyda[3] = a[2]*x*E/(1.0+E)/(1.0+E);
	dyda[4] = a[2]*E/(1.0+E)/(1.0+E);
}

static int FitDataToLogistic(float *fAxis, float *fData, float *fSig,
			 int ndata, float **yfit, float *A, float *B,
			 float *C, float *D, float *chi,
			 float *big)
{
	float *x, *y, *sig;
	int i, *ia, ma, undermin, abovemax;
	float miny, maxy, *a, **covar, **alpha, alamda = -1, chisq = 10,
	    oldchi = 1000;
	void (*GeneralFit) (float, float[], float *, float[], int);


	GeneralFit = FitLogisticDerivatives;
	ma = 4;

	for (undermin = abovemax = i = 0; i < ndata; i++) {
		if (fAxis[i] < ndBRangeMin)
			undermin++;
		if (fAxis[i] > ndBRangeMax)
			abovemax++;
	}
	fAxis += undermin;
	fData += undermin;
	ndata -= (undermin + abovemax);

	alpha = matrix(1, ma, 1, ma);
	covar = matrix(1, ma, 1, ma);
	ia = ivector(1, ma);
	a = vector(1, ma);
	x = vector(1, ndata);
	y = vector(1, ndata);
	sig = vector(1, ndata);

	for (i = 1; i <= ndata; i++) {
		sig[i] = (fSig[i - 1]>0.1)?fSig[i-1]:0.1;
		x[i] = fAxis[i - 1];
		y[i] = fData[i - 1];
	}

	for (i = 1; i <= ma; i++) {
		ia[i] = 1; }

	miny = maxy = y[1];
	for (i = 2; i <= ndata; i++)
		if (y[i] > maxy)
			maxy = y[i];
	for (i = 2; i <= ndata; i++)
		if (y[i] < miny)
			miny = y[i];
	a[1] = miny;
	a[2] = maxy - miny;
	a[3] = 1.0 / 10.0;
	a[4] = -a[3] * EstimateInflexion(x, y, ndata);

	alamda = -1;
  do {
		oldchi = chisq;
		mrqmin(x, y, sig, ndata, a, ia, ma, covar, alpha, &chisq,
		       GeneralFit, &alamda); }
	while ((((oldchi - chisq) > CHI2_STOP_STEP) || ((oldchi - chisq) < -CHI2_STOP_STEP) || (oldchi==chisq)) && (alamda<MAXFLOAT));

	alamda = 0;
	mrqmin(x, y, sig, ndata, a, ia, ma, covar, alpha, &chisq, GeneralFit, &alamda);

	assert(((*yfit) =
		(float *) malloc((ndata + 1) * sizeof(float))) != NULL);
	for (i = 1; i <= ndata; i++)
		(*yfit)[i] = FitLogisticPoint(x[i], a);
	(*A) = a[1];
	(*B) = a[2];
	(*C) = a[3];
	(*D) = a[4];
	(*chi) = chisq;
	(*big) = (chisq > ((ndata - ma) + (float) sqrt(2.0 * (ndata - ma)))) ? 1 : 0;

	free_matrix(covar, 1, ma, 1, ma);
	free_matrix(alpha, 1, ma, 1, ma);
	free_ivector(ia, 1, ma);
	free_vector(a, 1, ma);
	free_vector(x, 1, ma);
	free_vector(y, 1, ma);
	free_vector(sig, 1, ma);
	fflush(stderr);
	return 1;
}
