/* compute_gccphat_c.c
 *
 * MEX file: C implementation of the following MATLAB line:
 *
 *   ws.block.ssm_observed_gccphat = exp( j * angle( X( :, ws.ssm.pairs( :,1 ), : ) .* conj( X( :, ws.ssm.pairs( :,2 ), : ) ) ) );
 *
 * where "ws" is a global variable, as for example in "FULL_detect_locate.m"
 * or "FAST_detect_locate.m".
 *
 * Our goal is to speedup things a bit.
 */

#include <math.h>
#include <string.h>
#include "mex.h"

/* 1 or 0 */
#define VERBOSE             0

#if     VERBOSE == 1
#include "stdio.h"
#endif

/* Input Argument */

#define  N_INPUTS        0

/* Output Argument */

#define  N_OUTPUTS       0

/* All input/output are made through the "ws" global variable Note
 * that the outputs "ws.block.ssm_binary_decision" and
 * "ws.block.ssm_activeness" must be allocated by the user in advance.
 */

/* Misc */

#if !defined(MAX)
#define MAX(A, B)       ((A) > (B) ? (A) : (B))
#endif

#if !defined(MIN)
#define MIN(A, B)       ((A) < (B) ? (A) : (B))
#endif

#define PI 3.14159265



#if VERBOSE == 1
/* ---------------------------------------------------------------------- */
/* debugging functions */

void print_double_array( double* doublePtr, unsigned int N, const char * a_string ) {

  unsigned int n;

  printf( "\n" );
  printf( a_string );
  printf( "[]: " );

  for( n = 0; n < N; n++ ) {
    printf( " [%d]:%g,", n, doublePtr[n] );
  }

  printf( "\n" );
  
}

void print_unsigned_short_array( unsigned short* usPtr, unsigned int N, const char * a_string ) {

  unsigned int n;
  unsigned short * usmax;

  printf( "\n" );
  printf( a_string );
  printf( "[]: " );

  for( n = 0; n < N; n++ ) {
    printf( " [%d]:%d,", n, usPtr[n] );
  }

  printf( "\n" );
  
}

void print_unsigned_char_array( unsigned char* usPtr, unsigned int N, const char * a_string ) {

  unsigned int n;
  unsigned char * usmax;

  printf( "\n" );
  printf( a_string );
  printf( "[]: " );

  for( n = 0; n < N; n++ ) {
    printf( " [%d]:%d,", n, usPtr[n] );
  }

  printf( "\n" );
  
}
#endif


/* ----------------------------------------------------------------------
 * Wrapper: this is the gateway function to MATLAB
 * ---------------------------------------------------------------------- */

void mexFunction( int nlhs, mxArray *plhs[],
                  int nrhs, const mxArray*prhs[] )

{

  mxArray   *ws_mx, *ssm_mx, *pairs_mx, *block_mx, *X_mx, *ssm_observed_gccphat_mx;

  /* npairs-by-2 matrix of integer indices in 1...n_channels */
  unsigned short *pairs;

  /* input (complex values) */
  double  *X_real, *X_imag;

  /* output (complex values) */
  double  *ssm_observed_gccphat_real, *ssm_observed_gccphat_imag;
  
  int n_pairs, n_freq, n_channels, n_frames;

  
  double   *iter_gccphat_real,  *iter_gccphat_imag, *iter_X_real,  *iter_X_imag;
  double   *iter_X_real_mic1,  *iter_X_real_mic2, *iter_X_imag_mic1, *iter_X_imag_mic2;
  double   *iter_gccphat_real_max, *iter_gccphat_real_next_frame,  *iter_gccphat_real_next_pair;
  unsigned short *iter_mic1, *iter_mic2,  *iter_pairs,  *iter_pairs_max;

  int a, n_freq_n_channels,  n_freq_n_pairs;

  int ndims;
  const int *dims;

  int  *mic_shift,  *iter_mic_shift, *iter_mic_shift_max;

  double aa, bb, cc, rr1, ii1, rr2, ii2;

#if VERBOSE == 1 
  int v_i_frame;
#endif
  
  /* -------------------------------------------------- */
  /* ( 1 ) Get access to the global workspace */
  /* We assume that the output "ws.block.observed_gccphat" is already allocated
   */

  /* global ws */

  ws_mx = ( mxArray * ) mexGetVariablePtr( "global", "ws" );
  if ( ws_mx == NULL ) {
    mexErrMsgTxt( "compute_gccphat_c.c: could not get a pointer to the global variable 'ws'." );
  }
  if ( !mxIsStruct( ws_mx ) ) {
    mexErrMsgTxt( "compute_gccphat_c.c: the global variable 'ws' must be a structure." );
  }
  /* Now we can look for the fields of "ws" that we are interested in */

  /* ws.ssm */

  ssm_mx = ( mxArray * ) mxGetField( ws_mx, 0, "ssm" );
  if ( ssm_mx == NULL ) {
    mexErrMsgTxt( "compute_gccphat_c.c: could not get a pointer to 'ws.ssm'." );
  }
  if ( !mxIsStruct( ssm_mx ) ) {
    mexErrMsgTxt( "compute_gccphat_c.c: the global variable 'ws.ssm' must be a structure." );
  }

  /* ws.ssm.pairs */

  pairs_mx = ( mxArray * ) mxGetField( ssm_mx, 0, "pairs" );
  if ( pairs_mx == NULL ) {
    mexErrMsgTxt( "compute_gccphat_c.c: could not get a pointer to 'ws.ssm.pairs'." );
  }
  if ( mxGetClassID( pairs_mx ) != mxUINT16_CLASS ) {
    mexErrMsgTxt( "compute_gccphat_c.c: 'ws.ssm.pairs' must be an array of uint16." );
  }
  n_pairs = mxGetM( pairs_mx );
  pairs   = ( unsigned short * ) mxGetData( pairs_mx );
  
  
  /* ws.block */

  block_mx = ( mxArray * ) mxGetField( ws_mx, 0, "block" );
  if ( block_mx == NULL ) {
    mexErrMsgTxt( "compute_gccphat_c.c: could not get a pointer to 'ws.block'." );
  }
  if ( !mxIsStruct( block_mx ) ) {
    mexErrMsgTxt( "compute_gccphat_c.c: the global variable 'ws.block' must be a structure." );
  }
  
  /* ws.block.X */

  X_mx = ( mxArray * ) mxGetField( block_mx, 0, "X" );
  if ( X_mx == NULL ) {
    mexErrMsgTxt( "compute_gccphat_c.c: could not get a pointer to 'ws.block.X'." );
  }
  if ( mxGetClassID( X_mx ) != mxDOUBLE_CLASS ) {
    mexErrMsgTxt( "compute_gccphat_c.c: 'ws.block.X' must be an array of double." );
  }
  if ( !mxIsComplex( X_mx ) ) {
    mexErrMsgTxt( "compute_gccphat_c.c: 'ws.block.X' must be complex." );
  }
  ndims = mxGetNumberOfDimensions( X_mx );
  if ( ( ndims != 2 ) && ( ndims != 3 ) ) {
    mexErrMsgTxt( "compute_gccphat_c.c: 'ws.block.X' must be a 2-D or 3-D array." );
  }
  dims = mxGetDimensions( X_mx );
  n_freq     = dims[ 0 ];
  n_channels = dims[ 1 ];
  if ( ndims == 2 ) {
    n_frames = 1;
  } else {
    n_frames   = dims[ 2 ];
  }
  
  X_real = mxGetPr( X_mx );
  X_imag = mxGetPi( X_mx );

  /* check the pairs' definitions */
  iter_pairs     = pairs;
  iter_pairs_max = pairs + n_pairs * 2;
  while ( iter_pairs < iter_pairs_max ) {
    a = *iter_pairs++;
    if ( ( a < 1 ) || ( a > n_channels ) ) {
      mexErrMsgTxt( "compute_gccphat_c.c: all elements in 'ws.ssm.pairs' must be within 1...n_channels." );
    }
  }

  /* ws.block.ssm_observed_gccphat */

  ssm_observed_gccphat_mx = ( mxArray * ) mxGetField( block_mx, 0, "ssm_observed_gccphat" );
  if ( ssm_observed_gccphat_mx == NULL ) {
    mexErrMsgTxt( "compute_gccphat_c.c: could not get a pointer to 'ws.block.ssm_observed_gccphat'." );
  }
  if ( mxGetClassID( ssm_observed_gccphat_mx ) != mxDOUBLE_CLASS ) {
    mexErrMsgTxt( "compute_gccphat_c.c: 'ws.block.ssm_observed_gccphat' must be an array of double." );
  }
  if ( !mxIsComplex( ssm_observed_gccphat_mx ) ) {
    mexErrMsgTxt( "compute_gccphat_c.c: 'ws.block.ssm_observed_gccphat' must be complex." );
  }
  ndims = mxGetNumberOfDimensions( ssm_observed_gccphat_mx );
  if ( ndims != 3 ) {
    mexErrMsgTxt( "compute_gccphat_c.c: 'ws.block.ssm_observed_gccphat' must be a 3-D array." );
  }
  dims = mxGetDimensions( ssm_observed_gccphat_mx );
  if ( ( dims[ 0 ] != n_freq ) || ( dims[ 1 ] != n_pairs ) || ( dims[ 2 ] != n_frames ) ) {
    mexErrMsgTxt( "compute_gccphat_c.c: 'ws.block.ssm_observed_gccphat' must be a  n_freq-by-n_pairs-by-n_frames  array." );
  }
  
  ssm_observed_gccphat_real = mxGetPr( ssm_observed_gccphat_mx );
  ssm_observed_gccphat_imag = mxGetPi( ssm_observed_gccphat_mx );

#if VERBOSE == 1
  printf( "compute_gccphat_c.c: n_pairs:%d, n_freq:%d, n_channels:%d, n_frames:%d\n", n_pairs, n_freq, n_channels, n_frames );
#endif

  /* Save a bit of time: prepare the list of shifts for each microphone */
  /* This way we remove two multiplications for each pair */

  mic_shift = ( int* ) malloc( 2 * n_pairs * sizeof( int ) );
  
  iter_mic1 = pairs;
  iter_mic2 = pairs + n_pairs;
  
  iter_mic_shift     = mic_shift;
  iter_mic_shift_max = mic_shift + 2 * n_pairs;

  while ( iter_mic_shift < iter_mic_shift_max ) {
    
    *iter_mic_shift++ = n_freq * ( ( *iter_mic1++ ) - 1 );
    *iter_mic_shift++ = n_freq * ( ( *iter_mic2++ ) - 1 );
    
  }
  

  /* -------------------------------------------------- */
  /* ( 2 ) Compute GCCPHAT from frequency-domain signal */

  /* Start with the first frame */
  iter_X_real = X_real;
  iter_X_imag = X_imag;

  n_freq_n_channels = n_freq * n_channels;
  
  iter_gccphat_real = ssm_observed_gccphat_real;
  iter_gccphat_imag = ssm_observed_gccphat_imag;
  
  n_freq_n_pairs = n_freq * n_pairs;
  
  /* Loop through frames */

#if VERBOSE == 1
  v_i_frame = 0;
#endif
  
  iter_gccphat_real_max = ssm_observed_gccphat_real + n_freq_n_pairs * n_frames;

  while ( iter_gccphat_real < iter_gccphat_real_max ) {
    
#if VERBOSE == 1
    printf( "  compute_gccphat_c.c: frame #%d\n", v_i_frame++ );
#endif

    iter_mic_shift = mic_shift;
    
    /* Loop through pairs */
    
    iter_gccphat_real_next_frame = iter_gccphat_real + n_freq_n_pairs;
    
    while ( iter_gccphat_real < iter_gccphat_real_next_frame ) {

      /* access channel corresponding to microphone #1 of this pair */

      a = *iter_mic_shift++;
      iter_X_real_mic1 = iter_X_real + a;
      iter_X_imag_mic1 = iter_X_imag + a;
 
      /* access channel corresponding to microphone #2 of this pair */

      a = *iter_mic_shift++;
      iter_X_real_mic2 = iter_X_real + a;
      iter_X_imag_mic2 = iter_X_imag + a;
      
      /* Loop through frequencies */
      
      iter_gccphat_real_next_pair = iter_gccphat_real + n_freq;
      
      while ( iter_gccphat_real < iter_gccphat_real_next_pair ) {
	
	rr1 = *iter_X_real_mic1++;
	ii1 = *iter_X_imag_mic1++;
	
	rr2 = *iter_X_real_mic2++;
	ii2 = *iter_X_imag_mic2++;
	
	/* cross-correlation in frequency domain.
	 */
	
	aa =   rr1 * rr2 + ii1 * ii2;
	bb = - rr1 * ii2 + ii1 * rr2;
	
	/* PHAT normalization */

	cc = aa * aa + bb * bb;

	if ( cc == 0 ) {
	
	  /* cc = 0 is very rare.... but it does happen sometimes, even on real, non-zero data. */

	  *iter_gccphat_real++ = 0;
	  *iter_gccphat_imag++ = 0;

	} else {
	  
	  cc = sqrt( cc );
	  *iter_gccphat_real++ = aa / cc;
	  *iter_gccphat_imag++ = bb / cc;
	  
	}
	
      } /* end of loop through frequencies */
      
    } /* end of loop through pairs */

    /* move to the next frame */

    iter_X_real += n_freq_n_channels;
    iter_X_imag += n_freq_n_channels;
    
  } /* end of loop through frames */

  /* -------------------------------------------------- */
  /* ( 3 ) free memory */

  free( mic_shift );

}
