
// vbmm2.cpp
// VoxBo matrix multiplication
// Copyright (c) 1998-2002 by The VoxBo Development Team

// VoxBo is free software: you can redistribute it and/or modify it
// under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// 
// VoxBo is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// General Public License for more details.
// 
// You should have received a copy of the GNU General Public License
// along with VoxBo.  If not, see <http://www.gnu.org/licenses/>.
// 
// For general information on VoxBo, including the latest complete
// source code and binary distributions, manual, and associated files,
// see the VoxBo home page at: http://www.voxbo.org/
//
// original version written by Dan Kimberg

using namespace std;



// This code provides some simple matrix operations for VoxBo.
// Some of the functions are optimized to operate in parallel
// and/or on large matrices that won't fit in memory.  The
// implementation notes below are strictly for your entertainment.

#define MIN(a,b) ((a) < (b) ? (a) : (b))
#define MAX(a,b) ((a) < (b) ? (b) : (a))

#include <stdio.h>
#include <math.h>
#include <fcntl.h>
#include <unistd.h>
#include <string.h>
#include <signal.h>
#include <ctype.h>
#include <string>
#include "vbutil.h"
#include "vbio.h"

void do_print(tokenlist &args);
void do_printsub(tokenlist &args);
void do_compare(tokenlist &args);
int do_ident(tokenlist &args);
int do_zeros(tokenlist &args);
int do_random(tokenlist &args);
int do_invert(tokenlist &args);
int do_add(tokenlist &args);
int do_subtract(tokenlist &args);
int do_xyt(tokenlist &args);
int do_xy(tokenlist &args);
int do_imxy(tokenlist &args);
int do_f3(tokenlist &args);
int do_pinv(tokenlist &args);
int do_pca(tokenlist &args);
int do_assemblecols(tokenlist &args);
int do_assemblerows(tokenlist &args);
int do_xyz(tokenlist &args);
void vbmm_help();

int
main(int argc,char *argv[])
{
  tokenlist args;
  int err=0;
  string cmd;

  args.Transfer(argc-1,argv+1);
  if (args.size()<1) {
    vbmm_help();
    exit(0);
  }
  
  cmd=args[0];
  args.DeleteFirst();

  if (cmd=="-xy")
    err=do_xy(args);
  else if (cmd=="-imxy")
    err=do_imxy(args);
  else if (cmd=="-f3")
    err=do_f3(args);
  else if (cmd=="-pinv")
    err=do_pinv(args);
  else if (cmd=="-pca")
    err=do_pca(args);
  else if (cmd=="-xyz")
    err=do_xyz(args);
  else if (cmd=="-xyt")
    err=do_xyt(args);
  else if (cmd=="-add")
    err=do_add(args);
  else if (cmd=="-subtract")
    err=do_subtract(args);
  else if (cmd=="-invert")
    err=do_invert(args);
  else if (cmd=="-ident")
    err=do_ident(args);
  else if (cmd=="-zeros")
    err=do_zeros(args);
  else if (cmd=="-assemblecols")
    err=do_assemblecols(args);
  else if (cmd=="-assemblerows")
    err=do_assemblerows(args);
  else if (cmd=="-random")
    err=do_random(args);
  else if (cmd=="-print")
    do_print(args);
  else if (cmd=="-printsub")
    do_printsub(args);
  else if (cmd=="-compare")
    do_compare(args);

  exit(err);
}

// ident args: name size

int
do_ident(tokenlist &args)
{
  int r;

  if (args.size() < 2) {
    printf("[E] vbmm2: need a name and a size to create a matrix.\n");
    return 5;
  }

  r = strtol(args[1]);
  if (r < 0) {
    printf("[E] vbmm2: size for identity matrix must be > 0.\n");
    return 10;
  }

  printf("[I] vbmm2: creating identity matrix %s of size %d.\n",args[0].c_str(),r);

  VBMatrix target(r,r);
  target.ident();
  if (target.WriteMAT1(args[0]))
    printf("[E] vbmm2: identity matrix %s (%dx%d) not created.\n",args[0].c_str(),r,r);
  else
    printf("[I] vbmm2: identity matrix %s (%dx%d) created.\n",args[0].c_str(),r,r);
  return 0;
}

// zeros args: name size

int
do_zeros(tokenlist &args)
{
  if (args.size() < 2) {
    printf("[E] vbmm2: need a name and a size to create a matrix.\n");
    return 5;
  }

  int r,c;
  c = r = strtol(args[1]);
  if (args.size() > 2)         // needed only if it's not square
    c = strtol(args[2]);
  if (r < 0 || c < 0) {
    printf("[E] vbmm2: dimensions for zero matrix must be > 0.\n");
    return 10;
  }

  printf("vbmm2: creating %dx%d zero matrix %s.\n",r,c,args[0].c_str());

  VBMatrix target(c,r);
  target.zero();
  if (target.WriteMAT1(args[0]))
    printf("vbmm2: zero matrix %s (%dx%d) not created.\n",args[0].c_str(),r,c);
  else
    printf("vbmm2: zero matrix %s (%dx%d) created.\n",args[0].c_str(),r,c);
  return 0;
}

int
do_random(tokenlist &args)
{
  if (args.size() < 2) {
    printf("[E] vbmm2: need a name and a size to create a matrix.\n");
    return 5;
  }
  int r,c;
  c = r = strtol(args[1]);
  if (args.size() > 2)         // needed only if it's not square
    c = strtol(args[2]);
  if (r < 0 || c < 0) {
    printf("[E] vbmm2: dimensions for random matrix must be > 0.\n");
    return 10;
  }

  printf("vbmm2: creating random %dx%d matrix %s.\n",r,c,args[0].c_str());

  VBMatrix target(r,c);
  target.random();
  if (target.WriteMAT1(args[0]))
    printf("[E] vbmm2: error writing random %dx%d matrix %s\n",r,c,args[0].c_str());
  else
    printf("[I] vbmm2: random %dx%d matrix %s created\n",r,c,args[0].c_str());
  return 0;
}

// args: in1 in2 out

int
do_subtract(tokenlist &args)
{
  if (args.size() != 3) {
    printf("[E] vbmm2: usage: vbmm2 -subtract in1 in2 out\n");
    return 5;
  }

  VBMatrix mat1(args[0]);
  VBMatrix mat2(args[1]);

  if (mat1.m == 0 || mat1.n == 0) {
    printf("[E] vbmm2: first matrix was bad.\n");
    return 101;
  }
  if (mat2.m == 0 || mat2.n == 0) {
    printf("[E] vbmm2: second matrix was bad.\n");
    return 102;
  }
  if (mat1.m != mat2.m || mat1.n != mat2.n) {
    fprintf(stderr,"[E] vbmm2: matrix dimensions don't match.\n");
    return 103;
  }
  printf("[I] vbmm2: subtracting matrix %s from matrix %s.\n",args[1].c_str(),args[0].c_str());
  mat1-=mat2;
  mat1.WriteMAT1(args[2]);
  printf("[I] vbmm2: done.\n");

  return 0;
}

int
do_add(tokenlist &args)
{
  if (args.size() != 3) {
    printf("[E] vbmm2: usage: vbmms -add in1 in2 out\n");
    return 5;
  }

  VBMatrix mat1(args[0]);
  VBMatrix mat2(args[1]);

  if (mat1.m == 0 || mat1.n == 0) {
    printf("[E] vbmm2: first matrix was bad.\n");
    return 101;
  }
  if (mat2.m == 0 || mat2.n == 0) {
    printf("[E] vbmm2: second matrix was bad.\n");
    return 102;
  }
  if (mat1.m != mat2.m || mat1.n != mat2.n) {
    fprintf(stderr,"[E] vbmm2: matrix dimensions don't match.\n");
    return 103;
  }

  
  printf("[I] vbmm2: adding matrix %s and matrix %s.\n",args[0].c_str(),args[1].c_str());
  mat1+=mat2;
  mat1.WriteMAT1(args[2]);
  printf("[I] vbmm2: done.\n");

  return 0;
}

void
do_compare(tokenlist &args)
{
  if (args.size()!=2) {
    printf("[E] vbmm2: usage: vbmm2 -compare <mat1> <mat2>\n");
    return;
  }
  VBMatrix mat1(args[0]);
  VBMatrix mat2(args[1]);
  if (mat1.m!=mat2.m) {
    printf("[E] vbmm2: matrices have different row count\n");
    return;
  }
  if (mat1.n!=mat2.n) {
    printf("[E] vbmm2: matrices have different column count\n");
    return;
  }
  int diffs_all=0,diffs_diag=0,diffs_off=0;
  double totals_all=0.0,totals_diag=0.0,totals_off=0.0;
  double max_all=0.0,max_diag=0.0,max_off=0.0;
  double diff;
  
  for (int i=0; i<mat1.m; i++) {
    for (int j=0; j<mat1.n; j++) {
      diff=fabs(mat1(i,j)-mat2(i,j));
      if (diff==0.0) continue;
      diffs_all++;
      totals_all+=diff;
      if (diff>max_all) max_all=diff;
      if (i==j) {
        diffs_diag++;
        totals_diag+=diff;
        if (diff>max_diag) max_diag=diff;
      }
      else {
        diffs_off++;
        totals_off+=diff;
        if (diff>max_off) max_off=diff;
      }
    }
  }
  if (diffs_all)
    totals_all/=(double)diffs_all;
  if (diffs_diag)
    totals_diag/=(double)diffs_diag;
  if (diffs_off)
    totals_off/=(double)diffs_off;
  if (diffs_all==0)
    printf("[I] vbmm2: matrices are identical\n");
  else {
    printf("[I] vbmm2: %d total cells\n",mat1.m*mat1.n);
    printf("[I] vbmm2:    total: %d different cells, mean abs diff %g, max diff %g\n",diffs_all,totals_all,max_all);
    printf("[I] vbmm2: diagonal: %d different cells, mean abs diff %g, max diff %g\n",diffs_diag,totals_diag,max_diag);
    printf("[I] vbmm2: off-diag: %d different cells, mean abs diff %g, max diff %g\n",diffs_off,totals_off,max_off);
  }
}

// xy args: in1 in2 out [col1 col2]

int
do_xy(tokenlist &args)
{
  int c1,c2;
  
  if (args.size() != 3 && args.size() != 5) {
    printf("[E] vbmm2: usage: vbmm2 -xy in1 in2 out [c1 c2]\n");
    return 5;
  }
  VBMatrix mat1,mat2;
  mat1.ReadMAT1Header(args[0]);
  mat2.ReadMAT1Header(args[1]);
  if (mat1.m==0||mat1.n==0||mat2.m==0||mat2.n==0) {
    printf("[E] vbmm2: couldn't read matrix headers\n");
    return 100;
  }
  if (mat1.ReadMAT1(args[0])) {
    printf("[E] vbmm2: first matrix was bad.\n");
    return 100;
  }

  if (args.size()==5) {
    c1 = strtol(args[3]);
    c2 = strtol(args[4]);
  }
  else {
    c1=0;
    c2=mat2.cols-1;
  }

  // figure out the outfile name for this part, if we're not doing the whole thing
  string outname=args[2];
  char tmps[128];
  if (c1!=0 || c2!=mat2.cols-1) {
    sprintf(tmps,"_%08d_%08d",c1,c2);
    outname += tmps;
  }

  if (mat2.ReadMAT1(args[1],-1,-1,c1,c2)) {
    printf("[E] vbmm2: second matrix was bad.\n");
    return 101;
  }

  printf("vbmm2: multiplying matrix %s by matrix %s (cols %d to %d).\n",
	 args[0].c_str(),args[1].c_str(),c1,c2);
  mat1*=mat2;
  if (mat1.WriteMAT1(outname))
    printf("[E] vbmm2: failed!\n");
  else
    printf("[I] vbmm2: done.\n");

  return 0;
}

int
do_xyt(tokenlist &args)
{
  int c1,c2;
  
  if (args.size() != 3 && args.size() != 5) {
    printf("[E] vbmm2: usage: vbmm2 -xy in1 in2 out [c1 c2]\n");
    return 5;
  }
  
  VBMatrix mat1,mat2;
  mat1.ReadMAT1Header(args[0]);
  mat2.ReadMAT1Header(args[1]);
  if (mat1.m==0||mat1.n==0||mat2.m==0||mat2.n==0) {
    printf("[E] vbmm2: couldn't read matrix headers\n");
    return 100;
  }

  if (mat1.ReadMAT1(args[0])) {
    printf("[E] vbmm2: first matrix was bad.\n");
    return 100;
  }
  if (args.size()==5) {
    c1 = strtol(args[3]);
    c2 = strtol(args[4]);
  }
  else {
    c1 = 0;
    c2 = mat2.rows-1;
  }

  // figure out the outfile name
  string outname=args[2];
  char tmps[128];
  if (c1!=0 || c2!=mat2.rows-1) {
    sprintf(tmps,"_%05d_%05d",c1,c2);
    outname += tmps;
  }

  if (mat2.ReadMAT1(args[1],c1,c2,-1,-1)) {
    printf("[E] vbmm2: second matrix was bad\n");
    return 101;
  }

  printf("[I] vbmm2: multiplying matrix %s by matrix %s (cols %d to %d).\n",
	 args[0].c_str(),args[1].c_str(),c1,c2);
  mat2.transposed=1;
  mat1*=mat2;
  if (mat1.WriteMAT1(outname))
    printf("[E] vbmm2: failed\n");
  else
    printf("[I] vbmm2: done\n");

  return 0;
}

int
do_xyz(tokenlist &args)
{
  if (args.size() != 4) {
    printf("vbmm2: usage: vbmm2 -xyz in1 in2 in3 out\n");
    return 5;
  }
  
  VBMatrix mat1(args[0]);
  VBMatrix mat2(args[1]);

  if (!(mat1.dataValid() && mat2.dataValid())) {
    printf("[E] vbmm2: bad input matrix\n");
    return (100);
  }
  if (mat1.n != mat2.m) {
    printf("[E] vbmm2: bad matrix dimensions for xyz.\n");
    return (104);
  }

  printf("vbmm2: multiplying matrix %s by matrix %s.\n",args(0),args(1));
  mat1*=mat2;
  printf("vbmm2: multiplying matrix %s by matrix %s.\n",args(1),args(2));
  VBMatrix mat3(args[2]);
  if (mat1.n != mat3.m) {
    printf("[E] vbmm2: bad matrix dimensions for xyz.\n");
    return (104);
  }
  mat1*=mat3;
  printf("vbmm2: done.\n");

  mat1.WriteMAT1(args[3]);

  return 0;
}

int
do_assemblerows(tokenlist &)
{
  return (100);  // error!
  return (0);  // no error!
}

int
do_assemblecols(tokenlist &args)
{
  vglob vg(args[0]+"_*_*");
  if (vg.size() < 1) {
    printf("[E] vbmm2: no parts found for %s\n",args(0));
    return (101);
  }
  vector<VBMatrix *> mats;
  int rows=0,cols=0;
  // first read all the headers
  for (size_t i=0; i<vg.size(); i++) {
    VBMatrix tmp;
    tmp.ReadMAT1Header(vg[i]);
    if (!tmp.headerValid()) {
      printf("[E] vbmm2: invalid matrix in assemble list\n");
      return (102);
    }
    if (rows==0)
      rows=tmp.m;
    cols+=tmp.n;
    if (rows != tmp.m) {
      printf("[E] vbmm2: wrong-sized matrix %s in assemble list\n",vg[i].c_str());
      return (103);
    }
  }
  if (rows < 1 || cols < 1) {
    printf("[E] vbmm2: invalid size for assembled matrix: %d x %d\n",rows,cols);
    return (103);
  }
  VBMatrix newmat(rows,cols);
  
  int ind=0;
  for (size_t i=0; i<vg.size(); i++) {
    VBMatrix tmp(vg[i]);
    for (int j=0; j<tmp.n; j++) {
      VB_Vector vv=tmp.GetColumn(j);
      newmat.SetColumn(ind,vv);
      ind++;
    }
  }
  // try to unlink the actual files now that we're merged
  for (size_t i=0; i<vg.size(); i++)
    unlink(vg[i].c_str());
  newmat.WriteMAT1(args[0]);
  return (0);  // no error!
}

int
do_imxy(tokenlist &args)
{
  if (args.size() != 3 && args.size() != 5) {
    printf("[E] vbmm2: usage: vbmm -imxy in1 in2 out\n");
    return (100);
  }
  
  VBMatrix mat1(args[0]);
  VBMatrix mat2(args[1]);

  if (mat1.m <= 0 || mat1.n <= 0) {
    printf("[E] vbmm2: first matrix was bad.\n");
    return (101);
  }
  if (mat2.m == 0 || mat2.n == 0) {
    printf("[E] vbmm2: second matrix was bad.\n");
    return (102);
  }
  if (mat1.n != mat2.m) {
    printf("[E] vbmm2: incompatible matrix dimensions for I-XY.\n");
    return (103);
  }

  printf("vbmm2: I-XYing matrix %s by matrix %s.\n",
	 args[0].c_str(),args[1].c_str());
  mat1*=mat2;
  for (int i=0; i<mat1.m; i++) {
    VB_Vector tmp=mat1.GetRow(i);
    for (int j=0; j<mat1.n; j++) {
      tmp[j]*=(double)-1.0;
      if (i==j)
        tmp[j]+=(double)1.0;
    }
    mat1.SetRow(i,tmp);
  }
  printf("vbmm2: done.\n");
  mat1.WriteMAT1(args[2]);
  return 0;
}

int
do_f3(tokenlist &args)
{
  if (args.size() != 3) {
    printf("[E] vbmm2: usage: vbmm -f3 v kg out\n");
    return (100);
  }
  printf("[I] vbmm2: creating F3 matrix (V*KG*invert(KGtKG))\n");
  printf("[I] vbmm2: V: %s\n",args(0));
  printf("[I] vbmm2: KG: %s\n",args(1));
  VBMatrix v(args[0]);
  VBMatrix kg(args[1]);
  VBMatrix kgt=kg;
  VBMatrix tmp;

  if (v.m==0 || kg.m==0 || kgt.m==0) {
    printf("[E] vbmm2: couldn't read matrices\n");
    return 100;
  }
  if (v.n != kg.m) {
    printf("[E] vbmm2: incompatible matrix dimensions\n");
    return 100;
  }

  kgt.transposed=1;
  kgt*=kg;
  kgt.transposed=0;
  invert(kgt,tmp);
  kgt.clear();  // free mem
  v*=kg;
  kg.clear();   // free mem
  v*=tmp;
  v.WriteMAT1(args[2]);
  printf("[I] vbmm2: wrote F3 matrix %s\n",args(2));

  return 0;
}

int
do_invert(tokenlist &args)
{
  if (args.size() != 2) {
    printf("[E] vbmm2: usage: vbmm -invert in out\n");
    return (100);
  }
  
  VBMatrix mat(args[0]);

  if (mat.m <= 0 || mat.n <= 0 || mat.m != mat.n) {
    printf("[E] vbmm2: input matrix for invert was bad.\n");
    return (101);
  }
  VBMatrix target(mat.m,mat.m);
  printf("vbmm2: inverting matrix %s.\n",args[0].c_str());
  invert(mat,target);
  printf("vbmm2: done.\n");
  target.WriteMAT1(args[1]);
  return 0;
}

// do_pinv() computes the pseudo-inverse, which is
// inverse(KGtKG) ## KGt

int
do_pinv(tokenlist &args)
{
  if (args.size() != 2) {
    printf("[E] vbmm2: usage: vbmm -pinv in out\n");
    return (100);
  }
  
  VBMatrix mat(args[0]);

  if (mat.m <= 0 || mat.n <= 0) {
    printf("[E] vbmm2: input matrix for pinv was bad.\n");
    return (101);
  }

  VBMatrix target(mat.n,mat.m);
  printf("vbmm2: pinv'ing matrix %s.\n",args[0].c_str());
  pinv(mat,target);
  printf("vbmm2: done.\n");
  target.WriteMAT1(args[1]);
  return 0;
}

// do_pca() calculates the principle components

int
do_pca(tokenlist &args)
{
  if (args.size() != 2) {
    printf("[E] vbmm2: usage: vbmm -pca in out\n");
    return (100);
  }
  
  VBMatrix mat(args[0]);

  if (mat.m <= 0 || mat.n <= 0) {
    printf("[E] vbmm2: input matrix for pinv was bad.\n");
    return (101);
  }

  VB_Vector lambdas;
  VBMatrix pcs;
  VBMatrix E;
  printf("vbmm2: pca'ing matrix %s.\n",args[0].c_str());
  pca(mat,lambdas,pcs,E);
  lambdas.print();
  printf("vbmm2: done.\n");
  pcs.WriteMAT1(args[1]);
  return 0;
}

void
do_print(tokenlist &args)
{
  if (args.size() < 1)
    return;
  for (int i=0; i<args.size(); i++) {
    VBMatrix tmp(args[i]);
    if (tmp.dataValid())
      tmp.print();
    else {
      printf("[E] vbmm2: couldn't open matrix %s to print.\n",args(i));
    }
  }
}

void
do_printsub(tokenlist &args)
{
  if (args.size() != 5)
    return;
  int r1=strtol(args[1]);
  int r2=strtol(args[2]);
  int c1=strtol(args[3]);
  int c2=strtol(args[4]);
  VBMatrix mat(args[0],r1,r2,c1,c2);
  if (!mat.rowdata) {
    printf("[E] vbmm2: couldn't read data\n");
    return;
  }
  printf("[I] vbmm: read rows %d-%d and cols %d-%d of %s\n",
         r1,r2,c1,c2,args(0));
  mat.print();
}

void
vbmm_help()
{
  printf("\nVoxBo vbmm2 (v%s)\n",vbversion.c_str());
  printf("| vbmm does various bits of matrix arithmetic with MAT1 files.\n");
  printf("| below, in and out refer to the input and output matrices.  c1 and c2\n");
  printf("| refer to the start and end columns to be produced (for parallelization).\n");
  printf("| usage:\n");
  printf("    vbmm -xyt <in1> <in2> <out> <c1> <c2 >      do part of XYt\n");
  printf("    vbmm -xy <in1> <in2> <out> <c1> <c2>        do part of XY\n");
  printf("    vbmm -imxy <in1> <in2> <out>                I-XY in core\n");
  printf("    vbmm -f3 <v> <kg> <out>                     V*KG*invert(KTtKG)\n");
  printf("    vbmm -xyz <in1> <in2> <in3> <out>           XYZ in core\n");
  printf("    vbmm -assemblecols <out>                    assemble out from available parts\n");
  // printf("    vbmm -assemblerows <out>                    assemble out from available parts\n");
  printf("    vbmm -add <in1> <in2> <out>                 add two matrices\n");
  printf("    vbmm -subtract <in1> <in2> <out>            subtract in2 from in1\n");
  printf("    vbmm -invert <in> <out>                     invert in\n"); 
  printf("    vbmm -pinv <in> <out>                       pseudo-inverse, in core\n");
  printf("    vbmm -pca <in> <out>                        calculte pca\n"); 
  printf("    vbmm -ident <name> <size>                   create an identity matrix\n");
  printf("    vbmm -zeros <name> <cols> [rows]            create a zero matrix\n");
  printf("    vbmm -random <name> <cols> [rows]           create a matrix of random numbers\n");
  printf("    vbmm -print <name>                          display a matrix\n");
  printf("    vbmm -printsub <name> <r1> <r2> <c1> <c2>   display part of a matrix\n");
  printf("    vbmm -compare <mat1> <mat2>                 compare two matrices\n");
}
