// MM1TAB9.H : slightly modified tables class for parameter fitting.

// Copyright (C) 1998 Tommi Hassinen.

// This program is free software; you can redistribute it and/or modify it
// under the terms of the license (GNU GPL) which comes with this package.

/*################################################################################################*/

#include "config.h"	// this is target-dependent...

#ifndef MM1TAB9_H
#define MM1TAB9_H

class input_at;		// initial atomtype
struct prmfit_at;	// real atomtype

struct prmfit_bs;	// bond stretching
struct prmfit_ab;	// angle bending
struct prmfit_tr;	// torsion
struct prmfit_op;	// out of plane

struct prmfit_bs_query;
struct prmfit_ab_query;
struct prmfit_tr_query;
struct prmfit_op_query;

class prmfit_tables;
class prmfit_cg_optim;

class file_io_handler;

class gaussian_io_handler;
class mpqc_io_handler;

/*################################################################################################*/

#include "mm1eng.h"

#include "mm1eng9.h"
#include "conjgrad.h"

#include <vector>
using namespace std;

/*################################################################################################*/

class input_at
{
	public:
	
	enum default_geom
	{
		Unknown = 0, Linear = 1, Planar = 2, Tetrahedral = 3
	};
	
	protected:
	
	i32s atomtype[2];		// the same as in prmfit_at!!!
	mm1_typerule * typerule;	// the same as in prmfit_at!!!
	char * description;		// the same as in prmfit_at!!!
	f64 formal_charge;		// the same as in prmfit_at!!!
	i32u flags;			// the same as in prmfit_at!!!
	
// atomtype flags have the following meaning:
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
// bit 0 (2^0 = 1)	default geometry bit // 0 = unknown, 1 = linear
// bit 1 (2^1 = 2)	default geometry bit \\ 2 = planar, 3 = tetrahedral
// bit 2 (2^2 = 4)	is a hydrogen-bond donor
// bit 3 (2^3 = 8)	is a hydrogen-bond acceptor
// bit 4 (2^4 = 16)	out-of-plane term is needed
	
	// the rest is forcefield-related...
	// the rest is forcefield-related...
	// the rest is forcefield-related...
	
	i32s atomic_number;
	i32u number_of_hits;	// this determines how generic this type is; will be used in sorting...
	i32u * secondary;	// used only in search of secondary type; NOT USED IN THIS CLASS!!!
	
	friend class prmfit_tables;
	
	public:
	
	input_at(void);
	input_at(ifstream &);
	input_at(const input_at &);
	~input_at(void);
	
	input_at & operator=(const input_at &);		// an assignment operator; STL sort needs this...
	
	bool operator<(const input_at &) const;		// using element and number_of_hits...
};

struct prmfit_at	// atomtype
{
	i32s atomtype[2];
	mm1_typerule * typerule;
	char * description;
	f64 lj_r; f64 lj_e;
	f64 formal_charge;
	i32u flags;
	
	/*##########################*/
	/*##########################*/
	
};

struct prmfit_bs	// bond stretching
{
	i32s atmtp[2];
	bondtype bndtp;
	
	f64 opt;
	f64 fc;
	
	f64 cid;
	
	/*##########################*/
	/*##########################*/
	
	f64 d_fc;
	i32u obs_count;
	
	bool operator<(const prmfit_bs & p1) const
	{
		if (atmtp[0] != p1.atmtp[0]) return (atmtp[0] < p1.atmtp[0]);
		else if (atmtp[1] != p1.atmtp[1]) return (atmtp[1] < p1.atmtp[1]);
		else return (bndtp.GetValue() < p1.bndtp.GetValue());
	}
};

struct prmfit_ab	// angle bending
{
	i32s atmtp[3];
	bondtype bndtp[2];
	
	f64 opt;
	f64 fc;
	
	/*##########################*/
	/*##########################*/
	
	i32u obs_count;
	
	bool operator<(const prmfit_ab & p1) const
	{
		if (atmtp[1] != p1.atmtp[1]) return (atmtp[1] < p1.atmtp[1]);
		else if (atmtp[0] != p1.atmtp[0]) return (atmtp[0] < p1.atmtp[0]);
		else if (atmtp[2] != p1.atmtp[2]) return (atmtp[2] < p1.atmtp[2]);
		else if (bndtp[0].GetValue() != p1.bndtp[0].GetValue()) return (bndtp[0].GetValue() < p1.bndtp[0].GetValue());
		else return (bndtp[1].GetValue() < p1.bndtp[1].GetValue());
	}
};

struct prmfit_tr	// torsion
{
	i32s atmtp[4];
	bondtype bndtp[3];
	
	f64 k[3];
	f64 t[3];
	
	/*##########################*/
	/*##########################*/
	
	bool operator<(const prmfit_tr & p1) const
	{
		if (atmtp[1] != p1.atmtp[1]) return (atmtp[1] < p1.atmtp[1]);
		else if (atmtp[2] != p1.atmtp[2]) return (atmtp[2] < p1.atmtp[2]);
		else if (atmtp[0] != p1.atmtp[0]) return (atmtp[0] < p1.atmtp[0]);
		else if (atmtp[3] != p1.atmtp[3]) return (atmtp[3] < p1.atmtp[3]);
		else if (bndtp[1].GetValue() != p1.bndtp[1].GetValue()) return (bndtp[1].GetValue() < p1.bndtp[1].GetValue());
		else if (bndtp[0].GetValue() != p1.bndtp[0].GetValue()) return (bndtp[0].GetValue() < p1.bndtp[0].GetValue());
		else return (bndtp[2].GetValue() < p1.bndtp[2].GetValue());
	}
};

struct prmfit_op	// out of plane
{
	i32s atmtp[4];
	bondtype bndtp[3];
	
	f64 opt;
	f64 fc;
	
	/*##########################*/
	/*##########################*/
	
	bool operator<(const prmfit_op & p1) const
	{
		if (atmtp[1] != p1.atmtp[1]) return (atmtp[1] < p1.atmtp[1]);
		else if (atmtp[2] != p1.atmtp[2]) return (atmtp[2] < p1.atmtp[2]);
		else if (atmtp[0] != p1.atmtp[0]) return (atmtp[0] < p1.atmtp[0]);
		else if (atmtp[3] != p1.atmtp[3]) return (atmtp[3] < p1.atmtp[3]);
		else if (bndtp[0].GetValue() != p1.bndtp[0].GetValue()) return (bndtp[0].GetValue() < p1.bndtp[0].GetValue());
		else if (bndtp[1].GetValue() != p1.bndtp[1].GetValue()) return (bndtp[1].GetValue() < p1.bndtp[1].GetValue());
		else return (bndtp[2].GetValue() < p1.bndtp[2].GetValue());
	}
};

struct prmfit_bs_query
{
	i32s atmtp[2];		// filled by client!!!
	bondtype bndtp;		// filled by client!!!

	bool strict;		// filled by client!!!
	
	i32s index;
	bool dir;
	
	f64 opt;
	f64 fc;
	
	f64 cid;
};

struct prmfit_ab_query
{
	i32s atmtp[3];		// filled by client!!!
	bondtype bndtp[2];	// filled by client!!!

	bool strict;		// filled by client!!!

	i32s index;
	bool dir;
	
	f64 opt;
	f64 fc;
};

struct prmfit_tr_query
{
	i32s atmtp[4];		// filled by client!!!
	bondtype bndtp[3];	// filled by client!!!

	bool strict;		// filled by client!!!

	i32s index;
	bool dir;
	
	f64 k[3];
	f64 t[3];
};

struct prmfit_op_query
{
	i32s atmtp[4];		// filled by client!!!
	bondtype bndtp[3];	// filled by client!!!

	bool strict;		// filled by client!!!

	i32s index;
	
	f64 opt;
	f64 fc;
};

/*################################################################################################*/

/// MM parameter fitting : force field parameter tables + some fitting tools.

class prmfit_tables
{
	protected:
	
	char * path;
	
	vector<input_at> at1_vector;
	vector<prmfit_at> at2_vector;
	
	vector<prmfit_bs> bs_vector;
	vector<prmfit_ab> ab_vector;
	vector<prmfit_tr> tr_vector;
	vector<prmfit_op> op_vector;
	
	public:
	
	prmfit_tables(const char *);
	virtual ~prmfit_tables(void);
	
	const prmfit_at * GetAtomType(i32s);
	
	void DoParamSearch(prmfit_bs_query &, ostream *);
	void DoParamSearch(prmfit_ab_query &, ostream *);
	void DoParamSearch(prmfit_tr_query &, ostream *);
	void DoParamSearch(prmfit_op_query &, ostream *);
	
	i32s UpdateTypes(mm1_mdl &, ostream *);
	
	void PrintAllTypeRules(ostream &);
	
// the extensions are here...
// the extensions are here...
// the extensions are here...
	
	void InitSTAGE1(void);
	
	void AddCaseSTAGE1a(file_io_handler *);
	void ValidateAtomTypesSTAGE1a(void);
	
	void AddCaseSTAGE1b(file_io_handler *);
	void ValidateAtomTypesSTAGE1b(void);

	void AddCaseSTAGE1c(file_io_handler *);
	void AddCaseSTAGE1c(const prmfit_bs_query *, f64, i32s);
	void AddCaseSTAGE1c(const prmfit_ab_query *, f64, i32s);
	void AddCaseSTAGE1c(const prmfit_tr_query *, i32s);
	void AddCaseSTAGE1c(const prmfit_op_query *, f64, i32s);
	
	void CalcParamSTAGE1(void);
	
	void DistortStructureSTAGE3(mm1_eng_prmfit *);
	
	void WriteAtomTypes(void);
	void WriteParamFiles(void);
};

/*################################################################################################*/

/// MM parameter fitting : conjugate-gradient parameter optimizer.

class prmfit_cg_optim : public prmfit_tables, public conjugate_gradient
{
	protected:
	
	f64 value;
	
	i32s debug_level;
	
	public:
	
	prmfit_cg_optim(const char *);
	~prmfit_cg_optim(void);
	
	void SetDebugLevel(i32s p1) { debug_level = p1; }
	
	void Check(i32s);
	
	f64 GetValue(void);		// virtual
	f64 GetGradient(void);		// virtual
	
	void Calculate(i32s);
	void ProcessFiles(i32s, const char *, const char *, const char *, f64, f64);
	void CompareEnergyDiffs(i32s, file_io_handler *, file_io_handler *, f64);
	void CompareGradients(i32s, file_io_handler *, f64);
};

/*################################################################################################*/

class file_io_handler
{
	protected:
	
	mm1_mdl * mdl;		// all_atom_interface?!?!?!
	
	bool has_e; f64 e;
	bool has_d1; f64 * d1;
	bool has_crd; f64 * crd;
	
	char * debug;
	
	friend class prmfit_tables;
	friend class prmfit_cg_optim;
	
	public:
	
	file_io_handler(mm1_mdl &);
	virtual ~file_io_handler(void);
	
	virtual void WriteInput(const char * fn, bool optimize, bool get_energy, bool get_gradient) = 0;
	virtual bool ReadOutput(const char * fn, bool optimize, bool get_energy, bool get_gradient) = 0;
};

/*################################################################################################*/

class gaussian_io_handler : public file_io_handler
{
	protected:
	
	public:
	
	gaussian_io_handler(mm1_mdl &);
	~gaussian_io_handler(void);
	
	void WriteSettings(ofstream &);
	void WriteMolecule(ofstream &);
	
	void WriteInput(const char *, bool, bool, bool);	// virtual
	bool ReadOutput(const char *, bool, bool, bool);	// virtual
};

/*################################################################################################*/

class mpqc_io_handler : public file_io_handler
{
	protected:
	
	public:
	
	mpqc_io_handler(mm1_mdl &);
	~mpqc_io_handler(void);
	
	void WriteInput(const char *, bool, bool, bool);	// virtual
	bool ReadOutput(const char *, bool, bool, bool);	// virtual
};

/*################################################################################################*/

#endif	// MM1TAB9_H

// eof
