#include <assert.h>
#include <unistd.h>
#include <pthread.h>
#include "springball.h"

static const double MICRONAPTIME = 100;

static const double BALLMASS = 1e-1;
static const double SHADOWBALLMASS = BALLMASS*1e-2;
static const double BALLRADIUS = 1.0;
static const double SHADOWBALLRADIUS = BALLRADIUS*1e-2;

static const double HOLDFORCE = 1e-5;

static const double WORLDK = 10.0;
static const double WORLDDRAG = 0.1;

static const double SPRINGK = 10.0;
static const double SPRINGDRAG = 0.1;

static const double STEPSIZE = 5e2;
static const double WORLDERP = (STEPSIZE*WORLDK)/(STEPSIZE*WORLDK+WORLDDRAG);
static const double WORLDCFM = 1.0/(STEPSIZE*WORLDK+WORLDDRAG);

static const double SPRINGERP =(STEPSIZE*SPRINGK)/(STEPSIZE*SPRINGK+SPRINGDRAG);
static const double SPRINGCFM = 1.0/(STEPSIZE*SPRINGK+SPRINGDRAG);

static const double HINGESEP = 0.1;
static const double SHADOWSEP = 0.05;
static const double RFACTOR = 1e-6;
static const double BREATHNUM = 4.0;
static const double BREATHRATE = 1.0e-6;
static const double GFORCE = 3.0e-9;

static void copyDRealToDouble(const dReal *src, double *dst);

PseudoSpring::~PseudoSpring() {
    dBodyDestroy(ss1);
    dBodyDestroy(ss2);
    dJointDestroy(j1);
    dJointDestroy(j2);
    dJointDestroy(js);
    dJointGroupDestroy(jg);
}

PseudoSpring::PseudoSpring(dWorldID w, dBodyID s1, dBodyID s2, dMass& msmall)
{
    const dReal *p1, *p2;
    dReal sp1[3], sp2[3];
    dReal pav[3];
    int i;
    ss1 = dBodyCreate(w);
    ss2 = dBodyCreate(w);
    dBodySetFiniteRotationMode(ss1, 0);
    dBodySetFiniteRotationMode(ss2, 0);
    dBodySetLinearVel(ss1, 0, 0, 0);
    dBodySetLinearVel(ss2, 0, 0, 0);
    dBodySetMass(ss1, &msmall);
    dBodySetMass(ss2, &msmall);
    jg = dJointGroupCreate(0);
    j1 = dJointCreateBall(w, jg);
    j2 = dJointCreateBall(w, jg);
    js = dJointCreateHinge(w, jg);
    p1 = dBodyGetPosition(s1);
    memcpy(sp1, p1, sizeof(sp1));
    p2 = dBodyGetPosition(s2);
    memcpy(sp2, p2, sizeof(sp2));
    for (i = 0; i < 3; i += 1)
      pav[i] = 0.5 * (sp1[i] + sp2[i]);
    dJointSetHingeAnchor(js, pav[0], pav[1], pav[2]);
    dJointSetHingeAxis(js, 0, 0, 1);
    dBodySetPosition(ss1, 0, 0, 0);
    dBodySetPosition(ss2, HINGESEP, HINGESEP, HINGESEP);
    dBodySetPosition(ss1, sp1[0]+SHADOWSEP, sp1[1]+SHADOWSEP, sp1[2]);
    dBodySetPosition(ss2, sp2[0]+SHADOWSEP, sp2[1]+SHADOWSEP, sp2[2]);
    dJointAttach(j1, s1, ss1);
    dJointAttach(j2, s2, ss2);
    dJointAttach(js, ss1, ss2);
    dJointSetHingeParam(js, dParamCFM, SPRINGCFM);
    dJointSetHingeParam(js, dParamStopCFM, SPRINGCFM);
    dJointSetHingeParam(js, dParamStopERP, SPRINGERP);
    dJointSetHingeParam(js, dParamSuspensionERP, SPRINGERP);
    dJointSetBallAnchor(j1, sp1[0], sp1[1], sp1[2]);
    dJointSetBallAnchor(j2, sp2[0], sp2[1], sp2[2]);
}

SpringBallSystem::~SpringBallSystem() {
  int i, j;
  pthread_mutex_destroy(&_balllock);
  pthread_mutex_destroy(&_smatlock);
  for (i = 0; i < _num_balls; i += 1)
    for (j = i+1; j < _num_balls; j += 1)
      adjust_spring_full(i, j, 0.0);
  dWorldDestroy(_w);
  gsl_matrix_free(_smattarget);
  free(_balls);
  free(_bpos);
}

SpringBallSystem::SpringBallSystem(const TimeFrame &tf, int howManyBalls) : _tf(tf) {
  double initrad =  2+1e-8*pow(howManyBalls, 3);
  _num_balls = howManyBalls;
  _gravity_on = false;
  _abort_now = false;
  _coulomb_on = true;
  _motion_on = false;
  _smattarget = gsl_matrix_calloc(howManyBalls, howManyBalls);
  _smatnum = 1;
  _smatlast = 0;
  pthread_mutex_init(&_balllock, NULL);
  pthread_mutex_init(&_smatlock, NULL);
  _coulombic_power =  (howManyBalls < 30 ? 5:1)*3e-4/pow(howManyBalls, 1.2);

  _balls = (dBodyID *) calloc(sizeof(_balls[0]), howManyBalls);
  _bpos = (double *) calloc(sizeof(_bpos[0]), howManyBalls*3);
  _springs = (PseudoSpring **) calloc(sizeof(PseudoSpring *), howManyBalls*howManyBalls);
  _w = dWorldCreate();
  dMassSetSphereTotal(&_ball_mass, BALLMASS, BALLRADIUS);
  dMassSetSphereTotal(&_shadowball_mass, SHADOWBALLMASS, SHADOWBALLRADIUS);
  int i;
  for (i = 0; i < howManyBalls; i += 1) {
    _balls[i] = dBodyCreate(_w);
    dBodySetFiniteRotationMode(_balls[i], 0);
    dBodySetLinearVel(_balls[i], 0, 0, 0);
    dBodySetMass(_balls[i], &_ball_mass);
    dBodySetPosition(_balls[i], ((10.0+i)/(10.0))*initrad*sin(i), ((20.0+i)/(15.0))*initrad*cos(i*2.2), initrad*sin(i*3.77));
  }
}

void SpringBallSystem::set_spring(int i, int j, PseudoSpring *s)
{
  assert(i != j);
  if (i > j)
    return set_spring(j,i,s);
  _springs[i*_num_balls + j] = s;
}

PseudoSpring *SpringBallSystem::get_spring(int i, int j) const
{
  if (i == j)
    return NULL;
  if (i > j)
    return get_spring(j,i);
  return _springs[i*_num_balls + j];
}

void SpringBallSystem::adjust_spring_full(int i, int j, double sk)
{
  PseudoSpring *oldspring = get_spring(i,j);
  bool should_have = (sk > 0.1);
  if (should_have && oldspring)
    return;
  if (!should_have && (oldspring == 0))
    return;
  if (!should_have && oldspring) {
    delete oldspring; // TODO: fix
    set_spring(i,j, (PseudoSpring *) 0);
    return;
  }
  assert(should_have && (oldspring == 0));
  PseudoSpring *ps = new PseudoSpring(_w,_balls[i],_balls[j],_shadowball_mass);
  set_spring(i, j, ps);
}

void SpringBallSystem::adjust_to_springy_matrix(const gsl_matrix *mat)
{
  assert(mat != 0);
  assert(mat->size1 == mat->size2);
  assert(mat->size1 == _num_balls);
  pthread_mutex_lock(&_smatlock);
  gsl_matrix_memcpy(_smattarget, mat);
  _smatnum += 1;
  pthread_mutex_unlock(&_smatlock);
  return;
}

void SpringBallSystem::dosimloop(void)
{
  double t0 = _tf.get_time();
  double tcur;
  double tnext = t0 + STEPSIZE;
  _abort_now = false;
  unsigned int i, j;
  while (!_abort_now) {
    tcur = _tf.get_time();
    if (tcur > tnext) {
      dosimstep();
      t0 += STEPSIZE;
      tnext = t0 + STEPSIZE;
    }
      bool mustdo;
      pthread_mutex_lock(&_smatlock);
      mustdo = (_smatnum != _smatlast);
      pthread_mutex_unlock(&_smatlock);
      if (mustdo) {
      pthread_mutex_lock(&_balllock);
      pthread_mutex_lock(&_smatlock);
        for (i = 0; i < _smattarget->size1; i += 1)
          for (j = i+1; j < _smattarget->size2; j += 1)
            adjust_spring_full(i, j, gsl_matrix_get(_smattarget, i,j));
        _smatnum = _smatlast;
      pthread_mutex_unlock(&_smatlock);
      pthread_mutex_unlock(&_balllock);
    if (tcur <= tnext) {
        usleep((tnext - tcur)*1e6/_tf.time_dilation());
      }
    }
  }
}

void SpringBallSystem::dosimstep(void)
{
  int i, j;
  static int k = 0;
  double bco[3];
  pthread_mutex_lock(&_balllock);
  k += 1;
//  if (k % 2 == 0)
  if (_motion_on) {
    double t = _tf.get_time();
    double tidi = _tf.time_dilation();
    dBodyAddForce(_balls[_num_balls-3],
            1e-5*cos(6*t/tidi),-1e-5*sin(4.37*t/tidi),0.0);

  }
  if (_gravity_on)
    add_gravity();
  if (_coulomb_on)
    add_coulombic();
  int holdnode = _num_balls - 2;
  const dReal *basepos;
  double ballfactor = 1*sqrt(_num_balls);
  const double fdt = (1e-2*ballfactor)/STEPSIZE;
  dWorldStepFast1(_w, STEPSIZE, 80);
  for (i = 0; i < _num_balls; i += 1) {
    const dReal *dpos = dBodyGetPosition(_balls[i]);
    copyDRealToDouble(dpos, bco);
    for (j = 0; j < 3; j += 1)
      _bpos[3*i+j] = bco[j];
  }
  adjustCOM();
  pthread_mutex_unlock(&_balllock);
}

void SpringBallSystem::stopsimloop(void) {
  _abort_now = true;
}

int SpringBallSystem::get_ball_count(void) const
{
  return _num_balls;
}

static void copyDRealToDouble(const dReal *src, double *dst)
{
  int i;
  for (i = 0; i < 3; i += 1)
    dst[i] = (double) src[i];
}

void SpringBallSystem::toggle_motion(void)
{
  _motion_on = !_motion_on;
}

void SpringBallSystem::toggle_coulombic(void)
{
  _coulomb_on = !_coulomb_on;
}

bool SpringBallSystem::get_motion(void) const
{
  return _motion_on;
}

bool SpringBallSystem::get_gravity(void) const
{
  return _gravity_on;
}

void SpringBallSystem::toggle_gravity(void)
{
  _gravity_on = !_gravity_on;
}

void SpringBallSystem::add_gravity(void)
{
  int i;
  for (i = 0; i < _num_balls; i += 1) {
    double mult =  ((i%2)==0) ? -1 : 1;
    double f = GFORCE * mult;
    dBodyAddForce(_balls[i], f, f, f);
  }
}

void SpringBallSystem::add_coulombic(void)
{
  int i, j;
  double r;
  double ft = BREATHNUM;
  const dReal *p1, *p2;
  double pd1[3], pd2[3];
  double psub[3], fsub[3];
  for (i = 0; i < _num_balls; i += 1) {
    p1 = dBodyGetPosition(_balls[i]);
    copyDRealToDouble(p1, pd1);
    for (j = i+1; j < _num_balls; j += 1) {
      p2 = dBodyGetPosition(_balls[j]);
      copyDRealToDouble(p2, pd2);
      cblas_dcopy(3, pd1, 1, psub, 1);
      cblas_daxpy(3, -1, pd2, 1, psub, 1);
      r = cblas_dnrm2(3, psub, 1);
      if (r <= 0)
        continue;
      cblas_dscal(3, 1/r, psub, 1);
      double fmag = ft*_coulombic_power/(r*r);
      cblas_dcopy(3, psub, 1, fsub, 1);
      cblas_dscal(3, fmag, fsub, 1);
      dBodyAddForce(_balls[i], fsub[0], fsub[1], fsub[2]);
      dBodyAddForce(_balls[j], -fsub[0], -fsub[1], -fsub[2]);
    }
  }
}

bool SpringBallSystem::is_spring_connecting(int i, int j) const
{
  return get_spring(i,j) != 0;
}

static void *do_simcalculations(void *udata) {
  SpringBallSystem *smt = (SpringBallSystem *) udata;
  smt->dosimloop();
  return NULL;
}

void SpringBallSystem::dosimloopthread(void)
{
  pthread_t calcthread;
  pthread_create(&calcthread, NULL, do_simcalculations, (void *) this);
}

void SpringBallSystem::set_ball_position(int whichBall, double *co)
{
  int i;
  assert(whichBall >= 0 && whichBall < _num_balls);
//  pthread_mutex_lock(&_balllock);
  for (i = 0; i < 3; i += 1)
    _bpos[whichBall*3+i] = co[i];
//  pthread_mutex_unlock(&_balllock);
}

void SpringBallSystem::get_ball_position(int whichBall, double *co)
{
  int i;
  assert(whichBall >= 0 && whichBall < _num_balls);
  for (i = 0; i < 3; i += 1)
    co[i] = _bpos[whichBall*3+i];
}

void SpringBallSystem::adjustCOM(void)
{
  double m[3], co[3];
  getCOM(m);
  int i, c;
  c = get_ball_count();
  for (i = 0; i < c; i += 1) {
      get_ball_position(i, co);
      co[0] -= m[0]; co[1] -= m[1]; co[2] -= m[2];
      set_ball_position(i, co);
  }
}

void SpringBallSystem::getCOM(double m[3])
{
  double co[3], a[3] = { 0,0,0 };
  int i, c;
  c = get_ball_count();
  if (c > 0) {
    for (i = 0; i < c; i += 1) {
      get_ball_position(i, co);
      a[0] += co[0]; a[1] += co[1]; a[2] += co[2];
    }
    a[0] /= c;
    a[1] /= c;
    a[2] /= c;
  }
  m[0] = a[0]; m[1] = a[1]; m[2] = a[2];
}

