/* get_ssm_activeness_c.c
 *
 * MEX file: C implementation of the "get_ssm_activeness" function,
 * used for example in "FULL_detect_locate.m" or "FAST_detect_locate.m".
 *
 * Our goal is to speedup things a bit (avoid those ugly repmat/permute things).
 */

#include <math.h>
#include <string.h>
#include "mex.h"

/* 1 or 0 */
#define VERBOSE             0

#if     VERBOSE == 1
#include "stdio.h"
#endif

/* Input Argument */

#define  N_INPUTS        0

/* Output Argument */

#define  N_OUTPUTS       0

/* All input/output are made through the "ws" global variable Note
 * that the outputs "ws.block.ssm_binary_decision" and
 * "ws.block.ssm_activeness" must be allocated by the user in advance.
 */

/* Misc */

#if !defined(MAX)
#define MAX(A, B)       ((A) > (B) ? (A) : (B))
#endif

#if !defined(MIN)
#define MIN(A, B)       ((A) < (B) ? (A) : (B))
#endif

#define PI 3.14159265



#if VERBOSE == 1
/* ---------------------------------------------------------------------- */
/* debugging functions */

void print_double_array( double* doublePtr, unsigned int N, const char * a_string ) {

  unsigned int n;

  printf( "\n" );
  printf( a_string );
  printf( "[]: " );

  for( n = 0; n < N; n++ ) {
    printf( " [%d]:%g,", n, doublePtr[n] );
  }

  printf( "\n" );
  
}

void print_unsigned_short_array( unsigned short* usPtr, unsigned int N, const char * a_string ) {

  unsigned int n;
  unsigned short * usmax;

  printf( "\n" );
  printf( a_string );
  printf( "[]: " );

  for( n = 0; n < N; n++ ) {
    printf( " [%d]:%d,", n, usPtr[n] );
  }

  printf( "\n" );
  
}

void print_unsigned_char_array( unsigned char* usPtr, unsigned int N, const char * a_string ) {

  unsigned int n;
  unsigned char * usmax;

  printf( "\n" );
  printf( a_string );
  printf( "[]: " );

  for( n = 0; n < N; n++ ) {
    printf( " [%d]:%d,", n, usPtr[n] );
  }

  printf( "\n" );
  
}
#endif


/* ----------------------------------------------------------------------
 * Wrapper: this is the gateway function to MATLAB
 * ---------------------------------------------------------------------- */

void mexFunction( int nlhs, mxArray *plhs[],
                  int nrhs, const mxArray*prhs[] )

{

  mxArray  *ws_mx, *block_mx, *a_struct_mx, *ssm_activeness_mx, *ssm_binary_decision_mx, *n_sectors_mx;
  mxArray  *ssm_observed_gccphat_mx, *conj_ssm_Z_mx;

  double   *conj_ssm_Z_real, *conj_ssm_Z_imag, *ssm_observed_gccphat_real, *ssm_observed_gccphat_imag;
  double   *ssm_activeness;
  
  unsigned char *ssm_binary_decision;
  
  int n_pairs, n_freq, n_sectors, n_frames, a, m, n;
  
  const int *dims;
  int ndims;

  int n_sectors_n_frames, n_freq_n_sectors_n_frames;

  unsigned char *iter_ssm_binary_decision_next_frame;
  
  double *iter_ssm_activeness, *iter_ssm_activeness_max;
  double *iter_conj_ssm_Z_real, *iter_conj_ssm_Z_real_max, *iter_conj_ssm_Z_imag;
  double *iter_adsp, *iter_adsp_this_sector, *iter_adsp_next_sector, *iter_adsp_next_pair;
  double *iter_ssm_observed_gccphat_real, *iter_ssm_observed_gccphat_imag;
  double *iter_real, *iter_imag, *iter_real_next_sector;

  double *adsp, *adsp_max, *max_val, *iter_max_val;
  int    *max_ind, *max_ind_max, *iter_max_ind;

  int n_freq_n_pairs, n_freq_n_sectors, n_freq_n_pairs_n_sectors, nf_ns_dble;

  int i_sector;
  
  double aa, bb;

  unsigned char *iter_ssm_binary_decision;

#if VERBOSE == 1
  int v_i_frame;
#endif  

  /* -------------------------------------------------- */
  /* ( 1 ) Get access to the global workspace */
  /* We assume that outputs "ws.block.ssm_binary_decision" and
   * "ws.block.ssm_activeness" are already allocated 
   */
  
  /* global ws */

  ws_mx = ( mxArray * ) mexGetVariablePtr( "global", "ws" );
  if ( ws_mx == NULL ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: could not get a pointer to the global variable 'ws'." );
  }
  if ( !mxIsStruct( ws_mx ) ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: the global variable 'ws' must be a structure." );
  }

  /* Now we can look for the fields of "ws" that we are interested in */

  /* ws.a_struct */
  
  a_struct_mx = ( mxArray * ) mxGetField( ws_mx, 0, "a_struct" );
  if ( a_struct_mx == NULL ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: could not get a pointer to 'ws.a_struct'." );
  }
  if ( !mxIsStruct( a_struct_mx ) ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: the global variable 'ws.a_struct' must be a structure." );
  }
       
  /* ws.a_struct.n_sectors */
  n_sectors_mx = ( mxArray * ) mxGetField( a_struct_mx, 0, "n_sectors" );
  if ( n_sectors_mx == NULL ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: could not get a pointer to 'ws.a_struct.n_sectors'" );
  }
  if ( mxGetClassID( n_sectors_mx ) != mxDOUBLE_CLASS ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.a_struct.n_sectors' must be a double." );
  }
  n_sectors = (int) mxGetScalar( n_sectors_mx );
  
  /* ws.a_struct.conj_ssm_Z */

  conj_ssm_Z_mx = ( mxArray * ) mxGetField( a_struct_mx, 0, "conj_ssm_Z" );
  if ( conj_ssm_Z_mx == NULL ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: could not get a pointer to 'ws.a_struct.conj_ssm_Z'." );
  }
  if ( mxGetClassID( conj_ssm_Z_mx ) != mxDOUBLE_CLASS ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.a_struct.conj_ssm_Z' must be an array of double." );
  }
  if ( !mxIsComplex( conj_ssm_Z_mx ) ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.a_struct.conj_ssm_Z' must be complex." );
  }
  ndims = mxGetNumberOfDimensions( conj_ssm_Z_mx );
  if ( ndims != 3 ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.a_struct.conj_ssm_Z' must be a 3-D array." );
  }
  dims = mxGetDimensions( conj_ssm_Z_mx );
  if ( dims[ 2 ] != n_sectors ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: size( ws.a_struct.conj_ssm_Z, 3 ) must be equal to n_sectors." );
  }
  n_freq  = dims[ 0 ];
  n_pairs = dims[ 1 ];

  conj_ssm_Z_real = mxGetPr( conj_ssm_Z_mx );
  conj_ssm_Z_imag = mxGetPi( conj_ssm_Z_mx );
  

  /* ws.block */

  block_mx = ( mxArray * ) mxGetField( ws_mx, 0, "block" );
  if ( block_mx == NULL ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: could not get a pointer to 'ws.block'." );
  }
  if ( !mxIsStruct( block_mx ) ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: the global variable 'ws.block' must be a structure." );
  }
  
  /* ws.block.ssm_observed_gccphat */

  ssm_observed_gccphat_mx = ( mxArray * ) mxGetField( block_mx, 0, "ssm_observed_gccphat" );
  if ( ssm_observed_gccphat_mx == NULL ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: could not get a pointer to 'ws.block.ssm_observed_gccphat'." );
  }
  if ( mxGetClassID( ssm_observed_gccphat_mx ) != mxDOUBLE_CLASS ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.block.ssm_observed_gccphat' must be an array of double." );
  }
  if ( !mxIsComplex( ssm_observed_gccphat_mx ) ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.block.ssm_observed_gccphat' must be complex." );
  }
  ndims = mxGetNumberOfDimensions( ssm_observed_gccphat_mx );
  if ( ndims != 3 ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.block.ssm_observed_gccphat' must be a 3-D array." );
  }
  dims = mxGetDimensions( ssm_observed_gccphat_mx );
  if ( ( dims[ 0 ] != n_freq ) || ( dims[ 1 ] != n_pairs ) ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.block.ssm_observed_gccphat' must be a  n_freq-by-n_pairs-by-n_frames  array." );
  }
  n_frames = dims[ 2 ];

  ssm_observed_gccphat_real = mxGetPr( ssm_observed_gccphat_mx );
  ssm_observed_gccphat_imag = mxGetPi( ssm_observed_gccphat_mx );
  
  /* ws.block.ssm_binary_decision */
  
  ssm_binary_decision_mx = ( mxArray * ) mxGetField( block_mx, 0, "ssm_binary_decision" );
  if ( ssm_binary_decision_mx == NULL ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: could not get a pointer to 'ws.block.ssm_binary_decision'." );
  }
  if ( mxGetClassID( ssm_binary_decision_mx ) != mxUINT8_CLASS ) { 
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.block.ssm_binary_decision' must be an array of uint8." );
  }
  ndims = mxGetNumberOfDimensions( ssm_binary_decision_mx );
  if ( ndims != 3 ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.block.ssm_binary_decision' must be a 3-D array." );
  }
  dims = mxGetDimensions( ssm_binary_decision_mx );
  if ( ( dims[ 0 ] != n_freq ) || ( dims[ 1 ] != n_sectors ) || ( dims[ 2 ] != n_frames ) ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.block.ssm_binary_decision' must be a n_freq-by-n_sectors-by-n_frames array." );
  }

  ssm_binary_decision = (unsigned char*) mxGetData( ssm_binary_decision_mx );

  
  /* initialize this output with zeroes */

  n_sectors_n_frames = n_sectors * n_frames;

  n_freq_n_sectors_n_frames = n_freq * n_sectors_n_frames;

  memset( ssm_binary_decision, 0, n_freq_n_sectors_n_frames * sizeof( unsigned char ) );
  
    /* ws.block.ssm_activeness */
  
  ssm_activeness_mx = ( mxArray * ) mxGetField( block_mx, 0, "ssm_activeness" );
  if ( ssm_activeness_mx == NULL ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: could not get a pointer to 'ws.block.ssm_activeness'." );
  }
  if ( mxGetClassID( ssm_activeness_mx ) != mxDOUBLE_CLASS ) { 
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.block.ssm_activeness' must be an array of double." );
  }
  ndims = mxGetNumberOfDimensions( ssm_activeness_mx );
  if ( ndims != 2 ) {
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.block.ssm_activeness' must be a 2-D array." );
  }
  m = mxGetM( ssm_activeness_mx );
  n = mxGetN( ssm_activeness_mx );
  if ( ( m != n_sectors ) || ( n != n_frames ) ) { 
    mexErrMsgTxt( "get_ssm_activeness_c.c: 'ws.block.ssm_activeness' must be n_sectors-by-n_frames." );
  }

  ssm_activeness = mxGetPr( ssm_activeness_mx );

  /* initialize this output with zeroes */
  
  memset( ssm_activeness, 0, n_sectors_n_frames * sizeof( double ) );

  /* Some verbosity */
  
#if VERBOSE == 1
  printf( "get_ssm_activeness_c.c:  n_freq:%d  n_pairs:%d  n_sectors:%d  n_frames:%d\n", n_freq, n_pairs, n_sectors, n_frames );
#endif
  
  /* --------------------------------------------------
   * ( 2 ) Processing: SAM-SPARSE-MEAN approach to 
   * determine the activity in each sector of space. */
  
  
  n_freq_n_pairs = n_freq * n_pairs;
  n_freq_n_pairs_n_sectors = n_freq_n_pairs * n_sectors;
  n_freq_n_sectors = n_freq * n_sectors;


  nf_ns_dble = n_freq_n_sectors * sizeof( double );
  adsp = ( double * ) malloc( nf_ns_dble );
  adsp_max = adsp + n_freq_n_sectors;
  
  max_ind = ( int * ) malloc( n_freq * sizeof( int ) );
  max_ind_max = max_ind + n_freq;
  max_val = ( double * ) malloc( n_freq * sizeof( double ) );
  
  iter_ssm_observed_gccphat_real = ssm_observed_gccphat_real;
  iter_ssm_observed_gccphat_imag = ssm_observed_gccphat_imag;

  iter_conj_ssm_Z_real_max = conj_ssm_Z_real + n_freq_n_pairs_n_sectors;
  
  iter_ssm_binary_decision = ssm_binary_decision;

#if VERBOSE == 1
  v_i_frame = 0;
#endif

  /* Loop #1: through frames */
  iter_ssm_activeness     = ssm_activeness;
  iter_ssm_activeness_max = ssm_activeness + n_sectors_n_frames;
  while ( iter_ssm_activeness < iter_ssm_activeness_max ) {
    
#if VERBOSE == 1
    printf( "  frame #%d\n", v_i_frame++ );
#endif

    /* ( 2.1 ) Compute average delay-sum power (a.d.s.p.) */

    memset( adsp, 0, nf_ns_dble );
    
    /* first frequency, first sector */
    iter_adsp_this_sector = adsp;
    
    /* We are going to compare measured phase values
     * to theoretical phase values */
    
    iter_conj_ssm_Z_real = conj_ssm_Z_real;
    iter_conj_ssm_Z_imag = conj_ssm_Z_imag;
    
    /* loop through sectors */

    iter_real_next_sector = iter_ssm_observed_gccphat_real + n_freq_n_pairs;

    while ( iter_conj_ssm_Z_real < iter_conj_ssm_Z_real_max ) {
      
      /* Current frequency, current pair, current frame */
      iter_real = iter_ssm_observed_gccphat_real;
      iter_imag = iter_ssm_observed_gccphat_imag;
      
      iter_adsp_next_pair   = iter_adsp_this_sector + n_freq;
      
      /* Loop through microphone pairs */

      while ( iter_real < iter_real_next_sector ) {
	
	iter_adsp           = iter_adsp_this_sector;
	
	/* Loop through frequencies */

	while ( iter_adsp < iter_adsp_next_pair ) {
	  
	  (*iter_adsp++) += (*iter_real++) * (*iter_conj_ssm_Z_real++) - (*iter_imag++) * (*iter_conj_ssm_Z_imag++);
	  
	} /* end of loop thru frequencies */
	
      } /* end of loop thru pairs */
      
      /* for adsp computation of the next sector */
      iter_adsp_this_sector = iter_adsp;
      
    }  /* adsp loop (thru sectors) */
    
    /* prepare for next frame */
    iter_ssm_observed_gccphat_real = iter_real;
    iter_ssm_observed_gccphat_imag = iter_imag;
    
    
    /* ( 2.2 ) Sparsity assumption: for each frequency, find the
     * sector with maximum a.d.s.p. */
    
    /* At first, our "max a.d.s.p." candidate is the first sector */
    
    iter_adsp = adsp;
    iter_max_ind = max_ind;
    iter_max_val = max_val;
    while ( iter_max_ind < max_ind_max ) {
      
      (*iter_max_val++) = (*iter_adsp++);
      (*iter_max_ind++) = 0;

    }
    
    /* ... now let's look at remaining sectors */
    
    i_sector = 1;

    while ( iter_adsp < adsp_max ) {
      
      iter_max_ind = max_ind;
      iter_max_val = max_val;
      
      /* Loop through frequencies */

      while ( iter_max_ind < max_ind_max ) {

	aa = (*iter_adsp);
	bb = (*iter_max_val);

	if ( aa > bb ) {
	
	  /* Update the "maximum a.d.s.p." candidate */
  
	  (*iter_max_val) = aa;
	  (*iter_max_ind) = i_sector;
	  
	} else if ( aa == bb ) {
	  
	  /* Currently no sector has the strict maximum a.d.s.p. value -> mark it with a "-1" */

	  (*iter_max_ind) = -1;
	  
	}
	
	iter_adsp++;
	iter_max_val++;
	iter_max_ind++;

      }
      
      i_sector++;
 
    }

    
    /* ( 2.3 ) Put it together: binary decision for each frequency,
     * and total activeness for each sector.  Here I tried to
     * eliminate multiplications (e.g. for indexing) as much as
     * possible, hence the seemingly "redundant" implementation. */
    
    i_sector = 0;
    
    iter_ssm_binary_decision_next_frame = iter_ssm_binary_decision + n_freq_n_sectors;
    
    /* loop through sectors */

    while ( iter_ssm_binary_decision < iter_ssm_binary_decision_next_frame ) {
      
      iter_max_ind = max_ind;
      
      /* Loop through frequencies */
      
      while ( iter_max_ind < max_ind_max ) {
	
	if ( ( *iter_max_ind ) == i_sector ) {
	  
	  *iter_ssm_binary_decision = 1;
	  ( *iter_ssm_activeness )++;

	}
	
	/* Move to the next frequency */
	
	iter_max_ind++;
	iter_ssm_binary_decision++;

      }

      /* Normalize activeness: result is a value between 0 and 1 */
      *iter_ssm_activeness /= n_freq;

      /* Move to the next sector */

      i_sector++;
      iter_ssm_activeness++;

    } /* end of "while ( iter_ssm_binary_decision < iter_ssm_binary_decision_next_frame )" */

  }  /* end of "while ( iter_ssm_activeness < iter_ssm_activeness_max )" */
  
  /* free memory */
  
  free( adsp );
  free( max_ind );
  free( max_val );
  
}
