/*
	File:			Mat.cc

	Function:		See header file

	Author(s):		Andrew Willmott

	Copyright:		Copyright (c) 1995-1996, Andrew Willmott

	Notes:			

	Change History:
		31/01/96	ajw		Started
*/

#include "Mat.h"
#include <math.h>
#include <ctype.h>
#include <string.h>
#include <stdarg.h>
#include <iomanip.h>
#include "Array.h"
#include "CopyMat.h"


#pragma mark -
// --- Mat Constructors & Destructors -----------------------------------------


TMPLMat TMat::Mat() : isRef(0), rows(0), cols(0), data(0)
{
}

TMPLMat TMat::Mat(Int rows, Int cols) : isRef(0), rows(rows), cols(cols)
{
	UInt elts = rows * cols;
	Assert(elts > 0, "(Mat) illegal matrix size");
	
	data = new TMReal[elts];
	Assert(data != 0, "(Mat) Out of memory");
}

TMPLMat TMat::Mat(Int rows, Int cols, ZeroOrOne k) : isRef(0), rows(rows), cols(cols)
{
	UInt elts = rows * cols;
	Assert(elts > 0, "(Mat) illegal matrix size");
	
	data = new TMReal[elts];
	Assert(data != 0, "(Mat) Out of memory");
	
	MakeUnit(k);
}

TMPLMat TMat::Mat(Int rows, Int cols, double elt0, ...) : isRef(0), rows(rows), cols(cols)
// The double is hardwired here because it is the only type that will work with var args
// and C++ real numbers. 
{
	UInt elts = rows * cols;
	Assert(elts > 0, "(Mat) illegal matrix size");
	
	va_list ap;
	Int 	i, j;
	
	data = new TMReal[elts];
	Assert(data != 0, "(Mat) Out of memory");
	va_start(ap, elt0);
		
	SELF[0][0] = elt0;
	
	for (i = 1; i < cols; i++)
		SELF[0][i] = va_arg(ap, double);

	for (i = 1; i < rows; i++)
		for (j = 0; j < cols; j++)
			SELF[i][j] = va_arg(ap, double);

	va_end(ap);
}

TMPLMat TMat::Mat(const TMat &m) : isRef(m.isRef), rows(m.rows), cols(m.cols)
{
	if (isRef || m.data == 0)
		data = m.data;
	else
	{
		UInt elts = rows * cols;
		
		data = new TMReal[elts];
		Assert(data != 0, "(Mat) Out of memory");
		memcpy(data, m.data, elts * sizeof(TMReal));
	}
}

TMPLMat TMat::Mat(Int nrows, Int ncols, TMReal *ndata) : isRef(1), rows(nrows), cols(ncols), data(ndata)
{
}

TMPLMat TMat::Mat(const TSGMat &m) : rows(m.Rows()), cols(m.Cols())
{
	data = new TMReal[rows * cols];
	Assert(data != 0, "(Mat) Out of memory");

	CopyMat(SELF, m);
}

TMPLMat TMat::Mat(TMat2 &m) : isRef(1), rows(2), cols(2), data(m.Ref())
{
}

TMPLMat TMat::Mat(TMat3 &m) : isRef(1), rows(3), cols(3), data(m.Ref())
{
}

TMPLMat TMat::Mat(TMat4 &m) : isRef(1), rows(4), cols(4), data(m.Ref())
{
}

TMPLMat TMat::~Mat()
{
	if (!isRef)
		delete[] data;
}

#pragma mark -
// --- Mat Assignment Operators -----------------------------------------------


TMPLMat TMat &TMat::operator = (const TMat &m)		// Sigh. Oh, for decent templating...
{	
	if (rows == 0)
		SetSize(m.Rows(), m.Cols());

    return(CopyMat(SELF, m));
}
	  
TMPLMat TMat &TMat::operator = (const TSGMat &m)
{	
	if (rows == 0)
		SetSize(m.Rows(), m.Cols());

    return(CopyMat(SELF, m));
}
	  
TMPLMat TMat &TMat::operator = (ZeroOrOne k)
{
	Int		i, j;
	
	for (i = 0; i < rows; i++)
		for (j = 0; j < cols; j++)
			elt(i,j) = k;		

	return(SELF);	
}

TMPLMat TMat &TMat::operator = (const TMat2 &m)
{	
	if (rows == 0)
		SetSize(m.Rows(), m.Cols());

    return(CopyMat(SELF, m));
}
	  
TMPLMat TMat &TMat::operator = (const TMat3 &m)
{	
	if (rows == 0)
		SetSize(m.Rows(), m.Cols());

    return(CopyMat(SELF, m));
}
	  
TMPLMat TMat &TMat::operator = (const TMat4 &m)
{	
	if (rows == 0)
		SetSize(m.Rows(), m.Cols());

    return(CopyMat(SELF, m));
}

TMPLMat TMat &TMat::operator >> (Action<TMReal> &a)
{
	Int i, j;
	
	a.Start();
	
	for (i = 0; i < rows; i++)
		for (j = 0; j < cols; j++)
			a.Process(elt(i,j));

	a.Stop();

	return(SELF);
}

TMPLMat void TMat::SetSize(Int nrows, Int ncols)
{
	UInt elts = nrows * ncols;
	Assert(elts > 0, "(Mat::SetSize) Illegal matrix size.");

	if (!isRef)
		delete[] data;

	rows = nrows;
	cols = ncols;
	data = new TMReal[elts];

	Assert(data != 0, "(Mat::SetSize) Out of memory");
}

TMPLMat void TMat::MakeZero()
{
	Int		i, j;
	
	for (i = 0; i < rows; i++)
		for (j = 0; j < cols; j++)
			elt(i,j) = vl_zero;		
}

TMPLMat void TMat::MakeUnit(TMReal k)
{
	Int		i, j;
	
	for (i = 0; i < rows; i++)
		for (j = 0; j < cols; j++)
			if (i == j) 
				elt(i,j) = k;
			else
				elt(i,j) = vl_zero;		
}

TMPLMat void TMat::MakeUnit()
{
	Int		i, j;
	
	for (i = 0; i < rows; i++)
		for (j = 0; j < cols; j++)
			elt(i,j) = (i == j) ? vl_one : vl_zero;		
}

TMPLMat void TMat::MakeBlock(TMReal k)
{
	Int		i, j;
	
	for (i = 0; i < rows; i++)
		for (j = 0; j < cols; j++)
			elt(i,j) = k;		
}

TMPLMat void TMat::MakeBlock()
{
	Int		i, j;
	
	for (i = 0; i < rows; i++)
		for (j = 0; j < cols; j++)
			elt(i,j) = vl_one;		
}

TMPLMat TSGMat sub(const TMat &m, Int top, Int left, Int nrows, Int ncols)
{
	Assert(left >= 0 && ncols > 0 && left + ncols <= m.Cols(), "(sub(Mat)) illegal subset of matrix");
	Assert(top >= 0 && nrows > 0 && top + nrows <= m.Rows(), "(sub(Mat)) illegal subset of matrix");

	TSGMat result(nrows, ncols, m.Cols(), m.Ref() + top * m.Cols() + left);

	return(result);
}

TMPLMat TSGMat sub(const TMat &m, Int nrows, Int ncols)
{
	Assert(ncols > 0 && nrows > 0 && nrows <= m.Rows() && ncols <= m.Cols(), 
		"(sub(Mat)) illegal subset of matrix");

	TSGMat result(nrows, ncols, m.Cols(), m.Ref());

	return(result);
}

TMPLMat TMSGVec col(const TMat &m, Int i)
{
	CheckRange(i, 0, m.Cols(), "(col(Mat)) illegal column index");

	return(TMSGVec(m.Rows(), m.Ref() + i, m.Cols()));
}

TMPLMat TMSGVec diag(const TMat &m, Int diagNum)
{
	if (diagNum == 0)
		return(TMSGVec(Min(m.Rows(), m.Cols()), m.Ref(), m.Cols() + 1));
	else if (diagNum < 0)
		return(TMSGVec(Min(m.Rows() + diagNum, m.Cols()), m.Ref() - diagNum 
			* m.Cols(), m.Cols() + 1));
	else
		return(TMSGVec(Min(m.Cols() - diagNum, m.Rows()), m.Ref() + diagNum,
			m.Cols() + 1));
}

#pragma mark -
// --- Mat Assignment Operators -----------------------------------------------


TMPLMat TMat &operator += (TMat &m, const TMat &n)
{
	Assert(n.Rows() == m.Rows(), "(Mat::+=) matrix rows don't match");	
	
	Int		i;
	
	for (i = 0; i < m.Rows(); i++) 
		m[i] += n[i];
	
	return(m);
}

TMPLMat TMat &operator -= (TMat &m, const TMat &n)
{
	Assert(n.Rows() == m.Rows(), "(Mat::-=) matrix rows don't match");	
	
	Int		i;
	
	for (i = 0; i < m.Rows(); i++) 
		m[i] -= n[i];
	
	return(m);
}

TMPLMat TMat &operator *= (TMat &m, const TMat &n)
{
	Assert(m.Cols() == n.Cols(), "(Mat::*=) matrix columns don't match");	
	
	Int		i;
	
	for (i = 0; i < m.Rows(); i++) 
		m[i] *= (TMMat &) n;
	
	return(m);
}

TMPLMat TMat &operator *= (TMat &m, TMReal s)
{	
	Int		i;
	
	for (i = 0; i < m.Rows(); i++) 
		m[i] *= s;
	
	return(m);
}

TMPLMat TMat &operator /= (TMat &m, TMReal s)
{	
	Int		i;
	
	for (i = 0; i < m.Rows(); i++) 
		m[i] /= s;
	
	return(m);
}

#pragma mark -
// --- Mat Comparison Operators -----------------------------------------------


TMPLMat Bool operator == (const TMat &m, const TMat &n)
{
	Assert(n.Rows() == m.Rows(), "(Mat::==) matrix rows don't match");	
	
	Int		i;
	
	for (i = 0; i < m.Rows(); i++) 
		if (m[i] != n[i])
			return(0);

	return(1);
}

TMPLMat Bool operator != (const TMat &m, const TMat &n)
{
	Assert(n.Rows() == m.Rows(), "(Mat::!=) matrix rows don't match");	
	
	Int		i;
	
	for (i = 0; i < m.Rows(); i++) 
		if (m[i] != n[i])
			return(1);

	return(0);
}

#pragma mark -
// --- Mat Arithmetic Operators -----------------------------------------------


TMPLMat TMat operator + (const TMat &m, const TMat &n)
{
	Assert(n.Rows() == m.Rows(), "(Mat::+) matrix rows don't match");	
	
	TMat	result(m.Rows(), m.Cols());
	Int		i;
	
	for (i = 0; i < m.Rows(); i++) 
		result[i] = m[i] + n[i];
	
	return(result);
}

TMPLMat TMat operator - (const TMat &m, const TMat &n)
{
	Assert(n.Rows() == m.Rows(), "(Mat::-) matrix rows don't match");	
	
	TMat	result(m.Rows(), m.Cols());
	Int		i;
	
	for (i = 0; i < m.Rows(); i++) 
		result[i] = m[i] - n[i];
	
	return(result);
}

TMPLMat TMat operator - (const TMat &m)
{
	TMat	result(m.Rows(), m.Cols());
	Int		i;
	
	for (i = 0; i < m.Rows(); i++) 
		result[i] = -m[i];
	
	return(result);
}

TMPLMat TMat operator * (const TMat &m, const TMat &n)
{
	Assert(m.Cols() == n.Rows(), "(Mat::*m) matrix cols don't match");	
	
	TMat	result(m.Rows(), n.Cols());
	Int		i;
	
	for (i = 0; i < m.Rows(); i++) 
		result[i] = m[i] * (TMMat &) n;
	
	return(result);
}

TMPLMat TVec operator * (const TMat &m, const TVec &v)
{
	Assert(m.Cols() == v.Elts(), "(Mat::*v) matrix and vector sizes don't match");
	
	Int		i;
	TVec	result(m.Rows());
	
	for (i = 0; i < m.Rows(); i++) 
		result[i] = m[i] dot v;
	
	return(result);
}

TMPLMat TMat operator * (const TMat &m, TMReal s)
{
	Int		i;
	TMat	result(m.Rows(), m.Cols());
	
	for (i = 0; i < m.Rows(); i++) 
		result[i] = m[i] * s;
	
	return(result);
}

TMPLMat TMat operator / (const TMat &m, TMReal s)
{
	Int		i;
	TMat	result(m.Rows(), m.Cols());
	
	for (i = 0; i < m.Rows(); i++) 
		result[i] = m[i] / s;
	
	return(result);
}

#pragma mark -
// --- Mat Mat-Vec Functions --------------------------------------------------


TMPLMat TVec operator * (const TVec &v, const TMat &m)			// v * m
{
	Assert(v.Elts() == m.Rows(), "(Mat::v*m) vector/matrix sizes don't match");
	
	TVec 	result(m.Cols());
	Int		i, j;
	
	for (i = 0; i < m.Cols(); i++) 
	{
		TReal sum = 0;
		
		for (j = 0; j < v.Elts(); j++)
			sum += v[j] * m.elt(j,i);
			
		result[i] = sum;
	}
	
	return(result);
}

TMPLMat TVec &operator *= (TVec &v, const TMat &m)				// v *= m
{
	v = v * m;		// Can't optimise much here...
	
	return(v);
}

#pragma mark -
// --- Mat Special Functions --------------------------------------------------


TMPLMat TMat trans(const TMat &m)
{
	Int		i,j;
	TMat	result(m.Cols(), m.Rows());
	
	for (i = 0; i < m.Rows(); i++) 
		for (j = 0; j < m.Cols(); j++)
			result.elt(j,i) = m.elt(i,j);
	
	return(result);
}

TMPLMat TMReal trace(const TMat &m)
{
	Int		i;
	TMReal 	result = 0;
	
	for (i = 0; i < m.Rows(); i++) 
		result += m.elt(i,i);
	
	return(result);
}

#pragma mark -
// --- Mat Input & Output -----------------------------------------------------


TMPLMat ostream	&operator << (ostream &s, const TMat &m)
{
	Int i, w = s.width();

	for (i = 0; i < m.Rows(); i++)
		s << setw(w) << m[i] << endl;
	
	return(s);
}

TMPLMat istream	&operator >> (istream &s, TMat &m)
{
	Array< Array<TMReal> > array;	
    Int		i;
	
	s >> array;						// Read input into array of arrays
	
	m.SetSize(array.NumItems(), array[0].NumItems());
	
	for (i = 0; i < m.Rows(); i++)	// copy the result into m
	{
		Assert(m.Cols() == array[i].NumItems(), "(Mat/>>) different sized matrix rows");
		m[i] = TMVec(m.Cols(), array[i].Ref());
	}
	
    return(s);
}


