import java.io.PrintWriter;
import jamaica.*;

/**
 * the wrapper class for jamaica.Matrix
 *
 * @author Hanhua Feng - hf2048@columbia.edu
 * @version $Id: MxMatrix.java,v 1.17 2003/05/12 23:44:34 hanhua Exp $
 */
class MxMatrix extends MxDataType {
    Matrix mat;
    BitArray mask;

    MxMatrix( Matrix mat ) {
        this.mat = mat;
        mask = null;
    }

    MxMatrix( Matrix mat, BitArray mask ) {
        this.mat = mat;
        this.mask = mask;
    }

    public String typename() {
        return "matrix";
    }

    public MxDataType copy() {
        return new MxMatrix( mat, mask );
    }

    public MxDataType deepCopy() {
        return new MxMatrix( mat.copy(), mask );
    }

    public void print( PrintWriter w ) {
        if ( name != null )
            w.println( name + " = " );
        mat.print( w, 8, 4, 4 );
        if ( mask != null )
        {
            w.print( "  <mask> = " );
            w.println( mask.toString() );
        }
    }

    public void what( PrintWriter w ) {
        w.print( "<" + typename() + ">  " );
        if ( name != null )
            w.print( name + "  " );
        if ( mask != null )
            w.print( " (masked) " );
        w.println( " " + mat.height() + "*" + mat.width() );
    }

    final boolean demoteable() {
        return mat.width() == 1 && mat.height() == 1 && mask == null;
    }

    final MxDouble demote() {
        return new MxDouble( mat.get( 0, 0 ) );
    }

    public MxDataType transpose() {
        return new MxMatrix( mat.transpose() );
    }

    public MxDataType uminus() { 
        return new MxMatrix( mat.uminus() );
    }

    public MxDataType plus( MxDataType b ) { 
        if ( demoteable() )
            return demote().plus( b );
                
        if ( b instanceof MxMatrix )
            return new MxMatrix( mat.plus( ((MxMatrix)b).mat ) );

        return error( b, "+" );
    }

    public MxDataType add( MxDataType b ){
        if ( b instanceof MxMatrix )
        {
            if ( null == mask )                
                mat.selfadd( ((MxMatrix)b).mat );
            else
                mat.assign( mask, mat.plus( ((MxMatrix)b).mat ) );
            return this;
        }
        else if ( demoteable() )
        {
            mat.set( 0, 0, mat.get(0,0) + MxDouble.doubleValue(b) );
            return this;
        }

        return error( b, "+=" );
    }

    public MxDataType minus( MxDataType b ) {
        if ( demoteable() )
            return demote().minus( b );
                
        if ( b instanceof MxMatrix )
            return new MxMatrix( mat.minus( ((MxMatrix)b).mat ) );

        return error( b, "-" );
    }
    
    public MxDataType sub( MxDataType b ) {
        if ( b instanceof MxMatrix )
        {
            if ( null == mask )                
                mat.selfsub( ((MxMatrix)b).mat );
            else
                mat.assign( mask, mat.minus( ((MxMatrix)b).mat ) );
            return this;
        }
        else if ( demoteable() )
        {
            mat.set( 0, 0, mat.get(0,0) - MxDouble.doubleValue(b) );
            return this;
        }

        return error( b, "-=" );    
    }

    public MxDataType times( MxDataType b ) { 
        if ( demoteable() )
            return demote().times( b );

        if ( b instanceof MxMatrix && !((MxMatrix)b).demoteable() )
            return new MxMatrix( mat.times( ((MxMatrix)b).mat ) );

        return new MxMatrix( mat.times( MxDouble.doubleValue( b ) ) );
    }

    public MxDataType mul( MxDataType b ){
        if ( b instanceof MxMatrix && !((MxMatrix)b).demoteable() )
        {
            if ( null == mask )                
                mat.selfrmul( ((MxMatrix)b).mat );
            else
                mat.assign( mask, mat.times( ((MxMatrix)b).mat ) );
            return this;
        }

        mat.selfmul( MxDouble.doubleValue( b ) );
        return this;
    }

    public MxDataType lfracts( MxDataType b ) {
        if ( demoteable() )
            return demote().lfracts( b );

        if ( b instanceof MxMatrix && !((MxMatrix)b).demoteable() )
            return new MxMatrix( mat.ldivide( ((MxMatrix)b).mat ) );

        return new MxMatrix( mat.times( 1.0 / MxDouble.doubleValue( b ) ) );
    }

    public MxDataType rfracts( MxDataType b ) { 
        if ( demoteable() )
            return demote().rfracts( b );

        if ( b instanceof MxMatrix && !((MxMatrix)b).demoteable() )
            return new MxMatrix( mat.rdivide( ((MxMatrix)b).mat ) );

        return new MxMatrix( mat.times( 1.0 / MxDouble.doubleValue( b ) ) );
    }

    public MxDataType ldiv( MxDataType b ) {
        if ( b instanceof MxMatrix && !((MxMatrix)b).demoteable() )
        {
            if ( null == mask )                
                mat.selfldiv( ((MxMatrix)b).mat );
            else
                mat.assign( mask, mat.ldivide( ((MxMatrix)b).mat ) );
            return this;
        }
        
        mat.selfmul( 1.0 / MxDouble.doubleValue( b ) );
        return this;
    }

    public MxDataType rdiv( MxDataType b ) {
        if ( b instanceof MxMatrix && !((MxMatrix)b).demoteable() )
        {
            if ( null == mask )                
                mat.selfrdiv( ((MxMatrix)b).mat );
            else
                mat.assign( mask, mat.rdivide( ((MxMatrix)b).mat ) );
            return this;
        }
        
        mat.selfmul( 1.0 / MxDouble.doubleValue( b ) );
        return this;
    }

    public MxDataType modulus( MxDataType b ) {
        if ( demoteable() )
            return demote().modulus( b );
        return error( b, "%" );
    }

    public MxDataType rem( MxDataType b ){
        if ( demoteable() )
        {
            mat.set( 0, 0, mat.get( 0, 0 ) % MxDouble.doubleValue( b ) );
            return this;
        }
        return error( b, "%=" );
    }

    public MxDataType gt( MxDataType b ) {
        if ( demoteable() )
            return b.lt( demote() );

        if ( b instanceof MxMatrix && !((MxMatrix)b).demoteable() )
            return new MxBitArray( mat.gt( ((MxMatrix)b).mat ) );

        return new MxBitArray( mat.gt( MxDouble.doubleValue( b ) ) );
    }

    public MxDataType ge( MxDataType b ) {
        if ( demoteable() )
            return b.le( demote() );

        if ( b instanceof MxMatrix && !((MxMatrix)b).demoteable() )
            return new MxBitArray( mat.ge( ((MxMatrix)b).mat ) );

        return new MxBitArray( mat.ge( MxDouble.doubleValue( b ) ) );
    }

    public MxDataType lt( MxDataType b ) {
        if ( demoteable() )
            return b.gt( demote() );

        if ( b instanceof MxMatrix && !((MxMatrix)b).demoteable() )
            return new MxBitArray( mat.lt( ((MxMatrix)b).mat ) );

        return new MxBitArray( mat.lt( MxDouble.doubleValue( b ) ) );
    }

    public MxDataType le( MxDataType b ) {
        if ( demoteable() )
            return b.ge( demote() );

        if ( b instanceof MxMatrix && !((MxMatrix)b).demoteable() )
            return new MxBitArray( mat.le( ((MxMatrix)b).mat ) );

        return new MxBitArray( mat.le( MxDouble.doubleValue( b ) ) );
    }

    public MxDataType eq( MxDataType b ) {
        if ( demoteable() )
            return b.eq( demote() );

        if ( b instanceof MxMatrix && !((MxMatrix)b).demoteable() )
            return new MxBitArray( mat.eq( ((MxMatrix)b).mat ) );

        return new MxBitArray( mat.eq( MxDouble.doubleValue( b ) ) );
    }

    public MxDataType ne( MxDataType b ) {
        if ( demoteable() )
            return b.ne( demote() );

        if ( b instanceof MxMatrix && !((MxMatrix)b).demoteable() )
            return new MxBitArray( mat.ne( ((MxMatrix)b).mat ) );

        return new MxBitArray( mat.ne( MxDouble.doubleValue( b ) ) );
    }

    private static final int getDim( MxDataType x, boolean rows ) {
        if ( x instanceof MxDouble || x instanceof MxInt )
            return 1;
        if ( x instanceof MxMatrix ) 
            return (rows ? ((MxMatrix)x).mat.height() 
                    : ((MxMatrix)x).mat.width());
        x.error( "array [none number/array element]" );
        return 0;
    }

    public static MxDataType joinVert( MxDataType [] x ) {
        if ( x.length == 0 )
            throw new IllegalArgumentException( "No data in array " );
        int nrow = getDim( x[0], true );
        int ncol = getDim( x[0], false );
        for ( int i=1; i<x.length; i++ )
        {
            if ( getDim( x[i], false ) != ncol )
                return x[i].error( "wrong width" );
            nrow += getDim( x[i], true );
        }

        MxMatrix y = new MxMatrix( new Matrix( nrow, ncol ) );

        int rowp = 0;
        for ( int i=0; i<x.length; i++ )
        {
            if ( x[i] instanceof MxMatrix )
            {
                Matrix b = ((MxMatrix)x[i]).mat;
                int m = b.height();
                y.mat.slice( new Range( rowp, m, 1 ),
                             new Range( 0, ncol, 1 ) ).assign( b );
                rowp += m;
            }
            else
            {
                y.mat.set( rowp, 0, MxDouble.doubleValue( x[i] ) );
                rowp++;
            }
        }
        
        return y;
    }

    public static MxDataType joinHori( MxDataType [] x ) {
        if ( x.length == 0 )
            throw new IllegalArgumentException( "no data in array" );
        int nrow = getDim( x[0], true );
        int ncol = getDim( x[0], false );
        for ( int i=1; i<x.length; i++ )
        {
            if ( getDim( x[i], true ) != nrow )
                return x[i].error( "wrong height" );
            ncol += getDim( x[i], false );
        }

        MxMatrix y = new MxMatrix( new Matrix( nrow, ncol ) );

        int colp = 0;
        for ( int i=0; i<x.length; i++ )
        {
            if ( x[i] instanceof MxMatrix )
            {
                Matrix b = ((MxMatrix)x[i]).mat;
                int n = b.width();
                y.mat.slice( new Range( 0, nrow, 1 ),
                             new Range( colp, n, 1 ) ).assign( b );
                colp += n;
            }
            else
            {
                y.mat.set( 0, colp, MxDouble.doubleValue( x[i] ) );
                colp++;
            }
        }

        return y;
    }
}
