#include "Distance_restraint.h"

#include "Create.h"
#include "Vector_3D_impl.h"
#include "Point_3D.h"
#include "Atom_kit.h"
#include "Atom_cache.h"
#include "Mach_eps.h"

#include <math.h>


namespace
{
    double a = 1.;
}

namespace MM
{

Distance_restraint::
Distance_restraint (Prototype &)
:
    atom1_ (nil<Atom>()), atom2_ (nil<Atom>())
{
}


Distance_restraint::
Distance_restraint (Atom & atom1, 
                    Atom & atom2,
                    Text const & /*type*/,
                    double       R_eqv,
                    double       K_r)
:
    atom1_(atom1), atom2_(atom2), R_eqv_(R_eqv), K_r_(K_r)
{
}

Distance_restraint * Distance_restraint::
clone (Atom & atom1, 
       Atom & atom2,
       Text const & type,
       double       R_eqv,
       double       K_r) const
{
    return new Distance_restraint (atom1, atom2, type, R_eqv, K_r);
}

double Distance_restraint::
potential ()
{
    Point_3D const & point_1 = atom1_.position();
    Point_3D const & point_2 = atom2_.position();

    double dx = point_2.x() - point_1.x();
    double dy = point_2.y() - point_1.y();
    double dz = point_2.z() - point_1.z();

    double R           = sqrt (dx * dx + dy * dy + dz * dz);
    
    double R_minus_Req      = R - R_eqv_;
    double abs_R_minus_Req  = fabs (R_minus_Req);

    //last_potential_ = K_r_ * abs_R_minus_Req;
    last_potential_ = K_r_ * (abs_R_minus_Req * abs_R_minus_Req / (abs_R_minus_Req + a) /*+ a*/);
    return last_potential_;
}

void Distance_restraint::
add_force ()
{
    //if (!cache_flag_)
    //{
    //    cache_atoms ();
    //    cache_bonds ();
    //}

    Atom_cache & cache1 = *atom1_.kit().cache___;
    Atom_cache & cache2 = *atom2_.kit().cache___;

    ////double R           = value ();
    ////double R_minus_Req = R - R_eqv_;
    ////double dPot_I_dR   = 2. * K_r_ * R_minus_Req;

    ////double dPot_I_dR_I_R =  dPot_I_dR / R;
    ////double fx = (atom_2.x___ - atom_1.x___) * dPot_I_dR_I_R;
    ////double fy = (atom_2.y___ - atom_1.y___) * dPot_I_dR_I_R;
    ////double fz = (atom_2.z___ - atom_1.z___) * dPot_I_dR_I_R;;

    ////atom_1.force_x___ += fx;
    ////atom_2.force_x___ -= fx;

    ////atom_1.force_y___ += fy;
    ////atom_2.force_y___ -= fy;

    ////atom_1.force_z___ += fz;
    ////atom_2.force_z___ -= fz;

    //Point_3D const & point_1 = atom1_.position();
    //Point_3D const & point_2 = atom2_.position();

    //double dx = point_2.x() - point_1.x();
    //double dy = point_2.y() - point_1.y();
    //double dz = point_2.z() - point_1.z();

    double dx = cache2.x___ - cache1.x___;
    double dy = cache2.y___ - cache1.y___;
    double dz = cache2.z___ - cache1.z___;

    double R           = sqrt (dx * dx + dy * dy + dz * dz);
    
    double R_minus_Req      = R - R_eqv_;
    double abs_R_minus_Req  = fabs (R_minus_Req);

    //double dPot_I_dR = K_r_ * (R_minus_Req * abs_R_minus_Req) / (R_minus_Req * R_minus_Req + 1);
    //double dPot_I_dR   = K_r_ * R_minus_Req / abs_R_minus_Req;

    double b = abs_R_minus_Req / (abs_R_minus_Req + a);
    double dPot_I_dR   = K_r_ * (2.* b - b*b) * R_minus_Req / abs_R_minus_Req;

//    if (abs_R_minus_Req < d_mach_eps_2)
    if (abs_R_minus_Req < d_mach_eps)
        return;

    double dPot_I_dR_I_R =  dPot_I_dR / R;
    double fx = dx * dPot_I_dR_I_R;
    double fy = dy * dPot_I_dR_I_R;
    double fz = dz * dPot_I_dR_I_R;;

    Vector_3D_impl f (fx, fy, fz);
  
    atom1_.kit().add_potential_force    (f);
    atom1_.kit().add_bond_stretch_force (f);

    cache1.F_bond_stretch_x___ += fx;
    cache1.F_bond_stretch_y___ += fy;
    cache1.F_bond_stretch_z___ += fz;
    cache1.force_x___ += fx;
    cache1.force_y___ += fy;
    cache1.force_z___ += fz;

    atom2_.kit().add_potential_force    (f.negate());
    atom2_.kit().add_bond_stretch_force (f);

    cache2.F_bond_stretch_x___ -= fx;
    cache2.F_bond_stretch_y___ -= fy;
    cache2.F_bond_stretch_z___ -= fz;
    cache2.force_x___ -= fx;
    cache2.force_y___ -= fy;
    cache2.force_z___ -= fz;

    //if (!cache_flag_)
    //{
    //    flush_force ();
    //    atom_.clear ();
    //    bond_.clear ();
    //}
}

}//MM
