/* Functions to calculate the spike density (Richmond etal 1990). */

#ifdef MNEMOSYNE
#include "mnemosyne.h"
#endif

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#include "srl.h"
#include "density.h"

void SA_InitSpikeDensity(
     SA_SpikeDensity **sdp,
     float dur,			/* collection duration in (msecs) */
     int ng)			/* total number of kernals */
{
  *sdp = (SA_SpikeDensity *)malloc(sizeof(SA_SpikeDensity));
  
  (*sdp)->m = (float *)calloc((int)(dur+1),sizeof(float));
  (*sdp)->dur = (int)dur;

  (*sdp)->gu = (float *)malloc(ng*sizeof(float));
  (*sdp)->gs = (float *)malloc(ng*sizeof(float));
  (*sdp)->ng = ng;
}

void SA_FreeSpikeDensity(
     SA_SpikeDensity *sd)
{
  if (sd) {
    free((char *)sd->m);
    free((char *)sd->gu);
    free((char *)sd->gs);
    free((char *)sd);
  }
}

void SA_InitSpikeDensityList(
     SA_SpikeDensityList **sdlp,
     int ndensities,
     int dur)			/* duration of each SpikeDensity */
{
  *sdlp = (SA_SpikeDensityList *)malloc(sizeof(SA_SpikeDensityList));
  (*sdlp)->sd = (SA_SpikeDensity **)malloc(ndensities*sizeof(SA_SpikeDensity *));
  (*sdlp)->ndensities = ndensities;

  (*sdlp)->mean = (float *)calloc(dur,sizeof(float));
  (*sdlp)->sig = (float *)calloc(dur,sizeof(float));
  (*sdlp)->sem = (float *)calloc(dur,sizeof(float));
  (*sdlp)->dur = dur;
}

void SA_FreeSpikeDensityList(
     SA_SpikeDensityList *sdl)
{
  int i;
  if (sdl == NULL) return;

  for (i=0; i<sdl->ndensities; i++)
    SA_FreeSpikeDensity(sdl->sd[i]);

  if (sdl->mean) free((char *)sdl->mean);
  if (sdl->sd) free((char *)sdl->sd);
  if (sdl->sem) free((char *)sdl->sem);
  if (sdl->sig) free((char *)sdl->sig);

  free((char *)sdl);
}

void SA_CalcSpikeDensity(
     SA_SpikeRaster *sr,
     SA_SpikeDensity **sdp,
     float sigma_p,
     int num_iterations)
{
  int i,s,t,c;
  float u,x,z,sp,sp2,mu;
  SA_SpikeDensity *sd;

  if (sr == NULL) return;

  sp = sigma_p;
  sp2 = sp*sp;

  /* get the normalization right.  Want height to be in spikes/sec */
  x = sqrt(sp2*log(2.0) - 1.0)/1000.0;
  z = 1.0/(2.0*x*2.7182818);

  SA_InitSpikeDensity(sdp,sr->dur,sr->nspikes);
  sd = *sdp;

  /** Calculate the "pilot" estimate **/

  /* set the initial kernal values */
  for (s=0; s<sr->nspikes; s++) {
    sd->gu[s] = sr->tm[s];
    sd->gs[s] = sp;
  }

  /* add up the kernals */
  for (i=0; i<sd->ng; i++) {
    u = sd->gu[i];
    for (t=-3.0*sp; t<=3.0*sp; t++) {
      if ((int)(u+t) < 0 || (int)(u+t) >= sd->dur)
	continue;
      sd->m[(int)(u+t)] += exp(-t*t/sp2);
    }
  }

  /* compute the geometric mean of the pilot estimates */

  for (c=0; c<num_iterations; c++) {
    mu=0.0;
    for (i=0; i<sd->ng; i++)
      if (sd->m[(int)sd->gu[i]] > 0.0) 
	mu+=log(sd->m[(int)sd->gu[i]]);
    mu = exp(mu/(float)sd->ng);

    /* estimate the local bandwidth factors */
    for (i=0; i<sd->ng; i++)
      sd->gs[i] /= sqrt(sd->m[(int)sd->gu[i]]/mu);

    /* recompute the density estimate */
    for (i=0; i<sd->dur; i++) sd->m[i]=0.0;
    for (i=0; i<sd->ng; i++) {
      u = sd->gu[i];
      sp = sd->gs[i];
      sp2 = sp*sp;
      for (t=-3.0*sp; t<=3.0*sp; t++) {
	if ((int)(u+t) < 0 || (int)(u+t) >= sd->dur)
	  continue;
	sd->m[(int)(u+t)] += exp(-t*t/sp2);
      }
    }
  }

  /* normalize */
  for (t=0; t<sd->dur; t++)
    sd->m[t] *= z;
}

void SA_CalcSpikeDensityList(
     SA_SpikeRasterList *srl,
     SA_SpikeDensityList **sdlp,
     float sigma_p,
     int num_iterations)
{
  int i,t;
  SA_SpikeDensityList *sdl;

  SA_InitSpikeDensityList(sdlp,srl->nrasters,(int)srl->sr[0]->dur);
  sdl = *sdlp;

  /* estimate density for each raster */
  for (i=0; i<srl->nrasters; i++) {
    SA_CalcSpikeDensity(srl->sr[i],&sdl->sd[i],sigma_p, num_iterations);
  }

  /* calculate the mean for the spike densities */
  for (t=0; t<sdl->dur; t++) {
    sdl->mean[t]=0.0;
    for (i=0; i<sdl->ndensities; i++)
      sdl->mean[t] += sdl->sd[i]->m[t];
    sdl->mean[t] /= (float)sdl->ndensities;
  }
}
