// $Id: Basis.cc,v 1.4 2009/08/05 14:42:45 senning Exp $
//
// Copyright (c) 2009 Department of Mathematics and Computer Science
// Gordon College, 255 Grapevine Road, Wenham, MA 01984
//
// Author:  Jonathan Senning <jonathan.senning@gordon.edu>
// Written: July 23, 2009
//
// Implements a function to compute quadratic, linear, pure expoential, mixed
// exponential, rational, and indicator basis functions.  The indicator
// function logic was worked out by Senning, Lauren Berger and Christopher
// Pfohl.

#include <math.h>
#include <string>
#include <sstream>
#include <SPNetwork.h>
#include <Basis.h>

//-------------------------------------------------------------------------- 

// Constructor 
//
// Input:
//   unsigned int mode  - arithmetic "or" of possible function types

Basis::Basis( unsigned int mode )
{
    _num = 0;
    _mode = mode;
    _indicatorTruncation = 0;
    _rationalTruncation = NULL;
    _betaLowerBound = _BetaLowerBound;
    _betaUpperBound = _BetaUpperBound;
    _beta = NULL;
}

//-------------------------------------------------------------------------- 

// Constructor 
//
// Input:
//   unsigned int mode  - arithmetic "or" of possible function types
//   SPNetwork& network - instance of network class

Basis::Basis( SPNetwork& network, unsigned int mode )
{
    _mode = mode;
    _indicatorTruncation = 0;
    _rationalTruncation = NULL;
    _betaLowerBound = _BetaLowerBound;
    _betaUpperBound = _BetaUpperBound;
    _beta = NULL;
    setNetwork( network );
}

//-------------------------------------------------------------------------- 

// Copy Constructor
//
// Input:
//   Basis& basis - an existance instance of this class

Basis::Basis( Basis& basis )
{
    _num  = basis._num;
    _mode = basis._mode;
    _indicatorTruncation = basis._indicatorTruncation;
    _betaLowerBound = basis._betaLowerBound;
    _betaUpperBound = basis._betaUpperBound;

    if ( basis._beta )
    {
	_beta = new double [_num + 1];
	for ( int i = 1; i <= _num; i++ )
	{
	    _beta[i] = basis._beta[i];
	}
    }
    if ( basis._rationalTruncation )
    {
	_rationalTruncation = new int [_num + 1];
	for ( int i = 1; i <= _num; i++ )
	{
	    _rationalTruncation[i] = basis._rationalTruncation[i];
	}
    }

}

//-------------------------------------------------------------------------- 

// Destructor

Basis::~Basis( void )
{
    if ( _beta ) delete [] _beta;
    if ( _rationalTruncation ) delete [] _rationalTruncation;
}

//-------------------------------------------------------------------------- 

// Method - set constant rational trunction values
//
// Input:
//   int N         - value of truncation

void Basis::setRationalTruncation( int N )
{
    if ( _num > 0 )
    {
	if ( !_rationalTruncation ) _rationalTruncation = new int [_num + 1];
	for ( int i = 1; i <= _num; i++ )
	{
	    _rationalTruncation[i] = N;
	}
    }
}

//-------------------------------------------------------------------------- 

// Method - set individual rational trunction values
//
// Input:
//   int* N        - array of values of truncation

void Basis::setRationalTruncation( int* N )
{
    if ( _num > 0  )
    {
	if ( !_rationalTruncation ) _rationalTruncation = new int [_num + 1];
	for ( int i = 1; i <= _num; i++ )
	{
	    _rationalTruncation[i] = N[i];
	}
    }
}

//-------------------------------------------------------------------------- 

// Method - construct beta array from network parameters
//
// Input:
//   SPNetwork& network - instance of network class
//
// Compute the beta values used for exponential functions in phi():
//
//                         total service rate for class i 
//     _beta[i] = ---------------------------------------------------
//                rate jobs enter class i through service completions
//
// The value of _beta[i] should always be between 0 and 1.  We choose to
// bound all _beta[i] values between _betaLowerBound and _betaUpperBound.

void Basis::setNetwork( SPNetwork& network )
{
    _num = network.getClasses();
    
    int pLength;
    SPNetwork::MatrixEntry* p = network.getTransitionProbabilities( pLength );
    int muLength;
    SPNetwork::MatrixEntry* mu = network.getServiceRates( muLength );

    // get fresh array that is correct length for this network

    if ( _beta != NULL ) delete [] _beta;
    _beta = new double [_num + 1];

    for ( int i = 1; i <= _num; i++ )
    {
	double rateIn = 0.0;
	double rateOut = 0.0;
	for ( int n = 0; n < pLength; n++ )
	{
	    if ( p[n].j == i )
	    {
		for ( int m = 0; m < muLength; m++ )
		{
		    if ( p[n].i == mu[m].j )
		    {
			rateIn += p[n].val * mu[m].val;
		    }
		}
	    }
	}
	for ( int m = 0; m < muLength; m++ )
	{
	    if ( mu[m].j == i ) rateOut += mu[m].val;
	}

	if ( rateIn == 0.0 )
	{
	    _beta[i] = _betaUpperBound;
	}
	else
	{
	    _beta[i] = rateOut / rateIn;
	    if ( _beta[i] < _betaLowerBound )
	    {
		_beta[i] = _betaLowerBound;
	    }
	    else if ( _beta[i] > _betaUpperBound )
	    {
		_beta[i] = _betaUpperBound;
	    }
	}
    }
}

//-------------------------------------------------------------------------- 

// Method - return number of basis functions
//
// Input:
//   int n         - (optional) number of classes
// Output:
//   int           - number of basis functions

int Basis::getBasisSize( int n )
{
    // if n is not passed then we get the default value -1, signaling that
    // we should use the value from the constructor.

    if ( n < 0 ) n = _num;

    int count = 0;
    if ( _mode & Quadratic ) count += ( n * ( n + 1 ) ) / 2;
    if ( _mode & Linear )    count += n;
    if ( _mode & PureExp )   count += n;
    if ( _mode & MixedExp )  count += n * n;
    if ( _mode & Rational )  count += int( pow( 3, n ) );
    if ( _mode & Indicator )
    {
	int c = 1;
	for ( int i = 1; i <= n; i++ )
	{
	    c = ( c * ( _indicatorTruncation + i ) ) / i;
	}
	count += c;
    }
    return count;
}

//-------------------------------------------------------------------------- 

// Method - evaluate jth basis function
//
// Compute the jth basis function assuming that we are some combination of
// quadratic, linear, pure exponential, mixed exponential, and indicator
// functions for the ALP approximation of h(x).
//    h(x) = 0.5 x'Qx + p'x + r beta^x + s x beta^x' + indicators.
// The type of functions included are specified by the type value given to
// the constructor.
//
// The matrix Q is symmetric so the first n(n+1)/2 basis functions are
// x_i * x_j; if i==j then there is also a coefficient of 0.5.  The next n
// functions are x_i.  The next n functions are beta^x_i, followed by the 
// n^2 functions x_i * beta^x_j.  Finally indicator functions may be used on
// the portion of the state space satisfying sum x(i) <= N where N i is the
// indicator truncation value.
//
// Input:
//   int j       - index of the desired basis function
//   int x[]     - current state, argument of the basis function
//   int n       - (optional) number of elements in x starting with x[1]
//
// Output:
//   double      - value of the basis function evaluated at x
//
// API NOTE: The original intent was to have the prototype of this function
// match the "phi()" prototype used by ALP programs so that this function
// could be used instead by means of function pointers.  Unfortunately C++ does
// not allow this.  As a work around, the instance of this class can be made
// global and another function can be defined that access this function.

double Basis::phi( int j, int x[], int n )
{
    // if n is not passed then we get the default value -1, signaling that
    // we should use the value from the constructor.

    if ( n < 0 ) n = _num;

    int base  = 0;
    int count = 0;

    if ( _mode & Quadratic )
    {
	count += ( n * ( n + 1 ) ) / 2;
	if ( j <= count )
	{
	    // quadratic terms
	    int i = 1;
	    while ( j - 1 >= n )
	    {
		j -= ( n - i++ );
	    }
	    return ( i == j ? 0.5 * x[i] * x[j] : double ( x[i] * x[j] ) );
	}
    }

    if ( _mode & Linear )
    {
	base = count;
	count += n;
	if ( j <= count )
	{
	    // linear terms
	    j -= base;
	    return double ( x[j] );
	}
    }

    if ( _mode & PureExp )
    {
	base = count;
	count += n;
	if ( j <= count )
	{
	    // pure exponential terms
	    j -= base;
	    return pow( _beta[j], double( x[j] ) );
	}
    }

    if ( _mode & MixedExp )
    {
	base = count;
	count += n * n;
	if ( j <= count )
	{
	    // mixed exponential terms
	    j -= ( base + 1 );
	    int i = j % n + 1;
	    int k = j / n + 1;
	    return x[i] * pow( _beta[k], double( x[k] ) );
	}
    }

    if ( _mode & Rational )
    {
	// rational functions
	base = count;
	count += int( pow( 3, n ) );
	if ( j <= count )
	{
	    j -= ( base + 1 ); // j = 0 for first rational function
	    double numer = 1.0;
	    double denom = 1.0;
	    for ( int i = n; i >= 1; i-- )
	    {
		int N = _rationalTruncation[i];
		denom += double( x[i] ) / N;
		switch ( j % 3 )
		{ 
		    case 0:
			numer *= ( N - x[i] ) * ( N - x[i] );
			break;
		    case 1:
			numer *= x[i] * ( N - x[i] );
			break;
		    case 2:
			numer *= x[i] * x[i];
			break;
		}
		j /= 3;
	    }
	    //denom = denom * denom;
	    denom = pow( denom, 2 * ( n - 1 ) ); // preserves quadratic bound
	    return numer / denom;
	}
    }

    if ( _mode & Indicator )
    {
	// indicator functions
	int M = _indicatorTruncation;
	base = count;
	count = 1;
	for ( int i = 1; i <= n; i++ )
	{
	    count = ( count * ( M + i ) ) / i;
	}
	count += base;
	if ( j <= count )
	{
	    j -= ( base + 1 );
	    for ( int i = n; i >= 1; i-- )
	    {
		int base = 0;
		int next = 1;
		for ( int k = 1; k < i; k++ )
		{
		    next = ( next * ( M + k ) ) / k;
		}
		int v = 0;

		while ( next <= j )
		{
		    M--;
		    base = next;
		    next = 1;
		    for ( int k = 1; k < i; k++ )
		    {
			next = ( next * ( M + k ) ) / k;
		    }
		    next += base;
		    v++;
		}
		j -= base;
		if ( v != x[i] ) return 0.0;
	    }
	    return 1.0;
	}
    }

    return 0.0;
}

//-------------------------------------------------------------------------- 

// Method - return a string representation of the jth basis function
//
// Input:
//   int j       - index of the desired basis function
//   int x[]     - current state, argument of the basis function
//   int n       - (optional) number of elements in x starting with x[1]
//
// Output:
//   double      - value of the basis function evaluated at x

std::string Basis::phiString( int j, int n )
{
    using std::string;

    // if n is not passed then we get the default value -1, signaling that
    // we should use the value from the constructor.

    if ( n < 0 ) n = _num;

    int base  = 0;
    int count = 0;

    if ( _mode & Quadratic )
    {
	count += ( n * ( n + 1 ) ) / 2;
	if ( j <= count )
	{
	    // quadratic terms
	    int i = 1;
	    while ( j - 1 >= n )
	    {
		j -= ( n - i++ );
	    }
            string str = "x(" + toString( i ) + ")*x(" + toString( j ) + ")";
            if ( i == j ) str += "/2";
	    return str;
	}
    }

    if ( _mode & Linear )
    {
	base = count;
	count += n;
	if ( j <= count )
	{
	    // linear terms
	    j -= base;
	    string str = "x(" + toString( j ) + ")";
	    return str;
	}
    }

    if ( _mode & PureExp )
    {
	base = count;
	count += n;
	if ( j <= count )
	{
	    // pure exponential terms
	    j -= base;
	    string str = toString( _beta[j] ) + "^x(" + toString( j ) + ")";
	    return str;
	}
    }

    if ( _mode & MixedExp )
    {
	base = count;
	count += n * n;
	if ( j <= count )
	{
	    // mixed exponential terms
	    j -= ( base + 1 );
	    int i = j % n + 1;
	    int k = j / n + 1;
	    string str = "x(" + toString( i ) + ")*"
		+ toString( _beta[k] ) + "^x(" + toString( k ) + ")";
	    return str;
	}
    }

    if ( _mode & Rational )
    {
	// rational functions
	base = count;
	count += int( pow( 3, n ) );
	if ( j <= count )
	{
	    j -= ( base + 1 ); // j = 0 for first rational function
	    string numer = "";
	    string denom = ")^2";
	    for ( int i = n; i >= 1; i-- )
	    {
		int N = _rationalTruncation[i];
		denom = "+x(" + toString( i ) + ")/" + toString( N )
		    + denom;
		switch ( j % 3 )
		{ 
		    case 0:
			numer = "(" + toString( N ) 
			    + "-x(" + toString( i ) + "))^2" + numer;
			break;
		    case 1:
			numer = "x(" + toString( i ) + ")(" + toString( N )
			    + "-x(" + toString( i ) + "))" + numer;
			break;
		    case 2:
			numer = "x(" + toString( i ) + ")^2" + numer;
			break;
		}
		j /= 3;
		if ( i > 1 ) numer = " " + numer;
	    }
	    denom = "(1" + denom;
	    return numer + " / " + denom;
	}
    }

    if ( _mode & Indicator )
    {
	// indicator functions
	string str = " )";
	int M = _indicatorTruncation;
	base = count;
	count = 1;
	for ( int i = 1; i <= n; i++ )
	{
	    count = ( count * ( M + i ) ) / i;
	}
	count += base;
	if ( j <= count )
	{
	    j -= ( base + 1 );
	    for ( int i = n; i >= 1; i-- )
	    {
		int base = 0;
		int next = 1;
		for ( int k = 1; k < i; k++ )
		{
		    next = ( next * ( M + k ) ) / k;
		}
		int v = 0;

		while ( next <= j )
		{
		    M--;
		    base = next;
		    next = 1;
		    for ( int k = 1; k < i; k++ )
		    {
			next = ( next * ( M + k ) ) / k;
		    }
		    next += base;
		    v++;
		}
		j -= base;
		str = " " + toString( v ) + str;
	    }
	    str = "ind (" + str;
	    return str;
	}
    }

    return string( "Unknown Function" );
}

//----------------------------------------------------------------------------

// Private Method - returns a string representation of t
//
// Input:
//   t          - instance of arbitrary type
// Output:
//   string     - string representation of t

template <class T>
std::string Basis::toString( const T& t )
{
    std::stringstream ss;
    ss << t;
    return ss.str();
}

//----------------------------------------------------------------------------
