#include "Replica_exchange.h"

#include "System_of_Units.h"
#include "Atom_kit.h"

#include <map>

namespace MM
{
Replica_exchange  & Replica_exchange::
singleton ()
{
    static Replica_exchange instance;
    return instance;
}

void Replica_exchange::
cycle ()
{
    Project & project = Project::singleton();
    int steps = iterations();

    for (int i=0;  i<steps && !stop_flag_;  ++i)
    {
        for (int j=0;  j<project.model_count();  ++j)
        {
            Model &                 model       = project.model (j);
            Model_kit &             kit         = model.kit ();
            //Replica &               replica     = kit.replica ();
            //Composite_interaction & interaction = kit.interaction ();
            MD_simulator &          MD          = kit.MD ();

            MD.method ().step ();
            MD.thermostat ().execute();
            MD.barostat   ().execute();
            MD.publish_step_finished ();
            MD.subject().update ();
        }

        if (i % exchange_per_step_ == 0)
            exchange_replica ();
    }
}


void Replica_exchange::
start ()
{
    stop_flag_ = false;
    accept_count_ = reject_count_ = 0;

    Project & project = Project::singleton();
    int count = project.model_count();

    if (count < 2)
    {
        to_user().fatal_error ("Should be at least two models.");
        stop ();
        //return;
    }

    for (int j=0;  j<count;  ++j)
    {
        Model &        model = project.model (j);
        Model_kit &    kit   = model.kit ();
        MD_simulator & MD    = kit.MD ();

        //MD_subject *subject = kit.create_md_interface ();
        //subject_.adopt ( subject);
        //MD.set_subject (*subject);

        MD.tune ();

        MD.subject().start ();
        MD.method ().start ();
    }
}

void Replica_exchange::
finish ()
{
    Project & project = Project::singleton();

    for (int j=0;  j<project.model_count();  ++j)
    {
        Model &        model = project.model (j);
        Model_kit &    kit   = model.kit ();
        MD_simulator & MD    = kit.MD ();

        MD.method ().finish ();
        MD.subject().finish ();
    }

    stop_flag_ = true;
}


void Replica_exchange::
exchange_replica ()
{
    Project & project = Project::singleton();

    std::map <int, double> U;
    typedef std::map <int, double>::iterator Iterator;

    for (int a=0;  a<attempts_;  ++a)
    {
        // choose_pair
        int model_i_N = generator_.randInt (project.model_count() - 1);

        int model_j_N;
        do  model_j_N = generator_.randInt (project.model_count() - 1);
        while (model_i_N == model_j_N);

        Model & model_i = project.model (model_i_N);
        Model & model_j = project.model (model_j_N);

        Common_interactions & i_interactions 
            = model_i.kit().interaction().common();

        Common_interactions & j_interactions 
            = model_j.kit().interaction().common();

        // accept
        double k = System_of_Units::singleton().gas_constant();
        MD_simulator & MD_i = model_i.kit().MD ();
        MD_simulator & MD_j = model_j.kit().MD ();
        //Temperature_unit Tu_m =model_i.kit().simulator().desired_temperature();
        //Temperature_unit Tu_n =model_j.kit().simulator().desired_temperature();
        Temperature_unit Tu_m = MD_i.desired_temperature();
        Temperature_unit Tu_n = MD_j.desired_temperature();
        double T_m = System_of_Units::singleton().temperature (Tu_m);
        double T_n = System_of_Units::singleton().temperature (Tu_n);
        double beta_m = 1. / (k * T_m);
        double beta_n = 1. / (k * T_n);

        double E_i, E_j;

        Iterator it = U.find (model_i_N);
        if (it == U.end())
            U[model_i_N] = E_i = MD_i.subject().U();
        else
            E_i = it->second;

        it = U.find (model_j_N);
        if (it == U.end())
            U[model_j_N] = E_j = MD_j.subject().U();
        else
            E_j = it->second;

        double delta;
        if (potential_is_different (i_interactions, j_interactions))
        {
            exchange_potential (i_interactions, j_interactions);

            double E_il = MD_i.subject().U();
            double E_jl = MD_j.subject().U();

            delta = beta_n * (E_il - E_j) - beta_m * (E_i - E_jl);
        }
        else
            delta = (beta_n - beta_m) * (E_i - E_j);

        if (exp (-delta) < generator_.rand())
        {
            // reject
            //log() << "-";

            if (potential_is_different (i_interactions, j_interactions))
                exchange_potential (i_interactions, j_interactions);

            ++reject_count_;
        }
        else
        {
            // accept
            //log() << "+";

            // rescale
            int count_i = model_i.atom_count();
            int count_j = model_j.atom_count();

            if (count_i != count_j)
            {
                to_user().fatal_error ("Different size models.");
                stop ();
                return;
            }

            double scale_i = sqrt (T_n / T_m);
            double scale_j = sqrt (T_m / T_n);

            for (int i=0;  i<count_i;  ++i)
            {
                model_i.atom(i).kit().scale_velocity (scale_i);
                model_j.atom(i).kit().scale_velocity (scale_j);
            }

            //model_i.kit().simulator().set_desired_temperature(Tu_n);
            //model_j.kit().simulator().set_desired_temperature(Tu_m);
            MD_i.set_desired_temperature(Tu_n);
            MD_j.set_desired_temperature(Tu_m);

            ++accept_count_;
        }
    }
}

bool Replica_exchange::
potential_is_different (Common_interactions const & i, 
                        Common_interactions const & j) const
{
    return i.rm () != j.rm ();
}

void Replica_exchange::
exchange_potential (Common_interactions & i, Common_interactions & j)
{
    double swap = i.rm ();
    i.set_rm (j.rm ());
    j.set_rm (swap);
}

}//MM
